Skip to content

Commit c487bad

Browse files
committed
And make loader works with shared tensors
1 parent c8b1e52 commit c487bad

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

kraken/models/loaders.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from os import PathLike
1212
from typing import Union, NewType, Literal, Optional
1313
from pathlib import Path
14-
from collections import defaultdict
1514
from collections.abc import Sequence
1615
from packaging.version import Version
1716

@@ -44,9 +43,10 @@ def load_safetensors(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tas
4443
"""
4544
Loads one or more models in safetensors format and returns them.
4645
"""
46+
from torch import nn
4747
from safetensors import safe_open, SafetensorError
48-
weights = defaultdict(dict)
49-
models = {}
48+
from safetensors.torch import load_model
49+
models = nn.ModuleDict()
5050
try:
5151
with safe_open(path, framework="pt") as f:
5252
if (metadata := f.metadata()) is not None:
@@ -64,17 +64,10 @@ def load_safetensors(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tas
6464
models[prefix] = create_model(model_map[prefix].get('_model'), **model_map[prefix])
6565
else:
6666
raise ValueError(f'No model metadata found in {path}.')
67-
for k in f.offset_keys():
68-
try:
69-
prefix = prefixes[list(map(k.startswith, prefixes)).index(True)]
70-
weights[prefix][k.removeprefix(f'{prefix}.')] = f.get_tensor(k)
71-
except ValueError:
72-
continue
7367
except SafetensorError as e:
7468
raise ValueError(f'Invalid model file {path}') from e
7569
# load weights into models
76-
for prefix, weight in weights.items():
77-
models[prefix].load_state_dict(weight)
70+
load_model(models, path)
7871
return list(models.values())
7972

8073

0 commit comments

Comments
 (0)