@@ -234,6 +234,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
234234 payload = {
235235 "text" : prompt + response ,
236236 "sampling_params" : sampling_params ,
237+ "return_logprob" : True , # Request log probabilities for training
237238 }
238239
239240 # Log payload to wandb for debugging
@@ -265,10 +266,17 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
265266 return sample
266267
267268 cur_response = output ["text" ]
268- cur_response = postprocess_responses (cur_response )
269-
270- # Record current response tokens
271- cur_response_token_ids = state .tokenizer (cur_response , add_special_tokens = False )["input_ids" ]
269+
270+ if "output_token_logprobs" in output ["meta_info" ]:
271+ cur_response_token_ids = [item [1 ] for item in output ["meta_info" ]["output_token_logprobs" ]]
272+ cur_log_probs = [item [0 ] for item in output ["meta_info" ]["output_token_logprobs" ]]
273+ if sample .rollout_log_probs is None :
274+ sample .rollout_log_probs = []
275+ sample .rollout_log_probs += cur_log_probs
276+ else :
277+ cur_response = postprocess_responses (cur_response )
278+ cur_response_token_ids = state .tokenizer (cur_response , add_special_tokens = False )["input_ids" ]
279+
272280 response += cur_response
273281 response_token_ids += cur_response_token_ids
274282 loss_masks += [1 ] * len (cur_response_token_ids )
@@ -292,7 +300,15 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
292300 response_token_ids += obs_tokens_ids
293301 loss_masks += [0 ] * len (obs_tokens_ids )
294302
303+ # Add dummy log probs for observation tokens (they won't be used due to loss_mask=0)
295304 # Check if maximum tool call count reached
305+ if sample .rollout_log_probs is not None :
306+ sample .rollout_log_probs += [0.0 ] * len (obs_tokens_ids )
307+
308+ assert len (response_token_ids ) == len (
309+ sample .rollout_log_probs
310+ ), f"Token/logp length mismatch at turn { turn } : { len (response_token_ids )} tokens vs { len (sample .rollout_log_probs )} logps"
311+
296312 if turn >= TOOL_CONFIGS ["max_tool_calls" ]:
297313 break
298314
0 commit comments