Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions miles/rollout/generate_utils/openai_endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,23 +227,9 @@ def truncate_samples_by_total_tokens(
if allowed_output <= 0:
break

_truncate_sample_output(sample, allowed_output, tokenizer)
sample.strip_last_output_tokens(overshoot, tokenizer)
sample.status = Sample.Status.TRUNCATED
Comment thread
guapisolo marked this conversation as resolved.
result.append(sample)
break

return result


def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None:
"""Truncate a sample's output in-place to exactly ``keep_tokens`` tokens."""
prompt_len = len(sample.tokens) - sample.response_length
kept_ids = sample.tokens[prompt_len : prompt_len + keep_tokens]

sample.tokens = sample.tokens[:prompt_len] + kept_ids
sample.response = tokenizer.decode(kept_ids)
sample.response_length = keep_tokens
if sample.rollout_log_probs is not None:
sample.rollout_log_probs = sample.rollout_log_probs[:keep_tokens]
if sample.loss_mask is not None:
sample.loss_mask = sample.loss_mask[:keep_tokens]
sample.status = Sample.Status.TRUNCATED
Loading