File tree Expand file tree Collapse file tree
ice_station_zebra/callbacks Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11import logging
2+ from collections .abc import Sequence
23from typing import Any
34
45from lightning import LightningModule , Trainer
@@ -43,18 +44,16 @@ def on_test_batch_end(
4344 # Run plotting every `frequency` batches
4445 if batch_idx % self .frequency == 0 :
4546 # Get date for this batch
46- batch_size = outputs .target .shape [0 ]
47- test_dataloaders : DataLoader | list [DataLoader ] | None = (
48- trainer .test_dataloaders
49- )
50- if test_dataloaders is None :
47+ dataloaders : DataLoader | list [DataLoader ] | None = trainer .test_dataloaders
48+ if dataloaders is None :
5149 logger .debug ("No test dataloaders found, skipping plotting." )
5250 return
5351 dataset : CombinedDataset = (
54- test_dataloaders [dataloader_idx ]
55- if isinstance (test_dataloaders , list )
56- else test_dataloaders
57- ).dataset # type: ignore[assignment]
52+ dataloaders [dataloader_idx ]
53+ if isinstance (dataloaders , Sequence )
54+ else dataloaders
55+ ).dataset
56+ batch_size = outputs .target .shape [0 ]
5857 date_ = dataset .date_from_index (batch_size * batch_idx )
5958
6059 # Load the ground truth and prediction
You can’t perform that action at this time.
0 commit comments