-
Notifications
You must be signed in to change notification settings - Fork 175
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
21d5882
7cd6d69
eff9dde
b230d11
692a270
b52e5b7
806becf
98d249a
cdc4fdd
030f985
5aff4d6
ba74743
a749637
c716fb5
951c026
d8adf27
0e410a4
8ea34e8
96992d5
460428a
dd3b867
743ebb3
5a6c825
90d9e6a
787feb0
dea107d
cdd3747
b04c8ef
383c445
eef9015
2d2275b
a2a8606
dcbd235
c227265
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -687,6 +687,8 @@ def apply(self, el, *, axis, fill_value=None): # noqa: PLR0911 | |
""" | ||
if self.no_change and (axis_len(el, axis) == len(self.old_idx)): | ||
return el | ||
if _is_pandas(el) or isinstance(el, pd.DataFrame): | ||
return self._apply_to_df_like(el, axis=axis, fill_value=fill_value) | ||
Comment on lines
+690
to
+691
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you add this? |
||
if isinstance(el, pd.DataFrame | Dataset2D): | ||
return self._apply_to_df_like(el, axis=axis, fill_value=fill_value) | ||
elif isinstance(el, CSMatrix | CSArray | CupySparseMatrix): | ||
|
@@ -753,16 +755,8 @@ def _apply_to_cupy_array(self, el, *, axis, fill_value=None): | |
def _apply_to_array_api(self, el, *, axis, fill_value=None): | ||
if fill_value is None: | ||
fill_value = default_fill_value([el]) | ||
|
||
indexer = self.idx | ||
|
||
if _is_pandas(el): | ||
# using the behavior that already exists in pandas for Series/Index | ||
return pd.api.extensions.take( | ||
el, indexer, axis=axis, allow_fill=True, fill_value=fill_value | ||
) | ||
# e.g., numpy, jax.numpy, cubed, etc. | ||
xp = get_namespace(el) | ||
indexer = xp.asarray(self.idx) | ||
|
||
# Handling edge case to mimic pandas behavior | ||
if el.shape[axis] == 0: | ||
|
@@ -777,15 +771,11 @@ def _apply_to_array_api(self, el, *, axis, fill_value=None): | |
if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): | ||
# Convert to NumPy via DLPack | ||
el_np = _dlpack_to_numpy(el) | ||
|
||
# Recursively call this same function | ||
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value) | ||
|
||
# reverting back to numpy as it is hard to reindex on JAX and others | ||
return _dlpack_from_numpy(out_np, xp) | ||
|
||
# numpy case | ||
indexer = xp.asarray(indexer) | ||
# marking which positions are missing, so we could use fill_value | ||
missing_mask = indexer == -1 | ||
safe_indexer = xp.where(missing_mask, 0, indexer) | ||
|
@@ -910,38 +900,6 @@ def merge_indices(inds: Iterable[pd.Index], join: Join_T) -> pd.Index: | |
raise ValueError(msg) | ||
|
||
|
||
# def default_fill_value(els): | ||
# # Given some arrays, returns what the default fill value should be. | ||
|
||
# if any( | ||
# isinstance(el, pd.DataFrame | pd.Series) | ||
# for el in els | ||
# if el is not None and el is not MissingVal | ||
# ): | ||
# return pd.NA | ||
|
||
# if any( | ||
# isinstance(el, CSMatrix | CSArray) | ||
# or (isinstance(el, DaskArray) and isinstance(el._meta, CSMatrix | CSArray)) | ||
# for el in els | ||
# ): | ||
# return 0 | ||
|
||
# # Pick the namespace of the first valid element | ||
# for el in els: | ||
# if el is not None and el is not MissingVal: | ||
# xp = get_namespace(el) | ||
# # Not all backends have `nan` defined (e.g. integer dtypes) | ||
# try: | ||
# return xp.nan | ||
# except AttributeError: | ||
# # Fall back to 0 if no NaN in this backend | ||
# return xp.asarray(0).item() | ||
|
||
# # Fallback if list was empty or only MissingVal | ||
# return np.nan | ||
|
||
|
||
def default_fill_value(els): | ||
"""Given some arrays, returns what the default fill value should be. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
from functools import reduce, singledispatch, wraps | ||
from typing import TYPE_CHECKING, Literal | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from pandas.api.types import is_bool_dtype | ||
|
@@ -116,7 +115,6 @@ def view_update(adata_view: AnnData, attr_name: str, keys: tuple[str, ...]): | |
|
||
def _replace_field(container, idx, value): | ||
import awkward as ak | ||
import numpy as np | ||
|
||
# JAX-style immutable array | ||
if hasattr(container, "at"): | ||
|
@@ -556,7 +554,42 @@ def _resolve_idxs( | |
|
||
@singledispatch | ||
def _resolve_idx(old: Index1DNorm, new: Index1DNorm, l: Literal[0, 1]) -> Index1DNorm: | ||
raise NotImplementedError | ||
import array_api_compat as aac | ||
|
||
from ..compat import has_xp | ||
|
||
# handling array-API–compatible arrays (e.g. JAX, etc) | ||
if has_xp(old): | ||
xp = aac.array_namespace(old) | ||
# skip early if numpy | ||
if xp.__name__.startswith("numpy"): | ||
return old[new] | ||
Comment on lines
+563
to
+566
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why special case this? If |
||
|
||
# handle boolean mask; i.e. checking whether old is a boolean array | ||
if hasattr(old, "dtype") and str(old.dtype) in ("bool", "bool_", "boolean"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
# retrieving the where function (like np.where or jnp.where) | ||
where_fn = getattr(xp, "where", None) | ||
# if where exists, use it to convert the boolean mask to integer indices | ||
if where_fn is not None: | ||
old = where_fn(old)[0] | ||
else: | ||
# if no where function is found, fallback to NumPy | ||
old = np.where(np.asarray(old))[0] | ||
Comment on lines
+573
to
+577
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How could |
||
|
||
# if new is a slice object, converting it into a range of indices using arange | ||
if isinstance(new, slice): | ||
# trying to get arange from the backend | ||
arange_fn = getattr(xp, "arange", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
# if arange exists, apply it to the resolved slice range | ||
if arange_fn is not None: | ||
new = arange_fn(*new.indices(old.shape[0])) | ||
else: | ||
new = np.arange(*new.indices(old.shape[0])) | ||
|
||
return old[new] | ||
|
||
msg = f"_resolve_idx not implemented for type {type(old)}" | ||
raise NotImplementedError(msg) | ||
|
||
|
||
@_resolve_idx.register(np.ndarray) | ||
|
@@ -582,20 +615,6 @@ def _resolve_idx_slice( | |
return np.arange(*old.indices(l))[new] | ||
|
||
|
||
@_resolve_idx.register(jnp.ndarray) | ||
def _resolve_idx_jnp( | ||
old: jnp.ndarray, new: Index1DNorm, l: Literal[0, 1] | ||
) -> jnp.ndarray: | ||
# Boolean mask + index | ||
if old.dtype == jnp.bool_: | ||
old = jnp.where(old)[0] | ||
|
||
if isinstance(new, slice): | ||
new = jnp.arange(*new.indices(old.shape[0])) | ||
|
||
return old[new] | ||
|
||
|
||
def _resolve_idx_slice_slice(old: slice, new: slice, l: Literal[0, 1]) -> slice: | ||
r = range(*old.indices(l))[new] | ||
# Convert back to slice | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,21 +64,38 @@ def to_numpy_if_array_api(x): | |
|
||
# Try array-API detection only for unknown leaves | ||
try: | ||
# importing the array API compatibility layer | ||
import array_api_compat as aac | ||
|
||
# getting the array namespace | ||
xp = aac.array_namespace(x) | ||
# Already a NumPy array | ||
# Skip if it's already NumPy | ||
if xp.__name__.startswith("numpy"): | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How could this condition be reached if we have the above |
||
# # If this succeeds, it's an array-API array (e.g. JAX, CuPy, torch, …) | ||
# aac.array_namespace(x) | ||
# return np.asarray(x) | ||
# If the array has a `.device` attribute, check if it's on GPU | ||
if hasattr(x, "device"): | ||
device = x.device | ||
# if the device has a `.type` field, use it | ||
if hasattr(device, "type"): | ||
if device.type != "cpu": | ||
# move the array to CPU if it's on GPU | ||
x = xp.to_device( | ||
x, "cpu" | ||
) # not sure about this, would we want to move the entire array to CPU? | ||
# otherwise, if the device is not a string, check if it's not "cpu" | ||
elif str(device) != "cpu": | ||
x = xp.to_device(x, "cpu") | ||
|
||
# Convert to NumPy using DLPack (safe now) | ||
if hasattr(x, "__to_dlpack__"): | ||
return np.from_dlpack(x) | ||
except (ImportError, AttributeError, TypeError): | ||
# Not an array-API object (or not supported), so return unchanged | ||
|
||
except (ImportError, AttributeError, TypeError, RuntimeError): | ||
# Could not detect or convert – return unchanged | ||
return x | ||
|
||
return x | ||
|
||
|
||
def normalize_nested(obj): | ||
if isinstance(obj, dict): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.