TL;DR
mx.quantized_matmul is 1.5–1.8× slower than mx.matmul (fp16) at M=96 and
M=128 on common transformer FFN-down shapes (N=2048, K=8192) on M5 Max NAX
hardware. At M ≥ 160 it's at parity (~1.0× of fp16) as expected. The cause is
qmm_splitk's aggressive split-K threshold: it forces a split-K reduction
even when the underlying NAX qmm_t_nax would be faster as a single kernel
launch.
Fix candidate: bump the 512 threshold in mlx/backend/metal/quantized.cpp:793
(or remove split-K from the NAX path entirely — NAX qmm is already fast at
small M).
Reproducer
See bench/bench_splitk_boundary.py in this repo. Sweeps M on FFN-down shape,
prints split_k value and timing ratios.
Observed on M5 Max (g17s, NAX), macOS 26.5, MLX 0.31.2:
M= 64 fp16: 0.300 ms qmm: 0.307 ms qmm/fp16: 1.02x (split_k=4)
M= 96 fp16: 0.265 ms qmm: 0.418 ms qmm/fp16: 1.58x (split_k=2) ← regression
M= 128 fp16: 0.272 ms qmm: 0.486 ms qmm/fp16: 1.79x (split_k=2) ← regression
M= 160 fp16: 0.354 ms qmm: 0.354 ms qmm/fp16: 1.00x (split_k=1)
M= 192 fp16: 0.341 ms qmm: 0.349 ms qmm/fp16: 1.02x (split_k=1)
M= 256 fp16: 0.349 ms qmm: 0.381 ms qmm/fp16: 1.09x (split_k=1)
Root cause
In mlx/backend/metal/quantized.cpp:787-805:
void qmm_splitk(...) {
int bm = 32, bn = 32;
int n_tiles = (N + bn - 1) / bn;
int m_tiles = (M + bm - 1) / bm;
int current_tgs = n_tiles * m_tiles;
int split_k = std::max(1, 512 / current_tgs);
...
if (split_k <= 1) {
return qmm(...); // fall back to non-split kernel
}
// ... else allocate intermediate buffer, launch split-K kernel, sum-reduce
}
For FFN-down (N=2048):
n_tiles = 64
- M=64 →
m_tiles=2, current_tgs=128, split_k=4
- M=96 →
m_tiles=3, current_tgs=192, split_k=2
- M=128 →
m_tiles=4, current_tgs=256, split_k=2
- M=160 →
m_tiles=5, current_tgs=320, split_k=1 (integer div: 512/320=1)
- M=192+ →
split_k=1
The split-K path allocates an intermediate [split_k, M, N] buffer and then
runs a sum-along-axis-0 reduce after the matmul. On the NAX-capable path,
qmm_t_nax is so fast that this extra allocation + reduce step is more
overhead than the parallelism gain. The split_k = 2 cases at M=96, M=128
clearly demonstrate this: the regular qmm() would run in ~half the time.
Suggested fix
Option 1 (minimal): bump the threshold so split-K only triggers in extreme
small-N cases:
int split_k = std::max(1, 128 / current_tgs); // was 512
Option 2 (NAX-aware): bypass split-K entirely on NAX:
if (metal::is_nax_available() && transpose) {
return qmm(x, w, scales, biases, out, true, group_size, bits, M, N, K, d, s, mode);
}
Option 3 (full re-tuning): empirically sweep the threshold per (N, K) shape on
representative hardware. The 512 was likely tuned for the pre-NAX kernel.
Impact
This regression hits exactly the prefill lengths most chatbot prompts produce
(roughly 96-160 tokens after system prompt + short user message). Fixing it
recovers ~1.5× prefill throughput at that operating point for free, on any
NAX-capable Apple Silicon (M3/M4/M5 Pro+).
Hardware tested
Apple M5 Max (applegpu_g17s), 128 GB unified memory, macOS 26.5, Xcode 26.5
Metal Toolchain 32023.883.
Reference
This was found while building a custom NAX-accelerated fused dequant+matmul
kernel for testing. The custom kernel doesn't use split-K and lands at
1.39× faster than mx.quantized_matmul at M=128 on the same shape,
specifically because it avoids this dispatch overhead. See
bench/bench_splitk_boundary.py and bench/bench_smollm2_real.py in this
repo for full reproducers.
TL;DR
mx.quantized_matmulis 1.5–1.8× slower thanmx.matmul(fp16) at M=96 andM=128 on common transformer FFN-down shapes (N=2048, K=8192) on M5 Max NAX
hardware. At M ≥ 160 it's at parity (~1.0× of fp16) as expected. The cause is
qmm_splitk's aggressive split-K threshold: it forces a split-K reductioneven when the underlying NAX
qmm_t_naxwould be faster as a single kernellaunch.
Fix candidate: bump the
512threshold inmlx/backend/metal/quantized.cpp:793(or remove split-K from the NAX path entirely — NAX
qmmis already fast atsmall M).
Reproducer
See
bench/bench_splitk_boundary.pyin this repo. Sweeps M on FFN-down shape,prints split_k value and timing ratios.
Observed on M5 Max (g17s, NAX), macOS 26.5, MLX 0.31.2:
Root cause
In
mlx/backend/metal/quantized.cpp:787-805:For FFN-down (N=2048):
n_tiles = 64m_tiles=2,current_tgs=128,split_k=4m_tiles=3,current_tgs=192,split_k=2m_tiles=4,current_tgs=256,split_k=2m_tiles=5,current_tgs=320,split_k=1(integer div: 512/320=1)split_k=1The split-K path allocates an intermediate
[split_k, M, N]buffer and thenruns a
sum-along-axis-0 reduce after the matmul. On the NAX-capable path,qmm_t_naxis so fast that this extra allocation + reduce step is moreoverhead than the parallelism gain. The split_k = 2 cases at M=96, M=128
clearly demonstrate this: the regular
qmm()would run in ~half the time.Suggested fix
Option 1 (minimal): bump the threshold so split-K only triggers in extreme
small-N cases:
Option 2 (NAX-aware): bypass split-K entirely on NAX:
Option 3 (full re-tuning): empirically sweep the threshold per (N, K) shape on
representative hardware. The 512 was likely tuned for the pre-NAX kernel.
Impact
This regression hits exactly the prefill lengths most chatbot prompts produce
(roughly 96-160 tokens after system prompt + short user message). Fixing it
recovers ~1.5× prefill throughput at that operating point for free, on any
NAX-capable Apple Silicon (M3/M4/M5 Pro+).
Hardware tested
Apple M5 Max (applegpu_g17s), 128 GB unified memory, macOS 26.5, Xcode 26.5
Metal Toolchain 32023.883.
Reference
This was found while building a custom NAX-accelerated fused dequant+matmul
kernel for testing. The custom kernel doesn't use split-K and lands at
1.39× faster than
mx.quantized_matmulat M=128 on the same shape,specifically because it avoids this dispatch overhead. See
bench/bench_splitk_boundary.pyandbench/bench_smollm2_real.pyin thisrepo for full reproducers.