Skip to content

Commit 2d77dfa

Browse files
sgreenburycispraguemarjanfamili
committed
Update initial API
- Remove preprocessor (encoder to be used) - Change batch types to dataclasses - Add EncodedBatch - Initial rollout method Co-authored-by: Christopher Iliffe Sprague <cisprague@users.noreply.github.com> Co-authored-by: Marjan Famili <marjanfamili@users.noreply.github.com>
1 parent 3202252 commit 2d77dfa

6 files changed

Lines changed: 97 additions & 55 deletions

File tree

src/auto_cast/decoders/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from torch import nn
3+
from torch import Tensor, nn
44

55

66
class Decoder(nn.Module):
Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,70 @@
11
from typing import Any
22

33
import lightning as L
4+
import torch
5+
from torch import nn
46

57
from auto_cast.models.encoder_decoder import EncoderDecoder
6-
from auto_cast.preprocessor.base import Preprocessor
78
from auto_cast.processors.base import Processor
8-
from auto_cast.types import Batch, Tensor
9+
from auto_cast.types import Batch, RolloutOutput, Tensor
910

1011

1112
class EncoderProcessorDecoder(L.LightningModule):
1213
"""Encoder-Processor-Decoder Model."""
1314

1415
encoder_decoder: EncoderDecoder
1516
processor: Processor
16-
preprocessor: Preprocessor
17+
teacher_forcing_ratio: float
18+
stride: int
19+
max_rollout_steps: int
20+
loss_func: nn.Module
1721

1822
def __init__(self): ...
1923

24+
def from_encoder_processor_decoder(
25+
self, encoder_decoder: EncoderDecoder, processor: Processor
26+
) -> None:
27+
self.encoder_decoder = encoder_decoder
28+
self.processor = processor
29+
2030
def forward(self, *args: Any, **kwargs: Any) -> Any:
2131
return self.encoder_decoder.decoder(
2232
self.processor(self.encoder_decoder.encoder(*args, **kwargs))
2333
)
2434

2535
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
26-
x = self.preprocessor(batch)
27-
output = self(x)
28-
loss = self.processor.loss_func(output, batch["output_fields"])
36+
output = self(batch)
37+
loss = self.processor.loss_func(output, batch.output_fields)
2938
return loss # noqa: RET504
3039

31-
def configure_optmizers(self): ...
40+
def configure_optimizers(self): ...
41+
42+
def rollout(self, batch: Batch) -> RolloutOutput:
43+
"""Rollout over multiple time steps."""
44+
pred_outs, gt_outs = [], []
45+
for _ in range(0, self.max_rollout_steps, self.stride):
46+
x = self.encoder_decoder.encoder(batch)
47+
pred_outs.append(self.processor.map(x))
48+
# TODO: combining teacher forcing logic
49+
gt_outs.append(batch.output_fields) # This assumes we have output fields
50+
return torch.stack(pred_outs), torch.stack(gt_outs)
51+
52+
53+
# TODO: consider if separate rollout class would be better
54+
class Rollout:
55+
max_rollout_steps: int
56+
stride: int
57+
58+
def rollout(
59+
self,
60+
batch: Batch,
61+
model: Processor | EncoderProcessorDecoder,
62+
) -> RolloutOutput:
63+
"""Rollout over multiple time steps."""
64+
pred_outs, gt_outs = [], []
65+
for _ in range(0, self.max_rollout_steps, self.stride):
66+
output = model(batch)
67+
pred_outs.append(output)
68+
# TODO: logic for moving window with teacher forcing that assigns
69+
gt_outs.append(batch.output_fields) # This assumes we have output fields
70+
return torch.stack(pred_outs), torch.stack(gt_outs)

src/auto_cast/preprocessor/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/auto_cast/preprocessor/base.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/auto_cast/processors/base.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import torch
66
from torch import nn
77

8-
from auto_cast.preprocessor import Preprocessor
9-
from auto_cast.types import Batch, RolloutOutput, Tensor
8+
from auto_cast.types import Batch, EncodedBatch, RolloutOutput, Tensor
109

1110

1211
class Processor(L.LightningModule):
@@ -15,43 +14,45 @@ class Processor(L.LightningModule):
1514
teacher_forcing_ratio: float
1615
stride: int
1716
max_rollout_steps: int
18-
preprocessor: Preprocessor
1917
loss_func: nn.Module
2018

21-
def forward(self, *args: Any, **kwargs: Any) -> Any:
19+
def forward(self, *args, **kwargs: Any) -> Any:
2220
"""Forward pass through the Processor."""
2321
msg = "To implement."
2422
raise NotImplementedError(msg)
2523

26-
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
27-
x = self.preprocessor(batch)
28-
output = self(x)
29-
loss = self.loss_func(output, batch["output_fields"])
24+
def training_step(self, batch: EncodedBatch, batch_idx: int) -> Tensor: # noqa: ARG002
25+
output = self.map(batch.encoded_inputs)
26+
loss = self.loss_func(output, batch.encoded_output_fields)
3027
return loss # noqa: RET504
3128

32-
def configure_optmizers(self):
33-
pass
29+
@abstractmethod
30+
def map(self, x: Tensor) -> Tensor:
31+
"""Map input window of states/times to output window."""
32+
33+
def configure_optimizers(self): ...
3434

35-
def rollout(self, batch: Batch) -> RolloutOutput:
35+
def rollout(self, batch: EncodedBatch) -> RolloutOutput:
3636
"""Rollout over multiple time steps."""
37-
pred_outs = []
38-
gt_outs = []
39-
for _time_step in range(0, self.max_rollout_steps, self.stride):
40-
x = self.preprocessor(batch)
41-
pred_outs.append(self(x))
42-
gt_outs.append(batch["output_fields"]) # This assumes we have output fields
37+
pred_outs, gt_outs = [], []
38+
for _ in range(0, self.max_rollout_steps, self.stride):
39+
pred_outs.append(self.map(batch.encoded_inputs))
40+
# TODO: combining teacher forcing logic
41+
gt_outs.append(
42+
batch.encoded_output_fields
43+
) # This assumes we have output fields
4344
return torch.stack(pred_outs), torch.stack(gt_outs)
4445

4546

4647
class DiscreteProcessor(Processor, ABC):
4748
"""DiscreteProcessor."""
4849

4950
@abstractmethod
50-
def map(self, x: Batch) -> Tensor:
51+
def map(self, x: Tensor) -> Tensor:
5152
...
5253
# Map input window of states/times to output window
5354

54-
def rollout(self, batch: Batch) -> RolloutOutput:
55+
def rollout(self, batch: EncodedBatch) -> RolloutOutput:
5556
...
5657

5758
# Use self.map to generate trajectory
@@ -60,7 +61,7 @@ def rollout(self, batch: Batch) -> RolloutOutput:
6061
class FlowBasedGenerativeProcessor(DiscreteProcessor):
6162
"""Flow-based generative processor."""
6263

63-
def map(self, x: Batch) -> Tensor:
64+
def map(self, x: Tensor) -> Tensor:
6465
...
6566
# Sample generative model def loss(self, ...):...
6667
# Flow matc

src/auto_cast/types/__init__.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,38 @@
1+
from dataclasses import dataclass
2+
13
import torch
24
from torch.utils.data import DataLoader
35

46
Tensor = torch.Tensor
57
Input = Tensor | DataLoader
68
RolloutOutput = tuple[Tensor, None] | tuple[Tensor, Tensor]
79

8-
Batch = dict[str, Tensor]
10+
# Batch = dict[str, Tensor]
11+
# EncodedBatch = dict[str, Tensor]
12+
913

1014
# TODO: Could be a dataclass if we want more structure
11-
# @dataclass
12-
# class Batch:
13-
# input_fields: Tensor
14-
# output_fields: Tensor
15-
# constant_scalars: Tensor
16-
# constant_fields: Tensor
15+
@dataclass
16+
class Batch:
17+
input_fields: Tensor
18+
output_fields: Tensor
19+
constant_scalars: Tensor
20+
constant_fields: Tensor
21+
22+
23+
@dataclass
24+
class EncodedBatch:
25+
encoded_inputs: Tensor
26+
encoded_output_fields: Tensor
27+
encoded_info: dict[str, Tensor]
28+
29+
30+
class EncoderForBatch:
31+
"""EncoderForBatch."""
32+
33+
def __call__(self, batch: Batch) -> EncodedBatch:
34+
return EncodedBatch(
35+
encoded_inputs=batch.input_fields,
36+
encoded_output_fields=batch.output_fields,
37+
encoded_info={},
38+
)

0 commit comments

Comments
 (0)