Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 165 additions & 111 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,218 +23,272 @@
class AsymmetricFocalTverskyLoss(_Loss):
"""
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
Actually, it's only supported for binary image segmentation now.
It treats the background class (index 0) differently from all foreground classes (indices 1...N).
Reimplementation of the Asymmetric Focal Tversky Loss described in:
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
include_background: bool = True,
delta: float = 0.7,
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include loss computation for the background class. Defaults to True.
delta : weight of the background. Defaults to 0.7. (Used to weigh FNs and FPs in Tversky index)
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.include_background = include_background
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
# Exclude background if needed
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
y_pred = y_pred[:, 1:]
y_true = y_true[:, 1:]

axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
return loss
# dice_class shape is (B, C)

n_classes = dice_class.shape[1]

if not self.include_background:
# All classes are foreground, apply foreground logic
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, C)
elif n_classes == 1:
# Single class, must be foreground (BG was excluded or not provided)
loss = torch.pow(1.0 - dice_class, 1.0 - self.gamma) # (B, 1)
else:
# Asymmetric logic: class 0 is BG, others are FG
back_dice_loss = (1.0 - dice_class[:, 0]).unsqueeze(1) # (B, 1)
fore_dice_loss = torch.pow(1.0 - dice_class[:, 1:], 1.0 - self.gamma) # (B, C-1)
loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) # (B, C)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss) # mean over batch and classes
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over batch and classes
if self.reduction == LossReduction.NONE.value:
return loss # returns (B, C) losses
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')


class AsymmetricFocalLoss(_Loss):
"""
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
Actually, it's only supported for binary image segmentation now.
AsymmetricFocalLoss is a variant of FocalLoss, which attentions to the foreground class.
It treats the background class (index 0) differently from all foreground classes (indices 1...N).
Background class (0): applies gamma exponent to (1-p)
Foreground classes (1..N): no gamma exponent
Reimplementation of the Asymmetric Focal Loss described in:
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
include_background: bool = True,
delta: float = 0.7,
gamma: float = 2,
gamma: float = 2.0,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
include_background: whether to include loss computation for the background class. Defaults to True.
delta : weight of the foreground. Defaults to 0.7. (1-delta is weight of background)
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.include_background = include_background
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = self.delta * fore_ce
# Exclude background if needed
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
y_pred = y_pred[:, 1:]
y_true = y_true[:, 1:]

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
return loss
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred) # Shape (B, C, H, W, [D])

n_classes = y_pred.shape[1]

if not self.include_background:
# All classes are foreground, apply foreground logic
loss = self.delta * cross_entropy # (B, C, H, W)
elif n_classes == 1:
# Single class, must be foreground
loss = self.delta * cross_entropy # (B, 1, H, W)
else:
# Asymmetric logic: class 0 is BG, others are FG
# (B, H, W)
back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
# (B, C-1, H, W)
fore_ce = self.delta * cross_entropy[:, 1:]

loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss) # mean over batch, class, and spatial
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss)
if self.reduction == LossReduction.NONE.value:
return loss # returns (B, C, H, W)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')


class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
Actually, it's only supported for binary image segmentation now
AsymmetricUnifiedFocalLoss is a variant of Focal Loss, combining AsymmetricFocalLoss
and AsymmetricFocalTverskyLoss.
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
num_classes: int = 2,
weight: float = 0.5,
gamma: float = 0.5,
delta: float = 0.7,
sigmoid: bool = False,
softmax: bool = False,
lambda_focal: float = 0.5,
focal_loss_gamma: float = 2.0,
focal_loss_delta: float = 0.7,
tversky_loss_gamma: float = 0.75,
tversky_loss_delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
include_background: whether to include loss computation for the background class. Defaults to True.
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
sigmoid: if True, apply a sigmoid activation to the input y_pred.
softmax: if True, apply a softmax activation to the input y_pred.
lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based).
The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5.
focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0.
focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7.
tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75.
tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7.
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32)
>>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(softmax=True, to_onehot_y=True)
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.sigmoid = sigmoid
self.softmax = softmax
self.lambda_focal = lambda_focal

if sigmoid and softmax:
raise ValueError("Both sigmoid and softmax cannot be True.")

self.asy_focal_loss = AsymmetricFocalLoss(
include_background=self.include_background,
gamma=focal_loss_gamma,
delta=focal_loss_delta,
reduction=self.reduction,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
include_background=self.include_background,
gamma=tversky_loss_gamma,
delta=tversky_loss_delta,
reduction=self.reduction,
)

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
Raises:
ValueError: When input and target are different shape
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
y_pred : the shape should be BNH[WD].
y_true : the shape should be BNH[WD] or B1H[WD].
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)
n_pred_ch = y_pred.shape[1]

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
y_pred_act = y_pred
if self.sigmoid:
y_pred_act = torch.sigmoid(y_pred)
elif self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, softmax=True ignored.")
else:
y_pred_act = torch.softmax(y_pred, dim=1)

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
if n_pred_ch == 1 and not self.sigmoid:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
elif n_pred_ch > 1 or self.sigmoid:
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
if y_true.shape[1] != 1:
y_true = y_true.unsqueeze(1)
y_true = one_hot(y_true, num_classes=n_pred_ch)

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
# Ensure y_true has the same shape as y_pred_act
if y_true.shape != y_pred_act.shape:
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
y_true = y_true.unsqueeze(1) # Add channel dim

loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
if y_true.shape != y_pred_act.shape:
raise ValueError(
f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) "
f"after activations/one-hot"
)

if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over the batch and channel dims
if self.reduction == LossReduction.NONE.value:
return loss # returns [N, num_classes] losses
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
f_loss = self.asy_focal_loss(y_pred_act, y_true)
t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true)

loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss

return loss
2 changes: 1 addition & 1 deletion monai/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N

# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.steps_offset

Expand Down
Loading
Loading