Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
21d5882
chore: add tests
ilan-gold Jul 29, 2025
7cd6d69
fix: allow creating objects with array-api
ilan-gold Jul 29, 2025
eff9dde
chore: add indexing test
ilan-gold Jul 29, 2025
b230d11
fix: add xp pass-through
ilan-gold Jul 29, 2025
692a270
chore: add more indexing methods
ilan-gold Jul 29, 2025
b52e5b7
Merge branch 'main' into ig/array_api_starter
ilan-gold Jul 29, 2025
806becf
initial backend step
amalia-k510 Aug 7, 2025
98d249a
concat lazy error fix
amalia-k510 Aug 9, 2025
cdc4fdd
comment for merge fix
amalia-k510 Aug 9, 2025
030f985
dask array fix
amalia-k510 Aug 9, 2025
5aff4d6
fix test_concatenate_roundtrip[inner-np_array-pandas-concat-lazy-conc…
amalia-k510 Aug 9, 2025
ba74743
fixed all contact errors
amalia-k510 Aug 9, 2025
a749637
comments fix
amalia-k510 Aug 9, 2025
c716fb5
quick fix but not ideal as it converts jax and other arrays to numpy
amalia-k510 Aug 9, 2025
951c026
Merge branch 'main' into ig/array_api_continue
amalia-k510 Aug 11, 2025
d8adf27
comment fix
amalia-k510 Aug 11, 2025
0e410a4
Merge branch 'ig/array_api_continue' of github.com:amalia-k510/anndat…
amalia-k510 Aug 11, 2025
8ea34e8
comment fix
amalia-k510 Aug 11, 2025
96992d5
extra tests and function changes
amalia-k510 Aug 14, 2025
460428a
dlpack introduction and trying to make gen_adata fully backend native…
amalia-k510 Aug 21, 2025
dd3b867
removed unnecessary function
amalia-k510 Aug 21, 2025
743ebb3
minor fixes
amalia-k510 Aug 21, 2025
5a6c825
minor fix
amalia-k510 Aug 21, 2025
90d9e6a
begin merge modification
amalia-k510 Aug 25, 2025
787feb0
concat on jax arrays is introduced
amalia-k510 Aug 26, 2025
dea107d
precommit fixes
amalia-k510 Aug 26, 2025
cdd3747
merge quick fix
amalia-k510 Aug 26, 2025
b04c8ef
minor fixes
amalia-k510 Aug 26, 2025
383c445
writer and reindexer introduced + jax in the tests, still need to fix…
amalia-k510 Aug 28, 2025
eef9015
just concat errors left
amalia-k510 Sep 4, 2025
2d2275b
test fixes for jax
amalia-k510 Sep 11, 2025
a2a8606
indexer implementation and comments addressed
amalia-k510 Sep 15, 2025
dcbd235
test_double_index_jax fixed
amalia-k510 Sep 15, 2025
c227265
register default case, merge issues, and gpu/cpu array transfer addre…
amalia-k510 Oct 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ test-min = [
"pyarrow<21", # https://github.com/scikit-hep/awkward/issues/3579
"anndata[dask]",
]
test = [ "anndata[test-min,lazy]" ]
test = [
"anndata[test-min,lazy]",
"anndata[jax]",
] # JAX is defined as a separate extra to avoid forcing installation for all users
gpu = [ "cupy" ]
cu12 = [ "cupy-cuda12x" ]
cu11 = [ "cupy-cuda11x" ]
Expand All @@ -110,6 +113,10 @@ lazy = [ "xarray>=2025.06.1", "aiohttp", "requests", "anndata[dask]" ]
# https://github.com/dask/dask/issues/11290
# https://github.com/dask/dask/issues/11752
dask = [ "dask[array]>=2023.5.1,!=2024.8.*,!=2024.9.*,<2025.2.0" ]
jax = [
"jax>=0.6.0",
"jaxlib>=0.6.0",
]

[tool.hatch.version]
source = "vcs"
Expand Down
6 changes: 6 additions & 0 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,12 @@ def X(self, value: XDataType | None): # noqa: PLR0912
ImplicitModificationWarning,
stacklevel=2,
)
dest = self._adata_ref._X
# Handles read-only NumPy views from backend arrays like JAX by
# making a writable copy so in-place assignment on views can succeed.
if isinstance(dest, np.ndarray) and not dest.flags.writeable:
dest = np.array(dest, copy=True) # make a fresh, writable buffer
self._adata_ref._X = dest
Comment on lines +670 to +675
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle

self._adata_ref._X[oidx, vidx] = value
else:
self._X = value
Expand Down
69 changes: 66 additions & 3 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
from scipy.sparse import issparse

from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray
from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray, has_xp
from .xarray import Dataset2D

if TYPE_CHECKING:
Expand Down Expand Up @@ -39,6 +39,7 @@ def _normalize_index( # noqa: PLR0911, PLR0912
indexer: Index1D, index: pd.Index
) -> Index1DNorm | int | np.integer:
# TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough.
# protect aroound weird numeric index
if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64):
msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}"
raise TypeError(msg)
Expand All @@ -50,6 +51,7 @@ def name_idx(i):
i = index.get_loc(i)
return i

# converting start and stop of the slide to the integer positions if they are strings
if isinstance(indexer, slice):
start = name_idx(indexer.start)
stop = name_idx(indexer.stop)
Expand All @@ -65,17 +67,21 @@ def name_idx(i):
elif isinstance(
indexer, Sequence | np.ndarray | pd.Index | CSMatrix | np.matrix | CSArray
):
# convert to the 1D if it's accidentally 2D column/row vector
# convert sparse into dense arrays if needed
if hasattr(indexer, "shape") and (
(indexer.shape == (index.shape[0], 1))
or (indexer.shape == (1, index.shape[0]))
):
if isinstance(indexer, CSMatrix | CSArray):
indexer = indexer.toarray()
indexer = np.ravel(indexer)
# if it is something else, convert it to numpy
if not isinstance(indexer, np.ndarray | pd.Index):
indexer = np.array(indexer)
if len(indexer) == 0:
indexer = indexer.astype(int)
# if it is a float array or something along those lines, convert it to integers
if isinstance(indexer, np.ndarray) and np.issubdtype(
indexer.dtype, np.floating
):
Expand All @@ -94,7 +100,7 @@ def name_idx(i):
)
raise IndexError(msg)
return indexer
else: # indexer should be string array
else:
positions = index.get_indexer(indexer)
if np.any(positions < 0):
not_found = indexer[positions < 0]
Expand All @@ -108,8 +114,65 @@ def name_idx(i):
if isinstance(indexer.data, DaskArray):
return indexer.data.compute()
return indexer.data

elif has_xp(indexer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.

# getting array's namespace
xp = indexer.__array_namespace__()

# Flatten to 1D
if hasattr(indexer, "shape") and (
indexer.shape == (index.shape[0], 1) or indexer.shape == (1, index.shape[0])
):
indexer = xp.ravel(
indexer
) # flattening to 1D, jax.numpy has it, not sure about cubed

# Get dtype in array-api-style
dtype = getattr(indexer, "dtype", None)

# if we have like a jax boolean mask array
if xp.issubdtype(dtype, xp.bool_):
if indexer.shape != index.shape:
msg = (
f"Boolean index does not match AnnData’s shape along this dimension. "
f"Boolean index has shape {indexer.shape}, expected {index.shape}"
)
raise IndexError(msg)
return indexer

# all good, you can return it
elif xp.issubdtype(dtype, xp.integer):
return indexer
# float number case
elif xp.issubdtype(dtype, xp.floating):
indexer_int = xp.astype(indexer, xp.int32) # jax default to it
# If all floats were “safe” (like 0.0, 1.0, 2.0), return them cast to integers.
is_fractional = xp.not_equal(indexer, xp.astype(indexer_int, xp.floating))
if xp.any(is_fractional):
msg = f"Indexer {indexer!r} has non-integer floating point values."
raise IndexError(msg)
return indexer_int

else:
try:
values = indexer.tolist() # converting to the list
except Exception:
msg = f"Could not convert {indexer!r} to list for string lookup."
raise IndexError(msg)
positions = index.get_indexer(values)
if np.any(positions < 0):
not_found = [
v for v, p in zip(values, positions, strict=False) if p < 0
]
msg = (
f"Values {not_found}, from {values}, "
"are not valid obs/ var names or indices."
)
raise KeyError(msg)
return positions

msg = f"Unknown indexer {indexer!r} of type {type(indexer)}"
raise IndexError()
raise IndexError(msg)


def _fix_slice_bounds(s: slice, length: int) -> slice:
Expand Down
Loading