diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 65ef1fd93..c29c146b4 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -334,19 +334,21 @@ struct TokenRing { /// Bulk-load from a prompt. Keeps the last `capacity` tokens. mutating func loadPrompt(_ prompt: MLXArray) { - let n = prompt.dim(0) - let promptTokens = prompt.asType(.int32) + // Flatten first so dim(0) always returns the token count, + // even when the prompt has a batch dimension (e.g. VLM tokens are [1, seq_len]). + let promptTokens = prompt.reshaped(-1).asType(.int32) + let n = promptTokens.dim(0) if n <= capacity { if n < capacity { let padding = MLXArray.zeros([capacity - n], type: Int32.self) - buffer = concatenated([promptTokens.reshaped(-1), padding]) + buffer = concatenated([promptTokens, padding]) } else { - buffer = promptTokens.reshaped(-1) + buffer = promptTokens } count = n writeIndex = n % capacity } else { - buffer = promptTokens[(-capacity)...].reshaped(-1) + buffer = promptTokens[(-capacity)...] count = capacity writeIndex = 0 } @@ -355,7 +357,7 @@ struct TokenRing { /// Append a single token using GPU-only mask write (no CPU←GPU sync). mutating func append(_ token: MLXArray) { let mask = positions .== Int32(writeIndex) - buffer = MLX.where(mask, token.asType(.int32), buffer) + buffer = MLX.where(mask, token.reshaped().asType(.int32), buffer) writeIndex = (writeIndex + 1) % capacity count = min(count + 1, capacity) } diff --git a/Libraries/MLXVLM/Models/Qwen35.swift b/Libraries/MLXVLM/Models/Qwen35.swift index 8a3c3984a..2e1a16ce3 100644 --- a/Libraries/MLXVLM/Models/Qwen35.swift +++ b/Libraries/MLXVLM/Models/Qwen35.swift @@ -1106,6 +1106,8 @@ public class Qwen35: Module, VLMModel { cache: [any KVCache], windowSize _: Int? ) throws -> PrepareResult { + languageModel.resetPositionState() + let inputIds = input.text.tokens var pixelValues: MLXArray? @@ -1145,8 +1147,6 @@ public class Qwen35: Module, VLMModel { videoTokenIndex: config.videoTokenIndex ) inputEmbeddings = mergedEmbeds - } else { - languageModel.resetPositionState() } let typedCache = castCache(cache) diff --git a/Tests/MLXLMTests/SampleTests.swift b/Tests/MLXLMTests/SampleTests.swift index cb9e44162..82865bea8 100644 --- a/Tests/MLXLMTests/SampleTests.swift +++ b/Tests/MLXLMTests/SampleTests.swift @@ -204,4 +204,86 @@ public class SampleTests: XCTestCase { XCTAssertEqual(values[2], 0.0, accuracy: 1e-4) XCTAssertEqual(values[3], -0.5, accuracy: 1e-4) } + + // MARK: - 2D prompt handling (VLM models produce [1, seq_len] tokens) + + func testPresencePenaltyContextHandles2DPrompt() { + var processor = PresencePenaltyContext(presencePenalty: 0.5, presenceContextSize: 5) + // VLM processors create 2D tokens [1, seq_len] via expandedDimensions(axis: 0) + let prompt2D = MLXArray([0, 0, 0, 1, 1]).expandedDimensions(axis: 0) + processor.prompt(prompt2D) + + let logits = MLXArray.zeros([1, 4], type: Float.self) + let processed = processor.process(logits: logits) + let values = processed[0].asArray(Float.self) + + // Same expected values as testPresencePenaltyContextPenalizesUniqueSeenTokens + XCTAssertEqual(values[0], -0.5, accuracy: 1e-6) + XCTAssertEqual(values[1], -0.5, accuracy: 1e-6) + XCTAssertEqual(values[2], 0.0, accuracy: 1e-6) + XCTAssertEqual(values[3], 0.0, accuracy: 1e-6) + } + + func testPresencePenaltyDidSampleWithBatchedToken() { + var processor = PresencePenaltyContext(presencePenalty: 0.5, presenceContextSize: 5) + processor.prompt(MLXArray([0, 0, 0, 1, 1])) + + // Sampler returns shape [1] (not scalar) for [1, vocab_size] input + let token = MLXArray([2]) + processor.didSample(token: token) + + let logits = MLXArray.zeros([1, 4], type: Float.self) + let processed = processor.process(logits: logits) + let values = processed[0].asArray(Float.self) + + // Token 2 should now also be penalized + XCTAssertEqual(values[0], -0.5, accuracy: 1e-6) + XCTAssertEqual(values[1], -0.5, accuracy: 1e-6) + XCTAssertEqual(values[2], -0.5, accuracy: 1e-6) + XCTAssertEqual(values[3], 0.0, accuracy: 1e-6) + } + + func testRepetitionContextHandlesLong2DPrompt() { + var processor = RepetitionContext(repetitionPenalty: 2.0, repetitionContextSize: 3) + // 2D prompt with more tokens than context size + let prompt2D = MLXArray([10, 20, 30, 40, 50]).expandedDimensions(axis: 0) + processor.prompt(prompt2D) + + // Only last 3 tokens (30, 40, 50) should be in the ring + let logits = MLXArray.ones([1, 100], type: Float.self) + let processed = processor.process(logits: logits) + + // Positive logits divided by penalty: 1.0 / 2.0 = 0.5 + XCTAssertEqual(processed[0, 30].item(Float.self), 0.5, accuracy: 1e-6) + XCTAssertEqual(processed[0, 40].item(Float.self), 0.5, accuracy: 1e-6) + XCTAssertEqual(processed[0, 50].item(Float.self), 0.5, accuracy: 1e-6) + // Token 10 should NOT be penalized (outside context window) + XCTAssertEqual(processed[0, 10].item(Float.self), 1.0, accuracy: 1e-6) + } + + func testPenaltyProcessorComposesWith2DPrompt() { + var processor = GenerateParameters( + repetitionPenalty: 1.5, repetitionContextSize: 5, + presencePenalty: 0.5, presenceContextSize: 5, + frequencyPenalty: 0.25, frequencyContextSize: 5 + ).processor() + XCTAssertNotNil(processor) + + // 2D prompt as produced by VLM processors + let prompt2D = MLXArray([0, 0, 0, 1, 1]).expandedDimensions(axis: 0) + processor?.prompt(prompt2D) + let logits = MLXArray([1.0 as Float, 0.5 as Float, 0.0 as Float, -0.5 as Float])[ + .newAxis, .ellipsis + ] + let processed = processor?.process(logits: logits) + guard let values = processed?[0].asArray(Float.self) else { + XCTFail("Expected processed logits") + return + } + // Same expected values as testGenerateParametersPenaltyProcessorComposesPenaltiesInOrder + XCTAssertEqual(values[0], -0.5833, accuracy: 1e-4) + XCTAssertEqual(values[1], -0.6667, accuracy: 1e-4) + XCTAssertEqual(values[2], 0.0, accuracy: 1e-4) + XCTAssertEqual(values[3], -0.5, accuracy: 1e-4) + } }