Skip to content

Commit e7ff7d7

Browse files
committed
take care of a small pain point
1 parent 35c4c84 commit e7ff7d7

File tree

6 files changed

+126
-110
lines changed

6 files changed

+126
-110
lines changed

examples/autoencoder.py

+20-35
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,29 @@
88
from torchvision import datasets, transforms
99
from torch.utils.data import DataLoader
1010

11-
from vector_quantize_pytorch import VectorQuantize
12-
11+
from vector_quantize_pytorch import VectorQuantize, Sequential
1312

1413
lr = 3e-4
1514
train_iter = 1000
1615
num_codes = 256
1716
seed = 1234
17+
rotation_trick = True
1818
device = "cuda" if torch.cuda.is_available() else "cpu"
1919

20-
21-
class SimpleVQAutoEncoder(nn.Module):
22-
def __init__(self, **vq_kwargs):
23-
super().__init__()
24-
self.layers = nn.ModuleList(
25-
[
26-
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
27-
nn.MaxPool2d(kernel_size=2, stride=2),
28-
nn.GELU(),
29-
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
30-
nn.MaxPool2d(kernel_size=2, stride=2),
31-
VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs),
32-
nn.Upsample(scale_factor=2, mode="nearest"),
33-
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
34-
nn.GELU(),
35-
nn.Upsample(scale_factor=2, mode="nearest"),
36-
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
37-
]
38-
)
39-
return
40-
41-
def forward(self, x):
42-
for layer in self.layers:
43-
if isinstance(layer, VectorQuantize):
44-
x, indices, commit_loss = layer(x)
45-
else:
46-
x = layer(x)
47-
48-
return x.clamp(-1, 1), indices, commit_loss
49-
20+
def SimpleVQAutoEncoder(**vq_kwargs):
21+
return Sequential(
22+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23+
nn.MaxPool2d(kernel_size=2, stride=2),
24+
nn.GELU(),
25+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
26+
nn.MaxPool2d(kernel_size=2, stride=2),
27+
VectorQuantize(dim=32, accept_image_fmap = True, **vq_kwargs),
28+
nn.Upsample(scale_factor=2, mode="nearest"),
29+
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
30+
nn.GELU(),
31+
nn.Upsample(scale_factor=2, mode="nearest"),
32+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
33+
)
5034

5135
def train(model, train_loader, train_iterations=1000, alpha=10):
5236
def iterate_dataset(data_loader):
@@ -62,7 +46,10 @@ def iterate_dataset(data_loader):
6246
for _ in (pbar := trange(train_iterations)):
6347
opt.zero_grad()
6448
x, _ = next(iterate_dataset(train_loader))
49+
6550
out, indices, cmt_loss = model(x)
51+
out = out.clamp(-1., 1.)
52+
6653
rec_loss = (out - x).abs().mean()
6754
(rec_loss + alpha * cmt_loss).backward()
6855

@@ -72,8 +59,6 @@ def iterate_dataset(data_loader):
7259
+ f"cmt loss: {cmt_loss.item():.3f} | "
7360
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
7461
)
75-
return
76-
7762

7863
transform = transforms.Compose(
7964
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
@@ -91,7 +76,7 @@ def iterate_dataset(data_loader):
9176

9277
model = SimpleVQAutoEncoder(
9378
codebook_size=num_codes,
94-
rotation_trick=True
79+
rotation_trick=rotation_trick
9580
).to(device)
9681

9782
opt = torch.optim.AdamW(model.parameters(), lr=lr)

examples/autoencoder_fsq.py

+19-32
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchvision import datasets, transforms
1010
from torch.utils.data import DataLoader
1111

12-
from vector_quantize_pytorch import FSQ
12+
from vector_quantize_pytorch import FSQ, Sequential
1313

1414

1515
lr = 3e-4
@@ -20,36 +20,22 @@
2020
device = "cuda" if torch.cuda.is_available() else "cpu"
2121

2222

23-
class SimpleFSQAutoEncoder(nn.Module):
24-
def __init__(self, levels: list[int]):
25-
super().__init__()
26-
self.layers = nn.ModuleList(
27-
[
28-
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
29-
nn.MaxPool2d(kernel_size=2, stride=2),
30-
nn.GELU(),
31-
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
32-
nn.MaxPool2d(kernel_size=2, stride=2),
33-
nn.Conv2d(32, len(levels), kernel_size=1),
34-
FSQ(levels),
35-
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
36-
nn.Upsample(scale_factor=2, mode="nearest"),
37-
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
38-
nn.GELU(),
39-
nn.Upsample(scale_factor=2, mode="nearest"),
40-
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
41-
]
42-
)
43-
return
44-
45-
def forward(self, x):
46-
for layer in self.layers:
47-
if isinstance(layer, FSQ):
48-
x, indices = layer(x)
49-
else:
50-
x = layer(x)
51-
52-
return x.clamp(-1, 1), indices
23+
def SimpleFSQAutoEncoder(levels: list[int]):
24+
return Sequential(
25+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
26+
nn.MaxPool2d(kernel_size=2, stride=2),
27+
nn.GELU(),
28+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
29+
nn.MaxPool2d(kernel_size=2, stride=2),
30+
nn.Conv2d(32, len(levels), kernel_size=1),
31+
FSQ(levels),
32+
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
33+
nn.Upsample(scale_factor=2, mode="nearest"),
34+
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
35+
nn.GELU(),
36+
nn.Upsample(scale_factor=2, mode="nearest"),
37+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
38+
)
5339

5440

5541
def train(model, train_loader, train_iterations=1000):
@@ -67,6 +53,8 @@ def iterate_dataset(data_loader):
6753
opt.zero_grad()
6854
x, _ = next(iterate_dataset(train_loader))
6955
out, indices = model(x)
56+
out = out.clamp(-1., 1.)
57+
7058
rec_loss = (out - x).abs().mean()
7159
rec_loss.backward()
7260

@@ -75,7 +63,6 @@ def iterate_dataset(data_loader):
7563
f"rec loss: {rec_loss.item():.3f} | "
7664
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
7765
)
78-
return
7966

8067

8168
transform = transforms.Compose(

examples/autoencoder_lfq.py

+27-42
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchvision import datasets, transforms
1111
from torch.utils.data import DataLoader
1212

13-
from vector_quantize_pytorch import LFQ
13+
from vector_quantize_pytorch import LFQ, Sequential
1414

1515
lr = 3e-4
1616
train_iter = 1000
@@ -22,46 +22,31 @@
2222

2323
device = "cuda" if torch.cuda.is_available() else "cpu"
2424

25-
class LFQAutoEncoder(nn.Module):
26-
def __init__(
27-
self,
28-
codebook_size,
29-
**vq_kwargs
30-
):
31-
super().__init__()
32-
assert log2(codebook_size).is_integer()
33-
quantize_dim = int(log2(codebook_size))
34-
35-
self.encode = nn.Sequential(
36-
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
37-
nn.MaxPool2d(kernel_size=2, stride=2),
38-
nn.GELU(),
39-
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
40-
nn.MaxPool2d(kernel_size=2, stride=2),
41-
# In general norm layers are commonly used in Resnet-based encoder/decoders
42-
# explicitly add one here with affine=False to avoid introducing new parameters
43-
nn.GroupNorm(4, 32, affine=False),
44-
nn.Conv2d(32, quantize_dim, kernel_size=1),
45-
)
46-
47-
self.quantize = LFQ(dim=quantize_dim, **vq_kwargs)
48-
49-
self.decode = nn.Sequential(
50-
nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1),
51-
nn.Upsample(scale_factor=2, mode="nearest"),
52-
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
53-
nn.GELU(),
54-
nn.Upsample(scale_factor=2, mode="nearest"),
55-
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
56-
)
57-
return
58-
59-
def forward(self, x):
60-
x = self.encode(x)
61-
x, indices, entropy_aux_loss = self.quantize(x)
62-
x = self.decode(x)
63-
return x.clamp(-1, 1), indices, entropy_aux_loss
64-
25+
def LFQAutoEncoder(
26+
codebook_size,
27+
**vq_kwargs
28+
):
29+
assert log2(codebook_size).is_integer()
30+
quantize_dim = int(log2(codebook_size))
31+
32+
return Sequential(
33+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
34+
nn.MaxPool2d(kernel_size=2, stride=2),
35+
nn.GELU(),
36+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
37+
nn.MaxPool2d(kernel_size=2, stride=2),
38+
# In general norm layers are commonly used in Resnet-based encoder/decoders
39+
# explicitly add one here with affine=False to avoid introducing new parameters
40+
nn.GroupNorm(4, 32, affine=False),
41+
nn.Conv2d(32, quantize_dim, kernel_size=1),
42+
LFQ(dim=quantize_dim, **vq_kwargs),
43+
nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1),
44+
nn.Upsample(scale_factor=2, mode="nearest"),
45+
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
46+
nn.GELU(),
47+
nn.Upsample(scale_factor=2, mode="nearest"),
48+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
49+
)
6550

6651
def train(model, train_loader, train_iterations=1000):
6752
def iterate_dataset(data_loader):
@@ -78,6 +63,7 @@ def iterate_dataset(data_loader):
7863
opt.zero_grad()
7964
x, _ = next(iterate_dataset(train_loader))
8065
out, indices, entropy_aux_loss = model(x)
66+
out = out.clamp(-1., 1.)
8167

8268
rec_loss = F.l1_loss(out, x)
8369
(rec_loss + entropy_aux_loss).backward()
@@ -88,7 +74,6 @@ def iterate_dataset(data_loader):
8874
+ f"entropy aux loss: {entropy_aux_loss.item():.3f} | "
8975
+ f"active %: {indices.unique().numel() / codebook_size * 100:.3f}"
9076
)
91-
return
9277

9378
transform = transforms.Compose(
9479
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.19.3"
3+
version = "1.19.4"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

vector_quantize_pytorch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@
66
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
77
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
88
from vector_quantize_pytorch.latent_quantization import LatentQuantize
9+
10+
from vector_quantize_pytorch.utils import Sequential

vector_quantize_pytorch/utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import Module, ModuleList
4+
5+
# quantization
6+
7+
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
8+
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
9+
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
10+
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
11+
from vector_quantize_pytorch.lookup_free_quantization import LFQ
12+
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
13+
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
14+
from vector_quantize_pytorch.latent_quantization import LatentQuantize
15+
16+
QUANTIZE_KLASSES = (
17+
VectorQuantize,
18+
ResidualVQ,
19+
GroupedResidualVQ,
20+
RandomProjectionQuantizer,
21+
FSQ,
22+
LFQ,
23+
ResidualLFQ,
24+
GroupedResidualLFQ,
25+
ResidualFSQ,
26+
GroupedResidualFSQ,
27+
LatentQuantize
28+
)
29+
30+
# classes
31+
32+
class Sequential(Module):
33+
def __init__(
34+
self,
35+
*fns: Module
36+
):
37+
super().__init__()
38+
assert sum([int(isinstance(fn, QUANTIZE_KLASSES)) for fn in fns]) == 1, 'this special Sequential must contain exactly one quantizer'
39+
40+
self.fns = ModuleList(fns)
41+
42+
def forward(
43+
self,
44+
x,
45+
**kwargs
46+
):
47+
for fn in self.fns:
48+
49+
if not isinstance(fn, QUANTIZE_KLASSES):
50+
x = fn(x)
51+
continue
52+
53+
x, *rest = fn(x, **kwargs)
54+
55+
output = (x, *rest)
56+
57+
return output

0 commit comments

Comments
 (0)