Skip to content

Commit e03f1bf

Browse files
tianshubThe tunix Authors
authored andcommitted
remove unnecessary rollout round
PiperOrigin-RevId: 876437319
1 parent df627a6 commit e03f1bf

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

tunix/rl/experimental/agentic_rl_learner.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -511,10 +511,8 @@ def _batch_to_train_example(
511511
"""
512512
# Create a merged training_input where each field from the original input
513513
# is repeated G times to align with the G completions.
514-
num_generations = self.algo_config.num_generations
515-
prompt_index = batch_results[0].pair_index // num_generations
516-
if mode == rl_cluster_lib.Mode.TRAIN and self._full_batch_size:
517-
expected_step = prompt_index // self._full_batch_size
514+
if mode == rl_cluster_lib.Mode.TRAIN:
515+
expected_step = batch_results[0].group_id // self._full_batch_size
518516
else:
519517
expected_step = self.rl_cluster.global_steps
520518

@@ -710,7 +708,10 @@ def train(
710708
micro_batches_since_last_sync = 0
711709
micro_batches_per_full_batch = full_batch_size // train_micro_batch_size
712710
for train_micro_batch in train_data_gen:
713-
if self.rl_cluster.global_steps >= self._training_config.max_steps:
711+
if (
712+
self._training_config.max_steps
713+
and self.rl_cluster.global_steps >= self._training_config.max_steps
714+
):
714715
logging.info(
715716
"Reached max_steps: %d >= %d",
716717
self.rl_cluster.global_steps,
@@ -825,7 +826,17 @@ def _put_prompts_to_queue(
825826
prompt_queue: The queue to put the batch into.
826827
batch: The batch of prompts (TrainingInputT).
827828
"""
828-
if len(batch["prompts"]) != self._full_batch_size:
829+
if (
830+
self._training_config.max_steps
831+
and self.rl_cluster.global_steps >= self._training_config.max_steps
832+
):
833+
logging.info(
834+
"Reached max_steps: %d >= %d",
835+
self.rl_cluster.global_steps,
836+
self._training_config.max_steps,
837+
)
838+
prompt_queue.put(None)
839+
elif len(batch["prompts"]) != self._full_batch_size:
829840
logging.warning(
830841
"partial batch %d vs %d detected. The rest of the batch will be"
831842
" skipped.",

0 commit comments

Comments
 (0)