Skip to content

Commit 1732af2

Browse files
authored
Ended up doing a lot but initially: convert_ids_to_logits() does not work with single-channel IDs (#223)
* Refactored `convert_ids_to_logits` and `convert_logits_to_ids` for enhanced flexibility and clarity. (#222) * Refactored `logitfy_no_grad` to simplify logic and ensure no-grad context only where necessary. (#222) * Refactored `convert_ids_to_logits` to simplify shape modification logic. (#222) * Fixed output handling in deep supervision for segmentation preset. (#222) * Refactored `forward` methods in loss functions for improved readability and consistency. (#222) * Refactored dice metric to remove double summation. (#222) * Added `folder` method to `Dataset` class for retrieving folder attribute. (#222) * Added validation to `convert_ids_to_logits` to enforce positive integer class IDs. (#222) * Refactored type hints in `inspection.py` to explicitly use tuples for shape definitions. (#222) * Added worst validation case logging during training for improved tracking. (#222) * Refactored statistical shape computation in `inspection.py` for simplification and removed redundant caching. (#222) * Refactored and renamed statistical shape computation in `inspection.py` for improved clarity and consistency. (#222) * Added `TensorLoader` import to `__init__.py` in `data` module. (#222) * Added `order` parameter to `JointTransform` for configurable transform application sequence. (#222) * Refactored `JointTransform` to use `Literal`-based type hints for `order` to enhance type safety. (#222) * Refactored `sanity_check` logic into a dedicated method `sanity_check` in `training.py` for improved modularity and code reuse. (#222) * Refactored `sanity_check` to accept `template_model` as a parameter and adjusted its usage in `training.py` for improved flexibility and clarity. (#222)
1 parent 37a3cf4 commit 1732af2

9 files changed

Lines changed: 96 additions & 62 deletions

File tree

mipcandy/common/optim/loss.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,29 +37,11 @@ def __init__(self, num_classes: int, include_background: bool) -> None:
3737
self.include_background: bool = include_background
3838

3939
def logitfy_no_grad(self, ids: torch.Tensor) -> torch.Tensor:
40-
with torch.no_grad():
41-
if self.num_classes != 1 and ids.shape[1] == 1:
42-
if (d := ids.ndim - 2) not in (1, 2, 3):
43-
raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions")
44-
return convert_ids_to_logits(ids.int(), d, self.num_classes)
40+
if self.num_classes != 1 and ids.shape[1] == 1:
41+
with torch.no_grad():
42+
return convert_ids_to_logits(ids.int(), self.num_classes)
4543
return ids.float()
4644

47-
def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
48-
if not self.validation_mode:
49-
return self._forward(outputs, labels)
50-
with torch.no_grad():
51-
c, metrics = self._forward(outputs, labels)
52-
outputs = convert_logits_to_ids(outputs)
53-
dice = 0
54-
for i in range(0 if self.include_background else 1, self.num_classes):
55-
class_dice = binary_dice(outputs == i, labels == i).item()
56-
dice += class_dice
57-
metrics[f"dice {i}"] = class_dice
58-
metrics["dice"] = dice_similarity_coefficient(
59-
self.logitfy_no_grad(outputs), self.logitfy_no_grad(labels)
60-
).item()
61-
return c, metrics
62-
6345

6446
class DiceCELossWithLogits(_SegmentationLoss):
6547
def __init__(self, num_classes: int, *, lambda_ce: float = 1, lambda_soft_dice: float = 1,
@@ -81,6 +63,20 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T
8163
c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - dice)
8264
return c, metrics
8365

66+
def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
67+
if not self.validation_mode:
68+
return self._forward(outputs, labels)
69+
with torch.no_grad():
70+
c, metrics = self._forward(outputs, labels)
71+
outputs = convert_logits_to_ids(outputs)
72+
for i in range(0 if self.include_background else 1, self.num_classes):
73+
class_dice = binary_dice(outputs == i, labels == i).item()
74+
metrics[f"dice {i}"] = class_dice
75+
metrics["dice"] = dice_similarity_coefficient(
76+
self.logitfy_no_grad(outputs), self.logitfy_no_grad(labels)
77+
).item()
78+
return c, metrics
79+
8480

8581
class DiceBCELossWithLogits(_SegmentationLoss):
8682
def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1,
@@ -99,3 +95,12 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T
9995
metrics = {"soft dice": dice.item(), "bce loss": bce.item()}
10096
c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - dice)
10197
return c, metrics
98+
99+
def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
100+
if not self.validation_mode:
101+
return self._forward(outputs, labels)
102+
with torch.no_grad():
103+
c, metrics = self._forward(outputs, labels)
104+
outputs = convert_logits_to_ids(outputs).bool()
105+
metrics["dice"] = binary_dice(outputs, labels.bool()).item()
106+
return c, metrics

mipcandy/data/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from mipcandy.data.convertion import convert_ids_to_logits, convert_logits_to_ids, auto_convert
2-
from mipcandy.data.dataset import Loader, UnsupervisedDataset, SupervisedDataset, DatasetFromMemory, MergedDataset, \
3-
PathBasedUnsupervisedDataset, SimpleDataset, PathBasedSupervisedDataset, NNUNetDataset, BinarizedDataset
2+
from mipcandy.data.dataset import Loader, TensorLoader, UnsupervisedDataset, SupervisedDataset, DatasetFromMemory, \
3+
MergedDataset, PathBasedUnsupervisedDataset, SimpleDataset, PathBasedSupervisedDataset, NNUNetDataset, \
4+
BinarizedDataset
45
from mipcandy.data.download import download_dataset
56
from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop
67
from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \

mipcandy/data/convertion.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1-
from typing import Literal
2-
31
import torch
42

53
from mipcandy.common import Normalize
64

75

8-
def convert_ids_to_logits(ids: torch.Tensor, d: Literal[1, 2, 3], num_classes: int) -> torch.Tensor:
9-
if ids.dtype != torch.int or ids.min() < 0:
10-
raise TypeError("`ids` should be positive integers")
11-
d += 1
12-
if ids.ndim != d:
13-
if ids.ndim == d + 1 and ids.shape[1] == 1:
14-
ids = ids.squeeze(1)
15-
else:
16-
raise ValueError(f"`ids` should be {d} dimensional or {d + 1} dimensional with single channel")
17-
logits = torch.zeros((ids.shape[0], num_classes, *ids.shape[1:]), device=ids.device, dtype=torch.float32)
18-
logits.scatter_(1, ids.unsqueeze(1).long(), 1)
6+
def convert_ids_to_logits(ids: torch.Tensor, num_classes: int, *, channel_dim: int = 1) -> torch.Tensor:
7+
"""
8+
:param ids: class ids (..., 1, ...)
9+
:param num_classes: number of classes
10+
:param channel_dim: the index of the channel dimension
11+
:return: logits (..., num_classes, ...)
12+
"""
13+
if torch.is_floating_point(ids) or (ids < 0).any():
14+
raise TypeError("Class ids must be positive integers")
15+
shape = list(ids.shape)
16+
shape[channel_dim] = num_classes
17+
logits = torch.zeros(shape, device=ids.device, dtype=torch.float32)
18+
logits.scatter_(channel_dim, ids.long(), 1)
1919
return logits
2020

2121

2222
def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor:
23+
"""
24+
:param logits: logits (..., num_classes, ...)
25+
:param channel_dim: the index of the channel dimension
26+
:return: class ids (..., 1, ...)
27+
"""
2328
return logits.round().int() if logits.shape[channel_dim] < 2 else logits.argmax(channel_dim, keepdim=True)
2429

2530

mipcandy/data/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def __init__(self, folder: str | PathLike[str], *, split: str | Literal["Tr", "T
339339
self._prefix: str = prefix
340340
self._align_spacing: bool = align_spacing
341341

342+
def folder(self) -> str:
343+
return self._folder
344+
342345
@staticmethod
343346
def _create_subset(folder: str) -> None:
344347
if exists(folder) and len(listdir(folder)) > 0:

mipcandy/data/inspection.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def save(self, path: str | PathLike[str]) -> None:
9090
}, f)
9191

9292
def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], AmbiguousShape]) -> tuple[
93-
AmbiguousShape | None, AmbiguousShape, AmbiguousShape]:
93+
tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
9494
depths = []
9595
widths = []
9696
heights = []
@@ -105,26 +105,29 @@ def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], AmbiguousShape
105105
widths.append(shape[2])
106106
return tuple(depths) if depths else None, tuple(heights), tuple(widths)
107107

108-
def shapes(self) -> tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape]:
108+
def shapes(self) -> tuple[tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
109109
if self._shapes:
110110
return self._shapes
111111
self._shapes = self._get_shapes(lambda annotation: annotation.shape)
112112
return self._shapes
113113

114-
def foreground_shapes(self) -> tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape]:
114+
def statistical_shape(self, *, percentile: float = .95) -> Shape:
115+
depths, heights, widths = self.shapes()
116+
percentile *= 100
117+
sfs = (round(np.percentile(heights, percentile)), round(np.percentile(widths, percentile)))
118+
return (round(np.percentile(depths, percentile)),) + sfs if depths else sfs
119+
120+
def foreground_shapes(self) -> tuple[tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
115121
if self._foreground_shapes:
116122
return self._foreground_shapes
117123
self._foreground_shapes = self._get_shapes(lambda annotation: annotation.foreground_shape())
118124
return self._foreground_shapes
119125

120126
def statistical_foreground_shape(self, *, percentile: float = .95) -> Shape:
121-
if self._statistical_foreground_shape:
122-
return self._statistical_foreground_shape
123127
depths, heights, widths = self.foreground_shapes()
124128
percentile *= 100
125129
sfs = (round(np.percentile(heights, percentile)), round(np.percentile(widths, percentile)))
126-
self._statistical_foreground_shape = (round(np.percentile(depths, percentile)),) + sfs if depths else sfs
127-
return self._statistical_foreground_shape
130+
return (round(np.percentile(depths, percentile)),) + sfs if depths else sfs
128131

129132
def crop_foreground(self, i: int, *, expand_ratio: float = 1) -> tuple[torch.Tensor, torch.Tensor]:
130133
image, label = self._dataset.image(i), self._dataset.label(i)
@@ -371,10 +374,10 @@ def __init__(self, annotations: InspectionAnnotations, batch_size: int, *, num_p
371374
self._images, self._labels = images, images.copy()
372375
self._batch_size: int = batch_size
373376
self._oversample_rate: float = oversample_rate
374-
sfs = self._annotations.statistical_foreground_shape(percentile=self._percentile)
375-
sfs = [ceil(s / min_factor) * min_factor for s in sfs]
376-
self._roi_shape: Shape = (min(sfs[0], 2048), min(sfs[1], 2048)) if len(sfs) == 2 else (
377-
min(sfs[0], 128), min(sfs[1], 128), min(sfs[2], 128))
377+
median_shape = self._annotations.statistical_shape(percentile=self._percentile)
378+
median_shape = [ceil(s / min_factor) * min_factor for s in median_shape]
379+
self._roi_shape: Shape = (min(median_shape[0], 2048), min(median_shape[1], 2048)) if len(
380+
median_shape) == 2 else (min(median_shape[0], 128), min(median_shape[1], 128), min(median_shape[2], 128))
378381

379382
def convert_idx(self, idx: int) -> int:
380383
idx, idx2 = self._images[idx], self._labels[idx]

mipcandy/data/transform.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,38 @@
1+
from typing import Literal
2+
13
import torch
24
from torch import nn
35

46
from mipcandy.types import Transform
57

8+
type _Order = Literal["transform", "image_only", "label_only"]
9+
610

711
class JointTransform(nn.Module):
812
def __init__(self, *, transform: Transform | None = None, image_only: Transform | None = None,
9-
label_only: Transform | None = None, keys: tuple[str, str] = ("image", "label")) -> None:
13+
label_only: Transform | None = None, keys: tuple[str, str] = ("image", "label"),
14+
order: tuple[_Order, _Order, _Order] = ("transform", "image_only", "label_only")) -> None:
1015
super().__init__()
1116
self.transform: Transform | None = transform
1217
self.image_only: Transform | None = image_only
1318
self.label_only: Transform | None = label_only
1419
self._keys: tuple[str, str] = keys
20+
self._order: tuple[_Order, _Order, _Order] = order
1521

1622
def forward(self, image: torch.Tensor, label: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1723
ik, lk = self._keys
1824
data = {ik: image, lk: label}
19-
if self.transform:
20-
data = self.transform(data)
21-
if self.image_only:
22-
data[ik] = self.image_only(data[ik])
23-
if self.label_only:
24-
data[lk] = self.label_only(data[lk])
25+
for t in self._order:
26+
transform = getattr(self, t)
27+
if not transform:
28+
continue
29+
match t:
30+
case "transform":
31+
data = transform(data)
32+
case "image_only":
33+
data[ik] = transform(data[ik])
34+
case "label_only":
35+
data[lk] = transform(data[lk])
2536
return data[ik], data[lk]
2637

2738

mipcandy/metrics.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,4 @@ def soft_dice(outputs: torch.Tensor, labels: torch.Tensor, *, smooth: float = 1,
7979
label_sum = labels.sum(axes)
8080
intersection = (outputs * labels).sum(axes)
8181
output_sum = outputs.sum(axes)
82-
if batch_dice:
83-
intersection = intersection.sum(0)
84-
output_sum = output_sum.sum(0)
85-
label_sum = label_sum.sum(0)
8682
return do_reduction((2 * intersection + smooth) / (label_sum + output_sum + smooth), reduction)

mipcandy/presets/segmentation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,10 @@ def validate_case(self, idx: int, image: torch.Tensor, label: torch.Tensor, tool
158158
if self.deep_supervision:
159159
if not isinstance(toolbox.criterion, DeepSupervisionWrapper):
160160
raise TypeError("Deep supervision is enabled but criterion is not a `DeepSupervisionWrapper`")
161-
output = output[0] if isinstance(output, (list, tuple)) else output[:, 0]
161+
if isinstance(output, (list, tuple)):
162+
output = output[0]
163+
elif output.ndim > label.ndim:
164+
output = output[:, 0]
162165
loss, metrics = toolbox.criterion([output], [label])
163166
else:
164167
loss, metrics = toolbox.criterion(output, label)

mipcandy/training.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from mipcandy.frontend import Frontend
2828
from mipcandy.layer import WithPaddingModule, WithNetwork
2929
from mipcandy.profiler import Profiler
30-
from mipcandy.sanity_check import sanity_check
30+
from mipcandy.sanity_check import sanity_check, SanityCheckResult
3131
from mipcandy.types import Params, Setting, AmbiguousShape
3232

3333

@@ -391,6 +391,12 @@ def empty_cache(self) -> None:
391391

392392
# Training methods
393393

394+
def sanity_check(self, template_model: nn.Module, example_shape: AmbiguousShape) -> SanityCheckResult:
395+
try:
396+
return sanity_check(template_model, example_shape, device=self._device)
397+
finally:
398+
del template_model
399+
394400
@abstractmethod
395401
def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
396402
str, float]]:
@@ -454,7 +460,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
454460
template_model = self.build_network(example_shape)
455461
model_name = template_model.__class__.__name__
456462
self.log(f"Model: {model_name}")
457-
sanity_check_result = sanity_check(template_model, example_shape, device=self._device)
463+
sanity_check_result = self.sanity_check(template_model, example_shape)
458464
self.log(str(sanity_check_result))
459465
self.log(f"Example output shape: {tuple(sanity_check_result.output.shape)}")
460466
self.record_profiler()
@@ -467,7 +473,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
467473
self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note,
468474
sanity_check_result.num_macs, sanity_check_result.num_params, num_epochs,
469475
early_stop_tolerance)
470-
del sanity_check_result, template_model, example_input
476+
del sanity_check_result, example_input
471477
self.empty_cache()
472478
try:
473479
for epoch in range(self._tracker.epoch, self._tracker.epoch + num_epochs):
@@ -506,6 +512,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
506512
self.log(f"Estimated time of completion in {etc:.1f} seconds at {datetime.fromtimestamp(
507513
time() + etc):%m-%d %H:%M:%S}")
508514
self.show_metrics_per_case(epoch, metrics)
515+
self.log(f"Validation worst case: {self._tracker.worst_case}")
509516
self.show_metrics(epoch, metrics, "validation", lookup_prefix="val ")
510517
if score > self._tracker.best_score:
511518
copy(checkpoint_path("latest"), checkpoint_path("best"))

0 commit comments

Comments
 (0)