|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import torch |
| 15 | +import torch.nn as nn |
| 16 | +from torch.nn.modules.loss import _Loss |
| 17 | + |
| 18 | +from monai.utils import LossReduction |
| 19 | + |
| 20 | + |
| 21 | +class AUCMLoss(_Loss): |
| 22 | + """ |
| 23 | + AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC. |
| 24 | +
|
| 25 | + The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints |
| 26 | + on positive and negative predictions. It supports two versions: 'v1' includes class prior |
| 27 | + information, while 'v2' removes this dependency for better generalization. |
| 28 | +
|
| 29 | + Reference: |
| 30 | + Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao. |
| 31 | + "Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification." |
| 32 | + Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021. |
| 33 | + https://arxiv.org/abs/2012.03173 |
| 34 | +
|
| 35 | + Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py |
| 36 | +
|
| 37 | + Example: |
| 38 | + >>> import torch |
| 39 | + >>> from monai.losses import AUCMLoss |
| 40 | + >>> loss_fn = AUCMLoss() |
| 41 | + >>> input = torch.randn(32, 1, requires_grad=True) |
| 42 | + >>> target = torch.randint(0, 2, (32, 1)).float() |
| 43 | + >>> loss = loss_fn(input, target) |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + margin: float = 1.0, |
| 49 | + imratio: float | None = None, |
| 50 | + version: str = "v1", |
| 51 | + reduction: LossReduction | str = LossReduction.MEAN, |
| 52 | + ) -> None: |
| 53 | + """ |
| 54 | + Args: |
| 55 | + margin: margin for squared-hinge surrogate loss (default: ``1.0``). |
| 56 | + imratio: the ratio of the number of positive samples to the number of total samples in the training dataset. |
| 57 | + If this value is not given, it will be automatically calculated with mini-batch samples. |
| 58 | + This value is ignored when ``version`` is set to ``'v2'``. |
| 59 | + version: whether to include prior class information in the objective function (default: ``'v1'``). |
| 60 | + 'v1' includes class prior, 'v2' removes this dependency. |
| 61 | + reduction: {``"none"``, ``"mean"``, ``"sum"``} |
| 62 | + Specifies the reduction to apply to the output. Defaults to ``"mean"``. |
| 63 | + Note: This loss is computed at the batch level and always returns a scalar. |
| 64 | + The reduction parameter is accepted for API consistency but has no effect. |
| 65 | +
|
| 66 | + Raises: |
| 67 | + ValueError: When ``version`` is not one of ["v1", "v2"]. |
| 68 | + ValueError: When ``imratio`` is not in [0, 1]. |
| 69 | +
|
| 70 | + Example: |
| 71 | + >>> import torch |
| 72 | + >>> from monai.losses import AUCMLoss |
| 73 | + >>> loss_fn = AUCMLoss(version='v2') |
| 74 | + >>> input = torch.randn(32, 1, requires_grad=True) |
| 75 | + >>> target = torch.randint(0, 2, (32, 1)).float() |
| 76 | + >>> loss = loss_fn(input, target) |
| 77 | + """ |
| 78 | + super().__init__(reduction=LossReduction(reduction).value) |
| 79 | + if version not in ["v1", "v2"]: |
| 80 | + raise ValueError(f"version should be 'v1' or 'v2', got {version}") |
| 81 | + if imratio is not None and not (0.0 <= imratio <= 1.0): |
| 82 | + raise ValueError(f"imratio must be in [0, 1], got {imratio}") |
| 83 | + self.margin = margin |
| 84 | + self.imratio = imratio |
| 85 | + self.version = version |
| 86 | + self.a = nn.Parameter(torch.tensor(0.0)) |
| 87 | + self.b = nn.Parameter(torch.tensor(0.0)) |
| 88 | + self.alpha = nn.Parameter(torch.tensor(0.0)) |
| 89 | + |
| 90 | + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| 91 | + """ |
| 92 | + Args: |
| 93 | + input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification. |
| 94 | + target: the shape should be B1HW[D], with values 0 or 1. |
| 95 | +
|
| 96 | + Returns: |
| 97 | + torch.Tensor: scalar AUCM loss. |
| 98 | +
|
| 99 | + Raises: |
| 100 | + ValueError: When input or target have incorrect shapes. |
| 101 | + ValueError: When input or target have fewer than 2 dimensions. |
| 102 | + ValueError: When target contains non-binary values. |
| 103 | + """ |
| 104 | + if input.ndim < 2 or target.ndim < 2: |
| 105 | + raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)") |
| 106 | + if input.shape[1] != 1: |
| 107 | + raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}") |
| 108 | + if target.shape[1] != 1: |
| 109 | + raise ValueError(f"Target should have 1 channel, got {target.shape[1]}") |
| 110 | + if input.shape != target.shape: |
| 111 | + raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}") |
| 112 | + |
| 113 | + input = input.flatten() |
| 114 | + target = target.flatten() |
| 115 | + |
| 116 | + if input.numel() == 0: |
| 117 | + raise ValueError("Input and target must contain at least one element.") |
| 118 | + |
| 119 | + if not torch.all((target == 0) | (target == 1)): |
| 120 | + raise ValueError("Target must contain only binary values (0 or 1)") |
| 121 | + |
| 122 | + pos_mask = (target == 1).float() |
| 123 | + neg_mask = (target == 0).float() |
| 124 | + |
| 125 | + mean_pos_sq = (input - self.a) ** 2 |
| 126 | + mean_neg_sq = (input - self.b) ** 2 |
| 127 | + |
| 128 | + # Note: |
| 129 | + # v1 uses global expectations (normalized by total number of samples), |
| 130 | + # following the original LibAUC implementation. |
| 131 | + # v2 uses class-conditional expectations (normalized by number of samples |
| 132 | + # in each class), implemented via non-zero averaging. |
| 133 | + # These behaviors differ and should not be unified. |
| 134 | + if self.version == "v1": |
| 135 | + p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item()) |
| 136 | + p1 = 1.0 - p |
| 137 | + |
| 138 | + mean_pos = self._global_mean(mean_pos_sq, pos_mask) |
| 139 | + mean_neg = self._global_mean(mean_neg_sq, neg_mask) |
| 140 | + |
| 141 | + interaction = self._global_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask) |
| 142 | + |
| 143 | + loss = ( |
| 144 | + p1 * mean_pos |
| 145 | + + p * mean_neg |
| 146 | + + 2 * self.alpha * (p * p1 * self.margin + interaction) |
| 147 | + - p * p1 * self.alpha**2 |
| 148 | + ) |
| 149 | + |
| 150 | + else: # v2 |
| 151 | + mean_pos = self._class_mean(mean_pos_sq, pos_mask) |
| 152 | + mean_neg = self._class_mean(mean_neg_sq, neg_mask) |
| 153 | + |
| 154 | + mean_input_pos = self._class_mean(input, pos_mask) |
| 155 | + mean_input_neg = self._class_mean(input, neg_mask) |
| 156 | + |
| 157 | + loss = ( |
| 158 | + mean_pos + mean_neg + 2 * self.alpha * (self.margin + mean_input_neg - mean_input_pos) - self.alpha**2 |
| 159 | + ) |
| 160 | + |
| 161 | + return loss |
| 162 | + |
| 163 | + def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| 164 | + """ |
| 165 | + Compute the global mean of a masked tensor. |
| 166 | +
|
| 167 | + This computes the mean over all elements, where values outside the mask |
| 168 | + are zeroed out. The result is normalized by the total number of elements, |
| 169 | + not by the number of masked elements. |
| 170 | +
|
| 171 | + This corresponds to a global expectation: |
| 172 | + E[mask * tensor] |
| 173 | +
|
| 174 | + Args: |
| 175 | + tensor: Input tensor. |
| 176 | + mask: Binary mask tensor of the same shape as ``tensor``. |
| 177 | +
|
| 178 | + Returns: |
| 179 | + Scalar tensor representing the global mean. |
| 180 | + """ |
| 181 | + masked = tensor * mask |
| 182 | + if masked.numel() == 0: |
| 183 | + return torch.zeros((), dtype=tensor.dtype, device=tensor.device) |
| 184 | + return masked.mean() |
| 185 | + |
| 186 | + def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| 187 | + """ |
| 188 | + Compute the class-conditional mean of a masked tensor. |
| 189 | +
|
| 190 | + This computes the mean over only the masked (non-zero) elements, i.e., |
| 191 | + the result is normalized by the number of masked elements. |
| 192 | +
|
| 193 | + This corresponds to a class-conditional expectation: |
| 194 | + E[tensor | mask] |
| 195 | +
|
| 196 | + Args: |
| 197 | + tensor: Input tensor. |
| 198 | + mask: Binary mask tensor of the same shape as ``tensor``. |
| 199 | +
|
| 200 | + Returns: |
| 201 | + Scalar tensor representing the class-conditional mean. |
| 202 | + Returns 0 if no elements are selected by the mask. |
| 203 | + """ |
| 204 | + denom = mask.sum() |
| 205 | + if denom.item() == 0: |
| 206 | + return torch.zeros((), dtype=tensor.dtype, device=tensor.device) |
| 207 | + return (tensor * mask).sum() / denom |
0 commit comments