5
5
6
6
from einops import rearrange
7
7
from axial_positional_embedding import AxialPositionalEmbedding
8
+ from vector_quantize_pytorch import VectorQuantize
8
9
from dalle_pytorch .transformer import Transformer
9
10
10
11
# helpers
@@ -123,6 +124,7 @@ def __init__(
123
124
self .kl_div_loss_weight = kl_div_loss_weight
124
125
125
126
@torch .no_grad ()
127
+ @eval_decorator
126
128
def get_codebook_indices (self , images ):
127
129
logits = self .forward (images , return_logits = True )
128
130
codebook_indices = logits .argmax (dim = 1 ).flatten (1 )
@@ -175,6 +177,119 @@ def forward(
175
177
176
178
return recon_loss + (kl_div * kl_div_loss_weight )
177
179
180
+ class VQVAE (nn .Module ):
181
+ def __init__ (
182
+ self ,
183
+ image_size = 256 ,
184
+ num_tokens = 512 ,
185
+ codebook_dim = 512 ,
186
+ num_layers = 3 ,
187
+ num_resnet_blocks = 0 ,
188
+ hidden_dim = 64 ,
189
+ channels = 3 ,
190
+ temperature = 0.9 ,
191
+ straight_through = False ,
192
+ vq_decay = 0.8 ,
193
+ commitment_weight = 1.
194
+ ):
195
+ super ().__init__ ()
196
+ assert log2 (image_size ).is_integer (), 'image size must be a power of 2'
197
+ assert num_layers >= 1 , 'number of layers must be greater than or equal to 1'
198
+ has_resblocks = num_resnet_blocks > 0
199
+
200
+ self .image_size = image_size
201
+ self .num_tokens = num_tokens
202
+ self .num_layers = num_layers
203
+
204
+ self .vq = VectorQuantize (
205
+ dim = codebook_dim ,
206
+ n_embed = num_tokens ,
207
+ decay = vq_decay ,
208
+ commitment = commitment_weight
209
+ )
210
+
211
+ hdim = hidden_dim
212
+
213
+ enc_chans = [hidden_dim ] * num_layers
214
+ dec_chans = list (reversed (enc_chans ))
215
+
216
+ enc_chans = [channels , * enc_chans ]
217
+
218
+ dec_init_chan = codebook_dim if not has_resblocks else dec_chans [0 ]
219
+ dec_chans = [dec_init_chan , * dec_chans ]
220
+
221
+ enc_chans_io , dec_chans_io = map (lambda t : list (zip (t [:- 1 ], t [1 :])), (enc_chans , dec_chans ))
222
+
223
+ enc_layers = []
224
+ dec_layers = []
225
+
226
+ for (enc_in , enc_out ), (dec_in , dec_out ) in zip (enc_chans_io , dec_chans_io ):
227
+ enc_layers .append (nn .Sequential (nn .Conv2d (enc_in , enc_out , 4 , stride = 2 , padding = 1 ), nn .ReLU ()))
228
+ dec_layers .append (nn .Sequential (nn .ConvTranspose2d (dec_in , dec_out , 4 , stride = 2 , padding = 1 ), nn .ReLU ()))
229
+
230
+ for _ in range (num_resnet_blocks ):
231
+ dec_layers .insert (0 , ResBlock (dec_chans [1 ]))
232
+ enc_layers .append (ResBlock (enc_chans [- 1 ]))
233
+
234
+ if num_resnet_blocks > 0 :
235
+ dec_layers .insert (0 , nn .Conv2d (codebook_dim , dec_chans [1 ], 1 ))
236
+
237
+ enc_layers .append (nn .Conv2d (enc_chans [- 1 ], codebook_dim , 1 ))
238
+ dec_layers .append (nn .Conv2d (dec_chans [- 1 ], channels , 1 ))
239
+
240
+ self .encoder = nn .Sequential (* enc_layers )
241
+ self .decoder = nn .Sequential (* dec_layers )
242
+
243
+ @torch .no_grad ()
244
+ @eval_decorator
245
+ def get_codebook_indices (self , images ):
246
+ encoded = self .forward (images , return_encoded = True )
247
+ encoded = rearrange (encoded , 'b c h w -> b (h w) c' )
248
+ _ , indices , _ = self .vq (encoded )
249
+ return indices
250
+
251
+ def decode (
252
+ self ,
253
+ img_seq
254
+ ):
255
+ codebook = rearrange (self .vq .embed , 'd n -> n d' )
256
+ image_embeds = codebook [img_seq ]
257
+ b , n , d = image_embeds .shape
258
+ h = w = int (sqrt (n ))
259
+
260
+ image_embeds = rearrange (image_embeds , 'b (h w) d -> b d h w' , h = h , w = w )
261
+ images = self .decoder (image_embeds )
262
+ return images
263
+
264
+ def forward (
265
+ self ,
266
+ img ,
267
+ return_loss = False ,
268
+ return_encoded = False
269
+ ):
270
+ shape , device = img .shape , img .device
271
+
272
+ encoded = self .encoder (img )
273
+
274
+ if return_encoded :
275
+ return encoded
276
+
277
+ h , w = encoded .shape [- 2 :]
278
+
279
+ encoded = rearrange (encoded , 'b c h w -> b (h w) c' )
280
+ quantized , _ , commit_loss = self .vq (encoded )
281
+ quantized = rearrange (quantized , 'b (h w) c -> b c h w' , h = h , w = w )
282
+ out = self .decoder (quantized )
283
+
284
+ if not return_loss :
285
+ return out
286
+
287
+ # reconstruction loss and VQ commitment loss
288
+
289
+ recon_loss = F .mse_loss (img , out )
290
+
291
+ return recon_loss + commit_loss
292
+
178
293
# main classes
179
294
180
295
class CLIP (nn .Module ):
@@ -273,11 +388,10 @@ def __init__(
273
388
ff_dropout = 0 ,
274
389
sparse_attn = False ,
275
390
ignore_index = - 100 ,
276
- attn_types = None ,
277
- tie_codebook_image_emb = False ,
391
+ attn_types = None
278
392
):
279
393
super ().__init__ ()
280
- assert isinstance (vae , DiscreteVAE ) , 'vae must be an instance of DiscreteVAE'
394
+ assert isinstance (vae , ( DiscreteVAE , VQVAE )) , 'vae must be an instance of DiscreteVAE or VQVAE '
281
395
282
396
image_size = vae .image_size
283
397
num_image_tokens = vae .num_tokens
@@ -302,12 +416,6 @@ def __init__(
302
416
self .total_seq_len = seq_len
303
417
304
418
self .vae = vae
305
- self .tie_codebook_image_emb = tie_codebook_image_emb
306
- if exists (self .vae ):
307
- self .vae = vae
308
-
309
- if tie_codebook_image_emb :
310
- self .image_emb = vae .codebook
311
419
312
420
self .transformer = Transformer (
313
421
dim = dim ,
@@ -415,9 +523,6 @@ def forward(
415
523
image_len = image .shape [1 ]
416
524
image_emb = self .image_emb (image )
417
525
418
- if self .tie_codebook_image_emb :
419
- image_emb .detach_ ()
420
-
421
526
image_emb += self .image_pos_emb (image_emb )
422
527
423
528
tokens = torch .cat ((tokens , image_emb ), dim = 1 )
0 commit comments