@@ -42,6 +42,17 @@ def load_models(path: Union[str, 'PathLike'], tasks: Optional[Sequence[_T_tasks]
4242def load_safetensors (path : Union [str , PathLike ], tasks : Optional [Sequence [_T_tasks ]] = None ) -> list [BaseModel ]:
4343 """
4444 Loads one or more models in safetensors format and returns them.
45+
46+ Args:
47+ path: Path to the safetensors file.
48+ tasks: Filter for model types to load from file.
49+
50+ Returns:
51+ A list of models.
52+
53+ Raises:
54+ ValueError: When model metadata is incomplete or the safetensors file
55+ is invalid.
4556 """
4657 from torch import nn
4758 from safetensors import safe_open , SafetensorError
@@ -74,7 +85,18 @@ def load_safetensors(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tas
7485
7586def load_coreml (path : Union [str , PathLike ], tasks : Optional [Sequence [_T_tasks ]] = None ) -> list [BaseModel ]:
7687 """
77- Loads a model in coreml format.
88+ Loads a model in CoreML format.
89+
90+ Args:
91+ path: Path to the coreml file.
92+ tasks: Filter for model types to load from file.
93+
94+ Returns:
95+ A list of models.
96+
97+ Raises:
98+ ValueError: When model metadata is incomplete or the coreml file
99+ is invalid.
78100 """
79101 root_logger = logging .getLogger ()
80102 level = root_logger .getEffectiveLevel ()
@@ -83,6 +105,8 @@ def load_coreml(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tasks]]
83105 root_logger .setLevel (level )
84106 from google .protobuf .message import DecodeError
85107
108+ models = []
109+
86110 if isinstance (path , PathLike ):
87111 path = path .as_posix ()
88112 try :
@@ -111,12 +135,14 @@ def load_coreml(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tasks]]
111135 weights .update (cml_parser (spec ))
112136
113137 model .load_state_dict (weights )
138+ models .append (model )
114139
115140 # construct additional models if auxiliary layers are defined.
116141
117- # if 'aux_layers' in mlmodel.user_defined_metadata:
118- # logger.info('Deserializing auxiliary layers.')
142+ if 'aux_layers' in mlmodel .user_defined_metadata :
143+ logger .info ('Deserializing auxiliary layers.' )
144+
145+ nn .aux_layers = {k : cls (v ).nn .get_submodule (k ) for k , v in json .loads (mlmodel .user_defined_metadata ['aux_layers' ]).items ()}
119146
120- # nn.aux_layers = {k: cls(v).nn.get_submodule(k) for k, v in json.loads(mlmodel.user_defined_metadata['aux_layers']).items()}
147+ return models
121148
122- return [model ]
0 commit comments