Skip to content

Commit 03a92e4

Browse files
committed
aim for clarity
1 parent 4b0d4a2 commit 03a92e4

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

dalle_pytorch/dalle_pytorch.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ def forward(
168168

169169
logits = rearrange(logits, 'b n h w -> b (h w) n')
170170
qy = F.softmax(logits, dim = -1)
171+
171172
log_qy = torch.log(qy + 1e-20)
172-
g = torch.log(torch.tensor([1. / num_tokens], device = device))
173-
kl_div = (qy * (log_qy - g)).sum(dim = (1, 2)).mean()
173+
log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device))
174+
kl_div = F.kl_div(log_uniform, log_qy, None, None, 'sum', log_target = True)
174175

175176
return recon_loss + (kl_div * kl_div_loss_weight)
176177

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.1.6',
6+
version = '0.1.7',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)