Skip to content

Conversation

amalia-k510
Copy link
Contributor

First step in getting anndata concat and test generation to work properly with JAX, (and Cubed potentially), without just converting everything into NumPy.

Random data creation and shape handling use xp.asarray so arrays stay in their original backend where possible. I also updated concat paths to actually check types before converting, added helpers for sparse detection and array API checks, and made sure backend arrays only get turned into NumPy when absolutely necessary. This fixes a bunch of concat-related test failures.

It’s still not perfect. Some pandas calls in concat still force conversion to NumPy, so the data gets copied instead of being used directly. Cubed support is only a placeholder right now. Type detection might still be a bit too broad, which can lead to extra conversions. Works for NumPy and JAX in tests, but I haven’t tried other backends.

Copy link
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I just went over general code style, nothing JAX-related

Comment on lines 662 to 663
# Force to NumPy (materializes JAX/Cubed); fine for small tests,
# but may be slow or fail on large/lazy arrays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code doesn’t just run for tests though. Also are you sure that this is a good idea for arrays with pandas dtypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was initially forcing everything to NumPy, but that’s no longer the case. I’ve updated it so the it should preserve arrays with pandas dtypes.

return False


def _to_numpy_if_array_api(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be no second copy of this that’s slightly different, only one!

Comment on lines +670 to +675
dest = self._adata_ref._X
# Handles read-only NumPy views from backend arrays like JAX by
# making a writable copy so in-place assignment on views can succeed.
if isinstance(dest, np.ndarray) and not dest.flags.writeable:
dest = np.array(dest, copy=True) # make a fresh, writable buffer
self._adata_ref._X = dest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle

hasattr(x, "dtype") and is_extension_array_dtype(x.dtype)
):
return x
return np.asarray(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok nice this is the right direction no doubt! So what we want here probably is not to rely on asarray but dlpack to do the conversion. In short:

  1. We should have a check in _apply_to_array to see if something is array-api compatible but not a numpy ndarray.
  2. If this case is true, dlpack into numpy, recursively call _apply_to_array
  3. Then use dlpack to take the output of the recursive call to the original type before we went to numpy.

Does that make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a nice paradigm to follow for situations where we have an existing numpy or cupy implementation and it isn't clear how to use the array-api to achieve our aims. We should still try to use it as much as possible so that we can eventually remove numpy codepaths where possible, but this is a nice first step.

def _dlpack_from_numpy(x_np, original_xp):
# cubed and other array later elif
if original_xp.__name__.startswith("jax"):
return jax.dlpack.from_dlpack(x_np)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


T = TypeVar("T")

with suppress(ImportError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1086 to 1094
# Use the backend of the first array as the reference
ref = arrays[0]
xp = get_namespace(ref)

# Convert all arrays to the same backend as `ref`
arrays = [ref] + [_same_backend(ref, x, copy=True)[1] for x in arrays[1:]]

# Concatenate with the backend’s API
value = xp.concatenate(
Copy link
Contributor

@ilan-gold ilan-gold Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition was previously hit by the fact that none of the above checks involving any were True. Instead of changing this last default condition, I would create a new branch here specifically for the array-api, check that they all have the same backend, and the concatenate. If they don't have the same backed, you just proceed to the np condition (which will fail presumably). I wouldn't worry about mixing different backends, especially with the array-api for now. If we use cubed, dlpack won't work there anyway


# fallback for known backends that put it elsewhere (JAX and later others)
if original_xp.__name__.startswith("jax"):
import jax.dlpack
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_xp should have a from_dlpack method! Does my comment here 383c445#r2291442475 not apply?



def test_write_large_categorical(tmp_path, diskfmt):
@pytest.mark.parametrize("xp", [np, jnp]) # xp = array namespace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would revert this - it's just for generating categories which gets pushed into pandas. I don't think this triggers any internal array-api code

return indexer.data
return indexer.dat

elif has_xp(indexer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why.

return pd.api.extensions.take(
el, indexer, axis=axis, allow_fill=True, fill_value=fill_value
)
if _is_pandas(el):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if _is_pandas(el):
if isinstance(el, np.ndarray):

I would have thought that el is a numpy array given that the old function name was _apply_to_array, no?

# reverting back to numpy as it is hard to reindex on JAX and others
return _dlpack_from_numpy(out_np, xp)

# numpy case
Copy link
Contributor

@ilan-gold ilan-gold Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic here is a little confused. This function used to be for numpy, but you added the _is_pandas check above, which I don't think applies here. But the logic you've written "# numpy case" and down works great for non-numpy array-api compatible arrays as well!

So I would leave the numpy case as before (i.e., remove _is_pandas and check isinstance(el, np.ndarray)), and then in the case it is not a numpy array, use this logic under "# numpy case"! You can then get rid of the if not isinstance(el, np.ndarray) and _is_array_api_compatible(el): branch. One reason I cautioned against falling back to numpy behavior is that some things like jax arrays that are API compatible might be on the GPU! You can't transfer a JAX array on the GPU to numpy :/

@amalia-k510 amalia-k510 requested a review from ilan-gold October 13, 2025 12:01
Comment on lines +563 to +566
xp = aac.array_namespace(old)
# skip early if numpy
if xp.__name__.startswith("numpy"):
return old[new]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why special case this? If isinstance(old, numpy.ndarray) then singledispatch handles this for us. If something is resolved to use numpy by array_api_compat (I don't see why this would ever be the case here, and I don't see anything in the docs about falling back to numpy)

Comment on lines +573 to +577
if where_fn is not None:
old = where_fn(old)[0]
else:
# if no where function is found, fallback to NumPy
old = np.where(np.asarray(old))[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return old[new]

# handle boolean mask; i.e. checking whether old is a boolean array
if hasattr(old, "dtype") and str(old.dtype) in ("bool", "bool_", "boolean"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If old is not array-api compatible, shouldn't we error out? old.dtype would exist and https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html#isdtype could handle checking for bool, no need for strings or anything

# if new is a slice object, converting it into a range of indices using arange
if isinstance(new, slice):
# trying to get arange from the backend
arange_fn = getattr(xp, "arange", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 72 to 74
# Skip if it's already NumPy
if xp.__name__.startswith("numpy"):
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could this condition be reached if we have the above isinstance(x, np.ndarray | ...) check?

Comment on lines +50 to +63
def to_numpy_if_array_api(x):
if isinstance(
x,
np.ndarray
| np.generic
| pd.DataFrame
| pd.Series
| pd.Index
| ExtensionArray
| DaskArray
| sp.spmatrix
| AnnData,
):
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this check, I would think this function should be called to_writeable, no?

Comment on lines +461 to +462
if not isinstance(elem, AnnData):
elem = normalize_nested(elem)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want to also normalize AnnData objects so they're sub elements are also corrected?

Comment on lines +1045 to +1050
# use first as a reference to check if all of the arrays are the same type
xp = get_namespace(arrays[0])

if not all(get_namespace(a) is xp or a.shape == 0 for a in arrays):
msg = "Cannot concatenate array-api arrays from different backends."
raise ValueError(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the check around making sure they have the same namespace is encapsulated in https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace since it can take in an array of array objects - you could use that instead of just using the first one

Comment on lines +690 to +691
if _is_pandas(el) or isinstance(el, pd.DataFrame):
return self._apply_to_df_like(el, axis=axis, fill_value=fill_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add this?

Comment on lines +770 to +777
# Check: is array-api compatible, but not NumPy
if not isinstance(el, np.ndarray) and _is_array_api_compatible(el):
# Convert to NumPy via DLPack
el_np = _dlpack_to_numpy(el)
# Recursively call this same function
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value)
# reverting back to numpy as it is hard to reindex on JAX and others
return _dlpack_from_numpy(out_np, xp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why convert to numpy here? Just let the array pass through, that's the whole point of using the array api :) If a jax array is on the GPU you're going to bring it to the CPU here, but why?

Copy link
Contributor

@ilan-gold ilan-gold Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if it wasn't clear in my previous comment - we should only go to numpy if the existing array is on the CPU and we couldn't come up with a generic way of doing this via the array-api. But you made a way of doing this via the array-api which is great, so no need to convert to numpy!

Copy link
Contributor

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if this is going in circles, but I'm kind of reviewing locally to changes, and losing the big picture sometimes!

Comment on lines +770 to +777
# Check: is array-api compatible, but not NumPy
if not isinstance(el, np.ndarray) and _is_array_api_compatible(el):
# Convert to NumPy via DLPack
el_np = _dlpack_to_numpy(el)
# Recursively call this same function
out_np = self._apply_to_array_api(el_np, axis=axis, fill_value=fill_value)
# reverting back to numpy as it is hard to reindex on JAX and others
return _dlpack_from_numpy(out_np, xp)
Copy link
Contributor

@ilan-gold ilan-gold Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if it wasn't clear in my previous comment - we should only go to numpy if the existing array is on the CPU and we couldn't come up with a generic way of doing this via the array-api. But you made a way of doing this via the array-api which is great, so no need to convert to numpy!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants