@@ -249,7 +249,6 @@ def _evaluate_metrics(
249249 dataloader ,
250250 metrics : dict [str , nn .Module ],
251251 device : torch .device ,
252- n_spatial_dims : int ,
253252) -> list [dict [str , object ]]:
254253 rows : list [dict [str , object ]] = []
255254 totals = dict .fromkeys (metrics , 0.0 )
@@ -270,7 +269,7 @@ def _evaluate_metrics(
270269 "num_samples" : batch_size ,
271270 }
272271 for name , metric in metrics .items ():
273- value = metric (preds , trues , n_spatial_dims )
272+ value = metric (preds , trues )
274273 scalar = float (value .mean ().item ())
275274 row [name ] = scalar
276275 totals [name ] += scalar * batch_size
@@ -496,8 +495,8 @@ def main() -> None:
496495 channel_count ,
497496 inferred_n_steps_input ,
498497 inferred_n_steps_output ,
499- _input_shape ,
500- output_shape ,
498+ _ ,
499+ _ ,
501500 ) = prepare_datamodule (cfg )
502501
503502 configure_module_dimensions (
@@ -508,15 +507,14 @@ def main() -> None:
508507 )
509508 normalize_processor_cfg (cfg )
510509
511- n_spatial_dims = _infer_spatial_dims (args , output_shape )
512510 metrics = _build_metrics (args .metrics or ("mse" , "rmse" ))
513511
514512 model = _load_model (cfg , args .checkpoint )
515513 device = _resolve_device (args .device )
516514 model .to (device )
517515
518516 test_loader = datamodule .test_dataloader ()
519- rows = _evaluate_metrics (model , test_loader , metrics , device , n_spatial_dims )
517+ rows = _evaluate_metrics (model , test_loader , metrics , device )
520518 _write_csv (rows , csv_path , list (metrics .keys ()))
521519
522520 aggregate_row = next ((row for row in rows if row .get ("batch_index" ) == "all" ), None )
0 commit comments