Skip to content

Commit 4511d2d

Browse files
committed
remove some dead code, text mask is not needed when training DALL-E
1 parent 07f6f3f commit 4511d2d

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

README.md

+3-6
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,13 @@ dalle = DALLE(
123123

124124
text = torch.randint(0, 10000, (4, 256))
125125
images = torch.randn(4, 3, 256, 256)
126-
mask = torch.ones_like(text).bool()
127126

128-
loss = dalle(text, images, mask = mask, return_loss = True)
127+
loss = dalle(text, images, return_loss = True)
129128
loss.backward()
130129

131130
# do the above for a long time with a lot of data ... then
132131

133-
images = dalle.generate_images(text, mask = mask)
132+
images = dalle.generate_images(text)
134133
images.shape # (4, 3, 256, 256)
135134
```
136135

@@ -141,7 +140,6 @@ img_prime = torch.randn(4, 3, 256, 256)
141140

142141
images = dalle.generate_images(
143142
text,
144-
mask = mask,
145143
img = img_prime,
146144
num_init_img_tokens = (14 * 32) # you can set the size of the initial crop, defaults to a little less than ~1/2 of the tokens, as done in the paper
147145
)
@@ -179,9 +177,8 @@ dalle = DALLE(
179177

180178
text = torch.randint(0, 10000, (4, 256))
181179
images = torch.randn(4, 3, 256, 256)
182-
mask = torch.ones_like(text).bool()
183180

184-
loss = dalle(text, images, mask = mask, return_loss = True)
181+
loss = dalle(text, images, return_loss = True)
185182
loss.backward()
186183
```
187184

dalle_pytorch/dalle_pytorch.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,6 @@ def generate_images(
466466
text,
467467
*,
468468
clip = None,
469-
mask = None,
470469
filter_thres = 0.5,
471470
temperature = 1.,
472471
img = None,
@@ -494,7 +493,7 @@ def generate_images(
494493

495494
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
496495

497-
logits = self(text, image, mask = mask)
496+
logits = self(text, image)
498497
logits = logits[:, -1, :]
499498

500499
filtered_logits = top_k(logits, thres = filter_thres)
@@ -503,9 +502,6 @@ def generate_images(
503502
sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
504503
out = torch.cat((out, sample[:, None]), dim=-1)
505504

506-
if out.shape[1] <= text_seq_len:
507-
mask = F.pad(mask, (0, 1), value = True)
508-
509505
text_seq = out[:, :text_seq_len]
510506

511507
img_seq = out[:, -image_seq_len:]
@@ -521,7 +517,6 @@ def forward(
521517
self,
522518
text,
523519
image = None,
524-
mask = None,
525520
return_loss = False
526521
):
527522
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.1.8',
7+
version = '1.2.0',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)