Skip to content

Commit 8cbf45a

Browse files
deploy changes
1 parent 317a854 commit 8cbf45a

11 files changed

Lines changed: 61 additions & 39 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,4 @@ tail_log.bash
194194
CLAUDE.md
195195

196196
.vscode/
197+
debug_logs

asparagus/functional/metrics/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def compute_alignment_uniformity(
102102
# Flatten spatial dimensions if present
103103
if features.dim() > 2:
104104
B = features.shape[0]
105-
features = features.view(B, -1).float() # (B, D)
105+
features = features.reshape(B, -1).float() # (B, D)
106106
else:
107107
features = features.float()
108108

asparagus/functional/metrics/features.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@
77
from typing import Dict
88

99

10+
def _to_channel_samples(features: torch.Tensor) -> torch.Tensor:
11+
"""Reshape [B, C, *spatial] → [B*N_spatial, C] for channel-space analysis.
12+
2D inputs [B, C] are returned unchanged."""
13+
if features.dim() <= 2:
14+
return features
15+
C = features.shape[1]
16+
spatial_dims = list(range(2, features.dim()))
17+
return features.permute(0, *spatial_dims, 1).reshape(-1, C)
18+
19+
1020
def compute_train(encoder_features: torch.Tensor) -> Dict[str, float]:
1121
"""Metrics computed every training step."""
1222
return compute_embedding_metrics(encoder_features)
@@ -35,10 +45,7 @@ def compute_feature_covariance(features: torch.Tensor) -> Dict[str, float]:
3545
if features is None:
3646
return {}
3747

38-
# Flatten spatial dimensions if present
39-
if features.dim() > 2:
40-
B = features.shape[0]
41-
features = features.view(B, -1)
48+
features = _to_channel_samples(features)
4249

4350
features_centered = features - features.mean(dim=0, keepdim=True)
4451
cov = torch.mm(features_centered.T, features_centered) / (features.shape[0] - 1)
@@ -84,9 +91,7 @@ def compute_collapse_score(features: torch.Tensor, eps: float = 1e-8) -> Dict[st
8491
if features is None or features.numel() == 0:
8592
return {}
8693

87-
if features.dim() > 2:
88-
B = features.shape[0]
89-
features = features.view(B, -1)
94+
features = _to_channel_samples(features)
9095

9196
dim_variance = features.var(dim=0, unbiased=False)
9297

@@ -135,11 +140,7 @@ def compute_participation_ratio(features: torch.Tensor, k_values: list = [10, 50
135140
if features is None or features.numel() == 0:
136141
return {}
137142

138-
if features.dim() > 2:
139-
B = features.shape[0]
140-
features = features.view(B, -1).float()
141-
else:
142-
features = features.float()
143+
features = _to_channel_samples(features).float()
143144

144145
features_centered = features - features.mean(dim=0, keepdim=True)
145146

@@ -183,11 +184,7 @@ def compute_whitening_diagnostics(features: torch.Tensor) -> Dict[str, float]:
183184
if features is None or features.numel() == 0:
184185
return {}
185186

186-
if features.dim() > 2:
187-
B = features.shape[0]
188-
features = features.view(B, -1).float()
189-
else:
190-
features = features.float()
187+
features = _to_channel_samples(features).float()
191188

192189
# Compute correlation matrix
193190
features_centered = features - features.mean(dim=0, keepdim=True)

asparagus/functional/metrics/stability.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ def compute_feature_stability(
105105
# Flatten spatial dimensions if present
106106
if current_features.dim() > 2:
107107
B = current_features.shape[0]
108-
current_features = current_features.view(B, -1).float()
108+
current_features = current_features.reshape(B, -1).float()
109109
else:
110110
current_features = current_features.float()
111111

112112
if previous_features.dim() > 2:
113113
B = previous_features.shape[0]
114-
previous_features = previous_features.view(B, -1).float()
114+
previous_features = previous_features.reshape(B, -1).float()
115115
else:
116116
previous_features = previous_features.float()
117117

asparagus/modules/networks/primus.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ def freeze_backbone(self):
176176

177177

178178
@depends_on_timm()
179-
def primus_s(input_channels, output_channels, patch_size, patch_embed_size=(8, 8, 8), patch_drop_rate=0.0):
179+
def primus_s(input_channels, output_channels, patch_size, patch_embed_size=8, patch_drop_rate=0.0):
180180
model = Primus(
181181
input_channels=input_channels,
182182
embed_dim=396,
183-
patch_embed_size=patch_embed_size,
183+
patch_embed_size=(patch_embed_size,) * len(patch_size),
184184
num_classes=output_channels,
185185
eva_depth=12,
186186
eva_numheads=6,
@@ -193,11 +193,11 @@ def primus_s(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
193193

194194

195195
@depends_on_timm()
196-
def primus_b(input_channels, output_channels, patch_size, patch_embed_size=(8, 8, 8), patch_drop_rate=0.0):
196+
def primus_b(input_channels, output_channels, patch_size, patch_embed_size=8, patch_drop_rate=0.0):
197197
model = Primus(
198198
input_channels=input_channels,
199199
embed_dim=792,
200-
patch_embed_size=patch_embed_size,
200+
patch_embed_size=(patch_embed_size,) * len(patch_size),
201201
num_classes=output_channels,
202202
eva_depth=12,
203203
eva_numheads=12,
@@ -211,11 +211,11 @@ def primus_b(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
211211

212212

213213
@depends_on_timm()
214-
def primus_m(input_channels, output_channels, patch_size, patch_embed_size=(8, 8, 8), patch_drop_rate=0.0):
214+
def primus_m(input_channels, output_channels, patch_size, patch_embed_size=8, patch_drop_rate=0.0):
215215
model = Primus(
216216
input_channels=input_channels,
217217
embed_dim=864,
218-
patch_embed_size=patch_embed_size,
218+
patch_embed_size=(patch_embed_size,) * len(patch_size),
219219
num_classes=output_channels,
220220
eva_depth=16,
221221
eva_numheads=12,
@@ -229,11 +229,11 @@ def primus_m(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
229229

230230

231231
@depends_on_timm()
232-
def primus_l(input_channels, output_channels, patch_size, patch_embed_size=(8, 8, 8), patch_drop_rate=0.0):
232+
def primus_l(input_channels, output_channels, patch_size, patch_embed_size=8, patch_drop_rate=0.0):
233233
model = Primus(
234234
input_channels=input_channels,
235235
embed_dim=1056,
236-
patch_embed_size=patch_embed_size,
236+
patch_embed_size=(patch_embed_size,) * len(patch_size),
237237
num_classes=output_channels,
238238
eva_depth=24,
239239
eva_numheads=16,
@@ -247,11 +247,11 @@ def primus_l(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
247247

248248

249249
@depends_on_timm()
250-
def primus_h(input_channels, output_channels, patch_size, patch_embed_size=(8, 8, 8), patch_drop_rate=0.0):
250+
def primus_h(input_channels, output_channels, patch_size, patch_embed_size=8, patch_drop_rate=0.0):
251251
model = Primus(
252252
input_channels=input_channels,
253253
embed_dim=1248,
254-
patch_embed_size=patch_embed_size,
254+
patch_embed_size=(patch_embed_size,) * len(patch_size),
255255
num_classes=output_channels,
256256
eva_depth=32,
257257
eva_numheads=16,
@@ -265,11 +265,11 @@ def primus_h(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
265265

266266

267267
@depends_on_timm()
268-
def primus_g(input_channels, output_channels, patch_size, patch_embed_size=(8, 8, 8), patch_drop_rate=0.0):
268+
def primus_g(input_channels, output_channels, patch_size, patch_embed_size=8, patch_drop_rate=0.0):
269269
model = Primus(
270270
input_channels=input_channels,
271271
embed_dim=1584,
272-
patch_embed_size=patch_embed_size,
272+
patch_embed_size=(patch_embed_size,) * len(patch_size),
273273
num_classes=output_channels,
274274
eva_depth=32,
275275
eva_numheads=24,
@@ -284,13 +284,13 @@ def primus_g(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
284284

285285
@depends_on_timm()
286286
def primus_m_clsreg(
287-
input_channels, output_channels, patch_size, patch_embed_size=(8, 8, 8), dropout_rate=0.0, late_fusion: bool = False
287+
input_channels, output_channels, patch_size, patch_embed_size=8, dropout_rate=0.0, late_fusion: bool = False
288288
):
289289
return PrimusCLSREG(
290290
input_channels=input_channels,
291291
output_channels=output_channels,
292292
embed_dim=864,
293-
patch_embed_size=patch_embed_size,
293+
patch_embed_size=(patch_embed_size,) * len(patch_size),
294294
eva_depth=16,
295295
eva_numheads=12,
296296
input_shape=patch_size,

configs/default_pretrain.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ training:
2525
accumulate_grad_batches: 1
2626
patch_size: [160, 160, 160]
2727
seed: ${random:0,1000000}
28-
mask_patch_size: 4
28+
mask_patch_size: ${model.patch_embed_size}
2929
mask_ratio: 0.6
3030
max_samples: 6_000_000
3131
warmup_ratio: 0.02

configs/model/core/primus.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,33 @@ _pretrain_net:
22
_target_: asparagus.modules.networks.primus.${model.pretrain_net}
33
patch_size: ${training.patch_size}
44
patch_drop_rate: ${training.mask_ratio}
5+
patch_embed_size: ${model.patch_embed_size}
56

67
_seg_net:
78
_target_: asparagus.modules.networks.primus.${model.seg_net}
89
patch_size: ${training.patch_size}
910
patch_drop_rate: 0.0
11+
patch_embed_size: ${model.patch_embed_size}
1012

1113
_cls_net:
1214
_target_: asparagus.modules.networks.primus.${model.cls_net}
13-
patch_size: ${training.target_size}
15+
patch_size: ${training.patch_size}
16+
patch_drop_rate: 0.0
17+
patch_embed_size: ${model.patch_embed_size}
1418

1519
_plugin_seg_net:
1620
_target_: asparagus.modules.networks.primus.${model.plugin_seg_net}
1721
patch_size: ${training.patch_size}
1822
patch_drop_rate: 0.0
23+
patch_embed_size: ${model.patch_embed_size}
1924

2025
pretrain_optim: AdamW
2126
pretrain_lr: 3e-4
2227
train_optim: AdamW
2328
train_lr: 3e-4
2429
finetune_optim: AdamW
2530
finetune_lr: 3e-5
31+
patch_embed_size: 8
2632

2733
weight_decay: 5e-2
2834
nesterov: False

configs/model/core/resenc_unet.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ finetune_lr: 1e-3
3232
weight_decay: 3e-5
3333
nesterov: True
3434
momentum: 0.99
35+
patch_embed_size: 4
3536

3637
min_test_patch_size: [96, 96, 96]
3738
deep_supervision: False

configs/model/core/unet.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ finetune_lr: 1e-3
2626
weight_decay: 3e-5
2727
nesterov: True
2828
momentum: 0.99
29+
patch_embed_size: 4
2930

3031
use_skip_connections: True
3132
deep_supervision: False
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# @package _global_
2+
defaults:
3+
- /default_pretrain
4+
- /model/primus_m@model
5+
- /hardware/1node8gpus@hardware
6+
- _self_
7+
8+
task: PT900_FOMO300K
9+
root: datapaper
10+
stem: pretrain
11+
12+
checkpoint_run_id:
13+
14+
training:
15+
batch_size: 16
16+
max_samples: 6_000_000

0 commit comments

Comments
 (0)