Skip to content

Commit 29171c6

Browse files
committed
rename a keyword
1 parent 29e40a3 commit 29171c6

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ vae = DiscreteVAE(
4343

4444
images = torch.randn(4, 3, 256, 256)
4545

46-
loss = vae(images, return_recon_loss = True)
46+
loss = vae(images, return_loss = True)
4747
loss.backward()
4848

4949
# train with a lot of data to learn a good codebook

dalle_pytorch/dalle_pytorch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def decode(
143143
def forward(
144144
self,
145145
img,
146-
return_recon_loss = False,
146+
return_loss = False,
147147
return_logits = False
148148
):
149149
num_tokens, kl_div_loss_weight = self.num_tokens, self.kl_div_loss_weight
@@ -157,7 +157,7 @@ def forward(
157157
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
158158
out = self.decoder(sampled)
159159

160-
if not return_recon_loss:
160+
if not return_loss:
161161
return out
162162

163163
# reconstruction loss

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)