@@ -184,7 +184,7 @@ async def abort(args, rollout_id: int, data_buffer):
184184 print (f"Abort request for { url } " , flush = True )
185185 # await post(f"{url}/abort_request", {"abort_all": True}, use_http2=False)
186186 # based on https://github.com/THUDM/slime/pull/63/files
187- await post (f"{ url } /abort_request" , {"rid" :"" , "abort_all" : True }, use_http2 = False )
187+ await post (f"{ url } /abort_request" , {"rid" : "" , "abort_all" : True }, use_http2 = False )
188188
189189 # make sure all the pending tasks are finished
190190 count = 0
@@ -281,26 +281,28 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer) -> list[lis
281281
282282 assert len (data ) == args .rollout_batch_size , f"Got { len (data )} samples, expected { args .rollout_batch_size } "
283283 data = sorted (data , key = lambda group : group [0 ].index )
284-
284+
285285 rollout_time = time .time () - state .rollout_start_time
286286
287287 completion_tokens_stats = {}
288288 if state .completion_tokens_list :
289289 completion_tokens_array = np .array (state .completion_tokens_list )
290290 completion_tokens_stats = {
291- ' total_completion_tokens' : np .sum (completion_tokens_array ).item (),
292- ' completion_tokens_mean' : np .mean (completion_tokens_array ).item (),
293- ' completion_tokens_std' : np .std (completion_tokens_array ).item (),
294- ' completion_tokens_count' : len (completion_tokens_array ),
291+ " total_completion_tokens" : np .sum (completion_tokens_array ).item (),
292+ " completion_tokens_mean" : np .mean (completion_tokens_array ).item (),
293+ " completion_tokens_std" : np .std (completion_tokens_array ).item (),
294+ " completion_tokens_count" : len (completion_tokens_array ),
295295 }
296296
297297 if len (data ) > 0 :
298- data [0 ][0 ].metadata .update ({
299- 'rollout_time' : rollout_time ,
300- 'completion_tokens_stats' : completion_tokens_stats ,
301- 'partial_samples' : state .partial_samples_count ,
302- 'total_off_policy_tokens' : state .total_off_policy_tokens ,
303- })
298+ data [0 ][0 ].metadata .update (
299+ {
300+ "rollout_time" : rollout_time ,
301+ "completion_tokens_stats" : completion_tokens_stats ,
302+ "partial_samples" : state .partial_samples_count ,
303+ "total_off_policy_tokens" : state .total_off_policy_tokens ,
304+ }
305+ )
304306 if completion_tokens_stats :
305307 print (f"[DEBUG] Rollout { rollout_id } : Completion tokens stats: { completion_tokens_stats } " , flush = True )
306308
0 commit comments