File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 66from aion .codecs .quantizers import Quantizer
77
88
9- class Codec (ABC ):
9+ class Codec (ABC , torch . nn . Module ):
1010 """Abstract definition of a Codec.
1111
1212 A codec embeds a specific type of data into a sequence of either
@@ -53,6 +53,13 @@ def decode(
5353 """Encodes a given batch of samples into latent space."""
5454 return self ._decode (z )
5555
56+ def forward (
57+ self ,
58+ x : Float [torch .Tensor , " b c *input_shape" ],
59+ channel_mask : Bool [torch .Tensor , " b c" ],
60+ ) -> Float [torch .Tensor , " b c1 *code_shape" ]:
61+ return self .encode (x , channel_mask )
62+
5663
5764class QuantizedCodec (Codec ):
5865 def __init__ (self , quantizer : Quantizer ):
Original file line number Diff line number Diff line change @@ -114,7 +114,7 @@ def __init__(
114114 mult_factor: Multiplication factor.
115115 """
116116 # Get MagViT architecture
117- self . model = MagVitAE (
117+ model = MagVitAE (
118118 n_bands = multisurvey_projection_dims ,
119119 hidden_dims = hidden_dims ,
120120 n_compressions = n_compressions ,
@@ -123,11 +123,12 @@ def __init__(
123123 super ().__init__ (
124124 n_bands ,
125125 quantizer ,
126- self . model .encode ,
127- self . model .decode ,
126+ model .encode ,
127+ model .decode ,
128128 hidden_dims ,
129129 embedding_dim ,
130130 multisurvey_projection_dims ,
131131 range_compression_factor ,
132132 mult_factor ,
133133 )
134+ self .model = model
You can’t perform that action at this time.
0 commit comments