Skip to content

Commit 0d8473e

Browse files
authored
Merge pull request #17 from lukasugar/asparagus_evals/segmentation
Asparagus evals - segmentation
2 parents 00d8b5e + 84e4dda commit 0d8473e

9 files changed

Lines changed: 404 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,6 @@ line-length = 100
6060

6161
[tool.ruff.lint]
6262
ignore = ["F722"]
63+
64+
[tool.pytest.ini_options]
65+
norecursedirs = ["third_party", ".scratch"]

src/asparagus_bridge/README.md

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,72 @@ can be used as a reference for Task 3 regression runs.
143143

144144
#### Segmentation
145145

146-
TBD
146+
##### Task 2
147+
Task 2 is `SEG009_FOMO26_Meningioma`; for sMRI MAE use the FLAIR-only custom variant `SEG009_FOMO26_Meningioma_FLAIR`.
148+
149+
Prepare the raw FOMO task folders and convert them to asparagus tensors:
150+
151+
```sh
152+
cd "$ASPARAGUS_SOURCE"
153+
unzip -n Task_2.zip -d Task_2
154+
155+
cd /Users/lukasecerovic/Documents/repos/sMRI/smri-fm
156+
uv run asp_process \
157+
--dataset SEG009_FOMO26_Meningioma_CUSTOM \
158+
--task_name SEG009_FOMO26_Meningioma_FLAIR \
159+
--modalities flair \
160+
--save_as_tensor \
161+
--num_workers 4
162+
```
163+
164+
The segmentation processors write `split_80_10_10.json` and
165+
`TEST_80_10_10.json`. Override the asparagus segmentation defaults to use
166+
those splits when finetuning.
167+
168+
Convert the sMRI MAE checkpoint:
169+
170+
```sh
171+
uv run python -c 'from asparagus_bridge.checkpoint import convert_checkpoint; convert_checkpoint("smri_mae", "runs/mae/checkpoint-last.pth", "runs/mae/asparagus.ckpt")'
172+
```
173+
174+
Task 2 smoke test:
175+
176+
```sh
177+
uv run asp_finetune_seg --config-name finetuning/smoke_test_seg_task_2.yaml
178+
```
179+
180+
##### Task 4
181+
Task 4 is `SEG010_FOMO26_TrigeminalNeuralgia` and is already single-channel T2w.
182+
183+
```sh
184+
cd "$ASPARAGUS_SOURCE"
185+
unzip -n Task_4.zip -d Task_4
186+
187+
cd <repo_root>
188+
189+
uv run asp_process --dataset SEG010 --save_as_tensor --num_workers 4
190+
```
191+
192+
Task 4 smoke test:
193+
194+
```sh
195+
uv run asp_finetune_seg --config-name finetuning/smoke_test_seg_task_4.yaml
196+
```
197+
198+
##### Notes on segmentation
199+
`SmriMaeSegBackbone` currently inherits asparagus/gardening-tools sliding-window
200+
inference. That path is compatible with asparagus eval, but it sums overlapping
201+
logits without overlap-count normalization or Gaussian/Hann weighting. Treat
202+
that as a known follow-up if segmentation quality near patch borders matters.
203+
204+
Future segmentation variants worth testing:
205+
206+
- Canonical Task 2 two-channel finetuning (`flair`, `dwi`) with explicit
207+
multi-channel checkpoint stem adaptation.
208+
- MAE reconstruction decoder reuse instead of the current Primus-like patch
209+
segmentation decoder.
210+
- Normalized sliding-window blending with overlap-count averaging, Gaussian
211+
weighting, or Hann weighting.
147212

148213
#### Linear probing
149214

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# @package _global_
2+
defaults:
3+
- /default_finetune_seg
4+
- /model/smri_mae@model
5+
- _self_
6+
7+
# hydra:
8+
# job_logging:
9+
# root:
10+
# level: DEBUG
11+
12+
task: SEG009_FOMO26_Meningioma_FLAIR
13+
14+
checkpoint_path: /Users/lukasecerovic/Documents/repos/sMRI/smri-fm/.scratch/pretrained_mae_checkpoint/checkpoint-last-asparagus-NEW.ckpt
15+
16+
data:
17+
train_split: split_80_10_10
18+
test_split: TEST_80_10_10
19+
20+
hardware:
21+
num_workers: 2
22+
precision: 32-true
23+
compile_mode: null
24+
25+
training:
26+
epochs: 1
27+
batch_size: 1
28+
patch_size: [64, 64, 64]
29+
train_batches_per_epoch_per_device: 2
30+
val_batches_per_epoch_per_device: 1
31+
check_val_every_n_epoch: 1
32+
33+
model:
34+
_seg_net:
35+
patch_size: 8
36+
37+
logger:
38+
wandb_logging: false
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# @package _global_
2+
defaults:
3+
- /default_finetune_seg
4+
- /model/smri_mae@model
5+
- _self_
6+
7+
task: SEG010_FOMO26_TrigeminalNeuralgia
8+
9+
checkpoint_path: /Users/lukasecerovic/Documents/repos/sMRI/smri-fm/.scratch/pretrained_mae_checkpoint/checkpoint-last-asparagus-NEW.ckpt
10+
11+
data:
12+
train_split: split_80_10_10
13+
test_split: TEST_80_10_10
14+
15+
hardware:
16+
num_workers: 2
17+
precision: 32-true
18+
compile_mode: null
19+
20+
training:
21+
epochs: 1
22+
batch_size: 1
23+
patch_size: [64, 64, 64]
24+
train_batches_per_epoch_per_device: 2
25+
val_batches_per_epoch_per_device: 1
26+
check_val_every_n_epoch: 1
27+
28+
model:
29+
_seg_net:
30+
patch_size: 8
31+
32+
logger:
33+
wandb_logging: false

src/asparagus_bridge/configs/model/smri_mae.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ seg_net: smri_mae
88
cls_net: smri_mae
99
reg_net: smri_mae
1010

11-
# Segmentation: placeholder.
1211
_seg_net:
1312
_target_: asparagus_bridge.models_smri_mae.SmriMaeSegBackbone
1413
dimensions: ${model.dimensions}
14+
img_size: ${training.patch_size}
15+
patch_size: 16
16+
depth: 12
17+
embed_dim: 768
18+
num_heads: 12
1519

1620
_cls_net:
1721
_target_: asparagus_bridge.models_smri_mae.SmriMaeClsRegBackbone

src/asparagus_bridge/models_smri_mae.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
- weights loaded later by asparagus.BaseModule via load_state_dict(strict=False)
88
"""
99

10+
import math
11+
12+
from gardening_tools.modules.networks.BaseNet import BaseNet
13+
from gardening_tools.modules.networks.components.transformer import PatchDecode
1014
import torch.nn as nn
1115
from torch import Tensor
1216

@@ -62,18 +66,68 @@ def _features(self, x: Tensor) -> Tensor:
6266
def forward(self, x: Tensor) -> Tensor:
6367
"""Encoder + head """
6468
return self.head(self._features(x))
65-
69+
6670
def _encode(self, x: Tensor) -> Tensor:
6771
""" Encoder output in format used for linear probing"""
6872
feat = self._features(x)
6973
return feat[:, :, None, None, None]
7074

7175

72-
class SmriMaeSegBackbone(nn.Module):
73-
"""Placeholder for ViT-based segmentation backbone."""
76+
class SmriMaeSegBackbone(BaseNet):
77+
"""MAE ViT segmentation backbone with a Primus-like patch decoder."""
78+
79+
def __init__(
80+
self,
81+
input_channels: int,
82+
output_channels: int,
83+
img_size: int | tuple[int, int, int] = (160, 160, 160),
84+
patch_size: int | tuple[int, int, int] = (16, 16, 16),
85+
depth: int = 12,
86+
embed_dim: int = 768,
87+
num_heads: int = 12,
88+
dimensions: str = "3D",
89+
**_ignored,
90+
):
91+
super().__init__()
92+
assert dimensions == "3D", f"only 3D supported, got dimensions={dimensions}"
93+
94+
self.num_classes = output_channels
95+
self.stem_weight_name = "encoder.patch_embed.weight"
7496

75-
def __init__(self, *args, **kwargs):
76-
raise NotImplementedError("SmriMaeSegBackbone is not yet implemented")
97+
self.encoder = MaskedViT(
98+
img_size=img_size,
99+
patch_size=patch_size,
100+
in_chans=input_channels,
101+
depth=depth,
102+
embed_dim=embed_dim,
103+
num_heads=num_heads,
104+
class_token=True,
105+
)
106+
self.grid_size = self.encoder.patchify.grid_size
107+
self.decoder = PatchDecode(
108+
patch_size=self.encoder.patchify.patch_size,
109+
embed_dim=embed_dim,
110+
out_channels=output_channels,
111+
)
77112

78113
def forward(self, x: Tensor) -> Tensor:
79-
raise NotImplementedError
114+
_, _, patch_embeds, _, _ = self.encoder(x)
115+
expected_tokens = math.prod(self.grid_size)
116+
if patch_embeds.shape[1] != expected_tokens:
117+
raise ValueError(
118+
"unexpected MAE patch token count: "
119+
f"got {patch_embeds.shape[1]}, expected {expected_tokens}"
120+
)
121+
122+
features = patch_embeds.reshape(
123+
x.shape[0],
124+
*self.grid_size,
125+
patch_embeds.shape[-1],
126+
)
127+
features = features.permute(0, 4, 1, 2, 3).contiguous()
128+
return self.decoder(features)
129+
130+
# Inherits BaseNet.sliding_window_predict. That implementation sums
131+
# overlapping logits without normalizing by an overlap-count map.
132+
# TODO: replace inherited sliding-window accumulation with normalized overlap averaging.
133+
# TODO: consider Gaussian/Hann weighting so patch centers contribute more than patch borders.

0 commit comments

Comments
 (0)