Skip to content

Commit 01440eb

Browse files
authored
Merge pull request #18 from PolymathicAI/gaia
Adding codecs for Gaia
2 parents a01a9ac + 3b5e875 commit 01440eb

31 files changed

Lines changed: 916 additions & 4 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,5 @@ cython_debug/
172172

173173
# PyPI configuration file
174174
.pypirc
175+
notebooks
176+
data

aion/codecs/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from .image import ImageCodec
2-
from .scalar import ScalarCodec, LogScalarCodec
2+
from .scalar import ScalarCodec, LogScalarCodec, MultiScalarCodec
33
from .spectrum import SpectrumCodec
44

55

6-
__all__ = ["ImageCodec", "ScalarCodec", "LogScalarCodec", "SpectrumCodec"]
6+
__all__ = [
7+
"ImageCodec",
8+
"ScalarCodec",
9+
"LogScalarCodec",
10+
"MultiScalarCodec",
11+
"SpectrumCodec",
12+
]

aion/codecs/quantizers/scalar.py

Lines changed: 235 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import scipy.interpolate
55
import torch
6+
import torch.nn as nn
67

78
from aion.codecs.quantizers import Quantizer
89

@@ -11,7 +12,7 @@ class ScalarReservoirQuantizer(Quantizer):
1112
"""
1213
Scalar quantizer module.
1314
14-
The sclar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
15+
The scalar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
1516
The CDF estimate is updated using reservoir sampling, allowing you to stream through data.
1617
1718
Args:
@@ -182,7 +183,7 @@ class ScalarLogReservoirQuantizer(ScalarReservoirQuantizer):
182183
"""
183184
Scalar quantizer module.
184185
185-
The sclar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
186+
The scalar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
186187
The CDF estimate is updated using reservoir sampling, allowing you to stream through data.
187188
188189
Args:
@@ -271,3 +272,235 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor:
271272
Decoded sample.
272273
"""
273274
return torch.exp(super().decode(codes))
275+
276+
277+
class ScalarCompressedReservoirQuantizer(ScalarReservoirQuantizer):
278+
"""
279+
Scalar quantizer module with compression/decompression functions.
280+
281+
The scalar quantizer module takes a batch of scalars, applies compression functions,
282+
and quantizes them using a CDF codebook. The CDF estimate is updated using reservoir
283+
sampling, allowing you to stream through data.
284+
285+
Args:
286+
compression_fns: list[str]
287+
List of torch function names to apply for compression (e.g., ['arcsinh']).
288+
decompression_fns: list[str]
289+
List of torch function names to apply for decompression (e.g., ['sinh']).
290+
codebook_size: int
291+
The number of codes in the codebook.
292+
reservoir_size: int
293+
The size of the reservoir to keep in memory.
294+
reservoir_default: float
295+
Optional default value of reservoir samples. Only relevant if there
296+
are fewer samples in your dataset than the size of your codebook.
297+
"""
298+
299+
def __init__(
300+
self,
301+
compression_fns: list[str],
302+
decompression_fns: list[str],
303+
codebook_size: int,
304+
reservoir_size: int,
305+
reservoir_default: Optional[float] = 0.0,
306+
):
307+
super().__init__(codebook_size, reservoir_size, reservoir_default)
308+
assert len(compression_fns) == len(decompression_fns), (
309+
"Mismatched compression/decompression functions"
310+
)
311+
self.compression_fns = compression_fns
312+
self.decompression_fns = decompression_fns
313+
314+
assert self._check_identity(torch.tensor([1.0])), (
315+
"Identity check failed, compression/decompression functions are not inverses."
316+
)
317+
318+
def compress(self, x: torch.Tensor) -> torch.Tensor:
319+
"""Apply compression functions to input tensor.
320+
321+
Args:
322+
x: torch.Tensor
323+
Input tensor to compress.
324+
325+
Returns:
326+
torch.Tensor
327+
Compressed tensor.
328+
"""
329+
for c in self.compression_fns:
330+
x = getattr(torch, c)(x)
331+
return x
332+
333+
def decompress(self, x: torch.Tensor) -> torch.Tensor:
334+
"""Apply decompression functions to input tensor.
335+
336+
Args:
337+
x: torch.Tensor
338+
Input tensor to decompress.
339+
340+
Returns:
341+
torch.Tensor
342+
Decompressed tensor.
343+
"""
344+
for c in self.decompression_fns[::-1]:
345+
x = getattr(torch, c)(x)
346+
return x
347+
348+
def _check_identity(self, x: torch.Tensor) -> bool:
349+
"""Check if compression and decompression are inverses.
350+
351+
Args:
352+
x: torch.Tensor
353+
Test tensor.
354+
355+
Returns:
356+
bool
357+
True if compress(decompress(x)) ≈ x.
358+
"""
359+
return torch.allclose(self.decompress(self.compress(x)), x)
360+
361+
def _update_reservoirs(self, z_e: torch.Tensor):
362+
z_e = self.compress(z_e)
363+
super()._update_reservoirs(z_e)
364+
365+
def encode(self, z: torch.Tensor) -> torch.Tensor:
366+
z = self.compress(z)
367+
return super().encode(z)
368+
369+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
370+
return self.decompress(super().decode(codes))
371+
372+
373+
class MultiScalarCompressedReservoirQuantizer(Quantizer):
374+
"""
375+
Multi-channel scalar quantizer with compression.
376+
377+
Wraps multiple ScalarCompressedReservoirQuantizers to quantize multi-channel tensors.
378+
Each channel is quantized independently with its own reservoir.
379+
380+
Args:
381+
compression_fns: list[str]
382+
List of torch function names to apply for compression (e.g., ['arcsinh']).
383+
decompression_fns: list[str]
384+
List of torch function names to apply for decompression (e.g., ['sinh']).
385+
codebook_size: int
386+
The number of codes in the codebook.
387+
reservoir_size: int
388+
The size of the reservoir to keep in memory.
389+
reservoir_default: float
390+
Optional default value of reservoir samples.
391+
num_quantizers: int
392+
Number of channels/quantizers to create.
393+
"""
394+
395+
def __init__(
396+
self,
397+
compression_fns: list[str],
398+
decompression_fns: list[str],
399+
codebook_size: int,
400+
reservoir_size: int,
401+
reservoir_default: Optional[float] = 0.0,
402+
num_quantizers: int = 1,
403+
):
404+
super().__init__()
405+
self.quantizers = nn.ModuleList(
406+
[
407+
ScalarCompressedReservoirQuantizer(
408+
compression_fns,
409+
decompression_fns,
410+
codebook_size,
411+
reservoir_size,
412+
reservoir_default,
413+
)
414+
for _ in range(num_quantizers)
415+
]
416+
)
417+
self.num_quantizers = num_quantizers
418+
419+
def encode(self, z: torch.Tensor) -> torch.Tensor:
420+
"""Encodes the input tensor z, returns the corresponding
421+
codebook index.
422+
423+
Args:
424+
z: torch.Tensor (B, C)
425+
The input tensor to be encoded.
426+
427+
Returns:
428+
codes: torch.Tensor (B, C)
429+
Encoded tensor.
430+
"""
431+
return torch.stack(
432+
[q.encode(z[:, i]) for i, q in enumerate(self.quantizers)],
433+
dim=1,
434+
)
435+
436+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
437+
"""Decodes the input code index into corresponding codebook entry of
438+
dimension (embedding_dim).
439+
440+
Args:
441+
codes: torch.Tensor (B, C)
442+
Codes to be decoded.
443+
444+
Returns:
445+
z: torch.Tensor (B, C)
446+
Decoded sample.
447+
"""
448+
return torch.stack(
449+
[q.decode(codes[:, i]) for i, q in enumerate(self.quantizers)],
450+
dim=1,
451+
)
452+
453+
def quantize(self, z: torch.Tensor) -> torch.Tensor:
454+
"""Quantize the input tensor z, returns corresponding
455+
codebook entry.
456+
457+
Args:
458+
z: torch.Tensor (B, C)
459+
The input tensor to be quantized.
460+
461+
Returns:
462+
z: torch.Tensor (B, C)
463+
Quantized tensor.
464+
"""
465+
return self.decode(self.encode(z))
466+
467+
def _update_reservoirs(self, z_e: torch.Tensor):
468+
for i, q in enumerate(self.quantizers):
469+
q._update_reservoirs(z_e[:, i])
470+
471+
def forward(
472+
self, z_e: torch.Tensor
473+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
474+
"""Performs a forward pass through the vector quantizer.
475+
Args:
476+
z_e: torch.Tensor (B, C, ...)
477+
The input tensor to be quantized.
478+
Returns:
479+
z_q: torch.Tensor
480+
The quantized tensor.
481+
loss: torch.Tensor
482+
The embedding loss for the quantization.
483+
codebook_usage: torch.Tensor
484+
The fraction of codes used in the codebook.
485+
"""
486+
self._update_reservoirs(z_e)
487+
indices = self.encode(z_e)
488+
z_q = self.decode(indices)
489+
num_unique = sum([len(torch.unique(c)) for c in indices.T])
490+
codebook_usage = num_unique / (self.codebook_size * self.num_quantizers)
491+
return z_q, torch.nn.functional.mse_loss(z_q, z_e), torch.tensor(codebook_usage)
492+
493+
@property
494+
def codebook_size(self) -> int:
495+
"""Returns the size of the codebook."""
496+
return self.quantizers[0].codebook_size
497+
498+
@property
499+
def codebook(self) -> torch.Tensor:
500+
"""Returns the codebook."""
501+
return self.quantizers[0].codebook
502+
503+
@property
504+
def embedding_dim(self) -> int:
505+
"""Returns the dimension of the codebook entries."""
506+
return 1

aion/codecs/scalar.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aion.codecs.quantizers.scalar import (
99
ScalarLogReservoirQuantizer,
1010
ScalarReservoirQuantizer,
11+
MultiScalarCompressedReservoirQuantizer,
1112
)
1213
from aion.codecs.base import Codec
1314
from aion.modalities import ScalarModality, ScalarModalities
@@ -85,3 +86,70 @@ def __init__(
8586
reservoir_size=reservoir_size,
8687
min_log_value=min_log_value,
8788
)
89+
90+
91+
class MultiScalarCodec(BaseScalarIdentityCodec):
92+
"""Codec for multi-channel scalar quantities with compression.
93+
94+
A codec that handles multi-channel scalar modalities using compression
95+
and decompression functions before quantization. This is particularly useful
96+
for spectral coefficients or other multi-dimensional scalar data that
97+
benefits from preprocessing transformations.
98+
99+
Each channel is quantized independently using a compressed reservoir quantizer,
100+
allowing for different statistical distributions across channels while
101+
maintaining the ability to handle streaming data.
102+
103+
Args:
104+
modality: str
105+
The name of the modality this codec is designed for. Must match
106+
a modality name defined in the ScalarModalities registry.
107+
compression_fns: list[str]
108+
List of PyTorch function names to apply for compression (e.g., ['arcsinh']).
109+
These functions are applied in order to transform the data before quantization.
110+
decompression_fns: list[str]
111+
List of PyTorch function names to apply for decompression (e.g., ['sinh']).
112+
These functions are applied in reverse order during decoding to restore
113+
the original data range.
114+
codebook_size: int
115+
The number of codes in each quantizer's codebook.
116+
reservoir_size: int
117+
The size of the reservoir to keep in memory for each channel's quantizer.
118+
num_quantizers: int
119+
Number of channels/quantizers to create, corresponding to the number
120+
of dimensions in the multi-channel scalar data.
121+
122+
Note:
123+
The compression and decompression functions must be mathematical inverses
124+
of each other. The codec will verify this during initialization and raise
125+
an assertion error if the functions are not properly inverse.
126+
127+
Example:
128+
>>> codec = MultiScalarCodec(
129+
... modality="bp_coefficients",
130+
... compression_fns=["arcsinh"],
131+
... decompression_fns=["sinh"],
132+
... codebook_size=1024,
133+
... reservoir_size=10000,
134+
... num_quantizers=55
135+
... )
136+
"""
137+
138+
def __init__(
139+
self,
140+
modality: str,
141+
compression_fns: list[str],
142+
decompression_fns: list[str],
143+
codebook_size: int,
144+
reservoir_size: int,
145+
num_quantizers: int,
146+
):
147+
super().__init__()
148+
self._modality_class = next(m for m in ScalarModalities if m.name == modality)
149+
self._quantizer = MultiScalarCompressedReservoirQuantizer(
150+
compression_fns=compression_fns,
151+
decompression_fns=decompression_fns,
152+
codebook_size=codebook_size,
153+
reservoir_size=reservoir_size,
154+
num_quantizers=num_quantizers,
155+
)

0 commit comments

Comments
 (0)