Skip to content

Commit ad330f7

Browse files
committed
♻️ Minor reordering of date extraction code
1 parent 47a136b commit ad330f7

1 file changed

Lines changed: 8 additions & 9 deletions

File tree

ice_station_zebra/callbacks/plotting_callback.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections.abc import Sequence
23
from typing import Any
34

45
from lightning import LightningModule, Trainer
@@ -43,18 +44,16 @@ def on_test_batch_end(
4344
# Run plotting every `frequency` batches
4445
if batch_idx % self.frequency == 0:
4546
# Get date for this batch
46-
batch_size = outputs.target.shape[0]
47-
test_dataloaders: DataLoader | list[DataLoader] | None = (
48-
trainer.test_dataloaders
49-
)
50-
if test_dataloaders is None:
47+
dataloaders: DataLoader | list[DataLoader] | None = trainer.test_dataloaders
48+
if dataloaders is None:
5149
logger.debug("No test dataloaders found, skipping plotting.")
5250
return
5351
dataset: CombinedDataset = (
54-
test_dataloaders[dataloader_idx]
55-
if isinstance(test_dataloaders, list)
56-
else test_dataloaders
57-
).dataset # type: ignore[assignment]
52+
dataloaders[dataloader_idx]
53+
if isinstance(dataloaders, Sequence)
54+
else dataloaders
55+
).dataset
56+
batch_size = outputs.target.shape[0]
5857
date_ = dataset.date_from_index(batch_size * batch_idx)
5958

6059
# Load the ground truth and prediction

0 commit comments

Comments
 (0)