Skip to content

Commit 0a20ca2

Browse files
committed
refactor resample
1 parent 09974ef commit 0a20ca2

File tree

2 files changed

+18
-37
lines changed

2 files changed

+18
-37
lines changed

swift/megatron/trainers/gkd_trainer.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,11 @@ def __init__(self, args: MegatronArguments, template, **kwargs):
5858
self.truncation_strategy = args.truncation_strategy
5959
self.max_completion_length = args.max_completion_length
6060

61-
# Resample iterator will be initialized lazily
6261
self.resample_data_iterator = None
63-
self._train_dataset = None
6462

6563
def train(self, train_dataset, val_dataset):
66-
"""Override train to initialize resample iterator for truncation_strategy='delete'."""
67-
# Store dataset provider for lazy resample iterator initialization
6864
if self.truncation_strategy == 'delete':
69-
self._train_dataset = train_dataset
70-
65+
self.resample_data_iterator = self._init_resample_data_iterator(train_dataset)
7166
super().train(train_dataset, val_dataset)
7267

7368
def prepare_model(self):
@@ -184,13 +179,16 @@ def _determine_data_source(self) -> DataSource:
184179
# Mode 3: Off-Policy learning, use dataset responses
185180
return DataSource.DATASET
186181

187-
def _init_resample_data_iterator(self):
188-
"""Initialize an independent data iterator for dynamic resampling (lazy initialization).
182+
def _init_resample_data_iterator(self, train_dataset):
183+
"""Initialize an independent data iterator for resampling.
189184
190185
Uses a different seed (args.seed + 1) to avoid overlapping with training samples.
191186
187+
Args:
188+
train_dataset: The training dataset to create the resample iterator from.
189+
192190
Returns:
193-
train_data_iterator: Independent data iterator with different random seed
191+
The resample data iterator (first element of the iterator tuple).
194192
"""
195193
args = self.args
196194
resample_seed = getattr(args, 'seed', 42) + 1
@@ -200,8 +198,7 @@ def _init_resample_data_iterator(self):
200198
args.data_parallel_random_init,
201199
args.te_rng_tracker,
202200
)
203-
resample_data_iterator = self._prepare_data_iterator(self._train_dataset, use_origin_cyclic=True)
204-
self._train_dataset = None
201+
resample_data_iterator = self._prepare_data_iterator(train_dataset, use_origin_cyclic=True)[0]
205202
finally:
206203
set_random_seed(
207204
args.seed,
@@ -225,10 +222,6 @@ def resample_encode_failed_inputs(self, inputs: List[Dict], max_resample_rounds:
225222
valid_samples = []
226223
pending_samples = list(inputs)
227224

228-
# Lazy initialization of resample_data_iterator
229-
if self.resample_data_iterator is None:
230-
self.resample_data_iterator = self._init_resample_data_iterator()[0]
231-
232225
for _ in range(max_resample_rounds + 1):
233226
still_needed = required_count - len(valid_samples)
234227
if still_needed <= 0:
@@ -283,7 +276,7 @@ def _replace_data_iterator(self, data_iterator):
283276
raw_batch = next(data_iterator)
284277

285278
# Resample for encoding failed data when truncation_strategy is 'delete'
286-
if self.truncation_strategy == 'delete' and self._train_dataset is not None:
279+
if self.truncation_strategy == 'delete' and self.resample_data_iterator is not None:
287280
raw_batch = self.resample_encode_failed_inputs(raw_batch)
288281

289282
global_batch.extend(raw_batch)

swift/megatron/trainers/grpo_trainer.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,11 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs):
5454
self._init_rollout_engine()
5555
self._prepare_rewards()
5656
self._prepare_scheduler()
57-
self._train_dataset = None
57+
self.resample_data_iterator = None
5858

5959
def train(self, train_dataset, val_dataset):
60-
# Store dataset provider for lazy resample iterator initialization
61-
# Used by both dynamic_sample and truncation_strategy='delete'
6260
if self.dynamic_sample or self.truncation_strategy == 'delete':
63-
self._train_dataset = train_dataset
61+
self.resample_data_iterator = self._init_resample_data_iterator(train_dataset)
6462
super().train(train_dataset, val_dataset)
6563

6664
def _init_grpo_params(self):
@@ -215,31 +213,28 @@ def _prepare_scheduler(self):
215213
assert isinstance(args.multi_turn_scheduler, MultiTurnScheduler)
216214
self.multi_turn_scheduler: MultiTurnScheduler = args.multi_turn_scheduler
217215

218-
def _init_resample_data_iterator(self):
219-
"""Initialize an independent data iterator for dynamic resampling (lazy initialization).
216+
def _init_resample_data_iterator(self, train_dataset):
217+
"""Initialize an independent data iterator for resampling.
220218
221219
Uses a different seed (args.seed + 1) to avoid overlapping with training samples.
222220
221+
Args:
222+
train_dataset: The training dataset to create the resample iterator from.
223+
223224
Returns:
224-
train_data_iterator: Independent data iterator with different random seed
225+
The resample data iterator (first element of the iterator tuple).
225226
"""
226227
args = self.args
227-
# Use different seed for resample iterator (offset by 1 to avoid overlap)
228228
resample_seed = getattr(args, 'seed', 42) + 1
229229
try:
230-
# Set new seed for resample iterator creation
231230
set_random_seed(
232231
resample_seed,
233232
args.data_parallel_random_init,
234233
args.te_rng_tracker,
235234
)
236-
237-
# Build data iterators with new seed
238235
# TODO: VPP (Virtual Pipeline Parallelism)
239-
resample_data_iterator = self._prepare_data_iterator(self._train_dataset, use_origin_cyclic=True)
240-
self._train_dataset = None
236+
resample_data_iterator = self._prepare_data_iterator(train_dataset, use_origin_cyclic=True)[0]
241237
finally:
242-
# Restore original random states to avoid affecting training
243238
set_random_seed(
244239
args.seed,
245240
args.data_parallel_random_init,
@@ -909,9 +904,6 @@ def _dynamic_sampling(self, rollout_batch: DataType,
909904
if len(valid_samples) >= self.generation_batch_size:
910905
break
911906

912-
# Lazy initialization of resample_data_iterator
913-
if not hasattr(self, 'resample_data_iterator') or self.resample_data_iterator is None:
914-
self.resample_data_iterator = self._init_resample_data_iterator()[0]
915907
num_iters_per_step = self.get_num_iters_per_step()
916908
next_rollout_prompt_batch = []
917909
for _ in range(num_iters_per_step):
@@ -1562,11 +1554,7 @@ def resample_encode_failed_inputs(self, inputs: DataType, max_resample_rounds: i
15621554
required_count = len(inputs)
15631555
valid_samples = []
15641556

1565-
# Buffer for samples waiting to be validated
15661557
pending_samples = list(inputs)
1567-
# Lazy initialization of resample_data_iterator
1568-
if not hasattr(self, 'resample_data_iterator') or self.resample_data_iterator is None:
1569-
self.resample_data_iterator = self._init_resample_data_iterator()[0]
15701558
for _ in range(max_resample_rounds + 1):
15711559
# Calculate how many more samples we need
15721560
still_needed = required_count - len(valid_samples)

0 commit comments

Comments
 (0)