11use VGG;
22use Tensor;
33
4+ config param vggExampleDir = " ." ;
5+
6+ writeln (" VGG Example Directory: " , vggExampleDir);
7+
48config const k = 5 ;
5- config const labelFile = " imagenet/LOC_synset_mapping.txt" ;
9+ config const modelDir = vggExampleDir + " /models/vgg16/" ;
10+ config const labelFile = vggExampleDir + " /imagenet/LOC_synset_mapping.txt" ;
611
712
813proc getLabels(): [] {
@@ -24,22 +29,28 @@ proc confidence(x: []): [] {
2429}
2530
2631// returns (top k indicies, top k condiences)
27- proc run(model: borrowed , file: string ) {
28- const img = Tensor.load(file): real (32 );
32+ proc run(model: shared VGG16(real (32 )), file: string ) {
2933
30- writeln (" Loaded image: " , file);
31- writeln (" Image shape: " , img.shape());
3234
33- var output = model(img);
35+ writeln (" Loading image: " , file);
36+ // const image: dynamicTensor(real(32)) = dynamicTensor.loadImage(imagePath=file,eltType=real(32));
37+ const imageData: ndarray(3 ,real (32 )) = ndarray.loadImage(imagePath= file,eltType= real (32 ));
38+ writeln (" Loaded image: " , file);
39+ writeln (" Image shape: " , imageData.shape);
40+ const image: dynamicTensor(real (32 )) = imageData.toTensor(); // new dynamicTensor(imageData);
41+ writeln (" Converted image to dynamicTensor (or Tensor)." );
3442
35- writeln (" Output shape: " , output.shape());
43+ writeln (" Running model on image." );
44+ var output: dynamicTensor(real (32 )) = model(image);
45+ writeln (" Output shape: " , output.shape());
46+ writeln (" Output type: " , output.type : string );
3647
37- const top = output.topk(k);
38- var topArr = top.tensorize (1 ).array.data;
39- var percent = confidence(output.tensorize (1 ).array.data);
48+ const top = output.topk(k);
49+ var topArr = top.forceRank (1 ).array.data;
50+ var percent = confidence(output.forceRank (1 ).array.data);
4051
41- var percentTopk = [ i in 0 ..< k] percent(topArr[ i] );
42- return (topArr, percentTopk);
52+ var percentTopk = [ i in 0 ..< k] percent(topArr[ i] );
53+ return (topArr, percentTopk);
4354}
4455
4556proc main(args: [] string ) {
@@ -48,11 +59,11 @@ proc main(args: [] string) {
4859 writeln (" Loaded " , labels.size, " labels." );
4960
5061 writeln (" Constructing VGG16 model." );
51- const vgg = new VGG16(real (32 ));
62+ const vgg = new shared VGG16(real (32 ));
5263 writeln (" Constructed VGG16 model." );
5364
5465 writeln (" Loading VGG16 model weights." );
55- vgg.loadPyTorchDump(" models/vgg16/ " , false );
66+ vgg.loadPyTorchDump(modelDir , false );
5667 writeln (" Loaded VGG16 model." );
5768
5869
0 commit comments