@@ -54,13 +54,11 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs):
5454 self ._init_rollout_engine ()
5555 self ._prepare_rewards ()
5656 self ._prepare_scheduler ()
57- self ._train_dataset = None
57+ self .resample_data_iterator = None
5858
5959 def train (self , train_dataset , val_dataset ):
60- # Store dataset provider for lazy resample iterator initialization
61- # Used by both dynamic_sample and truncation_strategy='delete'
6260 if self .dynamic_sample or self .truncation_strategy == 'delete' :
63- self ._train_dataset = train_dataset
61+ self .resample_data_iterator = self . _init_resample_data_iterator ( train_dataset )
6462 super ().train (train_dataset , val_dataset )
6563
6664 def _init_grpo_params (self ):
@@ -215,31 +213,28 @@ def _prepare_scheduler(self):
215213 assert isinstance (args .multi_turn_scheduler , MultiTurnScheduler )
216214 self .multi_turn_scheduler : MultiTurnScheduler = args .multi_turn_scheduler
217215
218- def _init_resample_data_iterator (self ):
219- """Initialize an independent data iterator for dynamic resampling (lazy initialization) .
216+ def _init_resample_data_iterator (self , train_dataset ):
217+ """Initialize an independent data iterator for resampling.
220218
221219 Uses a different seed (args.seed + 1) to avoid overlapping with training samples.
222220
221+ Args:
222+ train_dataset: The training dataset to create the resample iterator from.
223+
223224 Returns:
224- train_data_iterator: Independent data iterator with different random seed
225+ The resample data iterator (first element of the iterator tuple).
225226 """
226227 args = self .args
227- # Use different seed for resample iterator (offset by 1 to avoid overlap)
228228 resample_seed = getattr (args , 'seed' , 42 ) + 1
229229 try :
230- # Set new seed for resample iterator creation
231230 set_random_seed (
232231 resample_seed ,
233232 args .data_parallel_random_init ,
234233 args .te_rng_tracker ,
235234 )
236-
237- # Build data iterators with new seed
238235 # TODO: VPP (Virtual Pipeline Parallelism)
239- resample_data_iterator = self ._prepare_data_iterator (self ._train_dataset , use_origin_cyclic = True )
240- self ._train_dataset = None
236+ resample_data_iterator = self ._prepare_data_iterator (train_dataset , use_origin_cyclic = True )[0 ]
241237 finally :
242- # Restore original random states to avoid affecting training
243238 set_random_seed (
244239 args .seed ,
245240 args .data_parallel_random_init ,
@@ -909,9 +904,6 @@ def _dynamic_sampling(self, rollout_batch: DataType,
909904 if len (valid_samples ) >= self .generation_batch_size :
910905 break
911906
912- # Lazy initialization of resample_data_iterator
913- if not hasattr (self , 'resample_data_iterator' ) or self .resample_data_iterator is None :
914- self .resample_data_iterator = self ._init_resample_data_iterator ()[0 ]
915907 num_iters_per_step = self .get_num_iters_per_step ()
916908 next_rollout_prompt_batch = []
917909 for _ in range (num_iters_per_step ):
@@ -1562,11 +1554,7 @@ def resample_encode_failed_inputs(self, inputs: DataType, max_resample_rounds: i
15621554 required_count = len (inputs )
15631555 valid_samples = []
15641556
1565- # Buffer for samples waiting to be validated
15661557 pending_samples = list (inputs )
1567- # Lazy initialization of resample_data_iterator
1568- if not hasattr (self , 'resample_data_iterator' ) or self .resample_data_iterator is None :
1569- self .resample_data_iterator = self ._init_resample_data_iterator ()[0 ]
15701558 for _ in range (max_resample_rounds + 1 ):
15711559 # Calculate how many more samples we need
15721560 still_needed = required_count - len (valid_samples )
0 commit comments