@@ -2005,6 +2005,34 @@ proc type ndarray.multiheadAttention(
20052005 return a;
20062006}
20072007
2008+ /* Softmax, but with the option to specify the axis.
2009+
2010+ :returns: For a tensor ``t``, :math:`\frac{\exp{t}}{\Sigma \exp{t}}`.
2011+ :rtype: ndarray(rank, eltType)
2012+ */
2013+ proc ndarray._softmax(axis: int = (this .rank- 1 )): ndarray(this .rank, this .eltType)
2014+ // where isSubtype(this.eltType, real)
2015+ {
2016+ const dom = this .domain ;
2017+ var exps = new ndarray(this .eltType, dom);
2018+ var smxd = new ndarray(this .eltType, dom);
2019+ const ref thisData = this .data;
2020+ ref expsData = exps.data;
2021+ ref outData = smxd.data;
2022+
2023+ forall i in dom.every() {
2024+ expsData[ i] = Math.exp(thisData[ i] );
2025+ }
2026+ var sums = exps.sum(axis).expand((...this .shape));
2027+ ref sumsData = sums.data;
2028+
2029+ forall i in dom.every() {
2030+ outData[ i] = expsData[ i] / sumsData[ i] ;
2031+ }
2032+
2033+ return smxd;
2034+ }
2035+
20082036proc type ndarray.batchNormTrain(
20092037 features: ndarray(?rank,?eltType),
20102038 weight: ndarray(1 ,eltType),
0 commit comments