We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent dfc35d5 commit c184951Copy full SHA for c184951
dalle_pytorch/dalle_pytorch.py
@@ -166,10 +166,11 @@ def forward(
166
167
# kl divergence
168
169
+ logits = rearrange(logits, 'b n h w -> b (h w) n')
170
qy = F.softmax(logits, dim = -1)
171
log_qy = torch.log(qy + 1e-20)
172
g = torch.log(torch.Tensor([1. / num_tokens]))
- kl_div = (qy * (log_qy - g)).sum(dim = -1).mean()
173
+ kl_div = (qy * (log_qy - g)).sum(dim = (1, 2)).mean()
174
175
return recon_loss + (kl_div * kl_div_loss_weight)
176
setup.py
@@ -3,7 +3,7 @@
3
setup(
4
name = 'dalle-pytorch',
5
packages = find_packages(),
6
- version = '0.1.2',
+ version = '0.1.4',
7
license='MIT',
8
description = 'DALL-E - Pytorch',
9
author = 'Phil Wang',
0 commit comments