-
Notifications
You must be signed in to change notification settings - Fork 4
Description
Describe the feature or idea you want to propose
Weights are used with cross entropy loss (and with several of the custom losses) to help compensate for an unbalanced dataset. However, these weights need to be factored into the denominator of the reduction term in order to provide proper regularization. Otherwise, the loss value for batches with rare classes will be much higher and an overall higher loss is roughly equivalent to a higher learning rate for just those batches.
When using vanilla Cross Entropy loss with soft labels and weights, the denominator for the reduction does not include the weights as shown in the equation in the pytorch documentation. This impacts many of the loss functions which are implemented via soft labels with a base loss of something like cross entropy: https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
Similarly, the code for CDWCE does not use either the distance weights or the user weights in the denominator. https://dlordinal.readthedocs.io/en/latest/_modules/dlordinal/losses/cdw.html#CDWCELoss
In either case, the reported loss will be much less stable with user weights and this can also impact the training results, for example as described in this comment:
https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/39
Describe your proposed solution
I've now rolled my own version of CDWCELoss in response to this issue and other things. At this point, to generate loss, I combine a CE with custom weights and proper denominator but also include CDWCE which only uses its own weights. And each one can have its own lambda attached to balance them.
class CDW_CCELoss(nn.Module):
def __init__(self, classes, *, alpha=1.0, weight=None):
super(CDW_CCELoss, self).__init__()
indices = torch.arange(classes)
distance_list = []
for i in range(classes):
distance_list.append(torch.abs(i - indices) ** alpha)
distance_weight = torch.stack(distance_list)
self.register_buffer("distance_weight", distance_weight)
if weight is not None:
self.register_buffer("custom_weight", weight)
else:
self.weight = None
def forward(self, pred, target):
"""pred is of shape NxC, target is of shape N"""
index = torch.arange(pred.shape[0], device=pred.device)
if self.custom_weight is not None:
weight = self.custom_weight[target]
ce = torch.log_softmax(pred, dim=1)[index, target] * (weight / weight.sum())
else:
ce = torch.log_softmax(pred, dim=1)[index, target] / pred.shape[0]
exit(1)
pass
distance = self.distance_weight[target]
cdw = torch.log(1 - torch.softmax(pred, dim=1)) # * distance
cdw = cdw * (distance / distance.sum(dim=1, keepdim=True))
cdw = cdw / pred.shape[0]
return -ce.sum(), -cdw.sum() * 100
For the other methods you implement, you can specify that the base cross entropy class does not do a reduction and do the proper weighted reduction in your own code wrapper.
Feel free to use this code or a variant if you decide you want to implement this way of dealing with custom weights. But either way, thanks for the great resource. I ended up having to roll my own but testing with the options in your library and using it has been invaluable.
Describe alternatives you have considered, if relevant
No response
Additional context
No response