-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathzebra_model.py
More file actions
102 lines (82 loc) · 3.49 KB
/
zebra_model.py
File metadata and controls
102 lines (82 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import itertools
from abc import ABC, abstractmethod
import hydra
import torch
from lightning import LightningModule
from omegaconf import DictConfig
from torch.optim import Optimizer
from ice_station_zebra.types import DataSpace, LightningBatch
class ZebraModel(LightningModule, ABC):
def __init__(
self,
*,
name: str,
input_spaces: list[DictConfig],
output_space: DictConfig,
optimizer: DictConfig,
) -> None:
super().__init__()
# Save model name
self.name = name
# Construct the input and output spaces
self.input_spaces = [DataSpace.from_dict(space) for space in input_spaces]
self.output_space = DataSpace.from_dict(output_space)
# Store the optimizer config
self.optimizer_cfg = optimizer
# Save all of the arguments to __init__ as hyperparameters
# This will also save the parameters of whichever child class is used
# Note that W&B will log all hyperparameters
self.save_hyperparameters()
@abstractmethod
def forward(self, inputs: LightningBatch) -> torch.Tensor:
"""Forward step of the model"""
def configure_optimizers(self) -> Optimizer:
"""Construct the optimizer from the config"""
return hydra.utils.instantiate(
dict(**self.optimizer_cfg)
| {
"params": itertools.chain(
*[module.parameters() for module in self.children()]
)
}
)
def loss(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.l1_loss(output, target)
def test_step(
self, batch: LightningBatch, batch_idx: int
) -> dict[str, torch.Tensor]:
"""Run the test step, in PyTorch eval model (i.e. no gradients)
A batch contains one tensor for each input dataset followed by one for the target
The shape of each of these tensors is (batch_size; variables; ensembles; position)
- Separate the batch into inputs and target
- Run inputs through the model
- Return the output, target and loss
"""
inputs, target = batch[:-1], batch[-1]
output = self(inputs)
loss = self.loss(output, target)
return {"output": output, "target": target, "loss": loss}
def training_step(self, batch: LightningBatch, batch_idx: int) -> torch.Tensor:
"""Run the training step
A batch contains one tensor for each input dataset followed by one for the target
The shape of each of these tensors is (batch_size; variables; ensembles; position)
- Separate the batch into inputs and target
- Run inputs through the model
- Calculate the loss wrt. the target
"""
inputs, target = batch[:-1], batch[-1]
output = self(inputs)
return self.loss(output, target)
def validation_step(self, batch: LightningBatch, batch_idx: int) -> torch.Tensor:
"""Run the validation step
A batch contains one tensor for each input dataset followed by one for the target
The shape of each of these tensors is (batch_size; variables; ensembles; position)
- Separate the batch into inputs and target
- Run inputs through the model
- Calculate the loss wrt. the target
"""
inputs, target = batch[:-1], batch[-1]
output = self(inputs)
loss = self.loss(output, target)
self.log("validation_loss", loss)
return loss