Skip to content

Commit 5c46965

Browse files
committed
Image normalization working for vgg pre processing.
1 parent 4b2ee64 commit 5c46965

5 files changed

Lines changed: 44 additions & 59 deletions

File tree

bridge/include/bridge.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ bridge_tensor_t load_tensor_from_file(const uint8_t* file_path);
3434
bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_t* tensor_key);
3535
bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tensor_t input);
3636
bridge_tensor_t resize(bridge_tensor_t input,int height,int width);
37+
bridge_tensor_t imagenet_normalize(bridge_tensor_t input);
3738

3839

3940
int baz(void);

bridge/lib/bridge.cpp

Lines changed: 20 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -206,57 +206,6 @@ extern "C" bridge_tensor_t max_pool2d(
206206
return torch_to_bridge(output);
207207
}
208208

209-
/*
210-
211-
* Resize the last two dimensions of a tensor, mimicking torchvision.transforms.Resize.
212-
*
213-
* • Works with 2‑D (H,W), 3‑D (C,H,W or N,H,W), 4‑D (N,C,H,W) … tensors – any rank ≥ 2.
214-
* • Leading dimensions are preserved; only the final H × W are resized.
215-
* • Defaults match torchvision: bilinear for floating tensors, align_corners=false.
216-
*
217-
* @param input Tensor on CPU or CUDA.
218-
* @param new_h Target height.
219-
* @param new_w Target width.
220-
* @param mode Interpolation mode (“bilinear”, “nearest”, “bicubic”, …).
221-
* @param align_corners Forwarded to F::interpolate (ignored for “nearest”).
222-
* @return Tensor with shape …, new_h, new_w and same dtype / device.
223-
224-
inline torch::Tensor resize_tensor_last2(
225-
const torch::Tensor& input,
226-
int64_t new_h,
227-
int64_t new_w,
228-
const std::string& mode = "bilinear",
229-
bool align_corners = false) {
230-
231-
// Keep dtype/device; cast to float if interpolate needs it
232-
const bool need_cast = !input.is_floating_point() && mode != "nearest";
233-
auto x = need_cast ? input.to(torch::kFloat32) : input;
234-
x = x.contiguous(); // guarantees a re‑view is safe
235-
236-
// Collapse every axis except the last two into a single batch dimension.
237-
const int64_t h = x.size(-2);
238-
const int64_t w = x.size(-1);
239-
const int64_t flat = x.numel() / (h * w); // product of leading dims
240-
241-
auto x4d = x.view({flat, 1, h, w}); // N=flat, C=1, H, W
242-
243-
// Interpolate – equivalent to torchvision.transforms.Resize for tensors.
244-
auto y4d = torch::nn::functional::interpolate(
245-
x4d,
246-
torch::nn::functional::InterpolateFuncOptions()
247-
.size(std::vector<int64_t>{new_h, new_w})
248-
.mode(mode)
249-
.align_corners(align_corners));
250-
251-
// Restore the original leading shape.
252-
std::vector<int64_t> out_shape(input.sizes().begin(), input.sizes().end() - 2);
253-
out_shape.push_back(new_h);
254-
out_shape.push_back(new_w);
255-
256-
auto y = y4d.view(out_shape);
257-
return need_cast ? y.to(input.scalar_type()) : y;
258-
}
259-
*/
260209
extern "C" bridge_tensor_t resize(
261210
bridge_tensor_t input,
262211
int height,
@@ -294,6 +243,26 @@ extern "C" bridge_tensor_t resize(
294243
}
295244
}
296245

246+
extern "C" bridge_tensor_t imagenet_normalize(bridge_tensor_t input) {
247+
auto t_input = bridge_to_torch(input);
248+
torch::Tensor image = t_input; //.to(torch::kFloat32);// / 255.0;
249+
250+
static const std::vector<float> kMean{0.485, 0.456, 0.406};
251+
static const std::vector<float> kStd {0.229, 0.224, 0.225};
252+
auto opts = image.options();
253+
auto mean = torch::tensor(kMean).reshape({3, 1, 1}); // (3,1,1)
254+
auto std = torch::tensor(kStd).reshape({3, 1, 1});
255+
256+
// if (image.dim() == 4) {
257+
// mean = mean.unsqueeze(0); // (1,3,1,1)
258+
// std = std.unsqueeze(0);
259+
// }
260+
261+
auto output = (image - mean) / std;
262+
return torch_to_bridge(output);
263+
}
264+
265+
297266

298267
// extern "C"
299268

examples/torch_model_loading/torch_load.chpl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,24 @@ proc main(args: [] string) {
1111
writeln("Loaded image: ", args[1]);
1212
writeln("Image shape: ", image.shape);
1313

14-
image = image.resize(224,224);
15-
writeln("Resized image: ", image.shape);
14+
writeln("image : ", max reduce image.data);
1615

17-
var batchedImage = ndarray.loadFrom(args[1],3,real(32)).unsqueeze(0);
18-
writeln("Batched image: ", batchedImage.shape);
16+
image = image.imageNetNormalize();
17+
writeln("image : ", max reduce image.data);
1918

20-
batchedImage = batchedImage.resize(224,224);
21-
writeln("Batched image resized: ", batchedImage.shape);
2219

23-
image = batchedImage.squeeze(3);
24-
writeln("Squeezed image: ", image.shape);
20+
21+
// image = image.resize(224,224).imageNetNormalize();
22+
// writeln("Resized image: ", image.shape);
23+
24+
// var batchedImage = ndarray.loadFrom(args[1],3,real(32)).unsqueeze(0);
25+
// writeln("Batched image: ", batchedImage.shape);
26+
27+
// batchedImage = batchedImage.resize(224,224);
28+
// writeln("Batched image resized: ", batchedImage.shape);
29+
30+
// image = batchedImage.squeeze(3).imageNetNormalize();
31+
// writeln("Squeezed image: ", image.shape);
2532

2633
image.saveImage("test.jpg");
2734
}

lib/Bridge.chpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ module Bridge {
7878
in input: bridge_tensor_t,
7979
in height: int(32),
8080
in width: int(32)): bridge_tensor_t;
81+
82+
extern "imagenet_normalize" proc imageNetNormalize(
83+
in input: bridge_tensor_t): bridge_tensor_t;
8184

8285

8386

lib/NDArray.chpl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2230,6 +2230,11 @@ proc ndarray.resize(height: int,width: int) {
22302230
width : int(32)) : ndarray(rank,eltType);
22312231
}
22322232

2233+
proc ndarray.imageNetNormalize() {
2234+
return Bridge.imageNetNormalize(
2235+
this : Bridge.tensorHandle(eltType)) : ndarray(rank,eltType);
2236+
}
2237+
22332238
proc type ndarray.loadImage(imagePath: string, type eltType = defaultEltType): ndarray(3,eltType) throws {
22342239
import Image;
22352240

0 commit comments

Comments
 (0)