Skip to content

Commit 0289b6c

Browse files
committed
add VQ-VAE as an alternative
1 parent 3aabe77 commit 0289b6c

File tree

3 files changed

+121
-15
lines changed

3 files changed

+121
-15
lines changed

dalle_pytorch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
1+
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE, VQVAE

dalle_pytorch/dalle_pytorch.py

+117-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from einops import rearrange
77
from axial_positional_embedding import AxialPositionalEmbedding
8+
from vector_quantize_pytorch import VectorQuantize
89
from dalle_pytorch.transformer import Transformer
910

1011
# helpers
@@ -123,6 +124,7 @@ def __init__(
123124
self.kl_div_loss_weight = kl_div_loss_weight
124125

125126
@torch.no_grad()
127+
@eval_decorator
126128
def get_codebook_indices(self, images):
127129
logits = self.forward(images, return_logits = True)
128130
codebook_indices = logits.argmax(dim = 1).flatten(1)
@@ -175,6 +177,119 @@ def forward(
175177

176178
return recon_loss + (kl_div * kl_div_loss_weight)
177179

180+
class VQVAE(nn.Module):
181+
def __init__(
182+
self,
183+
image_size = 256,
184+
num_tokens = 512,
185+
codebook_dim = 512,
186+
num_layers = 3,
187+
num_resnet_blocks = 0,
188+
hidden_dim = 64,
189+
channels = 3,
190+
temperature = 0.9,
191+
straight_through = False,
192+
vq_decay = 0.8,
193+
commitment_weight = 1.
194+
):
195+
super().__init__()
196+
assert log2(image_size).is_integer(), 'image size must be a power of 2'
197+
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
198+
has_resblocks = num_resnet_blocks > 0
199+
200+
self.image_size = image_size
201+
self.num_tokens = num_tokens
202+
self.num_layers = num_layers
203+
204+
self.vq = VectorQuantize(
205+
dim = codebook_dim,
206+
n_embed = num_tokens,
207+
decay = vq_decay,
208+
commitment = commitment_weight
209+
)
210+
211+
hdim = hidden_dim
212+
213+
enc_chans = [hidden_dim] * num_layers
214+
dec_chans = list(reversed(enc_chans))
215+
216+
enc_chans = [channels, *enc_chans]
217+
218+
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
219+
dec_chans = [dec_init_chan, *dec_chans]
220+
221+
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
222+
223+
enc_layers = []
224+
dec_layers = []
225+
226+
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
227+
enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
228+
dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
229+
230+
for _ in range(num_resnet_blocks):
231+
dec_layers.insert(0, ResBlock(dec_chans[1]))
232+
enc_layers.append(ResBlock(enc_chans[-1]))
233+
234+
if num_resnet_blocks > 0:
235+
dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
236+
237+
enc_layers.append(nn.Conv2d(enc_chans[-1], codebook_dim, 1))
238+
dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
239+
240+
self.encoder = nn.Sequential(*enc_layers)
241+
self.decoder = nn.Sequential(*dec_layers)
242+
243+
@torch.no_grad()
244+
@eval_decorator
245+
def get_codebook_indices(self, images):
246+
encoded = self.forward(images, return_encoded = True)
247+
encoded = rearrange(encoded, 'b c h w -> b (h w) c')
248+
_, indices, _ = self.vq(encoded)
249+
return indices
250+
251+
def decode(
252+
self,
253+
img_seq
254+
):
255+
codebook = rearrange(self.vq.embed, 'd n -> n d')
256+
image_embeds = codebook[img_seq]
257+
b, n, d = image_embeds.shape
258+
h = w = int(sqrt(n))
259+
260+
image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
261+
images = self.decoder(image_embeds)
262+
return images
263+
264+
def forward(
265+
self,
266+
img,
267+
return_loss = False,
268+
return_encoded = False
269+
):
270+
shape, device = img.shape, img.device
271+
272+
encoded = self.encoder(img)
273+
274+
if return_encoded:
275+
return encoded
276+
277+
h, w = encoded.shape[-2:]
278+
279+
encoded = rearrange(encoded, 'b c h w -> b (h w) c')
280+
quantized, _, commit_loss = self.vq(encoded)
281+
quantized = rearrange(quantized, 'b (h w) c -> b c h w', h = h, w = w)
282+
out = self.decoder(quantized)
283+
284+
if not return_loss:
285+
return out
286+
287+
# reconstruction loss and VQ commitment loss
288+
289+
recon_loss = F.mse_loss(img, out)
290+
291+
return recon_loss + commit_loss
292+
178293
# main classes
179294

180295
class CLIP(nn.Module):
@@ -273,11 +388,10 @@ def __init__(
273388
ff_dropout = 0,
274389
sparse_attn = False,
275390
ignore_index = -100,
276-
attn_types = None,
277-
tie_codebook_image_emb = False,
391+
attn_types = None
278392
):
279393
super().__init__()
280-
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
394+
assert isinstance(vae, (DiscreteVAE, VQVAE)), 'vae must be an instance of DiscreteVAE or VQVAE'
281395

282396
image_size = vae.image_size
283397
num_image_tokens = vae.num_tokens
@@ -302,12 +416,6 @@ def __init__(
302416
self.total_seq_len = seq_len
303417

304418
self.vae = vae
305-
self.tie_codebook_image_emb = tie_codebook_image_emb
306-
if exists(self.vae):
307-
self.vae = vae
308-
309-
if tie_codebook_image_emb:
310-
self.image_emb = vae.codebook
311419

312420
self.transformer = Transformer(
313421
dim = dim,
@@ -415,9 +523,6 @@ def forward(
415523
image_len = image.shape[1]
416524
image_emb = self.image_emb(image)
417525

418-
if self.tie_codebook_image_emb:
419-
image_emb.detach_()
420-
421526
image_emb += self.image_pos_emb(image_emb)
422527

423528
tokens = torch.cat((tokens, image_emb), dim = 1)

setup.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.1.7',
6+
version = '0.1.8',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',
@@ -18,7 +18,8 @@
1818
install_requires=[
1919
'axial_positional_embedding',
2020
'einops>=0.3',
21-
'torch>=1.6'
21+
'torch>=1.6',
22+
'vector-quantize-pytorch'
2223
],
2324
classifiers=[
2425
'Development Status :: 4 - Beta',

0 commit comments

Comments
 (0)