Skip to content

Commit 99a9a17

Browse files
committed
Add softmax along specified axis
1 parent b15a558 commit 99a9a17

1 file changed

Lines changed: 28 additions & 0 deletions

File tree

lib/NDArray.chpl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,34 @@ proc type ndarray.multiheadAttention(
20052005
return a;
20062006
}
20072007

2008+
/* Softmax, but with the option to specify the axis.
2009+
2010+
:returns: For a tensor ``t``, :math:`\frac{\exp{t}}{\Sigma \exp{t}}`.
2011+
:rtype: ndarray(rank, eltType)
2012+
*/
2013+
proc ndarray._softmax(axis:int = (this.rank-1)): ndarray(this.rank, this.eltType)
2014+
// where isSubtype(this.eltType, real)
2015+
{
2016+
const dom = this.domain;
2017+
var exps = new ndarray(this.eltType, dom);
2018+
var smxd = new ndarray(this.eltType, dom);
2019+
const ref thisData = this.data;
2020+
ref expsData = exps.data;
2021+
ref outData = smxd.data;
2022+
2023+
forall i in dom.every() {
2024+
expsData[i] = Math.exp(thisData[i]);
2025+
}
2026+
var sums = exps.sum(axis).expand((...this.shape));
2027+
ref sumsData = sums.data;
2028+
2029+
forall i in dom.every() {
2030+
outData[i] = expsData[i]/sumsData[i];
2031+
}
2032+
2033+
return smxd;
2034+
}
2035+
20082036
proc type ndarray.batchNormTrain(
20092037
features: ndarray(?rank,?eltType),
20102038
weight: ndarray(1,eltType),

0 commit comments

Comments
 (0)