diff --git a/pyproject.toml b/pyproject.toml index e479e266c..fb23a2a6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ test-min = [ "pyarrow<21", # https://github.com/scikit-hep/awkward/issues/3579 "anndata[dask]", ] -test = [ "anndata[test-min,lazy]" ] +test = [ "anndata[test-min,lazy]", "jax", "jaxlib" ] # TODO: remove jax? own extra? gpu = [ "cupy" ] cu12 = [ "cupy-cuda12x" ] cu11 = [ "cupy-cuda11x" ] diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 5ed271add..a9177f78f 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -10,7 +10,7 @@ import pandas as pd from scipy.sparse import issparse -from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray +from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray, has_xp from .xarray import Dataset2D if TYPE_CHECKING: @@ -115,8 +115,11 @@ def name_idx(i): if isinstance(indexer.data, DaskArray): return indexer.data.compute() return indexer.data + elif has_xp(indexer): + msg = "Need to implement array api-based indexing" + raise NotImplementedError(msg) msg = f"Unknown indexer {indexer!r} of type {type(indexer)}" - raise IndexError() + raise IndexError(msg) def _fix_slice_bounds(s: slice, length: int) -> slice: diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index e3e00ba7d..42afc50b8 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -10,7 +10,7 @@ from anndata.compat import CSArray, CSMatrix from .._warnings import ImplicitModificationWarning -from ..compat import XDataset +from ..compat import XDataset, has_xp from ..utils import ( ensure_df_homogeneous, join_english, @@ -67,6 +67,8 @@ def coerce_array( return np.array(value) except (ValueError, TypeError) as _e: e = _e + if has_xp(value): + return value # if value isn’t the right type or convertible, raise an error msg = f"{name} needs to be of one of {join_english(map(str, array_data_structure_types))}, not {type(value)}." if e is not None: diff --git a/src/anndata/_core/views.py b/src/anndata/_core/views.py index ac9a0dd0f..a7f74ef27 100644 --- a/src/anndata/_core/views.py +++ b/src/anndata/_core/views.py @@ -21,6 +21,7 @@ CupyCSRMatrix, DaskArray, ZappyArray, + has_xp, ) from .access import ElementRef from .xarray import Dataset2D @@ -292,6 +293,9 @@ def __setattr__(self, key: str, value: Any): @singledispatch def as_view(obj, view_args): + if has_xp(obj): + # TODO: Determine if we need some sort of specific view object for array-api + return obj msg = f"No view type has been registered for {type(obj)}" raise NotImplementedError(msg) diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 6eb4da48b..c49ca2a62 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd import scipy +from array_api_compat import get_namespace as array_api_get_namespace from packaging.version import Version from zarr import Array as ZarrArray # noqa: F401 from zarr import Group as ZarrGroup @@ -415,3 +416,11 @@ def _map_cat_to_str(cat: pd.Categorical) -> pd.Categorical: return cat.map(str, na_action="ignore") else: return cat.map(str) + + +def has_xp(mod): + try: + array_api_get_namespace(mod) + return True + except TypeError: + return False diff --git a/tests/test_obsmvarm.py b/tests/test_obsmvarm.py index e2513e3a8..02bf136df 100644 --- a/tests/test_obsmvarm.py +++ b/tests/test_obsmvarm.py @@ -144,6 +144,13 @@ def test_setting_daskarray(adata: AnnData): assert h == joblib.hash(adata) +def test_setting_jax(adata: AnnData): + import jax.numpy as jnp + + adata.obsm["jax"] = jnp.ones((adata.shape[0], 10)) + assert isinstance(adata.obsm["jax"], jnp.ndarray) + + def test_shape_error(adata: AnnData): with pytest.raises( ValueError, diff --git a/tests/test_views.py b/tests/test_views.py index 29a02b503..4e12f251b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -824,6 +824,48 @@ def test_index_float_sequence_raises_error(index): gen_adata((10, 10))[index] +def test_jax_indexer(): + import jax.numpy as jnp + + index = np.array([0, 3, 6]) + index_jax = jnp.array(index) + adata = gen_adata((10, 10)) + assert_equal(adata[index], adata[index_jax]) + + +@pytest.mark.parametrize( + "index", + [ + np.array([0, 3, 6]), + slice(3), + Ellipsis, + (np.array([0, 3, 6]), np.array([1, 4, 7])), + ( + np.array([([True] * 5) + ([False] * 5)]), + np.array([([True] * 5) + ([False] * 5)]), + ), + ( + np.array([0, 3, 6]), + np.array([([True] * 5) + ([False] * 5)]), + ), + ], + ids=[ + "integer-array", + "slice", + "ellipsis", + "two-axis-integer-arrays", + "two-axis-boolean-arrays", + "mixed-array-type", + ], +) +def test_index_into_jax(index): + import jax.numpy as jnp + + adata = ad.AnnData(X=np.ones((10, 10))) + adata_as_jax = ad.AnnData(X=jnp.ones((10, 10))) + assert_equal(adata[index], adata_as_jax[index]) + + # @pytest.mark.parametrize("dim", ["obs", "var"]) # @pytest.mark.parametrize( # ("idx", "pat"), diff --git a/tests/test_x.py b/tests/test_x.py index d7da59a0c..c1ebfe840 100644 --- a/tests/test_x.py +++ b/tests/test_x.py @@ -192,3 +192,9 @@ def test_fail_on_non_csr_csc_matrix(): match=r"Only CSR and CSC.*", ): ad.AnnData(X=X) + + +def test_create_anndata_with_jax(): + import jax.numpy as jnp + + ad.AnnData(X=jnp.ones((10, 10)))