We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bcb75a9 commit 208b069Copy full SHA for 208b069
gmmx/gmm.py
@@ -631,7 +631,7 @@ def from_k_means(
631
632
n_samples = x.shape[Axis.batch]
633
634
- resp = jnp.zeros((n_samples, n_components))
+ resp = jnp.zeros((n_samples, n_components), device="cpu")
635
636
kwargs.setdefault("n_init", 10) # type: ignore [arg-type]
637
label = KMeans(n_clusters=n_components, **kwargs).fit(x).labels_
0 commit comments