@@ -73,7 +73,8 @@ def __init__(
73
73
hidden_dim = 64 ,
74
74
channels = 3 ,
75
75
temperature = 0.9 ,
76
- straight_through = False
76
+ straight_through = False ,
77
+ kl_div_loss_weight = 1.
77
78
):
78
79
super ().__init__ ()
79
80
assert log2 (image_size ).is_integer (), 'image size must be a power of 2'
@@ -119,6 +120,8 @@ def __init__(
119
120
self .encoder = nn .Sequential (* enc_layers )
120
121
self .decoder = nn .Sequential (* dec_layers )
121
122
123
+ self .kl_div_loss_weight = kl_div_loss_weight
124
+
122
125
@torch .no_grad ()
123
126
def get_codebook_indices (self , images ):
124
127
logits = self .forward (images , return_logits = True )
@@ -140,9 +143,11 @@ def decode(
140
143
def forward (
141
144
self ,
142
145
img ,
143
- return_recon_loss = False ,
146
+ return_loss = False ,
144
147
return_logits = False
145
148
):
149
+ num_tokens , kl_div_loss_weight = self .num_tokens , self .kl_div_loss_weight
150
+
146
151
logits = self .encoder (img )
147
152
148
153
if return_logits :
@@ -152,11 +157,21 @@ def forward(
152
157
sampled = einsum ('b n h w, n d -> b d h w' , soft_one_hot , self .codebook .weight )
153
158
out = self .decoder (sampled )
154
159
155
- if not return_recon_loss :
160
+ if not return_loss :
156
161
return out
157
162
158
- loss = F .mse_loss (img , out )
159
- return loss
163
+ # reconstruction loss
164
+
165
+ recon_loss = F .mse_loss (img , out )
166
+
167
+ # kl divergence
168
+
169
+ qy = F .softmax (logits , dim = - 1 )
170
+ log_qy = torch .log (qy + 1e-20 )
171
+ g = torch .log (torch .Tensor ([1. / num_tokens ]))
172
+ kl_div = (qy * (log_qy - g )).sum (dim = - 1 ).mean ()
173
+
174
+ return recon_loss + (kl_div * kl_div_loss_weight )
160
175
161
176
# main classes
162
177
0 commit comments