Skip to content

Commit b7fdc18

Browse files
committed
refactor(test): fix tool parser tests and add logprob regression test
1 parent 192951b commit b7fdc18

File tree

2 files changed

+51
-63
lines changed

2 files changed

+51
-63
lines changed

tests/integration/test_sglang_integration.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,12 @@ async def test_tool_call_generation(self, model, calculator_tool):
7777
system_prompt = "You are a calculator. Use the calculator tool for all math."
7878

7979
events = []
80-
async for event in model.stream(
81-
messages, tool_specs=[calculator_tool], system_prompt=system_prompt
82-
):
80+
async for event in model.stream(messages, tool_specs=[calculator_tool], system_prompt=system_prompt):
8381
events.append(event)
8482

8583
# Check for tool use events
8684
tool_starts = [e for e in events if "contentBlockStart" in e]
87-
tool_use_starts = [
88-
e for e in tool_starts if "toolUse" in e["contentBlockStart"].get("start", {})
89-
]
85+
tool_use_starts = [e for e in tool_starts if "toolUse" in e["contentBlockStart"].get("start", {})]
9086

9187
# Model should have called calculator tool
9288
if tool_use_starts:
@@ -101,9 +97,7 @@ async def test_multi_turn_with_tool_result(self, model, calculator_tool):
10197

10298
# First generation
10399
events = []
104-
async for event in model.stream(
105-
messages, tool_specs=[calculator_tool], system_prompt=system_prompt
106-
):
100+
async for event in model.stream(messages, tool_specs=[calculator_tool], system_prompt=system_prompt):
107101
events.append(event)
108102

109103
# Add assistant response and tool result
@@ -130,9 +124,7 @@ async def test_multi_turn_with_tool_result(self, model, calculator_tool):
130124

131125
# Second generation: model should respond after receiving tool result
132126
events = []
133-
async for event in model.stream(
134-
messages, tool_specs=[calculator_tool], system_prompt=system_prompt
135-
):
127+
async for event in model.stream(messages, tool_specs=[calculator_tool], system_prompt=system_prompt):
136128
events.append(event)
137129

138130
# Should have generated a response (content deltas or tool calls)
@@ -164,6 +156,22 @@ async def test_token_count_consistency(self, model):
164156
assert total_tokens == len(model.token_manager.loss_mask)
165157
assert total_tokens == len(model.token_manager.logprobs)
166158

159+
async def test_logprobs_no_none_when_return_logprob_enabled(self, model):
160+
"""Logprobs should never contain None when return_logprob=True (regression test for v0.2.0)."""
161+
# Ensure return_logprob is enabled (default is True)
162+
assert model.config.get("return_logprob", True) is True
163+
164+
messages = [{"role": "user", "content": [{"text": "Say hello"}]}]
165+
async for _ in model.stream(messages):
166+
pass
167+
168+
logprobs = model.token_manager.logprobs
169+
assert len(logprobs) > 0, "Should have logprobs after generation"
170+
assert all(lp is not None for lp in logprobs), (
171+
f"Logprobs should never contain None when return_logprob=True. "
172+
f"Found {logprobs.count(None)} None values out of {len(logprobs)} total."
173+
)
174+
167175
async def test_incremental_tokenization(self, model):
168176
"""Subsequent calls only tokenize new messages."""
169177
# First turn

tests/unit/test_tool_parser.py

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ def test_parse_with_whitespace(self, parser):
278278
# --- Custom Tokens ---
279279

280280
def test_custom_tokens(self):
281-
"""Use custom tool_call_tokens."""
282-
parser = HermesToolParser(tool_call_tokens=("<function>", "</function>"))
281+
"""Use custom tool tokens."""
282+
parser = HermesToolParser(tool_start_token="<function>", tool_end_token="</function>")
283283
text = '<function>{"name": "custom", "arguments": {}}</function>'
284284
results = parser.parse(text)
285285

@@ -288,7 +288,7 @@ def test_custom_tokens(self):
288288

289289
def test_custom_tokens_ignore_default(self):
290290
"""Custom tokens ignore default format."""
291-
parser = HermesToolParser(tool_call_tokens=("<function>", "</function>"))
291+
parser = HermesToolParser(tool_start_token="<function>", tool_end_token="</function>")
292292
# Default format should not be parsed
293293
text = '<tool_call>{"name": "ignored", "arguments": {}}</tool_call>'
294294
results = parser.parse(text)
@@ -436,25 +436,9 @@ def test_no_think_blocks(self, parser):
436436
assert len(results) == 1
437437
assert results[0].name == "tool"
438438

439-
def test_disable_think_block_exclusion(self):
440-
"""Setting think_tokens=None disables exclusion."""
441-
parser = HermesToolParser(think_tokens=None)
442-
text = """
443-
<think>
444-
<tool_call>{"name": "inside_think", "arguments": {}}</tool_call>
445-
</think>
446-
<tool_call>{"name": "outside_think", "arguments": {}}</tool_call>
447-
"""
448-
results = parser.parse(text)
449-
450-
# Both should be parsed when exclusion is disabled
451-
assert len(results) == 2
452-
assert results[0].name == "inside_think"
453-
assert results[1].name == "outside_think"
454-
455439
def test_custom_think_tokens(self):
456440
"""Custom think tokens work correctly."""
457-
parser = HermesToolParser(think_tokens=("<reasoning>", "</reasoning>"))
441+
parser = HermesToolParser(think_start_token="<reasoning>", think_end_token="</reasoning>")
458442
text = """
459443
<reasoning>
460444
<tool_call>{"name": "draft", "arguments": {}}</tool_call>
@@ -468,7 +452,7 @@ def test_custom_think_tokens(self):
468452

469453
def test_custom_think_tokens_ignore_default(self):
470454
"""Custom think tokens don't exclude default <think> blocks."""
471-
parser = HermesToolParser(think_tokens=("<reasoning>", "</reasoning>"))
455+
parser = HermesToolParser(think_start_token="<reasoning>", think_end_token="</reasoning>")
472456
text = """
473457
<think>
474458
<tool_call>{"name": "in_think", "arguments": {}}</tool_call>
@@ -663,8 +647,8 @@ def test_parse_compact_format(self, parser):
663647
# --- Custom Tokens ---
664648

665649
def test_custom_tokens(self):
666-
"""Use custom tool_call_tokens."""
667-
parser = QwenXMLToolParser(tool_call_tokens=("<call>", "</call>"))
650+
"""Use custom tool tokens."""
651+
parser = QwenXMLToolParser(tool_start_token="<call>", tool_end_token="</call>")
668652
text = """<call>
669653
<function=custom>
670654
<parameter=x>1</parameter>
@@ -677,7 +661,7 @@ def test_custom_tokens(self):
677661

678662
def test_custom_tokens_ignore_default(self):
679663
"""Custom tokens ignore default format."""
680-
parser = QwenXMLToolParser(tool_call_tokens=("<call>", "</call>"))
664+
parser = QwenXMLToolParser(tool_start_token="<call>", tool_end_token="</call>")
681665
text = """<tool_call>
682666
<function=ignored>
683667
<parameter=x>1</parameter>
@@ -713,27 +697,25 @@ def test_exclude_tool_calls_inside_think_block(self, parser):
713697
assert results[0].name == "actual_tool"
714698
assert results[0].input == {"y": "2"}
715699

716-
def test_disable_think_block_exclusion(self):
717-
"""Setting think_tokens=None disables exclusion."""
718-
parser = QwenXMLToolParser(think_tokens=None)
700+
def test_custom_think_tokens(self):
701+
"""Custom think tokens work correctly."""
702+
parser = QwenXMLToolParser(think_start_token="<reasoning>", think_end_token="</reasoning>")
719703
text = """
720-
<think>
704+
<reasoning>
721705
<tool_call>
722-
<function=inside_think>
706+
<function=inside_reasoning>
723707
</function>
724708
</tool_call>
725-
</think>
709+
</reasoning>
726710
<tool_call>
727-
<function=outside_think>
711+
<function=outside_reasoning>
728712
</function>
729713
</tool_call>
730714
"""
731715
results = parser.parse(text)
732716

733-
# Both should be parsed when exclusion is disabled
734-
assert len(results) == 2
735-
assert results[0].name == "inside_think"
736-
assert results[1].name == "outside_think"
717+
assert len(results) == 1
718+
assert results[0].name == "outside_reasoning"
737719

738720
# --- Edge Cases ---
739721

@@ -1011,8 +993,8 @@ def test_parse_compact_format(self, parser):
1011993
# --- Custom Tokens ---
1012994

1013995
def test_custom_tokens(self):
1014-
"""Use custom tool_call_tokens."""
1015-
parser = GLMToolParser(tool_call_tokens=("<call>", "</call>"))
996+
"""Use custom tool tokens."""
997+
parser = GLMToolParser(tool_start_token="<call>", tool_end_token="</call>")
1016998
text = """<call>custom
1017999
<arg_key>x</arg_key>
10181000
<arg_value>1</arg_value>
@@ -1025,7 +1007,7 @@ def test_custom_tokens(self):
10251007

10261008
def test_custom_tokens_ignore_default(self):
10271009
"""Custom tokens ignore default format."""
1028-
parser = GLMToolParser(tool_call_tokens=("<call>", "</call>"))
1010+
parser = GLMToolParser(tool_start_token="<call>", tool_end_token="</call>")
10291011
text = """<tool_call>ignored
10301012
<arg_key>x</arg_key>
10311013
<arg_value>1</arg_value>
@@ -1058,23 +1040,21 @@ def test_exclude_tool_calls_inside_think_block(self, parser):
10581040
assert results[0].name == "actual_tool"
10591041
assert results[0].input == {"y": 2} # JSON-decoded as integer
10601042

1061-
def test_disable_think_block_exclusion(self):
1062-
"""Setting think_tokens=None disables exclusion."""
1063-
parser = GLMToolParser(think_tokens=None)
1043+
def test_custom_think_tokens(self):
1044+
"""Custom think tokens work correctly."""
1045+
parser = GLMToolParser(think_start_token="<reasoning>", think_end_token="</reasoning>")
10641046
text = """
1065-
<think>
1066-
<tool_call>inside_think
1047+
<reasoning>
1048+
<tool_call>inside_reasoning
10671049
</tool_call>
1068-
</think>
1069-
<tool_call>outside_think
1050+
</reasoning>
1051+
<tool_call>outside_reasoning
10701052
</tool_call>
10711053
"""
10721054
results = parser.parse(text)
10731055

1074-
# Both should be parsed when exclusion is disabled
1075-
assert len(results) == 2
1076-
assert results[0].name == "inside_think"
1077-
assert results[1].name == "outside_think"
1056+
assert len(results) == 1
1057+
assert results[0].name == "outside_reasoning"
10781058

10791059
# --- Edge Cases ---
10801060

@@ -1184,8 +1164,8 @@ def test_get_parser_with_kwargs(self):
11841164
"""Get parser with custom arguments."""
11851165
from strands_sglang.tool_parsers import get_tool_parser
11861166

1187-
parser = get_tool_parser("hermes", think_tokens=None)
1188-
assert parser.think_tokens is None
1167+
parser = get_tool_parser("hermes", think_start_token="<reasoning>")
1168+
assert parser.think_start_token == "<reasoning>"
11891169

11901170
def test_unknown_parser_raises(self):
11911171
"""Unknown parser name raises KeyError."""

0 commit comments

Comments
 (0)