-
Notifications
You must be signed in to change notification settings - Fork 175
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…a into ig/array_api_continue import merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I just went over general code style, nothing JAX-related
src/anndata/_core/merge.py
Outdated
# Force to NumPy (materializes JAX/Cubed); fine for small tests, | ||
# but may be slow or fail on large/lazy arrays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code doesn’t just run for tests though. Also are you sure that this is a good idea for arrays with pandas dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I was initially forcing everything to NumPy, but that’s no longer the case. I’ve updated it so the it should preserve arrays with pandas dtypes.
src/anndata/_core/merge.py
Outdated
return False | ||
|
||
|
||
def _to_numpy_if_array_api(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there should be no second copy of this that’s slightly different, only one!
dest = self._adata_ref._X | ||
# Handles read-only NumPy views from backend arrays like JAX by | ||
# making a writable copy so in-place assignment on views can succeed. | ||
if isinstance(dest, np.ndarray) and not dest.flags.writeable: | ||
dest = np.array(dest, copy=True) # make a fresh, writable buffer | ||
self._adata_ref._X = dest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle
src/anndata/_core/merge.py
Outdated
hasattr(x, "dtype") and is_extension_array_dtype(x.dtype) | ||
): | ||
return x | ||
return np.asarray(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok nice this is the right direction no doubt! So what we want here probably is not to rely on asarray
but dlpack
to do the conversion. In short:
- We should have a check in
_apply_to_array
to see if something is array-api compatible but not a numpy ndarray. - If this case is true, dlpack into numpy, recursively call
_apply_to_array
- Then use dlpack to take the output of the recursive call to the original type before we went to numpy.
Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a nice paradigm to follow for situations where we have an existing numpy or cupy implementation and it isn't clear how to use the array-api to achieve our aims. We should still try to use it as much as possible so that we can eventually remove numpy codepaths where possible, but this is a nice first step.
… with copying introduced as an extra precaution
src/anndata/_core/merge.py
Outdated
def _dlpack_from_numpy(x_np, original_xp): | ||
# cubed and other array later elif | ||
if original_xp.__name__.startswith("jax"): | ||
return jax.dlpack.from_dlpack(x_np) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/anndata/_core/merge.py
Outdated
|
||
T = TypeVar("T") | ||
|
||
with suppress(ImportError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/anndata/_core/merge.py
Outdated
# Use the backend of the first array as the reference | ||
ref = arrays[0] | ||
xp = get_namespace(ref) | ||
|
||
# Convert all arrays to the same backend as `ref` | ||
arrays = [ref] + [_same_backend(ref, x, copy=True)[1] for x in arrays[1:]] | ||
|
||
# Concatenate with the backend’s API | ||
value = xp.concatenate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition was previously hit by the fact that none of the above checks involving any
were True
. Instead of changing this last default condition, I would create a new branch here specifically for the array-api, check that they all have the same backend, and the concatenate. If they don't have the same backed, you just proceed to the np
condition (which will fail presumably). I wouldn't worry about mixing different backends, especially with the array-api for now. If we use cubed
, dlpack
won't work there anyway
src/anndata/_core/merge.py
Outdated
|
||
# fallback for known backends that put it elsewhere (JAX and later others) | ||
if original_xp.__name__.startswith("jax"): | ||
import jax.dlpack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
original_xp
should have a from_dlpack
method! Does my comment here 383c445#r2291442475 not apply?
|
||
|
||
def test_write_large_categorical(tmp_path, diskfmt): | ||
@pytest.mark.parametrize("xp", [np, jnp]) # xp = array namespace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would revert this - it's just for generating categories which gets pushed into pandas. I don't think this triggers any internal array-api code
return indexer.data | ||
return indexer.dat | ||
|
||
elif has_xp(indexer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.
src/anndata/_core/merge.py
Outdated
return pd.api.extensions.take( | ||
el, indexer, axis=axis, allow_fill=True, fill_value=fill_value | ||
) | ||
if _is_pandas(el): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if _is_pandas(el): | |
if isinstance(el, np.ndarray): |
I would have thought that el
is a numpy array given that the old function name was _apply_to_array
, no?
src/anndata/_core/merge.py
Outdated
# reverting back to numpy as it is hard to reindex on JAX and others | ||
return _dlpack_from_numpy(out_np, xp) | ||
|
||
# numpy case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the logic here is a little confused. This function used to be for numpy
, but you added the _is_pandas
check above, which I don't think applies here. But the logic you've written "# numpy case" and down works great for non-numpy array-api compatible arrays as well!
So I would leave the numpy case as before (i.e., remove _is_pandas
and check isinstance(el, np.ndarray)
), and then in the case it is not a numpy array, use this logic under "# numpy case"! You can then get rid of the if not isinstance(el, np.ndarray) and _is_array_api_compatible(el):
branch. One reason I cautioned against falling back to numpy
behavior is that some things like jax
arrays that are API compatible might be on the GPU! You can't transfer a JAX array on the GPU to numpy :/
xp = aac.array_namespace(old) | ||
# skip early if numpy | ||
if xp.__name__.startswith("numpy"): | ||
return old[new] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why special case this? If isinstance(old, numpy.ndarray)
then singledispatch
handles this for us. If something is resolved to use numpy
by array_api_compat
(I don't see why this would ever be the case here, and I don't see anything in the docs about falling back to numpy
)
if where_fn is not None: | ||
old = where_fn(old)[0] | ||
else: | ||
# if no where function is found, fallback to NumPy | ||
old = np.where(np.asarray(old))[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How could where_fn
ever be None
? https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html#where is part of the array-api
return old[new] | ||
|
||
# handle boolean mask; i.e. checking whether old is a boolean array | ||
if hasattr(old, "dtype") and str(old.dtype) in ("bool", "bool_", "boolean"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If old
is not array-api compatible, shouldn't we error out? old.dtype
would exist and https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html#isdtype could handle checking for bool
, no need for strings or anything
# if new is a slice object, converting it into a range of indices using arange | ||
if isinstance(new, slice): | ||
# trying to get arange from the backend | ||
arange_fn = getattr(xp, "arange", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why getattr
(same with where
)? Shouldn't xp.arange
exist: https://data-apis.org/array-api/latest/API_specification/generated/array_api.arange.html#arange?
# Skip if it's already NumPy | ||
if xp.__name__.startswith("numpy"): | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How could this condition be reached if we have the above isinstance(x, np.ndarray | ...)
check?
def to_numpy_if_array_api(x): | ||
if isinstance( | ||
x, | ||
np.ndarray | ||
| np.generic | ||
| pd.DataFrame | ||
| pd.Series | ||
| pd.Index | ||
| ExtensionArray | ||
| DaskArray | ||
| sp.spmatrix | ||
| AnnData, | ||
): | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given this check, I would think this function should be called to_writeable
, no?
if not isinstance(elem, AnnData): | ||
elem = normalize_nested(elem) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we want to also normalize AnnData
objects so they're sub elements are also corrected?
# use first as a reference to check if all of the arrays are the same type | ||
xp = get_namespace(arrays[0]) | ||
|
||
if not all(get_namespace(a) is xp or a.shape == 0 for a in arrays): | ||
msg = "Cannot concatenate array-api arrays from different backends." | ||
raise ValueError(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the check around making sure they have the same namespace is encapsulated in https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace since it can take in an array of array objects - you could use that instead of just using the first one
if _is_pandas(el) or isinstance(el, pd.DataFrame): | ||
return self._apply_to_df_like(el, axis=axis, fill_value=fill_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you add this?
# Check: is array-api compatible, but not NumPy | ||
if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): | ||
# Convert to NumPy via DLPack | ||
el_np = _dlpack_to_numpy(el) | ||
# Recursively call this same function | ||
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value) | ||
# reverting back to numpy as it is hard to reindex on JAX and others | ||
return _dlpack_from_numpy(out_np, xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why convert to numpy here? Just let the array pass through, that's the whole point of using the array api :) If a jax array is on the GPU you're going to bring it to the CPU here, but why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if it wasn't clear in my previous comment - we should only go to numpy if the existing array is on the CPU and we couldn't come up with a generic way of doing this via the array-api. But you made a way of doing this via the array-api which is great, so no need to convert to numpy!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if this is going in circles, but I'm kind of reviewing locally to changes, and losing the big picture sometimes!
# Check: is array-api compatible, but not NumPy | ||
if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): | ||
# Convert to NumPy via DLPack | ||
el_np = _dlpack_to_numpy(el) | ||
# Recursively call this same function | ||
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value) | ||
# reverting back to numpy as it is hard to reindex on JAX and others | ||
return _dlpack_from_numpy(out_np, xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if it wasn't clear in my previous comment - we should only go to numpy if the existing array is on the CPU and we couldn't come up with a generic way of doing this via the array-api. But you made a way of doing this via the array-api which is great, so no need to convert to numpy!
First step in getting anndata concat and test generation to work properly with JAX, (and Cubed potentially), without just converting everything into NumPy.
Random data creation and shape handling use xp.asarray so arrays stay in their original backend where possible. I also updated concat paths to actually check types before converting, added helpers for sparse detection and array API checks, and made sure backend arrays only get turned into NumPy when absolutely necessary. This fixes a bunch of concat-related test failures.
It’s still not perfect. Some pandas calls in concat still force conversion to NumPy, so the data gets copied instead of being used directly. Cubed support is only a placeholder right now. Type detection might still be a bit too broad, which can lead to extra conversions. Works for NumPy and JAX in tests, but I haven’t tried other backends.