@@ -33,16 +33,19 @@ def _silhouette_reduce(
33
33
"""
34
34
# accumulate distances from each sample to each cluster
35
35
D_chunk_len = D_chunk .shape [0 ]
36
- clust_dists = jnp .zeros ((D_chunk_len , len (label_freqs )), dtype = D_chunk .dtype )
37
36
38
- def _bincount (i , _data ):
39
- clust_dists , D_chunk , labels , label_freqs = _data
40
- clust_dists = clust_dists .at [i ].set (jnp .bincount (labels , weights = D_chunk [i ], length = label_freqs .shape [0 ]))
41
- return clust_dists , D_chunk , labels , label_freqs
37
+ # If running into memory issues, use fori_loop instead of vmap
38
+ # clust_dists = jnp.zeros((D_chunk_len, len(label_freqs)), dtype=D_chunk.dtype)
39
+ # def _bincount(i, _data):
40
+ # clust_dists, D_chunk, labels, label_freqs = _data
41
+ # clust_dists = clust_dists.at[i].set(jnp.bincount(labels, weights=D_chunk[i], length=label_freqs.shape[0]))
42
+ # return clust_dists, D_chunk, labels, label_freqs
42
43
43
- clust_dists = jax .lax .fori_loop (
44
- 0 , D_chunk_len , lambda i , _data : _bincount (i , _data ), (clust_dists , D_chunk , labels , label_freqs )
45
- )[0 ]
44
+ # clust_dists = jax.lax.fori_loop(
45
+ # 0, D_chunk_len, lambda i, _data: _bincount(i, _data), (clust_dists, D_chunk, labels, label_freqs)
46
+ # )[0]
47
+
48
+ clust_dists = jax .vmap (partial (jnp .bincount , length = label_freqs .shape [0 ]), in_axes = (None , 0 ))(labels , D_chunk )
46
49
47
50
# intra_index selects intra-cluster distances within clust_dists
48
51
intra_index = (jnp .arange (D_chunk_len ), jax .lax .dynamic_slice (labels , (start ,), (D_chunk_len ,)))
0 commit comments