Skip to content

Commit a70fe92

Browse files
authored
Fix pre-commit run --all-files (#870)
1 parent dd56492 commit a70fe92

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/retool/generate_with_retool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
266266
return sample
267267

268268
cur_response = output["text"]
269-
269+
270270
if "output_token_logprobs" in output["meta_info"]:
271271
cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
272272
cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
@@ -276,7 +276,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
276276
else:
277277
cur_response = postprocess_responses(cur_response)
278278
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
279-
279+
280280
response += cur_response
281281
response_token_ids += cur_response_token_ids
282282
loss_masks += [1] * len(cur_response_token_ids)
@@ -304,7 +304,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
304304
# Check if maximum tool call count reached
305305
if sample.rollout_log_probs is not None:
306306
sample.rollout_log_probs += [0.0] * len(obs_tokens_ids)
307-
307+
308308
assert len(response_token_ids) == len(
309309
sample.rollout_log_probs
310310
), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps"

0 commit comments

Comments
 (0)