Skip to content

Commit 29e40a3

Browse files
committed
add kl loss, per code-review from @karpathy !
1 parent c27f48c commit 29e40a3

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

dalle_pytorch/dalle_pytorch.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def __init__(
7373
hidden_dim = 64,
7474
channels = 3,
7575
temperature = 0.9,
76-
straight_through = False
76+
straight_through = False,
77+
kl_div_loss_weight = 1.
7778
):
7879
super().__init__()
7980
assert log2(image_size).is_integer(), 'image size must be a power of 2'
@@ -119,6 +120,8 @@ def __init__(
119120
self.encoder = nn.Sequential(*enc_layers)
120121
self.decoder = nn.Sequential(*dec_layers)
121122

123+
self.kl_div_loss_weight = kl_div_loss_weight
124+
122125
@torch.no_grad()
123126
def get_codebook_indices(self, images):
124127
logits = self.forward(images, return_logits = True)
@@ -143,6 +146,8 @@ def forward(
143146
return_recon_loss = False,
144147
return_logits = False
145148
):
149+
num_tokens, kl_div_loss_weight = self.num_tokens, self.kl_div_loss_weight
150+
146151
logits = self.encoder(img)
147152

148153
if return_logits:
@@ -155,8 +160,18 @@ def forward(
155160
if not return_recon_loss:
156161
return out
157162

158-
loss = F.mse_loss(img, out)
159-
return loss
163+
# reconstruction loss
164+
165+
recon_loss = F.mse_loss(img, out)
166+
167+
# kl divergence
168+
169+
qy = F.softmax(logits, dim = -1)
170+
log_qy = torch.log(qy + 1e-20)
171+
g = torch.Tensor([1. / num_tokens])
172+
kl_div = (qy * (log_qy - g)).sum(dim = -1).mean()
173+
174+
return recon_loss + (kl_div * kl_div_loss_weight)
160175

161176
# main classes
162177

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.63',
6+
version = '0.1.0',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)