-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathutils.py
More file actions
92 lines (77 loc) · 2.84 KB
/
Copy pathutils.py
File metadata and controls
92 lines (77 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import numpy as np
from torch.nn import functional as F
import random
import torch.nn as nn
import os
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, pred, mask):
"""
pred: [B, 1, H, W]
mask: [B, 1, H, W]
"""
assert pred.shape == mask.shape, "pred and mask should have the same shape."
pred = torch.sigmoid(pred)
num_pos = torch.sum(mask)
num_neg = mask.numel() - num_pos
w_pos = (1 - pred) ** self.gamma
w_neg = pred ** self.gamma
loss_pos = -self.alpha * mask * w_pos * torch.log(pred + 1e-12)
loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - pred + 1e-12)
loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12)
return loss
class DiceLoss(nn.Module):
def __init__(self, smooth=1.0):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, pred, mask):
"""
pred: [B, 1, H, W]
mask: [B, 1, H, W]
"""
assert pred.shape == mask.shape, "pred and mask should have the same shape."
pred = torch.sigmoid(pred)
intersection = torch.sum(pred * mask)
union = torch.sum(pred) + torch.sum(mask)
dice_loss = (2.0 * intersection + self.smooth) / (union + self.smooth)
return 1 - dice_loss
class MaskMSE(nn.Module):
def __init__(self, ):
super(MaskMSE, self).__init__()
def forward(self, pred, mask, pred_iou):
"""
pred: [B, 1, H, W]
mask: [B, 1, H, W]
pred_iou: [B, 1]
"""
assert pred.shape == mask.shape, "pred and mask should have the same shape."
pred = torch.sigmoid(pred)
intersection = torch.sum(pred * mask)
union = torch.sum(pred) + torch.sum(mask) - intersection
iou = (intersection + 1e-7) / (union + 1e-7)
mse = torch.mean((iou - pred_iou) ** 2)
return mse
class FocalDice_MSELoss(nn.Module):
def __init__(self, weight=20.0, iou_scale=1.0):
super(FocalDice_MSELoss, self).__init__()
self.weight = weight
self.iou_scale = iou_scale
self.focal_loss = FocalLoss()
self.dice_loss = DiceLoss()
self.maskiou_mse = MaskMSE()
def forward(self, pred, mask, pred_iou):
"""
pred: [B, 1, H, W]
mask: [B, 1, H, W]
"""
assert pred.shape == mask.shape, "pred and mask should have the same shape."
focal_loss = self.focal_loss(pred, mask)
dice_loss =self.dice_loss(pred, mask)
loss1 = self.weight * focal_loss + dice_loss
loss2 = self.maskiou_mse(pred, mask, pred_iou)
loss = loss1 + loss2 * self.iou_scale
return loss