88
99from aion .codecs .base import Codec
1010from aion .codecs .quantizers import Quantizer
11- from aion .codecs .quantizers .scalar import ComposedScalarQuantizer
11+ from aion .codecs .quantizers .scalar import ComposedScalarQuantizer , IdentityQuantizer , ScalarReservoirQuantizer
1212from aion .modalities import Catalog
1313
14- __all__ = ["CatalogIdentityCodec" ]
1514
16-
17- class CatalogIdentityCodec (Codec , PyTorchModelHubMixin ):
15+ class CatalogCodec (Codec , PyTorchModelHubMixin ):
1816 """Codec for catalog quantities.
1917
2018 A codec that embeds catalog quantities through an identity mapping. A
2119 quantizer is applied if specified.
22-
23- Args:
24- catalog_keys: List[str]
25- List of catalog keys to encode.
26- quantizers: Optional[List[Quantizer]]
27- Optional list of quantizers for each catalog key.
28- mask_value: int
29- Value used to indicate masked/missing data.
3020 """
3121
3222 def __init__ (
3323 self ,
34- catalog_keys : List [str ],
35- quantizers : Optional [List [Quantizer ]] = None ,
36- mask_value : int = 9999 ,
24+ mask_value : int = 9999 ,
3725 ):
3826 super ().__init__ ()
3927 self ._modality = Catalog
40- self ._catalog_keys = catalog_keys
28+ catalog_keys = ['X' , 'Y' , 'SHAPE_E1' , 'SHAPE_E2' , 'SHAPE_R' ]
29+ quantizers = [IdentityQuantizer (96 ),
30+ IdentityQuantizer (96 ),
31+ ScalarReservoirQuantizer (1024 , 100000 ),
32+ ScalarReservoirQuantizer (1024 , 100000 ),
33+ ScalarReservoirQuantizer (1024 , 100000 )]
4134 self .mask_value = mask_value
42- if quantizers :
43- assert len (catalog_keys ) == len (quantizers ), (
44- "Number of catalog keys and quantizers must match"
45- )
46- _quantizer = OrderedDict ()
47- for key , quantizer in zip (catalog_keys , quantizers ):
48- _quantizer [key ] = quantizer
49- self ._quantizer = ComposedScalarQuantizer (_quantizer )
50- else :
51- self ._quantizer = None
35+ self ._catalog_keys = catalog_keys
36+ assert len (catalog_keys ) == len (quantizers ), (
37+ "Number of catalog keys and quantizers must match"
38+ )
39+ _quantizer = OrderedDict ()
40+ for key , quantizer in zip (catalog_keys , quantizers ):
41+ _quantizer [key ] = quantizer
42+ self ._quantizer = ComposedScalarQuantizer (_quantizer )
5243
5344 @property
5445 def modality (self ) -> Type [Catalog ]:
@@ -61,7 +52,7 @@ def quantizer(self) -> Optional[Quantizer]:
6152 def _encode (self , x : Catalog ) -> Dict [str , Tensor ]:
6253 encoded = OrderedDict ()
6354 for key in self ._catalog_keys :
64- catalog_value = x [ self . modality ][ key ]
55+ catalog_value = getattr ( x , key )
6556 mask = catalog_value != self .mask_value
6657 catalog_value = catalog_value [mask ]
6758 encoded [key ] = catalog_value
@@ -87,7 +78,7 @@ def encode(self, x: Catalog) -> Float[Tensor, "b c1 *code_shape"]:
8778 return encoded
8879
8980 def _decode (self , z : Dict [str , Tensor ]) -> Catalog :
90- return Catalog (data = z )
81+ return Catalog (** z )
9182
9283 def decode (self , z : Float [Tensor , "b c1 *code_shape" ]) -> Catalog :
9384 B , LC = z .shape
0 commit comments