-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathfocal_loss.py
More file actions
309 lines (271 loc) · 14 KB
/
focal_loss.py
File metadata and controls
309 lines (271 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from collections.abc import Sequence
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from monai.metrics.utils import create_ignore_mask
from monai.networks import one_hot
from monai.utils import LossReduction
class FocalLoss(_Loss):
"""
FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from
high confidence correct predictions.
Reimplementation of the Focal Loss described in:
- ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017
- "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy",
Zhu et al., Medical Physics 2018
Example:
>>> import torch
>>> from monai.losses import FocalLoss
>>> from torch.nn import BCEWithLogitsLoss
>>> shape = B, N, *DIMS = 2, 3, 5, 7, 11
>>> input = torch.rand(*shape)
>>> target = torch.rand(*shape)
>>> # Demonstrate equivalence to BCE when gamma=0
>>> fl_g0_criterion = FocalLoss(reduction='none', gamma=0)
>>> fl_g0_loss = fl_g0_criterion(input, target)
>>> bce_criterion = BCEWithLogitsLoss(reduction='none')
>>> bce_loss = bce_criterion(input, target)
>>> assert torch.allclose(fl_g0_loss, bce_loss)
>>> # Demonstrate "focus" by setting gamma > 0.
>>> fl_g2_criterion = FocalLoss(reduction='none', gamma=2)
>>> fl_g2_loss = fl_g2_criterion(input, target)
>>> # Mark easy and hard cases
>>> is_easy = (target > 0.7) & (input > 0.7)
>>> is_hard = (target > 0.7) & (input < 0.3)
>>> easy_loss_g0 = fl_g0_loss[is_easy].mean()
>>> hard_loss_g0 = fl_g0_loss[is_hard].mean()
>>> easy_loss_g2 = fl_g2_loss[is_easy].mean()
>>> hard_loss_g2 = fl_g2_loss[is_hard].mean()
>>> # Gamma > 0 causes the loss function to "focus" on the hard
>>> # cases. IE, easy cases are downweighted, so hard cases
>>> # receive a higher proportion of the loss.
>>> hard_to_easy_ratio_g2 = hard_loss_g2 / easy_loss_g2
>>> hard_to_easy_ratio_g0 = hard_loss_g0 / easy_loss_g0
>>> assert hard_to_easy_ratio_g2 > hard_to_easy_ratio_g0
"""
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
gamma: float = 2.0,
alpha: float | Sequence[float] | None = None,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
The value should be in [0, 1].
If a sequence is provided, its length must match the number of classes
(excluding the background class if `include_background=False`).
Defaults to None.
weight: weights to apply to the voxels of each class. If None no weights are applied.
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes. If not ``include_background``,
the number of classes should not include the background category class 0).
The value/values should be no less than 0. Defaults to None.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
ignore_index: class index to ignore from the loss computation.
Example:
>>> import torch
>>> from monai.losses import FocalLoss
>>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
>>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)
>>> fl = FocalLoss(to_onehot_y=True)
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.gamma = gamma
self.weight = weight
self.use_softmax = use_softmax
self.alpha: float | torch.Tensor | None
if alpha is None:
self.alpha = None
elif isinstance(alpha, (float, int)):
self.alpha = float(alpha)
else:
self.alpha = torch.as_tensor(alpha)
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.ignore_index = ignore_index
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
The input should be the original logits since it will be transformed by
a sigmoid/softmax in the forward function.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
Raises:
ValueError: When input and target (after one hot transform if set)
have different shapes.
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
ValueError: When ``self.weight`` is a sequence and the length is not equal to the
number of classes.
ValueError: When ``self.weight`` is/contains a value that is less than 0.
"""
n_pred_ch = input.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
# if skipping background, removing first channel
target = target[:, 1:]
input = input[:, 1:]
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
mask = create_ignore_mask(target, self.ignore_index)
if mask is not None:
input = input * mask
target = target * mask
loss: torch.Tensor | None = None
input = input.float()
target = target.float()
alpha_arg = self.alpha
if self.use_softmax:
if not self.include_background and self.alpha is not None:
if isinstance(self.alpha, (float, int)):
alpha_arg = None
warnings.warn(
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
)
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
else:
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)
num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
# make sure the lengths of weights are equal to the number of classes
if self.class_weight.ndim == 0:
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
else:
if self.class_weight.shape[0] != num_of_classes:
raise ValueError(
"""the length of the `weight` sequence should be the same as the number of classes.
If `include_background=False`, the weight should not include
the background category class 0."""
)
if self.class_weight.min() < 0:
raise ValueError("the value/values of the `weight` should be no less than 0.")
# apply class_weight to loss
self.class_weight = self.class_weight.to(loss)
broadcast_dims = [-1] + [1] * len(target.shape[2:])
self.class_weight = self.class_weight.view(broadcast_dims)
loss = self.class_weight * loss
if self.reduction == LossReduction.SUM.value:
# Previously there was a mean over the last dimension, which did not
# return a compatible BCE loss. To maintain backwards compatible
# behavior we have a flag that performs this extra step, disable or
# parameterize if necessary. (Or justify why the mean should be there)
average_spatial_dims = True
if average_spatial_dims:
loss = loss.mean(dim=list(range(2, len(target.shape))))
loss = loss.sum()
elif self.reduction == LossReduction.MEAN.value:
loss = loss.mean()
elif self.reduction == LossReduction.NONE.value:
pass
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return loss
def softmax_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
s_j is the unnormalized score for class j.
"""
input_ls = input.log_softmax(1)
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target
if alpha is not None:
if isinstance(alpha, torch.Tensor):
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
else:
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
if alpha_t.ndim == 0: # scalar
alpha_val = alpha_t.item()
# (1-alpha) for the background class and alpha for the other classes
alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
else: # tensor (sequence)
if alpha_t.shape[0] != target.shape[1]:
raise ValueError(
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
)
alpha_fac = alpha_t
broadcast_dims = [-1] + [1] * len(target.shape[2:])
alpha_fac = alpha_fac.view(broadcast_dims)
loss = alpha_fac * loss
return loss
def sigmoid_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
"""
# computing binary cross entropy with logits
# equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
# see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363
loss: torch.Tensor = input - input * target - F.logsigmoid(input)
# sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
# 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
# 1-p if t==1; p if t==0 <=>
# pfac, that is, the term (1 - pt)
invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow
# (pfac.log() * gamma).exp() <=>
# pfac.log().exp() ^ gamma <=>
# pfac ^ gamma
loss = (invprobs * gamma).exp() * loss
if alpha is not None:
if isinstance(alpha, torch.Tensor):
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
else:
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
if alpha_t.ndim == 0: # scalar
# alpha if t==1; (1-alpha) if t==0
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
else: # tensor (sequence)
if alpha_t.shape[0] != target.shape[1]:
raise ValueError(
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
)
# Reshape alpha for broadcasting: (1, C, 1, 1...)
broadcast_dims = [-1] + [1] * len(target.shape[2:])
alpha_t = alpha_t.view(broadcast_dims)
# Apply per-class weight only to positive samples
# For positive samples (target==1): multiply by alpha[c]
# For negative samples (target==0): keep weight as 1.0
alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))
loss = alpha_factor * loss
return loss