@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
2626
2727
2828def _preprocess_model_id (
29- model_id : str , subfolder : Optional [str ], same_as_pretrained : bool , use_pretrained : bool
30- ) -> Tuple [str , Optional [str ], bool , bool ]:
29+ model_id : str ,
30+ subfolder : Optional [str ],
31+ same_as_pretrained : bool ,
32+ use_pretrained : bool ,
33+ submodule : Optional [str ] = None ,
34+ ) -> Tuple [str , Optional [str ], bool , bool , Optional [str ]]:
35+ if "::" in model_id :
36+ assert (
37+ not submodule
38+ ), f"submodule={ submodule !r} cannot be defined in model_id={ model_id !r} as well"
39+ model_id , submodule = model_id .split ("::" , maxsplit = 1 )
3140 if subfolder or "//" not in model_id :
32- return model_id , subfolder , same_as_pretrained , use_pretrained
41+ return model_id , subfolder , same_as_pretrained , use_pretrained , submodule
3342 spl = model_id .split ("//" )
3443 if spl [- 1 ] == "pretrained" :
35- return _preprocess_model_id ("//" .join (spl [:- 1 ]), "" , True , True )
44+ return _preprocess_model_id ("//" .join (spl [:- 1 ]), "" , True , True , submodule )
3645 if spl [- 1 ] in {"transformer" , "vae" }:
3746 # known subfolder
38- return "//" .join (spl [:- 1 ]), spl [- 1 ], same_as_pretrained , use_pretrained
39- return model_id , subfolder , same_as_pretrained , use_pretrained
47+ return "//" .join (spl [:- 1 ]), spl [- 1 ], same_as_pretrained , use_pretrained , submodule
48+ return model_id , subfolder , same_as_pretrained , use_pretrained , submodule
4049
4150
4251def get_untrained_model_with_inputs (
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
5463 subfolder : Optional [str ] = None ,
5564 use_only_preinstalled : bool = False ,
5665 config_reduction : Optional [Callable [[Any , str ], Dict ]] = None ,
66+ submodule : Optional [str ] = None ,
5767) -> Dict [str , Any ]:
5868 """
5969 Gets a non initialized model similar to the original model
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
8292 <onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
8393 this function takes a configuration and a task (string)
8494 as arguments
95+ :param submodule: use a submodule instead of the main model
8596 :return: dictionary with a model, inputs, dynamic shapes, and the configuration,
8697 some necessary rewriting as well
8798
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
108119 f"model_id={ model_id !r} , preinstalled model is only available "
109120 f"if use_only_preinstalled is False."
110121 )
111- model_id , subfolder , same_as_pretrained , use_pretrained = _preprocess_model_id (
122+ model_id , subfolder , same_as_pretrained , use_pretrained , submodule = _preprocess_model_id (
112123 model_id ,
113124 subfolder ,
114125 same_as_pretrained = same_as_pretrained ,
115126 use_pretrained = use_pretrained ,
127+ submodule = submodule ,
116128 )
117129 if verbose :
118130 print (
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
147159 if verbose :
148160 print (f"[get_untrained_model_with_inputs] architecture={ arch !r} " )
149161 print (f"[get_untrained_model_with_inputs] cls={ config .__class__ .__name__ !r} " )
162+ if submodule :
163+ print (f"[get_untrained_model_with_inputs] submodule={ submodule !r} " )
150164 if task is None :
151165 task = task_from_arch (arch , model_id = model_id , subfolder = subfolder )
152166 if verbose :
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
357371 if diff_config is not None :
358372 res ["dump_info" ] = dict (config_diff = diff_config )
359373
374+ if submodule :
375+ path = submodule .split ("::" ) if "::" in submodule else [submodule ]
376+ for p in path :
377+ assert hasattr (model , p ), (
378+ f"Unable to find submodule { p !r} in in class { type (model )} , "
379+ f"submodule={ submodule !r} , possible candidates: "
380+ f"{ [k for k in dir (model ) if isinstance (getattr (model , k ), torch .nn .Module )]} "
381+ )
382+ model = getattr (model , p )
383+
384+ if verbose :
385+ print (f"[get_untrained_model_with_inputs] model class={ model .__class__ .__name__ !r} " )
386+
360387 sizes = compute_model_size (model )
361388 res ["model" ] = model
362389 res ["configuration" ] = config
0 commit comments