Skip to content

Commit 1820a0e

Browse files
committed
feat(tool): add qwen xml tool parser and rename tool call parser to tool parser
1 parent b2f50fa commit 1820a0e

File tree

15 files changed

+685
-190
lines changed

15 files changed

+685
-190
lines changed

CHANGELOG.md

Lines changed: 0 additions & 91 deletions
This file was deleted.

CLAUDE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pytest tests/unit/ -v
2929
pytest tests/unit/test_sglang.py -v
3030

3131
# Single test
32-
pytest tests/unit/test_tool_parser.py::TestHermesToolCallParser::test_parse_single_tool_call -v
32+
pytest tests/unit/test_tool_parser.py::TestHermesToolParser::test_parse_single_tool_call -v
3333

3434
# Unit tests with coverage
3535
pytest tests/unit/ -v --cov=src/strands_sglang --cov-report=html
@@ -49,7 +49,7 @@ The package lives in `src/strands_sglang/` with 5 core modules:
4949

5050
**TokenManager** (`token.py`) - Segment-based token accumulation for TITO. Tokens organized into PROMPT segments (loss_mask=0) and RESPONSE segments (loss_mask=1) matching multi-turn conversation structure. Exposes `token_ids`, `loss_mask`, `logprobs`, and `segments` properties.
5151

52-
**ToolCallParser** (`tool_parser.py`) - Abstract base with `HermesToolCallParser` implementation. Parses XML-wrapped JSON tool calls (`<tool_call>{"name": ..., "arguments": ...}</tool_call>`). Strict parsing: only catches JSONDecodeError, propagates failures as tool calls with `raw` content for model feedback. Excludes tool calls inside `<think>` blocks.
52+
**ToolParser** (`tool_parsers/`) - Abstract base with `HermesToolParser` and `QwenXMLToolParser` implementations. Parses tool calls from model output. Strict parsing: only catches JSONDecodeError, propagates failures as tool calls with `raw` content for model feedback. Excludes tool calls inside `<think>` blocks. New parsers self-register via `@register_tool_parser` decorator.
5353

5454
**ToolIterationLimiter** (`tool_limiter.py`) - Strands hook enforcing max tool iterations per invocation. One iteration = model response with tool calls + execution + result returned. Raises `MaxToolIterationsReachedError`.
5555

examples/math_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from strands_sglang import SGLangModel
2626
from strands_sglang.client import SGLangClient
27-
from strands_sglang.tool_parser import HermesToolCallParser
27+
from strands_sglang.tool_parsers import HermesToolParser
2828

2929

3030
async def main():
@@ -44,7 +44,7 @@ async def main():
4444
model = SGLangModel(
4545
tokenizer=tokenizer,
4646
client=client,
47-
tool_call_parser=HermesToolCallParser(),
47+
tool_parser=HermesToolParser(),
4848
model_id=model_id,
4949
params={"max_new_tokens": 16384}, # Limit response length
5050
)

examples/retokenization_drift/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from strands_sglang import SGLangModel
2929
from strands_sglang.client import SGLangClient
30-
from strands_sglang.tool_parser import HermesToolCallParser
30+
from strands_sglang.tool_parsers import HermesToolParser
3131

3232

3333
def find_drift_index(original: list[int], re_encoded: list[int]) -> int | None:
@@ -53,7 +53,7 @@ async def main():
5353
model = SGLangModel(
5454
tokenizer=tokenizer,
5555
client=client,
56-
tool_call_parser=HermesToolCallParser(),
56+
tool_parser=HermesToolParser(),
5757
model_id=model_id,
5858
params={"max_new_tokens": 32768},
5959
)

src/strands_sglang/sglang.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454

5555
from .client import SGLangClient
5656
from .token import TokenManager
57-
from .tool_parsers import HermesToolCallParser, ToolCallParser, ToolCallParseResult
57+
from .tool_parsers import HermesToolParser, ToolParser, ToolParseResult
5858

5959
if TYPE_CHECKING:
6060
from transformers import PreTrainedTokenizerBase
@@ -72,7 +72,7 @@ class SGLangModel(Model):
7272
tokenizer: HuggingFace tokenizer for encoding/decoding.
7373
client: SGLangClient for HTTP communication with the SGLang server.
7474
token_manager: Tracks tokens, logprobs, and masks for on-policy training.
75-
tool_call_parser: Parser for extracting tool calls from model output.
75+
tool_parser: Parser for extracting tool calls from model output.
7676
7777
Example:
7878
>>> from transformers import AutoTokenizer
@@ -99,22 +99,22 @@ def __init__(
9999
*,
100100
tokenizer: PreTrainedTokenizerBase,
101101
client: SGLangClient,
102-
tool_call_parser: ToolCallParser | None = None,
102+
tool_parser: ToolParser | None = None,
103103
**model_config: Unpack[SGLangConfig],
104104
) -> None:
105105
"""Initialize SGLang model provider.
106106
107107
Args:
108108
tokenizer: HuggingFace tokenizer for chat template and tokenization.
109109
client: SGLangClient for HTTP communication with the SGLang server.
110-
tool_call_parser: Parser for tool calls (default: HermesToolCallParser).
110+
tool_parser: Parser for tool calls (default: HermesToolParser).
111111
**model_config: See SGLangConfig for available options.
112112
"""
113113

114114
# Essential attributes
115115
self.tokenizer = tokenizer
116116
self.client = client
117-
self.tool_call_parser = tool_call_parser or HermesToolCallParser()
117+
self.tool_parser = tool_parser or HermesToolParser()
118118

119119
# Config
120120
self.config = dict(model_config)
@@ -274,8 +274,8 @@ def tokenize_prompt_messages(
274274
# Prepend message separator to align with chat template.
275275
# The model generates up to <|im_end|>, but the chat template adds
276276
# a separator (e.g., "\n") before the next <|im_start|>.
277-
if self.tool_call_parser:
278-
formatted = self.tool_call_parser.message_separator + formatted
277+
if self.tool_parser:
278+
formatted = self.tool_parser.message_separator + formatted
279279

280280
return self.tokenizer.encode(formatted, add_special_tokens=False)
281281

@@ -300,7 +300,7 @@ def _sort_tool_results(self, messages: Messages) -> Messages:
300300

301301
def _yield_tool_use_events(
302302
self,
303-
tool_calls: list[ToolCallParseResult],
303+
tool_calls: list[ToolParseResult],
304304
) -> Iterator[StreamEvent]:
305305
"""Yield toolUse stream events for parsed tool calls.
306306
@@ -424,7 +424,7 @@ async def stream(
424424
yield {"contentBlockStop": {}}
425425

426426
# Parse tool calls and yield events
427-
parsed_tool_calls = self.tool_call_parser.parse(text)
427+
parsed_tool_calls = self.tool_parser.parse(text)
428428
for event in self._yield_tool_use_events(parsed_tool_calls):
429429
yield event
430430

src/strands_sglang/tool_parsers/__init__.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,49 +21,27 @@
2121
- Only handle `JSONDecodeError` (can't extract anything from malformed JSON)
2222
- Let Strands validate arguments against tool schemas
2323
- Parse errors become tool calls with error info for model feedback
24-
"""
25-
26-
from typing import Any
27-
28-
from .base import UNKNOWN_TOOL_NAME, ToolCallParser, ToolCallParseResult
29-
from .hermes import HermesToolCallParser
30-
31-
# Parser registry
32-
TOOL_PARSER_REGISTRY: dict[str, type[ToolCallParser]] = {
33-
"hermes": HermesToolCallParser,
34-
}
35-
3624
37-
def get_tool_parser(name: str, **kwargs: Any) -> ToolCallParser:
38-
"""Get a tool parser by name.
39-
40-
Args:
41-
name: Parser name (e.g., "hermes").
42-
**kwargs: Arguments passed to the parser constructor.
43-
44-
Returns:
45-
Instantiated parser.
46-
47-
Raises:
48-
KeyError: If parser name is not registered.
25+
Adding a new parser:
26+
1. Create a new file (e.g., `my_parser.py`)
27+
2. Decorate the class with `@register_tool_parser("my_parser")`
28+
3. Import the module here to trigger registration
29+
"""
4930

50-
Example:
51-
>>> parser = get_tool_parser("hermes")
52-
>>> parser = get_tool_parser("hermes", think_tokens=None)
53-
"""
54-
if name not in TOOL_PARSER_REGISTRY:
55-
available = ", ".join(sorted(TOOL_PARSER_REGISTRY.keys()))
56-
raise KeyError(f"Unknown tool parser: {name!r}. Available: {available}")
57-
return TOOL_PARSER_REGISTRY[name](**kwargs)
31+
from .base import TOOL_PARSER_REGISTRY, UNKNOWN_TOOL_NAME, ToolParser, ToolParseResult, get_tool_parser
5832

33+
# Import parsers to trigger registration via @register_tool_parser decorator
34+
from .hermes import HermesToolParser
35+
from .qwen_xml import QwenXMLToolParser
5936

6037
__all__ = [
6138
# Base
62-
"ToolCallParseResult",
63-
"ToolCallParser",
39+
"ToolParseResult",
40+
"ToolParser",
6441
"UNKNOWN_TOOL_NAME",
6542
# Parsers
66-
"HermesToolCallParser",
43+
"HermesToolParser",
44+
"QwenXMLToolParser",
6745
# Registry
6846
"TOOL_PARSER_REGISTRY",
6947
"get_tool_parser",

src/strands_sglang/tool_parsers/base.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@
1919
import json
2020
from abc import ABC, abstractmethod
2121
from dataclasses import dataclass, field
22-
from typing import Any
22+
from typing import Any, Callable, TypeVar
2323

2424
# Fallback tool name when we can't identify which tool the model tried to call
2525
UNKNOWN_TOOL_NAME = "unknown_tool"
2626

27+
# Parser registry - populated by @register_tool_parser decorator
28+
TOOL_PARSER_REGISTRY: dict[str, type[ToolParser]] = {}
29+
30+
T = TypeVar("T", bound="ToolParser")
31+
2732

2833
@dataclass(frozen=True, slots=True)
29-
class ToolCallParseResult:
34+
class ToolParseResult:
3035
"""A parsed tool call request.
3136
3237
For successful parses: name and input are populated, raw is None.
@@ -55,7 +60,7 @@ def payload(self) -> str:
5560
return json.dumps(self.input)
5661

5762

58-
class ToolCallParser(ABC):
63+
class ToolParser(ABC):
5964
"""Base class for tool call parsers.
6065
6166
Subclasses implement `parse` to extract tool calls from model output.
@@ -82,7 +87,7 @@ def message_separator(self) -> str:
8287
return ""
8388

8489
@abstractmethod
85-
def parse(self, text: str) -> list[ToolCallParseResult]:
90+
def parse(self, text: str) -> list[ToolParseResult]:
8691
"""Parse tool calls from model output text.
8792
8893
Args:
@@ -104,3 +109,48 @@ def __call__(self, text: str) -> list[dict[str, Any]]:
104109
"""
105110
results = self.parse(text)
106111
return [{"id": tc.id, "name": tc.name, "input": tc.input} for tc in results if not tc.is_error]
112+
113+
114+
def register_tool_parser(name: str) -> Callable[[type[T]], type[T]]:
115+
"""Decorator to register a tool parser class.
116+
117+
Args:
118+
name: Registry name for the parser.
119+
120+
Returns:
121+
Decorator that registers the class and returns it unchanged.
122+
123+
Example:
124+
>>> @register_tool_parser("my_parser")
125+
... class MyParser(ToolParser):
126+
... def parse(self, text): ...
127+
"""
128+
129+
def decorator(cls: type[T]) -> type[T]:
130+
TOOL_PARSER_REGISTRY[name] = cls
131+
return cls
132+
133+
return decorator
134+
135+
136+
def get_tool_parser(name: str, **kwargs: Any) -> ToolParser:
137+
"""Get a tool parser by name.
138+
139+
Args:
140+
name: Parser name (e.g., "hermes", "qwen_xml").
141+
**kwargs: Arguments passed to the parser constructor.
142+
143+
Returns:
144+
Instantiated parser.
145+
146+
Raises:
147+
KeyError: If parser name is not registered.
148+
149+
Example:
150+
>>> parser = get_tool_parser("hermes")
151+
>>> parser = get_tool_parser("hermes", think_tokens=None)
152+
"""
153+
if name not in TOOL_PARSER_REGISTRY:
154+
available = ", ".join(sorted(TOOL_PARSER_REGISTRY.keys()))
155+
raise KeyError(f"Unknown tool parser: {name!r}. Available: {available}")
156+
return TOOL_PARSER_REGISTRY[name](**kwargs)

0 commit comments

Comments
 (0)