Skip to content

Commit b5060e9

Browse files
committed
Add incredible improvements to vgg and vridge.
1 parent 0b180cd commit b5060e9

2 files changed

Lines changed: 8 additions & 10 deletions

File tree

examples/vgg/test.chpl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ proc getLabels(): [] {
2222
return lines;
2323
}
2424

25-
proc confidence(x: ndarray(2,real(32))): [] {
25+
proc confidence(x: ndarray(2,real(32))): ndarray(1,real(32)) {
2626
// use Math;
2727
// var expSum = + reduce exp(x.data);
2828
// return (exp(x.data) / expSum) * 100.0;
2929

3030
// const X: ndarray(1,real(32)) = x.squeeze(1);
3131
const (_,i) = x.shape;
32-
const X: ndarray(1,real(32)) = x.reshape(i);
32+
const X: ndarray(1,real(32)) = x.squeeze(1);
3333
const smX = X.softmax();
34-
return smX.data;
34+
return smX;
3535

3636

3737
}
@@ -83,8 +83,8 @@ proc runX(file: string) {
8383
const predictions: ndarray(2,real(32)) = output;
8484
const percent = confidence(predictions);
8585

86-
const topPredictions: ndarray(2,int) = predictions.topk(k);
87-
var percentTopk = [i in 0..<k] percent[topPredictions[0,i]];
86+
const topPredictions: ndarray(1,int) = predictions.squeeze(1).topk(k);
87+
var percentTopk = [i in 0..<k] percent[topPredictions[i]];
8888
return (topPredictions.data, percentTopk);
8989
}
9090

@@ -103,7 +103,7 @@ proc main(args: [] string) {
103103
return;
104104

105105

106-
106+
/*
107107
writeln("Loading labels from ", labelFile);
108108
const labels = getLabels();
109109
writeln("Loaded ", labels.size, " labels.");
@@ -127,5 +127,5 @@ proc main(args: [] string) {
127127
}
128128
writeln();
129129
}
130-
130+
*/
131131
}

lib/NDArray.chpl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,9 +2579,7 @@ proc type ndarray.einsum(param subscripts: string,a: ndarray(?rankA,?eltType), b
25792579
:returns: For a tensor ``t``, :math:`\frac{\exp{t}}{\Sigma \exp{t}}`.
25802580
:rtype: ndarray(rank, eltType)
25812581
*/
2582-
proc ndarray.softmax(): ndarray(this.rank, this.eltType)
2583-
where isSubtype(this.eltType, real)
2584-
{
2582+
proc ndarray.softmax(): ndarray(this.rank, this.eltType) {
25852583
const dom = this.domain;
25862584
const ref thisData = this.data;
25872585

0 commit comments

Comments
 (0)