Skip to content

Commit c8e39e8

Browse files
authored
Merge pull request #33 from alan-turing-institute/add-base
Update base classes
2 parents 01ef346 + 2d77dfa commit c8e39e8

11 files changed

Lines changed: 1037 additions & 110 deletions

File tree

docs/DEVELOPER_GUIDE.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Developer Guide
2+
TODO
3+
4+
## Overview
5+
TODO
6+
7+
## API notes
8+
9+
10+
### Trainer and Model Integration
11+
Example usage with `lightning` Trainer:
12+
13+
```python
14+
model = EncoderDecoder() # Anything that inherits for L.LightningModule
15+
trainer = L.Trainer()
16+
trainer.fit(model, train_dataloader) # train_dataloader should output a batch of data from an iterable.
17+
18+
model = EncoderProcessorDecoder()
19+
trainer = L.Trainer()
20+
trainer.fit(model, train_dataloader)
21+
```
22+
23+
### Model API
24+
Subclasses of `LightningModule` from `lightning` aim to have API:
25+
```python
26+
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
27+
def forward(self, x: Tensor) -> Tensor: ...
28+
```
29+
30+
Direct subclasses of `nn.Module` from `torch` aim to have API:
31+
```python
32+
def forward(self, x: Tensor) -> Tensor: ...

pyproject.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@ readme = "README.md"
66
authors = [
77
{ name = "AI for Physical Systems Team at The Alan Turing Institute", email = "ai4physics@turing.ac.uk" },
88
]
9-
requires-python = ">=3.10,<3.13"
10-
dependencies = []
9+
requires-python = ">=3.11,<3.13"
10+
dependencies = [
11+
"einops>=0.8.1",
12+
"h5py>=3.15.1",
13+
"lightning>=2.5.6",
14+
"the-well>=1.1.0",
15+
"torch>=2.9.1",
16+
]
1117

1218
[project.optional-dependencies]
1319
dev = [
@@ -69,6 +75,7 @@ ignore = [
6975
"ISC001", # Conflicts with formatter
7076
# "D417", # Missing trailing new line in docstring
7177
"D100", # Missing docstring in public module
78+
"D102", # Missing docstring in public class, TODO: remove in future
7279
"D104", # Missing docstring in public package
7380
"PLR0913", # too many arguments
7481
]

src/auto_cast/decoders/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base import Decoder
2+
3+
__all__ = ["Decoder"]

src/auto_cast/decoders/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Any
2+
3+
from torch import Tensor, nn
4+
5+
6+
class Decoder(nn.Module):
7+
"""Base Decoder."""
8+
9+
def forward(self, *args: Any, **kwargs: Any) -> Any: ...

src/auto_cast/encoders/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base import Encoder
2+
3+
__all__ = ["Encoder"]

src/auto_cast/encoders/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any
2+
3+
from torch import nn
4+
5+
6+
class Encoder(nn.Module):
7+
"""Base encoder."""
8+
9+
def forward(self, *args: Any, **kwargs: Any) -> Any:
10+
"""Forward Pass through the Encoder."""
11+
msg = "To implement."
12+
raise NotImplementedError(msg)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Any
2+
3+
import lightning as L
4+
import torch
5+
from torch import nn
6+
7+
from auto_cast.decoders import Decoder
8+
from auto_cast.encoders import Encoder
9+
from auto_cast.processors.base import Preprocessor
10+
from auto_cast.types import Batch, Tensor
11+
12+
13+
class EncoderDecoder(L.LightningModule):
14+
"""Encoder-Decoder Model."""
15+
16+
encoder: Encoder
17+
decoder: Decoder
18+
preprocessor: Preprocessor
19+
loss_func: nn.Module
20+
21+
def __init__(self):
22+
pass
23+
24+
def forward(self, *args: Any, **kwargs: Any) -> Any:
25+
return self.decoder(self.encoder(*args, **kwargs))
26+
27+
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
28+
x = self.preprocessor(batch)
29+
output = self(x)
30+
loss = self.loss_func(output, batch["output_fields"])
31+
return loss # noqa: RET504
32+
33+
def validation_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
34+
35+
def test_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
36+
37+
def predict_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
38+
39+
def encode(self, x: Batch) -> Tensor:
40+
x = self.preprocessor(x)
41+
return self.encoder(x)
42+
43+
def configure_optmizers(self):
44+
pass
45+
46+
47+
class VAE(EncoderDecoder):
48+
"""Variational Autoencoder Model."""
49+
50+
def forward(self, x: Tensor) -> Tensor:
51+
mu, log_var = self.encoder(x)
52+
z = self.reparametrize(mu, log_var)
53+
x = self.decoder(z)
54+
return x # noqa: RET504
55+
56+
def reparametrize(self, mu: Tensor, log_var: Tensor) -> Tensor:
57+
std = torch.exp(0.5 * log_var)
58+
eps = torch.randn_like(std)
59+
return mu + eps * std
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any
2+
3+
import lightning as L
4+
import torch
5+
from torch import nn
6+
7+
from auto_cast.models.encoder_decoder import EncoderDecoder
8+
from auto_cast.processors.base import Processor
9+
from auto_cast.types import Batch, RolloutOutput, Tensor
10+
11+
12+
class EncoderProcessorDecoder(L.LightningModule):
13+
"""Encoder-Processor-Decoder Model."""
14+
15+
encoder_decoder: EncoderDecoder
16+
processor: Processor
17+
teacher_forcing_ratio: float
18+
stride: int
19+
max_rollout_steps: int
20+
loss_func: nn.Module
21+
22+
def __init__(self): ...
23+
24+
def from_encoder_processor_decoder(
25+
self, encoder_decoder: EncoderDecoder, processor: Processor
26+
) -> None:
27+
self.encoder_decoder = encoder_decoder
28+
self.processor = processor
29+
30+
def forward(self, *args: Any, **kwargs: Any) -> Any:
31+
return self.encoder_decoder.decoder(
32+
self.processor(self.encoder_decoder.encoder(*args, **kwargs))
33+
)
34+
35+
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
36+
output = self(batch)
37+
loss = self.processor.loss_func(output, batch.output_fields)
38+
return loss # noqa: RET504
39+
40+
def configure_optimizers(self): ...
41+
42+
def rollout(self, batch: Batch) -> RolloutOutput:
43+
"""Rollout over multiple time steps."""
44+
pred_outs, gt_outs = [], []
45+
for _ in range(0, self.max_rollout_steps, self.stride):
46+
x = self.encoder_decoder.encoder(batch)
47+
pred_outs.append(self.processor.map(x))
48+
# TODO: combining teacher forcing logic
49+
gt_outs.append(batch.output_fields) # This assumes we have output fields
50+
return torch.stack(pred_outs), torch.stack(gt_outs)
51+
52+
53+
# TODO: consider if separate rollout class would be better
54+
class Rollout:
55+
max_rollout_steps: int
56+
stride: int
57+
58+
def rollout(
59+
self,
60+
batch: Batch,
61+
model: Processor | EncoderProcessorDecoder,
62+
) -> RolloutOutput:
63+
"""Rollout over multiple time steps."""
64+
pred_outs, gt_outs = [], []
65+
for _ in range(0, self.max_rollout_steps, self.stride):
66+
output = model(batch)
67+
pred_outs.append(output)
68+
# TODO: logic for moving window with teacher forcing that assigns
69+
gt_outs.append(batch.output_fields) # This assumes we have output fields
70+
return torch.stack(pred_outs), torch.stack(gt_outs)

src/auto_cast/processors/base.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
import lightning as L
5+
import torch
6+
from torch import nn
7+
8+
from auto_cast.types import Batch, EncodedBatch, RolloutOutput, Tensor
9+
10+
11+
class Processor(L.LightningModule):
12+
"""Processor Base Class."""
13+
14+
teacher_forcing_ratio: float
15+
stride: int
16+
max_rollout_steps: int
17+
loss_func: nn.Module
18+
19+
def forward(self, *args, **kwargs: Any) -> Any:
20+
"""Forward pass through the Processor."""
21+
msg = "To implement."
22+
raise NotImplementedError(msg)
23+
24+
def training_step(self, batch: EncodedBatch, batch_idx: int) -> Tensor: # noqa: ARG002
25+
output = self.map(batch.encoded_inputs)
26+
loss = self.loss_func(output, batch.encoded_output_fields)
27+
return loss # noqa: RET504
28+
29+
@abstractmethod
30+
def map(self, x: Tensor) -> Tensor:
31+
"""Map input window of states/times to output window."""
32+
33+
def configure_optimizers(self): ...
34+
35+
def rollout(self, batch: EncodedBatch) -> RolloutOutput:
36+
"""Rollout over multiple time steps."""
37+
pred_outs, gt_outs = [], []
38+
for _ in range(0, self.max_rollout_steps, self.stride):
39+
pred_outs.append(self.map(batch.encoded_inputs))
40+
# TODO: combining teacher forcing logic
41+
gt_outs.append(
42+
batch.encoded_output_fields
43+
) # This assumes we have output fields
44+
return torch.stack(pred_outs), torch.stack(gt_outs)
45+
46+
47+
class DiscreteProcessor(Processor, ABC):
48+
"""DiscreteProcessor."""
49+
50+
@abstractmethod
51+
def map(self, x: Tensor) -> Tensor:
52+
...
53+
# Map input window of states/times to output window
54+
55+
def rollout(self, batch: EncodedBatch) -> RolloutOutput:
56+
...
57+
58+
# Use self.map to generate trajectory
59+
60+
61+
class FlowBasedGenerativeProcessor(DiscreteProcessor):
62+
"""Flow-based generative processor."""
63+
64+
def map(self, x: Tensor) -> Tensor:
65+
...
66+
# Sample generative model def loss(self, ...):...
67+
# Flow matc

src/auto_cast/types/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
from torch.utils.data import DataLoader
5+
6+
Tensor = torch.Tensor
7+
Input = Tensor | DataLoader
8+
RolloutOutput = tuple[Tensor, None] | tuple[Tensor, Tensor]
9+
10+
# Batch = dict[str, Tensor]
11+
# EncodedBatch = dict[str, Tensor]
12+
13+
14+
# TODO: Could be a dataclass if we want more structure
15+
@dataclass
16+
class Batch:
17+
input_fields: Tensor
18+
output_fields: Tensor
19+
constant_scalars: Tensor
20+
constant_fields: Tensor
21+
22+
23+
@dataclass
24+
class EncodedBatch:
25+
encoded_inputs: Tensor
26+
encoded_output_fields: Tensor
27+
encoded_info: dict[str, Tensor]
28+
29+
30+
class EncoderForBatch:
31+
"""EncoderForBatch."""
32+
33+
def __call__(self, batch: Batch) -> EncodedBatch:
34+
return EncodedBatch(
35+
encoded_inputs=batch.input_fields,
36+
encoded_output_fields=batch.output_fields,
37+
encoded_info={},
38+
)

0 commit comments

Comments
 (0)