From a47f33498ebb4edeb5d4582a789b761f40ffc8ba Mon Sep 17 00:00:00 2001 From: Terence Pae Date: Sun, 1 Feb 2026 06:19:55 -0800 Subject: [PATCH 01/10] added prefix matching for flexible parsing --- .../MLXLMCommon/Tool/ToolCallFormat.swift | 20 +++++++++++++------ Tests/MLXLMTests/ToolTests.swift | 9 +++++++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift index 3b39bf608..b110e464e 100644 --- a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift +++ b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift @@ -97,15 +97,23 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { /// - Parameter modelType: The `model_type` value from config.json /// - Returns: The appropriate `ToolCallFormat`, or `nil` to use the default format public static func infer(from modelType: String) -> ToolCallFormat? { - switch modelType.lowercased() { - case "lfm2", "lfm2_moe": + let type = modelType.lowercased() + + // LFM2 family (lfm2, lfm2_moe, lfm2_5, lfm25, etc.) + if type.hasPrefix("lfm2") { return .lfm2 - case "glm4", "glm4_moe", "glm4_moe_lite": + } + + // GLM4 family (glm4, glm4_moe, glm4_moe_lite, etc.) + if type.hasPrefix("glm4") { return .glm4 - case "gemma": + } + + // Gemma + if type == "gemma" { return .gemma - default: - return nil } + + return nil } } diff --git a/Tests/MLXLMTests/ToolTests.swift b/Tests/MLXLMTests/ToolTests.swift index b2b312b8b..eb36ef4d9 100644 --- a/Tests/MLXLMTests/ToolTests.swift +++ b/Tests/MLXLMTests/ToolTests.swift @@ -317,15 +317,20 @@ struct ToolTests { @Test("Test ToolCallFormat Inference from Model Type") func testToolCallFormatInference() throws { - // LFM2 models + // LFM2 models (prefix matching) #expect(ToolCallFormat.infer(from: "lfm2") == .lfm2) #expect(ToolCallFormat.infer(from: "LFM2") == .lfm2) #expect(ToolCallFormat.infer(from: "lfm2_moe") == .lfm2) + #expect(ToolCallFormat.infer(from: "lfm2_5") == .lfm2) + #expect(ToolCallFormat.infer(from: "LFM2_5") == .lfm2) + #expect(ToolCallFormat.infer(from: "lfm25") == .lfm2) - // GLM4 models + // GLM4 models (prefix matching) #expect(ToolCallFormat.infer(from: "glm4") == .glm4) #expect(ToolCallFormat.infer(from: "glm4_moe") == .glm4) #expect(ToolCallFormat.infer(from: "glm4_moe_lite") == .glm4) + #expect(ToolCallFormat.infer(from: "glm4_5") == .glm4) + #expect(ToolCallFormat.infer(from: "GLM4_5") == .glm4) // Gemma models #expect(ToolCallFormat.infer(from: "gemma") == .gemma) From 6be5dfcbd4205253279f0ed1ba7c7142cd8a8568 Mon Sep 17 00:00:00 2001 From: Terence Pae Date: Sun, 1 Feb 2026 06:48:50 -0800 Subject: [PATCH 02/10] convert to pythonic tool converter --- .../Tool/Parsers/PythonicToolCallParser.swift | 100 ++++++++++++++++++ .../MLXLMCommon/Tool/ToolCallFormat.swift | 7 +- Tests/MLXLMTests/ToolTests.swift | 90 +++++++++++++++- 3 files changed, 190 insertions(+), 7 deletions(-) create mode 100644 Libraries/MLXLMCommon/Tool/Parsers/PythonicToolCallParser.swift diff --git a/Libraries/MLXLMCommon/Tool/Parsers/PythonicToolCallParser.swift b/Libraries/MLXLMCommon/Tool/Parsers/PythonicToolCallParser.swift new file mode 100644 index 000000000..73daec340 --- /dev/null +++ b/Libraries/MLXLMCommon/Tool/Parsers/PythonicToolCallParser.swift @@ -0,0 +1,100 @@ +// Copyright © 2025 Apple Inc. + +import Foundation + +/// Parser for Pythonic tool call format: [function_name(arg1='value1', arg2='value2')] +/// Used by LFM2.5 and similar models that output tool calls in Python function call syntax. +/// Reference: LiquidAI LFM2.5 chat template format +public struct PythonicToolCallParser: ToolCallParser, Sendable { + public let startTag: String? + public let endTag: String? + + public init(startTag: String? = nil, endTag: String? = nil) { + self.startTag = startTag + self.endTag = endTag + } + + public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? { + var text = content + + // Strip tags if present + if let start = startTag, let startRange = text.range(of: start) { + text = String(text[startRange.upperBound...]) + } + if let end = endTag, let endRange = text.range(of: end) { + text = String(text[.. [String: any Sendable] { + var arguments: [String: any Sendable] = [:] + + // Pattern for key=value pairs, handling quoted strings with possible commas inside + // This handles: key='value', key="value", key=123, key=True, key=None + let argPattern = #"(\w+)\s*=\s*('(?:[^'\\]|\\.)*'|"(?:[^"\\]|\\.)*"|[^,\)]+)"# + + guard let regex = try? NSRegularExpression(pattern: argPattern, options: []) else { + return arguments + } + + let matches = regex.matches( + in: argsString, options: [], range: NSRange(argsString.startIndex..., in: argsString)) + + for match in matches { + guard let keyRange = Range(match.range(at: 1), in: argsString), + let valueRange = Range(match.range(at: 2), in: argsString) + else { continue } + + let key = String(argsString[keyRange]) + var value = String(argsString[valueRange]).trimmingCharacters(in: .whitespaces) + + // Remove surrounding quotes if present + if (value.hasPrefix("'") && value.hasSuffix("'")) + || (value.hasPrefix("\"") && value.hasSuffix("\"")) + { + value = String(value.dropFirst().dropLast()) + // Unescape escaped quotes + value = value.replacingOccurrences(of: "\\'", with: "'") + value = value.replacingOccurrences(of: "\\\"", with: "\"") + value = value.replacingOccurrences(of: "\\\\", with: "\\") + } + + // Convert value based on schema type if available + arguments[key] = convertParameterValue( + value, paramName: key, funcName: funcName, tools: tools) + } + + return arguments + } +} diff --git a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift index b110e464e..33a79b03d 100644 --- a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift +++ b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift @@ -42,8 +42,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { /// Example: `{"name": "func", "arguments": {...}}` case json - /// LFM2 JSON format with model-specific tags. - /// Example: `<|tool_call_start|>{"name": "func", "arguments": {...}}<|tool_call_end|>` + /// LFM2/LFM2.5 Pythonic format with model-specific tags. + /// Example: `<|tool_call_start|>[func(arg='value')]<|tool_call_end|>` case lfm2 /// XML function format used by Qwen3 Coder. @@ -75,7 +75,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { case .json: return JSONToolCallParser(startTag: "", endTag: "") case .lfm2: - return JSONToolCallParser(startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + return PythonicToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") case .xmlFunction: return XMLFunctionParser() case .glm4: diff --git a/Tests/MLXLMTests/ToolTests.swift b/Tests/MLXLMTests/ToolTests.swift index eb36ef4d9..96f124fd5 100644 --- a/Tests/MLXLMTests/ToolTests.swift +++ b/Tests/MLXLMTests/ToolTests.swift @@ -100,8 +100,8 @@ struct ToolTests { #expect(toolCall.function.arguments["location"] == .string("Paris")) } - @Test("Test JSON Tool Call Parser - LFM2 Tags") - func testJSONParserLFM2Tags() throws { + @Test("Test JSON Tool Call Parser - Custom Tags") + func testJSONParserCustomTags() throws { let parser = JSONToolCallParser( startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") let content = @@ -113,11 +113,93 @@ struct ToolTests { #expect(toolCall.function.arguments["query"] == .string("swift programming")) } - @Test("Test LFM2 Format via ToolCallProcessor") + // MARK: - Pythonic Format Tests (LFM2/LFM2.5) + + @Test("Test Pythonic Tool Call Parser - Basic") + func testPythonicParserBasic() throws { + let parser = PythonicToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[get_weather(location='Paris', unit='celsius')]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("Paris")) + #expect(toolCall.function.arguments["unit"] == .string("celsius")) + } + + @Test("Test Pythonic Tool Call Parser - Double Quotes") + func testPythonicParserDoubleQuotes() throws { + let parser = PythonicToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[search(query=\"swift programming\")]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "search") + #expect(toolCall.function.arguments["query"] == .string("swift programming")) + } + + @Test("Test Pythonic Tool Call Parser - Without Brackets") + func testPythonicParserWithoutBrackets() throws { + let parser = PythonicToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>current_time(timezone='UTC')<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "current_time") + #expect(toolCall.function.arguments["timezone"] == .string("UTC")) + } + + @Test("Test Pythonic Tool Call Parser - No Arguments") + func testPythonicParserNoArguments() throws { + let parser = PythonicToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[current_time()]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "current_time") + #expect(toolCall.function.arguments.isEmpty) + } + + @Test("Test Pythonic Tool Call Parser - Type Conversion") + func testPythonicParserTypeConversion() throws { + let parser = PythonicToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let tools: [[String: any Sendable]] = [ + [ + "function": [ + "name": "set_temperature", + "parameters": [ + "properties": [ + "value": ["type": "integer"], + "enabled": ["type": "boolean"], + ] + ], + ] as [String: any Sendable] + ] + ] + let content = + "<|tool_call_start|>[set_temperature(value='25', enabled='true')]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: tools)) + + #expect(toolCall.function.name == "set_temperature") + #expect(toolCall.function.arguments["value"] == .int(25)) + #expect(toolCall.function.arguments["enabled"] == .bool(true)) + } + + @Test("Test LFM2 Format via ToolCallProcessor - Pythonic") func testLFM2FormatProcessor() throws { let processor = ToolCallProcessor(format: .lfm2) let content = - "<|tool_call_start|>{\"name\": \"calculator\", \"arguments\": {\"expression\": \"2+2\"}}<|tool_call_end|>" + "<|tool_call_start|>[calculator(expression='2+2')]<|tool_call_end|>" _ = processor.processChunk(content) From 49110a13949e1895e8ecbad59e3f225d88f4bd3e Mon Sep 17 00:00:00 2001 From: Terence Pae Date: Tue, 3 Mar 2026 12:15:12 -0800 Subject: [PATCH 03/10] added qwen3_5 tool calling support --- .../Tool/Parsers/XMLFunctionParser.swift | 9 ++- .../MLXLMCommon/Tool/ToolCallFormat.swift | 11 ++++ Tests/MLXLMTests/ToolTests.swift | 63 +++++++++++++++++++ 3 files changed, 80 insertions(+), 3 deletions(-) diff --git a/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift b/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift index 71c91f5e4..d4e21d124 100644 --- a/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift +++ b/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift @@ -5,10 +5,13 @@ import Foundation /// Parser for XML function format: value /// Reference: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/qwen3_coder.py public struct XMLFunctionParser: ToolCallParser, Sendable { - public let startTag: String? = nil // Inline format - no wrapper tags - public let endTag: String? = nil + public let startTag: String? + public let endTag: String? - public init() {} + public init(startTag: String? = nil, endTag: String? = nil) { + self.startTag = startTag + self.endTag = endTag + } public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? { // Pattern: — [\s\S] matches newlines diff --git a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift index 33a79b03d..f0c5029f1 100644 --- a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift +++ b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift @@ -66,6 +66,10 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { /// Example: `v` case minimaxM2 = "minimax_m2" + /// Qwen3.5 format: XML function syntax wrapped in tool_call tags. + /// Example: `value` + case qwen35 = "qwen3_5" + // MARK: - Factory Methods /// Create the appropriate parser for this format. @@ -87,6 +91,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { return KimiK2ToolCallParser() case .minimaxM2: return MiniMaxM2ToolCallParser() + case .qwen35: + return XMLFunctionParser(startTag: "", endTag: "") } } @@ -115,6 +121,11 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { return .gemma } + // Qwen3.5 family (qwen3_5, qwen3_5_moe, etc.) + if type.hasPrefix("qwen3_5") { + return .qwen35 + } + return nil } } diff --git a/Tests/MLXLMTests/ToolTests.swift b/Tests/MLXLMTests/ToolTests.swift index 634c2913e..28d5f7526 100644 --- a/Tests/MLXLMTests/ToolTests.swift +++ b/Tests/MLXLMTests/ToolTests.swift @@ -284,6 +284,63 @@ struct ToolTests { #expect(toolCall.function.arguments["location"] == .string("Tokyo")) } + // MARK: - Qwen3.5 Format Tests (XML Function with tool_call wrapper) + + @Test("Test Qwen3.5 XML Function Parser - With tool_call Tags") + func testQwen35Parser() throws { + let parser = XMLFunctionParser(startTag: "", endTag: "") + let content = """ + + + + San Francisco + + + celsius + + + + """ + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("San Francisco")) + #expect(toolCall.function.arguments["unit"] == .string("celsius")) + } + + @Test("Test Qwen3.5 Format via ToolCallProcessor") + func testQwen35FormatProcessor() throws { + let processor = ToolCallProcessor(format: .qwen35) + let chunks: [String] = [ + "", "\n\n", + "\nTokyo\n", + "\n\n", + ] + + for chunk in chunks { + _ = processor.processChunk(chunk) + } + + #expect(processor.toolCalls.count == 1) + let toolCall = try #require(processor.toolCalls.first) + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("Tokyo")) + } + + @Test("Test Qwen3.5 Format - No Arguments") + func testQwen35FormatNoArgs() throws { + let processor = ToolCallProcessor(format: .qwen35) + let content = "\n\n\n" + + _ = processor.processChunk(content) + + #expect(processor.toolCalls.count == 1) + let toolCall = try #require(processor.toolCalls.first) + #expect(toolCall.function.name == "get_current_datetime") + #expect(toolCall.function.arguments.isEmpty) + } + // MARK: - GLM4 Format Tests @Test("Test GLM4 Tool Call Parser") @@ -422,6 +479,7 @@ struct ToolTests { #expect(ToolCallFormat.gemma.rawValue == "gemma") #expect(ToolCallFormat.kimiK2.rawValue == "kimi_k2") #expect(ToolCallFormat.minimaxM2.rawValue == "minimax_m2") + #expect(ToolCallFormat.qwen35.rawValue == "qwen3_5") // Test round-trip via raw value for format in ToolCallFormat.allCases { @@ -452,6 +510,11 @@ struct ToolTests { #expect(ToolCallFormat.infer(from: "gemma") == .gemma) #expect(ToolCallFormat.infer(from: "GEMMA") == .gemma) + // Qwen3.5 models (prefix matching) + #expect(ToolCallFormat.infer(from: "qwen3_5") == .qwen35) + #expect(ToolCallFormat.infer(from: "qwen3_5_moe") == .qwen35) + #expect(ToolCallFormat.infer(from: "QWEN3_5") == .qwen35) + // Unknown models should return nil (use default) #expect(ToolCallFormat.infer(from: "llama") == nil) #expect(ToolCallFormat.infer(from: "qwen2") == nil) From b672eac55cf72f6a7b7fbb6cf25cee9e785b6e8c Mon Sep 17 00:00:00 2001 From: Terence Pae Date: Tue, 3 Mar 2026 12:25:02 -0800 Subject: [PATCH 04/10] added detection at vlm level --- Libraries/MLXVLM/VLMModelFactory.swift | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 42e594e8b..b64084ba5 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -329,6 +329,11 @@ public final class VLMModelFactory: ModelFactory { var mutableConfiguration = configuration mutableConfiguration.eosTokenIds = eosTokenIds + // Auto-detect tool call format from model type if not explicitly set + if mutableConfiguration.toolCallFormat == nil { + mutableConfiguration.toolCallFormat = ToolCallFormat.infer(from: baseConfig.modelType) + } + // Load tokenizer, processor config, and weights in parallel using async let. // Note: loadProcessorConfig does synchronous I/O but is marked async to enable // parallel scheduling. This may briefly block a cooperative thread pool thread, From 2bf261d5eb51b5827f84bd206a4b42db93d0baa7 Mon Sep 17 00:00:00 2001 From: Terence Pae Date: Tue, 10 Mar 2026 11:59:34 -0700 Subject: [PATCH 05/10] updated per feedback, added nemotron --- .../Tool/Parsers/XMLFunctionParser.swift | 2 +- .../MLXLMCommon/Tool/ToolCallFormat.swift | 19 +++++++-------- Tests/MLXLMTests/ToolTests.swift | 24 ++++++++++--------- .../mlx-swift-lm/references/tool-calling.md | 2 +- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift b/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift index d4e21d124..2598fc571 100644 --- a/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift +++ b/Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift @@ -8,7 +8,7 @@ public struct XMLFunctionParser: ToolCallParser, Sendable { public let startTag: String? public let endTag: String? - public init(startTag: String? = nil, endTag: String? = nil) { + public init(startTag: String, endTag: String) { self.startTag = startTag self.endTag = endTag } diff --git a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift index 361740179..f5bcba711 100644 --- a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift +++ b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift @@ -70,8 +70,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { /// Example: `<|tool_call_start|>[func(arg='value')]<|tool_call_end|>` case lfm2 - /// XML function format used by Qwen3 Coder. - /// Example: `value` + /// XML function format used by Nemotron, Qwen3 Coder, Qwen3.5, and similar models. + /// Example: `value` case xmlFunction = "xml_function" /// GLM4 format with arg_key/arg_value tags. @@ -90,10 +90,6 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { /// Example: `v` case minimaxM2 = "minimax_m2" - /// Qwen3.5 format: XML function syntax wrapped in tool_call tags. - /// Example: `value` - case qwen35 = "qwen3_5" - /// Mistral V11+ format with [TOOL_CALLS] and [ARGS] delimiters. /// Example: `[TOOL_CALLS]get_weather [ARGS]{"location": "Tokyo"}` case mistral @@ -110,7 +106,7 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { return PythonicToolCallParser( startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") case .xmlFunction: - return XMLFunctionParser() + return XMLFunctionParser(startTag: "", endTag: "") case .glm4: return GLM4ToolCallParser() case .gemma: @@ -119,8 +115,6 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { return KimiK2ToolCallParser() case .minimaxM2: return MiniMaxM2ToolCallParser() - case .qwen35: - return XMLFunctionParser(startTag: "", endTag: "") case .mistral: return MistralToolCallParser() } @@ -151,9 +145,14 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { return .gemma } + // Nemotron family (nemotron_h, etc.) + if type.hasPrefix("nemotron") { + return .xmlFunction + } + // Qwen3.5 family (qwen3_5, qwen3_5_moe, etc.) if type.hasPrefix("qwen3_5") { - return .qwen35 + return .xmlFunction } // Mistral3 family (mistral3, mistral3_text, etc.) diff --git a/Tests/MLXLMTests/ToolTests.swift b/Tests/MLXLMTests/ToolTests.swift index ef497907a..7b5c348ee 100644 --- a/Tests/MLXLMTests/ToolTests.swift +++ b/Tests/MLXLMTests/ToolTests.swift @@ -213,7 +213,7 @@ struct ToolTests { @Test("Test XML Function Parser - Qwen3 Coder Format") func testXMLFunctionParser() throws { - let parser = XMLFunctionParser() + let parser = XMLFunctionParser(startTag: "", endTag: "") let content = "Tokyocelsius" @@ -226,7 +226,7 @@ struct ToolTests { @Test("Test XML Function Parser - With Type Conversion") func testXMLFunctionParserTypeConversion() throws { - let parser = XMLFunctionParser() + let parser = XMLFunctionParser(startTag: "", endTag: "") let tools: [[String: any Sendable]] = [ [ "function": [ @@ -252,7 +252,7 @@ struct ToolTests { @Test("Test XML Function Parser - Multiline Content (Qwen3.5 style)") func testXMLFunctionParserMultiline() throws { - let parser = XMLFunctionParser() + let parser = XMLFunctionParser(startTag: "", endTag: "") // Qwen3.5 models generate newlines between the XML tags let content = """ @@ -269,7 +269,7 @@ struct ToolTests { @Test("Test XML Function Parser - Multiline Parameters") func testXMLFunctionParserMultilineParams() throws { - let parser = XMLFunctionParser() + let parser = XMLFunctionParser(startTag: "", endTag: "") let content = """ @@ -311,7 +311,7 @@ struct ToolTests { @Test("Test Qwen3.5 Format via ToolCallProcessor") func testQwen35FormatProcessor() throws { - let processor = ToolCallProcessor(format: .qwen35) + let processor = ToolCallProcessor(format: .xmlFunction) let chunks: [String] = [ "", "\n\n", "\nTokyo\n", @@ -330,7 +330,7 @@ struct ToolTests { @Test("Test Qwen3.5 Format - No Arguments") func testQwen35FormatNoArgs() throws { - let processor = ToolCallProcessor(format: .qwen35) + let processor = ToolCallProcessor(format: .xmlFunction) let content = "\n\n\n" _ = processor.processChunk(content) @@ -479,7 +479,6 @@ struct ToolTests { #expect(ToolCallFormat.gemma.rawValue == "gemma") #expect(ToolCallFormat.kimiK2.rawValue == "kimi_k2") #expect(ToolCallFormat.minimaxM2.rawValue == "minimax_m2") - #expect(ToolCallFormat.qwen35.rawValue == "qwen3_5") #expect(ToolCallFormat.mistral.rawValue == "mistral") // Test round-trip via raw value @@ -511,12 +510,15 @@ struct ToolTests { #expect(ToolCallFormat.infer(from: "gemma") == .gemma) #expect(ToolCallFormat.infer(from: "GEMMA") == .gemma) + // Nemotron models (prefix matching) + #expect(ToolCallFormat.infer(from: "nemotron_h") == .xmlFunction) + #expect(ToolCallFormat.infer(from: "NEMOTRON_H") == .xmlFunction) + // Qwen3.5 models (prefix matching) - #expect(ToolCallFormat.infer(from: "qwen3_5") == .qwen35) - #expect(ToolCallFormat.infer(from: "qwen3_5_moe") == .qwen35) - #expect(ToolCallFormat.infer(from: "QWEN3_5") == .qwen35) + #expect(ToolCallFormat.infer(from: "qwen3_5") == .xmlFunction) + #expect(ToolCallFormat.infer(from: "qwen3_5_moe") == .xmlFunction) + #expect(ToolCallFormat.infer(from: "QWEN3_5") == .xmlFunction) - // Unknown models should return nil (use default) // Mistral3 models (prefix matching) #expect(ToolCallFormat.infer(from: "mistral3") == .mistral) #expect(ToolCallFormat.infer(from: "Mistral3") == .mistral) diff --git a/skills/mlx-swift-lm/references/tool-calling.md b/skills/mlx-swift-lm/references/tool-calling.md index 9adb8f30b..556a67dad 100644 --- a/skills/mlx-swift-lm/references/tool-calling.md +++ b/skills/mlx-swift-lm/references/tool-calling.md @@ -30,7 +30,7 @@ mlx-swift-lm supports function calling / tool use with multiple model-specific f |--------|--------|----------------| | `.json` | Llama, Qwen, most models | `{"name":"f","arguments":{...}}` | | `.lfm2` | LFM2 | `<\|tool_call_start\|>{"name":"f",...}<\|tool_call_end\|>` | -| `.xmlFunction` | Qwen3 Coder | `v` | +| `.xmlFunction` | Nemotron, Qwen3 Coder, Qwen3.5 | `v` | | `.glm4` | GLM4 | `funckv` | | `.gemma` | Gemma | `call:name{key:value}` | | `.kimiK2` | Kimi K2 | `functions.name:0<\|tool_call_argument_begin\|>{...}` | From 34b828b1231e7e9523084fa7a5b031f93c2155d3 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Wed, 11 Mar 2026 12:19:12 +0100 Subject: [PATCH 06/10] Add Nemotron tool integration test --- .../ToolCallIntegrationTests.swift | 158 +++++++++++++++--- 1 file changed, 136 insertions(+), 22 deletions(-) diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift index 3f40d9882..576d0489c 100644 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift @@ -23,12 +23,14 @@ public class ToolCallIntegrationTests: XCTestCase { static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit" static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" + static let nemotronModelId = "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-4bit" // MARK: - Shared State nonisolated(unsafe) static var lfm2Container: ModelContainer? nonisolated(unsafe) static var glm4Container: ModelContainer? nonisolated(unsafe) static var mistral3Container: ModelContainer? + nonisolated(unsafe) static var nemotronContainer: ModelContainer? // MARK: - Tool Schema @@ -65,42 +67,31 @@ public class ToolCallIntegrationTests: XCTestCase { let lfm2Expectation = XCTestExpectation(description: "Load LFM2") let glm4Expectation = XCTestExpectation(description: "Load GLM4") let mistral3Expectation = XCTestExpectation(description: "Load Mistral3") + let nemotronExpectation = XCTestExpectation(description: "Load Nemotron") Task { - do { - lfm2Container = try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: lfm2ModelId) - ) - } catch { - print("Failed to load LFM2: \(error)") - } + lfm2Container = await loadModelContainer(modelId: lfm2ModelId) lfm2Expectation.fulfill() } Task { - do { - glm4Container = try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: glm4ModelId) - ) - } catch { - print("Failed to load GLM4: \(error)") - } + glm4Container = await loadModelContainer(modelId: glm4ModelId) glm4Expectation.fulfill() } Task { - do { - mistral3Container = try await VLMModelFactory.shared.loadContainer( - configuration: .init(id: mistral3ModelId) - ) - } catch { - print("Failed to load Mistral3: \(error)") - } + mistral3Container = await loadModelContainer(modelId: mistral3ModelId) mistral3Expectation.fulfill() } + Task { + nemotronContainer = await loadModelContainer(modelId: nemotronModelId) + nemotronExpectation.fulfill() + } + _ = XCTWaiter.wait( - for: [lfm2Expectation, glm4Expectation, mistral3Expectation], timeout: 600) + for: [lfm2Expectation, glm4Expectation, mistral3Expectation, nemotronExpectation], + timeout: 600) } // MARK: - LFM2 Tests @@ -325,8 +316,131 @@ public class ToolCallIntegrationTests: XCTestCase { } } + // MARK: - Nemotron Tests + + func testNemotronToolCallFormatAutoDetection() async throws { + guard let container = Self.nemotronContainer else { + throw XCTSkip("Nemotron model not available") + } + + let config = await container.configuration + XCTAssertEqual( + config.toolCallFormat, .xmlFunction, + "Nemotron model should auto-detect .xmlFunction tool call format" + ) + } + + func testNemotronEndToEndToolCallGeneration() async throws { + guard let container = Self.nemotronContainer else { + throw XCTSkip("Nemotron model not available") + } + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." + ), + .user("What's the weather in Tokyo?"), + ], + tools: Self.weatherToolSchema, + additionalContext: ["enable_thinking": false] + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 150 + ) + + print("Nemotron Output: \(result)") + print("Nemotron Tool Calls: \(toolCalls)") + + if !toolCalls.isEmpty { + let toolCall = toolCalls.first! + XCTAssertEqual(toolCall.function.name, "get_weather") + if let location = toolCall.function.arguments["location"]?.asString { + XCTAssertTrue( + location.lowercased().contains("tokyo"), + "Expected location to contain 'Tokyo', got: \(location)" + ) + } + } + } + + func testNemotronMultipleToolCallGeneration() async throws { + guard let container = Self.nemotronContainer else { + throw XCTSkip("Nemotron model not available") + } + + let multiToolSchema: [[String: any Sendable]] = + Self.weatherToolSchema + [ + [ + "type": "function", + "function": [ + "name": "get_time", + "description": "Get the current time in a given timezone", + "parameters": [ + "type": "object", + "properties": [ + "timezone": [ + "type": "string", + "description": + "The timezone, e.g. America/New_York, Asia/Tokyo", + ] as [String: any Sendable] + ] as [String: any Sendable], + "required": ["timezone"], + ] as [String: any Sendable], + ] as [String: any Sendable], + ] + ] + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. Always use the available tools to answer questions. Call multiple tools in parallel when needed." + ), + .user( + "What's the weather in Tokyo and what time is it there?" + ), + ], + tools: multiToolSchema + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 600 + ) + + print("Nemotron Output: \(result)") + print("Nemotron Calls: \(toolCalls)") + + let validNames: Set = ["get_weather", "get_time"] + for toolCall in toolCalls { + XCTAssertTrue( + validNames.contains(toolCall.function.name), + "Unexpected tool call: \(toolCall.function.name)" + ) + } + + if toolCalls.count > 1 { + print("Successfully parsed \(toolCalls.count) tool calls from Nemotron") + } + } + // MARK: - Helper Methods + private static func loadModelContainer(modelId: String) async -> ModelContainer? { + do { + return try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: modelId) + ) + } catch { + print("Failed to load model \(modelId): \(error)") + return nil + } + } + /// Generate text and collect any tool calls private func generateWithTools( container: ModelContainer, From ec015a2f1a194fe06c274ec28ad13848acaf0d7f Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Wed, 11 Mar 2026 12:35:24 +0100 Subject: [PATCH 07/10] Use IntegrationTestModels inside of ToolCallIntegrationTests --- .../IntegrationTestModels.swift | 68 +++++++++ .../ToolCallIntegrationTests.swift | 131 ++++++------------ 2 files changed, 110 insertions(+), 89 deletions(-) diff --git a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift b/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift index fbd84d164..d91e3117a 100644 --- a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift +++ b/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift @@ -8,14 +8,26 @@ import MLXVLM enum IntegrationTestModelIDs { static let llmModelId = "mlx-community/Qwen3-4B-Instruct-2507-4bit" static let vlmModelId = "mlx-community/Qwen3-VL-4B-Instruct-4bit" + + static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit" + static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" + static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" + static let nemotronModelId = "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-4bit" } actor IntegrationTestModels { static let shared = IntegrationTestModels() + private init() {} + private var llmTask: Task? private var vlmTask: Task? + private var lfm2Task: Task? + private var glm4Task: Task? + private var mistral3Task: Task? + private var nemotronTask: Task? + func llmContainer() async throws -> ModelContainer { if let task = llmTask { return try await task.value @@ -43,4 +55,60 @@ actor IntegrationTestModels { vlmTask = task return try await task.value } + + func lfm2Container() async throws -> ModelContainer { + if let task = lfm2Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.lfm2ModelId) + ) + } + lfm2Task = task + return try await task.value + } + + func glm4Container() async throws -> ModelContainer { + if let task = glm4Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.glm4ModelId) + ) + } + glm4Task = task + return try await task.value + } + + func mistral3Container() async throws -> ModelContainer { + if let task = mistral3Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.mistral3ModelId) + ) + } + mistral3Task = task + return try await task.value + } + + func nemotronContainer() async throws -> ModelContainer { + if let task = nemotronTask { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.nemotronModelId) + ) + } + nemotronTask = task + return try await task.value + } } diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift index 576d0489c..81f4223c0 100644 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift @@ -17,21 +17,6 @@ import XCTest /// - LFM2: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/default.py /// - GLM4: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/glm47.py public class ToolCallIntegrationTests: XCTestCase { - - // MARK: - Model IDs - - static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit" - static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" - static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" - static let nemotronModelId = "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-4bit" - - // MARK: - Shared State - - nonisolated(unsafe) static var lfm2Container: ModelContainer? - nonisolated(unsafe) static var glm4Container: ModelContainer? - nonisolated(unsafe) static var mistral3Container: ModelContainer? - nonisolated(unsafe) static var nemotronContainer: ModelContainer? - // MARK: - Tool Schema static let weatherToolSchema: [[String: any Sendable]] = [ @@ -59,49 +44,52 @@ public class ToolCallIntegrationTests: XCTestCase { ] ] - // MARK: - Setup + // MARK: - Model Loading - override public class func setUp() { - super.setUp() - - let lfm2Expectation = XCTestExpectation(description: "Load LFM2") - let glm4Expectation = XCTestExpectation(description: "Load GLM4") - let mistral3Expectation = XCTestExpectation(description: "Load Mistral3") - let nemotronExpectation = XCTestExpectation(description: "Load Nemotron") - - Task { - lfm2Container = await loadModelContainer(modelId: lfm2ModelId) - lfm2Expectation.fulfill() + private var lfm2Container: ModelContainer { + get async throws { + do { + return try await IntegrationTestModels.shared.lfm2Container() + } catch { + throw XCTSkip("LFM2 model not available: \(error)") + } } + } - Task { - glm4Container = await loadModelContainer(modelId: glm4ModelId) - glm4Expectation.fulfill() + private var glm4Container: ModelContainer { + get async throws { + do { + return try await IntegrationTestModels.shared.glm4Container() + } catch { + throw XCTSkip("GLM4 model not available: \(error)") + } } + } - Task { - mistral3Container = await loadModelContainer(modelId: mistral3ModelId) - mistral3Expectation.fulfill() + private var mistral3Container: ModelContainer { + get async throws { + do { + return try await IntegrationTestModels.shared.mistral3Container() + } catch { + throw XCTSkip("Mistral3 model not available: \(error)") + } } + } - Task { - nemotronContainer = await loadModelContainer(modelId: nemotronModelId) - nemotronExpectation.fulfill() + private var nemotronContainer: ModelContainer { + get async throws { + do { + return try await IntegrationTestModels.shared.nemotronContainer() + } catch { + throw XCTSkip("Nemotron model not available: \(error)") + } } - - _ = XCTWaiter.wait( - for: [lfm2Expectation, glm4Expectation, mistral3Expectation, nemotronExpectation], - timeout: 600) } // MARK: - LFM2 Tests func testLFM2ToolCallFormatAutoDetection() async throws { - guard let container = Self.lfm2Container else { - throw XCTSkip("LFM2 model not available") - } - - let config = await container.configuration + let config = try await lfm2Container.configuration XCTAssertEqual( config.toolCallFormat, .lfm2, "LFM2 model should auto-detect .lfm2 tool call format" @@ -109,9 +97,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testLFM2EndToEndToolCallGeneration() async throws { - guard let container = Self.lfm2Container else { - throw XCTSkip("LFM2 model not available") - } + let container = try await lfm2Container // Create input with tool schema let input = UserInput( @@ -151,11 +137,7 @@ public class ToolCallIntegrationTests: XCTestCase { // MARK: - GLM4 Tests func testGLM4ToolCallFormatAutoDetection() async throws { - guard let container = Self.glm4Container else { - throw XCTSkip("GLM4 model not available") - } - - let config = await container.configuration + let config = try await glm4Container.configuration XCTAssertEqual( config.toolCallFormat, .glm4, "GLM4 model should auto-detect .glm4 tool call format" @@ -163,9 +145,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testGLM4EndToEndToolCallGeneration() async throws { - guard let container = Self.glm4Container else { - throw XCTSkip("GLM4 model not available") - } + let container = try await glm4Container // Create input with tool schema let input = UserInput( @@ -205,11 +185,7 @@ public class ToolCallIntegrationTests: XCTestCase { // MARK: - Mistral3 Tests func testMistral3ToolCallFormatAutoDetection() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } - - let config = await container.configuration + let config = try await mistral3Container.configuration XCTAssertEqual( config.toolCallFormat, .mistral, "Mistral3 model should auto-detect .mistral tool call format" @@ -217,9 +193,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testMistral3EndToEndToolCallGeneration() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } + let container = try await mistral3Container let input = UserInput( chat: [ @@ -254,9 +228,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testMistral3MultipleToolCallGeneration() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } + let container = try await mistral3Container let multiToolSchema: [[String: any Sendable]] = Self.weatherToolSchema + [ @@ -319,11 +291,7 @@ public class ToolCallIntegrationTests: XCTestCase { // MARK: - Nemotron Tests func testNemotronToolCallFormatAutoDetection() async throws { - guard let container = Self.nemotronContainer else { - throw XCTSkip("Nemotron model not available") - } - - let config = await container.configuration + let config = try await nemotronContainer.configuration XCTAssertEqual( config.toolCallFormat, .xmlFunction, "Nemotron model should auto-detect .xmlFunction tool call format" @@ -331,9 +299,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testNemotronEndToEndToolCallGeneration() async throws { - guard let container = Self.nemotronContainer else { - throw XCTSkip("Nemotron model not available") - } + let container = try await nemotronContainer let input = UserInput( chat: [ @@ -368,9 +334,7 @@ public class ToolCallIntegrationTests: XCTestCase { } func testNemotronMultipleToolCallGeneration() async throws { - guard let container = Self.nemotronContainer else { - throw XCTSkip("Nemotron model not available") - } + let container = try await nemotronContainer let multiToolSchema: [[String: any Sendable]] = Self.weatherToolSchema + [ @@ -430,17 +394,6 @@ public class ToolCallIntegrationTests: XCTestCase { // MARK: - Helper Methods - private static func loadModelContainer(modelId: String) async -> ModelContainer? { - do { - return try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: modelId) - ) - } catch { - print("Failed to load model \(modelId): \(error)") - return nil - } - } - /// Generate text and collect any tool calls private func generateWithTools( container: ModelContainer, From 6efe31c78427f8ec7761d73cbf214332ab66bc09 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Wed, 11 Mar 2026 12:40:15 +0100 Subject: [PATCH 08/10] Skip Nemotron tests in ToolCallIntegrationTests by default --- Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift index 81f4223c0..cc6b8ccd4 100644 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift @@ -17,6 +17,7 @@ import XCTest /// - LFM2: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/default.py /// - GLM4: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/glm47.py public class ToolCallIntegrationTests: XCTestCase { + // MARK: - Tool Schema static let weatherToolSchema: [[String: any Sendable]] = [ @@ -78,6 +79,8 @@ public class ToolCallIntegrationTests: XCTestCase { private var nemotronContainer: ModelContainer { get async throws { + try XCTSkipIf(true, "Nemotron model is opt-in only because of its size") + do { return try await IntegrationTestModels.shared.nemotronContainer() } catch { From 233c467aa4a2d88d8c013d7afb05cb2c6f67fe12 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Wed, 11 Mar 2026 13:11:24 +0100 Subject: [PATCH 09/10] Add Qwen3.5 tool call integration tests --- .../IntegrationTestModels.swift | 16 +++ .../ToolCallIntegrationTests.swift | 114 ++++++++++++++++++ 2 files changed, 130 insertions(+) diff --git a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift b/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift index d91e3117a..7aa32e794 100644 --- a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift +++ b/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift @@ -13,6 +13,7 @@ enum IntegrationTestModelIDs { static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" static let nemotronModelId = "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-4bit" + static let qwen35ModelId = "mlx-community/Qwen3.5-2B-4bit" } actor IntegrationTestModels { @@ -27,6 +28,7 @@ actor IntegrationTestModels { private var glm4Task: Task? private var mistral3Task: Task? private var nemotronTask: Task? + private var qwen35Task: Task? func llmContainer() async throws -> ModelContainer { if let task = llmTask { @@ -111,4 +113,18 @@ actor IntegrationTestModels { nemotronTask = task return try await task.value } + + func qwen35Container() async throws -> ModelContainer { + if let task = qwen35Task { + return try await task.value + } + + let task = Task { + try await LLMModelFactory.shared.loadContainer( + configuration: .init(id: IntegrationTestModelIDs.qwen35ModelId) + ) + } + qwen35Task = task + return try await task.value + } } diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift index cc6b8ccd4..b0d47bec6 100644 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift @@ -89,6 +89,16 @@ public class ToolCallIntegrationTests: XCTestCase { } } + private var qwen35Container: ModelContainer { + get async throws { + do { + return try await IntegrationTestModels.shared.qwen35Container() + } catch { + throw XCTSkip("Qwen3.5 model not available: \(error)") + } + } + } + // MARK: - LFM2 Tests func testLFM2ToolCallFormatAutoDetection() async throws { @@ -395,6 +405,110 @@ public class ToolCallIntegrationTests: XCTestCase { } } + // MARK: - Qwen3.5 Tests + + func testQwen35ToolCallFormatAutoDetection() async throws { + let config = try await qwen35Container.configuration + XCTAssertEqual( + config.toolCallFormat, .xmlFunction, + "Qwen3.5 model should auto-detect .xmlFunction tool call format" + ) + } + + func testQwen35EndToEndToolCallGeneration() async throws { + let container = try await qwen35Container + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." + ), + .user("What's the weather in Tokyo?"), + ], + tools: Self.weatherToolSchema + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 150 + ) + + print("Qwen3.5 Output: \(result)") + print("Qwen3.5 Tool Calls: \(toolCalls)") + + if !toolCalls.isEmpty { + let toolCall = toolCalls.first! + XCTAssertEqual(toolCall.function.name, "get_weather") + if let location = toolCall.function.arguments["location"]?.asString { + XCTAssertTrue( + location.lowercased().contains("tokyo"), + "Expected location to contain 'Tokyo', got: \(location)" + ) + } + } + } + + func testQwen35MultipleToolCallGeneration() async throws { + let container = try await qwen35Container + + let multiToolSchema: [[String: any Sendable]] = + Self.weatherToolSchema + [ + [ + "type": "function", + "function": [ + "name": "get_time", + "description": "Get the current time in a given timezone", + "parameters": [ + "type": "object", + "properties": [ + "timezone": [ + "type": "string", + "description": + "The timezone, e.g. America/New_York, Asia/Tokyo", + ] as [String: any Sendable] + ] as [String: any Sendable], + "required": ["timezone"], + ] as [String: any Sendable], + ] as [String: any Sendable], + ] + ] + + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. Always use the available tools to answer questions. Call multiple tools in parallel when needed." + ), + .user( + "What's the weather in Tokyo and what time is it there?" + ), + ], + tools: multiToolSchema, + additionalContext: ["enable_thinking": true] + ) + + let (result, toolCalls) = try await generateWithTools( + container: container, + input: input, + maxTokens: 300 + ) + + print("Qwen3.5 Output: \(result)") + print("Qwen3.5 Calls: \(toolCalls)") + + let validNames: Set = ["get_weather", "get_time"] + for toolCall in toolCalls { + XCTAssertTrue( + validNames.contains(toolCall.function.name), + "Unexpected tool call: \(toolCall.function.name)" + ) + } + + if toolCalls.count > 1 { + print("Successfully parsed \(toolCalls.count) tool calls from Qwen3.5") + } + } + // MARK: - Helper Methods /// Generate text and collect any tool calls From f888b37f57b599b01837f6afc51f57b6c1042291 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Wed, 11 Mar 2026 13:34:28 +0100 Subject: [PATCH 10/10] Disable Nemotron thinking because it uses way too many tokens to think --- Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift index b0d47bec6..4c143f57d 100644 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift @@ -380,7 +380,8 @@ public class ToolCallIntegrationTests: XCTestCase { "What's the weather in Tokyo and what time is it there?" ), ], - tools: multiToolSchema + tools: multiToolSchema, + additionalContext: ["enable_thinking": false] ) let (result, toolCalls) = try await generateWithTools(