Skip to content

Commit 0684044

Browse files
committed
Add layerNorm
1 parent 99a9a17 commit 0684044

5 files changed

Lines changed: 127 additions & 0 deletions

File tree

lib/Autograd.chpl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,21 @@ record batchNormOp : serializable {
11011101
proc spec : GradOpSpec do return new dict(("operation","BatchNorm"));
11021102
}
11031103

1104+
record layerNormOp : serializable {
1105+
type eltType = real;
1106+
var features: shared BaseTensorResource(?);
1107+
var weight: shared BaseTensorResource(eltType, ?);
1108+
var bias: shared BaseTensorResource(eltType, ?);
1109+
1110+
proc children do return (features, weight, bias);
1111+
1112+
proc forward() {
1113+
return ndarray.layerNorm(features.array, weight.array, bias.array);
1114+
}
1115+
1116+
proc spec : GradOpSpec do return new dict(("operation", "LayerNorm"));
1117+
}
1118+
11041119
record multiheadAttentionOp : serializable {
11051120
type eltType = real;
11061121
var features: shared BaseTensorResource(?);

lib/DynamicTensor.chpl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,29 @@ proc type dynamicTensor.batchnorm(
634634
return new dynamicTensor(eltType);
635635
}
636636

637+
proc type dynamicTensor.layerNorm(
638+
features: dynamicTensor(?eltType),
639+
weight: dynamicTensor(eltType),
640+
bias: dynamicTensor(eltType)
641+
// normalizedShape: (...?maxRank)
642+
): dynamicTensor(eltType) {
643+
// const n = normalizedShape.size;
644+
for param rankF in 2..4 {
645+
for param rankN in 1..4 {
646+
if features.checkRank(rankF) && weight.checkRank(rankN) && bias.checkRank(rankN) {
647+
return staticTensor.layerNorm(
648+
features.forceRank(rankF),
649+
weight.forceRank(rankN),
650+
bias.forceRank(rankN),
651+
rankN
652+
).eraseRank();
653+
}
654+
}
655+
}
656+
halt("Could not determine rank in dynamicTensor.layerNorm.");
657+
return new dynamicTensor(eltType);
658+
}
659+
637660
proc type dynamicTensor.multiheadAttention(
638661
features: dynamicTensor(?eltType),
639662
q_weight: dynamicTensor(eltType),

lib/NDArray.chpl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,16 @@ proc ndarray.variance(axes: int...?axesCount): ndarray(rank,eltType) {
532532
return ((this - this.mean((...axes)).expand((...shape)))**2).sum((...axes)) / (denom - 1);
533533
}
534534

535+
proc ndarray.variance(axes: int...?axesCount, correction: int): ndarray(rank,eltType) {
536+
const shape = this.shape;
537+
var denom: eltType = 0;
538+
for param i in 0..<axesCount {
539+
const reducedN = shape(axes(i));
540+
denom += reducedN : eltType;
541+
}
542+
return ((this - this.mean((...axes)).expand((...shape)))**2).sum((...axes)) / (denom - correction);
543+
}
544+
535545
proc ndarray.shrink(narg: 2*int ... rank,param exactBounds = false): ndarray(rank,eltType) {
536546
var newShape: rank * int;
537547
var sliceRanges: rank * range;
@@ -1387,6 +1397,16 @@ operator *(a: ndarray(?rank,?eltType),b: ndarray(rank,eltType)): ndarray(rank,el
13871397
return c;
13881398
}
13891399

1400+
operator **(a: ndarray(?rank,?eltType),b: real): ndarray(rank,eltType) {
1401+
const dom = a.domain;
1402+
var c: ndarray(rank,eltType) = new ndarray(a.domain,eltType);
1403+
ref cData = c.data;
1404+
const ref aData = a.data;
1405+
forall i in dom.every() do
1406+
cData[i] = (aData[i]**b):eltType;
1407+
return c;
1408+
}
1409+
13901410
operator -(a: ndarray(?rank, ?eltType)): ndarray(rank, eltType) {
13911411
const dom = a.domain;
13921412
var negged = new ndarray(dom, eltType);
@@ -2112,6 +2132,45 @@ proc type ndarray.batchNorm(
21122132
return outFeatures;
21132133
}
21142134

2135+
proc type ndarray.layerNorm(
2136+
features: ndarray(?rank,?eltType),
2137+
weight: ndarray(?n,eltType),
2138+
bias: ndarray(n,eltType)
2139+
): ndarray(rank,eltType) {
2140+
const fshape = features.shape;
2141+
const axis = rank - n - 1;
2142+
2143+
var args: n*int;
2144+
for i in 0..<n {
2145+
args[i] = i + axis + 1;
2146+
}
2147+
var avgs = features.mean((...args));
2148+
var vars = features.variance((...args), correction = 0);
2149+
2150+
ref f = features.data;
2151+
ref a = avgs.data;
2152+
ref v = vars.data;
2153+
ref w = weight.data;
2154+
ref b = bias.data;
2155+
2156+
var outDom = util.domainFromShape((...fshape));
2157+
var outFeatures = new ndarray(outDom,eltType);
2158+
ref dat = outFeatures.data;
2159+
2160+
forall idx in outDom.every() {
2161+
var c = idx;
2162+
var d: n*int;
2163+
for i in (axis + 1)..<rank {
2164+
c[i] = 0;
2165+
}
2166+
for i in 0..<n {
2167+
d[i] = idx[axis+1+i];
2168+
}
2169+
dat[idx] = ((f[idx] - a[c])/v[c])*w[d] + b[d];
2170+
}
2171+
return outFeatures;
2172+
}
2173+
21152174

21162175
inline proc type ndarray.fromRanges(type eltType = real, rngs: range...?rank): ndarray(rank,eltType) {
21172176
const dom_ = {(...rngs)};

lib/Network.chpl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,26 @@ class BatchNorm : Module(?) {
10801080
}
10811081
}
10821082

1083+
class LayerNorm : Module(?) {
1084+
var weight: owned Parameter(eltType);
1085+
var bias: owned Parameter(eltType);
1086+
var nShape;
1087+
proc init(type eltType = real, normalizedShape: ?nShapeRankP*int) {
1088+
this.weight = new Parameter(Tensor.ones((...normalizedShape)));
1089+
this.bias = new Parameter(Tensor.zeros((...normalizedShape)));
1090+
this.nShape = normalizedShape;
1091+
}
1092+
1093+
override proc forward(input: Tensor(eltType)): Tensor(eltType) {
1094+
return Tensor.layerNorm(input, weight.data, bias.data);
1095+
}
1096+
1097+
override proc setup() {
1098+
addModule("weight", weight);
1099+
addModule("bias", bias);
1100+
}
1101+
}
1102+
10831103
class MultiheadAttention : Module(?) {
10841104
var q_weight: owned Parameter(eltType);
10851105
var k_weight: owned Parameter(eltType);

lib/StaticTensor.chpl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,16 @@ proc type staticTensor.batchNorm(
453453
return tensorFromCtx(featureRank, eltType, ctx);
454454
}
455455

456+
proc type staticTensor.layerNorm(
457+
features: staticTensor(?featureRank,?eltType),
458+
weight: staticTensor(?n,eltType),
459+
bias: staticTensor(n,eltType),
460+
rankN: int
461+
): staticTensor(featureRank,eltType) {
462+
var ctx = new layerNormOp(eltType, features.meta, weight.meta, bias.meta);
463+
return tensorFromCtx(featureRank, eltType, ctx);
464+
}
465+
456466
proc type staticTensor.multiheadAttention(
457467
features: staticTensor(3, ?eltType),
458468
q_weight: staticTensor(2, eltType),

0 commit comments

Comments
 (0)