-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathzebra_model.py
More file actions
116 lines (94 loc) · 3.91 KB
/
zebra_model.py
File metadata and controls
116 lines (94 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import itertools
from abc import ABC, abstractmethod
import hydra
import torch
from lightning import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from omegaconf import DictConfig
from ice_station_zebra.types import DataSpace, ModelTestOutput, TensorNTCHW
class ZebraModel(LightningModule, ABC):
def __init__(
self,
*,
name: str,
input_spaces: list[DictConfig],
n_forecast_steps: int,
n_history_steps: int,
output_space: DictConfig,
optimizer: DictConfig,
) -> None:
super().__init__()
# Save model name
self.name = name
# Save history and forecast steps
self.n_forecast_steps = n_forecast_steps
self.n_history_steps = n_history_steps
# Construct the input and output spaces
self.input_spaces = [DataSpace.from_dict(space) for space in input_spaces]
self.output_space = DataSpace.from_dict(output_space)
# Store the optimizer config
self.optimizer_cfg = optimizer
# Save all of the arguments to __init__ as hyperparameters
# This will also save the parameters of whichever child class is used
# Note that W&B will log all hyperparameters
self.save_hyperparameters()
@abstractmethod
def forward(self, inputs: dict[str, TensorNTCHW]) -> TensorNTCHW:
"""Forward step of the model
- start with multiple [NTCHW] inputs each with shape [batch, n_history_steps, C_input_k, H_input_k, W_input_k]
- return a single [NTCHW] output [batch, n_forecast_steps, C_output, H_output, W_output]
"""
def configure_optimizers(self) -> OptimizerLRScheduler:
"""Construct the optimizer from the config"""
return hydra.utils.instantiate(
dict(**self.optimizer_cfg)
| {
"params": itertools.chain(
*[module.parameters() for module in self.children()]
)
}
)
def loss(self, output: TensorNTCHW, target: TensorNTCHW) -> torch.Tensor:
return torch.nn.functional.l1_loss(output, target)
def test_step(
self, batch: dict[str, TensorNTCHW], batch_idx: int
) -> ModelTestOutput:
"""Run the test step, in PyTorch eval model (i.e. no gradients)
A batch contains one tensor for each input dataset and one for the target
These are [NTCHW] tensors with (batch_size, n_history_steps, C, H, W)
- Separate the batch into inputs and target
- Run inputs through the model
- Return the output, target and loss
"""
target = batch.pop("target")
prediction = self(batch)
loss = self.loss(prediction, target)
return ModelTestOutput(prediction, target, loss)
def training_step(
self, batch: dict[str, TensorNTCHW], batch_idx: int
) -> torch.Tensor:
"""Run the training step
A batch contains one tensor for each input dataset and one for the target
These are [NTCHW] tensors with (batch_size, n_history_steps, C, H, W)
- Separate the batch into inputs and target
- Run inputs through the model
- Calculate the loss wrt. the target
"""
target = batch.pop("target")
output = self(batch)
return self.loss(output, target)
def validation_step(
self, batch: dict[str, TensorNTCHW], batch_idx: int
) -> torch.Tensor:
"""Run the validation step
A batch contains one tensor for each input dataset and one for the target
These are [NTCHW] tensors with (batch_size, n_history_steps, C, H, W)
- Separate the batch into inputs and target
- Run inputs through the model
- Calculate the loss wrt. the target
"""
target = batch.pop("target")
output = self(batch)
loss = self.loss(output, target)
self.log("validation_loss", loss)
return loss