Skip to content
Merged
9 changes: 6 additions & 3 deletions Libraries/MLXLMCommon/Tool/Parsers/XMLFunctionParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import Foundation
/// Parser for XML function format: <function=name><parameter=key>value</parameter></function>
/// Reference: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/qwen3_coder.py
public struct XMLFunctionParser: ToolCallParser, Sendable {
public let startTag: String? = nil // Inline format - no wrapper tags
public let endTag: String? = nil
public let startTag: String?
public let endTag: String?
Comment on lines 5 to +9

Copilot AI Mar 11, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc comment at the top still describes the XML function format as unwrapped (<function=...><parameter=...>...</parameter></function>), but the parser is now configured/used with <tool_call>...</tool_call> wrapper tags. Please update the comment to match the actual supported/expected format (and ideally mention whether unwrapped output is still supported).

Copilot uses AI. Check for mistakes.

public init() {}
public init(startTag: String, endTag: String) {
self.startTag = startTag
self.endTag = endTag
}

public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? {
// Pattern: <function=(content)</function> — [\s\S] matches newlines
Expand Down
16 changes: 13 additions & 3 deletions Libraries/MLXLMCommon/Tool/ToolCallFormat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
/// Example: `<|tool_call_start|>[func(arg='value')]<|tool_call_end|>`
case lfm2

/// XML function format used by Qwen3 Coder.
/// Example: `<function=name><parameter=key>value</parameter></function>`
/// XML function format used by Nemotron, Qwen3 Coder, Qwen3.5, and similar models.
/// Example: `<tool_call><function=name><parameter=key>value</parameter></function></tool_call>`
case xmlFunction = "xml_function"

/// GLM4 format with arg_key/arg_value tags.
Expand Down Expand Up @@ -106,7 +106,7 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
return PythonicToolCallParser(
startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
case .xmlFunction:
return XMLFunctionParser()
return XMLFunctionParser(startTag: "<tool_call>", endTag: "</tool_call>")

Copilot AI Mar 11, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ToolCallFormat.xmlFunction now creates XMLFunctionParser with <tool_call> wrapper tags, which makes ToolCallProcessor(format: .xmlFunction) require those wrapper tags to detect tool calls. However, the XMLFunction unit tests still use the unwrapped Qwen3 Coder-style content (<function=...></function>), so streaming extraction would fail for that style. Consider supporting both wrapped and unwrapped XML function outputs (e.g., a fallback path in ToolCallProcessor/parser) or update the claimed format/tests so they are consistent and don’t regress Qwen3 Coder.

Suggested change
return XMLFunctionParser(startTag: "<tool_call>", endTag: "</tool_call>")
return XMLFunctionParser(startTag: nil, endTag: nil)

Copilot uses AI. Check for mistakes.
case .glm4:
return GLM4ToolCallParser()
case .gemma:
Expand Down Expand Up @@ -145,6 +145,16 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
return .gemma
}

// Nemotron family (nemotron_h, etc.)
if type.hasPrefix("nemotron") {
return .xmlFunction
}

// Qwen3.5 family (qwen3_5, qwen3_5_moe, etc.)
if type.hasPrefix("qwen3_5") {
return .xmlFunction
}

// Mistral3 family (mistral3, mistral3_text, etc.)
if type.hasPrefix("mistral3") {
return .mistral
Expand Down
84 changes: 84 additions & 0 deletions Tests/MLXLMIntegrationTests/IntegrationTestModels.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,28 @@ import MLXVLM
enum IntegrationTestModelIDs {
static let llmModelId = "mlx-community/Qwen3-4B-Instruct-2507-4bit"
static let vlmModelId = "mlx-community/Qwen3-VL-4B-Instruct-4bit"

static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit"
static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit"
static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit"
static let nemotronModelId = "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-4bit"
static let qwen35ModelId = "mlx-community/Qwen3.5-2B-4bit"
}

actor IntegrationTestModels {
static let shared = IntegrationTestModels()

private init() {}

private var llmTask: Task<ModelContainer, Error>?
private var vlmTask: Task<ModelContainer, Error>?

private var lfm2Task: Task<ModelContainer, Error>?
private var glm4Task: Task<ModelContainer, Error>?
private var mistral3Task: Task<ModelContainer, Error>?
private var nemotronTask: Task<ModelContainer, Error>?
private var qwen35Task: Task<ModelContainer, Error>?

func llmContainer() async throws -> ModelContainer {
if let task = llmTask {
return try await task.value
Expand Down Expand Up @@ -43,4 +57,74 @@ actor IntegrationTestModels {
vlmTask = task
return try await task.value
}

func lfm2Container() async throws -> ModelContainer {
if let task = lfm2Task {
return try await task.value
}

let task = Task {
try await LLMModelFactory.shared.loadContainer(
configuration: .init(id: IntegrationTestModelIDs.lfm2ModelId)
)
}
lfm2Task = task
return try await task.value
}

func glm4Container() async throws -> ModelContainer {
if let task = glm4Task {
return try await task.value
}

let task = Task {
try await LLMModelFactory.shared.loadContainer(
configuration: .init(id: IntegrationTestModelIDs.glm4ModelId)
)
}
glm4Task = task
return try await task.value
}

func mistral3Container() async throws -> ModelContainer {
if let task = mistral3Task {
return try await task.value
}

let task = Task {
try await LLMModelFactory.shared.loadContainer(
configuration: .init(id: IntegrationTestModelIDs.mistral3ModelId)
)
}
mistral3Task = task
return try await task.value
}

func nemotronContainer() async throws -> ModelContainer {
if let task = nemotronTask {
return try await task.value
}

let task = Task {
try await LLMModelFactory.shared.loadContainer(
configuration: .init(id: IntegrationTestModelIDs.nemotronModelId)
)
}
nemotronTask = task
return try await task.value
}

func qwen35Container() async throws -> ModelContainer {
if let task = qwen35Task {
return try await task.value
}

let task = Task {
try await LLMModelFactory.shared.loadContainer(
configuration: .init(id: IntegrationTestModelIDs.qwen35ModelId)
)
}
qwen35Task = task
return try await task.value
}
}
Loading