Skip to content

Commit cd14d1a

Browse files
authored
Refactor Retool recipe with rollout_log_probs recorded (#828)
1 parent 2cd8b34 commit cd14d1a

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

examples/retool/generate_with_retool.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
234234
payload = {
235235
"text": prompt + response,
236236
"sampling_params": sampling_params,
237+
"return_logprob": True, # Request log probabilities for training
237238
}
238239

239240
# Log payload to wandb for debugging
@@ -265,10 +266,17 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
265266
return sample
266267

267268
cur_response = output["text"]
268-
cur_response = postprocess_responses(cur_response)
269-
270-
# Record current response tokens
271-
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
269+
270+
if "output_token_logprobs" in output["meta_info"]:
271+
cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
272+
cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
273+
if sample.rollout_log_probs is None:
274+
sample.rollout_log_probs = []
275+
sample.rollout_log_probs += cur_log_probs
276+
else:
277+
cur_response = postprocess_responses(cur_response)
278+
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
279+
272280
response += cur_response
273281
response_token_ids += cur_response_token_ids
274282
loss_masks += [1] * len(cur_response_token_ids)
@@ -292,7 +300,15 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
292300
response_token_ids += obs_tokens_ids
293301
loss_masks += [0] * len(obs_tokens_ids)
294302

303+
# Add dummy log probs for observation tokens (they won't be used due to loss_mask=0)
295304
# Check if maximum tool call count reached
305+
if sample.rollout_log_probs is not None:
306+
sample.rollout_log_probs += [0.0] * len(obs_tokens_ids)
307+
308+
assert len(response_token_ids) == len(
309+
sample.rollout_log_probs
310+
), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps"
311+
296312
if turn >= TOOL_CONFIGS["max_tool_calls"]:
297313
break
298314

0 commit comments

Comments
 (0)