Skip to content

Commit 36fb46a

Browse files
committed
fix readme and a small bug in DALLE2 class
1 parent 07abfcf commit 36fb46a

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ loss.backward()
371371
unet1 = Unet(
372372
dim = 128,
373373
image_embed_dim = 512,
374+
text_embed_dim = 512,
374375
cond_dim = 128,
375376
channels = 3,
376377
dim_mults=(1, 2, 4, 8),

dalle2_pytorch/dalle2_pytorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2938,7 +2938,7 @@ def forward(
29382938
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
29392939

29402940
text_cond = text if self.decoder_need_text_cond else None
2941-
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
2941+
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
29422942

29432943
if return_pil_images:
29442944
images = list(map(self.to_pil, images.unbind(dim = 0)))

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.0'
1+
__version__ = '1.2.1'

0 commit comments

Comments
 (0)