@@ -86,17 +86,28 @@ async def generate_rollout_async(
8686 # target_data_size is the total number of valid samples to get
8787 target_data_size = args .rollout_batch_size
8888
89+ # When oversampling is disabled, submit exactly target_data_size groups once and keep
90+ # whatever survives the dynamic filter (the batch may end short); otherwise keep sampling
91+ # in over_sampling_batch_size chunks until target_data_size groups are collected.
92+ submit_budget = target_data_size if args .disable_oversampling else None
93+
8994 pendings = set ()
9095 data = []
9196 all_data = []
97+ submitted = 0
9298 do_print = True
9399 pbar = tqdm (total = target_data_size * args .n_samples_per_prompt , desc = "Rollout generation" )
94100 while len (data ) < target_data_size :
95- while len (data ) + len (pendings ) < target_data_size :
101+ while len (data ) + len (pendings ) < target_data_size and ( submit_budget is None or submitted < submit_budget ) :
96102 # get samples from the buffer and submit the generation requests.
97- samples = data_source (args .over_sampling_batch_size )
103+ n = args .over_sampling_batch_size if submit_budget is None else submit_budget - submitted
104+ samples = data_source (n )
105+ submitted += len (samples )
98106 pendings .update (submit_generate_tasks (state , samples ))
99107
108+ if not pendings :
109+ break
110+
100111 # wait for the generation to finish
101112 logger .debug (f"[rollout] Waiting on { len (pendings )} pending tasks, data={ len (data )} /{ target_data_size } " )
102113 done , pendings = await asyncio .wait (pendings , return_when = asyncio .FIRST_COMPLETED )
@@ -129,15 +140,22 @@ async def generate_rollout_async(
129140 pbar .update (args .n_samples_per_prompt )
130141
131142 pbar .close ()
132- sample = data [- 1 ][0 ][0 ] if isinstance (data [- 1 ][0 ], list ) else data [- 1 ][0 ]
133- logger .info (
134- f"Finish rollout: { [str (sample .prompt ) + sample .response ]} , label: { sample .label } , reward: { sample .reward } " ,
135- )
143+ if data :
144+ sample = data [- 1 ][0 ][0 ] if isinstance (data [- 1 ][0 ], list ) else data [- 1 ][0 ]
145+ logger .info (
146+ f"Finish rollout: { [str (sample .prompt ) + sample .response ]} , label: { sample .label } , reward: { sample .reward } " ,
147+ )
136148
137149 # there are still some unfinished requests, abort them
138150 aborted_samples = await abort (state , pendings , rollout_id )
139151
140- assert len (data ) == args .rollout_batch_size , f"Got { len (data )} samples, expected { args .rollout_batch_size } "
152+ if args .disable_oversampling :
153+ if len (data ) < args .rollout_batch_size :
154+ logger .warning (
155+ f"[rollout] oversampling disabled: { len (data )} /{ args .rollout_batch_size } groups survived the dynamic filter"
156+ )
157+ else :
158+ assert len (data ) == args .rollout_batch_size , f"Got { len (data )} samples, expected { args .rollout_batch_size } "
141159 data = sorted (data , key = lambda group : group [0 ][0 ].index if isinstance (group [0 ], list ) else group [0 ].index )
142160 all_samples = sorted (
143161 all_data , key = lambda group : group [0 ][0 ].index if isinstance (group [0 ], list ) else group [0 ].index
0 commit comments