Skip to content

Commit 46014a0

Browse files
committed
Upload scalar tokenizer to HF
1 parent 223b12f commit 46014a0

2 files changed

Lines changed: 31 additions & 14 deletions

File tree

aion/codecs/tokenizers/scalar.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from typing import Dict, Optional
22

33
import torch
4+
from huggingface_hub import PyTorchModelHubMixin
45
from jaxtyping import Float
56
from torch import Tensor
67

78
from aion.codecs.quantizers import Quantizer
9+
from aion.codecs.quantizers.scalar import (
10+
ScalarLogReservoirQuantizer,
11+
ScalarReservoirQuantizer,
12+
)
813
from aion.codecs.tokenizers.base import QuantizedCodec
914

1015

11-
class ScalarIdentityCodec(QuantizedCodec):
16+
class BaseScalarIdentityCodec(QuantizedCodec, PyTorchModelHubMixin):
1217
"""Codec for scalar quantities.
1318
1419
A codec that embeds scalar quantities through an identity mapping. A
@@ -34,3 +39,21 @@ def _encode(self, x: Dict[str, Dict[str, Float[Tensor, "b t"]]]) -> Tensor:
3439

3540
def _decode(self, z: torch.FloatTensor) -> Dict[str, torch.FloatTensor]:
3641
return {self.modality: z}
42+
43+
44+
class ScalarReservoirCodec(BaseScalarIdentityCodec):
45+
def __init__(self, modality: str, codebook_size: int, reservoir_size: int):
46+
quantizer = ScalarReservoirQuantizer(
47+
codebook_size=codebook_size,
48+
reservoir_size=reservoir_size,
49+
)
50+
super().__init__(modality, quantizer)
51+
52+
53+
class ScalarLogReservoirCodec(BaseScalarIdentityCodec):
54+
def __init__(self, modality: str, codebook_size: int, reservoir_size: int):
55+
quantizer = ScalarLogReservoirQuantizer(
56+
codebook_size=codebook_size,
57+
reservoir_size=reservoir_size,
58+
)
59+
super().__init__(modality, quantizer)

tests/tokenizers/test_scalar_tokenizer.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import pytest
22
import torch
33

4-
from aion.codecs.quantizers.scalar import (
5-
ScalarLogReservoirQuantizer,
6-
ScalarReservoirQuantizer,
4+
from aion.codecs.tokenizers.scalar import (
5+
ScalarLogReservoirCodec,
6+
ScalarReservoirCodec,
77
)
8-
from aion.codecs.tokenizers.scalar import ScalarIdentityCodec
98

109

1110
@pytest.mark.parametrize(
@@ -23,14 +22,10 @@
2322
],
2423
)
2524
def test_log_reservoir_tokenizer(data_dir, modality):
26-
codec = ScalarIdentityCodec(
27-
modality=modality,
28-
quantizer=ScalarLogReservoirQuantizer(
29-
codebook_size=1024, reservoir_size=100000
30-
),
25+
codec = ScalarLogReservoirCodec.from_pretrained(
26+
f"polymathic-ai/aion-scalar-{modality.lower().replace('_', '-')}-codec"
3127
)
3228
codec.eval()
33-
codec.load_state_dict(torch.load(data_dir / f"{modality}_codec.pt"))
3429

3530
input_batch = torch.load(data_dir / f"{modality}_input.pt")
3631
output_batch = torch.load(data_dir / f"{modality}_output.pt")
@@ -42,9 +37,8 @@ def test_log_reservoir_tokenizer(data_dir, modality):
4237

4338
@pytest.mark.parametrize("modality", ["SHAPE_E1", "SHAPE_E2", "EBV"])
4439
def test_reservoir_tokenizer(data_dir, modality):
45-
codec = ScalarIdentityCodec(
46-
modality=modality,
47-
quantizer=ScalarReservoirQuantizer(codebook_size=1024, reservoir_size=100000),
40+
codec = ScalarReservoirCodec.from_pretrained(
41+
f"polymathic-ai/aion-scalar-{modality.lower().replace('_', '-')}-codec"
4842
)
4943
codec.eval()
5044
codec.load_state_dict(torch.load(data_dir / f"{modality}_codec.pt"))

0 commit comments

Comments
 (0)