Skip to content

Commit 50d8bed

Browse files
authored
Fused attention for single query (#1497)
1 parent 9dd72cd commit 50d8bed

File tree

6 files changed

+304
-747
lines changed

6 files changed

+304
-747
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import argparse
2+
import math
3+
4+
import mlx.core as mx
5+
from time_utils import time_fn
6+
7+
L = 1024
8+
H = 32
9+
H_k = 32 // 4
10+
D = 128
11+
12+
13+
def attention(q, k, v):
14+
B, Hq, L, D = q.shape
15+
_, Hk, S, _ = k.shape
16+
q = q.reshape(B, Hk, Hq // Hk, L, D)
17+
k = k[:, :, None, :, :]
18+
v = v[:, :, None, :, :]
19+
s = q @ k.transpose(0, 1, 2, 4, 3)
20+
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
21+
o = p @ v
22+
return o.reshape(B, Hq, L, D)
23+
24+
25+
def sdpa(q, k, v):
26+
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
27+
28+
29+
def time_self_attention_primitives():
30+
mx.random.seed(3)
31+
q = mx.random.uniform(shape=(1, H, 1, D))
32+
k = mx.random.uniform(shape=(1, H_k, L, D))
33+
v = mx.random.uniform(shape=(1, H_k, L, D))
34+
mx.eval(q, k, v)
35+
time_fn(attention, q, k, v)
36+
37+
38+
def time_self_attention_sdpa():
39+
mx.random.seed(3)
40+
q = mx.random.uniform(shape=(1, H, 1, D))
41+
k = mx.random.uniform(shape=(1, H_k, L, D))
42+
v = mx.random.uniform(shape=(1, H_k, L, D))
43+
mx.eval(q, k, v)
44+
time_fn(sdpa, q, k, v)
45+
46+
47+
if __name__ == "__main__":
48+
time_self_attention_sdpa()
49+
time_self_attention_primitives()

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ build_kernel(layer_norm)
3030
build_kernel(random)
3131
build_kernel(rms_norm)
3232
build_kernel(rope)
33-
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
34-
steel/defines.h steel/gemm/transforms.h steel/utils.h)
33+
build_kernel(
34+
scaled_dot_product_attention scaled_dot_product_attention_params.h
35+
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
3536

3637
set(STEEL_HEADERS
3738
steel/defines.h

0 commit comments

Comments
 (0)