Skip to content

Commit 49f9eea

Browse files
committed
Adding code
1 parent 01440eb commit 49f9eea

4 files changed

Lines changed: 237 additions & 3 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,4 @@ cython_debug/
174174
.pypirc
175175
notebooks
176176
data
177+
old_impl

aion/codecs/catalog.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from collections import OrderedDict
2+
from typing import Type, Optional, Dict, List
3+
4+
from huggingface_hub import PyTorchModelHubMixin
5+
import torch
6+
from jaxtyping import Float
7+
from torch import Tensor
8+
9+
from aion.codecs.base import Codec
10+
from aion.codecs.quantizers import Quantizer
11+
from aion.codecs.quantizers.scalar import ComposedScalarQuantizer
12+
from aion.modalities import Catalog
13+
14+
__all__ = ["CatalogIdentityCodec"]
15+
16+
17+
class CatalogIdentityCodec(Codec, PyTorchModelHubMixin):
18+
"""Codec for catalog quantities.
19+
20+
A codec that embeds catalog quantities through an identity mapping. A
21+
quantizer is applied if specified.
22+
23+
Args:
24+
catalog_keys: List[str]
25+
List of catalog keys to encode.
26+
quantizers: Optional[List[Quantizer]]
27+
Optional list of quantizers for each catalog key.
28+
mask_value: int
29+
Value used to indicate masked/missing data.
30+
"""
31+
32+
def __init__(
33+
self,
34+
catalog_keys: List[str],
35+
quantizers: Optional[List[Quantizer]] = None,
36+
mask_value: int = 9999,
37+
):
38+
super().__init__()
39+
self._modality = Catalog
40+
self._catalog_keys = catalog_keys
41+
self.mask_value = mask_value
42+
if quantizers:
43+
assert len(catalog_keys) == len(quantizers), (
44+
"Number of catalog keys and quantizers must match"
45+
)
46+
_quantizer = OrderedDict()
47+
for key, quantizer in zip(catalog_keys, quantizers):
48+
_quantizer[key] = quantizer
49+
self._quantizer = ComposedScalarQuantizer(_quantizer)
50+
else:
51+
self._quantizer = None
52+
53+
@property
54+
def modality(self) -> Type[Catalog]:
55+
return self._modality
56+
57+
@property
58+
def quantizer(self) -> Optional[Quantizer]:
59+
return self._quantizer
60+
61+
def _encode(self, x: Catalog) -> Dict[str, Tensor]:
62+
encoded = OrderedDict()
63+
for key in self._catalog_keys:
64+
catalog_value = x[self.modality][key]
65+
mask = catalog_value != self.mask_value
66+
catalog_value = catalog_value[mask]
67+
encoded[key] = catalog_value
68+
encoded["mask"] = mask
69+
return encoded
70+
71+
def encode(self, x: Catalog) -> Float[Tensor, "b c1 *code_shape"]:
72+
"""Encodes a given batch of samples into latent space."""
73+
embedding = self._encode(x)
74+
_encoded = self.quantizer.encode(
75+
embedding
76+
) # (b, C), where b is the number of non-masked samples
77+
78+
mask = embedding["mask"]
79+
# B: batch size, L: sequence length (20) for each catalog key
80+
B, L = mask.shape
81+
C = len(self._catalog_keys)
82+
encoded = self.mask_value * torch.ones(
83+
B, L, C, dtype=_encoded.dtype, device=_encoded.device
84+
)
85+
encoded[mask] = _encoded
86+
encoded = encoded.reshape(B, -1)
87+
return encoded
88+
89+
def _decode(self, z: Dict[str, Tensor]) -> Catalog:
90+
return Catalog(data=z)
91+
92+
def decode(self, z: Float[Tensor, "b c1 *code_shape"]) -> Catalog:
93+
B, LC = z.shape
94+
C = len(self._catalog_keys)
95+
L = LC // C
96+
z = z[:, : C * L] # Truncate the z if it is longer than the expected length
97+
z = z.reshape(B * L, C)
98+
if self._quantizer is not None:
99+
z = self.quantizer.decode(z)
100+
for key in self._catalog_keys:
101+
z[key] = z[key].reshape(B, L)
102+
return self._decode(z)

aion/codecs/quantizers/scalar.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
2-
from typing import Optional
2+
from typing import Optional, Dict
3+
from collections import OrderedDict
34

45
import scipy.interpolate
56
import torch
@@ -504,3 +505,121 @@ def codebook(self) -> torch.Tensor:
504505
def embedding_dim(self) -> int:
505506
"""Returns the dimension of the codebook entries."""
506507
return 1
508+
509+
510+
class ComposedScalarQuantizer(Quantizer):
511+
"""
512+
Composed scalar quantizer module.
513+
514+
Combines multiple scalar quantizers into a single quantizer. Each quantizer
515+
operates on a different channel/feature and maintains its own codebook.
516+
517+
Args:
518+
quantizers: OrderedDict[str, Quantizer]
519+
Ordered dictionary mapping feature names to their respective quantizers.
520+
"""
521+
522+
def __init__(self, quantizers: OrderedDict[str, Quantizer]):
523+
super().__init__()
524+
_offsets = [0]
525+
for key, quantizer in quantizers.items():
526+
_offsets.append(_offsets[-1] + quantizer.codebook_size)
527+
self.offsets = _offsets[:-1]
528+
self._codebook_size = _offsets[-1]
529+
self.quantizers = nn.ModuleDict(quantizers)
530+
531+
def forward(
532+
self, z_es: Dict[str, torch.Tensor]
533+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
534+
"""Performs a forward pass through the vector quantizer.
535+
Args:
536+
z_es: Dict[str, torch.Tensor]
537+
The input tensor to be quantized.
538+
Returns:
539+
z_qs: torch.Tensor
540+
The quantized tensor.
541+
loss: torch.Tensor
542+
The embedding loss for the quantization.
543+
codebook_usage: torch.Tensor
544+
The fraction of codes used in the codebook.
545+
"""
546+
z_qs = []
547+
loss = torch.tensor(0.0)
548+
codebook_usage = torch.tensor(0.0)
549+
for key, quantizer in self.quantizers.items():
550+
z_e = z_es[key]
551+
z_q, _loss, _usage = quantizer(z_e)
552+
z_qs.append(z_q)
553+
loss += _loss
554+
codebook_usage += _usage
555+
556+
C = len(z_qs)
557+
z_qs = torch.stack(z_qs, dim=1) # (B, C)
558+
loss /= C
559+
codebook_usage /= C
560+
return z_qs, loss, codebook_usage
561+
562+
def quantize(self, z: Dict[str, torch.Tensor]) -> torch.Tensor:
563+
"""Quantize the input tensor z, returns corresponding
564+
codebook entry.
565+
"""
566+
quantized = []
567+
for key, quantizer in self.quantizers.items():
568+
quantized.append(quantizer.quantize(z[key]))
569+
570+
quantized = torch.stack(quantized, dim=1) # (B, C)
571+
return quantized
572+
573+
def encode(self, z: Dict[str, torch.Tensor]) -> torch.Tensor:
574+
"""Encodes the input tensor z, returns the corresponding
575+
codebook index.
576+
577+
Args:
578+
z: Dict[str, torch.Tensor]
579+
The input tensor to be encoded.
580+
581+
Returns:
582+
codes: torch.Tensor (B, C)
583+
Encoded tensor.
584+
"""
585+
codes = []
586+
for offset, (key, quantizer) in zip(self.offsets, self.quantizers.items()):
587+
codes.append((quantizer.encode(z[key]) + offset))
588+
codes = torch.stack(codes, dim=1) # (B, C)
589+
return codes
590+
591+
def decode(self, codes: torch.Tensor) -> Dict[str, torch.Tensor]:
592+
"""Decodes the input code index into corresponding codebook entry of
593+
dimension (embedding_dim).
594+
595+
Args:
596+
codes: torch.Tensor (B, C)
597+
Codes to be decoded.
598+
599+
Returns:
600+
z: Dict[str, torch.Tensor]
601+
Decoded sample.
602+
"""
603+
z = {}
604+
for i, (offset, (key, quantizer)) in enumerate(
605+
zip(self.offsets, self.quantizers.items())
606+
):
607+
codes_i = codes[:, i] - offset
608+
# clamp the codes to the valid range
609+
_codes_i = codes_i.clamp(0, quantizer.codebook_size - 1)
610+
decoded_i = quantizer.decode(_codes_i)
611+
# set the clamped codes to -1
612+
is_clamped = _codes_i != codes_i
613+
decoded_i[is_clamped] = -1
614+
z[key] = decoded_i
615+
return z
616+
617+
@property
618+
def codebook_size(self) -> int:
619+
"""Returns the size of the codebook."""
620+
return self._codebook_size
621+
622+
@property
623+
def embedding_dim(self) -> int:
624+
"""Returns the dimension of the codebook entries."""
625+
return 1

aion/modalities.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import List, Union, ClassVar
44
from pydantic import BaseModel, Field, ConfigDict
5-
from jaxtyping import Float, Bool
5+
from jaxtyping import Float, Bool, Dict
66
from torch import Tensor
77

88

@@ -70,6 +70,18 @@ def __repr__(self) -> str:
7070
return f"{self.__class__.__name__}(shape={list(self.value.shape)})"
7171

7272

73+
# Catalog modality
74+
class Catalog(Modality):
75+
"""Catalog modality data.
76+
77+
Represents a catalog of scalar values.
78+
"""
79+
80+
data: Dict[str, Dict[str, Float[Tensor, "b t"]]] = Field(
81+
description="Dictionary of dictionaries of scalar values."
82+
)
83+
84+
7385
# Flux measurements in different bands
7486
class FluxG(ScalarModality):
7587
"""G-band flux measurement."""
@@ -310,4 +322,4 @@ class XpRp(ScalarModality):
310322
]
311323

312324
# Convenience type for any modality data
313-
ModalityType = Union[Image, Spectrum, ScalarModality]
325+
ModalityType = Union[Image, Spectrum, ScalarModality, Catalog]

0 commit comments

Comments
 (0)