We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 208b069 commit 537cfddCopy full SHA for 537cfdd
gmmx/gmm.py
@@ -609,6 +609,9 @@ def from_k_means(
609
) -> GaussianMixtureModelJax:
610
"""Init from k-means clustering
611
612
+ From k-means only supports creation on the CPU. You can move
613
+ the whole model using `.to_device()` after.
614
+
615
Parameters
616
----------
617
x : jax.array
@@ -631,7 +634,7 @@ def from_k_means(
631
634
632
635
n_samples = x.shape[Axis.batch]
633
636
- resp = jnp.zeros((n_samples, n_components), device="cpu")
637
+ resp = jnp.zeros((n_samples, n_components), device=jax.devices("cpu")[0])
638
639
kwargs.setdefault("n_init", 10) # type: ignore [arg-type]
640
label = KMeans(n_clusters=n_components, **kwargs).fit(x).labels_
0 commit comments