Skip to content

Fix #157 and #168#8

Merged
atdrendel merged 1 commit into
mainfrom
fix-168-and-157
Apr 5, 2026
Merged

Fix #157 and #168#8
atdrendel merged 1 commit into
mainfrom
fix-168-and-157

Conversation

@atdrendel

Copy link
Copy Markdown

Qwen3.5 VLM broadcast crash fixes

These changes fix two crashes that affect VLM (Vision-Language Model) inference
with the Qwen3.5 model family. Both crashes produce [broadcast_shapes] errors
but have different root causes.

Related upstream issues and PRs:

What happens when you ask a model to generate text

When your app sends a message to Qwen3.5, the text goes through several stages
before the model can work with it. The model doesn't understand words — it
understands numbers. So first, your text is broken into tokens — small
pieces of text (sometimes whole words, sometimes parts of words, sometimes
punctuation) — and each token is mapped to a number. The list of all tokens a
model knows is called its vocabulary. Qwen3.5 has a vocabulary of 248,320
tokens.

So if your prompt is "Tell me a joke", the tokenizer might turn that into
something like [23998, 757, 264, 22458] — four numbers representing four
tokens.

How VLM models prepare those tokens

Qwen3.5 is a VLM — a Vision-Language Model. It can process both images
and text. Even when you're only sending text, the model's code is built to
handle both.

The VLM's input preparation creates the token array with a batch dimension.
A "dimension" here just means one axis of a multi-dimensional array. Think of it
like a spreadsheet:

  • A 1D array [3187] is a single row of 3,187 numbers
  • A 2D array [1, 3187] is a spreadsheet with 1 row and 3,187 columns

The VLM adds that extra dimension (the "1") because the model is designed to
potentially process multiple inputs at once (a "batch"). In practice, there's
just one input, so the batch size is 1. This is done with
expandedDimensions(axis: 0), which wraps the flat list in an outer container.

What happens during token generation

The model generates text one token at a time. For each new token, it:

  1. Looks at all previous tokens.
  2. Produces a set of scores for every token in the vocabulary — these scores
    are called logits. Higher logits mean the model thinks that token is more
    likely to come next. Your model produces 248,320 logits (one per vocabulary
    entry).
  3. A sampler picks one token from those logits. There are different
    strategies — pick the highest score ("argmax"), or randomly pick weighted by
    the scores ("categorical"). The sampler returns a single number: the chosen
    token ID.

Penalty processors and the TokenRing

Penalty processors are optional features that discourage the model from
repeating itself. Before the sampler picks a token, the penalty processor looks
at recent tokens the model has already generated and lowers the logits for those
tokens. This makes the model less likely to repeat them.

To track "recent tokens," each penalty processor uses a TokenRing — a
circular buffer that holds the last N token IDs (where N is the contextSize,
typically 20). Think of it as a fixed-size list that overwrites old entries when
it fills up.

The TokenRing has two key operations:

  • loadPrompt — called once at the start, loads the prompt tokens into the
    ring so that penalty processors know which tokens were in the original prompt.
  • append — called after each generated token, adds the new token to the
    ring.

Fix 1: TokenRing buffer corruption from 2D prompts

The code path to the crash

Here's the exact sequence, matching the stack trace from the crash:

TokenIterator.init creates a TokenIterator and calls prepare.

TokenIterator.prepare does two things. First, it loads the prompt tokens
into the penalty processor's ring buffer:

// Libraries/MLXLMCommon/Evaluate.swift
processor?.prompt(input.text.tokens)   // loads tokens into TokenRing

Then it asks the model to process the full prompt. Since Qwen35 is a VLM, its
prepare method processes everything at once and returns .logits — meaning
"here are the scores, you figure out the first token."

Back in TokenIterator.prepare, this hits the .logits case, which calls
convertToToken.

convertToToken takes the model's logits and turns them into a single
token:

// Libraries/MLXLMCommon/Evaluate.swift
mutating func convertToToken(logits: MLXArray) -> MLXArray {
    var logits = logits[0..., -1, 0...]
    logits = processor?.process(logits: logits) ?? logits
    let y = sampler.sample(logits: logits)
    processor?.didSample(token: y)       // tells the ring about the new token
    return y
}

The processor?.didSample(token: y) call forwards to each penalty context,
which calls ring.append(token).

TokenRing.append — this is where the crash happens:

// Libraries/MLXLMCommon/Evaluate.swift
mutating func append(_ token: MLXArray) {
    let mask = positions .== Int32(writeIndex)
    buffer = MLX.where(mask, token.asType(.int32), buffer)  // CRASH
    ...
}

MLX.where(mask, token, buffer) needs all three arrays to have compatible
shapes. mask is [20], token is [1], but buffer is [3206]. You can't
broadcast [20] and [3206] together — they're different sizes and neither is
1 — so MLX crashes with:

[broadcast_shapes] Shapes (20) and (3206) cannot be broadcast.

How the buffer got corrupted

The corruption happened when processor?.prompt(input.text.tokens) was called.
This eventually calls TokenRing.loadPrompt with the prompt tokens.

The tokens have shape [1, 3187] (2D, because the VLM added a batch
dimension). Here's what the old loadPrompt code did with that:

let n = prompt.dim(0)

dim(0) returns the size of the first dimension. For [1, 3187], the first
dimension is the batch — size 1. So n = 1, not 3187.

The code then thinks there's only 1 token in the prompt, enters the "small
prompt" branch, and:

  1. Calls reshaped(-1) which flattens all 3,187 tokens into a 1D array
  2. Creates 19 zeros of padding (capacity 20 minus the supposed 1 token)
  3. Concatenates them: [3187] + [19] = [3206]

The buffer, which should be [20] (the ring capacity), becomes [3206].
The number 3206 in the error is literally the prompt length (3187) plus the
padding (19).

The fix

Flatten the prompt to 1D before measuring its length:

let promptTokens = prompt.reshaped(-1).asType(.int32)
let n = promptTokens.dim(0)

Now dim(0) returns 3187 regardless of whether the input was [3187] or
[1, 3187]. The rest of the function correctly keeps only the last 20 tokens
in the buffer.

Defensive fix in append

The append method now reshapes the token to a scalar (a single number with no
dimensions) before using it in MLX.where:

buffer = MLX.where(mask, token.reshaped().asType(.int32), buffer)

A scalar broadcasts with any shape, so this can never cause a shape mismatch.
This is a safety net — if any future model or sampler returns a token in an
unexpected shape, append will handle it gracefully.

Fix 2: Stale position state between inferences

The problem

The Qwen3.5 VLM's language model stores two pieces of state on the model object
itself: precomputedPositionIds and ropeDeltas. These are computed during the
first inference to help the model understand where each token sits in the
sequence — its "position." They are cached so the model doesn't have to
recompute them on every token.

The problem is that these values persist between inference calls. If your app
runs a second inference reusing the same model but with a fresh KV cache (the
model's memory of what it has processed so far), the stale position values from
the first inference are still sitting on the model. They have the wrong shape
for the new input, causing a broadcast error in the attention layer.

The existing code only cleared this state conditionally — inside the else
branch of prepare, which runs when there are no images. This left edge cases
where stale state could survive into the next inference.

The fix

Move resetPositionState() to the top of Qwen35.prepare, before any
branching, and remove the old conditional call:

public func prepare(...) throws -> PrepareResult {
    languageModel.resetPositionState()

    let inputIds = input.text.tokens
    ...

This ensures position state is always fresh at the start of every inference,
regardless of whether images are present or whether a previous inference left
stale values behind.

The old conditional resetPositionState() in the else branch was removed
because it's now redundant — nothing between the new unconditional reset at the
top of prepare and the old conditional one sets precomputedPositionIds or
ropeDeltas. Those fields are only set later, inside languageModel's forward
pass.

Summary of changes

File Change
Libraries/MLXLMCommon/Evaluate.swift TokenRing.loadPrompt: flatten prompt to 1D before computing length
Libraries/MLXLMCommon/Evaluate.swift TokenRing.append: reshape token to scalar before MLX.where
Libraries/MLXVLM/Models/Qwen35.swift Qwen35.prepare: move resetPositionState() to the top, remove old conditional call
Tests/MLXLMTests/SampleTests.swift 4 new tests for 2D prompts and non-scalar tokens

Copilot AI review requested due to automatic review settings April 5, 2026 20:38

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes two [broadcast_shapes] crash scenarios affecting Qwen3.5 VLM inference by normalizing token shapes for penalty processing and ensuring per-inference position state is reset in the Qwen3.5 language model.

Changes:

  • Flatten prompt tokens in TokenRing.loadPrompt so 2D VLM prompts ([1, seq_len]) are handled correctly.
  • Reshape sampled tokens to a scalar in TokenRing.append to avoid MLX.where broadcast mismatches.
  • Reset Qwen3.5 position state at the start of every prepare call and add targeted regression tests.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
Libraries/MLXLMCommon/Evaluate.swift Makes TokenRing robust to batched prompts and non-scalar sampled tokens to prevent broadcast crashes.
Libraries/MLXVLM/Models/Qwen35.swift Resets cached position-related state unconditionally per inference to avoid stale-shape crashes across requests.
Tests/MLXLMTests/SampleTests.swift Adds regression tests covering 2D prompts and batched sampled tokens across penalty processors.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@atdrendel atdrendel merged commit 22deba9 into main Apr 5, 2026
6 checks passed
Satchitananda pushed a commit to Satchitananda/mlx-swift-lm that referenced this pull request Apr 15, 2026
VLM prompts are shaped [1, n]. TokenRing.loadPrompt called dim(0) which
returns 1 (batch size) instead of n (sequence length). This caused the
ring buffer to be sized as (n + capacity - 1) while positions stays at
(capacity,). The first append() call then hit MLX.where broadcasting
(capacity,) against (n + capacity - 1,) → fatal broadcast_shapes crash.

Affects all VLM models: Gemma4, Qwen3-VL, Qwen2.5-VL, GLM-OCR, etc.

Fix: flatten prompt with reshaped(-1) before dim(0) so n is always the
true sequence length regardless of whether the input is 1-D or 2-D.

Original fix discovered independently by:
- spokvulcan: ml-explore#170 (open PR)
- atdrendel: shareup#8 (merged to shareup fork)
- bulentongun: Satchitananda/mlx-swift-lm commit 9169796

Fixes ml-explore#168, ml-explore#191

Co-Authored-By: spokvulcan <spokvulcan@users.noreply.github.com>
Co-Authored-By: atdrendel <atdrendel@users.noreply.github.com>
Co-Authored-By: bulentongun <bulentongun@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants