Skip to content

Commit afff1b3

Browse files
committed
support toolcall+partial rollout
1 parent 71d5306 commit afff1b3

2 files changed

Lines changed: 352 additions & 11 deletions

File tree

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import re
2+
import unittest
3+
4+
from xtuner.v1.data_proto import RolloutState, SampleParams, Status
5+
from xtuner.v1.rl.agent_loop.gsm8k_with_tool import GSM8KToolAgentLoop
6+
from xtuner.v1.rl.agent_loop.utils import PartialRolloutHandler
7+
8+
9+
def _encode(text: str) -> list[int]:
10+
return [ord(char) for char in text]
11+
12+
13+
class _FakeTokenizer:
14+
def __init__(self) -> None:
15+
self.tool_messages: list[list[dict]] = []
16+
17+
def decode(self, ids: list[int] | None) -> str:
18+
return "".join(chr(token_id) for token_id in ids or [])
19+
20+
def apply_chat_template(self, messages: list[dict], remove_system_prompt: bool = True) -> list[int]:
21+
self.tool_messages.append(messages)
22+
content = "".join(message["content"] for message in messages)
23+
return _encode(f"<tool>{content}</tool>")
24+
25+
26+
class _FakeGenerate:
27+
def __init__(self, responses: list[tuple[str, Status]]) -> None:
28+
self.responses = list(responses)
29+
self.requests: list[dict] = []
30+
31+
async def remote(self, rollout_state: RolloutState) -> RolloutState:
32+
self.requests.append(
33+
{
34+
"tokens": list(rollout_state.tokens or []),
35+
"max_tokens": rollout_state.sample_params.max_tokens,
36+
}
37+
)
38+
text, status = self.responses.pop(0)
39+
rollout_state.response = text
40+
rollout_state.response_ids = _encode(text)
41+
rollout_state.logprobs = [0.0] * len(rollout_state.response_ids)
42+
rollout_state.status = status
43+
rollout_state.finish_reason = "stop" if status == Status.COMPLETED else "abort"
44+
return rollout_state
45+
46+
47+
class _FakeRolloutController:
48+
def __init__(self, responses: list[tuple[str, Status]]) -> None:
49+
self.generate = _FakeGenerate(responses)
50+
51+
52+
class _FakeJudger:
53+
def __init__(self) -> None:
54+
self.calls = 0
55+
56+
async def judge(self, rollout_state: RolloutState) -> RolloutState:
57+
self.calls += 1
58+
rollout_state.reward = {"score": 1.0}
59+
return rollout_state
60+
61+
62+
class TestGSM8KToolPartialRollout(unittest.IsolatedAsyncioTestCase):
63+
def _build_loop(
64+
self,
65+
responses: list[tuple[str, Status]],
66+
*,
67+
max_tokens: int = 16,
68+
max_turns: int = 2,
69+
judger: _FakeJudger | None = None,
70+
) -> GSM8KToolAgentLoop:
71+
loop = GSM8KToolAgentLoop.__new__(GSM8KToolAgentLoop)
72+
loop.max_turns = max_turns
73+
loop.rollout_ctl = _FakeRolloutController(responses)
74+
loop.sample_params = SampleParams(max_tokens=max_tokens)
75+
loop.max_tokens = max_tokens
76+
loop.partial_rollout_handler = PartialRolloutHandler(max_tokens=max_tokens)
77+
loop.tokenizer = _FakeTokenizer()
78+
loop.judger = judger
79+
loop.tool_call_pattern = re.compile(r"\n*<tool_call>(.*?)</tool_call>", re.DOTALL)
80+
loop.tool_call_start_token = "<tool_call>"
81+
loop.tool_call_end_token = "</tool_call>"
82+
loop.calc_gsm8k_reward = lambda answer, ground_truth: 1.0
83+
return loop
84+
85+
def _make_aborted_state(self, response_text: str, *, max_tokens: int = 16) -> RolloutState:
86+
response_ids = _encode(response_text)
87+
return RolloutState(
88+
uid=1,
89+
message=[{"role": "user", "content": "question"}],
90+
prompt_ids=[7],
91+
sample_params=SampleParams(max_tokens=max_tokens),
92+
status=Status.ABORTED,
93+
response_ids=response_ids,
94+
response=response_text,
95+
logprobs=[0.0] * len(response_ids),
96+
response_mask=[1] * len(response_ids),
97+
response_rollout_steps=[1] * len(response_ids),
98+
reward_model={"ground_truth": "#### 42"},
99+
extra_fields={},
100+
)
101+
102+
async def test_partial_rollout_appends_history_and_updates_staleness(self):
103+
judger = _FakeJudger()
104+
loop = self._build_loop([("c", Status.COMPLETED)], max_tokens=5, judger=judger)
105+
state = self._make_aborted_state("ab", max_tokens=5)
106+
107+
result = await loop.generate_sample(state, enable_partial_rollout=True, rollout_step=3)
108+
109+
self.assertEqual(loop.rollout_ctl.generate.requests[0]["tokens"], [7] + _encode("ab"))
110+
self.assertEqual(loop.rollout_ctl.generate.requests[0]["max_tokens"], 3)
111+
self.assertEqual(loop.tokenizer.decode(result.response_ids), "abc")
112+
self.assertEqual(result.response_mask, [1, 1, 1])
113+
self.assertEqual(result.response_rollout_steps, [1, 1, 3])
114+
self.assertEqual(result.seq_staleness, 2)
115+
tool_state = result.extra_fields["gsm8k_tool_agent_loop_state"]
116+
self.assertEqual(tool_state["cur_turn"], 1)
117+
self.assertEqual(tool_state["current_turn_response_start_idx"], 0)
118+
self.assertEqual(judger.calls, 1)
119+
120+
async def test_non_partial_path_does_not_write_partial_state(self):
121+
first_turn = '<tool_call>{"name":"calc_gsm8k_reward","arguments":{"answer":"42"}}</tool_call>'
122+
loop = self._build_loop([(first_turn, Status.COMPLETED), ("done", Status.COMPLETED)], max_tokens=128)
123+
state = RolloutState(
124+
uid=2,
125+
message=[{"role": "user", "content": "question"}],
126+
prompt_ids=[7],
127+
sample_params=SampleParams(max_tokens=128),
128+
reward_model={"ground_truth": "#### 42"},
129+
extra_fields={},
130+
)
131+
132+
result = await loop.generate_sample(state)
133+
134+
self.assertNotIn("gsm8k_tool_agent_loop_state", result.extra_fields)
135+
self.assertEqual(len(loop.tokenizer.tool_messages), 1)
136+
self.assertIn("done", loop.tokenizer.decode(result.response_ids))
137+
138+
async def test_aborted_partial_rollout_is_not_judged(self):
139+
judger = _FakeJudger()
140+
loop = self._build_loop([("c", Status.ABORTED)], max_tokens=5, judger=judger)
141+
state = self._make_aborted_state("ab", max_tokens=5)
142+
143+
result = await loop.generate_sample(state, enable_partial_rollout=True, rollout_step=3)
144+
145+
self.assertEqual(result.status, Status.ABORTED)
146+
self.assertEqual(loop.tokenizer.decode(result.response_ids), "abc")
147+
self.assertEqual(result.response_rollout_steps, [1, 1, 3])
148+
tool_state = result.extra_fields["gsm8k_tool_agent_loop_state"]
149+
self.assertEqual(tool_state["cur_turn"], 0)
150+
self.assertEqual(tool_state["current_turn_response_start_idx"], 0)
151+
self.assertEqual(judger.calls, 0)
152+
153+
async def test_tool_call_can_complete_across_partial_boundary(self):
154+
history = '<tool_call>{"name":"calc_gsm8k_reward","arguments":{"answer":"42"}}'
155+
loop = self._build_loop([("</tool_call>", Status.COMPLETED)], max_tokens=128, max_turns=1)
156+
state = self._make_aborted_state(history, max_tokens=128)
157+
158+
result = await loop.generate_sample(state, enable_partial_rollout=True, rollout_step=2)
159+
160+
decoded = loop.tokenizer.decode(result.response_ids)
161+
self.assertIn(history + "</tool_call>", decoded)
162+
self.assertIn("<tool>", decoded)
163+
self.assertEqual(len(loop.tokenizer.tool_messages), 1)
164+
tool_response_ids = _encode('<tool>{"result": 1.0}</tool>')
165+
self.assertEqual(result.response_mask[-len(tool_response_ids) :], [0] * len(tool_response_ids))
166+
tool_state = result.extra_fields["gsm8k_tool_agent_loop_state"]
167+
self.assertEqual(tool_state["current_turn_response_start_idx"], len(result.response_ids))
168+
169+
async def test_does_not_reparse_previous_completed_turn(self):
170+
history = (
171+
'<tool_call>{"name":"calc_gsm8k_reward","arguments":{"answer":"42"}}</tool_call>'
172+
'<tool>{"result": 1.0}</tool>'
173+
)
174+
loop = self._build_loop([("final answer", Status.COMPLETED)], max_tokens=128, max_turns=2)
175+
state = self._make_aborted_state(history, max_tokens=128)
176+
state.extra_fields["gsm8k_tool_agent_loop_state"] = {
177+
"cur_turn": 1,
178+
"current_turn_response_start_idx": len(_encode(history)),
179+
}
180+
181+
result = await loop.generate_sample(state, enable_partial_rollout=True, rollout_step=2)
182+
183+
self.assertEqual(loop.tokenizer.tool_messages, [])
184+
self.assertIn("final answer", loop.tokenizer.decode(result.response_ids))
185+
186+
187+
if __name__ == "__main__":
188+
unittest.main()

0 commit comments

Comments
 (0)