Skip to content

Commit 208b069

Browse files
authored
Update gmm.py
1 parent bcb75a9 commit 208b069

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

gmmx/gmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def from_k_means(
631631

632632
n_samples = x.shape[Axis.batch]
633633

634-
resp = jnp.zeros((n_samples, n_components))
634+
resp = jnp.zeros((n_samples, n_components), device="cpu")
635635

636636
kwargs.setdefault("n_init", 10) # type: ignore [arg-type]
637637
label = KMeans(n_clusters=n_components, **kwargs).fit(x).labels_

0 commit comments

Comments
 (0)