-
Notifications
You must be signed in to change notification settings - Fork 175
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
21d5882
7cd6d69
eff9dde
b230d11
692a270
b52e5b7
806becf
98d249a
cdc4fdd
030f985
5aff4d6
ba74743
a749637
c716fb5
951c026
d8adf27
0e410a4
8ea34e8
96992d5
460428a
dd3b867
743ebb3
5a6c825
90d9e6a
787feb0
dea107d
cdd3747
b04c8ef
383c445
eef9015
2d2275b
a2a8606
dcbd235
c227265
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import pandas as pd | ||
from scipy.sparse import issparse | ||
|
||
from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray | ||
from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray, has_xp | ||
from .xarray import Dataset2D | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -108,8 +108,11 @@ def name_idx(i): | |
if isinstance(indexer.data, DaskArray): | ||
return indexer.data.compute() | ||
return indexer.data | ||
elif has_xp(indexer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why. |
||
msg = "Need to implement array api-based indexing" | ||
raise NotImplementedError(msg) | ||
msg = f"Unknown indexer {indexer!r} of type {type(indexer)}" | ||
raise IndexError() | ||
raise IndexError(msg) | ||
|
||
|
||
def _fix_slice_bounds(s: slice, length: int) -> slice: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import scipy | ||
from natsort import natsorted | ||
from packaging.version import Version | ||
from pandas.api.types import is_extension_array_dtype | ||
from scipy import sparse | ||
|
||
from anndata._core.file_backing import to_memory | ||
|
@@ -538,6 +539,15 @@ def resolve_merge_strategy( | |
return strategy | ||
|
||
|
||
def safe_to_numpy(x): | ||
"""Convert to numpy array, handling JAX/Cupy arrays.""" | ||
if isinstance(x, pd.Series | pd.Index) or ( | ||
hasattr(x, "dtype") and is_extension_array_dtype(x.dtype) | ||
): | ||
return x | ||
return np.asarray(x) | ||
|
||
|
||
|
||
##################### | ||
# Concatenation | ||
##################### | ||
|
@@ -658,7 +668,11 @@ def _apply_to_array(self, el, *, axis, fill_value=None): | |
|
||
indexer = self.idx | ||
|
||
# Indexes real fast, and does outer indexing | ||
# Fallback to numpy: keep pandas | ||
# keep pandas EAs/Series/Index as-is; only normalize non-pandas arrays | ||
|
||
el = safe_to_numpy(el) | ||
|
||
return pd.api.extensions.take( | ||
el, indexer, axis=axis, allow_fill=True, fill_value=fill_value | ||
) | ||
|
@@ -1399,6 +1413,20 @@ def concat_dataset2d_on_annot_axis( | |
return ds_concat_2d | ||
|
||
|
||
def _to_numpy_if_array_api(x): | ||
|
||
if isinstance(x, np.ndarray | pd.DataFrame | pd.Series | DaskArray): | ||
return x | ||
try: | ||
import array_api_compat as aac | ||
|
||
# If this succeeds, it's an array-API array (e.g. JAX, cubed, cupy, dask) | ||
aac.array_namespace(x) | ||
return np.asarray(x) | ||
except TypeError: | ||
# Not an array-API object (or lib not available) → return unchanged | ||
return x | ||
|
||
|
||
def concat( # noqa: PLR0912, PLR0913, PLR0915 | ||
adatas: Collection[AnnData] | Mapping[str, AnnData], | ||
*, | ||
|
@@ -1785,6 +1813,7 @@ def concat( # noqa: PLR0912, PLR0913, PLR0915 | |
"not concatenating `.raw` attributes." | ||
) | ||
warn(msg, UserWarning, stacklevel=2) | ||
|
||
return AnnData( | ||
**{ | ||
"X": X, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,12 @@ | |
from types import MappingProxyType | ||
from typing import TYPE_CHECKING, Generic, TypeVar | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from pandas.api.extensions import ExtensionArray | ||
from scipy import sparse as sp | ||
|
||
from anndata import AnnData | ||
from anndata._io.utils import report_read_key_on_error, report_write_key_on_error | ||
from anndata._types import Read, ReadLazy, _ReadInternal, _ReadLazyInternal | ||
from anndata.compat import DaskArray, ZarrGroup, _read_attr, is_zarr_v2 | ||
|
@@ -34,6 +40,48 @@ | |
LazyDataStructures = DaskArray | Dataset2D | CategoricalArray | MaskedArray | ||
|
||
|
||
def is_sparse_like(x): | ||
try: | ||
return sp.issparse(x) | ||
except AttributeError: | ||
return False | ||
|
||
|
||
def to_numpy_if_array_api(x): | ||
amalia-k510 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance( | ||
x, | ||
np.ndarray | ||
| np.generic | ||
| pd.DataFrame | ||
| pd.Series | ||
| pd.Index | ||
| ExtensionArray | ||
| DaskArray | ||
| sp.spmatrix | ||
| AnnData, | ||
): | ||
return x | ||
Comment on lines
+50
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given this check, I would think this function should be called |
||
|
||
# Try array-API detection only for unknown leaves | ||
try: | ||
import array_api_compat as aac | ||
|
||
# If this succeeds, it's an array-API array (e.g. JAX, CuPy, torch, …) | ||
aac.array_namespace(x) | ||
return np.asarray(x) | ||
except (ImportError, AttributeError, TypeError): | ||
# Not an array-API object (or not supported), so return unchanged | ||
return x | ||
|
||
|
||
def normalize_nested(obj): | ||
if isinstance(obj, dict): | ||
return {k: normalize_nested(v) for k, v in obj.items()} | ||
if isinstance(obj, list | tuple): | ||
return type(obj)(normalize_nested(v) for v in obj) | ||
return to_numpy_if_array_api(obj) | ||
|
||
|
||
# TODO: This probably should be replaced by a hashable Mapping due to conversion b/w "_" and "-" | ||
# TODO: Should filetype be included in the IOSpec if it changes the encoding? Or does the intent that these things be "the same" overrule that? | ||
@dataclass(frozen=True) | ||
|
@@ -386,6 +434,10 @@ def write_elem( | |
elif k in store: | ||
del store[k] | ||
|
||
# Normalize array-API (e.g., JAX/CuPy) payloads buried in mappings/lists | ||
if not isinstance(elem, AnnData): | ||
elem = normalize_nested(elem) | ||
Comment on lines
+461
to
+462
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't we want to also normalize |
||
|
||
write_func = self.find_write_func(dest_type, elem, modifiers) | ||
|
||
if self.callback is None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle