Skip to content

Commit c8b1e52

Browse files
committed
Make write_safetensors work with shared tensors
1 parent 8de82f1 commit c8b1e52

File tree

4 files changed

+88
-82
lines changed

4 files changed

+88
-82
lines changed

kraken/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .writers import * # NOQA
33
from .loaders import * # NOQA
44
from .utils import * # NOQA
5+
from .convert import * # NOQA

kraken/models/convert.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import importlib
2+
3+
from pathlib import Path
4+
from collections.abc import Iterable
5+
from typing import TYPE_CHECKING, Union
6+
from kraken.models.loaders import load_models
7+
8+
if TYPE_CHECKING:
9+
from os import PathLike
10+
11+
__all__ = ['convert_models']
12+
13+
14+
def convert_models(paths: Iterable[Union[str, 'PathLike']],
15+
output: Union[str, 'PathLike'],
16+
weights_format: str = 'safetensors') -> 'PathLike':
17+
"""
18+
Converts the models in a set of checkpoint or weights files into a single
19+
output weights file.
20+
21+
It accepts checkpoints and weights files interchangeably for all supported
22+
formats and models.
23+
24+
This function has a number of uses:
25+
26+
* it can be used to convert checkpoints into weights.
27+
28+
convert_models(['model.ckpt'], 'model.safetensors')
29+
30+
* it can be used to convert multiple related models into a single
31+
weights file for joint inference:
32+
33+
convert_models(['blla_line.ckpt', 'blla_region.ckpt'], 'model.safetensors')
34+
35+
* it can convert models between coreml and safetensors formats:
36+
37+
convert_models(['blla.mlmodel'], 'blla.safetensors')
38+
39+
Args:
40+
paths: Paths to checkpoint or weights files.
41+
output: Output path to the combined/converted file. The actual output
42+
path may be modified.
43+
weights_format: Serialization format to write the weights to.
44+
45+
Returns:
46+
The path the actual weights file was written to.
47+
"""
48+
try:
49+
(entry_point,) = importlib.metadata.entry_points(group='kraken.writers', name=weights_format)
50+
writer = entry_point.load()
51+
except ValueError:
52+
raise ValueError('No writer for format {weights_format} found.')
53+
54+
def _find_module(path):
55+
for entry_point in importlib.metadata.entry_points(group='kraken.lightning_modules'):
56+
module = entry_point.load()
57+
try:
58+
return module.load_from_checkpoint(path)
59+
except ValueError:
60+
continue
61+
raise ValueError(f'No lightning module found for checkpoint {path}')
62+
63+
models = []
64+
for ckpt in paths:
65+
ckpt = Path(ckpt)
66+
if ckpt.suffix == '.ckpt':
67+
models.append(_find_module(ckpt).net)
68+
else:
69+
models.extend(load_models(ckpt))
70+
71+
return writer(models, output)

kraken/models/utils.py

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import importlib
22

3-
from pathlib import Path
4-
from collections.abc import Iterable
5-
from typing import Union, TYPE_CHECKING
6-
from kraken.models.base import BaseModel
7-
from kraken.models.loaders import load_models
3+
from typing import TYPE_CHECKING
84

9-
__all__ = ['create_model', 'convert_models']
5+
6+
__all__ = ['create_model']
107

118
if TYPE_CHECKING:
12-
from os import PathLike
9+
from .base import BaseModel
1310

1411

15-
def create_model(name, *args, **kwargs) -> BaseModel:
12+
def create_model(name, *args, **kwargs) -> 'BaseModel':
1613
"""
1714
Constructs an empty model from the model registry.
1815
"""
@@ -26,63 +23,3 @@ def create_model(name, *args, **kwargs) -> BaseModel:
2623

2724
cls = entry_point.load()
2825
return cls(*args, **kwargs)
29-
30-
31-
def convert_models(paths: Iterable[Union[str, 'PathLike']],
32-
output: Union[str, 'PathLike'],
33-
weights_format: str = 'safetensors') -> 'PathLike':
34-
"""
35-
Converts the models in a set of checkpoint or weights files into a single
36-
output weights file.
37-
38-
It accepts checkpoints and weights files interchangeably for all supported
39-
formats and models.
40-
41-
This function has a number of uses:
42-
43-
* it can be used to convert checkpoints into weights.
44-
45-
convert_models(['model.ckpt'], 'model.safetensors')
46-
47-
* it can be used to convert multiple related models into a single
48-
weights file for joint inference:
49-
50-
convert_models(['blla_line.ckpt', 'blla_region.ckpt'], 'model.safetensors')
51-
52-
* it can convert models between coreml and safetensors formats:
53-
54-
convert_models(['blla.mlmodel'], 'blla.safetensors')
55-
56-
Args:
57-
paths: Paths to checkpoint or weights files.
58-
output: Output path to the combined/converted file. The actual output
59-
path may be modified.
60-
weights_format: Serialization format to write the weights to.
61-
62-
Returns:
63-
The path the actual weights file was written to.
64-
"""
65-
try:
66-
(entry_point,) = importlib.metadata.entry_points(group='kraken.writers', name=weights_format)
67-
writer = entry_point.load()
68-
except ValueError:
69-
raise ValueError('No writer for format {weights_format} found.')
70-
71-
def _find_module(path):
72-
for entry_point in importlib.metadata.entry_points(group='kraken.lightning_modules'):
73-
module = entry_point.load()
74-
try:
75-
return module.load_from_checkpoint(path)
76-
except ValueError:
77-
continue
78-
raise ValueError(f'No lightning module found for checkpoint {path}')
79-
80-
models = []
81-
for ckpt in paths:
82-
ckpt = Path(ckpt)
83-
if ckpt.suffix == '.ckpt':
84-
models.append(_find_module(ckpt).net)
85-
else:
86-
models.extend(load_models(ckpt))
87-
88-
return writer(models, output)

kraken/models/writers.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import importlib
1111

12+
from torch import nn
1213
from os import PathLike
1314
from pathlib import Path
1415
from typing import Union, TYPE_CHECKING
@@ -44,21 +45,17 @@ def write_safetensors(objs: list[BaseModel], path: Union[str, PathLike]) -> Path
4445
"""
4546
Writes a set of models as a safetensors.
4647
"""
47-
from safetensors.torch import save_file
48+
from safetensors.torch import save_model
4849
# assign unique prefixes to each model in model list
49-
prefixes = {str(uuid.uuid4()): model for model in objs}
50-
metadatas = {k: {'_kraken_min_version': v._kraken_min_version,
51-
'_tasks': v.model_type,
52-
'_model': v.__class__.__name__,
53-
**v.user_metadata} for k, v in prefixes.items()}
54-
55-
weights = {}
56-
for prefix, model in prefixes.items():
57-
for name in (state_dict := model.state_dict()):
58-
weights[f'{prefix}.{name}'] = state_dict[name]
59-
save_file(weights,
60-
filename=path,
61-
metadata={'kraken_meta': json.dumps(metadatas)})
50+
prefixes = nn.ModuleDict({str(uuid.uuid4()): model for model in objs})
51+
metadata = {k: {'_kraken_min_version': v._kraken_min_version,
52+
'_tasks': v.model_type,
53+
'_model': v.__class__.__name__,
54+
**v.user_metadata} for k, v in prefixes.items()}
55+
56+
save_model(prefixes,
57+
filename=path,
58+
metadata={'kraken_meta': json.dumps(metadata)})
6259
return Path(path)
6360

6461

0 commit comments

Comments
 (0)