33
44import numpy as np
55import torch
6+ from huggingface_hub import hf_hub_download
67from PIL .Image import Image
78
89from cellseg_models_pytorch .decoders .multitask_decoder import (
910 SoftInstanceOutput ,
1011 SoftSemanticOutput ,
1112)
13+ from cellseg_models_pytorch .models .base import PRETRAINED
1214
1315__all__ = ["BaseModelInst" ]
1416
@@ -24,24 +26,60 @@ def set_inference_mode(self) -> None:
2426 @classmethod
2527 def from_pretrained (
2628 cls ,
27- weights_path : Union [str , Path ],
28- n_nuc_classes : int ,
29- enc_name : str = "efficientnet_b5" ,
30- enc_freeze : bool = False ,
29+ weights : Union [str , Path ],
3130 device : torch .device = torch .device ("cuda" ),
3231 model_kwargs : Dict [str , Any ] = {},
33- ) -> None :
34- """Load the model from pretrained weights."""
32+ ) -> "BaseModelInst" :
33+ """Load the model from pretrained weights.
34+
35+ Parameters:
36+ model_name (str):
37+ Name of the pretrained model.
38+ device (torch.device, default=torch.device("cuda")):
39+ Device to run the model on. Default is "cuda".
40+ model_kwargs (Dict[str, Any], default={}):
41+ Additional arguments for the model.
42+ """
43+ weights_path = Path (weights )
44+ if not weights_path .is_file ():
45+ if weights_path .as_posix () in PRETRAINED [cls .model_name ].keys ():
46+ weights_path = Path (
47+ hf_hub_download (
48+ repo_id = PRETRAINED [cls .model_name ][weights ]["repo_id" ],
49+ filename = PRETRAINED [cls .model_name ][weights ]["filename" ],
50+ )
51+ )
52+
53+ else :
54+ raise ValueError (
55+ "Please provide a valid path. or a pre-trained model downloaded from the"
56+ f" csmp-hub. One of { list (PRETRAINED [cls .model_name ].keys ())} ."
57+ )
58+
59+ try :
60+ from safetensors .torch import load_model
61+ except ImportError :
62+ raise ImportError (
63+ "Please install `safetensors` package to load .safetensors files."
64+ )
65+
66+ enc_name , n_nuc_classes , state_dict = cls ._get_state_dict (
67+ weights_path , device = device
68+ )
69+
3570 model_inst = cls (
3671 n_nuc_classes = n_nuc_classes ,
3772 enc_name = enc_name ,
3873 enc_pretrain = False ,
39- enc_freeze = enc_freeze ,
74+ enc_freeze = False ,
4075 device = device ,
4176 model_kwargs = model_kwargs ,
4277 )
43- state_dict = torch .load (weights_path , map_location = device )
44- model_inst .model .load_state_dict (state_dict , strict = True )
78+
79+ if weights_path .suffix == ".safetensors" :
80+ load_model (model_inst .model , weights_path , device .type )
81+ else :
82+ model_inst .model .load_state_dict (state_dict , strict = True )
4583
4684 return model_inst
4785
@@ -174,3 +212,34 @@ def post_process(
174212 )
175213
176214 return x
215+
216+ @staticmethod
217+ def _get_state_dict (
218+ weights_path : Union [str , Path ], device : torch .device = torch .device ("cuda" )
219+ ) -> None :
220+ """Load the model from pretrained weights."""
221+ weights_path = Path (weights_path )
222+ if not weights_path .exists ():
223+ raise ValueError (f"Model weights not found at { weights_path } " )
224+ if weights_path .suffix == ".safetensors" :
225+ try :
226+ from safetensors .torch import load_file
227+ except ImportError :
228+ raise ImportError (
229+ "Please install `safetensors` package to load .safetensors files."
230+ )
231+ state_dict = load_file (weights_path , device = device .type )
232+ else :
233+ state_dict = torch .load (weights_path , map_location = device )
234+
235+ # infer encoder name and number of classes from state_dict
236+ enc_keys = [key for key in state_dict .keys () if "encoder." in key ]
237+ enc_name = enc_keys [0 ].split ("." )[0 ] if enc_keys else None
238+ nuc_type_head_key = next (
239+ key
240+ for key in state_dict .keys ()
241+ if "nuc_type_head.head" in key and "weight" in key
242+ )
243+ n_nuc_classes = state_dict [nuc_type_head_key ].shape [0 ]
244+
245+ return enc_name , n_nuc_classes , state_dict
0 commit comments