Skip to content

Commit 9b45c24

Browse files
committed
Batched vgg preprocessing working for ndarray loaded images in [0,1].
1 parent 5c46965 commit 9b45c24

2 files changed

Lines changed: 12 additions & 21 deletions

File tree

bridge/lib/bridge.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,10 @@ extern "C" bridge_tensor_t imagenet_normalize(bridge_tensor_t input) {
253253
auto mean = torch::tensor(kMean).reshape({3, 1, 1}); // (3,1,1)
254254
auto std = torch::tensor(kStd).reshape({3, 1, 1});
255255

256-
// if (image.dim() == 4) {
257-
// mean = mean.unsqueeze(0); // (1,3,1,1)
258-
// std = std.unsqueeze(0);
259-
// }
256+
if (image.dim() == 4) {
257+
mean = mean.unsqueeze(0); // (1,3,1,1)
258+
std = std.unsqueeze(0);
259+
}
260260

261261
auto output = (image - mean) / std;
262262
return torch_to_bridge(output);

examples/torch_model_loading/torch_load.chpl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,18 @@ proc main(args: [] string) {
88
// writeln("b sum: ", b.sum());
99

1010
var image = ndarray.loadFrom(args[1],3,real(32));
11-
writeln("Loaded image: ", args[1]);
12-
writeln("Image shape: ", image.shape);
1311

14-
writeln("image : ", max reduce image.data);
12+
image = image.resize(224,224).imageNetNormalize();
13+
writeln("Resized image: ", image.shape);
1514

16-
image = image.imageNetNormalize();
17-
writeln("image : ", max reduce image.data);
15+
var batchedImage = ndarray.loadFrom(args[1],3,real(32)).unsqueeze(0);
16+
writeln("Batched image: ", batchedImage.shape);
1817

18+
batchedImage = batchedImage.resize(224,224);
19+
writeln("Batched image resized: ", batchedImage.shape);
1920

20-
21-
// image = image.resize(224,224).imageNetNormalize();
22-
// writeln("Resized image: ", image.shape);
23-
24-
// var batchedImage = ndarray.loadFrom(args[1],3,real(32)).unsqueeze(0);
25-
// writeln("Batched image: ", batchedImage.shape);
26-
27-
// batchedImage = batchedImage.resize(224,224);
28-
// writeln("Batched image resized: ", batchedImage.shape);
29-
30-
// image = batchedImage.squeeze(3).imageNetNormalize();
31-
// writeln("Squeezed image: ", image.shape);
21+
image = batchedImage.squeeze(3).imageNetNormalize();
22+
writeln("Squeezed image: ", image.shape);
3223

3324
image.saveImage("test.jpg");
3425
}

0 commit comments

Comments
 (0)