@@ -37,7 +37,7 @@ def rewrite_batch(
3737 batch : Dict [str , List [Any ]],
3838 mapper : MMIRAGEMapper ,
3939 renderer : TemplateRenderer ,
40- image_base_path : str = None ,
40+ image_base_path : Optional [ str ] = None ,
4141) -> Dict [str , List [Any ]]:
4242 """Rewrite a batch of samples by applying transformations.
4343 Args:
@@ -91,6 +91,8 @@ def main():
9191
9292 state_dir = shard_state_dir (shard_id , loading_params .get_state_root ())
9393
94+ gpu_poller : Optional [GpuUtilizationPoller ] = None
95+
9496 collect_stats = os .environ .get ("MMIRAGE_COLLECT_STATS" , "" ) == "1"
9597 if collect_stats :
9698 # Determine which physical GPU indices SGLang will use so the poller
@@ -112,9 +114,11 @@ def main():
112114 gpu_indices_for_polling : List [str ] = all_visible [:tp_size ] if all_visible else [str (i ) for i in range (tp_size )]
113115 else :
114116 gpu_indices_for_polling = [str (i ) for i in range (tp_size )]
115- gpu_poller : GpuUtilizationPoller = GpuUtilizationPoller (
117+
118+ gpu_poller = GpuUtilizationPoller (
116119 interval_seconds = 5.0 , gpu_indices = gpu_indices_for_polling
117120 )
121+
118122 try :
119123 retry_count = _mark_running (state_dir , shard_id , datasets_config )
120124 logger .info (f"Starting shard { shard_id } /{ last_shard_id } (attempt #{ retry_count } )" )
@@ -144,7 +148,7 @@ def main():
144148
145149 # Start GPU polling after model loading so utilisation samples reflect
146150 # inference only, not weight transfers during sgl.Engine() init.
147- if collect_stats :
151+ if collect_stats and gpu_poller is not None :
148152 gpu_poller .start ()
149153
150154 ds_processed_all : List [DatasetLike ] = []
@@ -180,7 +184,7 @@ def main():
180184 _save_dataset_atomic (ds_processed , out_dir )
181185 logger .info (f"✅ Saved dataset { ds_idx } shard in: { out_dir } " )
182186
183- gpu_info = gpu_poller .stop () if collect_stats else {"mean" : None , "min" : None , "max" : None , "samples" : 0 }
187+ gpu_info = gpu_poller .stop () if collect_stats and gpu_poller is not None else {"mean" : None , "min" : None , "max" : None , "samples" : 0 }
184188
185189 # Collect token counts accumulated by LLM processor(s).
186190 token_counts = mapper .get_token_counts ()
@@ -214,7 +218,7 @@ def main():
214218 error_msg = f"{ type (e ).__name__ } : { str (e )} "
215219 logger .error (f"❌ Shard { shard_id } failed: { error_msg } " )
216220 logger .error (traceback .format_exc ())
217- if collect_stats :
221+ if collect_stats and gpu_poller is not None :
218222 gpu_poller .stop ()
219223 _mark_failure (state_dir , error_msg )
220224 sys .exit (1 )
0 commit comments