File tree Expand file tree Collapse file tree 2 files changed +17
-0
lines changed Expand file tree Collapse file tree 2 files changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -665,6 +665,17 @@ def n_parameters(self) -> int:
665665 - 1
666666 )
667667
668+ def to_device (self , device : Any ) -> GaussianMixtureModelJax :
669+ """Move model to device"""
670+
671+ def move_array_to_device (node ): # type: ignore [no-untyped-def]
672+ if isinstance (node , jax .Array ):
673+ return jax .device_put (node , device = device )
674+
675+ return node
676+
677+ return jax .tree_util .tree_map (move_array_to_device , self ) # type: ignore [no-any-return]
678+
668679 def write (self , filename : str ) -> None :
669680 """Save the model parameters to a file in safetensors format."""
670681 from safetensors .flax import save_file
Original file line number Diff line number Diff line change @@ -149,6 +149,12 @@ def test_fit(gmm_jax):
149149 assert_allclose (result .gmm .weights_numpy , [0.2 , 0.8 ], rtol = 0.05 )
150150
151151
152+ def test_move_to_device (gmm_jax ):
153+ gmm_jax_moved = gmm_jax .to_device (jax .devices ()[0 ])
154+ assert_allclose (gmm_jax .means_numpy , gmm_jax_moved .means_numpy )
155+ assert gmm_jax_moved .means .devices () == {jax .devices ()[0 ]}
156+
157+
152158def test_io (gmm_jax , tmpdir ):
153159 filename = tmpdir / "model.safetensors"
154160
You can’t perform that action at this time.
0 commit comments