Skip to content

Commit 7dbc0d8

Browse files
committed
Add {Tensor,tensor,ndarray}.square() example.
1 parent d495dc7 commit 7dbc0d8

4 files changed

Lines changed: 37 additions & 0 deletions

File tree

lib/Autograd.chpl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,17 @@ record reluOp : serializable {
278278
proc spec : GradOpSpec do return new dict(("operation","ReLU"));
279279
}
280280

281+
record squareOp : serializable {
282+
var input: shared BaseTensorResource(?);
283+
284+
proc children do return (input,);
285+
286+
proc forward() do
287+
return input.array.square();
288+
289+
proc spec : GradOpSpec do return new dict(("operation","Square"));
290+
}
291+
281292
record expOp : serializable {
282293
var input: shared BaseTensorResource(?);
283294

lib/DynamicTensor.chpl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,15 @@ proc dynamicTensor.relu(): dynamicTensor(eltType) {
244244
return new dynamicTensor(eltType);
245245
}
246246

247+
proc dynamicTensor.square(): dynamicTensor(eltType) {
248+
for param rank in 1..maxRank {
249+
if this.checkRank(rank) then
250+
return this.forceRank(rank).square().eraseRank();
251+
}
252+
halt("Could not determine rank in dynamicTensor.square.");
253+
return new dynamicTensor(eltType);
254+
}
255+
247256
proc dynamicTensor.gelu(): dynamicTensor(eltType) {
248257
for param rank in 1..maxRank {
249258
if this.checkRank(rank) then

lib/NDArray.chpl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,18 @@ record ndarray : serializable {
582582
return rl;
583583
}
584584

585+
inline proc square() {
586+
const ref thisData = data;
587+
const dom = this.domain;
588+
var rl = new ndarray(dom,eltType);
589+
ref rlD = rl.data;
590+
forall i in dom.every() {
591+
const x = thisData[i];
592+
rlD[i] = x * x;
593+
}
594+
return rl;
595+
}
596+
585597
inline proc gelu() {
586598
const ref thisData = data;
587599
const dom = this.domain;

lib/StaticTensor.chpl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ proc staticTensor.relu() {
177177
return tensorFromCtx(rank,eltType,ctx);
178178
}
179179

180+
proc staticTensor.square() {
181+
var ctx = new squareOp(meta);
182+
return tensorFromCtx(rank,eltType,ctx);
183+
}
184+
180185
proc staticTensor.gelu() {
181186
var t = new staticTensor(rank,eltType);
182187
on this.device {

0 commit comments

Comments
 (0)