Skip to content

Commit a4b140e

Browse files
committed
Add resizing function for images in ndarray.
1 parent 2015f6e commit a4b140e

12 files changed

Lines changed: 105 additions & 41 deletions

File tree

bridge/include/bridge.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ float* unsafe(const float* arr);
3333
bridge_tensor_t load_tensor_from_file(const uint8_t* file_path);
3434
bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_t* tensor_key);
3535
bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input);
36+
bridge_tensor_t resize(bridge_tensor_t input,int height,int width);
3637

3738

3839
int baz(void);

bridge/lib/bridge.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,94 @@ extern "C" bridge_tensor_t max_pool2d(
206206
return torch_to_bridge(output);
207207
}
208208

209+
/*
210+
211+
* Resize the last two dimensions of a tensor, mimicking torchvision.transforms.Resize.
212+
*
213+
* • Works with 2‑D (H,W), 3‑D (C,H,W or N,H,W), 4‑D (N,C,H,W) … tensors – any rank ≥ 2.
214+
* • Leading dimensions are preserved; only the final H × W are resized.
215+
* • Defaults match torchvision: bilinear for floating tensors, align_corners=false.
216+
*
217+
* @param input Tensor on CPU or CUDA.
218+
* @param new_h Target height.
219+
* @param new_w Target width.
220+
* @param mode Interpolation mode (“bilinear”, “nearest”, “bicubic”, …).
221+
* @param align_corners Forwarded to F::interpolate (ignored for “nearest”).
222+
* @return Tensor with shape …, new_h, new_w and same dtype / device.
223+
224+
inline torch::Tensor resize_tensor_last2(
225+
const torch::Tensor& input,
226+
int64_t new_h,
227+
int64_t new_w,
228+
const std::string& mode = "bilinear",
229+
bool align_corners = false) {
230+
231+
// Keep dtype/device; cast to float if interpolate needs it
232+
const bool need_cast = !input.is_floating_point() && mode != "nearest";
233+
auto x = need_cast ? input.to(torch::kFloat32) : input;
234+
x = x.contiguous(); // guarantees a re‑view is safe
235+
236+
// Collapse every axis except the last two into a single batch dimension.
237+
const int64_t h = x.size(-2);
238+
const int64_t w = x.size(-1);
239+
const int64_t flat = x.numel() / (h * w); // product of leading dims
240+
241+
auto x4d = x.view({flat, 1, h, w}); // N=flat, C=1, H, W
242+
243+
// Interpolate – equivalent to torchvision.transforms.Resize for tensors.
244+
auto y4d = torch::nn::functional::interpolate(
245+
x4d,
246+
torch::nn::functional::InterpolateFuncOptions()
247+
.size(std::vector<int64_t>{new_h, new_w})
248+
.mode(mode)
249+
.align_corners(align_corners));
250+
251+
// Restore the original leading shape.
252+
std::vector<int64_t> out_shape(input.sizes().begin(), input.sizes().end() - 2);
253+
out_shape.push_back(new_h);
254+
out_shape.push_back(new_w);
255+
256+
auto y = y4d.view(out_shape);
257+
return need_cast ? y.to(input.scalar_type()) : y;
258+
}
259+
*/
260+
extern "C" bridge_tensor_t resize(
261+
bridge_tensor_t input,
262+
int height,
263+
int width
264+
) {
265+
auto image = bridge_to_torch(input);
266+
267+
// auto output = resize_tensor_last2(image, height, width);
268+
269+
// at::Tensor output = at::upsample_bilinear2d(t_input.unsqueeze(0), {height, width}, false);
270+
if (image.dim() == 3) {
271+
auto output = torch::nn::functional::interpolate(
272+
image.unsqueeze(0),
273+
torch::nn::functional::InterpolateFuncOptions()
274+
.size(std::vector<int64_t>({ height, width }))
275+
.mode(torch::kBilinear)
276+
.align_corners(false)
277+
).squeeze(0);
278+
return torch_to_bridge(output);
279+
} else if (image.dim() == 4) {
280+
auto output = torch::nn::functional::interpolate(
281+
image,
282+
torch::nn::functional::InterpolateFuncOptions()
283+
.size(std::vector<int64_t>({ height, width }))
284+
.mode(torch::kBilinear)
285+
.align_corners(false)
286+
);
287+
return torch_to_bridge(output);
288+
} else {
289+
std::cerr << "Unsupported tensor dimension: " << image.dim() << std::endl;
290+
std::cerr.flush();
291+
std::cout << "Unsupported tensor dimension: " << image.dim() << std::endl;
292+
std::cout.flush();
293+
return input; // Return the original tensor if the dimension is unsupported
294+
}
295+
}
296+
209297

210298
// extern "C"
211299

clib/Makefile

Lines changed: 0 additions & 18 deletions
This file was deleted.

clib/libmylib.a

-712 Bytes
Binary file not shown.

clib/mylib.c

Lines changed: 0 additions & 9 deletions
This file was deleted.

clib/mylib.h

Lines changed: 0 additions & 4 deletions
This file was deleted.

clib/mylib.h.gch

-565 KB
Binary file not shown.

clib/mylib.o

-520 Bytes
Binary file not shown.

clib/test.chpl

Lines changed: 0 additions & 10 deletions
This file was deleted.

examples/torch_model_loading/torch_load.chpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,8 @@ proc main(args: [] string) {
1111
writeln("Loaded image: ", args[1]);
1212
writeln("Image shape: ", image.shape);
1313

14+
image = image.resize(224,224);
15+
writeln("Resized image: ", image.shape);
16+
1417
image.saveImage("test.jpg");
1518
}

0 commit comments

Comments
 (0)