@@ -1925,6 +1925,63 @@ proc type ndarray.matvecmul(mat: ndarray(2,?eltType),vec: ndarray(1,eltType)): n
19251925 return u;
19261926}
19271927
1928+ proc type ndarray.matmul(a: ndarray(?aRank,?eltType),b: ndarray(?bRank,eltType)): ndarray(?) where (aRank >= 2 && bRank >= 2 ) {
1929+ assert(a.shape[ aRank- 1 ] == b.shape[ bRank- 2 ] , " Invalid dimensions for matrix multiplication" );
1930+ const outRank = if aRank >= bRank then aRank else bRank;
1931+ var mat1 = if aRank >= bRank then new ndarray(a.data) else new ndarray(b.data);
1932+ var preMat2 = if aRank < bRank then new ndarray(a.data) else new ndarray(b.data);
1933+ const first = if aRank >= bRank then 1 else 2 ;
1934+
1935+ var reshapeDims: mat1.rank* int ;
1936+ var expandDims: mat1.rank* int ;
1937+ const rankDiff = mat1.rank - preMat2.rank;
1938+ for i in 0 ..< mat1.rank {
1939+ if i < rankDiff {
1940+ reshapeDims[ i] = 1 ;
1941+ expandDims[ i] = mat1.shape[ i] ;
1942+ } else {
1943+ reshapeDims[ i] = preMat2.shape[ i- rankDiff] ;
1944+ expandDims[ i] = preMat2.shape[ i- rankDiff] ;
1945+ }
1946+ }
1947+ var mat2 = preMat2.reshape((...reshapeDims)).expand((...expandDims));
1948+
1949+ var outShape: mat1.rank* int ;
1950+ for i in 0 ..< mat1.rank {
1951+ if i == mat1.rank- 1 {
1952+ outShape[ i] = mat2.shape[ mat2.rank - 1 ] ;
1953+ } else {
1954+ outShape[ i] = mat1.shape[ i] ;
1955+ }
1956+ }
1957+ var outDom = util.domainFromShape((...outShape));
1958+ var prod = new ndarray(outDom,eltType);
1959+
1960+ ref m1 = mat1.data;
1961+ ref m2 = mat2.data;
1962+ ref pd = prod.data;
1963+
1964+ for idx in outDom.every() {
1965+ const i = idx[ outRank - 2 ] ;
1966+ const j = idx[ outRank - 1 ] ;
1967+ var sum: eltType = 0 ;
1968+ for n in 0 ..< a.shape[ aRank- 1 ] {
1969+ var idxLM = idx;
1970+ idxLM[ outRank- 2 ] = i;
1971+ idxLM[ outRank- 1 ] = n;
1972+ var idxRM = idx;
1973+ idxRM[ outRank- 2 ] = n;
1974+ if first == 1 {
1975+ sum += m1[ (...idxLM)] * m2[ (...idxRM)] ;
1976+ } else {
1977+ sum += m2[ (...idxLM)] * m1[ (...idxRM)] ;
1978+ }
1979+ }
1980+ pd[ idx] = sum;
1981+ }
1982+ return prod;
1983+ }
1984+
19281985proc type ndarray.batchNormTrain(
19291986 features: ndarray(?rank,?eltType),
19301987 weight: ndarray(1 ,eltType),
0 commit comments