-
Notifications
You must be signed in to change notification settings - Fork 175
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
21d5882
7cd6d69
eff9dde
b230d11
692a270
b52e5b7
806becf
98d249a
cdc4fdd
030f985
5aff4d6
ba74743
a749637
c716fb5
951c026
d8adf27
0e410a4
8ea34e8
96992d5
460428a
dd3b867
743ebb3
5a6c825
90d9e6a
787feb0
dea107d
cdd3747
b04c8ef
383c445
eef9015
2d2275b
a2a8606
dcbd235
c227265
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -658,7 +658,12 @@ def _apply_to_array(self, el, *, axis, fill_value=None): | |
|
||
indexer = self.idx | ||
|
||
# Indexes real fast, and does outer indexing | ||
# Fallback to numpy: keep pandas | ||
# Force to NumPy (materializes JAX/Cubed); fine for small tests, | ||
# but may be slow or fail on large/lazy arrays | ||
|
||
if not isinstance(el, np.ndarray): | ||
el = np.asarray(el) # fine for jax-in-cpu tests | ||
|
||
return pd.api.extensions.take( | ||
el, indexer, axis=axis, allow_fill=True, fill_value=fill_value | ||
) | ||
|
@@ -1399,6 +1404,37 @@ def concat_dataset2d_on_annot_axis( | |
return ds_concat_2d | ||
|
||
|
||
def _is_sparse(x): | ||
try: | ||
return scipy.sparse.issparse(x) | ||
except TypeError: | ||
return False | ||
amalia-k510 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
|
||
def _to_numpy_if_array_api(x): | ||
|
||
if isinstance(x, np.ndarray | pd.DataFrame | pd.Series | DaskArray) or _is_sparse( | ||
x | ||
): | ||
return x | ||
try: | ||
import array_api_compat as aac | ||
|
||
# If this succeeds, it's an array-API array (e.g. JAX, cubed, cupy, dask) | ||
aac.array_namespace(x) | ||
return np.asarray(x) | ||
except TypeError: | ||
# Not an array-API object (or lib not available) → return unchanged | ||
return x | ||
|
||
|
||
def _normalize_nested(obj): | ||
if isinstance(obj, dict): | ||
return {k: _normalize_nested(v) for k, v in obj.items()} | ||
if isinstance(obj, list | tuple): | ||
amalia-k510 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return type(obj)(_normalize_nested(v) for v in obj) | ||
return _to_numpy_if_array_api(obj) | ||
|
||
|
||
def concat( # noqa: PLR0912, PLR0913, PLR0915 | ||
adatas: Collection[AnnData] | Mapping[str, AnnData], | ||
*, | ||
|
@@ -1759,6 +1795,11 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 | |
) | ||
uns = uns_merge([a.uns for a in adatas]) | ||
|
||
# TODO: try pandas extension arrays after concat errors are fixed | ||
# converting to numpy since pandas does not support array-API arrays | ||
# normalizes uns (handles JAX / array-API arrays nested in dicts/lists) | ||
uns = _normalize_nested(uns) | ||
|
||
raw = None | ||
has_raw = [a.raw is not None for a in adatas] | ||
if all(has_raw): | ||
|
@@ -1785,6 +1826,13 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 | |
"not concatenating `.raw` attributes." | ||
) | ||
warn(msg, UserWarning, stacklevel=2) | ||
|
||
layers = _normalize_nested(layers) | ||
concat_mapping = _normalize_nested(concat_mapping) | ||
alt_mapping = _normalize_nested(alt_mapping) | ||
concat_pairwise = _normalize_nested(concat_pairwise) | ||
alt_pairwise = _normalize_nested(alt_pairwise) | ||
|
||
return AnnData( | ||
**{ | ||
"X": X, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,12 @@ | |
from types import MappingProxyType | ||
from typing import TYPE_CHECKING, Generic, TypeVar | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from pandas.api.extensions import ExtensionArray | ||
from scipy import sparse as sp | ||
|
||
from anndata import AnnData | ||
from anndata._io.utils import report_read_key_on_error, report_write_key_on_error | ||
from anndata._types import Read, ReadLazy, _ReadInternal, _ReadLazyInternal | ||
from anndata.compat import DaskArray, ZarrGroup, _read_attr, is_zarr_v2 | ||
|
@@ -34,6 +40,48 @@ | |
LazyDataStructures = DaskArray | Dataset2D | CategoricalArray | MaskedArray | ||
|
||
|
||
def is_sparse_like(x): | ||
try: | ||
return sp.issparse(x) | ||
except AttributeError: | ||
return False | ||
|
||
|
||
def to_numpy_if_array_api(x): | ||
amalia-k510 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance( | ||
x, | ||
np.ndarray | ||
| np.generic | ||
| pd.DataFrame | ||
| pd.Series | ||
| pd.Index | ||
| ExtensionArray | ||
| DaskArray | ||
| sp.spmatrix | ||
| AnnData, | ||
): | ||
return x | ||
Comment on lines
+50
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given this check, I would think this function should be called |
||
|
||
# Try array-API detection only for unknown leaves | ||
try: | ||
import array_api_compat as aac | ||
|
||
# If this succeeds, it's an array-API array (e.g. JAX, CuPy, torch, …) | ||
aac.array_namespace(x) | ||
return np.asarray(x) | ||
except (ImportError, AttributeError, TypeError): | ||
# Not an array-API object (or not supported), so return unchanged | ||
return x | ||
|
||
|
||
def normalize_nested(obj): | ||
if isinstance(obj, dict): | ||
return {k: normalize_nested(v) for k, v in obj.items()} | ||
if isinstance(obj, list | tuple): | ||
return type(obj)(normalize_nested(v) for v in obj) | ||
return to_numpy_if_array_api(obj) | ||
|
||
|
||
# TODO: This probably should be replaced by a hashable Mapping due to conversion b/w "_" and "-" | ||
# TODO: Should filetype be included in the IOSpec if it changes the encoding? Or does the intent that these things be "the same" overrule that? | ||
@dataclass(frozen=True) | ||
|
@@ -386,6 +434,10 @@ def write_elem( | |
elif k in store: | ||
del store[k] | ||
|
||
# Normalize array-API (e.g., JAX/CuPy) payloads buried in mappings/lists | ||
if not isinstance(elem, AnnData): | ||
elem = normalize_nested(elem) | ||
Comment on lines
+461
to
+462
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't we want to also normalize |
||
|
||
write_func = self.find_write_func(dest_type, elem, modifiers) | ||
|
||
if self.callback is None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.