Skip to content

Commit 34d4447

Browse files
committed
Improve vgg code. Going to add pytorch tensor file loading.
1 parent f5efcae commit 34d4447

5 files changed

Lines changed: 53 additions & 22 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ set(CHAI_LINKER_ARGS
270270
add_executable(vgg
271271
"${PROJECT_ROOT_DIR}/examples/vgg/test.chpl"
272272
${PROJECT_ROOT_DIR}/examples/vgg/VGG.chpl
273+
${CHAI_LIB_FILES}
273274
)
274275
add_dependencies(vgg bridge)
275276
add_dependencies(vgg ChAI)

examples/vgg/test.chpl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ proc getLabels(): [] {
2222
return lines;
2323
}
2424

25-
proc confidence(x: []): [] {
25+
proc confidence(x: ndarray(1,real(32))): [] {
2626
use Math;
27-
var expSum = + reduce exp(x);
28-
return (exp(x) / expSum) * 100.0;
27+
var expSum = + reduce exp(x.data);
28+
return (exp(x.data) / expSum) * 100.0;
2929
}
3030

3131
// returns (top k indicies, top k condiences)
@@ -41,16 +41,16 @@ proc run(model: shared VGG16(real(32)), file: string) {
4141
writeln("Converted image to dynamicTensor (or Tensor).");
4242

4343
writeln("Running model on image.");
44-
var output: dynamicTensor(real(32)) = model(image);
44+
const output: dynamicTensor(real(32)) = model(image);
4545
writeln("Output shape: ", output.shape());
4646
writeln("Output type: ", output.type:string);
4747

48-
const top = output.topk(k);
49-
var topArr = top.forceRank(1).array.data;
50-
var percent = confidence(output.forceRank(1).array.data);
51-
52-
var percentTopk = [i in 0..<k] percent(topArr[i]);
53-
return (topArr, percentTopk);
48+
const predictions: ndarray(1,real(32)) = output.forceRank(rank=1).array;
49+
const percent = confidence(predictions);
50+
51+
const topPredictions: ndarray(1,int) = predictions.topk(k);
52+
var percentTopk = [i in 0..<k] percent[topPredictions[i]];
53+
return (topPredictions.data, percentTopk);
5454
}
5555

5656
proc main(args: [] string) {

lib/DynamicTensor.chpl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,22 @@ record dynamicTensor : serializable {
126126
return this;
127127
}
128128

129-
proc array(param rank: int) ref : ndarray(rank,eltType) do
129+
inline proc ref rankedArray(param rank: int) ref : ndarray(rank,eltType) do
130130
return (this.meta.borrow() : borrowed BaseTensorResource(eltType, rank)).array;
131131

132-
proc grad(param rank: int) ref : ndarray(rank,eltType) do
132+
inline proc rankedArray(param rank: int): ndarray(rank,eltType) do
133+
return (this.meta.borrow() : borrowed BaseTensorResource(eltType, rank)).array;
134+
135+
inline proc ref rankedGradArray(param rank: int) ref : ndarray(rank,eltType) do
136+
return (this.meta.borrow() : borrowed BaseTensorResource(eltType, rank)).grad;
137+
138+
inline proc rankedGradArray(param rank: int): ndarray(rank,eltType) do
133139
return (this.meta.borrow() : borrowed BaseTensorResource(eltType, rank)).grad;
134140

135-
proc data(param rank: int) ref : [] eltType do
141+
inline proc ref rankedData(param rank: int) ref : [] eltType do
142+
return (this.meta.borrow() : borrowed BaseTensorResource(eltType, rank)).data;
143+
144+
inline proc rankedData(param rank: int): [] eltType do
136145
return (this.meta.borrow() : borrowed BaseTensorResource(eltType, rank)).data;
137146

138147

@@ -156,14 +165,23 @@ record dynamicTensor : serializable {
156165
}
157166
}
158167

159-
operator :(in t: dynamicTensor(?eltType), type toType): dynamicTensor(toType) {
168+
operator :(in t: dynamicTensor(?eltType), type toType): dynamicTensor(toType)
169+
where isNumericType(toType) {
160170
if eltType == toType then return t;
161171
for param rank in 1..maxRank do
162172
if t.checkRank(rank) then
163173
return (t.forceRank(rank) : toType).eraseRank();
164174
halt("Could not identify rank for this: ", t);
165175
}
166176

177+
operator :(in t: dynamicTensor(?eltType), type toType: ndarray(?rank,?toEltType)): ndarray(rank,toEltType)
178+
where isNumericType(eltType) && isNumericType(toEltType) {
179+
if eltType == toEltType then
180+
return t.toNDArray(rank);
181+
else
182+
return t.toNDArray(rank) : toEltType;
183+
}
184+
167185
proc type dynamicTensor.detachMode() param : bool {
168186
return defaultDetachedMode;
169187
}

lib/NDArray.chpl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,11 @@ record ndarray : serializable {
260260
}
261261

262262

263-
proc ref ndarray.this(args: int...rank) ref {
263+
// proc ref ndarray.this(args: int...rank) ref {
264+
// return data.this((...args));
265+
// }
266+
267+
proc ndarray.this(args: int...rank) {
264268
return data.this((...args));
265269
}
266270

@@ -750,7 +754,7 @@ proc ndarray.topk(k: int): ndarray(1, int) where rank == 1 {
750754
const mySize = myDom.size;
751755
if k > mySize then util.err("Cannot get top ", k, " from ", mySize, " elements.");
752756
var topK: [0..<k] int = 0..<k;
753-
var topKData: [0..<k] eltType = myData(0..<k);
757+
var topKData: [0..<k] eltType = myData[0..<k];
754758

755759
// Repeatedly find the minimum from the elements of topKData,
756760
// and then swap it out with some element from the remaining portion
@@ -2304,8 +2308,9 @@ proc ref ndarray.read(fr: IO.fileReader(?)) throws {
23042308
s[i] = fr.read(int);
23052309
var d = util.domainFromShape((...s));
23062310
this._domain = d;
2307-
for i in d do
2308-
this.data[i] = fr.read(eltType);
2311+
// for i in d do
2312+
// this.data[i] = fr.read(eltType);
2313+
fr.read(this.data);
23092314
}
23102315

23112316

lib/StaticTensor.chpl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,19 @@ record staticTensor : serializable {
8888
}
8989
}
9090

91-
operator :(in t: staticTensor(?rank,?eltType), type toType): staticTensor(rank,toType) {
91+
operator :(in t: staticTensor(?rank,?eltType), type toType): staticTensor(rank,toType)
92+
where isNumericType(toType) {
9293
if toType == t.eltType then
9394
return t;
94-
const a = t.array;
95-
const b = a : toType;
96-
return new staticTensor(b);
95+
96+
const device = t.device;
97+
var newDataResource = new shared Remote(ndarray(rank,eltType),device);
98+
ref dat = newDataResource.ptr;
99+
on device do
100+
dat = t.array : toType;
101+
var newTR = new shared TensorResource(newDataResource);
102+
103+
return new staticTensor(newTR);
97104
}
98105

99106
proc staticTensor.shapeArray(): [] int {

0 commit comments

Comments
 (0)