Skip to content

Commit 9b5e2a4

Browse files
committed
Add ndarray.matmul
1 parent 8300a45 commit 9b5e2a4

1 file changed

Lines changed: 57 additions & 0 deletions

File tree

lib/NDArray.chpl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,63 @@ proc type ndarray.matvecmul(mat: ndarray(2,?eltType),vec: ndarray(1,eltType)): n
19251925
return u;
19261926
}
19271927

1928+
proc type ndarray.matmul(a: ndarray(?aRank,?eltType),b: ndarray(?bRank,eltType)): ndarray(?) where (aRank >= 2 && bRank >= 2) {
1929+
assert(a.shape[aRank-1] == b.shape[bRank-2], "Invalid dimensions for matrix multiplication");
1930+
const outRank = if aRank >= bRank then aRank else bRank;
1931+
var mat1 = if aRank >= bRank then new ndarray(a.data) else new ndarray(b.data);
1932+
var preMat2 = if aRank < bRank then new ndarray(a.data) else new ndarray(b.data);
1933+
const first = if aRank >= bRank then 1 else 2;
1934+
1935+
var reshapeDims: mat1.rank*int;
1936+
var expandDims: mat1.rank*int;
1937+
const rankDiff = mat1.rank - preMat2.rank;
1938+
for i in 0..<mat1.rank {
1939+
if i < rankDiff {
1940+
reshapeDims[i] = 1;
1941+
expandDims[i] = mat1.shape[i];
1942+
} else {
1943+
reshapeDims[i] = preMat2.shape[i-rankDiff];
1944+
expandDims[i] = preMat2.shape[i-rankDiff];
1945+
}
1946+
}
1947+
var mat2 = preMat2.reshape((...reshapeDims)).expand((...expandDims));
1948+
1949+
var outShape: mat1.rank*int;
1950+
for i in 0..<mat1.rank {
1951+
if i == mat1.rank-1 {
1952+
outShape[i] = mat2.shape[mat2.rank - 1];
1953+
} else {
1954+
outShape[i] = mat1.shape[i];
1955+
}
1956+
}
1957+
var outDom = util.domainFromShape((...outShape));
1958+
var prod = new ndarray(outDom,eltType);
1959+
1960+
ref m1 = mat1.data;
1961+
ref m2 = mat2.data;
1962+
ref pd = prod.data;
1963+
1964+
for idx in outDom.every() {
1965+
const i = idx[outRank - 2];
1966+
const j = idx[outRank - 1];
1967+
var sum: eltType = 0;
1968+
for n in 0..<a.shape[aRank-1] {
1969+
var idxLM = idx;
1970+
idxLM[outRank-2] = i;
1971+
idxLM[outRank-1] = n;
1972+
var idxRM = idx;
1973+
idxRM[outRank-2] = n;
1974+
if first == 1 {
1975+
sum += m1[(...idxLM)]*m2[(...idxRM)];
1976+
} else {
1977+
sum += m2[(...idxLM)]*m1[(...idxRM)];
1978+
}
1979+
}
1980+
pd[idx] = sum;
1981+
}
1982+
return prod;
1983+
}
1984+
19281985
proc type ndarray.batchNormTrain(
19291986
features: ndarray(?rank,?eltType),
19301987
weight: ndarray(1,eltType),

0 commit comments

Comments
 (0)