Skip to content

Commit d45594e

Browse files
committed
Fixes following merge
1 parent 2c31a32 commit d45594e

10 files changed

Lines changed: 38 additions & 23 deletions

File tree

src/auto_cast/decoders/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88
class Decoder(nn.Module, ABC):
99
"""Base Decoder."""
1010

11-
decoder_model: nn.Module
12-
latent_dim: int
13-
14-
def __init__(self, latent_dim: int, output_channels: int) -> None:
15-
super().__init__()
16-
self.latent_dim = latent_dim
17-
self.output_channels = output_channels
18-
1911
def postprocess(self, decoded: Tensor) -> TensorBTSPlusC:
2012
"""Optionally transform the decoded tensor before returning.
2113

src/auto_cast/decoders/channels_last.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from einops import rearrange
22

33
from auto_cast.decoders.base import Decoder
4-
from auto_cast.types import Tensor
4+
from auto_cast.types import Tensor, TensorBCWH, TensorBTWHC
55

66

77
class ChannelsLast(Decoder):
@@ -33,3 +33,6 @@ def forward(self, x: Tensor) -> Tensor:
3333
)
3434
# Rearrange to channels-last: (B, C, T, W, H) -> (B, T, W, H, C)
3535
return rearrange(x, "b c t w h -> b t w h c")
36+
37+
def decode(self, z: TensorBCWH) -> TensorBTWHC:
38+
return self.forward(z)

src/auto_cast/decoders/dc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ def __init__(
8686
checkpointing: bool = False,
8787
identity_init: bool = True,
8888
) -> None:
89-
super().__init__(latent_dim=in_channels, output_channels=out_channels)
90-
89+
super().__init__()
90+
self.latent_dim = in_channels
91+
self.output_channels = out_channels
9192
attention_heads = attention_heads or {}
9293
assert len(hid_blocks) == len(hid_channels)
9394

src/auto_cast/encoders/permute_concat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from einops import rearrange
33

44
from auto_cast.encoders.base import Encoder
5-
from auto_cast.types import Batch, Tensor
5+
from auto_cast.types import Batch, Tensor, TensorBCWH
66

77

88
class PermuteConcat(Encoder):
@@ -27,3 +27,6 @@ def forward(self, batch: Batch) -> Tensor:
2727
scalars = scalars.expand(b, -1, t, w, h)
2828
x = torch.cat([x, scalars], dim=1)
2929
return rearrange(x, "b c t w h -> b (c t) w h")
30+
31+
def encode(self, batch: Batch) -> TensorBCWH:
32+
return self.forward(batch)

src/auto_cast/models/encoder_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Self
2+
13
import lightning as L
24
import torch
35
from torch import nn
@@ -18,6 +20,16 @@ class EncoderDecoder(L.LightningModule):
1820
def __init__(self):
1921
super().__init__()
2022

23+
@classmethod
24+
def from_encoder_decoder(
25+
cls, encoder: Encoder, decoder: Decoder, loss_func: nn.Module, **kwargs: Any
26+
) -> Self:
27+
instance = cls(**kwargs)
28+
instance.encoder = encoder
29+
instance.decoder = decoder
30+
instance.loss_func = loss_func
31+
return instance
32+
2133
def forward(self, batch: Batch) -> TensorBTSPlusC:
2234
return self.decoder(self.encoder(batch))
2335

src/auto_cast/models/encoder_processor_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from auto_cast.models.encoder_decoder import EncoderDecoder
88
from auto_cast.processors.base import Processor
99
from auto_cast.processors.rollout import RolloutMixin
10-
from auto_cast.types import Batch, EncodedBatch, Tensor, TensorBTSPlusC
10+
from auto_cast.types import Batch, EncodedBatch, Tensor, TensorBMStarL, TensorBTSPlusC
1111

1212

1313
class EncoderProcessorDecoder(RolloutMixin[Batch], L.LightningModule):

src/auto_cast/types/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
# Spatial only (no time dimension)
4040
TensorBCSPlus = Float[Tensor, "batch channel spatial *spatial"]
4141
TensorBWHC = Float[Tensor, "batch width height channel"]
42+
TensorBCWH = Float[Tensor, "batch channel width height"]
4243
TensorBWHDC = Float[Tensor, "batch width height depth channel"]
4344
TensorBSPlusC = Float[Tensor, "batch spatial *spatial channel"]
4445

tests/decoders/test_channels_last.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,13 @@ def test_channels_last_reorders_dimensions():
1111

1212
# Input shape: (B, C*T, W, H) - simulating encoder output
1313
x = torch.randn(batch_size, channels * time_steps, width, height)
14-
1514
output = decoder(x)
1615

1716
# Expected output shape: (B, T, W, H, C)
1817
assert output.shape == (batch_size, time_steps, width, height, channels)
1918

2019
# Verify the transformation is correct by checking a round-trip
21-
# Create a known input in (B, T, W, H, C) format
2220
original = torch.randn(batch_size, time_steps, width, height, channels)
23-
# Simulate encoder: (B, T, W, H, C) -> (B, C, T, W, H) -> (B, C*T, W, H)
2421
encoded = rearrange(original, "b t w h c -> b (c t) w h")
25-
# Decode back
2622
decoded = decoder(encoded)
27-
# Should match original
28-
assert torch.allclose(decoded, original)
23+
assert torch.allclose(decoded, original), "Decoded does not match original input"

tests/models/test_encoder_processor_decoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def test_encoder_processor_decoder_training_step_runs(make_toy_batch, dummy_load
3535
encoder = PermuteConcat(with_constants=False)
3636
decoder = ChannelsLast(output_channels=output_channels, time_steps=time_steps)
3737
loss = nn.MSELoss()
38-
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder, loss_func=loss)
38+
encoder_decoder = EncoderDecoder.from_encoder_decoder(
39+
encoder=encoder, decoder=decoder, loss_func=loss
40+
)
3941

4042
processor = TinyProcessor(in_channels=merged_channels)
4143
model = EncoderProcessorDecoder.from_encoder_processor_decoder(
@@ -67,7 +69,9 @@ def test_encoder_processor_decoder_rollout_is_mixin_backed(make_toy_batch):
6769
encoder = PermuteConcat(with_constants=False)
6870
decoder = ChannelsLast(output_channels=output_channels, time_steps=time_steps)
6971
loss = nn.MSELoss()
70-
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder, loss_func=loss)
72+
encoder_decoder = EncoderDecoder.from_encoder_decoder(
73+
encoder=encoder, decoder=decoder, loss_func=loss
74+
)
7175
processor = TinyProcessor(in_channels=merged_channels)
7276
model = EncoderProcessorDecoder.from_encoder_processor_decoder(
7377
encoder_decoder=encoder_decoder,

tests/models/test_vae.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ class _FlatDecoder(Decoder):
6969
"""Minimal decoder that reconstructs flat tensors for tests."""
7070

7171
def __init__(self, latent_dim: int, output_dim: int) -> None:
72-
super().__init__(latent_dim=latent_dim, output_channels=output_dim)
72+
super().__init__()
73+
self.latent_dim = latent_dim
74+
self.output_dim = output_dim
7375
self.net = nn.Sequential(
7476
nn.Linear(latent_dim, 2 * latent_dim),
7577
nn.GELU(),
@@ -111,7 +113,9 @@ class _FlatteningDecoder(Decoder):
111113
"""Decoder that maps flat latents back to spatial tensors."""
112114

113115
def __init__(self, latent_dim: int, output_shape: tuple[int, ...]) -> None:
114-
super().__init__(latent_dim=latent_dim, output_channels=output_shape[0])
116+
super().__init__()
117+
self.latent_dim = latent_dim
118+
self.output_channels = output_shape[0]
115119
self.output_shape = output_shape
116120
out_features = math.prod(output_shape)
117121
self.net = nn.Sequential(

0 commit comments

Comments
 (0)