11from typing import Dict , Optional
22
33import torch
4+ from huggingface_hub import PyTorchModelHubMixin
45from jaxtyping import Float
56from torch import Tensor
67
78from aion .codecs .quantizers import Quantizer
9+ from aion .codecs .quantizers .scalar import (
10+ ScalarLogReservoirQuantizer ,
11+ ScalarReservoirQuantizer ,
12+ )
813from aion .codecs .tokenizers .base import QuantizedCodec
914
1015
11- class ScalarIdentityCodec (QuantizedCodec ):
16+ class BaseScalarIdentityCodec (QuantizedCodec , PyTorchModelHubMixin ):
1217 """Codec for scalar quantities.
1318
1419 A codec that embeds scalar quantities through an identity mapping. A
@@ -34,3 +39,21 @@ def _encode(self, x: Dict[str, Dict[str, Float[Tensor, "b t"]]]) -> Tensor:
3439
3540 def _decode (self , z : torch .FloatTensor ) -> Dict [str , torch .FloatTensor ]:
3641 return {self .modality : z }
42+
43+
44+ class ScalarReservoirCodec (BaseScalarIdentityCodec ):
45+ def __init__ (self , modality : str , codebook_size : int , reservoir_size : int ):
46+ quantizer = ScalarReservoirQuantizer (
47+ codebook_size = codebook_size ,
48+ reservoir_size = reservoir_size ,
49+ )
50+ super ().__init__ (modality , quantizer )
51+
52+
53+ class ScalarLogReservoirCodec (BaseScalarIdentityCodec ):
54+ def __init__ (self , modality : str , codebook_size : int , reservoir_size : int ):
55+ quantizer = ScalarLogReservoirQuantizer (
56+ codebook_size = codebook_size ,
57+ reservoir_size = reservoir_size ,
58+ )
59+ super ().__init__ (modality , quantizer )
0 commit comments