Skip to content

Commit 70743b5

Browse files
committed
Adding code for catalog
1 parent 49f9eea commit 70743b5

8 files changed

Lines changed: 153 additions & 42 deletions

File tree

aion/codecs/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from .image import ImageCodec
22
from .scalar import ScalarCodec, LogScalarCodec, MultiScalarCodec
33
from .spectrum import SpectrumCodec
4-
4+
from .catalog import CatalogCodec
5+
from .base import Codec
56

67
__all__ = [
78
"ImageCodec",
89
"ScalarCodec",
910
"LogScalarCodec",
1011
"MultiScalarCodec",
1112
"SpectrumCodec",
13+
"CatalogCodec",
14+
"Codec",
1215
]

aion/codecs/catalog.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,38 @@
88

99
from aion.codecs.base import Codec
1010
from aion.codecs.quantizers import Quantizer
11-
from aion.codecs.quantizers.scalar import ComposedScalarQuantizer
11+
from aion.codecs.quantizers.scalar import ComposedScalarQuantizer, IdentityQuantizer, ScalarReservoirQuantizer
1212
from aion.modalities import Catalog
1313

14-
__all__ = ["CatalogIdentityCodec"]
1514

16-
17-
class CatalogIdentityCodec(Codec, PyTorchModelHubMixin):
15+
class CatalogCodec(Codec, PyTorchModelHubMixin):
1816
"""Codec for catalog quantities.
1917
2018
A codec that embeds catalog quantities through an identity mapping. A
2119
quantizer is applied if specified.
22-
23-
Args:
24-
catalog_keys: List[str]
25-
List of catalog keys to encode.
26-
quantizers: Optional[List[Quantizer]]
27-
Optional list of quantizers for each catalog key.
28-
mask_value: int
29-
Value used to indicate masked/missing data.
3020
"""
3121

3222
def __init__(
3323
self,
34-
catalog_keys: List[str],
35-
quantizers: Optional[List[Quantizer]] = None,
36-
mask_value: int = 9999,
24+
mask_value: int =9999,
3725
):
3826
super().__init__()
3927
self._modality = Catalog
40-
self._catalog_keys = catalog_keys
28+
catalog_keys = ['X', 'Y', 'SHAPE_E1', 'SHAPE_E2', 'SHAPE_R']
29+
quantizers = [IdentityQuantizer(96),
30+
IdentityQuantizer(96),
31+
ScalarReservoirQuantizer(1024, 100000),
32+
ScalarReservoirQuantizer(1024, 100000),
33+
ScalarReservoirQuantizer(1024, 100000)]
4134
self.mask_value = mask_value
42-
if quantizers:
43-
assert len(catalog_keys) == len(quantizers), (
44-
"Number of catalog keys and quantizers must match"
45-
)
46-
_quantizer = OrderedDict()
47-
for key, quantizer in zip(catalog_keys, quantizers):
48-
_quantizer[key] = quantizer
49-
self._quantizer = ComposedScalarQuantizer(_quantizer)
50-
else:
51-
self._quantizer = None
35+
self._catalog_keys = catalog_keys
36+
assert len(catalog_keys) == len(quantizers), (
37+
"Number of catalog keys and quantizers must match"
38+
)
39+
_quantizer = OrderedDict()
40+
for key, quantizer in zip(catalog_keys, quantizers):
41+
_quantizer[key] = quantizer
42+
self._quantizer = ComposedScalarQuantizer(_quantizer)
5243

5344
@property
5445
def modality(self) -> Type[Catalog]:
@@ -61,7 +52,7 @@ def quantizer(self) -> Optional[Quantizer]:
6152
def _encode(self, x: Catalog) -> Dict[str, Tensor]:
6253
encoded = OrderedDict()
6354
for key in self._catalog_keys:
64-
catalog_value = x[self.modality][key]
55+
catalog_value = getattr(x, key)
6556
mask = catalog_value != self.mask_value
6657
catalog_value = catalog_value[mask]
6758
encoded[key] = catalog_value
@@ -87,7 +78,7 @@ def encode(self, x: Catalog) -> Float[Tensor, "b c1 *code_shape"]:
8778
return encoded
8879

8980
def _decode(self, z: Dict[str, Tensor]) -> Catalog:
90-
return Catalog(data=z)
81+
return Catalog(**z)
9182

9283
def decode(self, z: Float[Tensor, "b c1 *code_shape"]) -> Catalog:
9384
B, LC = z.shape

aion/codecs/quantizers/scalar.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,3 +623,68 @@ def codebook_size(self) -> int:
623623
def embedding_dim(self) -> int:
624624
"""Returns the dimension of the codebook entries."""
625625
return 1
626+
627+
628+
class IdentityQuantizer(Quantizer):
629+
"""
630+
Identity quantizer module.
631+
632+
The identity quantizer module takes a batch of tensors and returns the same tensor.
633+
634+
Args:
635+
codebook_size: int
636+
The number of labels to be used as signature for the codebook.
637+
"""
638+
639+
def __init__(self, codebook_size: int):
640+
super().__init__()
641+
self.register_buffer("_codebook_size", torch.tensor(codebook_size))
642+
643+
def forward(
644+
self, z_e: torch.Tensor
645+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
646+
"""Performs a forward pass through the vector quantizer.
647+
Args:
648+
z_e: torch.Tensor (B, C, ...)
649+
The input tensor to be quantized.
650+
Returns:
651+
z_q: torch.Tensor
652+
The quantized tensor.
653+
loss: torch.Tensor
654+
The embedding loss for the quantization.
655+
codebook_usage: torch.Tensor
656+
The fraction of codes used in the codebook.
657+
"""
658+
codebook_usage = z_e.unique().numel() / self._codebook_size.item()
659+
return z_e, torch.tensor(0), torch.tensor(codebook_usage)
660+
661+
def quantize(self, z: torch.Tensor) -> torch.Tensor:
662+
"""Quantize the input tensor z, returns corresponding
663+
codebook entry.
664+
"""
665+
return z
666+
667+
def encode(self, z: torch.Tensor) -> torch.Tensor:
668+
"""Encodes the input tensor z, returns the corresponding codebook index."""
669+
return z
670+
671+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
672+
"""Decodes the input code index into corresponding codebook entry of
673+
dimension (embedding_dim).
674+
"""
675+
return codes
676+
677+
@property
678+
def codebook_size(self) -> int:
679+
"""Returns the size of the codebook."""
680+
return int(self._codebook_size.item())
681+
682+
@property
683+
def codebook(self) -> torch.Tensor:
684+
"""Returns the codebook."""
685+
return torch.arange(self._codebook_size.item())
686+
687+
@property
688+
def embedding_dim(self) -> int:
689+
"""Returns the dimension of the codebook entries."""
690+
return 1

aion/modalities.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import List, Union, ClassVar
44
from pydantic import BaseModel, Field, ConfigDict
5-
from jaxtyping import Float, Bool, Dict
5+
from jaxtyping import Float, Bool, Int
66
from torch import Tensor
77

88

@@ -54,6 +54,31 @@ def __repr__(self) -> str:
5454
return repr_str
5555

5656

57+
# Catalog modality
58+
class Catalog(Modality):
59+
"""Catalog modality data.
60+
61+
Represents a catalog of scalar values from the
62+
Legacy Survey.
63+
"""
64+
65+
X: Int[Tensor, "batch n"] = Field(
66+
description="X position of the object in the image."
67+
)
68+
Y: Int[Tensor, "batch n"] = Field(
69+
description="Y position of the object in the image."
70+
)
71+
SHAPE_E1: Float[Tensor, "batch n"] = Field(
72+
description="First ellipticity component of the object."
73+
)
74+
SHAPE_E2: Float[Tensor, "batch n"] = Field(
75+
description="Second ellipticity component of the object."
76+
)
77+
SHAPE_R: Float[Tensor, "batch n"] = Field(
78+
description="Size of the object."
79+
)
80+
81+
5782
class ScalarModality(Modality):
5883
"""Base class for scalar modality data.
5984
@@ -70,18 +95,6 @@ def __repr__(self) -> str:
7095
return f"{self.__class__.__name__}(shape={list(self.value.shape)})"
7196

7297

73-
# Catalog modality
74-
class Catalog(Modality):
75-
"""Catalog modality data.
76-
77-
Represents a catalog of scalar values.
78-
"""
79-
80-
data: Dict[str, Dict[str, Float[Tensor, "b t"]]] = Field(
81-
description="Dictionary of dictionaries of scalar values."
82-
)
83-
84-
8598
# Flux measurements in different bands
8699
class FluxG(ScalarModality):
87100
"""G-band flux measurement."""
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:436f2a56496712c65fa571c91d85e15b58268673ac0ab77facebe510219466d5
3+
size 575936
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:6091832e50b9060014065e3ab85130250fe02c2960dd16a10b78561ca20abbb3
3+
size 820544
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:627dc9ab69c3d26f3293029f9106fba20868e5aa4b245d00f1723f947f6e6c35
3+
size 330158
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
3+
from aion.codecs import CatalogCodec
4+
from aion.modalities import Catalog
5+
6+
def test_catalog_tokenizer(data_dir):
7+
codec = CatalogCodec.from_pretrained(
8+
f"polymathic-ai/aion-catalog-codec"
9+
)
10+
codec.eval()
11+
input_batch = torch.load(
12+
data_dir / f"catalog_codec_input_batch.pt", weights_only=False
13+
)
14+
reference_encoded_batch = torch.load(
15+
data_dir / f"catalog_codec_encoded_batch.pt", weights_only=False
16+
)
17+
reference_decoded_batch = torch.load(
18+
data_dir / f"catalog_codec_decoded_batch.pt", weights_only=False
19+
)
20+
21+
with torch.no_grad():
22+
output = codec.encode(Catalog(**input_batch))
23+
decoded_output = codec.decode(output)
24+
25+
assert torch.allclose(output, reference_encoded_batch)
26+
assert torch.allclose(decoded_output.X, reference_decoded_batch['X'], atol=1e-5)
27+
assert torch.allclose(decoded_output.Y, reference_decoded_batch['Y'], atol=1e-5)
28+
assert torch.allclose(decoded_output.SHAPE_E1, reference_decoded_batch['SHAPE_E1'], atol=1e-5)
29+
assert torch.allclose(decoded_output.SHAPE_E2, reference_decoded_batch['SHAPE_E2'], atol=1e-5)
30+
assert torch.allclose(decoded_output.SHAPE_R, reference_decoded_batch['SHAPE_R'], atol=1e-5)

0 commit comments

Comments
 (0)