Skip to content

Commit c02aaf8

Browse files
committed
fix device error bug
1 parent 706f06d commit c02aaf8

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

dalle_pytorch/dalle_pytorch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def forward(
146146
return_loss = False,
147147
return_logits = False
148148
):
149-
num_tokens, kl_div_loss_weight = self.num_tokens, self.kl_div_loss_weight
149+
device, num_tokens, kl_div_loss_weight = img.device, self.num_tokens, self.kl_div_loss_weight
150150

151151
logits = self.encoder(img)
152152

@@ -169,7 +169,7 @@ def forward(
169169
logits = rearrange(logits, 'b n h w -> b (h w) n')
170170
qy = F.softmax(logits, dim = -1)
171171
log_qy = torch.log(qy + 1e-20)
172-
g = torch.log(torch.Tensor([1. / num_tokens]))
172+
g = torch.log(torch.Tensor([1. / num_tokens], device = device))
173173
kl_div = (qy * (log_qy - g)).sum(dim = (1, 2)).mean()
174174

175175
return recon_loss + (kl_div * kl_div_loss_weight)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.1.4',
6+
version = '0.1.5',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)