File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -278,6 +278,17 @@ record reluOp : serializable {
278278 proc spec : GradOpSpec do return new dict ((" operation" ," ReLU" ));
279279}
280280
281+ record squareOp : serializable {
282+ var input: shared BaseTensorResource(?);
283+
284+ proc children do return (input,);
285+
286+ proc forward () do
287+ return input.array.square();
288+
289+ proc spec : GradOpSpec do return new dict ((" operation" ," Square" ));
290+ }
291+
281292record expOp : serializable {
282293 var input: shared BaseTensorResource(?);
283294
Original file line number Diff line number Diff line change @@ -244,6 +244,15 @@ proc dynamicTensor.relu(): dynamicTensor(eltType) {
244244 return new dynamicTensor(eltType);
245245}
246246
247+ proc dynamicTensor.square(): dynamicTensor(eltType) {
248+ for param rank in 1 ..maxRank {
249+ if this .checkRank(rank) then
250+ return this .forceRank(rank).square().eraseRank();
251+ }
252+ halt(" Could not determine rank in dynamicTensor.square." );
253+ return new dynamicTensor(eltType);
254+ }
255+
247256proc dynamicTensor.gelu(): dynamicTensor(eltType) {
248257 for param rank in 1 ..maxRank {
249258 if this .checkRank(rank) then
Original file line number Diff line number Diff line change @@ -582,6 +582,18 @@ record ndarray : serializable {
582582 return rl;
583583 }
584584
585+ inline proc square() {
586+ const ref thisData = data;
587+ const dom = this .domain ;
588+ var rl = new ndarray(dom,eltType);
589+ ref rlD = rl.data;
590+ forall i in dom.every() {
591+ const x = thisData[ i] ;
592+ rlD[ i] = x * x;
593+ }
594+ return rl;
595+ }
596+
585597 inline proc gelu() {
586598 const ref thisData = data;
587599 const dom = this .domain ;
Original file line number Diff line number Diff line change @@ -177,6 +177,11 @@ proc staticTensor.relu() {
177177 return tensorFromCtx(rank,eltType,ctx);
178178}
179179
180+ proc staticTensor.square() {
181+ var ctx = new squareOp(meta);
182+ return tensorFromCtx(rank,eltType,ctx);
183+ }
184+
180185proc staticTensor.gelu() {
181186 var t = new staticTensor(rank,eltType);
182187 on this .device {
You can’t perform that action at this time.
0 commit comments