Skip to content

Commit 537cfdd

Browse files
committed
Fix device creation in from_k_means
1 parent 208b069 commit 537cfdd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

gmmx/gmm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,9 @@ def from_k_means(
609609
) -> GaussianMixtureModelJax:
610610
"""Init from k-means clustering
611611
612+
From k-means only supports creation on the CPU. You can move
613+
the whole model using `.to_device()` after.
614+
612615
Parameters
613616
----------
614617
x : jax.array
@@ -631,7 +634,7 @@ def from_k_means(
631634

632635
n_samples = x.shape[Axis.batch]
633636

634-
resp = jnp.zeros((n_samples, n_components), device="cpu")
637+
resp = jnp.zeros((n_samples, n_components), device=jax.devices("cpu")[0])
635638

636639
kwargs.setdefault("n_init", 10) # type: ignore [arg-type]
637640
label = KMeans(n_clusters=n_components, **kwargs).fit(x).labels_

0 commit comments

Comments
 (0)