Skip to content

qmm_splitk dispatch causes 1.5-1.8x regression vs fp16 at M=96, M=128 on NAX hardware #3584

@marian-gg

Description

@marian-gg

TL;DR

mx.quantized_matmul is 1.5–1.8× slower than mx.matmul (fp16) at M=96 and
M=128 on common transformer FFN-down shapes (N=2048, K=8192) on M5 Max NAX
hardware. At M ≥ 160 it's at parity (~1.0× of fp16) as expected. The cause is
qmm_splitk's aggressive split-K threshold: it forces a split-K reduction
even when the underlying NAX qmm_t_nax would be faster as a single kernel
launch.

Fix candidate: bump the 512 threshold in mlx/backend/metal/quantized.cpp:793
(or remove split-K from the NAX path entirely — NAX qmm is already fast at
small M).

Reproducer

See bench/bench_splitk_boundary.py in this repo. Sweeps M on FFN-down shape,
prints split_k value and timing ratios.

Observed on M5 Max (g17s, NAX), macOS 26.5, MLX 0.31.2:

M=  64  fp16: 0.300 ms   qmm: 0.307 ms   qmm/fp16: 1.02x    (split_k=4)
M=  96  fp16: 0.265 ms   qmm: 0.418 ms   qmm/fp16: 1.58x    (split_k=2) ← regression
M= 128  fp16: 0.272 ms   qmm: 0.486 ms   qmm/fp16: 1.79x    (split_k=2) ← regression
M= 160  fp16: 0.354 ms   qmm: 0.354 ms   qmm/fp16: 1.00x    (split_k=1)
M= 192  fp16: 0.341 ms   qmm: 0.349 ms   qmm/fp16: 1.02x    (split_k=1)
M= 256  fp16: 0.349 ms   qmm: 0.381 ms   qmm/fp16: 1.09x    (split_k=1)

Root cause

In mlx/backend/metal/quantized.cpp:787-805:

void qmm_splitk(...) {
  int bm = 32, bn = 32;
  int n_tiles = (N + bn - 1) / bn;
  int m_tiles = (M + bm - 1) / bm;
  int current_tgs = n_tiles * m_tiles;
  int split_k = std::max(1, 512 / current_tgs);
  ...
  if (split_k <= 1) {
    return qmm(...);  // fall back to non-split kernel
  }
  // ... else allocate intermediate buffer, launch split-K kernel, sum-reduce
}

For FFN-down (N=2048):

  • n_tiles = 64
  • M=64 → m_tiles=2, current_tgs=128, split_k=4
  • M=96 → m_tiles=3, current_tgs=192, split_k=2
  • M=128 → m_tiles=4, current_tgs=256, split_k=2
  • M=160 → m_tiles=5, current_tgs=320, split_k=1 (integer div: 512/320=1)
  • M=192+ → split_k=1

The split-K path allocates an intermediate [split_k, M, N] buffer and then
runs a sum-along-axis-0 reduce after the matmul. On the NAX-capable path,
qmm_t_nax is so fast that this extra allocation + reduce step is more
overhead than the parallelism gain. The split_k = 2 cases at M=96, M=128
clearly demonstrate this: the regular qmm() would run in ~half the time.

Suggested fix

Option 1 (minimal): bump the threshold so split-K only triggers in extreme
small-N cases:

int split_k = std::max(1, 128 / current_tgs);  // was 512

Option 2 (NAX-aware): bypass split-K entirely on NAX:

if (metal::is_nax_available() && transpose) {
    return qmm(x, w, scales, biases, out, true, group_size, bits, M, N, K, d, s, mode);
}

Option 3 (full re-tuning): empirically sweep the threshold per (N, K) shape on
representative hardware. The 512 was likely tuned for the pre-NAX kernel.

Impact

This regression hits exactly the prefill lengths most chatbot prompts produce
(roughly 96-160 tokens after system prompt + short user message). Fixing it
recovers ~1.5× prefill throughput at that operating point for free, on any
NAX-capable Apple Silicon (M3/M4/M5 Pro+).

Hardware tested

Apple M5 Max (applegpu_g17s), 128 GB unified memory, macOS 26.5, Xcode 26.5
Metal Toolchain 32023.883.

Reference

This was found while building a custom NAX-accelerated fused dequant+matmul
kernel for testing. The custom kernel doesn't use split-K and lands at
1.39× faster than mx.quantized_matmul at M=128 on the same shape,
specifically because it avoids this dispatch overhead. See
bench/bench_splitk_boundary.py and bench/bench_smollm2_real.py in this
repo for full reproducers.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions