|
| 1 | +# Copyright © 2024 Apple Inc. |
| 2 | + |
1 | 3 | import argparse |
2 | 4 | import math |
| 5 | +import os |
| 6 | +import subprocess |
| 7 | +import time |
3 | 8 |
|
4 | 9 | import mlx.core as mx |
5 | | -from time_utils import time_fn |
| 10 | +import numpy as np |
6 | 11 |
|
7 | | -MAX_SEQ = 300 |
8 | | -START_SEQ = 100 |
9 | | -SEQ_INCREMENT = 50 |
| 12 | +device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) |
| 13 | +device_name = device_name.decode("utf-8").strip("\n") |
10 | 14 |
|
| 15 | +N_warmup = 5 |
| 16 | +N_iter_bench = 40 |
| 17 | +N_iter_func = 8 |
11 | 18 |
|
12 | | -def time_self_attention_primitives(): |
13 | | - mx.random.seed(3) |
14 | | - B = 2 |
15 | | - H = 38 |
16 | | - D = 64 |
17 | | - for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT): |
18 | | - q = mx.random.uniform(shape=(B, H, R, D)) |
19 | | - k = mx.random.uniform(shape=(B, H, R, D)) |
20 | | - v = mx.random.uniform(shape=(B, H, R, D)) |
21 | | - scale = 1.0 / math.sqrt(float(D)) |
22 | | - mx.eval(q, k, v) |
23 | 19 |
|
24 | | - def sdpa_primitives(qs, ks, vs, alpha): |
25 | | - s = (alpha * qs) @ ks.transpose(0, 1, 3, 2) |
26 | | - p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) |
27 | | - o = p @ vs |
28 | | - return o |
| 20 | +def bench(f, *args): |
| 21 | + for i in range(N_warmup): |
| 22 | + f(*args) |
29 | 23 |
|
30 | | - time_fn(sdpa_primitives, q, k, v, scale) |
| 24 | + s = time.perf_counter_ns() |
| 25 | + for i in range(N_iter_bench): |
| 26 | + f(*args) |
| 27 | + e = time.perf_counter_ns() |
| 28 | + return (e - s) * 1e-9 |
31 | 29 |
|
32 | 30 |
|
33 | | -def time_self_attention_sdpa(): |
34 | | - mx.random.seed(3) |
35 | | - B = 2 |
36 | | - H = 38 |
37 | | - D = 64 |
38 | | - for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT): |
39 | | - q = mx.random.uniform(shape=(B, H, R, D)) |
40 | | - k = mx.random.uniform(shape=(B, H, R, D)) |
41 | | - v = mx.random.uniform(shape=(B, H, R, D)) |
42 | | - scale = 1.0 / math.sqrt(float(D)) |
43 | | - mx.eval(q, k, v) |
| 31 | +def mlx_sdpa_fused_inner(q, k, v, scale): |
| 32 | + return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) |
44 | 33 |
|
45 | | - def sdpa_fused(qs, ks, vs, alpha): |
46 | | - o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha) |
47 | | - return o |
48 | 34 |
|
49 | | - time_fn(sdpa_fused, q, k, v, scale) |
| 35 | +def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): |
| 36 | + q_dtype = q.dtype |
| 37 | + q = q * mx.array(scale, q_dtype) |
| 38 | + n_q_heads = q.shape[-3] |
| 39 | + n_kv_heads = k.shape[-3] |
| 40 | + n_repeats = n_q_heads // n_kv_heads |
50 | 41 |
|
| 42 | + B = q.shape[0] |
| 43 | + L = q.shape[2] |
51 | 44 |
|
52 | | -if __name__ == "__main__": |
53 | | - parser = argparse.ArgumentParser("MLX benchmarks.") |
54 | | - parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") |
55 | | - args = parser.parse_args() |
56 | | - if args.gpu: |
57 | | - mx.set_default_device(mx.gpu) |
| 45 | + if n_repeats > 1: |
| 46 | + q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) |
| 47 | + k = mx.expand_dims(k, 2) |
| 48 | + v = mx.expand_dims(v, 2) |
| 49 | + |
| 50 | + scores = q @ mx.swapaxes(k, -1, -2) |
| 51 | + if f32softmax: |
| 52 | + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype) |
58 | 53 | else: |
59 | | - mx.set_default_device(mx.cpu) |
| 54 | + scores = mx.softmax(scores, axis=-1) |
| 55 | + |
| 56 | + out = scores @ v |
| 57 | + if n_repeats > 1: |
| 58 | + out = mx.reshape(out, [B, n_q_heads, L, -1]) |
| 59 | + |
| 60 | + return out |
| 61 | + |
| 62 | + |
| 63 | +def mlx_spda_unfused(q, k, v, scale, transpose): |
| 64 | + q_out = q |
| 65 | + if transpose: |
| 66 | + k = mx.transpose(k, (0, 2, 1, 3)) |
| 67 | + v = mx.transpose(v, (0, 2, 1, 3)) |
| 68 | + |
| 69 | + for i in range(N_iter_func): |
| 70 | + if transpose: |
| 71 | + q_out = mx.transpose(q_out, (0, 2, 1, 3)) |
| 72 | + q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale) |
| 73 | + if transpose: |
| 74 | + q_out = mx.transpose(q_out, (0, 2, 1, 3)) |
| 75 | + |
| 76 | + mx.eval(q_out) |
| 77 | + return q_out |
| 78 | + |
| 79 | + |
| 80 | +def mlx_spda_fused(q, k, v, scale, transpose): |
| 81 | + q_out = q |
| 82 | + if transpose: |
| 83 | + k = mx.transpose(k, (0, 2, 1, 3)) |
| 84 | + v = mx.transpose(v, (0, 2, 1, 3)) |
| 85 | + |
| 86 | + for i in range(N_iter_func): |
| 87 | + if transpose: |
| 88 | + q_out = mx.transpose(q_out, (0, 2, 1, 3)) |
| 89 | + q_out = mlx_sdpa_fused_inner(q_out, k, v, scale) |
| 90 | + if transpose: |
| 91 | + q_out = mx.transpose(q_out, (0, 2, 1, 3)) |
| 92 | + |
| 93 | + mx.eval(q_out) |
| 94 | + return q_out |
| 95 | + |
| 96 | + |
| 97 | +def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True): |
| 98 | + shape_q = ( |
| 99 | + (B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim) |
| 100 | + ) |
| 101 | + shape_kv = ( |
| 102 | + (B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim) |
| 103 | + ) |
| 104 | + |
| 105 | + q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype) |
| 106 | + k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) |
| 107 | + v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) |
| 108 | + |
| 109 | + scale = math.sqrt(1.0 / head_dim) |
| 110 | + |
| 111 | + q_mx = mx.array(q_np) |
| 112 | + k_mx = mx.array(k_np) |
| 113 | + v_mx = mx.array(v_np) |
| 114 | + |
| 115 | + time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose) |
| 116 | + time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose) |
| 117 | + |
| 118 | + if transpose: |
| 119 | + q_mx = mx.transpose(q_mx, (0, 2, 1, 3)) |
| 120 | + k_mx = mx.transpose(k_mx, (0, 2, 1, 3)) |
| 121 | + v_mx = mx.transpose(v_mx, (0, 2, 1, 3)) |
| 122 | + |
| 123 | + o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale) |
| 124 | + o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True) |
| 125 | + |
| 126 | + atol = 1e-5 if np_dtype == np.float32 else 1e-4 |
| 127 | + |
| 128 | + if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol): |
| 129 | + print( |
| 130 | + f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" |
| 131 | + ) |
| 132 | + |
| 133 | + return time_mlx_fused, time_mlx_unfused |
| 134 | + |
| 135 | + |
| 136 | +def get_gflop_count(B, M, N, K): |
| 137 | + return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) |
| 138 | + |
| 139 | + |
| 140 | +if __name__ == "__main__": |
| 141 | + parser = argparse.ArgumentParser(description="Run gemm benchmarks") |
| 142 | + |
| 143 | + dtypes = ("float16", "float32")[:1] |
| 144 | + transposes = (False,) |
| 145 | + |
| 146 | + # fmt: off |
| 147 | + shapes_64 = ( |
| 148 | + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) |
| 149 | + ( 1, 32, 32, 64, 32, 32), |
| 150 | + ( 1, 64, 64, 64, 32, 32), |
| 151 | + ( 1, 128, 128, 64, 32, 32), |
| 152 | + ( 1, 256, 256, 64, 32, 32), |
| 153 | + ( 1, 512, 512, 64, 32, 32), |
| 154 | + ( 1, 1024, 1024, 64, 32, 32), |
| 155 | + ( 1, 2048, 2048, 64, 32, 32), |
| 156 | + ( 1, 4096, 4096, 64, 32, 32), |
| 157 | + ) |
| 158 | + |
| 159 | + shapes_80 = ( |
| 160 | + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) |
| 161 | + ( 1, 1024, 1024, 80, 32, 32), |
| 162 | + ( 1, 2048, 2048, 80, 32, 32), |
| 163 | + ( 1, 4096, 4096, 80, 32, 32), |
| 164 | + ) |
| 165 | + |
| 166 | + shapes_128 = ( |
| 167 | + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) |
| 168 | + ( 1, 1024, 1024, 128, 32, 32), |
| 169 | + ( 1, 2048, 2048, 128, 32, 32), |
| 170 | + ( 1, 4096, 4096, 128, 32, 32), |
| 171 | + ) |
| 172 | + # fmt: on |
| 173 | + |
| 174 | + shapes = shapes_64 + shapes_80 + shapes_128 |
| 175 | + |
| 176 | + print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") |
60 | 177 |
|
61 | | - time_self_attention_sdpa() |
62 | | - time_self_attention_primitives() |
| 178 | + for dtype in dtypes: |
| 179 | + for transpose in transposes: |
| 180 | + for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: |
| 181 | + np_dtype = getattr(np, dtype) |
| 182 | + time_mlx_fused, time_mlx_unfused = bench_shape( |
| 183 | + B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose |
| 184 | + ) |
| 185 | + diff = time_mlx_unfused / time_mlx_fused - 1.0 |
| 186 | + t_str = 1 if transpose else 0 |
| 187 | + print( |
| 188 | + f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" |
| 189 | + ) |
0 commit comments