Skip to content
Draft
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
21d5882
chore: add tests
ilan-gold Jul 29, 2025
7cd6d69
fix: allow creating objects with array-api
ilan-gold Jul 29, 2025
eff9dde
chore: add indexing test
ilan-gold Jul 29, 2025
b230d11
fix: add xp pass-through
ilan-gold Jul 29, 2025
692a270
chore: add more indexing methods
ilan-gold Jul 29, 2025
b52e5b7
Merge branch 'main' into ig/array_api_starter
ilan-gold Jul 29, 2025
806becf
initial backend step
amalia-k510 Aug 7, 2025
98d249a
concat lazy error fix
amalia-k510 Aug 9, 2025
cdc4fdd
comment for merge fix
amalia-k510 Aug 9, 2025
030f985
dask array fix
amalia-k510 Aug 9, 2025
5aff4d6
fix test_concatenate_roundtrip[inner-np_array-pandas-concat-lazy-conc…
amalia-k510 Aug 9, 2025
ba74743
fixed all contact errors
amalia-k510 Aug 9, 2025
a749637
comments fix
amalia-k510 Aug 9, 2025
c716fb5
quick fix but not ideal as it converts jax and other arrays to numpy
amalia-k510 Aug 9, 2025
951c026
Merge branch 'main' into ig/array_api_continue
amalia-k510 Aug 11, 2025
d8adf27
comment fix
amalia-k510 Aug 11, 2025
0e410a4
Merge branch 'ig/array_api_continue' of github.com:amalia-k510/anndat…
amalia-k510 Aug 11, 2025
8ea34e8
comment fix
amalia-k510 Aug 11, 2025
96992d5
extra tests and function changes
amalia-k510 Aug 14, 2025
460428a
dlpack introduction and trying to make gen_adata fully backend native…
amalia-k510 Aug 21, 2025
dd3b867
removed unnecessary function
amalia-k510 Aug 21, 2025
743ebb3
minor fixes
amalia-k510 Aug 21, 2025
5a6c825
minor fix
amalia-k510 Aug 21, 2025
90d9e6a
begin merge modification
amalia-k510 Aug 25, 2025
787feb0
concat on jax arrays is introduced
amalia-k510 Aug 26, 2025
dea107d
precommit fixes
amalia-k510 Aug 26, 2025
cdd3747
merge quick fix
amalia-k510 Aug 26, 2025
b04c8ef
minor fixes
amalia-k510 Aug 26, 2025
383c445
writer and reindexer introduced + jax in the tests, still need to fix…
amalia-k510 Aug 28, 2025
eef9015
just concat errors left
amalia-k510 Sep 4, 2025
2d2275b
test fixes for jax
amalia-k510 Sep 11, 2025
a2a8606
indexer implementation and comments addressed
amalia-k510 Sep 15, 2025
dcbd235
test_double_index_jax fixed
amalia-k510 Sep 15, 2025
c227265
register default case, merge issues, and gpu/cpu array transfer addre…
amalia-k510 Oct 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ test-min = [
"pyarrow<21", # https://github.com/scikit-hep/awkward/issues/3579
"anndata[dask]",
]
test = [ "anndata[test-min,lazy]" ]
test = [ "anndata[test-min,lazy]", "jax", "jaxlib" ] # TODO: remove jax? own extra?
gpu = [ "cupy" ]
cu12 = [ "cupy-cuda12x" ]
cu11 = [ "cupy-cuda11x" ]
Expand Down
6 changes: 6 additions & 0 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,12 @@ def X(self, value: XDataType | None): # noqa: PLR0912
ImplicitModificationWarning,
stacklevel=2,
)
dest = self._adata_ref._X
# Handles read-only NumPy views from backend arrays like JAX by
# making a writable copy so in-place assignment on views can succeed.
if isinstance(dest, np.ndarray) and not dest.flags.writeable:
dest = np.array(dest, copy=True) # make a fresh, writable buffer
self._adata_ref._X = dest
Comment on lines +670 to +675
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle

self._adata_ref._X[oidx, vidx] = value
else:
self._X = value
Expand Down
7 changes: 5 additions & 2 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
from scipy.sparse import issparse

from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray
from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray, has_xp
from .xarray import Dataset2D

if TYPE_CHECKING:
Expand Down Expand Up @@ -108,8 +108,11 @@ def name_idx(i):
if isinstance(indexer.data, DaskArray):
return indexer.data.compute()
return indexer.data
elif has_xp(indexer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.

msg = "Need to implement array api-based indexing"
raise NotImplementedError(msg)
msg = f"Unknown indexer {indexer!r} of type {type(indexer)}"
raise IndexError()
raise IndexError(msg)


def _fix_slice_bounds(s: slice, length: int) -> slice:
Expand Down
31 changes: 30 additions & 1 deletion src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import scipy
from natsort import natsorted
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype
from scipy import sparse

from anndata._core.file_backing import to_memory
Expand Down Expand Up @@ -538,6 +539,15 @@ def resolve_merge_strategy(
return strategy


def safe_to_numpy(x):
"""Convert to numpy array, handling JAX/Cupy arrays."""
if isinstance(x, pd.Series | pd.Index) or (
hasattr(x, "dtype") and is_extension_array_dtype(x.dtype)
):
return x
return np.asarray(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok nice this is the right direction no doubt! So what we want here probably is not to rely on asarray but dlpack to do the conversion. In short:

  1. We should have a check in _apply_to_array to see if something is array-api compatible but not a numpy ndarray.
  2. If this case is true, dlpack into numpy, recursively call _apply_to_array
  3. Then use dlpack to take the output of the recursive call to the original type before we went to numpy.

Does that make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a nice paradigm to follow for situations where we have an existing numpy or cupy implementation and it isn't clear how to use the array-api to achieve our aims. We should still try to use it as much as possible so that we can eventually remove numpy codepaths where possible, but this is a nice first step.



#####################
# Concatenation
#####################
Expand Down Expand Up @@ -658,7 +668,11 @@ def _apply_to_array(self, el, *, axis, fill_value=None):

indexer = self.idx

# Indexes real fast, and does outer indexing
# Fallback to numpy: keep pandas
# keep pandas EAs/Series/Index as-is; only normalize non-pandas arrays

el = safe_to_numpy(el)

return pd.api.extensions.take(
el, indexer, axis=axis, allow_fill=True, fill_value=fill_value
)
Expand Down Expand Up @@ -1399,6 +1413,20 @@ def concat_dataset2d_on_annot_axis(
return ds_concat_2d


def _to_numpy_if_array_api(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be no second copy of this that’s slightly different, only one!

if isinstance(x, np.ndarray | pd.DataFrame | pd.Series | DaskArray):
return x
try:
import array_api_compat as aac

# If this succeeds, it's an array-API array (e.g. JAX, cubed, cupy, dask)
aac.array_namespace(x)
return np.asarray(x)
except TypeError:
# Not an array-API object (or lib not available) → return unchanged
return x


def concat( # noqa: PLR0912, PLR0913, PLR0915
adatas: Collection[AnnData] | Mapping[str, AnnData],
*,
Expand Down Expand Up @@ -1785,6 +1813,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915
"not concatenating `.raw` attributes."
)
warn(msg, UserWarning, stacklevel=2)

return AnnData(
**{
"X": X,
Expand Down
4 changes: 3 additions & 1 deletion src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from anndata.compat import CSArray, CSMatrix

from .._warnings import ImplicitModificationWarning
from ..compat import XDataset
from ..compat import XDataset, has_xp
from ..utils import (
ensure_df_homogeneous,
join_english,
Expand Down Expand Up @@ -67,6 +67,8 @@ def coerce_array(
return np.array(value)
except (ValueError, TypeError) as _e:
e = _e
if has_xp(value):
return value
# if value isn’t the right type or convertible, raise an error
msg = f"{name} needs to be of one of {join_english(map(str, array_data_structure_types))}, not {type(value)}."
if e is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CupyCSRMatrix,
DaskArray,
ZappyArray,
has_xp,
)
from .access import ElementRef
from .xarray import Dataset2D
Expand Down Expand Up @@ -296,6 +297,9 @@ def __setattr__(self, key: str, value: Any):

@singledispatch
def as_view(obj, view_args):
if has_xp(obj):
# TODO: Determine if we need some sort of specific view object for array-api
return obj
msg = f"No view type has been registered for {type(obj)}"
raise NotImplementedError(msg)

Expand Down
52 changes: 52 additions & 0 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Generic, TypeVar

import numpy as np
import pandas as pd
from pandas.api.extensions import ExtensionArray
from scipy import sparse as sp

from anndata import AnnData
from anndata._io.utils import report_read_key_on_error, report_write_key_on_error
from anndata._types import Read, ReadLazy, _ReadInternal, _ReadLazyInternal
from anndata.compat import DaskArray, ZarrGroup, _read_attr, is_zarr_v2
Expand All @@ -34,6 +40,48 @@
LazyDataStructures = DaskArray | Dataset2D | CategoricalArray | MaskedArray


def is_sparse_like(x):
try:
return sp.issparse(x)
except AttributeError:
return False


def to_numpy_if_array_api(x):
if isinstance(
x,
np.ndarray
| np.generic
| pd.DataFrame
| pd.Series
| pd.Index
| ExtensionArray
| DaskArray
| sp.spmatrix
| AnnData,
):
return x
Comment on lines +50 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this check, I would think this function should be called to_writeable, no?


# Try array-API detection only for unknown leaves
try:
import array_api_compat as aac

# If this succeeds, it's an array-API array (e.g. JAX, CuPy, torch, …)
aac.array_namespace(x)
return np.asarray(x)
except (ImportError, AttributeError, TypeError):
# Not an array-API object (or not supported), so return unchanged
return x


def normalize_nested(obj):
if isinstance(obj, dict):
return {k: normalize_nested(v) for k, v in obj.items()}
if isinstance(obj, list | tuple):
return type(obj)(normalize_nested(v) for v in obj)
return to_numpy_if_array_api(obj)


# TODO: This probably should be replaced by a hashable Mapping due to conversion b/w "_" and "-"
# TODO: Should filetype be included in the IOSpec if it changes the encoding? Or does the intent that these things be "the same" overrule that?
@dataclass(frozen=True)
Expand Down Expand Up @@ -386,6 +434,10 @@ def write_elem(
elif k in store:
del store[k]

# Normalize array-API (e.g., JAX/CuPy) payloads buried in mappings/lists
if not isinstance(elem, AnnData):
elem = normalize_nested(elem)
Comment on lines +461 to +462
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want to also normalize AnnData objects so they're sub elements are also corrected?


write_func = self.find_write_func(dest_type, elem, modifiers)

if self.callback is None:
Expand Down
9 changes: 9 additions & 0 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
import scipy
from array_api_compat import get_namespace as array_api_get_namespace
from numpy.typing import NDArray
from packaging.version import Version
from zarr import Array as ZarrArray # noqa: F401
Expand Down Expand Up @@ -436,3 +437,11 @@ def _map_cat_to_str(cat: pd.Categorical) -> pd.Categorical:
return cat.map(str, na_action="ignore")
else:
return cat.map(str)


def has_xp(mod):
try:
array_api_get_namespace(mod)
return True
except TypeError:
return False
Loading
Loading