Skip to content

Commit 0408ba0

Browse files
thesuryashawni
andauthored
Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm (#2220)
* Implementing Complex Matmul using Karatsuba Algorithm * Implemented Karatsuba's Algorithm for complex matmul and pre-commit them * fix --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent cbad6c3 commit 0408ba0

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

mlx/ops.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2862,21 +2862,30 @@ array matmul(
28622862
<< " second input with shape " << b.shape() << ".";
28632863
throw std::invalid_argument(msg.str());
28642864
}
2865-
// Type promotion
2866-
auto out_type = promote_types(a.dtype(), b.dtype());
2867-
// Complex matmul in terms of real matmuls
2868-
if (out_type == complex64) {
2865+
2866+
// complex matmul using Karatsuba's Algorithm
2867+
if (a.dtype() == complex64 || b.dtype() == complex64) {
2868+
// Extract real and imaginary parts
28692869
auto a_real = real(a, s);
2870-
auto b_real = real(b, s);
28712870
auto a_imag = imag(a, s);
2871+
auto b_real = real(b, s);
28722872
auto b_imag = imag(b, s);
2873-
auto c_real =
2874-
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
2875-
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
2873+
2874+
// Compute real and imaginary components of the result
2875+
auto m1 = matmul(a_real, b_real, s);
2876+
auto m2 = matmul(a_imag, b_imag, s);
2877+
auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s);
2878+
2879+
auto c_real = subtract(m1, m2, s);
2880+
auto c_imag = subtract(m3, add(m1, m2, s), s);
2881+
28762882
return add(
28772883
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
28782884
}
28792885

2886+
// Type promotion
2887+
auto out_type = promote_types(a.dtype(), b.dtype());
2888+
28802889
if (!issubdtype(out_type, floating)) {
28812890
std::ostringstream msg;
28822891
msg << "[matmul] Only real floating point types are supported but "

python/tests/test_blas.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,20 +1210,20 @@ def rand(shape):
12101210
self.assertTrue(np.allclose(c, c_np))
12111211

12121212
# Test addmm
1213-
M = 16
1214-
K = 50
1215-
N = 32
1216-
1217-
def rand(shape):
1218-
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
1219-
12201213
a = rand((M, K))
12211214
b = rand((K, N))
12221215
c = rand((M, N))
12231216
out = mx.addmm(c, a, b, 2.0, 2.0)
12241217
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
12251218
self.assertTrue(np.allclose(out, out_np))
12261219

1220+
# complex with real
1221+
a = rand((M, K)).real
1222+
b = rand((K, N))
1223+
c = mx.matmul(a, b)
1224+
c_np = np.matmul(a, b)
1225+
self.assertTrue(np.allclose(out, out_np))
1226+
12271227

12281228
if __name__ == "__main__":
12291229
unittest.main()

0 commit comments

Comments
 (0)