|
48 | 48 |
|
49 | 49 | from __future__ import annotations |
50 | 50 |
|
| 51 | +import logging |
51 | 52 | from dataclasses import dataclass, field |
52 | 53 | from enum import Enum |
53 | 54 | from functools import partial |
|
69 | 70 | "GaussianMixtureSKLearn", |
70 | 71 | ] |
71 | 72 |
|
| 73 | +log = logging.getLogger() |
72 | 74 |
|
73 | 75 | AnyArray = Union[np.typing.NDArray, jax.Array] |
74 | 76 | Device = Union[str, None] |
@@ -663,6 +665,49 @@ def n_parameters(self) -> int: |
663 | 665 | - 1 |
664 | 666 | ) |
665 | 667 |
|
| 668 | + def write(self, filename: str) -> None: |
| 669 | + """Save the model parameters to a file in safetensors format.""" |
| 670 | + from safetensors.flax import save_file |
| 671 | + |
| 672 | + data = { |
| 673 | + "means": self.means_numpy, |
| 674 | + "weights": self.weights_numpy, |
| 675 | + "covariances": self.covariances.values_numpy, |
| 676 | + } |
| 677 | + |
| 678 | + metadata = {"covariance-type": self.covariances.type} |
| 679 | + |
| 680 | + log.info(f"Writing {filename}") |
| 681 | + save_file(data, metadata=metadata, filename=filename) # type: ignore [arg-type] |
| 682 | + |
| 683 | + @classmethod |
| 684 | + def read(cls, filename: str, device: str = "cpu") -> GaussianMixtureModelJax: |
| 685 | + """Read model parameters from a safetensors file. |
| 686 | +
|
| 687 | + Parameters |
| 688 | + ---------- |
| 689 | + filename : str |
| 690 | + Path to the safetensors file. |
| 691 | + device : str, optional |
| 692 | + Device to load the tensors onto (default: "cpu"). |
| 693 | +
|
| 694 | + Returns |
| 695 | + ------- |
| 696 | + GaussianMixtureModelJax |
| 697 | + Loaded Gaussian Mixture Model instance. |
| 698 | + """ |
| 699 | + from safetensors import safe_open |
| 700 | + |
| 701 | + data = {} |
| 702 | + |
| 703 | + with safe_open(filename, framework="flax", device=device) as f: |
| 704 | + for key in f.keys(): # noqa: SIM118 |
| 705 | + data[key] = f.get_tensor(key) |
| 706 | + |
| 707 | + covariance_type = f.metadata()["covariance-type"] |
| 708 | + |
| 709 | + return cls.from_squeezed(**data, covariance_type=covariance_type) |
| 710 | + |
666 | 711 | @property |
667 | 712 | def log_weights(self) -> jax.Array: |
668 | 713 | """Log weights (~jax.ndarray)""" |
|
0 commit comments