diff --git a/mlx_vlm/models/diffusion_gemma/language.py b/mlx_vlm/models/diffusion_gemma/language.py
index 47e5720..b624cf2 100644
--- a/mlx_vlm/models/diffusion_gemma/language.py
+++ b/mlx_vlm/models/diffusion_gemma/language.py
@@ -1,3 +1,4 @@
+import os
import weakref
from functools import partial
from typing import Any, Optional
@@ -137,6 +138,58 @@ class Experts(nn.Module):
return (y * top_k_weights[..., None]).sum(axis=-2)
+# Chunked prefill: DiffusionGemma's head_dim=256 is above MLX's fused
+# flash-attention limit (64/80/128), so SDPA materialises the full
+# (heads, L, key_len) score tensor — O(L^2) memory that caps context near ~50k
+# on a 128GB Mac. Splitting the *queries* into chunks and evaluating each chunk
+# before the next bounds the live scores to (heads, chunk, key_len) = O(chunk*L)
+# at the cost of extra sync. Off by default; set DIFFUSION_PREFILL_CHUNK=N (e.g.
+# 2048) to enable. Only affects the prompt prefill pass (decoder=False).
+_PREFILL_CHUNK = int(os.environ.get("DIFFUSION_PREFILL_CHUNK", "0"))
+
+
+def _chunked_prefill_sdpa(queries, keys, values, scale, mask, chunk_size):
+ """Query-chunked SDPA that bounds peak attention memory.
+
+ Mathematically identical to a single SDPA call; only the evaluation order
+ differs. ``mask`` may be:
+ * None,
+ * the string "causal",
+ * a ("causal", window_or_None) spec — the per-chunk boolean mask is built
+ on the fly (so the full dense N×N mask is never materialised), or
+ * a dense array whose second-to-last axis is the query axis (sliced).
+ """
+ L = queries.shape[2]
+ key_len = keys.shape[2]
+ # MLX "causal" aligns queries to the *end* of the keys; reproduce that
+ # alignment so a middle chunk masks against the right absolute positions.
+ q_global_start = key_len - L
+ spec = isinstance(mask, (str, tuple))
+ kpos = mx.arange(key_len) if spec else None
+ window = mask[1] if isinstance(mask, tuple) else None
+
+ outputs = []
+ for i in range(0, L, chunk_size):
+ end = min(i + chunk_size, L)
+ q_chunk = queries[:, :, i:end, :]
+ if mask is None:
+ m = None
+ elif spec: # "causal" or ("causal", window)
+ q_abs = (q_global_start + i) + mx.arange(end - i)
+ causal = kpos[None, :] <= q_abs[:, None]
+ if window:
+ causal = causal & (kpos[None, :] > q_abs[:, None] - window)
+ m = causal[None, None, :, :]
+ else:
+ m = mask[..., i:end, :]
+ o = mx.fast.scaled_dot_product_attention(
+ q_chunk, keys, values, scale=scale, mask=m
+ )
+ mx.eval(o) # finish this chunk's scores before allocating the next
+ outputs.append(o)
+ return mx.concatenate(outputs, axis=2)
+
+
class Attention(nn.Module):
def __init__(self, config: TextConfig, layer_idx: int):
super().__init__()
@@ -251,9 +304,19 @@ class Attention(nn.Module):
keys, values = cache.update_and_fetch(keys, values)
attn_cache = cache
- output = scaled_dot_product_attention(
- queries, keys, values, cache=attn_cache, scale=self.scale, mask=mask
- )
+ if (
+ not decoder
+ and _PREFILL_CHUNK
+ and L > _PREFILL_CHUNK
+ and not hasattr(attn_cache, "bits") # plain (non-quantized) cache only
+ ):
+ output = _chunked_prefill_sdpa(
+ queries, keys, values, self.scale, mask, _PREFILL_CHUNK
+ )
+ else:
+ output = scaled_dot_product_attention(
+ queries, keys, values, cache=attn_cache, scale=self.scale, mask=mask
+ )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
@@ -635,6 +698,20 @@ class EncoderModel(nn.Module):
overlay = None
if attention_mask is None and overlay is None:
+ # With chunked prefill, return lightweight ("causal", window) specs
+ # instead of dense per-layer N×N masks. Pre-building 25 sliding-layer
+ # masks is the real O(N^2) memory hog (~429 GB at 128k); the chunked
+ # SDPA rebuilds each chunk's mask on the fly instead.
+ if _PREFILL_CHUNK and N > _PREFILL_CHUNK and (not cache or _cache_offset(cache[0]) == 0):
+ return [
+ (
+ "causal",
+ self.text_config.sliding_window
+ if layer.layer_type == "sliding_attention"
+ else None,
+ )
+ for layer in self.decoder.layers
+ ]
return [
create_attention_mask(
h,
Hi Prince, was playing with
diffusiongemma-26B-A4B-it-4biton a 128GB M-series and went down a rabbit hole on long-context memory. Saw you were just in here (#1348) so figured I'd share what I found before it goes stale. Bringing data, not just vibes.tl;dr
Prefill OOMs way before it should. Tracked it to two O(N²) terms, and the big one surprised me — it's not the attention scores, it's the masks:
_make_encoder_masksbuilds a separate dense N×N mask for every layer. The 25 sliding layers each get a fullcreate_causal_mask(N, 1024)array, all alive at once:(The scores are the other O(N²) term —
head_dim=256/512is above MLX's fused-SDPA limit of 64/80/128, so it materialises16·N²·2B. That's the one that trips Metal's ~86.5 GB single-buffer cap around ~52k.)the fix
Opt-in query-chunked prefill (env-gated, default off → zero behaviour change). Two moves:
evaleach chunk before the next → scores bounded toO(chunk·N).("causal", window)spec instead of a dense mask, and rebuild each chunk's mask on the fly → the 429 GB never happens.receipts
DIFFUSION_PREFILL_CHUNK=2048, output verified coherent + ~identical at temp 0:Memory goes quadratic → ~linear. Cost is ~4–7% prefill throughput from the
evalsyncs (compute is still O(N²) — no fused kernel for these head dims — so 128k is slow, just no longer fatal).the honest bit
This is a prototype, not a polished PR — I didn't want to guess at your preferences:
prefill_step_size?)mx.evalis a blunt-but-necessary hammer (without it lazy eval batches all chunks: 2.55 GB vs 0.60 GB in a micro-bench)head_dim>128+ dense masks hits it. There's an unusedchunked_attentioninbase.pyalready — would you rather wire that generally than take a diffusion-local copy? Happy to redo it whichever way you prefer, and to add tests + clean it up into a proper PR if you're into the approach.the diff (prototype, against
main)