|
1 | 1 | import math |
2 | | -from typing import Optional |
| 2 | +from typing import Optional, Dict |
| 3 | +from collections import OrderedDict |
3 | 4 |
|
4 | 5 | import scipy.interpolate |
5 | 6 | import torch |
@@ -504,3 +505,121 @@ def codebook(self) -> torch.Tensor: |
504 | 505 | def embedding_dim(self) -> int: |
505 | 506 | """Returns the dimension of the codebook entries.""" |
506 | 507 | return 1 |
| 508 | + |
| 509 | + |
| 510 | +class ComposedScalarQuantizer(Quantizer): |
| 511 | + """ |
| 512 | + Composed scalar quantizer module. |
| 513 | +
|
| 514 | + Combines multiple scalar quantizers into a single quantizer. Each quantizer |
| 515 | + operates on a different channel/feature and maintains its own codebook. |
| 516 | +
|
| 517 | + Args: |
| 518 | + quantizers: OrderedDict[str, Quantizer] |
| 519 | + Ordered dictionary mapping feature names to their respective quantizers. |
| 520 | + """ |
| 521 | + |
| 522 | + def __init__(self, quantizers: OrderedDict[str, Quantizer]): |
| 523 | + super().__init__() |
| 524 | + _offsets = [0] |
| 525 | + for key, quantizer in quantizers.items(): |
| 526 | + _offsets.append(_offsets[-1] + quantizer.codebook_size) |
| 527 | + self.offsets = _offsets[:-1] |
| 528 | + self._codebook_size = _offsets[-1] |
| 529 | + self.quantizers = nn.ModuleDict(quantizers) |
| 530 | + |
| 531 | + def forward( |
| 532 | + self, z_es: Dict[str, torch.Tensor] |
| 533 | + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 534 | + """Performs a forward pass through the vector quantizer. |
| 535 | + Args: |
| 536 | + z_es: Dict[str, torch.Tensor] |
| 537 | + The input tensor to be quantized. |
| 538 | + Returns: |
| 539 | + z_qs: torch.Tensor |
| 540 | + The quantized tensor. |
| 541 | + loss: torch.Tensor |
| 542 | + The embedding loss for the quantization. |
| 543 | + codebook_usage: torch.Tensor |
| 544 | + The fraction of codes used in the codebook. |
| 545 | + """ |
| 546 | + z_qs = [] |
| 547 | + loss = torch.tensor(0.0) |
| 548 | + codebook_usage = torch.tensor(0.0) |
| 549 | + for key, quantizer in self.quantizers.items(): |
| 550 | + z_e = z_es[key] |
| 551 | + z_q, _loss, _usage = quantizer(z_e) |
| 552 | + z_qs.append(z_q) |
| 553 | + loss += _loss |
| 554 | + codebook_usage += _usage |
| 555 | + |
| 556 | + C = len(z_qs) |
| 557 | + z_qs = torch.stack(z_qs, dim=1) # (B, C) |
| 558 | + loss /= C |
| 559 | + codebook_usage /= C |
| 560 | + return z_qs, loss, codebook_usage |
| 561 | + |
| 562 | + def quantize(self, z: Dict[str, torch.Tensor]) -> torch.Tensor: |
| 563 | + """Quantize the input tensor z, returns corresponding |
| 564 | + codebook entry. |
| 565 | + """ |
| 566 | + quantized = [] |
| 567 | + for key, quantizer in self.quantizers.items(): |
| 568 | + quantized.append(quantizer.quantize(z[key])) |
| 569 | + |
| 570 | + quantized = torch.stack(quantized, dim=1) # (B, C) |
| 571 | + return quantized |
| 572 | + |
| 573 | + def encode(self, z: Dict[str, torch.Tensor]) -> torch.Tensor: |
| 574 | + """Encodes the input tensor z, returns the corresponding |
| 575 | + codebook index. |
| 576 | +
|
| 577 | + Args: |
| 578 | + z: Dict[str, torch.Tensor] |
| 579 | + The input tensor to be encoded. |
| 580 | +
|
| 581 | + Returns: |
| 582 | + codes: torch.Tensor (B, C) |
| 583 | + Encoded tensor. |
| 584 | + """ |
| 585 | + codes = [] |
| 586 | + for offset, (key, quantizer) in zip(self.offsets, self.quantizers.items()): |
| 587 | + codes.append((quantizer.encode(z[key]) + offset)) |
| 588 | + codes = torch.stack(codes, dim=1) # (B, C) |
| 589 | + return codes |
| 590 | + |
| 591 | + def decode(self, codes: torch.Tensor) -> Dict[str, torch.Tensor]: |
| 592 | + """Decodes the input code index into corresponding codebook entry of |
| 593 | + dimension (embedding_dim). |
| 594 | +
|
| 595 | + Args: |
| 596 | + codes: torch.Tensor (B, C) |
| 597 | + Codes to be decoded. |
| 598 | +
|
| 599 | + Returns: |
| 600 | + z: Dict[str, torch.Tensor] |
| 601 | + Decoded sample. |
| 602 | + """ |
| 603 | + z = {} |
| 604 | + for i, (offset, (key, quantizer)) in enumerate( |
| 605 | + zip(self.offsets, self.quantizers.items()) |
| 606 | + ): |
| 607 | + codes_i = codes[:, i] - offset |
| 608 | + # clamp the codes to the valid range |
| 609 | + _codes_i = codes_i.clamp(0, quantizer.codebook_size - 1) |
| 610 | + decoded_i = quantizer.decode(_codes_i) |
| 611 | + # set the clamped codes to -1 |
| 612 | + is_clamped = _codes_i != codes_i |
| 613 | + decoded_i[is_clamped] = -1 |
| 614 | + z[key] = decoded_i |
| 615 | + return z |
| 616 | + |
| 617 | + @property |
| 618 | + def codebook_size(self) -> int: |
| 619 | + """Returns the size of the codebook.""" |
| 620 | + return self._codebook_size |
| 621 | + |
| 622 | + @property |
| 623 | + def embedding_dim(self) -> int: |
| 624 | + """Returns the dimension of the codebook entries.""" |
| 625 | + return 1 |
0 commit comments