Skip to content

Commit 3202252

Browse files
committed
Add flexibility to lower level API
1 parent 4103038 commit 3202252

6 files changed

Lines changed: 49 additions & 58 deletions

File tree

src/auto_cast/decoders/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from torch import nn
1+
from typing import Any
22

3-
from auto_cast.types import Tensor
3+
from torch import nn
44

55

66
class Decoder(nn.Module):
77
"""Base Decoder."""
88

9-
# Q: Should decoder handle all these input types
10-
def forward(self, x: Tensor) -> Tensor: ...
9+
def forward(self, *args: Any, **kwargs: Any) -> Any: ...

src/auto_cast/encoders/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from torch import nn
1+
from typing import Any
22

3-
from auto_cast.types import Tensor
3+
from torch import nn
44

55

66
class Encoder(nn.Module):
77
"""Base encoder."""
88

9-
def forward(self, x: Tensor) -> Tensor:
9+
def forward(self, *args: Any, **kwargs: Any) -> Any:
1010
"""Forward Pass through the Encoder."""
1111
msg = "To implement."
1212
raise NotImplementedError(msg)
Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Any
2+
13
import lightning as L
24
import torch
35
from torch import nn
46

57
from auto_cast.decoders import Decoder
68
from auto_cast.encoders import Encoder
7-
from auto_cast.processors.base import Preprocessor, Processor
9+
from auto_cast.processors.base import Preprocessor
810
from auto_cast.types import Batch, Tensor
911

1012

@@ -19,16 +21,22 @@ class EncoderDecoder(L.LightningModule):
1921
def __init__(self):
2022
pass
2123

22-
def forward(self, x: Tensor) -> Tensor:
23-
return self.decoder(self.encoder(x))
24+
def forward(self, *args: Any, **kwargs: Any) -> Any:
25+
return self.decoder(self.encoder(*args, **kwargs))
2426

2527
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
2628
x = self.preprocessor(batch)
2729
output = self(x)
2830
loss = self.loss_func(output, batch["output_fields"])
2931
return loss # noqa: RET504
3032

31-
def encode_only(self, x: Batch) -> Tensor:
33+
def validation_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
34+
35+
def test_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
36+
37+
def predict_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
38+
39+
def encode(self, x: Batch) -> Tensor:
3240
x = self.preprocessor(x)
3341
return self.encoder(x)
3442

@@ -45,25 +53,7 @@ def forward(self, x: Tensor) -> Tensor:
4553
x = self.decoder(z)
4654
return x # noqa: RET504
4755

48-
def reparametrize(self, mu, log_var):
56+
def reparametrize(self, mu: Tensor, log_var: Tensor) -> Tensor:
4957
std = torch.exp(0.5 * log_var)
5058
eps = torch.randn_like(std)
5159
return mu + eps * std
52-
53-
54-
class EncoderProcessorDecoder(L.LightningModule):
55-
"""Encoder-Processor-Decoder Model."""
56-
57-
encoder_decoder: EncoderDecoder
58-
processor: Processor
59-
60-
def __init__(self): ...
61-
62-
def forward(self, x: Tensor) -> Tensor:
63-
return self.encoder_decoder.decoder(
64-
self.processor(self.encoder_decoder.encoder(x))
65-
)
66-
67-
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
68-
69-
def configure_optmizers(self): ...
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from typing import Any
2+
13
import lightning as L
24

35
from auto_cast.models.encoder_decoder import EncoderDecoder
6+
from auto_cast.preprocessor.base import Preprocessor
47
from auto_cast.processors.base import Processor
58
from auto_cast.types import Batch, Tensor
69

@@ -10,14 +13,19 @@ class EncoderProcessorDecoder(L.LightningModule):
1013

1114
encoder_decoder: EncoderDecoder
1215
processor: Processor
16+
preprocessor: Preprocessor
1317

1418
def __init__(self): ...
1519

16-
def forward(self, x: Tensor) -> Tensor:
20+
def forward(self, *args: Any, **kwargs: Any) -> Any:
1721
return self.encoder_decoder.decoder(
18-
self.processor(self.encoder_decoder.encoder(x))
22+
self.processor(self.encoder_decoder.encoder(*args, **kwargs))
1923
)
2024

21-
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: ...
25+
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
26+
x = self.preprocessor(batch)
27+
output = self(x)
28+
loss = self.processor.loss_func(output, batch["output_fields"])
29+
return loss # noqa: RET504
2230

2331
def configure_optmizers(self): ...

src/auto_cast/preprocessor/base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
from torch import nn
1+
from typing import Any
22

3-
from auto_cast.types import Batch, Tensor
3+
from auto_cast.types import Batch
44

55

6-
class Preprocessor(nn.Module):
7-
"""Base Preprocessor."""
6+
class Preprocessor:
7+
"""Base Preprocessor.
88
9-
def forward(self, x: Batch) -> Tensor:
9+
This is not trainable but can combine the elements of the batch into a form that
10+
can be passed to the call/forward of the models.
11+
12+
"""
13+
14+
def __call__(self, x: Batch) -> Any:
1015
"""Forward Pass through the Preprocessor."""
1116
msg = "To implement."
1217
raise NotImplementedError(msg)

src/auto_cast/processors/base.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,33 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
23

34
import lightning as L
45
import torch
5-
from torch import nn # noqa: F401
6+
from torch import nn
67

78
from auto_cast.preprocessor import Preprocessor
89
from auto_cast.types import Batch, RolloutOutput, Tensor
910

1011

11-
class Processor(L.LightningModule, ABC):
12+
class Processor(L.LightningModule):
1213
"""Processor Base Class."""
1314

1415
teacher_forcing_ratio: float
1516
stride: int
1617
max_rollout_steps: int
1718
preprocessor: Preprocessor
19+
loss_func: nn.Module
1820

19-
def __init__(self):
20-
pass
21-
22-
# Option 1
23-
def forward(self, x: Tensor) -> Tensor:
21+
def forward(self, *args: Any, **kwargs: Any) -> Any:
2422
"""Forward pass through the Processor."""
2523
msg = "To implement."
2624
raise NotImplementedError(msg)
2725

2826
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
2927
x = self.preprocessor(batch)
30-
return self(x)
31-
32-
# # Option 2
33-
# def forward(self, x: Batch) -> Tensor:
34-
# """Forward pass through the Processor."""
35-
# msg = "To implement."
36-
# raise NotImplementedError(msg)
37-
38-
# def training_step(self, batch: Batch, batch_idx: int):
39-
# self(batch)
28+
output = self(x)
29+
loss = self.loss_func(output, batch["output_fields"])
30+
return loss # noqa: RET504
4031

4132
def configure_optmizers(self):
4233
pass
@@ -48,9 +39,7 @@ def rollout(self, batch: Batch) -> RolloutOutput:
4839
for _time_step in range(0, self.max_rollout_steps, self.stride):
4940
x = self.preprocessor(batch)
5041
pred_outs.append(self(x))
51-
gt_outs.append(
52-
batch["output_fields"]
53-
) # Q: this assumes we have output fields
42+
gt_outs.append(batch["output_fields"]) # This assumes we have output fields
5443
return torch.stack(pred_outs), torch.stack(gt_outs)
5544

5645

0 commit comments

Comments
 (0)