Skip to content

Commit dfc35d5

Browse files
committed
add kl loss, per code-review from @karpathy !
1 parent c27f48c commit dfc35d5

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ vae = DiscreteVAE(
4343

4444
images = torch.randn(4, 3, 256, 256)
4545

46-
loss = vae(images, return_recon_loss = True)
46+
loss = vae(images, return_loss = True)
4747
loss.backward()
4848

4949
# train with a lot of data to learn a good codebook

dalle_pytorch/dalle_pytorch.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def __init__(
7373
hidden_dim = 64,
7474
channels = 3,
7575
temperature = 0.9,
76-
straight_through = False
76+
straight_through = False,
77+
kl_div_loss_weight = 1.
7778
):
7879
super().__init__()
7980
assert log2(image_size).is_integer(), 'image size must be a power of 2'
@@ -119,6 +120,8 @@ def __init__(
119120
self.encoder = nn.Sequential(*enc_layers)
120121
self.decoder = nn.Sequential(*dec_layers)
121122

123+
self.kl_div_loss_weight = kl_div_loss_weight
124+
122125
@torch.no_grad()
123126
def get_codebook_indices(self, images):
124127
logits = self.forward(images, return_logits = True)
@@ -140,9 +143,11 @@ def decode(
140143
def forward(
141144
self,
142145
img,
143-
return_recon_loss = False,
146+
return_loss = False,
144147
return_logits = False
145148
):
149+
num_tokens, kl_div_loss_weight = self.num_tokens, self.kl_div_loss_weight
150+
146151
logits = self.encoder(img)
147152

148153
if return_logits:
@@ -152,11 +157,21 @@ def forward(
152157
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
153158
out = self.decoder(sampled)
154159

155-
if not return_recon_loss:
160+
if not return_loss:
156161
return out
157162

158-
loss = F.mse_loss(img, out)
159-
return loss
163+
# reconstruction loss
164+
165+
recon_loss = F.mse_loss(img, out)
166+
167+
# kl divergence
168+
169+
qy = F.softmax(logits, dim = -1)
170+
log_qy = torch.log(qy + 1e-20)
171+
g = torch.log(torch.Tensor([1. / num_tokens]))
172+
kl_div = (qy * (log_qy - g)).sum(dim = -1).mean()
173+
174+
return recon_loss + (kl_div * kl_div_loss_weight)
160175

161176
# main classes
162177

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.63',
6+
version = '0.1.2',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)