Skip to content

Commit 3917c87

Browse files
feat(rollout): add --disable-oversampling to cap submissions at rollout_batch_size
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 9d4262c commit 3917c87

2 files changed

Lines changed: 35 additions & 7 deletions

File tree

miles/rollout/inference_rollout/inference_rollout_train.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,28 @@ async def generate_rollout_async(
8686
# target_data_size is the total number of valid samples to get
8787
target_data_size = args.rollout_batch_size
8888

89+
# When oversampling is disabled, submit exactly target_data_size groups once and keep
90+
# whatever survives the dynamic filter (the batch may end short); otherwise keep sampling
91+
# in over_sampling_batch_size chunks until target_data_size groups are collected.
92+
submit_budget = target_data_size if args.disable_oversampling else None
93+
8994
pendings = set()
9095
data = []
9196
all_data = []
97+
submitted = 0
9298
do_print = True
9399
pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation")
94100
while len(data) < target_data_size:
95-
while len(data) + len(pendings) < target_data_size:
101+
while len(data) + len(pendings) < target_data_size and (submit_budget is None or submitted < submit_budget):
96102
# get samples from the buffer and submit the generation requests.
97-
samples = data_source(args.over_sampling_batch_size)
103+
n = args.over_sampling_batch_size if submit_budget is None else submit_budget - submitted
104+
samples = data_source(n)
105+
submitted += len(samples)
98106
pendings.update(submit_generate_tasks(state, samples))
99107

108+
if not pendings:
109+
break
110+
100111
# wait for the generation to finish
101112
logger.debug(f"[rollout] Waiting on {len(pendings)} pending tasks, data={len(data)}/{target_data_size}")
102113
done, pendings = await asyncio.wait(pendings, return_when=asyncio.FIRST_COMPLETED)
@@ -129,15 +140,22 @@ async def generate_rollout_async(
129140
pbar.update(args.n_samples_per_prompt)
130141

131142
pbar.close()
132-
sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0]
133-
logger.info(
134-
f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
135-
)
143+
if data:
144+
sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0]
145+
logger.info(
146+
f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
147+
)
136148

137149
# there are still some unfinished requests, abort them
138150
aborted_samples = await abort(state, pendings, rollout_id)
139151

140-
assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
152+
if args.disable_oversampling:
153+
if len(data) < args.rollout_batch_size:
154+
logger.warning(
155+
f"[rollout] oversampling disabled: {len(data)}/{args.rollout_batch_size} groups survived the dynamic filter"
156+
)
157+
else:
158+
assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
141159
data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index)
142160
all_samples = sorted(
143161
all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index

miles/utils/arguments.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,16 @@ def add_rollout_arguments(parser):
353353
"If this value is None, rollout_batch_size will be used as the default over_sampling_batch_size."
354354
),
355355
)
356+
parser.add_argument(
357+
"--disable-oversampling",
358+
action="store_true",
359+
default=False,
360+
help=(
361+
"Submit exactly rollout_batch_size groups and keep whatever survives the dynamic "
362+
"filter, instead of oversampling to refill groups dropped by the filter. The batch "
363+
"may end short; pair with --use-dynamic-global-batch-size when many groups are dropped."
364+
),
365+
)
356366
parser.add_argument(
357367
"--dynamic-sampling-filter-path",
358368
type=str,

0 commit comments

Comments
 (0)