@@ -1982,6 +1982,29 @@ proc type ndarray.matmul(a: ndarray(?aRank,?eltType),b: ndarray(?bRank,eltType))
19821982 return prod;
19831983}
19841984
1985+ // Supports 1 head as of now.
1986+ proc type ndarray.multiheadAttention(
1987+ features: ndarray(3 , ?eltType),
1988+ q_weight: ndarray(2 , eltType),
1989+ k_weight: ndarray(2 , eltType),
1990+ v_weight: ndarray(2 , eltType),
1991+ num_heads: int ,
1992+ embed_dim: int
1993+ ): ndarray(3 , eltType) {
1994+ const fshape = features.shape;
1995+ const seq_len = fshape[ 0 ] ;
1996+ const batch_size = fshape[ 1 ] ;
1997+ const head_dim = embed_dim / 1 ;
1998+
1999+ var q = ndarray.matmul(features,q_weight);
2000+ var k = ndarray.matmul(features,k_weight);
2001+ var v = ndarray.matmul(features,v_weight);
2002+ var z = (ndarray.matmul(q,k.permute(0 ,2 ,1 ))/ Math.sqrt (head_dim))._softmax(axis= 2 );
2003+ var a = ndarray.matmul(z,v);
2004+
2005+ return a;
2006+ }
2007+
19852008proc type ndarray.batchNormTrain(
19862009 features: ndarray(?rank,?eltType),
19872010 weight: ndarray(1 ,eltType),
0 commit comments