Skip to content

Commit fce4d27

Browse files
authored
Merge pull request #56 from gridfm/add_opf_predict_outputs_add_bs_override
Add opf predict outputs add bs override
2 parents d351909 + 97b0776 commit fce4d27

4 files changed

Lines changed: 74 additions & 37 deletions

File tree

gridfm_graphkit/__main__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ def main():
214214
evaluate_parser.add_argument("--run_name", type=str, default="run")
215215
evaluate_parser.add_argument("--log_dir", type=str, default="mlruns")
216216
evaluate_parser.add_argument("--data_path", type=str, default="data")
217+
evaluate_parser.add_argument(
218+
"--batch_size",
219+
type=int,
220+
default=None,
221+
help="Override training.batch_size from the YAML config for evaluation.",
222+
)
217223
evaluate_parser.add_argument("--compile", **_compile_kwargs)
218224
evaluate_parser.add_argument("--bfloat16", **_bfloat16_kwargs)
219225
evaluate_parser.add_argument("--tf32", **_tf32_kwargs)
@@ -266,6 +272,12 @@ def main():
266272
predict_parser.add_argument("--run_name", type=str, default="run")
267273
predict_parser.add_argument("--log_dir", type=str, default="mlruns")
268274
predict_parser.add_argument("--data_path", type=str, default="data")
275+
predict_parser.add_argument(
276+
"--batch_size",
277+
type=int,
278+
default=None,
279+
help="Override training.batch_size from the YAML config for prediction.",
280+
)
269281
predict_parser.add_argument(
270282
"--dataset_wrapper",
271283
type=str,

gridfm_graphkit/cli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ def main_cli(args):
177177
if num_workers_override is not None:
178178
config_args.data.workers = num_workers_override
179179

180+
batch_size_override = getattr(args, "batch_size", None)
181+
if batch_size_override is not None:
182+
config_args.training.batch_size = batch_size_override
183+
180184
_load_plugins(getattr(args, "plugins", []))
181185
_validate_dataset_wrapper(dataset_wrapper)
182186

@@ -222,12 +226,16 @@ def main_cli(args):
222226
if epoch_timer is not None:
223227
training_callbacks = training_callbacks + [epoch_timer]
224228

229+
_accelerator = config_args.training.accelerator
225230
_strategy = config_args.training.strategy
226-
if isinstance(_strategy, str) and _strategy in (
231+
# if mps is available and accelerator is auto, explicitely set accelerator to mps to select the right strategy in the next block
232+
if _accelerator == "auto" and torch.backends.mps.is_available():
233+
_accelerator = "mps"
234+
if _accelerator not in ("mps", "cpu") and isinstance(_strategy, str) and _strategy in (
227235
"auto",
228236
"ddp",
229237
"ddp_find_unused_parameters_true",
230-
):
238+
): # when using mps, we don't want to use ddp.
231239
_strategy = DDPStrategy(find_unused_parameters=True)
232240

233241
trainer = L.Trainer(

gridfm_graphkit/tasks/opf_task.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
QG_H,
66
VM_H,
77
VA_H,
8+
MIN_VM_H,
9+
MAX_VM_H,
810
MIN_QG_H,
911
MAX_QG_H,
1012
# Output feature indices
@@ -574,30 +576,37 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
574576
"bus": {
575577
"scenario": scenario_ids.cpu().numpy(),
576578
"bus": local_bus_idx.cpu().numpy(),
577-
"pd_mw": bus_x[:, PD_H].cpu().numpy(),
578-
"qd_mvar": bus_x[:, QD_H].cpu().numpy(),
579-
"vm_pu_target": bus_y[:, VM_H].cpu().numpy(),
580-
"va_target": bus_y[:, VA_H].cpu().numpy(),
581-
"pg_mw_target": agg_gen_on_bus.squeeze().cpu().numpy(),
582-
"qg_mvar_target": bus_y[:, QG_H].cpu().numpy(),
583-
"is_pq": mask_PQ.cpu().numpy().astype(int),
584-
"is_pv": mask_PV.cpu().numpy().astype(int),
585-
"is_ref": mask_REF.cpu().numpy().astype(int),
586-
"vm_pu": output["bus"][:, VM_OUT].detach().cpu().numpy(),
587-
"va": output["bus"][:, VA_OUT].detach().cpu().numpy(),
588-
"pg_mw": output["bus"][:, PG_OUT].detach().cpu().numpy(),
589-
"qg_mvar": output["bus"][:, QG_OUT].detach().cpu().numpy(),
579+
"Pd": bus_x[:, PD_H].cpu().numpy(),
580+
"Qd": bus_x[:, QD_H].cpu().numpy(),
581+
"Vm_min": bus_x[:, MIN_VM_H].cpu().numpy(),
582+
"Vm_max": bus_x[:, MAX_VM_H].cpu().numpy(),
583+
"Qg_min": bus_x[:, MIN_QG_H].cpu().numpy(),
584+
"Qg_max": bus_x[:, MAX_QG_H].cpu().numpy(),
585+
"Vm_target": bus_y[:, VM_H].cpu().numpy(),
586+
"Va_target": bus_y[:, VA_H].cpu().numpy(),
587+
"Pg_target": agg_gen_on_bus.squeeze().cpu().numpy(),
588+
"Qg_target": bus_y[:, QG_H].cpu().numpy(),
589+
"PQ": mask_PQ.cpu().numpy().astype(int),
590+
"PV": mask_PV.cpu().numpy().astype(int),
591+
"REF": mask_REF.cpu().numpy().astype(int),
592+
"Vm_pred": output["bus"][:, VM_OUT].detach().cpu().numpy(),
593+
"Va_pred": output["bus"][:, VA_OUT].detach().cpu().numpy(),
594+
"Pg_pred": output["bus"][:, PG_OUT].detach().cpu().numpy(),
595+
"Qg_pred": output["bus"][:, QG_OUT].detach().cpu().numpy(),
590596
"active res. (MW)": residual_P.detach().cpu().numpy(),
591597
"reactive res. (MVar)": residual_Q.detach().cpu().numpy(),
592598
"PBE": residual_mva.detach().cpu().numpy(),
593599
},
594600
"gen": {
595601
"scenario": gen_scenario_ids.cpu().numpy(),
596-
"gen": local_gen_idx.cpu().numpy(),
597-
"connected_bus": local_bus_idx[gen_to_bus_index].cpu().numpy(),
598-
"pg_mw_target": gen_target.cpu().numpy(),
599-
"pg_mw": gen_pred.detach().cpu().numpy(),
600-
"min_pg_mw": gen_x[:, MIN_PG].cpu().numpy(),
601-
"max_pg_mw": gen_x[:, MAX_PG].cpu().numpy(),
602+
"idx": local_gen_idx.cpu().numpy(),
603+
"bus": local_bus_idx[gen_to_bus_index].cpu().numpy(),
604+
"p_mw_target": gen_target.cpu().numpy(),
605+
"p_mw_pred": gen_pred.detach().cpu().numpy(),
606+
"min_p_mw": gen_x[:, MIN_PG].cpu().numpy(),
607+
"max_p_mw": gen_x[:, MAX_PG].cpu().numpy(),
608+
"cp0_eur": gen_x[:, C0_H].cpu().numpy(),
609+
"cp1_eur_per_mw": gen_x[:, C1_H].cpu().numpy(),
610+
"cp2_eur_per_mw2": gen_x[:, C2_H].cpu().numpy(),
602611
},
603612
}

gridfm_graphkit/tasks/pf_task.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
QG_H,
66
VM_H,
77
VA_H,
8+
MIN_VM_H,
9+
MAX_VM_H,
10+
MIN_QG_H,
11+
MAX_QG_H,
812
# Output feature indices
913
VM_OUT,
1014
VA_OUT,
@@ -404,20 +408,24 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
404408
return {
405409
"scenario": scenario_ids.cpu().numpy(),
406410
"bus": local_bus_idx.cpu().numpy(),
407-
"pd_mw": bus_x[:, PD_H].cpu().numpy(), # from original input
408-
"qd_mvar": bus_x[:, QD_H].cpu().numpy(), # from original input
409-
"vm_pu_target": bus_y[:, VM_H].cpu().numpy(), # from original input
410-
"va_target": bus_y[:, VA_H].cpu().numpy(), # from original input
411-
"pg_mw_target": agg_gen_on_bus.squeeze().cpu().numpy(), # from original input
412-
"qg_mvar_target": bus_y[:, QG_H].cpu().numpy(), # from original input
413-
"is_pq": mask_PQ.cpu().numpy().astype(int), # from original input
414-
"is_pv": mask_PV.cpu().numpy().astype(int), # from original input
415-
"is_ref": mask_REF.cpu().numpy().astype(int), # from original input
416-
"vm_pu": output["bus"][:, VM_OUT].detach().cpu().numpy(), # predicted output
417-
"va": output["bus"][:, VA_OUT].detach().cpu().numpy(), # predicted output
418-
"pg_mw": output["bus"][:, PG_OUT].detach().cpu().numpy(), # predicted output
419-
"qg_mvar": output["bus"][:, QG_OUT].detach().cpu().numpy(), # predicted output
420-
"active res. (MW)": residual_P.detach().cpu().numpy(), # predicted output
421-
"reactive res. (MVar)": residual_Q.detach().cpu().numpy(), # predicted output
422-
"PBE": residual_mva.detach().cpu().numpy(), # predicted output
411+
"Pd": bus_x[:, PD_H].cpu().numpy(),
412+
"Qd": bus_x[:, QD_H].cpu().numpy(),
413+
"Vm_min": bus_x[:, MIN_VM_H].cpu().numpy(),
414+
"Vm_max": bus_x[:, MAX_VM_H].cpu().numpy(),
415+
"Qg_min": bus_x[:, MIN_QG_H].cpu().numpy(),
416+
"Qg_max": bus_x[:, MAX_QG_H].cpu().numpy(),
417+
"Vm_target": bus_y[:, VM_H].cpu().numpy(),
418+
"Va_target": bus_y[:, VA_H].cpu().numpy(),
419+
"Pg_target": agg_gen_on_bus.squeeze().cpu().numpy(),
420+
"Qg_target": bus_y[:, QG_H].cpu().numpy(),
421+
"PQ": mask_PQ.cpu().numpy().astype(int),
422+
"PV": mask_PV.cpu().numpy().astype(int),
423+
"REF": mask_REF.cpu().numpy().astype(int),
424+
"Vm_pred": output["bus"][:, VM_OUT].detach().cpu().numpy(),
425+
"Va_pred": output["bus"][:, VA_OUT].detach().cpu().numpy(),
426+
"Pg_pred": output["bus"][:, PG_OUT].detach().cpu().numpy(),
427+
"Qg_pred": output["bus"][:, QG_OUT].detach().cpu().numpy(),
428+
"active res. (MW)": residual_P.detach().cpu().numpy(),
429+
"reactive res. (MVar)": residual_Q.detach().cpu().numpy(),
430+
"PBE": residual_mva.detach().cpu().numpy(),
423431
}

0 commit comments

Comments
 (0)