-
Notifications
You must be signed in to change notification settings - Fork 175
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
21d5882
7cd6d69
eff9dde
b230d11
692a270
b52e5b7
806becf
98d249a
cdc4fdd
030f985
5aff4d6
ba74743
a749637
c716fb5
951c026
d8adf27
0e410a4
8ea34e8
96992d5
460428a
dd3b867
743ebb3
5a6c825
90d9e6a
787feb0
dea107d
cdd3747
b04c8ef
383c445
eef9015
2d2275b
a2a8606
dcbd235
c227265
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,7 @@ | |
|
||
# Enable DLPack interop for JAX, CuPy, etc., only if installed | ||
with suppress(ImportError): | ||
import jax | ||
import jax.dlpack | ||
pass | ||
import pandas as pd | ||
import pandas.api.types as pdf | ||
import scipy | ||
|
@@ -123,6 +122,7 @@ def not_missing(v) -> bool: | |
|
||
|
||
def _same_backend(x, y, *, copy: bool = True): | ||
# TODO: convert it so that I could also use it to convert one array to another | ||
# for merge implementation | ||
# Makes sure two arrays are from the same array backend. | ||
# If not, uses from_dlpack() to convert `y` to `x`'s backend. | ||
|
@@ -619,8 +619,6 @@ def _is_array_api_compatible(x): | |
|
||
|
||
def _dlpack_to_numpy(x): | ||
###TODO: FIX | ||
# Convert array-api-compatible x to NumPy using DLPack. | ||
try: | ||
return np.from_dlpack(x) | ||
except TypeError as e: | ||
|
@@ -629,13 +627,22 @@ def _dlpack_to_numpy(x): | |
|
||
|
||
def _dlpack_from_numpy(x_np, original_xp): | ||
###TODO: FIX | ||
# cubed and other array later elif | ||
# TODO: cubed and other array later elif | ||
if hasattr(original_xp, "from_dlpack"): | ||
try: | ||
return original_xp.from_dlpack(x_np) | ||
except Exception as e: | ||
msg = f"Failed to call from_dlpack on backend {original_xp.__name__}: {e}" | ||
raise TypeError(msg) from e | ||
|
||
# fallback for known backends that put it elsewhere (JAX and later others) | ||
if original_xp.__name__.startswith("jax"): | ||
import jax.dlpack | ||
|
||
return jax.dlpack.from_dlpack(x_np) | ||
|
||
else: | ||
msg = f"DLPack back-conversion not implemented for {original_xp.__name__}" | ||
raise TypeError(msg) | ||
|
||
msg = f"DLPack back-conversion not implemented for backend {original_xp.__name__}" | ||
raise TypeError(msg) | ||
|
||
|
||
##################### | ||
|
@@ -698,7 +705,7 @@ def apply(self, el, *, axis, fill_value=None): # noqa: PLR0911 | |
elif isinstance(el, CupyArray): | ||
return self._apply_to_cupy_array(el, axis=axis, fill_value=fill_value) | ||
else: | ||
return self._apply_to_array(el, axis=axis, fill_value=fill_value) | ||
return self._apply_to_array_api(el, axis=axis, fill_value=fill_value) | ||
|
||
def _apply_to_df_like(self, el: pd.DataFrame | Dataset2D, *, axis, fill_value=None): | ||
if fill_value is None: | ||
|
@@ -750,7 +757,7 @@ def _apply_to_cupy_array(self, el, *, axis, fill_value=None): | |
|
||
return out | ||
|
||
def _apply_to_array(self, el, *, axis, fill_value=None): | ||
def _apply_to_array_api(self, el, *, axis, fill_value=None): | ||
if fill_value is None: | ||
fill_value = default_fill_value([el]) | ||
|
||
|
@@ -779,11 +786,9 @@ def _apply_to_array(self, el, *, axis, fill_value=None): | |
el_np = _dlpack_to_numpy(el) | ||
|
||
# Recursively call this same function | ||
out_np = self._apply_to_array(el_np, axis=axis, fill_value=fill_value) | ||
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value) | ||
|
||
# TODO: Fix it as moving it back and forth is not ideal, but it allows us to | ||
# keep the same interface for all backends. | ||
# Convert result back to original backend | ||
# reverting back to numpy as it is hard to reindex on JAX and others | ||
return _dlpack_from_numpy(out_np, xp) | ||
|
||
# numpy case | ||
|
||
|
@@ -1588,18 +1593,48 @@ def concat_dataset2d_on_annot_axis( | |
return ds_concat_2d | ||
|
||
|
||
def _to_numpy_if_array_api(x): | ||
if isinstance(x, np.ndarray | pd.DataFrame | pd.Series | DaskArray): | ||
return x | ||
try: | ||
import array_api_compat as aac | ||
|
||
# If this succeeds, it's an array-API array (e.g. JAX, cubed, cupy, dask) | ||
aac.array_namespace(x) | ||
return np.asarray(x) | ||
except TypeError: | ||
# Not an array-API object (or lib not available) = return unchanged | ||
return x | ||
# def _to_numpy_if_immutable(x): | ||
# print("Initial x:", type(x)) | ||
# if isinstance(x, np.ndarray | pd.DataFrame | pd.Series | DaskArray): | ||
# print("x is already a supported mutable type:", type(x)) | ||
# return x | ||
# try: | ||
# # checking for mutability | ||
# print("Trying np.asarray(x)...") | ||
# x_array = np.asarray(x) | ||
# print("np.asarray(x) succeeded:", type(x_array)) | ||
# if x_array.size > 0: | ||
# try: | ||
# orig = x_array[0] | ||
# print("Checking mutability: trying in-place write") | ||
# x_array[0] = orig # test no-op write | ||
# print("In-place mutation succeeded, x is mutable:", type(x)) | ||
# return x # mutation worked, so keep original | ||
# except (ValueError, TypeError) as e: | ||
# print("In-place mutation failed:", type(x), "|", repr(e)) | ||
# # pass | ||
|
||
# except ValueError: | ||
# print("Trying np.from_dlpack(x)...") | ||
# result = np.from_dlpack(x) | ||
# print("np.from_dlpack(x) succeeded:", type(result)) | ||
# # pass | ||
# # if it is not mutable, we convert to numpy | ||
# try: | ||
# # trying convert via from_dlpack first | ||
# return np.from_dlpack(x) | ||
# except TypeError: | ||
# try: | ||
# # fallback to asarray if from_dlpack not possible | ||
# print("Trying fallback np.asarray(x)...") | ||
# result = np.asarray(x) | ||
# print("Fallback np.asarray(x) succeeded:", type(result)) | ||
# return result | ||
# except ValueError: | ||
# # Not an array-API object (or lib not available) = return unchanged | ||
# print("Final fallback np.asarray(x) failed:", type(x), "|", repr(e)) | ||
# print("Returning x as-is:", type(x)) | ||
# return x | ||
|
||
|
||
def concat( # noqa: PLR0912, PLR0913, PLR0915 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from typing import TYPE_CHECKING | ||
|
||
import h5py | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
@@ -33,6 +34,7 @@ | |
from anndata.tests.helpers import ( | ||
GEN_ADATA_NO_XARRAY_ARGS, | ||
as_dense_dask_array, | ||
as_dense_jax_array, | ||
assert_equal, | ||
gen_adata, | ||
) | ||
|
@@ -120,7 +122,9 @@ def dtype(request): | |
# ------------------------------------------------------------------------------ | ||
|
||
|
||
@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array]) | ||
@pytest.mark.parametrize( | ||
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array] | ||
) | ||
def test_readwrite_roundtrip(typ, tmp_path, diskfmt, diskfmt2): | ||
pth1 = tmp_path / f"first.{diskfmt}" | ||
write1 = lambda x: getattr(x, f"write_{diskfmt}")(pth1) | ||
|
@@ -160,7 +164,9 @@ async def _do_test(): | |
|
||
|
||
@pytest.mark.parametrize("storage", ["h5ad", "zarr"]) | ||
@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array]) | ||
@pytest.mark.parametrize( | ||
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array] | ||
) | ||
def test_readwrite_kitchensink(tmp_path, storage, typ, backing_h5ad, dataset_kwargs): | ||
X = typ(X_list) | ||
adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict) | ||
|
@@ -206,7 +212,9 @@ def test_readwrite_kitchensink(tmp_path, storage, typ, backing_h5ad, dataset_kwa | |
assert_equal(adata, adata_src) | ||
|
||
|
||
@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array]) | ||
@pytest.mark.parametrize( | ||
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array] | ||
) | ||
def test_readwrite_maintain_X_dtype(typ, backing_h5ad): | ||
X = typ(X_list).astype("int8") | ||
adata_src = ad.AnnData(X) | ||
|
@@ -239,7 +247,9 @@ def test_maintain_layers(rw): | |
assert not np.any((orig.layers["sparse"] != curr.layers["sparse"]).toarray()) | ||
|
||
|
||
@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array]) | ||
@pytest.mark.parametrize( | ||
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array] | ||
) | ||
def test_readwrite_h5ad_one_dimension(typ, backing_h5ad): | ||
X = typ(X_list) | ||
adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict) | ||
|
@@ -250,7 +260,9 @@ def test_readwrite_h5ad_one_dimension(typ, backing_h5ad): | |
assert_equal(adata, adata_one) | ||
|
||
|
||
@pytest.mark.parametrize("typ", [np.array, csr_matrix, csr_array, as_dense_dask_array]) | ||
@pytest.mark.parametrize( | ||
"typ", [np.array, csr_matrix, csr_array, as_dense_dask_array, as_dense_jax_array] | ||
) | ||
def test_readwrite_backed(typ, backing_h5ad): | ||
X = typ(X_list) | ||
adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict) | ||
|
@@ -266,7 +278,7 @@ def test_readwrite_backed(typ, backing_h5ad): | |
|
||
|
||
@pytest.mark.parametrize( | ||
"typ", [np.array, csr_matrix, csc_matrix, csr_array, csc_array] | ||
"typ", [np.array, jnp.array, csr_matrix, csc_matrix, csr_array, csc_array] | ||
) | ||
def test_readwrite_equivalent_h5ad_zarr(tmp_path, typ): | ||
h5ad_pth = tmp_path / "adata.h5ad" | ||
|
@@ -455,7 +467,7 @@ def test_changed_obs_var_names(tmp_path, diskfmt): | |
|
||
|
||
@pytest.mark.skipif(not find_spec("loompy"), reason="Loompy is not installed") | ||
@pytest.mark.parametrize("typ", [np.array, csr_matrix]) | ||
@pytest.mark.parametrize("typ", [np.array, jnp.array, csr_matrix]) | ||
@pytest.mark.parametrize("obsm_mapping", [{}, dict(X_composed=["oanno3", "oanno4"])]) | ||
@pytest.mark.parametrize("varm_mapping", [{}, dict(X_composed2=["vanno3", "vanno4"])]) | ||
def test_readwrite_loom(typ, obsm_mapping, varm_mapping, tmp_path): | ||
|
@@ -572,14 +584,14 @@ def test_read_tsv_iter(): | |
assert adata.X.tolist() == X_list | ||
|
||
|
||
@pytest.mark.parametrize("typ", [np.array, csr_matrix]) | ||
@pytest.mark.parametrize("typ", [np.array, jnp.array, csr_matrix]) | ||
def test_write_csv(typ, tmp_path): | ||
X = typ(X_list) | ||
adata = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict) | ||
adata.write_csvs(tmp_path / "test_csv_dir", skip_data=False) | ||
|
||
|
||
@pytest.mark.parametrize("typ", [np.array, csr_matrix]) | ||
@pytest.mark.parametrize("typ", [np.array, jnp.array, csr_matrix]) | ||
def test_write_csv_view(typ, tmp_path): | ||
# https://github.com/scverse/anndata/issues/401 | ||
import hashlib | ||
|
@@ -624,8 +636,9 @@ def hash_dir_contents(dir: Path) -> dict[str, bytes]: | |
pytest.param(ad.read_zarr, ad.io.write_zarr, "test_empty.zarr"), | ||
], | ||
) | ||
def test_readwrite_empty(read, write, name, tmp_path): | ||
adata = ad.AnnData(uns=dict(empty=np.array([], dtype=float))) | ||
@pytest.mark.parametrize("xp", [np, jnp]) # xp = array namespace | ||
def test_readwrite_empty(read, write, name, tmp_path, xp): | ||
adata = ad.AnnData(uns=dict(empty=xp.array([], dtype=float))) | ||
write(tmp_path / name, adata) | ||
ad_read = read(tmp_path / name) | ||
assert ad_read.uns["empty"].shape == (0,) | ||
|
@@ -708,10 +721,11 @@ def test_dataframe_reserved_columns(tmp_path, diskfmt, colname, attr): | |
getattr(to_write, f"write_{diskfmt}")(adata_pth) | ||
|
||
|
||
def test_write_large_categorical(tmp_path, diskfmt): | ||
@pytest.mark.parametrize("xp", [np, jnp]) # xp = array namespace | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would revert this - it's just for generating categories which gets pushed into pandas. I don't think this triggers any internal array-api code |
||
def test_write_large_categorical(tmp_path, diskfmt, xp): | ||
M = 30_000 | ||
N = 1000 | ||
ls = np.array(list(ascii_letters)) | ||
ls = xp.array(list(ascii_letters)) | ||
|
||
def random_cats(n): | ||
cats = { | ||
|
@@ -722,14 +736,14 @@ def random_cats(n): | |
cats |= random_cats(n - len(cats)) | ||
return cats | ||
|
||
cats = np.array(sorted(random_cats(10_000))) | ||
cats = xp.array(sorted(random_cats(10_000))) | ||
adata_pth = tmp_path / f"adata.{diskfmt}" | ||
n_cats = len(np.unique(cats)) | ||
orig = ad.AnnData( | ||
csr_matrix(([1], ([0], [0])), shape=(M, N)), | ||
obs=dict( | ||
cat1=cats[np.random.choice(n_cats, M)], | ||
cat2=pd.Categorical.from_codes(np.random.choice(n_cats, M), cats), | ||
cat2=pd.Categorical.from_codes(xp.random.choice(n_cats, M), cats), | ||
), | ||
) | ||
getattr(orig, f"write_{diskfmt}")(adata_pth) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
original_xp
should have afrom_dlpack
method! Does my comment here 383c445#r2291442475 not apply?