Skip to content

Commit baf9fa5

Browse files
awnidc-dc-dcangeloskath
authored
Einsum (#1269)
* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <[email protected]> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <[email protected]> Co-authored-by: Angelos Katharopoulos <[email protected]>
1 parent 7f91436 commit baf9fa5

File tree

13 files changed

+1498
-65
lines changed

13 files changed

+1498
-65
lines changed

ACKNOWLEDGMENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals:
1010
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
1111
- Juarez Bochi: Fixed bug in cross attention.
1212
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
13-
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
13+
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
1414
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
1515
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
1616
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.

benchmarks/python/einsum_bench.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright © 2024 Apple Inc.
2+
3+
import time
4+
5+
import mlx.core as mx
6+
import numpy as np
7+
8+
9+
def timeit(fn, its=100, args=[]):
10+
for _ in range(5):
11+
fn(*args)
12+
tic = time.perf_counter()
13+
for _ in range(its):
14+
fn(*args)
15+
toc = time.perf_counter()
16+
return 1e3 * (toc - tic) / its
17+
18+
19+
def time_little_einsum_path():
20+
subscripts = "ik,kj->ij"
21+
x = mx.ones((32, 32))
22+
y = mx.ones((32, 32))
23+
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
24+
25+
x = np.array(x)
26+
y = np.array(y)
27+
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
28+
print("Timing little einsum path...")
29+
print(f"MLX ... {mx_time:.3f} ms")
30+
print(f"NumPy... {np_time:.3f} ms")
31+
32+
33+
def time_big_einsum_path():
34+
chars = list("abcdefgh")
35+
char_to_dim = {c: v for v, c in enumerate(chars)}
36+
37+
num_inputs = 10
38+
inputs = []
39+
subscripts = []
40+
for _ in range(num_inputs):
41+
subscript = np.random.choice(chars, size=5, replace=False).tolist()
42+
subscripts.append("".join(subscript))
43+
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
44+
subscripts = ",".join(subscripts)
45+
46+
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
47+
48+
inputs = [mx.array(x) for x in inputs]
49+
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
50+
print("Timing big einsum path...")
51+
print(f"MLX ... {mx_time:.3f} ms")
52+
print(f"NumPy... {np_time:.3f} ms")
53+
54+
55+
def time_attention():
56+
def regular_attention(x):
57+
# shape [batch, sequence, num_heads, head_dim]
58+
queries, keys, values = x, x, x
59+
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
60+
scores = mx.softmax(scores, axis=-1)
61+
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
62+
mx.eval(output)
63+
64+
def einsum_attention(x):
65+
# shape [batch, sequence, num_heads, head_dim]
66+
queries, keys, values = x, x, x
67+
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
68+
scores = mx.softmax(scores, axis=-1)
69+
output = mx.einsum("ijtu,iujk->itjk", scores, values)
70+
mx.eval(output)
71+
72+
x = mx.random.uniform(shape=(8, 512, 32, 128))
73+
74+
regular_time = timeit(regular_attention, args=(x,))
75+
ein_time = timeit(einsum_attention, args=(x,))
76+
print("Timing einsum attention...")
77+
print(f"Regular ... {regular_time:.3f} ms")
78+
print(f"Einsum ... {ein_time:.3f} ms")
79+
80+
81+
if __name__ == "__main__":
82+
time_little_einsum_path()
83+
time_big_einsum_path()
84+
time_attention()

docs/src/python/ops.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ Operations
5757
diagonal
5858
divide
5959
divmod
60+
einsum
61+
einsum_path
6062
equal
6163
erf
6264
erfinv

mlx/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_sources(
66
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
77
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
88
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
9+
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
910
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
1011
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
1112
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

0 commit comments

Comments
 (0)