Skip to content

Commit 21caa99

Browse files
Brooooooklynclaude
andcommitted
style: apply clang-format and black formatting
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 105d051 commit 21caa99

File tree

8 files changed

+187
-139
lines changed

8 files changed

+187
-139
lines changed

benchmarks/python/sdpa_vjp_bench.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import argparse
1010
import time
11+
1112
import mlx.core as mx
1213

1314
N_warmup = 10
@@ -115,7 +116,9 @@ def run_backward_only_benchmark(B, H_q, H_kv, L, D, dtype=mx.float16):
115116

116117
# Unfused backward
117118
def unfused_bwd():
118-
_, grads = mx.vjp(lambda q, k, v: mlx_ref_attn(q, k, v, scale), [q, k, v], [cotan])
119+
_, grads = mx.vjp(
120+
lambda q, k, v: mlx_ref_attn(q, k, v, scale), [q, k, v], [cotan]
121+
)
119122
return grads
120123

121124
# Fused backward
@@ -142,7 +145,9 @@ def verify_correctness(B, H_q, H_kv, L, D, dtype=mx.float16):
142145
v = mx.random.normal((B, H_kv, L, D), dtype=dtype)
143146
cotan = mx.ones((B, H_q, L, D), dtype=dtype)
144147

145-
_, ref_grads = mx.vjp(lambda q, k, v: mlx_ref_attn(q, k, v, scale), [q, k, v], [cotan])
148+
_, ref_grads = mx.vjp(
149+
lambda q, k, v: mlx_ref_attn(q, k, v, scale), [q, k, v], [cotan]
150+
)
146151
_, fused_grads = mx.vjp(
147152
lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale),
148153
[q, k, v],
@@ -154,7 +159,9 @@ def verify_correctness(B, H_q, H_kv, L, D, dtype=mx.float16):
154159
for i, (r, f) in enumerate(zip(ref_grads, fused_grads)):
155160
if not mx.allclose(r, f, rtol=rtol, atol=atol):
156161
max_diff = mx.max(mx.abs(r - f)).item()
157-
print(f" WARNING: Gradient {['dQ', 'dK', 'dV'][i]} mismatch, max_diff={max_diff:.2e}")
162+
print(
163+
f" WARNING: Gradient {['dQ', 'dK', 'dV'][i]} mismatch, max_diff={max_diff:.2e}"
164+
)
158165
all_match = False
159166

160167
return all_match
@@ -168,9 +175,15 @@ def main():
168175
default="vjp",
169176
help="Benchmark mode: vjp (fwd+bwd), forward only, backward only, or all",
170177
)
171-
parser.add_argument("--verify", action="store_true", help="Verify correctness before benchmarking")
172-
parser.add_argument("--dtype", choices=["float16", "bfloat16", "float32"], default="float16")
173-
parser.add_argument("--quick", action="store_true", help="Run quick subset of benchmarks")
178+
parser.add_argument(
179+
"--verify", action="store_true", help="Verify correctness before benchmarking"
180+
)
181+
parser.add_argument(
182+
"--dtype", choices=["float16", "bfloat16", "float32"], default="float16"
183+
)
184+
parser.add_argument(
185+
"--quick", action="store_true", help="Run quick subset of benchmarks"
186+
)
174187
args = parser.parse_args()
175188

176189
dtype = getattr(mx, args.dtype)
@@ -208,16 +221,18 @@ def main():
208221
(1, 32, 8, 1024, 128),
209222
(1, 32, 8, 2048, 128),
210223
# GQA configurations
211-
(2, 32, 8, 256, 64), # 4:1 GQA
212-
(2, 32, 4, 256, 64), # 8:1 GQA
224+
(2, 32, 8, 256, 64), # 4:1 GQA
225+
(2, 32, 4, 256, 64), # 8:1 GQA
213226
]
214227

215228
print(f"SDPA VJP Benchmark - dtype={args.dtype}")
216229
print("=" * 85)
217230

218231
if args.mode in ["vjp", "all"]:
219232
print("\n[Forward + Backward (VJP)]")
220-
print(f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}")
233+
print(
234+
f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}"
235+
)
221236
print("-" * 85)
222237

223238
for B, H_q, H_kv, L, D in configs:
@@ -235,7 +250,9 @@ def main():
235250

236251
if args.mode in ["forward", "all"]:
237252
print("\n[Forward Only]")
238-
print(f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}")
253+
print(
254+
f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}"
255+
)
239256
print("-" * 85)
240257

241258
for B, H_q, H_kv, L, D in configs:
@@ -248,7 +265,9 @@ def main():
248265

249266
if args.mode in ["backward", "all"]:
250267
print("\n[Backward Only]")
251-
print(f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}")
268+
print(
269+
f"{'B':>3} {'H_q':>4} {'H_kv':>5} {'L':>6} {'D':>4} | {'unfused':>10} {'fused':>10} {'speedup':>8} {'path':>8}"
270+
)
252271
print("-" * 85)
253272

254273
for B, H_q, H_kv, L, D in configs:
@@ -261,7 +280,9 @@ def main():
261280

262281
print("\n" + "=" * 85)
263282
print("Legend:")
264-
print(" - unfused: Reference implementation using separate matmul + softmax + matmul")
283+
print(
284+
" - unfused: Reference implementation using separate matmul + softmax + matmul"
285+
)
265286
print(" - fused: mx.fast.scaled_dot_product_attention with Flash Attention VJP")
266287
print(" - path: 'vector' for L<=8 (vector kernel), 'STEEL' for L>8 (tiled kernel)")
267288
print(" - speedup > 1.0 means fused is faster")

mlx/backend/metal/kernels/sdpa_vector.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ template <typename T, int D, int V = D>
9393
U sum_exp_score = 0;
9494
if (has_sinks && simd_gid == 0) {
9595
// Scale sink by M_LOG2E_F to match log2 domain
96-
max_score = static_cast<U>(M_LOG2E_F) * static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
96+
max_score = static_cast<U>(M_LOG2E_F) *
97+
static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
9798
sum_exp_score = 1;
9899
}
99100

mlx/backend/metal/kernels/sdpa_vector_vjp.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ template <typename T, int D, int V = D>
6262
const constant size_t& v_head_stride [[buffer(13)]],
6363
const constant size_t& v_seq_stride [[buffer(14)]],
6464
const constant float& scale [[buffer(15)]],
65-
// Output (O/dO) stride parameters - STEEL forward may produce non-row-major layout
66-
// Physical layout can be BLHV (strides [L*H*V, V, H*V, 1]) vs logical BHLV
65+
// Output (O/dO) stride parameters - STEEL forward may produce non-row-major
66+
// layout Physical layout can be BLHV (strides [L*H*V, V, H*V, 1]) vs
67+
// logical BHLV
6768
const constant int& num_q_heads [[buffer(16)]],
6869
const constant size_t& o_batch_stride [[buffer(17)]],
6970
const constant size_t& o_head_stride [[buffer(18)]],
@@ -138,7 +139,8 @@ template <typename T, int D, int V = D>
138139

139140
// Set up output/gradient pointers
140141
// Use explicit strides for O/dO to handle BLHV physical layout from STEEL
141-
// For BLHV strides: o_batch_stride = L*H*V, o_head_stride = V, o_seq_stride = H*V
142+
// For BLHV strides: o_batch_stride = L*H*V, o_head_stride = V, o_seq_stride =
143+
// H*V
142144
out += batch_idx * o_batch_stride + head_idx * o_head_stride +
143145
q_seq_idx * o_seq_stride + simd_lid * v_per_thread;
144146
d_out += batch_idx * o_batch_stride + head_idx * o_head_stride +
@@ -232,7 +234,8 @@ template <typename T, int D, int V = D>
232234
}
233235

234236
// Reconstruct attention probability: P = exp2(S - logsumexp)
235-
// Using exp2 to match STEEL attention domain (logsumexp is in log2 domain)
237+
// Using exp2 to match STEEL attention domain (logsumexp is in log2
238+
// domain)
236239
U prob = fast::exp2(score - lse);
237240

238241
// Compute dP = dO @ V^T for this KV position
@@ -247,8 +250,8 @@ template <typename T, int D, int V = D>
247250

248251
// Accumulate dQ += scale * dS @ K
249252
// Note: Although Q was scaled by M_LOG2E_F internally, the softmax
250-
// gradient dS compensates for this because the overall softmax(S') = softmax(S).
251-
// The gradient dQ = scale * dS @ K matches the reference.
253+
// gradient dS compensates for this because the overall softmax(S') =
254+
// softmax(S). The gradient dQ = scale * dS @ K matches the reference.
252255
for (int j = 0; j < qk_per_thread; j++) {
253256
dq[j] += static_cast<U>(scale) * dS * k[j];
254257
}
@@ -347,7 +350,8 @@ template <typename T, int D, int V = D>
347350
const constant size_t& v_head_stride [[buffer(14)]],
348351
const constant size_t& v_seq_stride [[buffer(15)]],
349352
const constant float& scale [[buffer(16)]],
350-
// Output (O/dO) stride parameters - STEEL forward may produce non-row-major layout
353+
// Output (O/dO) stride parameters - STEEL forward may produce non-row-major
354+
// layout
351355
const constant int& num_q_heads [[buffer(17)]],
352356
const constant size_t& o_batch_stride [[buffer(18)]],
353357
const constant size_t& o_head_stride [[buffer(19)]],
@@ -489,11 +493,13 @@ template <typename T, int D, int V = D>
489493

490494
if (float_mask) {
491495
// Scale float mask by M_LOG2E_F to match log2 domain
492-
score += static_cast<U>(M_LOG2E_F) * static_cast<U>(fm_ptr[mask_offset]);
496+
score +=
497+
static_cast<U>(M_LOG2E_F) * static_cast<U>(fm_ptr[mask_offset]);
493498
}
494499

495500
// Reconstruct probability: P = exp2(S - logsumexp)
496-
// Using exp2 to match STEEL attention domain (logsumexp is in log2 domain)
501+
// Using exp2 to match STEEL attention domain (logsumexp is in log2
502+
// domain)
497503
U prob = fast::exp2(score - lse);
498504

499505
// Compute dP

mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,19 +464,20 @@ template <
464464

465465
// Output logsumexp if requested for VJP backward pass
466466
// LSE = max_score + log2(sum_score) in log2 domain (matches STEEL convention)
467-
// Physical storage shape: [B*H, qL], laid out as linear array indexed by (B*H + head)*qL + query_pos
468-
// LSE_strides[0] = qL (stride between (batch, head) rows)
469-
// LSE_strides[1] = 1 (stride between query positions within a row)
467+
// Physical storage shape: [B*H, qL], laid out as linear array indexed by (B*H
468+
// + head)*qL + query_pos LSE_strides[0] = qL (stride between (batch, head)
469+
// rows) LSE_strides[1] = 1 (stride between query positions within a row)
470470
if (output_logsumexp) {
471471
// Compute linear index for (batch, head) combination
472-
// This matches the VJP kernel's indexing: (tidl.z * H + tidl.y) * LSE_strides[0]
473-
device float* lse_out = LSE +
474-
(tidl.z * params->H + tidl.y) * params->LSE_strides[0];
472+
// This matches the VJP kernel's indexing: (tidl.z * H + tidl.y) *
473+
// LSE_strides[0]
474+
device float* lse_out =
475+
LSE + (tidl.z * params->H + tidl.y) * params->LSE_strides[0];
475476

476477
// Write one logsumexp per query position in this tile
477478
// Each thread handles kRowsPT query positions
478-
// align_Q=true means query length is aligned (all blocks full), so always write
479-
// align_Q=false means last block is partial, so check bounds
479+
// align_Q=true means query length is aligned (all blocks full), so always
480+
// write align_Q=false means last block is partial, so check bounds
480481
STEEL_PRAGMA_UNROLL
481482
for (short i = 0; i < kRowsPT; ++i) {
482483
int row_pos = tid.x * BQ + tm + sm + (i * decltype(Stile)::kFragRows);

mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ void attention_vjp_dkv(
284284
const device T* O_base =
285285
O + tidl.z * params->O_strides[0] + q_head_idx * params->O_strides[1];
286286

287-
const device T* dO_base =
288-
dO + tidl.z * params->O_strides[0] + q_head_idx * params->O_strides[1];
287+
const device T* dO_base = dO + tidl.z * params->O_strides[0] +
288+
q_head_idx * params->O_strides[1];
289289

290290
const device float* LSE_base =
291291
LSE + (tidl.z * params->H + q_head_idx) * params->LSE_strides[0];
@@ -835,8 +835,8 @@ void attention_vjp_dkv(
835835
// tname is the string name used in kernel lookup (e.g., "float32", "float16")
836836
// dtype is the actual C++ type (e.g., float, half, bfloat16_t)
837837
#define instantiate_attention_vjp_dkv_kernel(tname, dtype, bq, bk, bd, wm, wn) \
838-
template [[host_name( \
839-
"attention_vjp_dkv_" #tname "_" #bq "_" #bk "_" #bd)]] [[kernel]] void \
838+
template [[host_name("attention_vjp_dkv_" #tname "_" #bq "_" #bk \
839+
"_" #bd)]] [[kernel]] void \
840840
attention_vjp_dkv<dtype, bq, bk, bd, wm, wn>( \
841841
const device dtype*, \
842842
const device dtype*, \

0 commit comments

Comments
 (0)