@@ -2019,40 +2019,12 @@ proc type ndarray.multiheadAttention(
20192019 var q = ndarray.matmul(features,q_weight);
20202020 var k = ndarray.matmul(features,k_weight);
20212021 var v = ndarray.matmul(features,v_weight);
2022- var z = (ndarray.matmul(q,k.permute(0 ,2 ,1 ))/ Math.sqrt (head_dim))._softmax (axis= 2 );
2022+ var z = (ndarray.matmul(q,k.permute(0 ,2 ,1 ))/ Math.sqrt (head_dim)).softmax (axis= 2 );
20232023 var a = ndarray.matmul(z,v);
20242024
20252025 return a;
20262026}
20272027
2028- /* Softmax, but with the option to specify the axis.
2029-
2030- :returns: For a tensor ``t``, :math:`\frac{\exp{t}}{\Sigma \exp{t}}`.
2031- :rtype: ndarray(rank, eltType)
2032- */
2033- proc ndarray._softmax(axis: int = (this .rank- 1 )): ndarray(this .rank, this .eltType)
2034- // where isSubtype(this.eltType, real)
2035- {
2036- const dom = this .domain ;
2037- var exps = new ndarray(this .eltType, dom);
2038- var smxd = new ndarray(this .eltType, dom);
2039- const ref thisData = this .data;
2040- ref expsData = exps.data;
2041- ref outData = smxd.data;
2042-
2043- forall i in dom.every() {
2044- expsData[ i] = Math.exp(thisData[ i] );
2045- }
2046- var sums = exps.sum(axis).expand((...this .shape));
2047- ref sumsData = sums.data;
2048-
2049- forall i in dom.every() {
2050- outData[ i] = expsData[ i] / sumsData[ i] ;
2051- }
2052-
2053- return smxd;
2054- }
2055-
20562028proc type ndarray.batchNormTrain(
20572029 features: ndarray(?rank,?eltType),
20582030 weight: ndarray(1 ,eltType),
@@ -2603,24 +2575,27 @@ proc type ndarray.einsum(param subscripts: string,a: ndarray(?rankA,?eltType), b
26032575 :returns: For a tensor ``t``, :math:`\frac{\exp{t}}{\Sigma \exp{t}}`.
26042576 :rtype: ndarray(rank, eltType)
26052577*/
2606- proc ndarray.softmax(): ndarray(this .rank, this .eltType)
2607- where isSubtype(this .eltType, real )
2578+ proc ndarray.softmax(axis : int = ( this .rank - 1 ) ): ndarray(this .rank, this .eltType)
2579+ // where isSubtype(this.eltType, real)
26082580{
26092581 const dom = this .domain ;
2582+ var exps = new ndarray(this .eltType, dom);
2583+ var outs = new ndarray(this .eltType, dom);
26102584 const ref thisData = this .data;
2585+ ref expsData = exps.data;
2586+ ref outsData = outs.data;
26112587
2612- var denom: this .eltType = 0.0 ;
2613- forall i in dom.every() with (+ reduce denom) {
2614- denom += Math.exp(thisData[ i] );
2588+ forall i in dom.every() {
2589+ expsData[ i] = Math.exp(thisData[ i] );
26152590 }
2591+ var sums = exps.sum(axis).expand((...this .shape));
2592+ ref sumsData = sums.data;
26162593
2617- var softmaxxed = new ndarray(this .eltType, dom);
2618- ref softmaxData = softmaxxed.data;
26192594 forall i in dom.every() {
2620- softmaxData [ i] = Math.exp(thisData [ i] ) / denom ;
2595+ outsData [ i] = expsData [ i] / sumsData [ i ] ;
26212596 }
26222597
2623- return softmaxxed ;
2598+ return outs ;
26242599}
26252600
26262601
0 commit comments