Skip to content

Commit b076e2d

Browse files
authored
Merge pull request #97 from alan-turing-institute/eval-processor-bug-fix
Small bug fixes
2 parents dc84c87 + 5942fbd commit b076e2d

2 files changed

Lines changed: 6 additions & 7 deletions

File tree

configs/trainer/default.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ devices: 1
55
log_every_n_steps: 10
66
enable_checkpointing: true
77
detect_anomaly: false
8-
default_root_dir: ${hydra:run.dir}
8+
default_root_dir: null
9+

src/autocast/eval/processor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def _evaluate_metrics(
249249
dataloader,
250250
metrics: dict[str, nn.Module],
251251
device: torch.device,
252-
n_spatial_dims: int,
253252
) -> list[dict[str, object]]:
254253
rows: list[dict[str, object]] = []
255254
totals = dict.fromkeys(metrics, 0.0)
@@ -270,7 +269,7 @@ def _evaluate_metrics(
270269
"num_samples": batch_size,
271270
}
272271
for name, metric in metrics.items():
273-
value = metric(preds, trues, n_spatial_dims)
272+
value = metric(preds, trues)
274273
scalar = float(value.mean().item())
275274
row[name] = scalar
276275
totals[name] += scalar * batch_size
@@ -496,8 +495,8 @@ def main() -> None:
496495
channel_count,
497496
inferred_n_steps_input,
498497
inferred_n_steps_output,
499-
_input_shape,
500-
output_shape,
498+
_,
499+
_,
501500
) = prepare_datamodule(cfg)
502501

503502
configure_module_dimensions(
@@ -508,15 +507,14 @@ def main() -> None:
508507
)
509508
normalize_processor_cfg(cfg)
510509

511-
n_spatial_dims = _infer_spatial_dims(args, output_shape)
512510
metrics = _build_metrics(args.metrics or ("mse", "rmse"))
513511

514512
model = _load_model(cfg, args.checkpoint)
515513
device = _resolve_device(args.device)
516514
model.to(device)
517515

518516
test_loader = datamodule.test_dataloader()
519-
rows = _evaluate_metrics(model, test_loader, metrics, device, n_spatial_dims)
517+
rows = _evaluate_metrics(model, test_loader, metrics, device)
520518
_write_csv(rows, csv_path, list(metrics.keys()))
521519

522520
aggregate_row = next((row for row in rows if row.get("batch_index") == "all"), None)

0 commit comments

Comments
 (0)