Skip to content

Commit c914e05

Browse files
committed
mtmd: adapt Pillow image resizing function
1 parent 95239f9 commit c914e05

File tree

1 file changed

+212
-3
lines changed

1 file changed

+212
-3
lines changed

tools/mtmd/clip.cpp

Lines changed: 212 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4218,6 +4218,7 @@ struct img_tool {
42184218
enum resize_algo {
42194219
RESIZE_ALGO_BILINEAR,
42204220
RESIZE_ALGO_BICUBIC,
4221+
RESIZE_ALGO_BICUBIC_PILLOW,
42214222
// RESIZE_ALGO_LANCZOS, // TODO
42224223
};
42234224

@@ -4247,6 +4248,9 @@ struct img_tool {
42474248
case RESIZE_ALGO_BICUBIC:
42484249
resize_bicubic(src, dst, target_resolution.width, target_resolution.height);
42494250
break;
4251+
case RESIZE_ALGO_BICUBIC_PILLOW:
4252+
resize_bicubic_pillow(src, dst, target_resolution.width, target_resolution.height);
4253+
break;
42504254
default:
42514255
throw std::runtime_error("Unsupported resize algorithm");
42524256
}
@@ -4266,6 +4270,9 @@ struct img_tool {
42664270
case RESIZE_ALGO_BICUBIC:
42674271
resize_bicubic(src, resized_image, new_width, new_height);
42684272
break;
4273+
case RESIZE_ALGO_BICUBIC_PILLOW:
4274+
resize_bicubic_pillow(src, resized_image, new_width, new_height);
4275+
break;
42694276
default:
42704277
throw std::runtime_error("Unsupported resize algorithm");
42714278
}
@@ -4475,6 +4482,209 @@ struct img_tool {
44754482
return true;
44764483
}
44774484

4485+
// Bicubic resize function using Pillow's ImagingResample algorithm
4486+
// Adapted from https://github.com/python-pillow/Pillow/blob/main/src/libImaging/Resample.c
4487+
static bool resize_bicubic_pillow(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
4488+
const int PRECISION_BITS = 32 - 8 - 2;
4489+
4490+
// Bicubic filter function
4491+
auto bicubic_filter = [](double x) -> double {
4492+
constexpr double a = -0.5;
4493+
if (x < 0.0) {
4494+
x = -x;
4495+
}
4496+
if (x < 1.0) {
4497+
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1;
4498+
}
4499+
if (x < 2.0) {
4500+
return (((x - 5) * x + 8) * x - 4) * a;
4501+
}
4502+
return 0.0;
4503+
};
4504+
4505+
constexpr double filter_support = 2.0;
4506+
4507+
// Clipping function for 8-bit values
4508+
auto clip8 = [](int val) -> uint8_t {
4509+
if (val < 0) return 0;
4510+
if (val > 255) return 255;
4511+
return static_cast<uint8_t>(val);
4512+
};
4513+
4514+
// Precompute coefficients
4515+
auto precompute_coeffs = [&](int inSize, double in0, double in1, int outSize,
4516+
std::vector<int> & bounds, std::vector<int32_t> & kk) -> int {
4517+
double support, scale, filterscale;
4518+
double center, ww, ss;
4519+
int xx, x, ksize, xmin, xmax;
4520+
4521+
filterscale = scale = (in1 - in0) / outSize;
4522+
if (filterscale < 1.0) {
4523+
filterscale = 1.0;
4524+
}
4525+
4526+
support = filter_support * filterscale;
4527+
ksize = static_cast<int>(std::ceil(support)) * 2 + 1;
4528+
4529+
std::vector<double> prekk(outSize * ksize);
4530+
bounds.resize(outSize * 2);
4531+
4532+
for (xx = 0; xx < outSize; xx++) {
4533+
center = in0 + (xx + 0.5) * scale;
4534+
ww = 0.0;
4535+
ss = 1.0 / filterscale;
4536+
4537+
xmin = static_cast<int>(center - support + 0.5);
4538+
if (xmin < 0) {
4539+
xmin = 0;
4540+
}
4541+
4542+
xmax = static_cast<int>(center + support + 0.5);
4543+
if (xmax > inSize) {
4544+
xmax = inSize;
4545+
}
4546+
xmax -= xmin;
4547+
4548+
double * k = &prekk[xx * ksize];
4549+
for (x = 0; x < xmax; x++) {
4550+
double w = bicubic_filter((x + xmin - center + 0.5) * ss);
4551+
k[x] = w;
4552+
ww += w;
4553+
}
4554+
4555+
for (x = 0; x < xmax; x++) {
4556+
if (ww != 0.0) {
4557+
k[x] /= ww;
4558+
}
4559+
}
4560+
4561+
for (; x < ksize; x++) {
4562+
k[x] = 0;
4563+
}
4564+
4565+
bounds[xx * 2 + 0] = xmin;
4566+
bounds[xx * 2 + 1] = xmax;
4567+
}
4568+
4569+
// Normalize coefficients to fixed-point
4570+
kk.resize(outSize * ksize);
4571+
for (int i = 0; i < outSize * ksize; i++) {
4572+
if (prekk[i] < 0) {
4573+
kk[i] = static_cast<int32_t>(-0.5 + prekk[i] * (1 << PRECISION_BITS));
4574+
} else {
4575+
kk[i] = static_cast<int32_t>(0.5 + prekk[i] * (1 << PRECISION_BITS));
4576+
}
4577+
}
4578+
4579+
return ksize;
4580+
};
4581+
4582+
// Horizontal resampling
4583+
auto resample_horizontal = [&](const clip_image_u8 & imIn, clip_image_u8 & imOut,
4584+
int ksize, const std::vector<int> & bounds, const std::vector<int32_t> & kk) {
4585+
imOut.ny = imIn.ny;
4586+
imOut.buf.resize(3 * imOut.nx * imOut.ny);
4587+
4588+
for (int yy = 0; yy < imOut.ny; yy++) {
4589+
for (int xx = 0; xx < imOut.nx; xx++) {
4590+
int xmin = bounds[xx * 2 + 0];
4591+
int xmax = bounds[xx * 2 + 1];
4592+
const int32_t * k = &kk[xx * ksize];
4593+
4594+
int32_t ss0 = 1 << (PRECISION_BITS - 1);
4595+
int32_t ss1 = 1 << (PRECISION_BITS - 1);
4596+
int32_t ss2 = 1 << (PRECISION_BITS - 1);
4597+
4598+
for (int x = 0; x < xmax; x++) {
4599+
int src_idx = ((yy * imIn.nx) + (x + xmin)) * 3;
4600+
ss0 += static_cast<uint8_t>(imIn.buf[src_idx + 0]) * k[x];
4601+
ss1 += static_cast<uint8_t>(imIn.buf[src_idx + 1]) * k[x];
4602+
ss2 += static_cast<uint8_t>(imIn.buf[src_idx + 2]) * k[x];
4603+
}
4604+
4605+
int dst_idx = (yy * imOut.nx + xx) * 3;
4606+
imOut.buf[dst_idx + 0] = clip8(ss0 >> PRECISION_BITS);
4607+
imOut.buf[dst_idx + 1] = clip8(ss1 >> PRECISION_BITS);
4608+
imOut.buf[dst_idx + 2] = clip8(ss2 >> PRECISION_BITS);
4609+
}
4610+
}
4611+
};
4612+
4613+
// Vertical resampling
4614+
auto resample_vertical = [&](const clip_image_u8 & imIn, clip_image_u8 & imOut,
4615+
int ksize, const std::vector<int> & bounds, const std::vector<int32_t> & kk) {
4616+
imOut.nx = imIn.nx;
4617+
imOut.buf.resize(3 * imOut.nx * imOut.ny);
4618+
4619+
for (int yy = 0; yy < imOut.ny; yy++) {
4620+
int ymin = bounds[yy * 2 + 0];
4621+
int ymax = bounds[yy * 2 + 1];
4622+
const int32_t * k = &kk[yy * ksize];
4623+
4624+
for (int xx = 0; xx < imOut.nx; xx++) {
4625+
int32_t ss0 = 1 << (PRECISION_BITS - 1);
4626+
int32_t ss1 = 1 << (PRECISION_BITS - 1);
4627+
int32_t ss2 = 1 << (PRECISION_BITS - 1);
4628+
4629+
for (int y = 0; y < ymax; y++) {
4630+
int src_idx = ((y + ymin) * imIn.nx + xx) * 3;
4631+
ss0 += static_cast<uint8_t>(imIn.buf[src_idx + 0]) * k[y];
4632+
ss1 += static_cast<uint8_t>(imIn.buf[src_idx + 1]) * k[y];
4633+
ss2 += static_cast<uint8_t>(imIn.buf[src_idx + 2]) * k[y];
4634+
}
4635+
4636+
int dst_idx = (yy * imOut.nx + xx) * 3;
4637+
imOut.buf[dst_idx + 0] = clip8(ss0 >> PRECISION_BITS);
4638+
imOut.buf[dst_idx + 1] = clip8(ss1 >> PRECISION_BITS);
4639+
imOut.buf[dst_idx + 2] = clip8(ss2 >> PRECISION_BITS);
4640+
}
4641+
}
4642+
};
4643+
4644+
// Main resampling logic
4645+
const int src_width = img.nx;
4646+
const int src_height = img.ny;
4647+
4648+
dst.nx = target_width;
4649+
dst.ny = target_height;
4650+
4651+
bool need_horizontal = (target_width != src_width);
4652+
bool need_vertical = (target_height != src_height);
4653+
4654+
// Precompute coefficients for both passes
4655+
std::vector<int> bounds_horiz, bounds_vert;
4656+
std::vector<int32_t> kk_horiz, kk_vert;
4657+
int ksize_horiz = 0, ksize_vert = 0;
4658+
4659+
if (need_horizontal) {
4660+
ksize_horiz = precompute_coeffs(src_width, 0.0, src_width, target_width, bounds_horiz, kk_horiz);
4661+
}
4662+
4663+
if (need_vertical) {
4664+
ksize_vert = precompute_coeffs(src_height, 0.0, src_height, target_height, bounds_vert, kk_vert);
4665+
}
4666+
4667+
// Perform two-pass resampling
4668+
if (need_horizontal && need_vertical) {
4669+
// Both horizontal and vertical
4670+
clip_image_u8 temp;
4671+
temp.nx = target_width;
4672+
resample_horizontal(img, temp, ksize_horiz, bounds_horiz, kk_horiz);
4673+
resample_vertical(temp, dst, ksize_vert, bounds_vert, kk_vert);
4674+
} else if (need_horizontal) {
4675+
// Only horizontal
4676+
resample_horizontal(img, dst, ksize_horiz, bounds_horiz, kk_horiz);
4677+
} else if (need_vertical) {
4678+
// Only vertical
4679+
resample_vertical(img, dst, ksize_vert, bounds_vert, kk_vert);
4680+
} else {
4681+
// No resampling needed
4682+
dst.buf = img.buf;
4683+
}
4684+
4685+
return true;
4686+
}
4687+
44784688
static inline int clip(int x, int lower, int upper) {
44794689
return std::max(lower, std::min(x, upper));
44804690
}
@@ -5101,7 +5311,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
51015311
clip_image_u8_ptr resized_img(clip_image_u8_init());
51025312
img_tool::resize(*img, *resized_img,
51035313
clip_image_size{image_size, image_size},
5104-
img_tool::RESIZE_ALGO_BICUBIC, true, color); // Match PIL default
5314+
img_tool::RESIZE_ALGO_BICUBIC_PILLOW, false, color); // Match PIL default
51055315

51065316
clip_image_f32_ptr res(clip_image_f32_init());
51075317
normalize_image_u8_to_f32(*resized_img, *res, params.image_mean, params.image_std);
@@ -5124,7 +5334,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
51245334

51255335
clip_image_u8_ptr scaled_img(clip_image_u8_init());
51265336
img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h},
5127-
img_tool::RESIZE_ALGO_BICUBIC, true, color);
5337+
img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color);
51285338

51295339
// Use mean color for padding
51305340
unsigned char pad_r = static_cast<unsigned char>(params.image_mean[0] * 255.0f);
@@ -5815,7 +6025,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
58156025
bool is_stored = false;
58166026
std::vector<std::string> patterns = {
58176027
/* Add tensor names here to dump (e.g. "sam_output") */
5818-
"inpL", "inp_raw_cpy"
58196028
};
58206029

58216030
for (auto & p : patterns) {

0 commit comments

Comments
 (0)