77from omegaconf import DictConfig
88from torch .optim import Optimizer
99
10- from ice_station_zebra .types import CombinedTensorBatch , DataSpace
10+ from ice_station_zebra .types import DataSpace , TensorNTCHW
1111
1212
1313class ZebraModel (LightningModule , ABC ):
@@ -46,7 +46,7 @@ def __init__(
4646 self .save_hyperparameters ()
4747
4848 @abstractmethod
49- def forward (self , inputs : CombinedTensorBatch ) -> torch . Tensor :
49+ def forward (self , inputs : dict [ str , TensorNTCHW ] ) -> TensorNTCHW :
5050 """Forward step of the model
5151
5252 - start with multiple [NTCHW] inputs each with shape [batch, n_history_steps, C_input_k, H_input_k, W_input_k]
@@ -62,11 +62,11 @@ def configure_optimizers(self) -> Optimizer:
6262 dict (** self .optimizer_cfg ) | {"params" : self .model_list .parameters ()}
6363 )
6464
65- def loss (self , output : torch . Tensor , target : torch . Tensor ) -> torch .Tensor :
65+ def loss (self , output : TensorNTCHW , target : TensorNTCHW ) -> torch .Tensor :
6666 return torch .nn .functional .l1_loss (output , target )
6767
6868 def test_step (
69- self , batch : CombinedTensorBatch , batch_idx : int
69+ self , batch : dict [ str , TensorNTCHW ] , batch_idx : int
7070 ) -> dict [str , torch .Tensor ]:
7171 """Run the test step, in PyTorch eval model (i.e. no gradients)
7272
@@ -82,7 +82,9 @@ def test_step(
8282 loss = self .loss (output , target )
8383 return {"output" : output , "target" : target , "loss" : loss }
8484
85- def training_step (self , batch : CombinedTensorBatch , batch_idx : int ) -> torch .Tensor :
85+ def training_step (
86+ self , batch : dict [str , TensorNTCHW ], batch_idx : int
87+ ) -> torch .Tensor :
8688 """Run the training step
8789
8890 A batch contains one tensor for each input dataset and one for the target
@@ -97,7 +99,7 @@ def training_step(self, batch: CombinedTensorBatch, batch_idx: int) -> torch.Ten
9799 return self .loss (output , target )
98100
99101 def validation_step (
100- self , batch : CombinedTensorBatch , batch_idx : int
102+ self , batch : dict [ str , TensorNTCHW ] , batch_idx : int
101103 ) -> torch .Tensor :
102104 """Run the validation step
103105
0 commit comments