Skip to content

Commit 0d17124

Browse files
deploy changes
1 parent fe723a7 commit 0d17124

14 files changed

Lines changed: 149 additions & 34 deletions

File tree

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Download and load models from HuggingFace Hub."""
2+
3+
from huggingface_hub import hf_hub_download, list_repo_files
4+
5+
6+
def download_hf_checkpoint(repo_id: str) -> str:
7+
"""Download a checkpoint file from HuggingFace Hub. Returns local path."""
8+
all_files = list_repo_files(repo_id)
9+
checkpoint_files = [f for f in all_files if f.endswith((".ckpt", ".pt", ".pth"))]
10+
11+
if not checkpoint_files:
12+
raise ValueError(f"No checkpoint files found in {repo_id}. Available: {all_files}")
13+
14+
filename = checkpoint_files[0]
15+
print(f"Downloading {filename} from {repo_id}...")
16+
17+
return hf_hub_download(repo_id, filename)
18+
19+
20+
class HuggingFaceWeightMapper:
21+
"""Remaps weights to asparagus format."""
22+
23+
def __init__(self, state_dict: dict):
24+
self.state_dict = state_dict
25+
26+
def remap_keys(self) -> dict:
27+
"""Add 'model.' prefix for Lightning module compatibility. Subclasses should override and call super()."""
28+
first_key = next(iter(self.state_dict.keys()))
29+
if not first_key.startswith("model."):
30+
return {f"model.{k}": v for k, v in self.state_dict.items()}
31+
return self.state_dict
32+
33+
34+
class OpenMindResEncWeightMapper(HuggingFaceWeightMapper):
35+
"""Remaps OpenMind ResEncUNet weights to asparagus format."""
36+
37+
def remap_keys(self) -> dict:
38+
original_keys = self.state_dict.keys()
39+
self.state_dict = {
40+
k.replace(".convs.0.", ".conv1.").replace(".norm.", ".norm_op."): v for k, v in self.state_dict.items()
41+
}
42+
if self.state_dict.keys() != original_keys:
43+
print("Remapped OpenMind ResEncUNet keys to asparagus naming conventions.")
44+
45+
return super().remap_keys()
46+
47+
48+
class OpenMindPrimusWeightMapper(HuggingFaceWeightMapper):
49+
"""Remaps OpenMind Primus weights to asparagus format."""
50+
51+
def remap_keys(self) -> dict:
52+
original_keys = self.state_dict.keys()
53+
self.state_dict = {
54+
k.replace("encoder.eva.", "eva.")
55+
.replace("encoder.down_projection.proj.", "encoder.proj.")
56+
.replace("encoder.mask_token", "mask_token"): v
57+
for k, v in self.state_dict.items()
58+
}
59+
if self.state_dict.keys() != original_keys:
60+
print("Remapped OpenMind Primus keys to asparagus naming conventions.")
61+
62+
return super().remap_keys()
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23

34

45
@dataclass
56
class PathingConfig:
67
run_dir: str
78
ckpt_save_dir: str
8-
ckpt_path: str
9-
ckpt_parent_folder: str
9+
ckpt_path: Optional[str]
10+
ckpt_parent_folder: Optional[str]
1011
dataset_json_path: str

asparagus/modules/lightning_modules/base_module.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ def __init__(
2929
decoder_warmup_epochs: int = 0,
3030
cosine_period_ratio: float = 1,
3131
compile_mode: str = None,
32-
weights: str = None,
32+
weights: dict = None,
3333
load_decoder: bool = True,
34-
repeat_stem_weights: bool = True,
3534
optimizer: str = "SGD",
3635
train_transforms: Optional[transforms.Compose] = None,
3736
test_transforms: Optional[transforms.Compose] = None,
3837
val_transforms: Optional[transforms.Compose] = None,
3938
weight_decay: float = 3e-5,
4039
nesterov: bool = True,
4140
momentum: float = 0.99,
41+
repeat_stem_weights: bool = True,
4242
):
4343
super().__init__()
4444
self.learning_rate = learning_rate
@@ -60,11 +60,11 @@ def __init__(
6060
self.repeat_stem_weights = repeat_stem_weights
6161
assert 0 < cosine_period_ratio <= 1
6262

63-
self.save_hyperparameters(ignore=["model", "train_transforms", "val_transforms", "test_transforms"])
63+
self.save_hyperparameters(ignore=["model", "weights", "train_transforms", "val_transforms", "test_transforms"])
6464
self.model = model
6565

6666
if weights is not None:
67-
self.load_weights(weights, load_decoder=load_decoder)
67+
self.load_state_dict(weights, load_decoder=load_decoder, strict=False)
6868

6969
self.model = torch.compile(model, mode=compile_mode) if compile_mode is not None else model
7070

@@ -143,11 +143,6 @@ def configure_optimizers(self):
143143

144144
return [optimizer], [scheduler_config]
145145

146-
def load_weights(self, weights, load_decoder=True):
147-
ckpt = torch.load(weights, map_location="cpu", weights_only=False)
148-
print(f"Loading weights trained for {ckpt['global_step']} steps / {ckpt['epoch']} epochs.")
149-
self.load_state_dict(ckpt["state_dict"], load_decoder=load_decoder, strict=False)
150-
151146
def load_state_dict(self, state_dict, load_decoder=True, *args, **kwargs):
152147
old_params = copy.deepcopy(self.state_dict())
153148

@@ -161,10 +156,10 @@ def load_state_dict(self, state_dict, load_decoder=True, *args, **kwargs):
161156
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
162157

163158
# Repeat stem weights when state_dict num_channels is smaller than new_state_dict num_channels
164-
if self.model.stem_weight_name is not None and self.repeat_stem_weights:
159+
if hasattr(self.model, "stem_weight_name") and self.model.stem_weight_name is not None and self.repeat_stem_weights:
165160
prefix = "model._orig_mod." if "_orig_mod" in list(state_dict.keys())[0] else "model."
166161
stem_name = f"{prefix}{self.model.stem_weight_name}"
167-
pt_input_channels = state_dict[stem_name].shape[1] # (N, C, H, W, Z) where N is num tokens.
162+
pt_input_channels = state_dict[stem_name].shape[1]
168163
ft_input_channels = old_params[stem_name].shape[1]
169164
if pt_input_channels < ft_input_channels:
170165
assert pt_input_channels == 1, (

asparagus/modules/lightning_modules/clsreg_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
decoder_warmup_epochs: int = 0,
2424
cosine_period_ratio: float = 1,
2525
compile_mode: str = None,
26-
weights: str = None,
26+
weights: dict = None,
2727
optimizer: str = "SGD",
2828
train_transforms: Optional[transforms.Compose] = None,
2929
test_transforms: Optional[transforms.Compose] = None,
@@ -75,6 +75,7 @@ def training_step(self, batch, batch_idx):
7575

7676
pred = self.model(x)
7777
loss = self.loss(pred, y)
78+
7879
self.log(
7980
"train/loss", loss, on_step=False, on_epoch=True, sync_dist=True, batch_size=self.trainer.datamodule.batch_size
8081
)

asparagus/modules/lightning_modules/segmentation_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
decoder_warmup_epochs: int = 0,
4545
cosine_period_ratio: float = 1,
4646
compile_mode: str = None,
47-
weights: str = None,
47+
weights: dict = None,
4848
deep_supervision: bool = False,
4949
train_transforms: Optional[transforms.Compose] = None,
5050
test_transforms: Optional[transforms.Compose] = None,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import torch
3+
from asparagus.functional.huggingface import download_hf_checkpoint
4+
from asparagus.functional.versioning import detect_id
5+
from hydra.utils import get_class
6+
7+
8+
def load_checkpoint_state_dict(path):
9+
"""Load a checkpoint file and return the state_dict."""
10+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
11+
12+
if "state_dict" in ckpt:
13+
print(f"Loading weights trained for {ckpt.get('global_step', '?')} steps / {ckpt.get('epoch', '?')} epochs.")
14+
return ckpt["state_dict"]
15+
elif "network_weights" in ckpt:
16+
print("Loading weights from external checkpoint (network_weights key).")
17+
return ckpt["network_weights"]
18+
else:
19+
raise ValueError("Unsupported checkpoint format. Expected 'state_dict' or 'network_weights' key.")
20+
21+
22+
def resolve_checkpoint_path(cfg):
23+
"""Resolve checkpoint file path from config. Returns path or None."""
24+
if cfg.checkpoint_run_id:
25+
folder = detect_id(cfg.checkpoint_run_id)
26+
return os.path.join(folder, "checkpoints", cfg.load_checkpoint_name)
27+
if cfg.checkpoint_path:
28+
return cfg.checkpoint_path
29+
return None
30+
31+
32+
def resolve_checkpoint(cfg):
33+
"""Resolve and load checkpoint from config. Returns a state_dict or None."""
34+
hf_id = getattr(cfg, "hf_model_id", None) or None
35+
ckpt_path = resolve_checkpoint_path(cfg)
36+
37+
sources = [s for s in [ckpt_path, hf_id] if s]
38+
if len(sources) > 1:
39+
raise ValueError("Provide only one of: checkpoint_run_id, checkpoint_path, hf_model_id")
40+
if len(sources) == 0:
41+
return None
42+
43+
if ckpt_path:
44+
return load_checkpoint_state_dict(ckpt_path)
45+
46+
path = download_hf_checkpoint(hf_id)
47+
state_dict = load_checkpoint_state_dict(path)
48+
49+
weight_mapper = get_class(cfg.hf_weight_format)
50+
return weight_mapper(state_dict).remap_keys()
Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from asparagus.functional.versioning import detect_id, detect_mlflow_id, detect_wandb_id
33
from asparagus.modules.dataclasses import PathingConfig, VersioningConfig
4+
from asparagus.pipeline.auto_configuration.checkpoint import resolve_checkpoint_path
45
from hydra.core.hydra_config import HydraConfig
56

67

@@ -21,25 +22,18 @@ def pathing(cfg, train=True):
2122
run_dir = HydraConfig.get().runtime.output_dir
2223
os.makedirs(run_dir, exist_ok=True)
2324

24-
if cfg.checkpoint_run_id is not None and cfg.checkpoint_run_id != "":
25-
model_folder = detect_id(cfg.checkpoint_run_id)
26-
pretrained_ckpt = os.path.join(model_folder, "checkpoints", cfg.load_checkpoint_name)
27-
assert cfg.checkpoint_path is None, "You cannot provide both a checkpoint path and a checkpoint run id"
28-
elif cfg.checkpoint_path is not None and cfg.checkpoint_path != "":
29-
model_folder = None
30-
pretrained_ckpt = cfg.checkpoint_path
31-
else:
32-
model_folder, pretrained_ckpt = None, None
25+
ckpt_path = resolve_checkpoint_path(cfg)
26+
ckpt_parent_folder = detect_id(cfg.checkpoint_run_id) if cfg.checkpoint_run_id else None
3327

3428
if train:
3529
dataset_json_path = cfg.data.data_path + "/dataset.json"
3630
else:
3731
dataset_json_path = cfg.data.test_data_path + "/dataset.json"
38-
pathingcfg = PathingConfig(
32+
33+
return PathingConfig(
3934
run_dir=run_dir,
4035
ckpt_save_dir=os.path.join(run_dir, "checkpoints"),
41-
ckpt_parent_folder=model_folder,
42-
ckpt_path=pretrained_ckpt,
36+
ckpt_parent_folder=ckpt_parent_folder,
37+
ckpt_path=ckpt_path,
4338
dataset_json_path=dataset_json_path,
4439
)
45-
return pathingcfg

asparagus/pipeline/run/finetune_cls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from asparagus.modules.hydra.plugins.searchpath_plugins import FinetuneSearchpathPlugin
77
from asparagus.modules.transforms.presets import CPU_clsreg_val_test_transforms_crop
88
from asparagus.paths import get_config_path
9+
from asparagus.pipeline.auto_configuration.checkpoint import resolve_checkpoint
910
from asparagus.pipeline.auto_configuration.experiment_setup import (
1011
prepare_standard_experiment,
1112
)
@@ -39,6 +40,7 @@ def main(cfg: DictConfig) -> None:
3940
print(f"{OmegaConf.to_yaml(cfg)}\n Version: {cfg.run_id}\n Run dir: {HydraConfig.get().run.dir}\n")
4041
logging_safe_cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
4142
file_store, path_store, version_store = prepare_standard_experiment(cfg)
43+
weights = resolve_checkpoint(cfg)
4244
pl.seed_everything(seed=cfg.training.seed, workers=True)
4345

4446
loggers = logging(
@@ -107,7 +109,7 @@ def main(cfg: DictConfig) -> None:
107109
decoder_warmup_epochs=cfg.training.decoder_warmup_epochs,
108110
train_transforms=gpu_tr_transforms,
109111
val_transforms=None,
110-
weights=path_store.ckpt_path,
112+
weights=weights,
111113
log_image_every_n_epochs=cfg.logger.log_images_every_n_epoch,
112114
optimizer=cfg.model.finetune_optim,
113115
learning_rate=cfg.model.finetune_lr,

asparagus/pipeline/run/finetune_reg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from asparagus.modules.hydra.plugins.searchpath_plugins import FinetuneSearchpathPlugin
77
from asparagus.modules.transforms.presets import CPU_clsreg_val_test_transforms_crop
88
from asparagus.paths import get_config_path
9+
from asparagus.pipeline.auto_configuration.checkpoint import resolve_checkpoint
910
from asparagus.pipeline.auto_configuration.experiment_setup import (
1011
prepare_standard_experiment,
1112
)
@@ -38,6 +39,7 @@
3839
def main(cfg: DictConfig) -> None:
3940
print(f"{OmegaConf.to_yaml(cfg)}\n Version: {cfg.run_id}\n Run dir: {HydraConfig.get().run.dir}\n")
4041
file_store, path_store, version_store = prepare_standard_experiment(cfg)
42+
weights = resolve_checkpoint(cfg)
4143

4244
pl.seed_everything(seed=cfg.training.seed, workers=True)
4345

@@ -105,7 +107,7 @@ def main(cfg: DictConfig) -> None:
105107
model=model,
106108
train_transforms=gpu_tr_transforms,
107109
val_transforms=None,
108-
weights=path_store.ckpt_path,
110+
weights=weights,
109111
log_image_every_n_epochs=cfg.logger.log_images_every_n_epoch,
110112
optimizer=cfg.model.finetune_optim,
111113
learning_rate=cfg.model.finetune_lr,

asparagus/pipeline/run/finetune_seg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from asparagus.modules.hydra.plugins.searchpath_plugins import FinetuneSearchpathPlugin
77
from asparagus.modules.transforms.presets import CPU_seg_test_transforms
88
from asparagus.paths import get_config_path
9+
from asparagus.pipeline.auto_configuration.checkpoint import resolve_checkpoint
910
from asparagus.pipeline.auto_configuration.experiment_setup import (
1011
prepare_standard_experiment,
1112
)
@@ -39,10 +40,10 @@ def main(cfg: DictConfig) -> None:
3940
print(f"{OmegaConf.to_yaml(cfg)}\n Version: {cfg.run_id}\n Run dir: {HydraConfig.get().run.dir}\n")
4041
logging_safe_cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
4142
file_store, path_store, version_store = prepare_standard_experiment(cfg)
43+
weights = resolve_checkpoint(cfg)
4244
pl.seed_everything(seed=cfg.training.seed, workers=True)
4345

4446
assert "load_checkpoint_name" in cfg.keys(), "load_checkpoint_name not in config. Did you supply a scratch config?"
45-
assert path_store.ckpt_path is not None, "Checkpoint must be provided for finetuning."
4647

4748
loggers = logging(
4849
ckpt_wandb_id=version_store.wandb_id,
@@ -112,7 +113,7 @@ def main(cfg: DictConfig) -> None:
112113
model=model,
113114
warmup_epochs=cfg.training.warmup_epochs,
114115
decoder_warmup_epochs=cfg.training.decoder_warmup_epochs,
115-
weights=path_store.ckpt_path,
116+
weights=weights,
116117
train_transforms=gpu_tr_transforms,
117118
val_transforms=None,
118119
optimizer=cfg.model.finetune_optim,

0 commit comments

Comments
 (0)