Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ test-min = [
"pyarrow<21", # https://github.com/scikit-hep/awkward/issues/3579
"anndata[dask]",
]
test = [ "anndata[test-min,lazy]" ]
test = [ "anndata[test-min,lazy]", "jax", "jaxlib" ] # TODO: remove jax? own extra?
gpu = [ "cupy" ]
cu12 = [ "cupy-cuda12x" ]
cu11 = [ "cupy-cuda11x" ]
Expand Down
7 changes: 5 additions & 2 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
from scipy.sparse import issparse

from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray
from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray, has_xp
from .xarray import Dataset2D

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,8 +115,11 @@ def name_idx(i):
if isinstance(indexer.data, DaskArray):
return indexer.data.compute()
return indexer.data
elif has_xp(indexer):
msg = "Need to implement array api-based indexing"
raise NotImplementedError(msg)
msg = f"Unknown indexer {indexer!r} of type {type(indexer)}"
raise IndexError()
raise IndexError(msg)


def _fix_slice_bounds(s: slice, length: int) -> slice:
Expand Down
4 changes: 3 additions & 1 deletion src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from anndata.compat import CSArray, CSMatrix

from .._warnings import ImplicitModificationWarning
from ..compat import XDataset
from ..compat import XDataset, has_xp
from ..utils import (
ensure_df_homogeneous,
join_english,
Expand Down Expand Up @@ -67,6 +67,8 @@ def coerce_array(
return np.array(value)
except (ValueError, TypeError) as _e:
e = _e
if has_xp(value):
return value
# if value isn’t the right type or convertible, raise an error
msg = f"{name} needs to be of one of {join_english(map(str, array_data_structure_types))}, not {type(value)}."
if e is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CupyCSRMatrix,
DaskArray,
ZappyArray,
has_xp,
)
from .access import ElementRef
from .xarray import Dataset2D
Expand Down Expand Up @@ -292,6 +293,9 @@ def __setattr__(self, key: str, value: Any):

@singledispatch
def as_view(obj, view_args):
if has_xp(obj):
# TODO: Determine if we need some sort of specific view object for array-api
return obj
msg = f"No view type has been registered for {type(obj)}"
raise NotImplementedError(msg)

Expand Down
9 changes: 9 additions & 0 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
import scipy
from array_api_compat import get_namespace as array_api_get_namespace
from packaging.version import Version
from zarr import Array as ZarrArray # noqa: F401
from zarr import Group as ZarrGroup
Expand Down Expand Up @@ -415,3 +416,11 @@ def _map_cat_to_str(cat: pd.Categorical) -> pd.Categorical:
return cat.map(str, na_action="ignore")
else:
return cat.map(str)


def has_xp(mod):
try:
array_api_get_namespace(mod)
return True
except TypeError:
return False
7 changes: 7 additions & 0 deletions tests/test_obsmvarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def test_setting_daskarray(adata: AnnData):
assert h == joblib.hash(adata)


def test_setting_jax(adata: AnnData):
import jax.numpy as jnp

adata.obsm["jax"] = jnp.ones((adata.shape[0], 10))
assert isinstance(adata.obsm["jax"], jnp.ndarray)


def test_shape_error(adata: AnnData):
with pytest.raises(
ValueError,
Expand Down
42 changes: 42 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,48 @@ def test_index_float_sequence_raises_error(index):
gen_adata((10, 10))[index]


def test_jax_indexer():
import jax.numpy as jnp

index = np.array([0, 3, 6])
index_jax = jnp.array(index)
adata = gen_adata((10, 10))
assert_equal(adata[index], adata[index_jax])


@pytest.mark.parametrize(
"index",
[
np.array([0, 3, 6]),
slice(3),
Ellipsis,
(np.array([0, 3, 6]), np.array([1, 4, 7])),
(
np.array([([True] * 5) + ([False] * 5)]),
np.array([([True] * 5) + ([False] * 5)]),
),
(
np.array([0, 3, 6]),
np.array([([True] * 5) + ([False] * 5)]),
),
],
ids=[
"integer-array",
"slice",
"ellipsis",
"two-axis-integer-arrays",
"two-axis-boolean-arrays",
"mixed-array-type",
],
)
def test_index_into_jax(index):
import jax.numpy as jnp

adata = ad.AnnData(X=np.ones((10, 10)))
adata_as_jax = ad.AnnData(X=jnp.ones((10, 10)))
assert_equal(adata[index], adata_as_jax[index])


# @pytest.mark.parametrize("dim", ["obs", "var"])
# @pytest.mark.parametrize(
# ("idx", "pat"),
Expand Down
6 changes: 6 additions & 0 deletions tests/test_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,9 @@ def test_fail_on_non_csr_csc_matrix():
match=r"Only CSR and CSC.*",
):
ad.AnnData(X=X)


def test_create_anndata_with_jax():
import jax.numpy as jnp

ad.AnnData(X=jnp.ones((10, 10)))
Loading