Skip to content

Commit 36369ec

Browse files
committed
Update softmax to apply along an axis
1 parent 0684044 commit 36369ec

File tree

1 file changed

+13
-38
lines changed

1 file changed

+13
-38
lines changed

lib/NDArray.chpl

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2019,40 +2019,12 @@ proc type ndarray.multiheadAttention(
20192019
var q = ndarray.matmul(features,q_weight);
20202020
var k = ndarray.matmul(features,k_weight);
20212021
var v = ndarray.matmul(features,v_weight);
2022-
var z = (ndarray.matmul(q,k.permute(0,2,1))/Math.sqrt(head_dim))._softmax(axis=2);
2022+
var z = (ndarray.matmul(q,k.permute(0,2,1))/Math.sqrt(head_dim)).softmax(axis=2);
20232023
var a = ndarray.matmul(z,v);
20242024

20252025
return a;
20262026
}
20272027

2028-
/* Softmax, but with the option to specify the axis.
2029-
2030-
:returns: For a tensor ``t``, :math:`\frac{\exp{t}}{\Sigma \exp{t}}`.
2031-
:rtype: ndarray(rank, eltType)
2032-
*/
2033-
proc ndarray._softmax(axis:int = (this.rank-1)): ndarray(this.rank, this.eltType)
2034-
// where isSubtype(this.eltType, real)
2035-
{
2036-
const dom = this.domain;
2037-
var exps = new ndarray(this.eltType, dom);
2038-
var smxd = new ndarray(this.eltType, dom);
2039-
const ref thisData = this.data;
2040-
ref expsData = exps.data;
2041-
ref outData = smxd.data;
2042-
2043-
forall i in dom.every() {
2044-
expsData[i] = Math.exp(thisData[i]);
2045-
}
2046-
var sums = exps.sum(axis).expand((...this.shape));
2047-
ref sumsData = sums.data;
2048-
2049-
forall i in dom.every() {
2050-
outData[i] = expsData[i]/sumsData[i];
2051-
}
2052-
2053-
return smxd;
2054-
}
2055-
20562028
proc type ndarray.batchNormTrain(
20572029
features: ndarray(?rank,?eltType),
20582030
weight: ndarray(1,eltType),
@@ -2603,24 +2575,27 @@ proc type ndarray.einsum(param subscripts: string,a: ndarray(?rankA,?eltType), b
26032575
:returns: For a tensor ``t``, :math:`\frac{\exp{t}}{\Sigma \exp{t}}`.
26042576
:rtype: ndarray(rank, eltType)
26052577
*/
2606-
proc ndarray.softmax(): ndarray(this.rank, this.eltType)
2607-
where isSubtype(this.eltType, real)
2578+
proc ndarray.softmax(axis:int = (this.rank-1)): ndarray(this.rank, this.eltType)
2579+
// where isSubtype(this.eltType, real)
26082580
{
26092581
const dom = this.domain;
2582+
var exps = new ndarray(this.eltType, dom);
2583+
var outs = new ndarray(this.eltType, dom);
26102584
const ref thisData = this.data;
2585+
ref expsData = exps.data;
2586+
ref outsData = outs.data;
26112587

2612-
var denom: this.eltType = 0.0;
2613-
forall i in dom.every() with (+ reduce denom) {
2614-
denom += Math.exp(thisData[i]);
2588+
forall i in dom.every() {
2589+
expsData[i] = Math.exp(thisData[i]);
26152590
}
2591+
var sums = exps.sum(axis).expand((...this.shape));
2592+
ref sumsData = sums.data;
26162593

2617-
var softmaxxed = new ndarray(this.eltType, dom);
2618-
ref softmaxData = softmaxxed.data;
26192594
forall i in dom.every() {
2620-
softmaxData[i] = Math.exp(thisData[i]) / denom;
2595+
outsData[i] = expsData[i]/sumsData[i];
26212596
}
26222597

2623-
return softmaxxed;
2598+
return outs;
26242599
}
26252600

26262601

0 commit comments

Comments
 (0)