Skip to content

Commit 5bba7f2

Browse files
committed
add sdpa_vjp_bench
1 parent 9008d2e commit 5bba7f2

File tree

1 file changed

+271
-0
lines changed

1 file changed

+271
-0
lines changed
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright © 2024-25 Apple Inc.
2+
"""
3+
Benchmark SDPA VJP: Fused Flash Attention vs Unfused Fallback
4+
5+
This benchmark measures the performance improvement from the fused VJP
6+
implementation for scaled dot product attention backward pass.
7+
"""
8+
9+
import argparse
10+
import time
11+
import mlx.core as mx
12+
13+
N_warmup = 10
14+
N_iter = 50
15+
16+
17+
def bench(f, *args):
18+
"""Warmup then time the function"""
19+
for _ in range(N_warmup):
20+
result = f(*args)
21+
mx.eval(result)
22+
23+
mx.synchronize()
24+
start = time.perf_counter()
25+
for _ in range(N_iter):
26+
result = f(*args)
27+
mx.eval(result)
28+
mx.synchronize()
29+
return (time.perf_counter() - start) / N_iter * 1000 # ms
30+
31+
32+
def mlx_ref_attn(q, k, v, scale):
33+
"""Reference unfused attention implementation"""
34+
n_q_heads = q.shape[-3]
35+
n_kv_heads = k.shape[-3]
36+
n_repeats = n_q_heads // n_kv_heads
37+
38+
B = q.shape[0]
39+
L = q.shape[2]
40+
41+
if n_repeats > 1:
42+
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
43+
k = mx.expand_dims(k, 2)
44+
v = mx.expand_dims(v, 2)
45+
46+
scores = (q * scale) @ mx.swapaxes(k, -1, -2)
47+
weights = mx.softmax(scores, axis=-1)
48+
out = weights @ v
49+
50+
if n_repeats > 1:
51+
out = mx.reshape(out, [B, n_q_heads, L, -1])
52+
53+
return out
54+
55+
56+
def run_forward_benchmark(B, H_q, H_kv, L, D, dtype=mx.float16):
57+
"""Benchmark forward pass only"""
58+
scale = D**-0.5
59+
60+
q = mx.random.normal((B, H_q, L, D), dtype=dtype)
61+
k = mx.random.normal((B, H_kv, L, D), dtype=dtype)
62+
v = mx.random.normal((B, H_kv, L, D), dtype=dtype)
63+
mx.eval(q, k, v)
64+
65+
def unfused_fwd():
66+
return mlx_ref_attn(q, k, v, scale)
67+
68+
def fused_fwd():
69+
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
70+
71+
t_unfused = bench(unfused_fwd)
72+
t_fused = bench(fused_fwd)
73+
74+
return t_unfused, t_fused
75+
76+
77+
def run_vjp_benchmark(B, H_q, H_kv, L, D, dtype=mx.float16):
78+
"""Benchmark forward + backward (VJP) pass"""
79+
scale = D**-0.5
80+
81+
q = mx.random.normal((B, H_q, L, D), dtype=dtype)
82+
k = mx.random.normal((B, H_kv, L, D), dtype=dtype)
83+
v = mx.random.normal((B, H_kv, L, D), dtype=dtype)
84+
mx.eval(q, k, v)
85+
86+
# Unfused forward+backward
87+
def unfused_fwd_bwd():
88+
def loss(q, k, v):
89+
return mlx_ref_attn(q, k, v, scale).sum()
90+
91+
return mx.grad(loss)(q, k, v)
92+
93+
# Fused forward+backward
94+
def fused_fwd_bwd():
95+
def loss(q, k, v):
96+
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale).sum()
97+
98+
return mx.grad(loss)(q, k, v)
99+
100+
t_unfused = bench(unfused_fwd_bwd)
101+
t_fused = bench(fused_fwd_bwd)
102+
103+
return t_unfused, t_fused
104+
105+
106+
def run_backward_only_benchmark(B, H_q, H_kv, L, D, dtype=mx.float16):
107+
"""Benchmark backward pass only (isolate VJP performance)"""
108+
scale = D**-0.5
109+
110+
q = mx.random.normal((B, H_q, L, D), dtype=dtype)
111+
k = mx.random.normal((B, H_kv, L, D), dtype=dtype)
112+
v = mx.random.normal((B, H_kv, L, D), dtype=dtype)
113+
cotan = mx.ones((B, H_q, L, D), dtype=dtype)
114+
mx.eval(q, k, v, cotan)
115+
116+
# Unfused backward
117+
def unfused_bwd():
118+
_, grads = mx.vjp(lambda q, k, v: mlx_ref_attn(q, k, v, scale), [q, k, v], [cotan])
119+
return grads
120+
121+
# Fused backward
122+
def fused_bwd():
123+
_, grads = mx.vjp(
124+
lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale),
125+
[q, k, v],
126+
[cotan],
127+
)
128+
return grads
129+
130+
t_unfused = bench(unfused_bwd)
131+
t_fused = bench(fused_bwd)
132+
133+
return t_unfused, t_fused
134+
135+
136+
def verify_correctness(B, H_q, H_kv, L, D, dtype=mx.float16):
137+
"""Verify that fused and unfused produce matching gradients"""
138+
scale = D**-0.5
139+
140+
q = mx.random.normal((B, H_q, L, D), dtype=dtype)
141+
k = mx.random.normal((B, H_kv, L, D), dtype=dtype)
142+
v = mx.random.normal((B, H_kv, L, D), dtype=dtype)
143+
cotan = mx.ones((B, H_q, L, D), dtype=dtype)
144+
145+
_, ref_grads = mx.vjp(lambda q, k, v: mlx_ref_attn(q, k, v, scale), [q, k, v], [cotan])
146+
_, fused_grads = mx.vjp(
147+
lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale),
148+
[q, k, v],
149+
[cotan],
150+
)
151+
152+
rtol, atol = (1e-2, 1e-2) if dtype != mx.float32 else (1e-4, 1e-4)
153+
all_match = True
154+
for i, (r, f) in enumerate(zip(ref_grads, fused_grads)):
155+
if not mx.allclose(r, f, rtol=rtol, atol=atol):
156+
max_diff = mx.max(mx.abs(r - f)).item()
157+
print(f" WARNING: Gradient {['dQ', 'dK', 'dV'][i]} mismatch, max_diff={max_diff:.2e}")
158+
all_match = False
159+
160+
return all_match
161+
162+
163+
def main():
164+
parser = argparse.ArgumentParser(description="Benchmark SDPA VJP performance")
165+
parser.add_argument(
166+
"--mode",
167+
choices=["vjp", "forward", "backward", "all"],
168+
default="vjp",
169+
help="Benchmark mode: vjp (fwd+bwd), forward only, backward only, or all",
170+
)
171+
parser.add_argument("--verify", action="store_true", help="Verify correctness before benchmarking")
172+
parser.add_argument("--dtype", choices=["float16", "bfloat16", "float32"], default="float16")
173+
parser.add_argument("--quick", action="store_true", help="Run quick subset of benchmarks")
174+
args = parser.parse_args()
175+
176+
dtype = getattr(mx, args.dtype)
177+
dtype_str = args.dtype[:4] if len(args.dtype) > 4 else args.dtype
178+
179+
# Configurations to benchmark
180+
# (B, H_q, H_kv, L, D)
181+
if args.quick:
182+
configs = [
183+
# Vector path (L <= 8)
184+
(2, 8, 8, 1, 64),
185+
(2, 8, 8, 8, 128),
186+
# STEEL path (L > 8)
187+
(2, 8, 8, 128, 64),
188+
(2, 8, 8, 512, 128),
189+
(1, 32, 8, 1024, 128),
190+
]
191+
else:
192+
configs = [
193+
# Vector path (L <= 8) - short sequences
194+
(2, 8, 8, 1, 64),
195+
(2, 8, 8, 4, 64),
196+
(2, 8, 8, 8, 64),
197+
(2, 8, 8, 8, 128),
198+
# STEEL path - medium sequences
199+
(2, 8, 8, 32, 64),
200+
(2, 8, 8, 64, 64),
201+
(2, 8, 8, 128, 64),
202+
(2, 8, 8, 128, 128),
203+
(2, 8, 8, 256, 128),
204+
# STEEL path - long sequences
205+
(1, 32, 8, 512, 64),
206+
(1, 32, 8, 512, 128),
207+
(1, 32, 8, 1024, 64),
208+
(1, 32, 8, 1024, 128),
209+
(1, 32, 8, 2048, 128),
210+
# GQA configurations
211+
(2, 32, 8, 256, 64), # 4:1 GQA
212+
(2, 32, 4, 256, 64), # 8:1 GQA
213+
]
214+
215+
print(f"SDPA VJP Benchmark - dtype={args.dtype}")
216+
print("=" * 85)
217+
218+
if args.mode in ["vjp", "all"]:
219+
print("\n[Forward + Backward (VJP)]")
220+
print(f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}")
221+
print("-" * 85)
222+
223+
for B, H_q, H_kv, L, D in configs:
224+
if args.verify:
225+
correct = verify_correctness(B, H_q, H_kv, L, D, dtype)
226+
if not correct:
227+
continue
228+
229+
t_unfused, t_fused = run_vjp_benchmark(B, H_q, H_kv, L, D, dtype)
230+
speedup = t_unfused / t_fused
231+
path = "vector" if L <= 8 else "STEEL"
232+
print(
233+
f"{B:3d} {H_q:4d} {H_kv:5d} {L:6d} {D:4d} | {t_unfused:9.2f}ms {t_fused:9.2f}ms {speedup:7.2f}x {path:>8}"
234+
)
235+
236+
if args.mode in ["forward", "all"]:
237+
print("\n[Forward Only]")
238+
print(f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}")
239+
print("-" * 85)
240+
241+
for B, H_q, H_kv, L, D in configs:
242+
t_unfused, t_fused = run_forward_benchmark(B, H_q, H_kv, L, D, dtype)
243+
speedup = t_unfused / t_fused
244+
path = "vector" if L <= 8 else "STEEL"
245+
print(
246+
f"{B:3d} {H_q:4d} {H_kv:5d} {L:6d} {D:4d} | {t_unfused:9.2f}ms {t_fused:9.2f}ms {speedup:7.2f}x {path:>8}"
247+
)
248+
249+
if args.mode in ["backward", "all"]:
250+
print("\n[Backward Only]")
251+
print(f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}")
252+
print("-" * 85)
253+
254+
for B, H_q, H_kv, L, D in configs:
255+
t_unfused, t_fused = run_backward_only_benchmark(B, H_q, H_kv, L, D, dtype)
256+
speedup = t_unfused / t_fused
257+
path = "vector" if L <= 8 else "STEEL"
258+
print(
259+
f"{B:3d} {H_q:4d} {H_kv:5d} {L:6d} {D:4d} | {t_unfused:9.2f}ms {t_fused:9.2f}ms {speedup:7.2f}x {path:>8}"
260+
)
261+
262+
print("\n" + "=" * 85)
263+
print("Legend:")
264+
print(" - unfused: Reference implementation using separate matmul + softmax + matmul")
265+
print(" - fused: mx.fast.scaled_dot_product_attention with Flash Attention VJP")
266+
print(" - path: 'vector' for L<=8 (vector kernel), 'STEEL' for L>8 (tiled kernel)")
267+
print(" - speedup > 1.0 means fused is faster")
268+
269+
270+
if __name__ == "__main__":
271+
main()

0 commit comments

Comments
 (0)