11from einops import rearrange
22
33from auto_cast .decoders .base import Decoder
4- from auto_cast .types import Tensor , TensorBCWH , TensorBTWHC
4+ from auto_cast .types import Tensor , TensorBCTSPlus , TensorBTSPlusC
55
66
77class ChannelsLast (Decoder ):
8- """Decoder that splits merged (channel*time) back to (time, channel) and reorders to channels-last format.""" # noqa: E501
8+ """Decoder that splits channels and time and reorders to channels-last format.
9+
10+ The decoder splits (channel*time) back to (time, channel) and moves the channels
11+ to the last dimension assuming one or more spatial dimensions that are the last
12+ dimensions of the ecoded tensor.
13+ """
914
1015 def __init__ (self , output_channels : int , time_steps : int = 1 ) -> None :
1116 """Initialize the ChannelsLast decoder.
@@ -24,15 +29,12 @@ def __init__(self, output_channels: int, time_steps: int = 1) -> None:
2429 def forward (self , x : Tensor ) -> Tensor :
2530 """Forward pass through the ChannelsLast decoder.
2631
27- Expects input shape (B, C*T, W, H ) and outputs (B, T, W, H , C).
32+ Expects input shape (B, C*T, spatial... ) and outputs (B, T, spatial... , C).
2833 """
29- # Split merged (C*T) dimension back into separate C and T
30- # x: (B, C*T, W, H) -> (B, C, T, W, H)
3134 x = rearrange (
32- x , "b (c t) w h -> b c t w h " , c = self .output_channels , t = self .time_steps
35+ x , "b (c t) ... -> b c t ... " , c = self .output_channels , t = self .time_steps
3336 )
34- # Rearrange to channels-last: (B, C, T, W, H) -> (B, T, W, H, C)
35- return rearrange (x , "b c t w h -> b t w h c" )
37+ return rearrange (x , "b c t ... -> b t ... c" )
3638
37- def decode (self , z : TensorBCWH ) -> TensorBTWHC :
39+ def decode (self , z : TensorBCTSPlus ) -> TensorBTSPlusC :
3840 return self .forward (z )
0 commit comments