@@ -532,6 +532,16 @@ proc ndarray.variance(axes: int...?axesCount): ndarray(rank,eltType) {
532532 return ((this - this .mean((...axes)).expand((...shape)))** 2 ).sum((...axes)) / (denom - 1 );
533533}
534534
535+ proc ndarray.variance(axes: int ...?axesCount, correction: int ): ndarray(rank,eltType) {
536+ const shape = this .shape;
537+ var denom: eltType = 0 ;
538+ for param i in 0 ..< axesCount {
539+ const reducedN = shape(axes(i));
540+ denom += reducedN : eltType;
541+ }
542+ return ((this - this .mean((...axes)).expand((...shape)))** 2 ).sum((...axes)) / (denom - correction);
543+ }
544+
535545proc ndarray.shrink(narg: 2 * int ... rank,param exactBounds = false ): ndarray(rank,eltType) {
536546 var newShape: rank * int ;
537547 var sliceRanges: rank * range ;
@@ -1387,6 +1397,16 @@ operator *(a: ndarray(?rank,?eltType),b: ndarray(rank,eltType)): ndarray(rank,el
13871397 return c;
13881398}
13891399
1400+ operator ** (a: ndarray(?rank,?eltType),b: real ): ndarray(rank,eltType) {
1401+ const dom = a.domain ;
1402+ var c: ndarray(rank,eltType) = new ndarray(a.domain ,eltType);
1403+ ref cData = c.data;
1404+ const ref aData = a.data;
1405+ forall i in dom.every() do
1406+ cData[ i] = (aData[ i] ** b): eltType;
1407+ return c;
1408+ }
1409+
13901410operator - (a: ndarray(?rank, ?eltType)): ndarray(rank, eltType) {
13911411 const dom = a.domain ;
13921412 var negged = new ndarray(dom, eltType);
@@ -2112,6 +2132,45 @@ proc type ndarray.batchNorm(
21122132 return outFeatures;
21132133}
21142134
2135+ proc type ndarray.layerNorm(
2136+ features: ndarray(?rank,?eltType),
2137+ weight: ndarray(?n,eltType),
2138+ bias: ndarray(n,eltType)
2139+ ): ndarray(rank,eltType) {
2140+ const fshape = features.shape;
2141+ const axis = rank - n - 1 ;
2142+
2143+ var args: n* int ;
2144+ for i in 0 ..< n {
2145+ args[ i] = i + axis + 1 ;
2146+ }
2147+ var avgs = features.mean((...args));
2148+ var vars = features.variance((...args), correction = 0 );
2149+
2150+ ref f = features.data;
2151+ ref a = avgs.data;
2152+ ref v = vars.data;
2153+ ref w = weight.data;
2154+ ref b = bias.data;
2155+
2156+ var outDom = util.domainFromShape((...fshape));
2157+ var outFeatures = new ndarray(outDom,eltType);
2158+ ref dat = outFeatures.data;
2159+
2160+ forall idx in outDom.every() {
2161+ var c = idx;
2162+ var d: n* int ;
2163+ for i in (axis + 1 )..< rank {
2164+ c[ i] = 0 ;
2165+ }
2166+ for i in 0 ..< n {
2167+ d[ i] = idx[ axis+ 1 + i] ;
2168+ }
2169+ dat[ idx] = ((f[ idx] - a[ c] )/ v[ c] )* w[ d] + b[ d] ;
2170+ }
2171+ return outFeatures;
2172+ }
2173+
21152174
21162175inline proc type ndarray.fromRanges(type eltType = real , rngs: range ...?rank): ndarray(rank,eltType) {
21172176 const dom_ = {(...rngs)};
0 commit comments