-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantizers.py
More file actions
101 lines (75 loc) · 2.99 KB
/
quantizers.py
File metadata and controls
101 lines (75 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
from vector_quantize_pytorch import FSQ, VectorQuantize
class DDCL_Bottleneck(nn.Module):
"""
A VAE bottleneck using the DDCL discretization procedure.
This is a plug-and-play replacement for a VQ-VAE/FSQ codebook.
"""
def __init__(self, delta, latent_dim=4):
super().__init__()
self.delta = delta
self.latent_dim = latent_dim
def forward(self, z):
"""
The forward pass for TRAINING and INFERENCE
"""
noise = (torch.rand_like(z) - 0.5) * self.delta
z_q = z + noise
indices = torch.floor(z_q / self.delta).long()
comm_loss = torch.log2((2 * torch.abs(z) / self.delta) + 1).mean()
return z_q, indices, comm_loss
class FSQWrapper(nn.Module):
"""Wrapper around FSQ to match the interface of DDCL_Bottleneck"""
def __init__(self, levels):
super().__init__()
self.fsq = FSQ(levels=levels, channel_first=True)
self.codebook_size = torch.prod(torch.tensor(levels)).item()
def forward(self, z):
"""Forward pass that matches DDCL interface"""
z_q, indices = self.fsq(z)
return z_q, indices, 0.0 # No regularization loss for FSQ
class VanillaVAE(nn.Module):
"""
Standard Variational Autoencoder bottleneck with Gaussian latent space.
Uses reparameterization trick for sampling during training.
"""
def __init__(self, latent_dim=4):
super().__init__()
self.latent_dim = latent_dim
def forward(self, z):
"""
Forward pass for VAE bottleneck.
Args:
z: Encoder output of shape (batch, 2*latent_dim) containing mu and logvar
Returns:
z_sampled: Sampled latent vector (batch, latent_dim)
None: No indices for VAE
kl_loss: KL divergence acts as regularization loss
"""
# Split into mu and logvar
mu = z[:, :self.latent_dim]
logvar = z[:, self.latent_dim:]
# Reparameterization trick: z = mu + sigma * epsilon
std = torch.exp(0.5 * logvar)
epsilon = torch.randn_like(std)
z_sampled = mu + std * epsilon
# KL divergence loss: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
return z_sampled, None, kl_loss
class VQVAEWrapper(nn.Module):
def __init__(self, codebook_size, latent_dim=4):
super().__init__()
self.codebook_size = codebook_size
self.latent_dim = latent_dim
self.vq = VectorQuantize(dim=latent_dim, codebook_size=codebook_size)
def forward(self, z):
z_q, indices, commitment_loss = self.vq(z)
return z_q, indices, commitment_loss
class AEWrapper(nn.Module):
"""A fake quantizer that does nothing, for implementing a vanilla autoencoder"""
def __init__(self):
super().__init__()
def forward(self, x):
# No quantization
return x, None, 0.0