@@ -206,6 +206,94 @@ extern "C" bridge_tensor_t max_pool2d(
206206 return torch_to_bridge (output);
207207}
208208
209+ /*
210+
211+ * Resize the last two dimensions of a tensor, mimicking torchvision.transforms.Resize.
212+ *
213+ * • Works with 2‑D (H,W), 3‑D (C,H,W or N,H,W), 4‑D (N,C,H,W) … tensors – any rank ≥ 2.
214+ * • Leading dimensions are preserved; only the final H × W are resized.
215+ * • Defaults match torchvision: bilinear for floating tensors, align_corners=false.
216+ *
217+ * @param input Tensor on CPU or CUDA.
218+ * @param new_h Target height.
219+ * @param new_w Target width.
220+ * @param mode Interpolation mode (“bilinear”, “nearest”, “bicubic”, …).
221+ * @param align_corners Forwarded to F::interpolate (ignored for “nearest”).
222+ * @return Tensor with shape …, new_h, new_w and same dtype / device.
223+
224+ inline torch::Tensor resize_tensor_last2(
225+ const torch::Tensor& input,
226+ int64_t new_h,
227+ int64_t new_w,
228+ const std::string& mode = "bilinear",
229+ bool align_corners = false) {
230+
231+ // Keep dtype/device; cast to float if interpolate needs it
232+ const bool need_cast = !input.is_floating_point() && mode != "nearest";
233+ auto x = need_cast ? input.to(torch::kFloat32) : input;
234+ x = x.contiguous(); // guarantees a re‑view is safe
235+
236+ // Collapse every axis except the last two into a single batch dimension.
237+ const int64_t h = x.size(-2);
238+ const int64_t w = x.size(-1);
239+ const int64_t flat = x.numel() / (h * w); // product of leading dims
240+
241+ auto x4d = x.view({flat, 1, h, w}); // N=flat, C=1, H, W
242+
243+ // Interpolate – equivalent to torchvision.transforms.Resize for tensors.
244+ auto y4d = torch::nn::functional::interpolate(
245+ x4d,
246+ torch::nn::functional::InterpolateFuncOptions()
247+ .size(std::vector<int64_t>{new_h, new_w})
248+ .mode(mode)
249+ .align_corners(align_corners));
250+
251+ // Restore the original leading shape.
252+ std::vector<int64_t> out_shape(input.sizes().begin(), input.sizes().end() - 2);
253+ out_shape.push_back(new_h);
254+ out_shape.push_back(new_w);
255+
256+ auto y = y4d.view(out_shape);
257+ return need_cast ? y.to(input.scalar_type()) : y;
258+ }
259+ */
260+ extern " C" bridge_tensor_t resize (
261+ bridge_tensor_t input,
262+ int height,
263+ int width
264+ ) {
265+ auto image = bridge_to_torch (input);
266+
267+ // auto output = resize_tensor_last2(image, height, width);
268+
269+ // at::Tensor output = at::upsample_bilinear2d(t_input.unsqueeze(0), {height, width}, false);
270+ if (image.dim () == 3 ) {
271+ auto output = torch::nn::functional::interpolate (
272+ image.unsqueeze (0 ),
273+ torch::nn::functional::InterpolateFuncOptions ()
274+ .size (std::vector<int64_t >({ height, width }))
275+ .mode (torch::kBilinear )
276+ .align_corners (false )
277+ ).squeeze (0 );
278+ return torch_to_bridge (output);
279+ } else if (image.dim () == 4 ) {
280+ auto output = torch::nn::functional::interpolate (
281+ image,
282+ torch::nn::functional::InterpolateFuncOptions ()
283+ .size (std::vector<int64_t >({ height, width }))
284+ .mode (torch::kBilinear )
285+ .align_corners (false )
286+ );
287+ return torch_to_bridge (output);
288+ } else {
289+ std::cerr << " Unsupported tensor dimension: " << image.dim () << std::endl;
290+ std::cerr.flush ();
291+ std::cout << " Unsupported tensor dimension: " << image.dim () << std::endl;
292+ std::cout.flush ();
293+ return input; // Return the original tensor if the dimension is unsupported
294+ }
295+ }
296+
209297
210298// extern "C"
211299
0 commit comments