Skip to content

Commit 5745a11

Browse files
Merge pull request #48 from gridfm/add_perf_fix
fix config restoration
2 parents ca05c73 + 3d904c6 commit 5745a11

5 files changed

Lines changed: 132 additions & 31 deletions

File tree

gridfm_graphkit/__main__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,12 @@ def main():
7676
choices=["simple", "advanced", "pytorch"],
7777
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
7878
)
79-
80-
# ---- FINETUNE SUBCOMMAND ----
79+
train_parser.add_argument(
80+
"--report-performance",
81+
dest="report_performance",
82+
action="store_true",
83+
help="Print the last training epoch time and a single test metric to stdout.",
84+
)
8185
finetune_parser = subparsers.add_parser("finetune", help="Run fine-tuning")
8286
finetune_parser.add_argument("--config", type=str, required=True)
8387
finetune_parser.add_argument("--model_path", type=str, required=True)
@@ -119,6 +123,12 @@ def main():
119123
choices=["simple", "advanced", "pytorch"],
120124
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
121125
)
126+
finetune_parser.add_argument(
127+
"--report-performance",
128+
dest="report_performance",
129+
action="store_true",
130+
help="Print the last training epoch time and a single test metric to stdout.",
131+
)
122132

123133
# ---- EVALUATE SUBCOMMAND ----
124134
evaluate_parser = subparsers.add_parser(

gridfm_graphkit/cli.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from gridfm_graphkit.datasets.hetero_powergrid_datamodule import LitGridHeteroDataModule
22
from gridfm_graphkit.io.param_handler import NestedNamespace
33
from gridfm_graphkit.io.registries import DATASET_WRAPPER_REGISTRY
4-
from gridfm_graphkit.training.callbacks import SaveBestModelStateDict
4+
from gridfm_graphkit.training.callbacks import (
5+
SaveBestModelStateDict,
6+
EpochTimerCallback,
7+
)
58
import importlib
69
import numpy as np
710
import os
811
import time
912
import yaml
1013
import torch
14+
import torch.distributed as dist
1115
import pandas as pd
1216

1317
from gridfm_graphkit.io.param_handler import get_task
@@ -186,6 +190,13 @@ def main_cli(args):
186190
trainer_kwargs["precision"] = precision
187191
profiler = getattr(args, "profiler", None)
188192

193+
report_performance = getattr(args, "report_performance", False)
194+
epoch_timer = EpochTimerCallback() if report_performance else None
195+
196+
training_callbacks = get_training_callbacks(config_args)
197+
if epoch_timer is not None:
198+
training_callbacks = training_callbacks + [epoch_timer]
199+
189200
trainer = L.Trainer(
190201
logger=logger,
191202
accelerator=config_args.training.accelerator,
@@ -194,43 +205,80 @@ def main_cli(args):
194205
log_every_n_steps=1000,
195206
default_root_dir=args.log_dir,
196207
max_epochs=config_args.training.epochs,
197-
callbacks=get_training_callbacks(config_args),
208+
callbacks=training_callbacks,
198209
**trainer_kwargs,
199210
profiler=profiler,
200211
)
201212
if args.command == "train" or args.command == "finetune":
202213
trainer.fit(model=model, datamodule=litGrid)
214+
if (
215+
report_performance
216+
and epoch_timer is not None
217+
and epoch_timer.last_epoch_time is not None
218+
):
219+
print(f"[performance] last epoch time : {epoch_timer.last_epoch_time:.3f}s")
220+
if (
221+
epoch_timer.last_epoch_iters_per_sec is not None
222+
and epoch_timer._last_batch_count > 0
223+
):
224+
print(
225+
f"[performance] last epoch it/s : {epoch_timer.last_epoch_iters_per_sec:.2f}",
226+
)
203227

204228
if args.command != "predict":
205-
test_trainer = L.Trainer(
206-
logger=logger,
207-
accelerator=config_args.training.accelerator,
208-
devices=1,
209-
num_nodes=1,
210-
log_every_n_steps=1,
211-
default_root_dir=args.log_dir,
212-
**trainer_kwargs,
213-
profiler=profiler,
214-
)
215-
test_trainer.test(model=model, datamodule=litGrid)
216-
217-
artifacts_dir = os.path.join(
218-
logger.save_dir,
219-
logger.experiment_id,
220-
logger.run_id,
221-
"artifacts",
229+
# Reuse the fit trainer when coming from train/finetune so that
230+
# torch.compile kernel caches are already warm (avoids a second
231+
# AUTOTUNE pass on the first test batch).
232+
if args.command in ("train", "finetune"):
233+
test_trainer = trainer
234+
else:
235+
test_trainer = L.Trainer(
236+
logger=logger,
237+
accelerator=config_args.training.accelerator,
238+
devices=1,
239+
num_nodes=1,
240+
log_every_n_steps=1,
241+
default_root_dir=args.log_dir,
242+
**trainer_kwargs,
243+
profiler=profiler,
244+
)
245+
test_results = test_trainer.test(model=model, datamodule=litGrid)
246+
if report_performance:
247+
# test_results[0] may be empty when metrics are routed to the logger
248+
# only; fall back to trainer.callback_metrics which always has them.
249+
metrics = (
250+
test_results[0]
251+
if test_results and test_results[0]
252+
else dict(test_trainer.callback_metrics)
253+
)
254+
if metrics:
255+
first_metric, first_value = next(iter(metrics.items()))
256+
print(f"[performance] {first_metric} : {first_value}")
257+
else:
258+
print("[performance] no test metrics available")
259+
260+
artifacts_dir = None
261+
is_rank0 = (
262+
not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0
222263
)
264+
if is_rank0:
265+
artifacts_dir = os.path.join(
266+
logger.save_dir,
267+
logger.experiment_id,
268+
logger.run_id,
269+
"artifacts",
270+
)
223271

224272
compute_dc_ac = getattr(args, "compute_dc_ac_metrics", False)
225-
if compute_dc_ac:
273+
if is_rank0 and compute_dc_ac:
226274
sn_mva = config_args.data.baseMVA
227275
for grid_name in config_args.data.networks:
228276
raw_dir = os.path.join(args.data_path, grid_name, "raw")
229277
print(f"\nComputing ground-truth AC/DC metrics for {grid_name}...")
230278
compute_ac_dc_metrics(artifacts_dir, raw_dir, grid_name, sn_mva)
231279

232280
save_output = getattr(args, "save_output", False) or args.command == "predict"
233-
if save_output:
281+
if is_rank0 and save_output:
234282
if len(config_args.data.networks) > 1:
235283
raise NotImplementedError(
236284
"Predict/save_output with multiple grids is not yet supported.",

gridfm_graphkit/tasks/utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,27 @@ def plot_residuals_histograms(outputs, dataset_name, plot_dir):
3030

3131
for stat_key, title in stats:
3232
# Gather all data first to compute common bin edges
33-
all_data = torch.cat(
34-
[
35-
torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs])
36-
for bus_type in bus_types
37-
],
38-
).numpy()
33+
all_data = (
34+
torch.cat(
35+
[
36+
torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs])
37+
for bus_type in bus_types
38+
],
39+
)
40+
.float()
41+
.numpy()
42+
)
3943

4044
# Define bins across the entire data range
4145
bins = np.linspace(all_data.min(), all_data.max(), 61) # 30 bins of equal width
4246

4347
plt.figure(figsize=(10, 6))
4448
for bus_type, color in zip(bus_types, colors):
45-
data = torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs]).numpy()
49+
data = (
50+
torch.cat([d[f"{stat_key}_{bus_type}"] for d in outputs])
51+
.float()
52+
.numpy()
53+
)
4654
plt.hist(data, bins=bins, alpha=0.6, label=bus_type, color=color)
4755

4856
plt.title(f"{title} per Bus Type in {dataset_name}")

gridfm_graphkit/training/callbacks.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,44 @@
22
from pytorch_lightning.utilities.rank_zero import rank_zero_only
33
from lightning.pytorch.loggers import MLFlowLogger
44
import os
5+
import time
56
import torch
67

78

9+
class EpochTimerCallback(Callback):
10+
"""Records wall-clock duration and iteration rate of every training epoch."""
11+
12+
def __init__(self):
13+
self.epoch_times: list[float] = []
14+
self._epoch_start: float | None = None
15+
self._batch_count: int = 0
16+
self._last_batch_count: int = 0
17+
18+
def on_train_epoch_start(self, trainer, pl_module):
19+
self._epoch_start = time.perf_counter()
20+
self._batch_count = 0
21+
22+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
23+
self._batch_count += 1
24+
25+
def on_train_epoch_end(self, trainer, pl_module):
26+
if self._epoch_start is not None:
27+
self.epoch_times.append(time.perf_counter() - self._epoch_start)
28+
self._last_batch_count = self._batch_count
29+
self._epoch_start = None
30+
31+
@property
32+
def last_epoch_time(self) -> float | None:
33+
return self.epoch_times[-1] if self.epoch_times else None
34+
35+
@property
36+
def last_epoch_iters_per_sec(self) -> float | None:
37+
t = self.last_epoch_time
38+
if t is None or t == 0 or self._last_batch_count == 0:
39+
return None
40+
return self._last_batch_count / t
41+
42+
843
class SaveBestModelStateDict(Callback):
944
def __init__(
1045
self,

integrationtests/test_base_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def cleanup_test_artifacts():
7070
"""
7171
Backup modified files and remove generated artifacts after the test.
7272
"""
73-
training_config = " "
73+
training_config = "examples/config/HGNS_PF_datakit_case14.yaml"
7474
backup_config = training_config + ".bak"
7575

7676
if os.path.exists(training_config):

0 commit comments

Comments
 (0)