Fix #157 and #168#8
Merged
Merged
Conversation
There was a problem hiding this comment.
Pull request overview
Fixes two [broadcast_shapes] crash scenarios affecting Qwen3.5 VLM inference by normalizing token shapes for penalty processing and ensuring per-inference position state is reset in the Qwen3.5 language model.
Changes:
- Flatten prompt tokens in
TokenRing.loadPromptso 2D VLM prompts ([1, seq_len]) are handled correctly. - Reshape sampled tokens to a scalar in
TokenRing.appendto avoidMLX.wherebroadcast mismatches. - Reset Qwen3.5 position state at the start of every
preparecall and add targeted regression tests.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
Libraries/MLXLMCommon/Evaluate.swift |
Makes TokenRing robust to batched prompts and non-scalar sampled tokens to prevent broadcast crashes. |
Libraries/MLXVLM/Models/Qwen35.swift |
Resets cached position-related state unconditionally per inference to avoid stale-shape crashes across requests. |
Tests/MLXLMTests/SampleTests.swift |
Adds regression tests covering 2D prompts and batched sampled tokens across penalty processors. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Satchitananda
pushed a commit
to Satchitananda/mlx-swift-lm
that referenced
this pull request
Apr 15, 2026
VLM prompts are shaped [1, n]. TokenRing.loadPrompt called dim(0) which returns 1 (batch size) instead of n (sequence length). This caused the ring buffer to be sized as (n + capacity - 1) while positions stays at (capacity,). The first append() call then hit MLX.where broadcasting (capacity,) against (n + capacity - 1,) → fatal broadcast_shapes crash. Affects all VLM models: Gemma4, Qwen3-VL, Qwen2.5-VL, GLM-OCR, etc. Fix: flatten prompt with reshaped(-1) before dim(0) so n is always the true sequence length regardless of whether the input is 1-D or 2-D. Original fix discovered independently by: - spokvulcan: ml-explore#170 (open PR) - atdrendel: shareup#8 (merged to shareup fork) - bulentongun: Satchitananda/mlx-swift-lm commit 9169796 Fixes ml-explore#168, ml-explore#191 Co-Authored-By: spokvulcan <spokvulcan@users.noreply.github.com> Co-Authored-By: atdrendel <atdrendel@users.noreply.github.com> Co-Authored-By: bulentongun <bulentongun@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Qwen3.5 VLM broadcast crash fixes
These changes fix two crashes that affect VLM (Vision-Language Model) inference
with the Qwen3.5 model family. Both crashes produce
[broadcast_shapes]errorsbut have different root causes.
Related upstream issues and PRs:
TokenRing.loadPromptWhat happens when you ask a model to generate text
When your app sends a message to Qwen3.5, the text goes through several stages
before the model can work with it. The model doesn't understand words — it
understands numbers. So first, your text is broken into tokens — small
pieces of text (sometimes whole words, sometimes parts of words, sometimes
punctuation) — and each token is mapped to a number. The list of all tokens a
model knows is called its vocabulary. Qwen3.5 has a vocabulary of 248,320
tokens.
So if your prompt is "Tell me a joke", the tokenizer might turn that into
something like
[23998, 757, 264, 22458]— four numbers representing fourtokens.
How VLM models prepare those tokens
Qwen3.5 is a VLM — a Vision-Language Model. It can process both images
and text. Even when you're only sending text, the model's code is built to
handle both.
The VLM's input preparation creates the token array with a batch dimension.
A "dimension" here just means one axis of a multi-dimensional array. Think of it
like a spreadsheet:
[3187]is a single row of 3,187 numbers[1, 3187]is a spreadsheet with 1 row and 3,187 columnsThe VLM adds that extra dimension (the "1") because the model is designed to
potentially process multiple inputs at once (a "batch"). In practice, there's
just one input, so the batch size is 1. This is done with
expandedDimensions(axis: 0), which wraps the flat list in an outer container.What happens during token generation
The model generates text one token at a time. For each new token, it:
are called logits. Higher logits mean the model thinks that token is more
likely to come next. Your model produces 248,320 logits (one per vocabulary
entry).
strategies — pick the highest score ("argmax"), or randomly pick weighted by
the scores ("categorical"). The sampler returns a single number: the chosen
token ID.
Penalty processors and the TokenRing
Penalty processors are optional features that discourage the model from
repeating itself. Before the sampler picks a token, the penalty processor looks
at recent tokens the model has already generated and lowers the logits for those
tokens. This makes the model less likely to repeat them.
To track "recent tokens," each penalty processor uses a
TokenRing— acircular buffer that holds the last N token IDs (where N is the
contextSize,typically 20). Think of it as a fixed-size list that overwrites old entries when
it fills up.
The
TokenRinghas two key operations:loadPrompt— called once at the start, loads the prompt tokens into thering so that penalty processors know which tokens were in the original prompt.
append— called after each generated token, adds the new token to thering.
Fix 1: TokenRing buffer corruption from 2D prompts
The code path to the crash
Here's the exact sequence, matching the stack trace from the crash:
TokenIterator.initcreates aTokenIteratorand callsprepare.TokenIterator.preparedoes two things. First, it loads the prompt tokensinto the penalty processor's ring buffer:
Then it asks the model to process the full prompt. Since
Qwen35is a VLM, itspreparemethod processes everything at once and returns.logits— meaning"here are the scores, you figure out the first token."
Back in
TokenIterator.prepare, this hits the.logitscase, which callsconvertToToken.convertToTokentakes the model's logits and turns them into a singletoken:
The
processor?.didSample(token: y)call forwards to each penalty context,which calls
ring.append(token).TokenRing.append— this is where the crash happens:MLX.where(mask, token, buffer)needs all three arrays to have compatibleshapes.
maskis[20],tokenis[1], butbufferis[3206]. You can'tbroadcast
[20]and[3206]together — they're different sizes and neither is1 — so MLX crashes with:
How the buffer got corrupted
The corruption happened when
processor?.prompt(input.text.tokens)was called.This eventually calls
TokenRing.loadPromptwith the prompt tokens.The tokens have shape
[1, 3187](2D, because the VLM added a batchdimension). Here's what the old
loadPromptcode did with that:dim(0)returns the size of the first dimension. For[1, 3187], the firstdimension is the batch — size 1. So
n = 1, not 3187.The code then thinks there's only 1 token in the prompt, enters the "small
prompt" branch, and:
reshaped(-1)which flattens all 3,187 tokens into a 1D array[3187] + [19]=[3206]The buffer, which should be
[20](the ring capacity), becomes[3206].The number 3206 in the error is literally the prompt length (3187) plus the
padding (19).
The fix
Flatten the prompt to 1D before measuring its length:
Now
dim(0)returns 3187 regardless of whether the input was[3187]or[1, 3187]. The rest of the function correctly keeps only the last 20 tokensin the buffer.
Defensive fix in
appendThe
appendmethod now reshapes the token to a scalar (a single number with nodimensions) before using it in
MLX.where:A scalar broadcasts with any shape, so this can never cause a shape mismatch.
This is a safety net — if any future model or sampler returns a token in an
unexpected shape,
appendwill handle it gracefully.Fix 2: Stale position state between inferences
The problem
The Qwen3.5 VLM's language model stores two pieces of state on the model object
itself:
precomputedPositionIdsandropeDeltas. These are computed during thefirst inference to help the model understand where each token sits in the
sequence — its "position." They are cached so the model doesn't have to
recompute them on every token.
The problem is that these values persist between inference calls. If your app
runs a second inference reusing the same model but with a fresh KV cache (the
model's memory of what it has processed so far), the stale position values from
the first inference are still sitting on the model. They have the wrong shape
for the new input, causing a broadcast error in the attention layer.
The existing code only cleared this state conditionally — inside the
elsebranch of
prepare, which runs when there are no images. This left edge caseswhere stale state could survive into the next inference.
The fix
Move
resetPositionState()to the top ofQwen35.prepare, before anybranching, and remove the old conditional call:
This ensures position state is always fresh at the start of every inference,
regardless of whether images are present or whether a previous inference left
stale values behind.
The old conditional
resetPositionState()in theelsebranch was removedbecause it's now redundant — nothing between the new unconditional reset at the
top of
prepareand the old conditional one setsprecomputedPositionIdsorropeDeltas. Those fields are only set later, insidelanguageModel's forwardpass.
Summary of changes
Libraries/MLXLMCommon/Evaluate.swiftTokenRing.loadPrompt: flatten prompt to 1D before computing lengthLibraries/MLXLMCommon/Evaluate.swiftTokenRing.append: reshape token to scalar beforeMLX.whereLibraries/MLXVLM/Models/Qwen35.swiftQwen35.prepare: moveresetPositionState()to the top, remove old conditional callTests/MLXLMTests/SampleTests.swift