-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlosses.py
More file actions
73 lines (55 loc) · 1.85 KB
/
losses.py
File metadata and controls
73 lines (55 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
from segmentation_models_pytorch.losses import (
DiceLoss,
JaccardLoss,
FocalLoss,
)
class LossCombined(nn.Module):
"""Class for the loss function that combines the jacard and binary cross entropy loss."""
def __init__(self):
super().__init__()
self.jaccard_loss = LossJaccard()
self.ce_loss = LossCE()
def forward(self, y_hat, y):
jaccard_loss = self.jaccard_loss(y_hat, y)
ce_loss = self.ce_loss(y_hat, y)
return (jaccard_loss + ce_loss) / 2.0
class LossCE(nn.Module):
"""Class for the binary cross entropy loss function."""
def __init__(self):
super().__init__()
self.loss = nn.BCEWithLogitsLoss()
def forward(self, y_hat, y):
y = y.float()
return self.loss(y_hat, y)
class LossDice(nn.Module):
"""Class for the dice loss function."""
def __init__(self):
super().__init__()
self.loss = DiceLoss(mode="binary", from_logits=True)
def forward(self, y_hat, y):
y = y.long()
return self.loss(y_hat, y)
class LossJaccard(nn.Module):
"""Class for the jacard loss function."""
def __init__(self):
super().__init__()
self.loss = JaccardLoss(mode="binary", from_logits=True)
def forward(self, y_hat, y):
y = y.long()
y_prob = torch.sigmoid(y_hat)
y_pred = (y_prob > 0.5).int()
y_ground = y.int()
if torch.all(y_pred == 0) and torch.all(y_ground == 0):
return torch.tensor(0.0)
else:
return self.loss(y_hat, y)
class LossFocal(nn.Module):
"""Class for the focal loss function."""
def __init__(self):
super().__init__()
self.loss = FocalLoss(mode="binary", alpha=0.25)
def forward(self, y_hat, y):
y = y.long()
return self.loss(y_hat, y)