Skip to content

Commit caeed1b

Browse files
committed
Add support for Mistral3 tool calling
1 parent 78a2457 commit caeed1b

6 files changed

Lines changed: 311 additions & 6 deletions

File tree

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,8 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
10801080
}
10811081
}
10821082

1083+
handler.onGenerationEnd(emit: continuation.yield)
1084+
10831085
let now = Date.timeIntervalSinceReferenceDate
10841086
let generateTime = now - start
10851087

@@ -1292,6 +1294,11 @@ private protocol TokenLoopHandler: Sendable {
12921294
emit: (sending Output) -> AsyncStream<Output>.Continuation.YieldResult
12931295
) -> Bool
12941296

1297+
/// Called after the token loop finishes, before the info event.
1298+
mutating func onGenerationEnd(
1299+
emit: (sending Output) -> AsyncStream<Output>.Continuation.YieldResult
1300+
)
1301+
12951302
func infoEvent(_ info: GenerateCompletionInfo) -> Output
12961303
}
12971304

@@ -1337,6 +1344,16 @@ private struct TextToolTokenLoopHandler: TokenLoopHandler, @unchecked Sendable {
13371344
true
13381345
}
13391346

1347+
mutating func onGenerationEnd(
1348+
emit: (sending Generation) -> AsyncStream<Generation>.Continuation.YieldResult
1349+
) {
1350+
toolCallProcessor.flush()
1351+
while let toolCall = toolCallProcessor.toolCalls.first {
1352+
toolCallProcessor.toolCalls.removeFirst()
1353+
_ = emit(.toolCall(toolCall))
1354+
}
1355+
}
1356+
13401357
func infoEvent(_ info: GenerateCompletionInfo) -> Generation {
13411358
.info(info)
13421359
}
@@ -1365,6 +1382,10 @@ private struct RawTokenLoopHandler: TokenLoopHandler {
13651382
return true
13661383
}
13671384

1385+
mutating func onGenerationEnd(
1386+
emit: (sending TokenGeneration) -> AsyncStream<TokenGeneration>.Continuation.YieldResult
1387+
) {}
1388+
13681389
func infoEvent(_ info: GenerateCompletionInfo) -> TokenGeneration {
13691390
.info(info)
13701391
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
5+
/// Parser for Mistral V13 tool call format: `[TOOL_CALLS]name[ARGS]{"json_args"}`
6+
///
7+
/// This format is used by Mistral3/Ministral-3 2512 models and Devstral 2.
8+
/// The special tokens `[TOOL_CALLS]` (token ID 9) and `[ARGS]` (ID 32) are used
9+
/// as delimiters. Multiple tool calls use repeated `[TOOL_CALLS]` tokens.
10+
///
11+
/// Also handles the older V11 format which includes an optional `[CALL_ID]`
12+
/// between the function name and `[ARGS]` (V13 does not use `[CALL_ID]`).
13+
///
14+
/// Examples:
15+
/// - `[TOOL_CALLS]get_weather[ARGS]{"location": "Tokyo"}`
16+
/// - `[TOOL_CALLS]fn1[ARGS]{...}[TOOL_CALLS]fn2[ARGS]{...}` (multiple calls)
17+
///
18+
/// The end tag is `</s>` (EOS token). Since stop tokens are intercepted at the
19+
/// token ID level before detokenization, the EOS text never reaches the processor
20+
/// — tool calls are extracted via `ToolCallProcessor.flush()` at generation end.
21+
public struct MistralToolCallParser: ToolCallParser, Sendable {
22+
public let startTag: String? = "[TOOL_CALLS]"
23+
public let endTag: String? = "</s>"
24+
25+
public init() {}
26+
27+
public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? {
28+
var text = content
29+
30+
// Strip [TOOL_CALLS] prefix if present
31+
if let range = text.range(of: "[TOOL_CALLS]") {
32+
text = String(text[range.upperBound...])
33+
}
34+
35+
text = text.trimmingCharacters(in: .whitespacesAndNewlines)
36+
37+
// Split on [ARGS] to get function name and arguments
38+
guard let argsRange = text.range(of: "[ARGS]") else {
39+
return nil
40+
}
41+
42+
var namePart = String(text[..<argsRange.lowerBound])
43+
.trimmingCharacters(in: .whitespacesAndNewlines)
44+
let argsPart = String(text[argsRange.upperBound...])
45+
.trimmingCharacters(in: .whitespacesAndNewlines)
46+
47+
// Handle optional [CALL_ID] between name and [ARGS]
48+
if let callIdRange = namePart.range(of: "[CALL_ID]") {
49+
namePart = String(namePart[..<callIdRange.lowerBound])
50+
.trimmingCharacters(in: .whitespacesAndNewlines)
51+
}
52+
53+
guard !namePart.isEmpty else { return nil }
54+
55+
// Parse arguments as JSON using deserialize from ParserUtilities
56+
let arguments = deserialize(argsPart)
57+
58+
guard let argsDict = arguments as? [String: any Sendable] else {
59+
return nil
60+
}
61+
62+
return ToolCall(
63+
function: ToolCall.Function(name: namePart, arguments: argsDict)
64+
)
65+
}
66+
}

Libraries/MLXLMCommon/Tool/ToolCallFormat.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
6666
/// Example: `<invoke name="f"><parameter name="k">v</parameter></invoke>`
6767
case minimaxM2 = "minimax_m2"
6868

69+
/// Mistral V11+ format with [TOOL_CALLS] and [ARGS] delimiters.
70+
/// Example: `[TOOL_CALLS]get_weather [ARGS]{"location": "Tokyo"}`
71+
case mistral
72+
6973
// MARK: - Factory Methods
7074

7175
/// Create the appropriate parser for this format.
@@ -87,6 +91,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
8791
return KimiK2ToolCallParser()
8892
case .minimaxM2:
8993
return MiniMaxM2ToolCallParser()
94+
case .mistral:
95+
return MistralToolCallParser()
9096
}
9197
}
9298

@@ -115,6 +121,11 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
115121
return .gemma
116122
}
117123

124+
// Mistral3 family (mistral3, mistral3_text, etc.)
125+
if type.hasPrefix("mistral3") {
126+
return .mistral
127+
}
128+
118129
return nil
119130
}
120131
}

Libraries/MLXLMCommon/Tool/ToolCallProcessor.swift

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,46 @@ public class ToolCallProcessor {
6767

6868
// MARK: - Public Methods
6969

70+
/// Flush any buffered content, attempting to parse it as tool call(s).
71+
///
72+
/// Call this when generation ends (e.g., on EOS token) to handle formats
73+
/// whose end tag is never delivered as text (e.g., Mistral where `</s>`
74+
/// is intercepted at the token ID level).
75+
///
76+
/// For formats with end tags that appear in the text stream, the buffer
77+
/// will already be empty at generation end, making this a no-op.
78+
public func flush() {
79+
guard state == .collectingToolCall || state == .potentialToolCall else { return }
80+
guard !toolCallBuffer.isEmpty else {
81+
state = .normal
82+
return
83+
}
84+
85+
if let startTag = parser.startTag {
86+
// Split buffer into individual tool call segments to handle
87+
// multiple tool calls (e.g., Mistral V13):
88+
// "[TOOL_CALLS]fn1[ARGS]{...}[TOOL_CALLS]fn2[ARGS]{...}"
89+
// → ["fn1[ARGS]{...}", "fn2[ARGS]{...}"]
90+
let segments =
91+
toolCallBuffer
92+
.components(separatedBy: startTag)
93+
.filter { !$0.isEmpty }
94+
95+
for segment in segments {
96+
if let toolCall = parser.parse(content: segment, tools: tools) {
97+
toolCalls.append(toolCall)
98+
}
99+
}
100+
} else {
101+
if let toolCall = parser.parse(content: toolCallBuffer, tools: tools) {
102+
toolCalls.append(toolCall)
103+
}
104+
}
105+
106+
toolCallBuffer = ""
107+
state = .normal
108+
}
109+
70110
/// Process a generated text chunk and extract any tool call content.
71111
/// - Parameter chunk: The text chunk to process
72112
/// - Returns: Regular text that should be displayed (non-tool call content), or `nil` if buffering

Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,15 @@ public class ToolCallIntegrationTests: XCTestCase {
213213

214214
// MARK: - Mistral3 Tests
215215

216-
func testMistral3ToolCallFormatDefaultsToJSON() async throws {
216+
func testMistral3ToolCallFormatAutoDetection() async throws {
217217
guard let container = Self.mistral3Container else {
218218
throw XCTSkip("Mistral3 model not available")
219219
}
220220

221221
let config = await container.configuration
222-
// Mistral3 uses the default JSON tool call format (infer returns nil)
223-
let format = config.toolCallFormat ?? .json
224222
XCTAssertEqual(
225-
format, .json,
226-
"Mistral3 model should use default .json tool call format"
223+
config.toolCallFormat, .mistral,
224+
"Mistral3 model should auto-detect .mistral tool call format"
227225
)
228226
}
229227

@@ -264,6 +262,69 @@ public class ToolCallIntegrationTests: XCTestCase {
264262
}
265263
}
266264

265+
func testMistral3MultipleToolCallGeneration() async throws {
266+
guard let container = Self.mistral3Container else {
267+
throw XCTSkip("Mistral3 model not available")
268+
}
269+
270+
let multiToolSchema: [[String: any Sendable]] =
271+
Self.weatherToolSchema + [
272+
[
273+
"type": "function",
274+
"function": [
275+
"name": "get_time",
276+
"description": "Get the current time in a given timezone",
277+
"parameters": [
278+
"type": "object",
279+
"properties": [
280+
"timezone": [
281+
"type": "string",
282+
"description":
283+
"The timezone, e.g. America/New_York, Asia/Tokyo",
284+
] as [String: any Sendable]
285+
] as [String: any Sendable],
286+
"required": ["timezone"],
287+
] as [String: any Sendable],
288+
] as [String: any Sendable],
289+
]
290+
]
291+
292+
let input = UserInput(
293+
chat: [
294+
.system(
295+
"You are a helpful assistant with access to tools. Always use the available tools to answer questions. Call multiple tools in parallel when needed."
296+
),
297+
.user(
298+
"What's the weather in Tokyo and what time is it there?"
299+
),
300+
],
301+
tools: multiToolSchema
302+
)
303+
304+
let (result, toolCalls) = try await generateWithTools(
305+
container: container,
306+
input: input,
307+
maxTokens: 150
308+
)
309+
310+
print("Mistral3 Output: \(result)")
311+
print("Mistral3 Calls: \(toolCalls)")
312+
313+
// Verify all returned tool calls have valid names from our schema
314+
let validNames: Set<String> = ["get_weather", "get_time"]
315+
for toolCall in toolCalls {
316+
XCTAssertTrue(
317+
validNames.contains(toolCall.function.name),
318+
"Unexpected tool call: \(toolCall.function.name)"
319+
)
320+
}
321+
322+
// If the model made multiple calls, verify we got more than one
323+
if toolCalls.count > 1 {
324+
print("Successfully parsed \(toolCalls.count) tool calls from Mistral3")
325+
}
326+
}
327+
267328
// MARK: - Helper Methods
268329

269330
/// Generate text and collect any tool calls

0 commit comments

Comments
 (0)