Skip to content

Commit c184951

Browse files
committed
rearrange logits before calculating kl
1 parent dfc35d5 commit c184951

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

dalle_pytorch/dalle_pytorch.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,11 @@ def forward(
166166

167167
# kl divergence
168168

169+
logits = rearrange(logits, 'b n h w -> b (h w) n')
169170
qy = F.softmax(logits, dim = -1)
170171
log_qy = torch.log(qy + 1e-20)
171172
g = torch.log(torch.Tensor([1. / num_tokens]))
172-
kl_div = (qy * (log_qy - g)).sum(dim = -1).mean()
173+
kl_div = (qy * (log_qy - g)).sum(dim = (1, 2)).mean()
173174

174175
return recon_loss + (kl_div * kl_div_loss_weight)
175176

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.1.2',
6+
version = '0.1.4',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)