Skip to content

Commit 8c37cc4

Browse files
committed
Make the tokenizer a pytorch module
Ease model weight loading.
1 parent 1cf702b commit 8c37cc4

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

aion/codecs/tokenizers/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from aion.codecs.quantizers import Quantizer
77

88

9-
class Codec(ABC):
9+
class Codec(ABC, torch.nn.Module):
1010
"""Abstract definition of a Codec.
1111
1212
A codec embeds a specific type of data into a sequence of either
@@ -53,6 +53,13 @@ def decode(
5353
"""Encodes a given batch of samples into latent space."""
5454
return self._decode(z)
5555

56+
def forward(
57+
self,
58+
x: Float[torch.Tensor, " b c *input_shape"],
59+
channel_mask: Bool[torch.Tensor, " b c"],
60+
) -> Float[torch.Tensor, " b c1 *code_shape"]:
61+
return self.encode(x, channel_mask)
62+
5663

5764
class QuantizedCodec(Codec):
5865
def __init__(self, quantizer: Quantizer):

aion/codecs/tokenizers/image.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
mult_factor: Multiplication factor.
115115
"""
116116
# Get MagViT architecture
117-
self.model = MagVitAE(
117+
model = MagVitAE(
118118
n_bands=multisurvey_projection_dims,
119119
hidden_dims=hidden_dims,
120120
n_compressions=n_compressions,
@@ -123,11 +123,12 @@ def __init__(
123123
super().__init__(
124124
n_bands,
125125
quantizer,
126-
self.model.encode,
127-
self.model.decode,
126+
model.encode,
127+
model.decode,
128128
hidden_dims,
129129
embedding_dim,
130130
multisurvey_projection_dims,
131131
range_compression_factor,
132132
mult_factor,
133133
)
134+
self.model = model

0 commit comments

Comments
 (0)