Open
Description
Two questions
alpha_t
is unused below.- should we use the following to derive the BCECrossEntropyLoss:
pt = torch.log(p) * target.float() + torch.log(1.0 - p) * (1 - target).float()
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, device="cuda:0", eps=1e-10):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.device = device
self.eps = eps
def forward(self, input, target):
p = torch.sigmoid(input)
pt = p * target.float() + (1.0 - p) * (1 - target).float()
alpha_t = (1.0 - self.alpha) * target.float() + self.alpha * (1 - target).float()
loss = - 1.0 * torch.pow((1 - pt), self.gamma) * torch.log(pt + self.eps)
return loss.sum()
Metadata
Assignees
Labels
No labels