Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
21d5882
chore: add tests
ilan-gold Jul 29, 2025
7cd6d69
fix: allow creating objects with array-api
ilan-gold Jul 29, 2025
eff9dde
chore: add indexing test
ilan-gold Jul 29, 2025
b230d11
fix: add xp pass-through
ilan-gold Jul 29, 2025
692a270
chore: add more indexing methods
ilan-gold Jul 29, 2025
b52e5b7
Merge branch 'main' into ig/array_api_starter
ilan-gold Jul 29, 2025
806becf
initial backend step
amalia-k510 Aug 7, 2025
98d249a
concat lazy error fix
amalia-k510 Aug 9, 2025
cdc4fdd
comment for merge fix
amalia-k510 Aug 9, 2025
030f985
dask array fix
amalia-k510 Aug 9, 2025
5aff4d6
fix test_concatenate_roundtrip[inner-np_array-pandas-concat-lazy-conc…
amalia-k510 Aug 9, 2025
ba74743
fixed all contact errors
amalia-k510 Aug 9, 2025
a749637
comments fix
amalia-k510 Aug 9, 2025
c716fb5
quick fix but not ideal as it converts jax and other arrays to numpy
amalia-k510 Aug 9, 2025
951c026
Merge branch 'main' into ig/array_api_continue
amalia-k510 Aug 11, 2025
d8adf27
comment fix
amalia-k510 Aug 11, 2025
0e410a4
Merge branch 'ig/array_api_continue' of github.com:amalia-k510/anndat…
amalia-k510 Aug 11, 2025
8ea34e8
comment fix
amalia-k510 Aug 11, 2025
96992d5
extra tests and function changes
amalia-k510 Aug 14, 2025
460428a
dlpack introduction and trying to make gen_adata fully backend native…
amalia-k510 Aug 21, 2025
dd3b867
removed unnecessary function
amalia-k510 Aug 21, 2025
743ebb3
minor fixes
amalia-k510 Aug 21, 2025
5a6c825
minor fix
amalia-k510 Aug 21, 2025
90d9e6a
begin merge modification
amalia-k510 Aug 25, 2025
787feb0
concat on jax arrays is introduced
amalia-k510 Aug 26, 2025
dea107d
precommit fixes
amalia-k510 Aug 26, 2025
cdd3747
merge quick fix
amalia-k510 Aug 26, 2025
b04c8ef
minor fixes
amalia-k510 Aug 26, 2025
383c445
writer and reindexer introduced + jax in the tests, still need to fix…
amalia-k510 Aug 28, 2025
eef9015
just concat errors left
amalia-k510 Sep 4, 2025
2d2275b
test fixes for jax
amalia-k510 Sep 11, 2025
a2a8606
indexer implementation and comments addressed
amalia-k510 Sep 15, 2025
dcbd235
test_double_index_jax fixed
amalia-k510 Sep 15, 2025
c227265
register default case, merge issues, and gpu/cpu array transfer addre…
amalia-k510 Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 62 additions & 27 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

# Enable DLPack interop for JAX, CuPy, etc., only if installed
with suppress(ImportError):
import jax
import jax.dlpack
pass
import pandas as pd
import pandas.api.types as pdf
import scipy
Expand Down Expand Up @@ -123,6 +122,7 @@ def not_missing(v) -> bool:


def _same_backend(x, y, *, copy: bool = True):
# TODO: convert it so that I could also use it to convert one array to another
# for merge implementation
# Makes sure two arrays are from the same array backend.
# If not, uses from_dlpack() to convert `y` to `x`'s backend.
Expand Down Expand Up @@ -619,8 +619,6 @@ def _is_array_api_compatible(x):


def _dlpack_to_numpy(x):
###TODO: FIX
# Convert array-api-compatible x to NumPy using DLPack.
try:
return np.from_dlpack(x)
except TypeError as e:
Expand All @@ -629,13 +627,22 @@ def _dlpack_to_numpy(x):


def _dlpack_from_numpy(x_np, original_xp):
###TODO: FIX
# cubed and other array later elif
# TODO: cubed and other array later elif
if hasattr(original_xp, "from_dlpack"):
try:
return original_xp.from_dlpack(x_np)
except Exception as e:
msg = f"Failed to call from_dlpack on backend {original_xp.__name__}: {e}"
raise TypeError(msg) from e

# fallback for known backends that put it elsewhere (JAX and later others)
if original_xp.__name__.startswith("jax"):
import jax.dlpack
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_xp should have a from_dlpack method! Does my comment here 383c445#r2291442475 not apply?


return jax.dlpack.from_dlpack(x_np)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
msg = f"DLPack back-conversion not implemented for {original_xp.__name__}"
raise TypeError(msg)

msg = f"DLPack back-conversion not implemented for backend {original_xp.__name__}"
raise TypeError(msg)


#####################
Expand Down Expand Up @@ -698,7 +705,7 @@ def apply(self, el, *, axis, fill_value=None): # noqa: PLR0911
elif isinstance(el, CupyArray):
return self._apply_to_cupy_array(el, axis=axis, fill_value=fill_value)
else:
return self._apply_to_array(el, axis=axis, fill_value=fill_value)
return self._apply_to_array_api(el, axis=axis, fill_value=fill_value)

def _apply_to_df_like(self, el: pd.DataFrame | Dataset2D, *, axis, fill_value=None):
if fill_value is None:
Expand Down Expand Up @@ -750,7 +757,7 @@ def _apply_to_cupy_array(self, el, *, axis, fill_value=None):

return out

def _apply_to_array(self, el, *, axis, fill_value=None):
def _apply_to_array_api(self, el, *, axis, fill_value=None):
if fill_value is None:
fill_value = default_fill_value([el])

Expand Down Expand Up @@ -779,11 +786,9 @@ def _apply_to_array(self, el, *, axis, fill_value=None):
el_np = _dlpack_to_numpy(el)

# Recursively call this same function
out_np = self._apply_to_array(el_np, axis=axis, fill_value=fill_value)
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value)

# TODO: Fix it as moving it back and forth is not ideal, but it allows us to
# keep the same interface for all backends.
# Convert result back to original backend
# reverting back to numpy as it is hard to reindex on JAX and others
return _dlpack_from_numpy(out_np, xp)

# numpy case
Copy link
Contributor

@ilan-gold ilan-gold Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic here is a little confused. This function used to be for numpy, but you added the _is_pandas check above, which I don't think applies here. But the logic you've written "# numpy case" and down works great for non-numpy array-api compatible arrays as well!

So I would leave the numpy case as before (i.e., remove _is_pandas and check isinstance(el, np.ndarray)), and then in the case it is not a numpy array, use this logic under "# numpy case"! You can then get rid of the if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): branch. One reason I cautioned against falling back to numpy behavior is that some things like jax arrays that are API compatible might be on the GPU! You can't transfer a JAX array on the GPU to numpy :/

Expand Down Expand Up @@ -1588,18 +1593,48 @@ def concat_dataset2d_on_annot_axis(
return ds_concat_2d


def _to_numpy_if_array_api(x):
if isinstance(x, np.ndarray | pd.DataFrame | pd.Series | DaskArray):
return x
try:
import array_api_compat as aac

# If this succeeds, it's an array-API array (e.g. JAX, cubed, cupy, dask)
aac.array_namespace(x)
return np.asarray(x)
except TypeError:
# Not an array-API object (or lib not available) = return unchanged
return x
# def _to_numpy_if_immutable(x):
# print("Initial x:", type(x))
# if isinstance(x, np.ndarray | pd.DataFrame | pd.Series | DaskArray):
# print("x is already a supported mutable type:", type(x))
# return x
# try:
# # checking for mutability
# print("Trying np.asarray(x)...")
# x_array = np.asarray(x)
# print("np.asarray(x) succeeded:", type(x_array))
# if x_array.size > 0:
# try:
# orig = x_array[0]
# print("Checking mutability: trying in-place write")
# x_array[0] = orig # test no-op write
# print("In-place mutation succeeded, x is mutable:", type(x))
# return x # mutation worked, so keep original
# except (ValueError, TypeError) as e:
# print("In-place mutation failed:", type(x), "|", repr(e))
# # pass

# except ValueError:
# print("Trying np.from_dlpack(x)...")
# result = np.from_dlpack(x)
# print("np.from_dlpack(x) succeeded:", type(result))
# # pass
# # if it is not mutable, we convert to numpy
# try:
# # trying convert via from_dlpack first
# return np.from_dlpack(x)
# except TypeError:
# try:
# # fallback to asarray if from_dlpack not possible
# print("Trying fallback np.asarray(x)...")
# result = np.asarray(x)
# print("Fallback np.asarray(x) succeeded:", type(result))
# return result
# except ValueError:
# # Not an array-API object (or lib not available) = return unchanged
# print("Final fallback np.asarray(x) failed:", type(x), "|", repr(e))
# print("Returning x as-is:", type(x))
# return x


def concat( # noqa: PLR0912, PLR0913, PLR0915
Expand Down
32 changes: 31 additions & 1 deletion src/anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,22 @@ def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]):

`adata.attr[key1][key2][keyn]...`
"""
# from anndata._core.merge import _to_numpy_if_immutable

new = adata_view.copy()
attr = getattr(new, attr_name)
container = reduce(lambda d, k: d[k], keys, attr)
# Traverse to the parent container
parent = reduce(lambda d, k: d[k], keys[:-1], attr)
# key = keys[-1]

# Get the actual object we want to mutate
container = parent

# Yield it (not yet converted)
yield container

# # After yield, check if immutable and convert to mutable before reinserting
# parent[key] = _to_numpy_if_immutable(container)
adata_view._init_as_actual(new)


Expand All @@ -73,6 +85,24 @@ class _SetItemMixin:

_view_args: ElementRef | None

# def __setitem__(self, idx: Any, value: Any):
# # from anndata._core.merge import _to_numpy_if_immutable

# if self._view_args is None:
# super().__setitem__(idx, value)
# else:
# warnings.warn(
# f"Trying to modify attribute `.{self._view_args.attrname}` of view, "
# "initializing view as actual.",
# ImplicitModificationWarning,
# stacklevel=2,
# )
# with view_update(*self._view_args) as container:
# arr = _to_numpy_if_immutable(container)
# arr[idx] = value
# # manually assign back into the parent dict
# parent = reduce(lambda d, k: d[k], self._view_args[:-1])
# parent[self._view_args[-1]] = arr
def __setitem__(self, idx: Any, value: Any):
if self._view_args is None:
super().__setitem__(idx, value)
Expand Down
12 changes: 9 additions & 3 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,15 @@ def to_numpy_if_array_api(x):
try:
import array_api_compat as aac

# If this succeeds, it's an array-API array (e.g. JAX, CuPy, torch, …)
aac.array_namespace(x)
return np.asarray(x)
xp = aac.array_namespace(x)
# Already a NumPy array
if xp.__name__.startswith("numpy"):
return x
# # If this succeeds, it's an array-API array (e.g. JAX, CuPy, torch, …)
# aac.array_namespace(x)
# return np.asarray(x)
if hasattr(x, "__to_dlpack__"):
return np.from_dlpack(x)
except (ImportError, AttributeError, TypeError):
# Not an array-API object (or not supported), so return unchanged
return x
Expand Down
16 changes: 16 additions & 0 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,22 @@ def _(a):
return a.map_blocks(asarray, dtype=a.dtype, meta=np.ndarray)


@singledispatch
def as_dense_jax_array(a):
# for cases where jax does not support sparse arrays
return jnp.asarray(asarray(a))


@as_dense_jax_array.register(CSMatrix)
def _(a):
return jnp.array(a.toarray())


@as_dense_jax_array.register(DaskArray)
def _(a):
return jnp.array(a.compute()) # fallback for lazy arrays


@singledispatch
def as_sparse_dask_array(a) -> DaskArray:
import dask.array as da
Expand Down
44 changes: 29 additions & 15 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING

import h5py
import jax.numpy as jnp
import numpy as np
import pandas as pd
import pytest
Expand All @@ -33,6 +34,7 @@
from anndata.tests.helpers import (
GEN_ADATA_NO_XARRAY_ARGS,
as_dense_dask_array,
as_dense_jax_array,
assert_equal,
gen_adata,
)
Expand Down Expand Up @@ -120,7 +122,9 @@ def dtype(request):
# ------------------------------------------------------------------------------


@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array])
@pytest.mark.parametrize(
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array]
)
def test_readwrite_roundtrip(typ, tmp_path, diskfmt, diskfmt2):
pth1 = tmp_path / f"first.{diskfmt}"
write1 = lambda x: getattr(x, f"write_{diskfmt}")(pth1)
Expand Down Expand Up @@ -160,7 +164,9 @@ async def _do_test():


@pytest.mark.parametrize("storage", ["h5ad", "zarr"])
@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array])
@pytest.mark.parametrize(
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array]
)
def test_readwrite_kitchensink(tmp_path, storage, typ, backing_h5ad, dataset_kwargs):
X = typ(X_list)
adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
Expand Down Expand Up @@ -206,7 +212,9 @@ def test_readwrite_kitchensink(tmp_path, storage, typ, backing_h5ad, dataset_kwa
assert_equal(adata, adata_src)


@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array])
@pytest.mark.parametrize(
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array]
)
def test_readwrite_maintain_X_dtype(typ, backing_h5ad):
X = typ(X_list).astype("int8")
adata_src = ad.AnnData(X)
Expand Down Expand Up @@ -239,7 +247,9 @@ def test_maintain_layers(rw):
assert not np.any((orig.layers["sparse"] != curr.layers["sparse"]).toarray())


@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array])
@pytest.mark.parametrize(
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array]
)
def test_readwrite_h5ad_one_dimension(typ, backing_h5ad):
X = typ(X_list)
adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
Expand All @@ -250,7 +260,9 @@ def test_readwrite_h5ad_one_dimension(typ, backing_h5ad):
assert_equal(adata, adata_one)


@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array])
@pytest.mark.parametrize(
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array]
)
def test_readwrite_backed(typ, backing_h5ad):
X = typ(X_list)
adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
Expand All @@ -266,7 +278,7 @@ def test_readwrite_backed(typ, backing_h5ad):


@pytest.mark.parametrize(
"typ", [np.array, csr_matrix, csc_matrix, csr_array, csc_array]
"typ", [np.array, jnp.array, csr_matrix, csc_matrix, csr_array, csc_array]
)
def test_readwrite_equivalent_h5ad_zarr(tmp_path, typ):
h5ad_pth = tmp_path / "adata.h5ad"
Expand Down Expand Up @@ -455,7 +467,7 @@ def test_changed_obs_var_names(tmp_path, diskfmt):


@pytest.mark.skipif(not find_spec("loompy"), reason="Loompy is not installed")
@pytest.mark.parametrize("typ", [np.array, csr_matrix])
@pytest.mark.parametrize("typ", [np.array, jnp.array, csr_matrix])
@pytest.mark.parametrize("obsm_mapping", [{}, dict(X_composed=["oanno3", "oanno4"])])
@pytest.mark.parametrize("varm_mapping", [{}, dict(X_composed2=["vanno3", "vanno4"])])
def test_readwrite_loom(typ, obsm_mapping, varm_mapping, tmp_path):
Expand Down Expand Up @@ -572,14 +584,14 @@ def test_read_tsv_iter():
assert adata.X.tolist() == X_list


@pytest.mark.parametrize("typ", [np.array, csr_matrix])
@pytest.mark.parametrize("typ", [np.array, jnp.array, csr_matrix])
def test_write_csv(typ, tmp_path):
X = typ(X_list)
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict)
adata.write_csvs(tmp_path / "test_csv_dir", skip_data=False)


@pytest.mark.parametrize("typ", [np.array, csr_matrix])
@pytest.mark.parametrize("typ", [np.array, jnp.array, csr_matrix])
def test_write_csv_view(typ, tmp_path):
# https://github.com/scverse/anndata/issues/401
import hashlib
Expand Down Expand Up @@ -624,8 +636,9 @@ def hash_dir_contents(dir: Path) -> dict[str, bytes]:
pytest.param(ad.read_zarr, ad.io.write_zarr, "test_empty.zarr"),
],
)
def test_readwrite_empty(read, write, name, tmp_path):
adata = ad.AnnData(uns=dict(empty=np.array([], dtype=float)))
@pytest.mark.parametrize("xp", [np, jnp]) # xp = array namespace
def test_readwrite_empty(read, write, name, tmp_path, xp):
adata = ad.AnnData(uns=dict(empty=xp.array([], dtype=float)))
write(tmp_path / name, adata)
ad_read = read(tmp_path / name)
assert ad_read.uns["empty"].shape == (0,)
Expand Down Expand Up @@ -708,10 +721,11 @@ def test_dataframe_reserved_columns(tmp_path, diskfmt, colname, attr):
getattr(to_write, f"write_{diskfmt}")(adata_pth)


def test_write_large_categorical(tmp_path, diskfmt):
@pytest.mark.parametrize("xp", [np, jnp]) # xp = array namespace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would revert this - it's just for generating categories which gets pushed into pandas. I don't think this triggers any internal array-api code

def test_write_large_categorical(tmp_path, diskfmt, xp):
M = 30_000
N = 1000
ls = np.array(list(ascii_letters))
ls = xp.array(list(ascii_letters))

def random_cats(n):
cats = {
Expand All @@ -722,14 +736,14 @@ def random_cats(n):
cats |= random_cats(n - len(cats))
return cats

cats = np.array(sorted(random_cats(10_000)))
cats = xp.array(sorted(random_cats(10_000)))
adata_pth = tmp_path / f"adata.{diskfmt}"
n_cats = len(np.unique(cats))
orig = ad.AnnData(
csr_matrix(([1], ([0], [0])), shape=(M, N)),
obs=dict(
cat1=cats[np.random.choice(n_cats, M)],
cat2=pd.Categorical.from_codes(np.random.choice(n_cats, M), cats),
cat2=pd.Categorical.from_codes(xp.random.choice(n_cats, M), cats),
),
)
getattr(orig, f"write_{diskfmt}")(adata_pth)
Expand Down