1111from os import PathLike
1212from typing import Union , NewType , Literal , Optional
1313from pathlib import Path
14- from collections import defaultdict
1514from collections .abc import Sequence
1615from packaging .version import Version
1716
@@ -44,9 +43,10 @@ def load_safetensors(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tas
4443 """
4544 Loads one or more models in safetensors format and returns them.
4645 """
46+ from torch import nn
4747 from safetensors import safe_open , SafetensorError
48- weights = defaultdict ( dict )
49- models = {}
48+ from safetensors . torch import load_model
49+ models = nn . ModuleDict ()
5050 try :
5151 with safe_open (path , framework = "pt" ) as f :
5252 if (metadata := f .metadata ()) is not None :
@@ -64,17 +64,10 @@ def load_safetensors(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tas
6464 models [prefix ] = create_model (model_map [prefix ].get ('_model' ), ** model_map [prefix ])
6565 else :
6666 raise ValueError (f'No model metadata found in { path } .' )
67- for k in f .offset_keys ():
68- try :
69- prefix = prefixes [list (map (k .startswith , prefixes )).index (True )]
70- weights [prefix ][k .removeprefix (f'{ prefix } .' )] = f .get_tensor (k )
71- except ValueError :
72- continue
7367 except SafetensorError as e :
7468 raise ValueError (f'Invalid model file { path } ' ) from e
7569 # load weights into models
76- for prefix , weight in weights .items ():
77- models [prefix ].load_state_dict (weight )
70+ load_model (models , path )
7871 return list (models .values ())
7972
8073
0 commit comments