-
Notifications
You must be signed in to change notification settings - Fork 20
Open
Description
The code in utils.py related to compute KL divergence is as follows, but I think maybe this is not the KL divergence but cross entropy.
Lines 199 to 203 in 647f309
| # P: pretrained model; Q: current model. | |
| prob_p = torch.nn.functional.softmax(pretrained_outputs.logits, -1) | |
| prob_q = torch.nn.functional.softmax(normal_outputs.logits, -1) | |
| loss = -(prob_p * torch.log(prob_q + 1e-12)).sum(-1).mean() |
Why not directly use PyTorch KLDivLoss?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels