Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b573cf2
Refactored `convert_ids_to_logits` and `convert_logits_to_ids` for en…
ATATC Feb 19, 2026
75ff257
Refactored `logitfy_no_grad` to simplify logic and ensure no-grad con…
ATATC Feb 19, 2026
b509191
Refactored `convert_ids_to_logits` to simplify shape modification log…
ATATC Feb 19, 2026
730db6e
Fixed output handling in deep supervision for segmentation preset. (#…
ATATC Feb 19, 2026
504bfcc
Refactored `forward` methods in loss functions for improved readabili…
ATATC Feb 19, 2026
e9251e8
Refactored dice metric to remove double summation. (#222)
ATATC Feb 19, 2026
ffcc803
Added `folder` method to `Dataset` class for retrieving folder attrib…
ATATC Feb 19, 2026
93dc0a3
Added validation to `convert_ids_to_logits` to enforce positive integ…
ATATC Feb 19, 2026
36c3584
Refactored type hints in `inspection.py` to explicitly use tuples for…
ATATC Feb 20, 2026
73ec891
Added worst validation case logging during training for improved trac…
ATATC Feb 20, 2026
0b5e9c4
Refactored statistical shape computation in `inspection.py` for simpl…
ATATC Feb 20, 2026
b5732b2
Refactored and renamed statistical shape computation in `inspection.p…
ATATC Feb 20, 2026
2b4bcc7
Added `TensorLoader` import to `__init__.py` in `data` module. (#222)
ATATC Feb 20, 2026
0710f5a
Added `order` parameter to `JointTransform` for configurable transfor…
ATATC Feb 20, 2026
e9cd384
Refactored `JointTransform` to use `Literal`-based type hints for `or…
ATATC Feb 20, 2026
5fabd42
Refactored `sanity_check` logic into a dedicated method `sanity_check…
ATATC Feb 21, 2026
d5005eb
Refactored `sanity_check` to accept `template_model` as a parameter a…
ATATC Feb 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions mipcandy/common/optim/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,11 @@ def __init__(self, num_classes: int, include_background: bool) -> None:
self.include_background: bool = include_background

def logitfy_no_grad(self, ids: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if self.num_classes != 1 and ids.shape[1] == 1:
if (d := ids.ndim - 2) not in (1, 2, 3):
raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions")
return convert_ids_to_logits(ids.int(), d, self.num_classes)
if self.num_classes != 1 and ids.shape[1] == 1:
with torch.no_grad():
return convert_ids_to_logits(ids.int(), self.num_classes)
return ids.float()

def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
if not self.validation_mode:
return self._forward(outputs, labels)
with torch.no_grad():
c, metrics = self._forward(outputs, labels)
outputs = convert_logits_to_ids(outputs)
dice = 0
for i in range(0 if self.include_background else 1, self.num_classes):
class_dice = binary_dice(outputs == i, labels == i).item()
dice += class_dice
metrics[f"dice {i}"] = class_dice
metrics["dice"] = dice_similarity_coefficient(
self.logitfy_no_grad(outputs), self.logitfy_no_grad(labels)
).item()
return c, metrics


class DiceCELossWithLogits(_SegmentationLoss):
def __init__(self, num_classes: int, *, lambda_ce: float = 1, lambda_soft_dice: float = 1,
Expand All @@ -81,6 +63,20 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T
c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - dice)
return c, metrics

def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
if not self.validation_mode:
return self._forward(outputs, labels)
with torch.no_grad():
c, metrics = self._forward(outputs, labels)
outputs = convert_logits_to_ids(outputs)
for i in range(0 if self.include_background else 1, self.num_classes):
class_dice = binary_dice(outputs == i, labels == i).item()
metrics[f"dice {i}"] = class_dice
metrics["dice"] = dice_similarity_coefficient(
self.logitfy_no_grad(outputs), self.logitfy_no_grad(labels)
).item()
return c, metrics


class DiceBCELossWithLogits(_SegmentationLoss):
def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1,
Expand All @@ -99,3 +95,12 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T
metrics = {"soft dice": dice.item(), "bce loss": bce.item()}
c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - dice)
return c, metrics

def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
if not self.validation_mode:
return self._forward(outputs, labels)
with torch.no_grad():
c, metrics = self._forward(outputs, labels)
outputs = convert_logits_to_ids(outputs).bool()
metrics["dice"] = binary_dice(outputs, labels.bool()).item()
return c, metrics
29 changes: 16 additions & 13 deletions mipcandy/data/convertion.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
from typing import Literal

import torch

from mipcandy.common import Normalize


def convert_ids_to_logits(ids: torch.Tensor, d: Literal[1, 2, 3], num_classes: int) -> torch.Tensor:
if ids.dtype != torch.int or ids.min() < 0:
raise TypeError("`ids` should be positive integers")
d += 1
if ids.ndim != d:
if ids.ndim == d + 1 and ids.shape[1] == 1:
ids = ids.squeeze(1)
else:
raise ValueError(f"`ids` should be {d} dimensional or {d + 1} dimensional with single channel")
logits = torch.zeros((ids.shape[0], num_classes, *ids.shape[1:]), device=ids.device, dtype=torch.float32)
logits.scatter_(1, ids.unsqueeze(1).long(), 1)
def convert_ids_to_logits(ids: torch.Tensor, num_classes: int, *, channel_dim: int = 1) -> torch.Tensor:
"""
:param ids: class ids (..., 1, ...)
:param num_classes: number of classes
:param channel_dim: the index of the channel dimension
:return: logits (..., num_classes, ...)
"""
shape = list(ids.shape)
shape[channel_dim] = num_classes
logits = torch.zeros(shape, device=ids.device, dtype=torch.float32)
logits.scatter_(channel_dim, ids.long(), 1)
return logits


def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor:
"""
:param logits: logits (..., num_classes, ...)
:param channel_dim: the index of the channel dimension
:return: class ids (..., 1, ...)
"""
return logits.round().int() if logits.shape[channel_dim] < 2 else logits.argmax(channel_dim, keepdim=True)


Expand Down
4 changes: 0 additions & 4 deletions mipcandy/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,4 @@ def soft_dice(outputs: torch.Tensor, labels: torch.Tensor, *, smooth: float = 1,
label_sum = labels.sum(axes)
intersection = (outputs * labels).sum(axes)
output_sum = outputs.sum(axes)
if batch_dice:
intersection = intersection.sum(0)
output_sum = output_sum.sum(0)
label_sum = label_sum.sum(0)
return do_reduction((2 * intersection + smooth) / (label_sum + output_sum + smooth), reduction)
5 changes: 4 additions & 1 deletion mipcandy/presets/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def validate_case(self, idx: int, image: torch.Tensor, label: torch.Tensor, tool
if self.deep_supervision:
if not isinstance(toolbox.criterion, DeepSupervisionWrapper):
raise TypeError("Deep supervision is enabled but criterion is not a `DeepSupervisionWrapper`")
output = output[0] if isinstance(output, (list, tuple)) else output[:, 0]
if isinstance(output, (list, tuple)):
output = output[0]
elif output.ndim > label.ndim:
output = output[:, 0]
loss, metrics = toolbox.criterion([output], [label])
else:
loss, metrics = toolbox.criterion(output, label)
Expand Down