Skip to content

Commit 8cd4911

Browse files
deploy changes
1 parent 806d505 commit 8cd4911

12 files changed

Lines changed: 789 additions & 16 deletions

File tree

asparagus/functional/pos_embed.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from einops import rearrange
4+
5+
6+
def interpolate_patch_embed_3d(patch_embed, in_shape, out_shape):
7+
"""Resizes patch embeddings using 3D trilinear interpolation.
8+
9+
Copied from SSL3D_classification/models/eva_mae_openneuro.py
10+
"""
11+
patch_embed = patch_embed.permute(0, 2, 1)
12+
patch_embed = rearrange(patch_embed, "B C (x y z) -> B C x y z", **in_shape)
13+
patch_embed = F.interpolate(patch_embed, size=list(out_shape.values()), mode="trilinear", align_corners=False)
14+
patch_embed = rearrange(patch_embed, "B C x y z -> B C (x y z)", **out_shape)
15+
return patch_embed.permute(0, 2, 1)
16+
17+
18+
def resize_pos_embed_3d(
19+
ckpt_pos_embed, model_pos_embed, num_prefix_tokens, pretrained_target_size, target_size, patch_embed_size
20+
):
21+
"""Resize a pos_embed tensor to match the model's expected shape.
22+
23+
Separates prefix tokens (cls/register), applies 3D trilinear interpolation
24+
to the patch tokens, and reattaches the prefix.
25+
"""
26+
if num_prefix_tokens > 0:
27+
prefix = ckpt_pos_embed[:, :num_prefix_tokens, :]
28+
patch_pos_embed = ckpt_pos_embed[:, num_prefix_tokens:, :]
29+
else:
30+
prefix = None
31+
patch_pos_embed = ckpt_pos_embed
32+
33+
in_shape = {
34+
"x": pretrained_target_size[0] // patch_embed_size[0],
35+
"y": pretrained_target_size[1] // patch_embed_size[1],
36+
"z": pretrained_target_size[2] // patch_embed_size[2],
37+
}
38+
39+
out_shape = {
40+
"x": target_size[0] // patch_embed_size[0],
41+
"y": target_size[1] // patch_embed_size[1],
42+
"z": target_size[2] // patch_embed_size[2],
43+
}
44+
45+
orig_dtype = patch_pos_embed.dtype
46+
resized = interpolate_patch_embed_3d(patch_pos_embed.float(), in_shape, out_shape).to(orig_dtype)
47+
48+
if prefix is not None:
49+
return torch.cat([prefix, resized], dim=1)
50+
return resized

asparagus/modules/data_modules/training.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
val_transforms: Optional[Compose] = None,
112112
test_transforms: Optional[Compose] = None,
113113
test_samples: Optional[list] = [],
114+
use_random_datasampler: Optional[bool] = True,
114115
):
115116
super().__init__()
116117
self.batch_size = batch_size
@@ -121,6 +122,7 @@ def __init__(
121122
self.train_split = train_split
122123
self.val_split = val_split
123124
self.test_samples = test_samples
125+
self.use_random_datasampler = use_random_datasampler
124126
logging.info(f"Using {self.num_workers} workers")
125127

126128
def setup(self, stage: Literal["fit", "test", "predict"]):
@@ -149,9 +151,10 @@ def setup_test(self):
149151
)
150152

151153
def train_dataloader(self):
152-
sampler = RandomSampler(self.train_dataset, num_samples=999999, replacement=True)
153-
if dist.is_initialized():
154-
sampler = DistributedSamplerWrapper(sampler)
154+
sampler = None
155+
if self.use_random_datasampler:
156+
sampler = RandomSampler(self.train_dataset, num_samples=999999, replacement=True)
157+
sampler = DistributedSamplerWrapper(sampler) if dist.is_initialized() else sampler
155158

156159
return DataLoader(
157160
self.train_dataset,
@@ -160,23 +163,25 @@ def train_dataloader(self):
160163
pin_memory=False,
161164
persistent_workers=True,
162165
drop_last=True,
166+
shuffle=sampler is None,
163167
sampler=sampler,
164168
)
165169

166170
def val_dataloader(self):
167-
sampler = RandomSampler(self.val_dataset, num_samples=999999, replacement=True)
168-
if dist.is_initialized():
169-
sampler = DistributedSamplerWrapper(sampler)
171+
sampler = None
172+
if self.use_random_datasampler:
173+
sampler = RandomSampler(self.val_dataset, num_samples=999999, replacement=True)
174+
sampler = DistributedSamplerWrapper(sampler) if dist.is_initialized() else sampler
170175

171176
return DataLoader(
172177
self.val_dataset,
173178
num_workers=self.num_workers // 2,
174179
batch_size=self.batch_size,
175180
pin_memory=False,
176-
shuffle=False,
177181
persistent_workers=True,
178-
drop_last=True,
182+
drop_last=False,
179183
sampler=sampler,
184+
shuffle=False,
180185
)
181186

182187
def test_dataloader(self):

asparagus/modules/lightning_modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .clsreg_module import ClassificationModule, RegressionModule
2+
from .linear_probe_module import LinearProbeModule
23
from .segmentation_module import SegmentationModule
34
from .self_supervised import SelfSupervisedModule
45

@@ -7,4 +8,5 @@
78
"ClassificationModule",
89
"RegressionModule",
910
"SelfSupervisedModule",
11+
"LinearProbeModule",
1012
]

asparagus/modules/lightning_modules/base_module.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
separate_encoder_decoder_weights,
1111
simple_warmup_cosine_decay_schedule,
1212
)
13+
from asparagus.functional.pos_embed import resize_pos_embed_3d
1314
from asparagus.functional.visualization import (
1415
get_logger_compatible_image_output_target,
1516
log_image_output_target_to_mlflow,
@@ -39,12 +40,16 @@ def __init__(
3940
nesterov: bool = True,
4041
momentum: float = 0.99,
4142
repeat_stem_weights: bool = True,
43+
pretrained_target_size: Optional[tuple] = None,
44+
target_size: Optional[tuple] = None,
4245
):
4346
super().__init__()
4447
self.learning_rate = learning_rate
4548
self.train_transforms = train_transforms
4649
self.test_transforms = test_transforms
4750
self.val_transforms = val_transforms
51+
self.pretrained_target_size = pretrained_target_size
52+
self.target_size = target_size
4853

4954
self.loss = None
5055
self.train_metrics = None
@@ -168,6 +173,24 @@ def load_state_dict(self, state_dict, load_decoder=True, *args, **kwargs):
168173
print(f"Repeating stem weights from {pt_input_channels} to {ft_input_channels} channels for {stem_name}.")
169174
state_dict[stem_name] = state_dict[stem_name].repeat(1, ft_input_channels, 1, 1, 1) / ft_input_channels
170175

176+
# Interpolate positional embeddings when spatial dimensions differ
177+
if self.pretrained_target_size is not None and self.target_size is not None:
178+
for key in list(state_dict.keys()):
179+
if key not in old_params or old_params[key].shape == state_dict[key].shape:
180+
continue
181+
if key.endswith("pos_embed"):
182+
num_prefix_tokens = getattr(self.model.eva, "num_prefix_tokens", 0)
183+
patch_embed_size = tuple(self.model.encoder.proj.weight.shape[2:])
184+
print(f"Interpolating {key}: {state_dict[key].shape} -> {old_params[key].shape}")
185+
state_dict[key] = resize_pos_embed_3d(
186+
state_dict[key],
187+
old_params[key],
188+
num_prefix_tokens=num_prefix_tokens,
189+
pretrained_target_size=self.pretrained_target_size,
190+
target_size=self.target_size,
191+
patch_embed_size=patch_embed_size,
192+
)
193+
171194
# Filter out keys that are not in the old state dict or have different shapes
172195
def should_load_key(key, state_dict, old_params, load_decoder):
173196
# reject all decoder keys regardless of their shape

0 commit comments

Comments
 (0)