Skip to content

Commit a0405b9

Browse files
committed
Revise loss API with generic BatchT input
1 parent aeb0118 commit a0405b9

8 files changed

Lines changed: 26 additions & 25 deletions

File tree

src/auto_cast/models/encoder_processor_decoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,7 @@ def __init__(
165165

166166
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
167167
encoded_batch = self.encoder_decoder.encoder.encode_batch(batch)
168-
output = self.processor.map(encoded_batch.encoded_inputs)
169-
loss = self.processor.loss(output, encoded_batch.encoded_output_fields)
168+
loss = self.processor.loss(encoded_batch)
170169
self.log(
171170
"train_loss", loss, prog_bar=True, batch_size=batch.input_fields.shape[0]
172171
)

src/auto_cast/models/processor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,14 @@ def forward(self, x: TensorBMStarL) -> TensorBMStarL:
4545
return self.processor.map(x)
4646

4747
def training_step(self, batch: EncodedBatch, batch_idx: int) -> Tensor: # noqa: ARG002
48-
output = self.processor.map(batch.encoded_inputs)
49-
loss = self.processor.loss(output, batch.encoded_output_fields)
48+
loss = self.processor.loss(batch)
5049
self.log(
5150
"train_loss", loss, prog_bar=True, batch_size=batch.encoded_inputs.shape[0]
5251
)
5352
return loss
5453

5554
def validation_step(self, batch: EncodedBatch, batch_idx: int) -> Tensor: # noqa: ARG002
56-
output = self.processor.map(batch.encoded_inputs)
57-
loss = self.processor.loss(output, batch.encoded_output_fields)
55+
loss = self.processor.loss(batch)
5856
self.log(
5957
"val_loss", loss, prog_bar=True, batch_size=batch.encoded_inputs.shape[0]
6058
)

src/auto_cast/nn/fno.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55

66
from auto_cast.processors.base import Processor
7-
from auto_cast.types import Tensor
7+
from auto_cast.types import EncodedBatch, Tensor
88

99

1010
@runtime_checkable
@@ -13,7 +13,7 @@ class _HasGridCache(Protocol):
1313
_res: Any | None
1414

1515

16-
class FNOProcessor(Processor):
16+
class FNOProcessor(Processor[EncodedBatch]):
1717
"""Fourier Neural Operator Module.
1818
1919
A discrete processor that uses a Fourier Neural Operator (FNO) to learn
@@ -91,5 +91,6 @@ def _apply(self, fn, recurse: bool = True):
9191
def map(self, x: Tensor) -> Tensor:
9292
return self(x)
9393

94-
def loss(self, output: Tensor, target: Tensor) -> Tensor:
95-
return self.loss_func(output, target)
94+
def loss(self, batch: EncodedBatch) -> Tensor:
95+
output = self.map(batch.encoded_inputs)
96+
return self.loss_func(output, batch.encoded_output_fields)

src/auto_cast/processors/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any
2+
from typing import Any, Generic
33

44
from torch import nn
55

6-
from auto_cast.types import Tensor
6+
from auto_cast.types import BatchT, Tensor
77

88

9-
class Processor(ABC, nn.Module):
9+
class Processor(ABC, nn.Module, Generic[BatchT]):
1010
"""Processor Base Class."""
1111

1212
def __init__(
@@ -27,7 +27,7 @@ def __init__(
2727
setattr(self, key, value)
2828

2929
@abstractmethod
30-
def loss(self, output: Tensor, target: Tensor) -> Tensor:
30+
def loss(self, batch: BatchT) -> Tensor:
3131
"""Compute loss between output and target."""
3232

3333
@abstractmethod

src/auto_cast/processors/rollout.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import Generic, TypeVar
4+
from typing import Generic
55

66
import torch
77

8-
from auto_cast.types import RolloutOutput, Tensor
9-
10-
BatchT = TypeVar("BatchT")
8+
from auto_cast.types import BatchT, RolloutOutput, Tensor
119

1210

1311
class RolloutMixin(ABC, Generic[BatchT]):

src/auto_cast/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Sequence
22
from dataclasses import dataclass
3+
from typing import TypeVar
34

45
import torch
56
from jaxtyping import Float
@@ -56,6 +57,9 @@
5657
# Rollout output type
5758
RolloutOutput = tuple[Tensor, None] | tuple[Tensor, Tensor]
5859

60+
# Generic batch type variable
61+
BatchT = TypeVar("BatchT")
62+
5963

6064
@dataclass
6165
class Sample:

tests/models/test_encoder_processor_decoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from auto_cast.models.encoder_decoder import EncoderDecoder
77
from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder
88
from auto_cast.processors.base import Processor
9-
from auto_cast.types import Tensor
9+
from auto_cast.types import EncodedBatch, Tensor
1010

1111

12-
class TinyProcessor(Processor):
12+
class TinyProcessor(Processor[EncodedBatch]):
1313
def __init__(self, in_channels: int = 1) -> None:
1414
super().__init__()
1515
self.conv = nn.Conv2d(
@@ -25,8 +25,9 @@ def forward(self, x: Tensor) -> Tensor:
2525
def map(self, x: Tensor) -> Tensor:
2626
return self(x)
2727

28-
def loss(self, output: Tensor, target: Tensor) -> Tensor:
29-
return self.loss_func(output, target)
28+
def loss(self, batch: EncodedBatch) -> Tensor:
29+
outputs = self(batch.encoded_inputs)
30+
return self.loss_func(outputs, batch.encoded_output_fields)
3031

3132

3233
def test_encoder_processor_decoder_training_step_runs(make_toy_batch, dummy_loader):

tests/processors/test_processors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _toy_encoded_batch(
2424
)
2525

2626

27-
class _IdentityProcessor(Processor):
27+
class _IdentityProcessor(Processor[EncodedBatch]):
2828
def __init__(self) -> None:
2929
super().__init__(
3030
loss_func=nn.MSELoss(),
@@ -33,8 +33,8 @@ def __init__(self) -> None:
3333
def map(self, x: Tensor) -> Tensor:
3434
return x
3535

36-
def loss(self, output: Tensor, target: Tensor) -> Tensor:
37-
return self.loss_func(output, target)
36+
def loss(self, batch: EncodedBatch) -> Tensor:
37+
return self.loss_func(batch.encoded_inputs, batch.encoded_output_fields)
3838

3939

4040
def test_processor_rollout_handles_encoded_batches():

0 commit comments

Comments
 (0)