Skip to content

Commit 30a13eb

Browse files
committed
do not tie vae codebook to dall-e image embedding by default, and if tying, make sure to detach
1 parent 40f4119 commit 30a13eb

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

dalle_pytorch/dalle_pytorch.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ def __init__(
256256
ff_dropout = 0,
257257
sparse_attn = False,
258258
noncausal_attn_len = 0,
259-
ignore_index = -100
259+
ignore_index = -100,
260+
tie_codebook_image_emb = False
260261
):
261262
super().__init__()
262263
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
@@ -285,9 +286,12 @@ def __init__(
285286
self.noncausal_attn_len = noncausal_attn_len
286287

287288
self.vae = vae
289+
self.tie_codebook_image_emb = tie_codebook_image_emb
288290
if exists(self.vae):
289291
self.vae = vae
290-
self.image_emb = vae.codebook
292+
293+
if tie_codebook_image_emb:
294+
self.image_emb = vae.codebook
291295

292296
self.transformer = Transformer(
293297
dim = dim,
@@ -394,6 +398,10 @@ def forward(
394398

395399
image_len = image.shape[1]
396400
image_emb = self.image_emb(image)
401+
402+
if self.tie_codebook_image_emb:
403+
image_emb.detach_()
404+
397405
image_emb += self.image_pos_emb(image_emb)
398406

399407
tokens = torch.cat((tokens, image_emb), dim = 1)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.53',
6+
version = '0.0.54',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)