Skip to content

Commit 3820460

Browse files
committed
fix jax top_k
1 parent 2cfbd4f commit 3820460

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

array_api_compat/jax/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def top_k(
4343
slice = slice_start + (s_[:k],)
4444
topk_indices = indices_array[slice]
4545

46-
topk_indices = topk_indices.astype(np.int_)
46+
topk_indices = topk_indices.astype(int_)
4747
topk_values = take_along_axis(arr, topk_indices, axis=axis)
4848
return (topk_values, topk_indices)
4949

0 commit comments

Comments
 (0)