Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -334,19 +334,21 @@ struct TokenRing {

/// Bulk-load from a prompt. Keeps the last `capacity` tokens.
mutating func loadPrompt(_ prompt: MLXArray) {
let n = prompt.dim(0)
let promptTokens = prompt.asType(.int32)
// Flatten first so dim(0) always returns the token count,
// even when the prompt has a batch dimension (e.g. VLM tokens are [1, seq_len]).
let promptTokens = prompt.reshaped(-1).asType(.int32)
let n = promptTokens.dim(0)
if n <= capacity {
if n < capacity {
let padding = MLXArray.zeros([capacity - n], type: Int32.self)
buffer = concatenated([promptTokens.reshaped(-1), padding])
buffer = concatenated([promptTokens, padding])
} else {
buffer = promptTokens.reshaped(-1)
buffer = promptTokens
}
count = n
writeIndex = n % capacity
} else {
buffer = promptTokens[(-capacity)...].reshaped(-1)
buffer = promptTokens[(-capacity)...]
count = capacity
writeIndex = 0
}
Expand All @@ -355,7 +357,7 @@ struct TokenRing {
/// Append a single token using GPU-only mask write (no CPU←GPU sync).
mutating func append(_ token: MLXArray) {
let mask = positions .== Int32(writeIndex)
buffer = MLX.where(mask, token.asType(.int32), buffer)
buffer = MLX.where(mask, token.reshaped().asType(.int32), buffer)
writeIndex = (writeIndex + 1) % capacity
count = min(count + 1, capacity)
}
Expand Down
4 changes: 2 additions & 2 deletions Libraries/MLXVLM/Models/Qwen35.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,8 @@ public class Qwen35: Module, VLMModel {
cache: [any KVCache],
windowSize _: Int?
) throws -> PrepareResult {
languageModel.resetPositionState()

let inputIds = input.text.tokens

var pixelValues: MLXArray?
Expand Down Expand Up @@ -1145,8 +1147,6 @@ public class Qwen35: Module, VLMModel {
videoTokenIndex: config.videoTokenIndex
)
inputEmbeddings = mergedEmbeds
} else {
languageModel.resetPositionState()
}

let typedCache = castCache(cache)
Expand Down
82 changes: 82 additions & 0 deletions Tests/MLXLMTests/SampleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,86 @@ public class SampleTests: XCTestCase {
XCTAssertEqual(values[2], 0.0, accuracy: 1e-4)
XCTAssertEqual(values[3], -0.5, accuracy: 1e-4)
}

// MARK: - 2D prompt handling (VLM models produce [1, seq_len] tokens)

func testPresencePenaltyContextHandles2DPrompt() {
var processor = PresencePenaltyContext(presencePenalty: 0.5, presenceContextSize: 5)
// VLM processors create 2D tokens [1, seq_len] via expandedDimensions(axis: 0)
let prompt2D = MLXArray([0, 0, 0, 1, 1]).expandedDimensions(axis: 0)
processor.prompt(prompt2D)

let logits = MLXArray.zeros([1, 4], type: Float.self)
let processed = processor.process(logits: logits)
let values = processed[0].asArray(Float.self)

// Same expected values as testPresencePenaltyContextPenalizesUniqueSeenTokens
XCTAssertEqual(values[0], -0.5, accuracy: 1e-6)
XCTAssertEqual(values[1], -0.5, accuracy: 1e-6)
XCTAssertEqual(values[2], 0.0, accuracy: 1e-6)
XCTAssertEqual(values[3], 0.0, accuracy: 1e-6)
}

func testPresencePenaltyDidSampleWithBatchedToken() {
var processor = PresencePenaltyContext(presencePenalty: 0.5, presenceContextSize: 5)
processor.prompt(MLXArray([0, 0, 0, 1, 1]))

// Sampler returns shape [1] (not scalar) for [1, vocab_size] input
let token = MLXArray([2])
processor.didSample(token: token)

let logits = MLXArray.zeros([1, 4], type: Float.self)
let processed = processor.process(logits: logits)
let values = processed[0].asArray(Float.self)

// Token 2 should now also be penalized
XCTAssertEqual(values[0], -0.5, accuracy: 1e-6)
XCTAssertEqual(values[1], -0.5, accuracy: 1e-6)
XCTAssertEqual(values[2], -0.5, accuracy: 1e-6)
XCTAssertEqual(values[3], 0.0, accuracy: 1e-6)
}

func testRepetitionContextHandlesLong2DPrompt() {
var processor = RepetitionContext(repetitionPenalty: 2.0, repetitionContextSize: 3)
// 2D prompt with more tokens than context size
let prompt2D = MLXArray([10, 20, 30, 40, 50]).expandedDimensions(axis: 0)
processor.prompt(prompt2D)

// Only last 3 tokens (30, 40, 50) should be in the ring
let logits = MLXArray.ones([1, 100], type: Float.self)
let processed = processor.process(logits: logits)

// Positive logits divided by penalty: 1.0 / 2.0 = 0.5
XCTAssertEqual(processed[0, 30].item(Float.self), 0.5, accuracy: 1e-6)
XCTAssertEqual(processed[0, 40].item(Float.self), 0.5, accuracy: 1e-6)
XCTAssertEqual(processed[0, 50].item(Float.self), 0.5, accuracy: 1e-6)
// Token 10 should NOT be penalized (outside context window)
XCTAssertEqual(processed[0, 10].item(Float.self), 1.0, accuracy: 1e-6)
}

func testPenaltyProcessorComposesWith2DPrompt() {
var processor = GenerateParameters(
repetitionPenalty: 1.5, repetitionContextSize: 5,
presencePenalty: 0.5, presenceContextSize: 5,
frequencyPenalty: 0.25, frequencyContextSize: 5
).processor()
XCTAssertNotNil(processor)

// 2D prompt as produced by VLM processors
let prompt2D = MLXArray([0, 0, 0, 1, 1]).expandedDimensions(axis: 0)
processor?.prompt(prompt2D)
let logits = MLXArray([1.0 as Float, 0.5 as Float, 0.0 as Float, -0.5 as Float])[
.newAxis, .ellipsis
]
let processed = processor?.process(logits: logits)
guard let values = processed?[0].asArray(Float.self) else {
XCTFail("Expected processed logits")
return
}
// Same expected values as testGenerateParametersPenaltyProcessorComposesPenaltiesInOrder
XCTAssertEqual(values[0], -0.5833, accuracy: 1e-4)
XCTAssertEqual(values[1], -0.6667, accuracy: 1e-4)
XCTAssertEqual(values[2], 0.0, accuracy: 1e-4)
XCTAssertEqual(values[3], -0.5, accuracy: 1e-4)
}
}