11from gridfm_graphkit .datasets .hetero_powergrid_datamodule import LitGridHeteroDataModule
22from gridfm_graphkit .io .param_handler import NestedNamespace
33from gridfm_graphkit .io .registries import DATASET_WRAPPER_REGISTRY
4- from gridfm_graphkit .training .callbacks import SaveBestModelStateDict
4+ from gridfm_graphkit .training .callbacks import (
5+ SaveBestModelStateDict ,
6+ EpochTimerCallback ,
7+ )
58import importlib
69import numpy as np
710import os
811import time
912import yaml
1013import torch
14+ import torch .distributed as dist
1115import pandas as pd
1216
1317from gridfm_graphkit .io .param_handler import get_task
@@ -186,6 +190,13 @@ def main_cli(args):
186190 trainer_kwargs ["precision" ] = precision
187191 profiler = getattr (args , "profiler" , None )
188192
193+ report_performance = getattr (args , "report_performance" , False )
194+ epoch_timer = EpochTimerCallback () if report_performance else None
195+
196+ training_callbacks = get_training_callbacks (config_args )
197+ if epoch_timer is not None :
198+ training_callbacks = training_callbacks + [epoch_timer ]
199+
189200 trainer = L .Trainer (
190201 logger = logger ,
191202 accelerator = config_args .training .accelerator ,
@@ -194,43 +205,80 @@ def main_cli(args):
194205 log_every_n_steps = 1000 ,
195206 default_root_dir = args .log_dir ,
196207 max_epochs = config_args .training .epochs ,
197- callbacks = get_training_callbacks ( config_args ) ,
208+ callbacks = training_callbacks ,
198209 ** trainer_kwargs ,
199210 profiler = profiler ,
200211 )
201212 if args .command == "train" or args .command == "finetune" :
202213 trainer .fit (model = model , datamodule = litGrid )
214+ if (
215+ report_performance
216+ and epoch_timer is not None
217+ and epoch_timer .last_epoch_time is not None
218+ ):
219+ print (f"[performance] last epoch time : { epoch_timer .last_epoch_time :.3f} s" )
220+ if (
221+ epoch_timer .last_epoch_iters_per_sec is not None
222+ and epoch_timer ._last_batch_count > 0
223+ ):
224+ print (
225+ f"[performance] last epoch it/s : { epoch_timer .last_epoch_iters_per_sec :.2f} " ,
226+ )
203227
204228 if args .command != "predict" :
205- test_trainer = L .Trainer (
206- logger = logger ,
207- accelerator = config_args .training .accelerator ,
208- devices = 1 ,
209- num_nodes = 1 ,
210- log_every_n_steps = 1 ,
211- default_root_dir = args .log_dir ,
212- ** trainer_kwargs ,
213- profiler = profiler ,
214- )
215- test_trainer .test (model = model , datamodule = litGrid )
216-
217- artifacts_dir = os .path .join (
218- logger .save_dir ,
219- logger .experiment_id ,
220- logger .run_id ,
221- "artifacts" ,
229+ # Reuse the fit trainer when coming from train/finetune so that
230+ # torch.compile kernel caches are already warm (avoids a second
231+ # AUTOTUNE pass on the first test batch).
232+ if args .command in ("train" , "finetune" ):
233+ test_trainer = trainer
234+ else :
235+ test_trainer = L .Trainer (
236+ logger = logger ,
237+ accelerator = config_args .training .accelerator ,
238+ devices = 1 ,
239+ num_nodes = 1 ,
240+ log_every_n_steps = 1 ,
241+ default_root_dir = args .log_dir ,
242+ ** trainer_kwargs ,
243+ profiler = profiler ,
244+ )
245+ test_results = test_trainer .test (model = model , datamodule = litGrid )
246+ if report_performance :
247+ # test_results[0] may be empty when metrics are routed to the logger
248+ # only; fall back to trainer.callback_metrics which always has them.
249+ metrics = (
250+ test_results [0 ]
251+ if test_results and test_results [0 ]
252+ else dict (test_trainer .callback_metrics )
253+ )
254+ if metrics :
255+ first_metric , first_value = next (iter (metrics .items ()))
256+ print (f"[performance] { first_metric } : { first_value } " )
257+ else :
258+ print ("[performance] no test metrics available" )
259+
260+ artifacts_dir = None
261+ is_rank0 = (
262+ not (dist .is_available () and dist .is_initialized ()) or dist .get_rank () == 0
222263 )
264+ if is_rank0 :
265+ artifacts_dir = os .path .join (
266+ logger .save_dir ,
267+ logger .experiment_id ,
268+ logger .run_id ,
269+ "artifacts" ,
270+ )
223271
224272 compute_dc_ac = getattr (args , "compute_dc_ac_metrics" , False )
225- if compute_dc_ac :
273+ if is_rank0 and compute_dc_ac :
226274 sn_mva = config_args .data .baseMVA
227275 for grid_name in config_args .data .networks :
228276 raw_dir = os .path .join (args .data_path , grid_name , "raw" )
229277 print (f"\n Computing ground-truth AC/DC metrics for { grid_name } ..." )
230278 compute_ac_dc_metrics (artifacts_dir , raw_dir , grid_name , sn_mva )
231279
232280 save_output = getattr (args , "save_output" , False ) or args .command == "predict"
233- if save_output :
281+ if is_rank0 and save_output :
234282 if len (config_args .data .networks ) > 1 :
235283 raise NotImplementedError (
236284 "Predict/save_output with multiple grids is not yet supported." ,
0 commit comments