Skip to content

DiffusionGemma long-context prefill: the masks are eating all the RAM #1355

@johntdavies

Description

@johntdavies

Hi Prince, was playing with diffusiongemma-26B-A4B-it-4bit on a 128GB M-series and went down a rabbit hole on long-context memory. Saw you were just in here (#1348) so figured I'd share what I found before it goes stale. Bringing data, not just vibes.

tl;dr

Prefill OOMs way before it should. Tracked it to two O(N²) terms, and the big one surprised me — it's not the attention scores, it's the masks:

_make_encoder_masks builds a separate dense N×N mask for every layer. The 25 sliding layers each get a full create_causal_mask(N, 1024) array, all alive at once:

context masks alone
32k ~27 GB
128k ~429 GB

(The scores are the other O(N²) term — head_dim=256/512 is above MLX's fused-SDPA limit of 64/80/128, so it materialises 16·N²·2B. That's the one that trips Metal's ~86.5 GB single-buffer cap around ~52k.)

the fix

Opt-in query-chunked prefill (env-gated, default off → zero behaviour change). Two moves:

  1. Chunk the queries, eval each chunk before the next → scores bounded to O(chunk·N).
  2. When chunking, hand the attention a lightweight ("causal", window) spec instead of a dense mask, and rebuild each chunk's mask on the fly → the 429 GB never happens.

receipts

DIFFUSION_PREFILL_CHUNK=2048, output verified coherent + ~identical at temp 0:

context stock peak chunked peak
32k 60.6 GB 29.5 GB
49k 109 GB 36 GB
57k 💥 OOM (fine)
128k 💥 (~607 GB proj.) 71 GB

Memory goes quadratic → ~linear. Cost is ~4–7% prefill throughput from the eval syncs (compute is still O(N²) — no fused kernel for these head dims — so 128k is slow, just no longer fatal).

the honest bit

This is a prototype, not a polished PR — I didn't want to guess at your preferences:

  • env-var gate should become a real param (reuse prefill_step_size?)
  • only the no-padding / no-overlay path is handled
  • the per-chunk mx.eval is a blunt-but-necessary hammer (without it lazy eval batches all chunks: 2.55 GB vs 0.60 GB in a micro-bench)
  • open question: this isn't diffusion-specific — anything with head_dim>128 + dense masks hits it. There's an unused chunked_attention in base.py already — would you rather wire that generally than take a diffusion-local copy? Happy to redo it whichever way you prefer, and to add tests + clean it up into a proper PR if you're into the approach.
the diff (prototype, against main)
diff --git a/mlx_vlm/models/diffusion_gemma/language.py b/mlx_vlm/models/diffusion_gemma/language.py
index 47e5720..b624cf2 100644
--- a/mlx_vlm/models/diffusion_gemma/language.py
+++ b/mlx_vlm/models/diffusion_gemma/language.py
@@ -1,3 +1,4 @@
+import os
 import weakref
 from functools import partial
 from typing import Any, Optional
@@ -137,6 +138,58 @@ class Experts(nn.Module):
         return (y * top_k_weights[..., None]).sum(axis=-2)
 
 
+# Chunked prefill: DiffusionGemma's head_dim=256 is above MLX's fused
+# flash-attention limit (64/80/128), so SDPA materialises the full
+# (heads, L, key_len) score tensor — O(L^2) memory that caps context near ~50k
+# on a 128GB Mac. Splitting the *queries* into chunks and evaluating each chunk
+# before the next bounds the live scores to (heads, chunk, key_len) = O(chunk*L)
+# at the cost of extra sync. Off by default; set DIFFUSION_PREFILL_CHUNK=N (e.g.
+# 2048) to enable. Only affects the prompt prefill pass (decoder=False).
+_PREFILL_CHUNK = int(os.environ.get("DIFFUSION_PREFILL_CHUNK", "0"))
+
+
+def _chunked_prefill_sdpa(queries, keys, values, scale, mask, chunk_size):
+    """Query-chunked SDPA that bounds peak attention memory.
+
+    Mathematically identical to a single SDPA call; only the evaluation order
+    differs. ``mask`` may be:
+      * None,
+      * the string "causal",
+      * a ("causal", window_or_None) spec — the per-chunk boolean mask is built
+        on the fly (so the full dense N×N mask is never materialised), or
+      * a dense array whose second-to-last axis is the query axis (sliced).
+    """
+    L = queries.shape[2]
+    key_len = keys.shape[2]
+    # MLX "causal" aligns queries to the *end* of the keys; reproduce that
+    # alignment so a middle chunk masks against the right absolute positions.
+    q_global_start = key_len - L
+    spec = isinstance(mask, (str, tuple))
+    kpos = mx.arange(key_len) if spec else None
+    window = mask[1] if isinstance(mask, tuple) else None
+
+    outputs = []
+    for i in range(0, L, chunk_size):
+        end = min(i + chunk_size, L)
+        q_chunk = queries[:, :, i:end, :]
+        if mask is None:
+            m = None
+        elif spec:  # "causal" or ("causal", window)
+            q_abs = (q_global_start + i) + mx.arange(end - i)
+            causal = kpos[None, :] <= q_abs[:, None]
+            if window:
+                causal = causal & (kpos[None, :] > q_abs[:, None] - window)
+            m = causal[None, None, :, :]
+        else:
+            m = mask[..., i:end, :]
+        o = mx.fast.scaled_dot_product_attention(
+            q_chunk, keys, values, scale=scale, mask=m
+        )
+        mx.eval(o)  # finish this chunk's scores before allocating the next
+        outputs.append(o)
+    return mx.concatenate(outputs, axis=2)
+
+
 class Attention(nn.Module):
     def __init__(self, config: TextConfig, layer_idx: int):
         super().__init__()
@@ -251,9 +304,19 @@ class Attention(nn.Module):
                 keys, values = cache.update_and_fetch(keys, values)
             attn_cache = cache
 
-        output = scaled_dot_product_attention(
-            queries, keys, values, cache=attn_cache, scale=self.scale, mask=mask
-        )
+        if (
+            not decoder
+            and _PREFILL_CHUNK
+            and L > _PREFILL_CHUNK
+            and not hasattr(attn_cache, "bits")  # plain (non-quantized) cache only
+        ):
+            output = _chunked_prefill_sdpa(
+                queries, keys, values, self.scale, mask, _PREFILL_CHUNK
+            )
+        else:
+            output = scaled_dot_product_attention(
+                queries, keys, values, cache=attn_cache, scale=self.scale, mask=mask
+            )
         output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
         return self.o_proj(output)
 
@@ -635,6 +698,20 @@ class EncoderModel(nn.Module):
             overlay = None
 
         if attention_mask is None and overlay is None:
+            # With chunked prefill, return lightweight ("causal", window) specs
+            # instead of dense per-layer N×N masks. Pre-building 25 sliding-layer
+            # masks is the real O(N^2) memory hog (~429 GB at 128k); the chunked
+            # SDPA rebuilds each chunk's mask on the fly instead.
+            if _PREFILL_CHUNK and N > _PREFILL_CHUNK and (not cache or _cache_offset(cache[0]) == 0):
+                return [
+                    (
+                        "causal",
+                        self.text_config.sliding_window
+                        if layer.layer_type == "sliding_attention"
+                        else None,
+                    )
+                    for layer in self.decoder.layers
+                ]
             return [
                 create_attention_mask(
                     h,

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions