33
44import scipy .interpolate
55import torch
6+ import torch .nn as nn
67
78from aion .codecs .quantizers import Quantizer
89
@@ -11,7 +12,7 @@ class ScalarReservoirQuantizer(Quantizer):
1112 """
1213 Scalar quantizer module.
1314
14- The sclar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
15+ The scalar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
1516 The CDF estimate is updated using reservoir sampling, allowing you to stream through data.
1617
1718 Args:
@@ -182,7 +183,7 @@ class ScalarLogReservoirQuantizer(ScalarReservoirQuantizer):
182183 """
183184 Scalar quantizer module.
184185
185- The sclar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
186+ The scalar quantizer module takes a batch of scalars and quantizes them using a CDF codebook.
186187 The CDF estimate is updated using reservoir sampling, allowing you to stream through data.
187188
188189 Args:
@@ -271,3 +272,235 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor:
271272 Decoded sample.
272273 """
273274 return torch .exp (super ().decode (codes ))
275+
276+
277+ class ScalarCompressedReservoirQuantizer (ScalarReservoirQuantizer ):
278+ """
279+ Scalar quantizer module with compression/decompression functions.
280+
281+ The scalar quantizer module takes a batch of scalars, applies compression functions,
282+ and quantizes them using a CDF codebook. The CDF estimate is updated using reservoir
283+ sampling, allowing you to stream through data.
284+
285+ Args:
286+ compression_fns: list[str]
287+ List of torch function names to apply for compression (e.g., ['arcsinh']).
288+ decompression_fns: list[str]
289+ List of torch function names to apply for decompression (e.g., ['sinh']).
290+ codebook_size: int
291+ The number of codes in the codebook.
292+ reservoir_size: int
293+ The size of the reservoir to keep in memory.
294+ reservoir_default: float
295+ Optional default value of reservoir samples. Only relevant if there
296+ are fewer samples in your dataset than the size of your codebook.
297+ """
298+
299+ def __init__ (
300+ self ,
301+ compression_fns : list [str ],
302+ decompression_fns : list [str ],
303+ codebook_size : int ,
304+ reservoir_size : int ,
305+ reservoir_default : Optional [float ] = 0.0 ,
306+ ):
307+ super ().__init__ (codebook_size , reservoir_size , reservoir_default )
308+ assert len (compression_fns ) == len (decompression_fns ), (
309+ "Mismatched compression/decompression functions"
310+ )
311+ self .compression_fns = compression_fns
312+ self .decompression_fns = decompression_fns
313+
314+ assert self ._check_identity (torch .tensor ([1.0 ])), (
315+ "Identity check failed, compression/decompression functions are not inverses."
316+ )
317+
318+ def compress (self , x : torch .Tensor ) -> torch .Tensor :
319+ """Apply compression functions to input tensor.
320+
321+ Args:
322+ x: torch.Tensor
323+ Input tensor to compress.
324+
325+ Returns:
326+ torch.Tensor
327+ Compressed tensor.
328+ """
329+ for c in self .compression_fns :
330+ x = getattr (torch , c )(x )
331+ return x
332+
333+ def decompress (self , x : torch .Tensor ) -> torch .Tensor :
334+ """Apply decompression functions to input tensor.
335+
336+ Args:
337+ x: torch.Tensor
338+ Input tensor to decompress.
339+
340+ Returns:
341+ torch.Tensor
342+ Decompressed tensor.
343+ """
344+ for c in self .decompression_fns [::- 1 ]:
345+ x = getattr (torch , c )(x )
346+ return x
347+
348+ def _check_identity (self , x : torch .Tensor ) -> bool :
349+ """Check if compression and decompression are inverses.
350+
351+ Args:
352+ x: torch.Tensor
353+ Test tensor.
354+
355+ Returns:
356+ bool
357+ True if compress(decompress(x)) ≈ x.
358+ """
359+ return torch .allclose (self .decompress (self .compress (x )), x )
360+
361+ def _update_reservoirs (self , z_e : torch .Tensor ):
362+ z_e = self .compress (z_e )
363+ super ()._update_reservoirs (z_e )
364+
365+ def encode (self , z : torch .Tensor ) -> torch .Tensor :
366+ z = self .compress (z )
367+ return super ().encode (z )
368+
369+ def decode (self , codes : torch .Tensor ) -> torch .Tensor :
370+ return self .decompress (super ().decode (codes ))
371+
372+
373+ class MultiScalarCompressedReservoirQuantizer (Quantizer ):
374+ """
375+ Multi-channel scalar quantizer with compression.
376+
377+ Wraps multiple ScalarCompressedReservoirQuantizers to quantize multi-channel tensors.
378+ Each channel is quantized independently with its own reservoir.
379+
380+ Args:
381+ compression_fns: list[str]
382+ List of torch function names to apply for compression (e.g., ['arcsinh']).
383+ decompression_fns: list[str]
384+ List of torch function names to apply for decompression (e.g., ['sinh']).
385+ codebook_size: int
386+ The number of codes in the codebook.
387+ reservoir_size: int
388+ The size of the reservoir to keep in memory.
389+ reservoir_default: float
390+ Optional default value of reservoir samples.
391+ num_quantizers: int
392+ Number of channels/quantizers to create.
393+ """
394+
395+ def __init__ (
396+ self ,
397+ compression_fns : list [str ],
398+ decompression_fns : list [str ],
399+ codebook_size : int ,
400+ reservoir_size : int ,
401+ reservoir_default : Optional [float ] = 0.0 ,
402+ num_quantizers : int = 1 ,
403+ ):
404+ super ().__init__ ()
405+ self .quantizers = nn .ModuleList (
406+ [
407+ ScalarCompressedReservoirQuantizer (
408+ compression_fns ,
409+ decompression_fns ,
410+ codebook_size ,
411+ reservoir_size ,
412+ reservoir_default ,
413+ )
414+ for _ in range (num_quantizers )
415+ ]
416+ )
417+ self .num_quantizers = num_quantizers
418+
419+ def encode (self , z : torch .Tensor ) -> torch .Tensor :
420+ """Encodes the input tensor z, returns the corresponding
421+ codebook index.
422+
423+ Args:
424+ z: torch.Tensor (B, C)
425+ The input tensor to be encoded.
426+
427+ Returns:
428+ codes: torch.Tensor (B, C)
429+ Encoded tensor.
430+ """
431+ return torch .stack (
432+ [q .encode (z [:, i ]) for i , q in enumerate (self .quantizers )],
433+ dim = 1 ,
434+ )
435+
436+ def decode (self , codes : torch .Tensor ) -> torch .Tensor :
437+ """Decodes the input code index into corresponding codebook entry of
438+ dimension (embedding_dim).
439+
440+ Args:
441+ codes: torch.Tensor (B, C)
442+ Codes to be decoded.
443+
444+ Returns:
445+ z: torch.Tensor (B, C)
446+ Decoded sample.
447+ """
448+ return torch .stack (
449+ [q .decode (codes [:, i ]) for i , q in enumerate (self .quantizers )],
450+ dim = 1 ,
451+ )
452+
453+ def quantize (self , z : torch .Tensor ) -> torch .Tensor :
454+ """Quantize the input tensor z, returns corresponding
455+ codebook entry.
456+
457+ Args:
458+ z: torch.Tensor (B, C)
459+ The input tensor to be quantized.
460+
461+ Returns:
462+ z: torch.Tensor (B, C)
463+ Quantized tensor.
464+ """
465+ return self .decode (self .encode (z ))
466+
467+ def _update_reservoirs (self , z_e : torch .Tensor ):
468+ for i , q in enumerate (self .quantizers ):
469+ q ._update_reservoirs (z_e [:, i ])
470+
471+ def forward (
472+ self , z_e : torch .Tensor
473+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
474+ """Performs a forward pass through the vector quantizer.
475+ Args:
476+ z_e: torch.Tensor (B, C, ...)
477+ The input tensor to be quantized.
478+ Returns:
479+ z_q: torch.Tensor
480+ The quantized tensor.
481+ loss: torch.Tensor
482+ The embedding loss for the quantization.
483+ codebook_usage: torch.Tensor
484+ The fraction of codes used in the codebook.
485+ """
486+ self ._update_reservoirs (z_e )
487+ indices = self .encode (z_e )
488+ z_q = self .decode (indices )
489+ num_unique = sum ([len (torch .unique (c )) for c in indices .T ])
490+ codebook_usage = num_unique / (self .codebook_size * self .num_quantizers )
491+ return z_q , torch .nn .functional .mse_loss (z_q , z_e ), torch .tensor (codebook_usage )
492+
493+ @property
494+ def codebook_size (self ) -> int :
495+ """Returns the size of the codebook."""
496+ return self .quantizers [0 ].codebook_size
497+
498+ @property
499+ def codebook (self ) -> torch .Tensor :
500+ """Returns the codebook."""
501+ return self .quantizers [0 ].codebook
502+
503+ @property
504+ def embedding_dim (self ) -> int :
505+ """Returns the dimension of the codebook entries."""
506+ return 1
0 commit comments