Skip to content

Commit 2cfbd4f

Browse files
committed
fix jax top_k
1 parent 966ae70 commit 2cfbd4f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

array_api_compat/jax/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def top_k(
4343
slice = slice_start + (s_[:k],)
4444
topk_indices = indices_array[slice]
4545

46+
topk_indices = topk_indices.astype(np.int_)
4647
topk_values = take_along_axis(arr, topk_indices, axis=axis)
47-
4848
return (topk_values, topk_indices)
4949

5050

0 commit comments

Comments
 (0)