Skip to content

Commit ef2acfb

Browse files
shubham-61969pre-commit-ci[bot]ericspod
authored
4609: Add AUC-Margin Loss for AUROC optimization (#8719)
Fixes #4609. ### Description This PR adds an implementation of **AUC-Margin Loss (AUCM)** for direct AUROC optimization in MONAI, based on: > [Yuan et al., *“Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification”*, ICCV 2021.](https://arxiv.org/abs/2012.03173) Implementation based on: [https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py](https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py) The loss is designed for imbalanced classification problems, which are common in medical imaging, where AUROC is often the primary evaluation metric. The implementation follows MONAI’s loss conventions, is fully PyTorch-native, and does not introduce any new dependencies. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 063081f commit ef2acfb

3 files changed

Lines changed: 419 additions & 0 deletions

File tree

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .adversarial_loss import PatchAdversarialLoss
15+
from .aucm_loss import AUCMLoss
1516
from .barlow_twins import BarlowTwinsLoss
1617
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1718
from .contrastive import ContrastiveLoss

monai/losses/aucm_loss.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import torch
15+
import torch.nn as nn
16+
from torch.nn.modules.loss import _Loss
17+
18+
from monai.utils import LossReduction
19+
20+
21+
class AUCMLoss(_Loss):
22+
"""
23+
AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC.
24+
25+
The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints
26+
on positive and negative predictions. It supports two versions: 'v1' includes class prior
27+
information, while 'v2' removes this dependency for better generalization.
28+
29+
Reference:
30+
Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao.
31+
"Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification."
32+
Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.
33+
https://arxiv.org/abs/2012.03173
34+
35+
Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py
36+
37+
Example:
38+
>>> import torch
39+
>>> from monai.losses import AUCMLoss
40+
>>> loss_fn = AUCMLoss()
41+
>>> input = torch.randn(32, 1, requires_grad=True)
42+
>>> target = torch.randint(0, 2, (32, 1)).float()
43+
>>> loss = loss_fn(input, target)
44+
"""
45+
46+
def __init__(
47+
self,
48+
margin: float = 1.0,
49+
imratio: float | None = None,
50+
version: str = "v1",
51+
reduction: LossReduction | str = LossReduction.MEAN,
52+
) -> None:
53+
"""
54+
Args:
55+
margin: margin for squared-hinge surrogate loss (default: ``1.0``).
56+
imratio: the ratio of the number of positive samples to the number of total samples in the training dataset.
57+
If this value is not given, it will be automatically calculated with mini-batch samples.
58+
This value is ignored when ``version`` is set to ``'v2'``.
59+
version: whether to include prior class information in the objective function (default: ``'v1'``).
60+
'v1' includes class prior, 'v2' removes this dependency.
61+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
62+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
63+
Note: This loss is computed at the batch level and always returns a scalar.
64+
The reduction parameter is accepted for API consistency but has no effect.
65+
66+
Raises:
67+
ValueError: When ``version`` is not one of ["v1", "v2"].
68+
ValueError: When ``imratio`` is not in [0, 1].
69+
70+
Example:
71+
>>> import torch
72+
>>> from monai.losses import AUCMLoss
73+
>>> loss_fn = AUCMLoss(version='v2')
74+
>>> input = torch.randn(32, 1, requires_grad=True)
75+
>>> target = torch.randint(0, 2, (32, 1)).float()
76+
>>> loss = loss_fn(input, target)
77+
"""
78+
super().__init__(reduction=LossReduction(reduction).value)
79+
if version not in ["v1", "v2"]:
80+
raise ValueError(f"version should be 'v1' or 'v2', got {version}")
81+
if imratio is not None and not (0.0 <= imratio <= 1.0):
82+
raise ValueError(f"imratio must be in [0, 1], got {imratio}")
83+
self.margin = margin
84+
self.imratio = imratio
85+
self.version = version
86+
self.a = nn.Parameter(torch.tensor(0.0))
87+
self.b = nn.Parameter(torch.tensor(0.0))
88+
self.alpha = nn.Parameter(torch.tensor(0.0))
89+
90+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
91+
"""
92+
Args:
93+
input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification.
94+
target: the shape should be B1HW[D], with values 0 or 1.
95+
96+
Returns:
97+
torch.Tensor: scalar AUCM loss.
98+
99+
Raises:
100+
ValueError: When input or target have incorrect shapes.
101+
ValueError: When input or target have fewer than 2 dimensions.
102+
ValueError: When target contains non-binary values.
103+
"""
104+
if input.ndim < 2 or target.ndim < 2:
105+
raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)")
106+
if input.shape[1] != 1:
107+
raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}")
108+
if target.shape[1] != 1:
109+
raise ValueError(f"Target should have 1 channel, got {target.shape[1]}")
110+
if input.shape != target.shape:
111+
raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}")
112+
113+
input = input.flatten()
114+
target = target.flatten()
115+
116+
if input.numel() == 0:
117+
raise ValueError("Input and target must contain at least one element.")
118+
119+
if not torch.all((target == 0) | (target == 1)):
120+
raise ValueError("Target must contain only binary values (0 or 1)")
121+
122+
pos_mask = (target == 1).float()
123+
neg_mask = (target == 0).float()
124+
125+
mean_pos_sq = (input - self.a) ** 2
126+
mean_neg_sq = (input - self.b) ** 2
127+
128+
# Note:
129+
# v1 uses global expectations (normalized by total number of samples),
130+
# following the original LibAUC implementation.
131+
# v2 uses class-conditional expectations (normalized by number of samples
132+
# in each class), implemented via non-zero averaging.
133+
# These behaviors differ and should not be unified.
134+
if self.version == "v1":
135+
p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item())
136+
p1 = 1.0 - p
137+
138+
mean_pos = self._global_mean(mean_pos_sq, pos_mask)
139+
mean_neg = self._global_mean(mean_neg_sq, neg_mask)
140+
141+
interaction = self._global_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask)
142+
143+
loss = (
144+
p1 * mean_pos
145+
+ p * mean_neg
146+
+ 2 * self.alpha * (p * p1 * self.margin + interaction)
147+
- p * p1 * self.alpha**2
148+
)
149+
150+
else: # v2
151+
mean_pos = self._class_mean(mean_pos_sq, pos_mask)
152+
mean_neg = self._class_mean(mean_neg_sq, neg_mask)
153+
154+
mean_input_pos = self._class_mean(input, pos_mask)
155+
mean_input_neg = self._class_mean(input, neg_mask)
156+
157+
loss = (
158+
mean_pos + mean_neg + 2 * self.alpha * (self.margin + mean_input_neg - mean_input_pos) - self.alpha**2
159+
)
160+
161+
return loss
162+
163+
def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
164+
"""
165+
Compute the global mean of a masked tensor.
166+
167+
This computes the mean over all elements, where values outside the mask
168+
are zeroed out. The result is normalized by the total number of elements,
169+
not by the number of masked elements.
170+
171+
This corresponds to a global expectation:
172+
E[mask * tensor]
173+
174+
Args:
175+
tensor: Input tensor.
176+
mask: Binary mask tensor of the same shape as ``tensor``.
177+
178+
Returns:
179+
Scalar tensor representing the global mean.
180+
"""
181+
masked = tensor * mask
182+
if masked.numel() == 0:
183+
return torch.zeros((), dtype=tensor.dtype, device=tensor.device)
184+
return masked.mean()
185+
186+
def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
187+
"""
188+
Compute the class-conditional mean of a masked tensor.
189+
190+
This computes the mean over only the masked (non-zero) elements, i.e.,
191+
the result is normalized by the number of masked elements.
192+
193+
This corresponds to a class-conditional expectation:
194+
E[tensor | mask]
195+
196+
Args:
197+
tensor: Input tensor.
198+
mask: Binary mask tensor of the same shape as ``tensor``.
199+
200+
Returns:
201+
Scalar tensor representing the class-conditional mean.
202+
Returns 0 if no elements are selected by the mask.
203+
"""
204+
denom = mask.sum()
205+
if denom.item() == 0:
206+
return torch.zeros((), dtype=tensor.dtype, device=tensor.device)
207+
return (tensor * mask).sum() / denom

0 commit comments

Comments
 (0)