@@ -22,16 +22,16 @@ proc getLabels(): [] {
2222 return lines;
2323}
2424
25- proc confidence(x: ndarray(2 ,real (32 ))): [] {
25+ proc confidence(x: ndarray(2 ,real (32 ))): ndarray( 1 , real ( 32 )) {
2626// use Math;
2727// var expSum = + reduce exp(x.data);
2828// return (exp(x.data) / expSum) * 100.0;
2929
3030 // const X: ndarray(1,real(32)) = x.squeeze(1);
3131 const (_,i) = x.shape;
32- const X: ndarray(1 ,real (32 )) = x.reshape(i );
32+ const X: ndarray(1 ,real (32 )) = x.squeeze( 1 );
3333 const smX = X.softmax();
34- return smX.data ;
34+ return smX;
3535
3636
3737}
@@ -83,8 +83,8 @@ proc runX(file: string) {
8383 const predictions: ndarray(2 ,real (32 )) = output;
8484 const percent = confidence(predictions);
8585
86- const topPredictions: ndarray(2 ,int ) = predictions.topk(k);
87- var percentTopk = [ i in 0 ..< k] percent[ topPredictions[ 0 , i]] ;
86+ const topPredictions: ndarray(1 ,int ) = predictions.squeeze( 1 ) .topk(k);
87+ var percentTopk = [ i in 0 ..< k] percent[ topPredictions[ i]] ;
8888 return (topPredictions.data, percentTopk);
8989}
9090
@@ -103,7 +103,7 @@ proc main(args: [] string) {
103103 return ;
104104
105105
106-
106+ /*
107107 writeln("Loading labels from ", labelFile);
108108 const labels = getLabels();
109109 writeln("Loaded ", labels.size, " labels.");
@@ -127,5 +127,5 @@ proc main(args: [] string) {
127127 }
128128 writeln();
129129 }
130-
130+ */
131131}
0 commit comments