Skip to content

Commit a69ebcd

Browse files
authored
Merge pull request #9 from adonath/safetensors_io
Add safetensors io
2 parents c7dc705 + 2673f46 commit a69ebcd

File tree

4 files changed

+64
-1
lines changed

4 files changed

+64
-1
lines changed

gmmx/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,4 @@ def em_cond(
161161
body_fun=em_step,
162162
init_val=(x, gmm, 0, jnp.asarray(jnp.inf), jnp.array(jnp.inf)),
163163
)
164-
return EMFitterResult(*result, converged=result[2] < self.max_iter) # type: ignore [misc]
164+
return EMFitterResult(*result, converged=result[2] < self.max_iter)

gmmx/gmm.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
from __future__ import annotations
5050

51+
import logging
5152
from dataclasses import dataclass, field
5253
from enum import Enum
5354
from functools import partial
@@ -69,6 +70,7 @@
6970
"GaussianMixtureSKLearn",
7071
]
7172

73+
log = logging.getLogger()
7274

7375
AnyArray = Union[np.typing.NDArray, jax.Array]
7476
Device = Union[str, None]
@@ -663,6 +665,49 @@ def n_parameters(self) -> int:
663665
- 1
664666
)
665667

668+
def write(self, filename: str) -> None:
669+
"""Save the model parameters to a file in safetensors format."""
670+
from safetensors.flax import save_file
671+
672+
data = {
673+
"means": self.means_numpy,
674+
"weights": self.weights_numpy,
675+
"covariances": self.covariances.values_numpy,
676+
}
677+
678+
metadata = {"covariance-type": self.covariances.type}
679+
680+
log.info(f"Writing {filename}")
681+
save_file(data, metadata=metadata, filename=filename) # type: ignore [arg-type]
682+
683+
@classmethod
684+
def read(cls, filename: str, device: str = "cpu") -> GaussianMixtureModelJax:
685+
"""Read model parameters from a safetensors file.
686+
687+
Parameters
688+
----------
689+
filename : str
690+
Path to the safetensors file.
691+
device : str, optional
692+
Device to load the tensors onto (default: "cpu").
693+
694+
Returns
695+
-------
696+
GaussianMixtureModelJax
697+
Loaded Gaussian Mixture Model instance.
698+
"""
699+
from safetensors import safe_open
700+
701+
data = {}
702+
703+
with safe_open(filename, framework="flax", device=device) as f:
704+
for key in f.keys(): # noqa: SIM118
705+
data[key] = f.get_tensor(key)
706+
707+
covariance_type = f.metadata()["covariance-type"]
708+
709+
return cls.from_squeezed(**data, covariance_type=covariance_type)
710+
666711
@property
667712
def log_weights(self) -> jax.Array:
668713
"""Log weights (~jax.ndarray)"""

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ classifiers = [
2020
dependencies = [
2121
"jax>=0.4.30",
2222
"numpy>=1.26.0",
23+
"safetensors>=0.5.0",
2324
]
2425

2526
[project.urls]
@@ -39,6 +40,7 @@ dev-dependencies = [
3940
"mkdocs-material>=8.5.10",
4041
"mkdocstrings[python]>=0.26.1",
4142
"scikit-learn>=1.0",
43+
"safetensors>=0.5.0",
4244
]
4345

4446
[build-system]

tests/test_gmm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ def test_fit(gmm_jax):
149149
assert_allclose(result.gmm.weights_numpy, [0.2, 0.8], rtol=0.05)
150150

151151

152+
def test_io(gmm_jax, tmpdir):
153+
filename = tmpdir / "model.safetensors"
154+
155+
gmm_jax.write(filename)
156+
157+
new_model = GaussianMixtureModelJax.read(filename)
158+
159+
assert_allclose(gmm_jax.means_numpy, new_model.means_numpy)
160+
assert_allclose(gmm_jax.weights_numpy, new_model.weights_numpy)
161+
assert_allclose(
162+
gmm_jax.covariances.values_numpy, new_model.covariances.values_numpy
163+
)
164+
165+
assert gmm_jax.covariances.type == new_model.covariances.type
166+
167+
152168
def test_fit_against_sklearn(gmm_jax):
153169
# Fitting is hard to test, especillay we cannot guarantee the fit converges to the same solution
154170
# However the "global" likelihood (summed accross all components) for a given feature vector

0 commit comments

Comments
 (0)