@@ -511,10 +511,8 @@ def _batch_to_train_example(
511511 """
512512 # Create a merged training_input where each field from the original input
513513 # is repeated G times to align with the G completions.
514- num_generations = self .algo_config .num_generations
515- prompt_index = batch_results [0 ].pair_index // num_generations
516- if mode == rl_cluster_lib .Mode .TRAIN and self ._full_batch_size :
517- expected_step = prompt_index // self ._full_batch_size
514+ if mode == rl_cluster_lib .Mode .TRAIN :
515+ expected_step = batch_results [0 ].group_id // self ._full_batch_size
518516 else :
519517 expected_step = self .rl_cluster .global_steps
520518
@@ -710,7 +708,10 @@ def train(
710708 micro_batches_since_last_sync = 0
711709 micro_batches_per_full_batch = full_batch_size // train_micro_batch_size
712710 for train_micro_batch in train_data_gen :
713- if self .rl_cluster .global_steps >= self ._training_config .max_steps :
711+ if (
712+ self ._training_config .max_steps
713+ and self .rl_cluster .global_steps >= self ._training_config .max_steps
714+ ):
714715 logging .info (
715716 "Reached max_steps: %d >= %d" ,
716717 self .rl_cluster .global_steps ,
@@ -825,7 +826,17 @@ def _put_prompts_to_queue(
825826 prompt_queue: The queue to put the batch into.
826827 batch: The batch of prompts (TrainingInputT).
827828 """
828- if len (batch ["prompts" ]) != self ._full_batch_size :
829+ if (
830+ self ._training_config .max_steps
831+ and self .rl_cluster .global_steps >= self ._training_config .max_steps
832+ ):
833+ logging .info (
834+ "Reached max_steps: %d >= %d" ,
835+ self .rl_cluster .global_steps ,
836+ self ._training_config .max_steps ,
837+ )
838+ prompt_queue .put (None )
839+ elif len (batch ["prompts" ]) != self ._full_batch_size :
829840 logging .warning (
830841 "partial batch %d vs %d detected. The rest of the batch will be"
831842 " skipped." ,
0 commit comments