File tree 2 files changed +11
-3
lines changed
2 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -256,7 +256,8 @@ def __init__(
256
256
ff_dropout = 0 ,
257
257
sparse_attn = False ,
258
258
noncausal_attn_len = 0 ,
259
- ignore_index = - 100
259
+ ignore_index = - 100 ,
260
+ tie_codebook_image_emb = False
260
261
):
261
262
super ().__init__ ()
262
263
assert isinstance (vae , DiscreteVAE ), 'vae must be an instance of DiscreteVAE'
@@ -285,9 +286,12 @@ def __init__(
285
286
self .noncausal_attn_len = noncausal_attn_len
286
287
287
288
self .vae = vae
289
+ self .tie_codebook_image_emb = tie_codebook_image_emb
288
290
if exists (self .vae ):
289
291
self .vae = vae
290
- self .image_emb = vae .codebook
292
+
293
+ if tie_codebook_image_emb :
294
+ self .image_emb = vae .codebook
291
295
292
296
self .transformer = Transformer (
293
297
dim = dim ,
@@ -394,6 +398,10 @@ def forward(
394
398
395
399
image_len = image .shape [1 ]
396
400
image_emb = self .image_emb (image )
401
+
402
+ if self .tie_codebook_image_emb :
403
+ image_emb .detach_ ()
404
+
397
405
image_emb += self .image_pos_emb (image_emb )
398
406
399
407
tokens = torch .cat ((tokens , image_emb ), dim = 1 )
Original file line number Diff line number Diff line change 3
3
setup (
4
4
name = 'dalle-pytorch' ,
5
5
packages = find_packages (),
6
- version = '0.0.53 ' ,
6
+ version = '0.0.54 ' ,
7
7
license = 'MIT' ,
8
8
description = 'DALL-E - Pytorch' ,
9
9
author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments