Skip to content

Commit d45a944

Browse files
committed
Refactor ChannelsLast decoder
1 parent d45594e commit d45a944

1 file changed

Lines changed: 11 additions & 9 deletions

File tree

src/auto_cast/decoders/channels_last.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from einops import rearrange
22

33
from auto_cast.decoders.base import Decoder
4-
from auto_cast.types import Tensor, TensorBCWH, TensorBTWHC
4+
from auto_cast.types import Tensor, TensorBCTSPlus, TensorBTSPlusC
55

66

77
class ChannelsLast(Decoder):
8-
"""Decoder that splits merged (channel*time) back to (time, channel) and reorders to channels-last format.""" # noqa: E501
8+
"""Decoder that splits channels and time and reorders to channels-last format.
9+
10+
The decoder splits (channel*time) back to (time, channel) and moves the channels
11+
to the last dimension assuming one or more spatial dimensions that are the last
12+
dimensions of the ecoded tensor.
13+
"""
914

1015
def __init__(self, output_channels: int, time_steps: int = 1) -> None:
1116
"""Initialize the ChannelsLast decoder.
@@ -24,15 +29,12 @@ def __init__(self, output_channels: int, time_steps: int = 1) -> None:
2429
def forward(self, x: Tensor) -> Tensor:
2530
"""Forward pass through the ChannelsLast decoder.
2631
27-
Expects input shape (B, C*T, W, H) and outputs (B, T, W, H, C).
32+
Expects input shape (B, C*T, spatial...) and outputs (B, T, spatial..., C).
2833
"""
29-
# Split merged (C*T) dimension back into separate C and T
30-
# x: (B, C*T, W, H) -> (B, C, T, W, H)
3134
x = rearrange(
32-
x, "b (c t) w h -> b c t w h", c=self.output_channels, t=self.time_steps
35+
x, "b (c t) ... -> b c t ...", c=self.output_channels, t=self.time_steps
3336
)
34-
# Rearrange to channels-last: (B, C, T, W, H) -> (B, T, W, H, C)
35-
return rearrange(x, "b c t w h -> b t w h c")
37+
return rearrange(x, "b c t ... -> b t ... c")
3638

37-
def decode(self, z: TensorBCWH) -> TensorBTWHC:
39+
def decode(self, z: TensorBCTSPlus) -> TensorBTSPlusC:
3840
return self.forward(z)

0 commit comments

Comments
 (0)