Skip to content

Commit da0b29f

Browse files
committed
WIP: Add top_k compatibility
This references the PR data-apis/array-api-tests#274.
1 parent 51daace commit da0b29f

File tree

8 files changed

+159
-5
lines changed

8 files changed

+159
-5
lines changed
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: Array API Tests (JAX)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-jax:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: jax
10+
# See https://github.com/google/jax/issues/22137 for reason behind skipped dtypes
11+
extra-env-vars: |
12+
JAX_ENABLE_X64=1
13+
ARRAY_API_TESTS_SKIP_DTYPES=uint8,uint16,uint32,uint64

.github/workflows/array-api-tests.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ on:
3333
description: "Multiline string of environment variables to set for the test run."
3434

3535
env:
36-
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
36+
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} -k top_k --hypothesis-disable-deadline"
3737

3838
jobs:
3939
tests:
@@ -50,9 +50,10 @@ jobs:
5050
- name: Checkout array-api-tests
5151
uses: actions/checkout@v4
5252
with:
53-
repository: data-apis/array-api-tests
53+
repository: JuliaPoo/array-api-tests
5454
submodules: 'true'
5555
path: array-api-tests
56+
ref: ci-wip-topk-tests
5657
- name: Set up Python ${{ matrix.python-version }}
5758
uses: actions/setup-python@v5
5859
with:
@@ -77,6 +78,7 @@ jobs:
7778
# This enables the NEP 50 type promotion behavior (without it a lot of
7879
# tests fail on bad scalar type promotion behavior)
7980
NPY_PROMOTION_STATE: weak
81+
ARRAY_API_TESTS_VERSION: draft
8082
run: |
8183
export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat"
8284
cd ${GITHUB_WORKSPACE}/array-api-tests

array_api_compat/dask/array/_aliases.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,28 @@ def asarray(
150150

151151
return da.asarray(obj, dtype=dtype, **kwargs)
152152

153+
154+
def top_k(
155+
x: Array,
156+
k: int,
157+
/,
158+
axis: Optional[int] = None,
159+
*,
160+
largest: bool = True,
161+
) -> tuple[Array, Array]:
162+
163+
if not largest:
164+
k = -k
165+
166+
# For now, perform the computation twice,
167+
# since an equivalent to numpy's `take_along_axis`
168+
# does not exist.
169+
# See https://github.com/dask/dask/issues/3663.
170+
args = da.argtopk(x, k, axis=axis).compute()
171+
vals = da.topk(x, k, axis=axis).compute()
172+
return vals, args
173+
174+
153175
from dask.array import (
154176
# Element wise aliases
155177
arccos as acos,
@@ -178,6 +200,7 @@ def asarray(
178200
'bitwise_right_shift', 'concat', 'pow',
179201
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
180202
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
181-
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
203+
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type',
204+
'top_k']
182205

183206
_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']

array_api_compat/jax/__init__.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from jax.numpy import (
2+
# Constants
3+
e,
4+
inf,
5+
nan,
6+
pi,
7+
newaxis,
8+
# Dtypes
9+
bool,
10+
float32,
11+
float64,
12+
int8,
13+
int16,
14+
int32,
15+
int64,
16+
uint8,
17+
uint16,
18+
uint32,
19+
uint64,
20+
complex64,
21+
complex128,
22+
iinfo,
23+
finfo,
24+
can_cast,
25+
result_type,
26+
# functions
27+
zeros,
28+
all,
29+
any,
30+
isnan,
31+
isfinite,
32+
reshape
33+
)
34+
from jax.numpy import (
35+
asarray,
36+
s_,
37+
int_,
38+
argpartition,
39+
take_along_axis
40+
)
41+
42+
43+
def top_k(
44+
x,
45+
k,
46+
/,
47+
axis=None,
48+
*,
49+
largest=True,
50+
):
51+
# The largest keyword can't be implemented with `jax.lax.top_k`
52+
# efficiently so am using `jax.numpy` for now
53+
if k <= 0:
54+
raise ValueError(f'k(={k}) provided must be positive.')
55+
56+
positive_axis: int
57+
_arr = asarray(x)
58+
if axis is None:
59+
arr = _arr.ravel()
60+
positive_axis = 0
61+
else:
62+
arr = _arr
63+
positive_axis = axis if axis > 0 else axis % arr.ndim
64+
65+
slice_start = (s_[:],) * positive_axis
66+
if largest:
67+
indices_array = argpartition(arr, -k, axis=axis)
68+
slice = slice_start + (s_[-k:],)
69+
topk_indices = indices_array[slice]
70+
else:
71+
indices_array = argpartition(arr, k-1, axis=axis)
72+
slice = slice_start + (s_[:k],)
73+
topk_indices = indices_array[slice]
74+
75+
topk_indices = topk_indices.astype(int_)
76+
topk_values = take_along_axis(arr, topk_indices, axis=axis)
77+
return (topk_values, topk_indices)
78+
79+
80+
__all__ = ['top_k', 'e', 'inf', 'nan', 'pi', 'newaxis', 'bool',
81+
'float32', 'float64', 'int8', 'int16', 'int32',
82+
'int64', 'uint8', 'uint16', 'uint32', 'uint64',
83+
'complex64', 'complex128', 'iinfo', 'finfo',
84+
'can_cast', 'result_type', 'zeros', 'all', 'isnan',
85+
'isfinite', 'reshape', 'any']

array_api_compat/numpy/_aliases.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,35 @@
6161
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6262
tensordot = get_xp(np)(_aliases.tensordot)
6363

64+
65+
def top_k(a, k, /, axis=-1, *, largest=True):
66+
if k <= 0:
67+
raise ValueError(f'k(={k}) provided must be positive.')
68+
69+
positive_axis: int
70+
_arr = np.asanyarray(a)
71+
if axis is None:
72+
arr = _arr.ravel()
73+
positive_axis = 0
74+
else:
75+
arr = _arr
76+
positive_axis = axis if axis > 0 else axis % arr.ndim
77+
78+
slice_start = (np.s_[:],) * positive_axis
79+
if largest:
80+
indices_array = np.argpartition(arr, -k, axis=axis)
81+
slice = slice_start + (np.s_[-k:],)
82+
topk_indices = indices_array[slice]
83+
else:
84+
indices_array = np.argpartition(arr, k-1, axis=axis)
85+
slice = slice_start + (np.s_[:k],)
86+
topk_indices = indices_array[slice]
87+
88+
topk_values = np.take_along_axis(arr, topk_indices, axis=axis)
89+
90+
return (topk_values, topk_indices)
91+
92+
6493
def _supports_buffer_protocol(obj):
6594
try:
6695
memoryview(obj)
@@ -126,6 +155,6 @@ def asarray(
126155
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
127156
'acosh', 'asin', 'asinh', 'atan', 'atan2',
128157
'atanh', 'bitwise_left_shift', 'bitwise_invert',
129-
'bitwise_right_shift', 'concat', 'pow']
158+
'bitwise_right_shift', 'concat', 'pow', 'top_k']
130159

131160
_all_ignore = ['np', 'get_xp']

array_api_compat/torch/_aliases.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
700700
axis = 0
701701
return torch.index_select(x, axis, indices, **kwargs)
702702

703+
top_k = torch.topk
704+
703705
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
704706
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
705707
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
@@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
713715
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
714716
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
715717
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
716-
'take']
718+
'take', 'top_k']
717719

718720
_all_ignore = ['torch', 'get_xp']

jax-skips.txt

Whitespace-only changes.

jax-xfails.txt

Whitespace-only changes.

0 commit comments

Comments
 (0)