Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
21d5882
chore: add tests
ilan-gold Jul 29, 2025
7cd6d69
fix: allow creating objects with array-api
ilan-gold Jul 29, 2025
eff9dde
chore: add indexing test
ilan-gold Jul 29, 2025
b230d11
fix: add xp pass-through
ilan-gold Jul 29, 2025
692a270
chore: add more indexing methods
ilan-gold Jul 29, 2025
b52e5b7
Merge branch 'main' into ig/array_api_starter
ilan-gold Jul 29, 2025
806becf
initial backend step
amalia-k510 Aug 7, 2025
98d249a
concat lazy error fix
amalia-k510 Aug 9, 2025
cdc4fdd
comment for merge fix
amalia-k510 Aug 9, 2025
030f985
dask array fix
amalia-k510 Aug 9, 2025
5aff4d6
fix test_concatenate_roundtrip[inner-np_array-pandas-concat-lazy-conc…
amalia-k510 Aug 9, 2025
ba74743
fixed all contact errors
amalia-k510 Aug 9, 2025
a749637
comments fix
amalia-k510 Aug 9, 2025
c716fb5
quick fix but not ideal as it converts jax and other arrays to numpy
amalia-k510 Aug 9, 2025
951c026
Merge branch 'main' into ig/array_api_continue
amalia-k510 Aug 11, 2025
d8adf27
comment fix
amalia-k510 Aug 11, 2025
0e410a4
Merge branch 'ig/array_api_continue' of github.com:amalia-k510/anndat…
amalia-k510 Aug 11, 2025
8ea34e8
comment fix
amalia-k510 Aug 11, 2025
96992d5
extra tests and function changes
amalia-k510 Aug 14, 2025
460428a
dlpack introduction and trying to make gen_adata fully backend native…
amalia-k510 Aug 21, 2025
dd3b867
removed unnecessary function
amalia-k510 Aug 21, 2025
743ebb3
minor fixes
amalia-k510 Aug 21, 2025
5a6c825
minor fix
amalia-k510 Aug 21, 2025
90d9e6a
begin merge modification
amalia-k510 Aug 25, 2025
787feb0
concat on jax arrays is introduced
amalia-k510 Aug 26, 2025
dea107d
precommit fixes
amalia-k510 Aug 26, 2025
cdd3747
merge quick fix
amalia-k510 Aug 26, 2025
b04c8ef
minor fixes
amalia-k510 Aug 26, 2025
383c445
writer and reindexer introduced + jax in the tests, still need to fix…
amalia-k510 Aug 28, 2025
eef9015
just concat errors left
amalia-k510 Sep 4, 2025
2d2275b
test fixes for jax
amalia-k510 Sep 11, 2025
a2a8606
indexer implementation and comments addressed
amalia-k510 Sep 15, 2025
dcbd235
test_double_index_jax fixed
amalia-k510 Sep 15, 2025
c227265
register default case, merge issues, and gpu/cpu array transfer addre…
amalia-k510 Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def name_idx(i):
elif isinstance(indexer, XDataArray):
if isinstance(indexer.data, DaskArray):
return indexer.data.compute()
return indexer.dat
return indexer.data

elif has_xp(indexer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.

# getting array's namespace
Expand Down
48 changes: 3 additions & 45 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,8 @@ def apply(self, el, *, axis, fill_value=None): # noqa: PLR0911
"""
if self.no_change and (axis_len(el, axis) == len(self.old_idx)):
return el
if _is_pandas(el) or isinstance(el, pd.DataFrame):
return self._apply_to_df_like(el, axis=axis, fill_value=fill_value)
Comment on lines +690 to +691
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add this?

if isinstance(el, pd.DataFrame | Dataset2D):
return self._apply_to_df_like(el, axis=axis, fill_value=fill_value)
elif isinstance(el, CSMatrix | CSArray | CupySparseMatrix):
Expand Down Expand Up @@ -753,16 +755,8 @@ def _apply_to_cupy_array(self, el, *, axis, fill_value=None):
def _apply_to_array_api(self, el, *, axis, fill_value=None):
if fill_value is None:
fill_value = default_fill_value([el])

indexer = self.idx

if _is_pandas(el):
# using the behavior that already exists in pandas for Series/Index
return pd.api.extensions.take(
el, indexer, axis=axis, allow_fill=True, fill_value=fill_value
)
# e.g., numpy, jax.numpy, cubed, etc.
xp = get_namespace(el)
indexer = xp.asarray(self.idx)

# Handling edge case to mimic pandas behavior
if el.shape[axis] == 0:
Expand All @@ -777,15 +771,11 @@ def _apply_to_array_api(self, el, *, axis, fill_value=None):
if not isinstance(el, np.ndarray) and _is_array_api_compatible(el):
# Convert to NumPy via DLPack
el_np = _dlpack_to_numpy(el)

# Recursively call this same function
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value)

# reverting back to numpy as it is hard to reindex on JAX and others
return _dlpack_from_numpy(out_np, xp)

# numpy case
indexer = xp.asarray(indexer)
# marking which positions are missing, so we could use fill_value
missing_mask = indexer == -1
safe_indexer = xp.where(missing_mask, 0, indexer)
Expand Down Expand Up @@ -910,38 +900,6 @@ def merge_indices(inds: Iterable[pd.Index], join: Join_T) -> pd.Index:
raise ValueError(msg)


# def default_fill_value(els):
# # Given some arrays, returns what the default fill value should be.

# if any(
# isinstance(el, pd.DataFrame | pd.Series)
# for el in els
# if el is not None and el is not MissingVal
# ):
# return pd.NA

# if any(
# isinstance(el, CSMatrix | CSArray)
# or (isinstance(el, DaskArray) and isinstance(el._meta, CSMatrix | CSArray))
# for el in els
# ):
# return 0

# # Pick the namespace of the first valid element
# for el in els:
# if el is not None and el is not MissingVal:
# xp = get_namespace(el)
# # Not all backends have `nan` defined (e.g. integer dtypes)
# try:
# return xp.nan
# except AttributeError:
# # Fall back to 0 if no NaN in this backend
# return xp.asarray(0).item()

# # Fallback if list was empty or only MissingVal
# return np.nan


def default_fill_value(els):
"""Given some arrays, returns what the default fill value should be.

Expand Down
53 changes: 36 additions & 17 deletions src/anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from functools import reduce, singledispatch, wraps
from typing import TYPE_CHECKING, Literal

import jax.numpy as jnp
import numpy as np
import pandas as pd
from pandas.api.types import is_bool_dtype
Expand Down Expand Up @@ -116,7 +115,6 @@ def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]):

def _replace_field(container, idx, value):
import awkward as ak
import numpy as np

# JAX-style immutable array
if hasattr(container, "at"):
Expand Down Expand Up @@ -556,7 +554,42 @@ def _resolve_idxs(

@singledispatch
def _resolve_idx(old: Index1DNorm, new: Index1DNorm, l: Literal[0, 1]) -> Index1DNorm:
raise NotImplementedError
import array_api_compat as aac

from ..compat import has_xp

# handling array-API–compatible arrays (e.g. JAX, etc)
if has_xp(old):
xp = aac.array_namespace(old)
# skip early if numpy
if xp.__name__.startswith("numpy"):
return old[new]
Comment on lines +563 to +566
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why special case this? If isinstance(old, numpy.ndarray) then singledispatch handles this for us. If something is resolved to use numpy by array_api_compat (I don't see why this would ever be the case here, and I don't see anything in the docs about falling back to numpy)


# handle boolean mask; i.e. checking whether old is a boolean array
if hasattr(old, "dtype") and str(old.dtype) in ("bool", "bool_", "boolean"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If old is not array-api compatible, shouldn't we error out? old.dtype would exist and https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html#isdtype could handle checking for bool, no need for strings or anything

# retrieving the where function (like np.where or jnp.where)
where_fn = getattr(xp, "where", None)
# if where exists, use it to convert the boolean mask to integer indices
if where_fn is not None:
old = where_fn(old)[0]
else:
# if no where function is found, fallback to NumPy
old = np.where(np.asarray(old))[0]
Comment on lines +573 to +577
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# if new is a slice object, converting it into a range of indices using arange
if isinstance(new, slice):
# trying to get arange from the backend
arange_fn = getattr(xp, "arange", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# if arange exists, apply it to the resolved slice range
if arange_fn is not None:
new = arange_fn(*new.indices(old.shape[0]))
else:
new = np.arange(*new.indices(old.shape[0]))

return old[new]

msg = f"_resolve_idx not implemented for type {type(old)}"
raise NotImplementedError(msg)


@_resolve_idx.register(np.ndarray)
Expand All @@ -582,20 +615,6 @@ def _resolve_idx_slice(
return np.arange(*old.indices(l))[new]


@_resolve_idx.register(jnp.ndarray)
def _resolve_idx_jnp(
old: jnp.ndarray, new: Index1DNorm, l: Literal[0, 1]
) -> jnp.ndarray:
# Boolean mask + index
if old.dtype == jnp.bool_:
old = jnp.where(old)[0]

if isinstance(new, slice):
new = jnp.arange(*new.indices(old.shape[0]))

return old[new]


def _resolve_idx_slice_slice(old: slice, new: slice, l: Literal[0, 1]) -> slice:
r = range(*old.indices(l))[new]
# Convert back to slice
Expand Down
29 changes: 23 additions & 6 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,38 @@ def to_numpy_if_array_api(x):

# Try array-API detection only for unknown leaves
try:
# importing the array API compatibility layer
import array_api_compat as aac

# getting the array namespace
xp = aac.array_namespace(x)
# Already a NumPy array
# Skip if it's already NumPy
if xp.__name__.startswith("numpy"):
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could this condition be reached if we have the above isinstance(x, np.ndarray | ...) check?

# # If this succeeds, it's an array-API array (e.g. JAX, CuPy, torch, …)
# aac.array_namespace(x)
# return np.asarray(x)
# If the array has a `.device` attribute, check if it's on GPU
if hasattr(x, "device"):
device = x.device
# if the device has a `.type` field, use it
if hasattr(device, "type"):
if device.type != "cpu":
# move the array to CPU if it's on GPU
x = xp.to_device(
x, "cpu"
) # not sure about this, would we want to move the entire array to CPU?
# otherwise, if the device is not a string, check if it's not "cpu"
elif str(device) != "cpu":
x = xp.to_device(x, "cpu")

# Convert to NumPy using DLPack (safe now)
if hasattr(x, "__to_dlpack__"):
return np.from_dlpack(x)
except (ImportError, AttributeError, TypeError):
# Not an array-API object (or not supported), so return unchanged

except (ImportError, AttributeError, TypeError, RuntimeError):
# Could not detect or convert – return unchanged
return x

return x


def normalize_nested(obj):
if isinstance(obj, dict):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,12 +1833,11 @@ def test_error_on_mixed_device():
concat(p)


@pytest.mark.xfail(
condition=lambda: array_type is jnp.asarray,
reason="concat across different array backends is not supported",
)
def test_concat_on_var_outer_join(array_type):
# https://github.com/scverse/anndata/issues/1286
if array_type is jnp.asarray:
pytest.xfail("concat across different array backends is not supported")

a = AnnData(
obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(10)]),
var=pd.DataFrame(index=[f"gene_{i:02d}" for i in range(10)]),
Expand Down