Skip to content

Commit bcb75a9

Browse files
committed
Add to_device() method
1 parent a69ebcd commit bcb75a9

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

gmmx/gmm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,17 @@ def n_parameters(self) -> int:
665665
- 1
666666
)
667667

668+
def to_device(self, device: Any) -> GaussianMixtureModelJax:
669+
"""Move model to device"""
670+
671+
def move_array_to_device(node): # type: ignore [no-untyped-def]
672+
if isinstance(node, jax.Array):
673+
return jax.device_put(node, device=device)
674+
675+
return node
676+
677+
return jax.tree_util.tree_map(move_array_to_device, self) # type: ignore [no-any-return]
678+
668679
def write(self, filename: str) -> None:
669680
"""Save the model parameters to a file in safetensors format."""
670681
from safetensors.flax import save_file

tests/test_gmm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ def test_fit(gmm_jax):
149149
assert_allclose(result.gmm.weights_numpy, [0.2, 0.8], rtol=0.05)
150150

151151

152+
def test_move_to_device(gmm_jax):
153+
gmm_jax_moved = gmm_jax.to_device(jax.devices()[0])
154+
assert_allclose(gmm_jax.means_numpy, gmm_jax_moved.means_numpy)
155+
assert gmm_jax_moved.means.devices() == {jax.devices()[0]}
156+
157+
152158
def test_io(gmm_jax, tmpdir):
153159
filename = tmpdir / "model.safetensors"
154160

0 commit comments

Comments
 (0)