-
Notifications
You must be signed in to change notification settings - Fork 175
Backend-native Implementation #2071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
amalia-k510
wants to merge
34
commits into
scverse:main
Choose a base branch
from
amalia-k510:ig/array_api_continue
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
21d5882
chore: add tests
ilan-gold 7cd6d69
fix: allow creating objects with array-api
ilan-gold eff9dde
chore: add indexing test
ilan-gold b230d11
fix: add xp pass-through
ilan-gold 692a270
chore: add more indexing methods
ilan-gold b52e5b7
Merge branch 'main' into ig/array_api_starter
ilan-gold 806becf
initial backend step
amalia-k510 98d249a
concat lazy error fix
amalia-k510 cdc4fdd
comment for merge fix
amalia-k510 030f985
dask array fix
amalia-k510 5aff4d6
fix test_concatenate_roundtrip[inner-np_array-pandas-concat-lazy-conc…
amalia-k510 ba74743
fixed all contact errors
amalia-k510 a749637
comments fix
amalia-k510 c716fb5
quick fix but not ideal as it converts jax and other arrays to numpy
amalia-k510 951c026
Merge branch 'main' into ig/array_api_continue
amalia-k510 d8adf27
comment fix
amalia-k510 0e410a4
Merge branch 'ig/array_api_continue' of github.com:amalia-k510/anndat…
amalia-k510 8ea34e8
comment fix
amalia-k510 96992d5
extra tests and function changes
amalia-k510 460428a
dlpack introduction and trying to make gen_adata fully backend native…
amalia-k510 dd3b867
removed unnecessary function
amalia-k510 743ebb3
minor fixes
amalia-k510 5a6c825
minor fix
amalia-k510 90d9e6a
begin merge modification
amalia-k510 787feb0
concat on jax arrays is introduced
amalia-k510 dea107d
precommit fixes
amalia-k510 cdd3747
merge quick fix
amalia-k510 b04c8ef
minor fixes
amalia-k510 383c445
writer and reindexer introduced + jax in the tests, still need to fix…
amalia-k510 eef9015
just concat errors left
amalia-k510 2d2275b
test fixes for jax
amalia-k510 a2a8606
indexer implementation and comments addressed
amalia-k510 dcbd235
test_double_index_jax fixed
amalia-k510 c227265
register default case, merge issues, and gpu/cpu array transfer addre…
amalia-k510 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import pandas as pd | ||
from scipy.sparse import issparse | ||
|
||
from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray | ||
from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray, has_xp | ||
from .xarray import Dataset2D | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -39,6 +39,7 @@ def _normalize_index( # noqa: PLR0911, PLR0912 | |
indexer: Index1D, index: pd.Index | ||
) -> Index1DNorm | int | np.integer: | ||
# TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. | ||
# protect aroound weird numeric index | ||
if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64): | ||
msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" | ||
raise TypeError(msg) | ||
|
@@ -50,6 +51,7 @@ def name_idx(i): | |
i = index.get_loc(i) | ||
return i | ||
|
||
# converting start and stop of the slide to the integer positions if they are strings | ||
if isinstance(indexer, slice): | ||
start = name_idx(indexer.start) | ||
stop = name_idx(indexer.stop) | ||
|
@@ -65,17 +67,21 @@ def name_idx(i): | |
elif isinstance( | ||
indexer, Sequence | np.ndarray | pd.Index | CSMatrix | np.matrix | CSArray | ||
): | ||
# convert to the 1D if it's accidentally 2D column/row vector | ||
# convert sparse into dense arrays if needed | ||
if hasattr(indexer, "shape") and ( | ||
(indexer.shape == (index.shape[0], 1)) | ||
or (indexer.shape == (1, index.shape[0])) | ||
): | ||
if isinstance(indexer, CSMatrix | CSArray): | ||
indexer = indexer.toarray() | ||
indexer = np.ravel(indexer) | ||
# if it is something else, convert it to numpy | ||
if not isinstance(indexer, np.ndarray | pd.Index): | ||
indexer = np.array(indexer) | ||
if len(indexer) == 0: | ||
indexer = indexer.astype(int) | ||
# if it is a float array or something along those lines, convert it to integers | ||
if isinstance(indexer, np.ndarray) and np.issubdtype( | ||
indexer.dtype, np.floating | ||
): | ||
|
@@ -94,7 +100,7 @@ def name_idx(i): | |
) | ||
raise IndexError(msg) | ||
return indexer | ||
else: # indexer should be string array | ||
else: | ||
positions = index.get_indexer(indexer) | ||
if np.any(positions < 0): | ||
not_found = indexer[positions < 0] | ||
|
@@ -108,8 +114,65 @@ def name_idx(i): | |
if isinstance(indexer.data, DaskArray): | ||
return indexer.data.compute() | ||
return indexer.data | ||
|
||
elif has_xp(indexer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks nearly identical to the numpy case. In a case like this, I think you were/would be right to just merge the two if possible. If it's not I would explain why. |
||
# getting array's namespace | ||
xp = indexer.__array_namespace__() | ||
|
||
# Flatten to 1D | ||
if hasattr(indexer, "shape") and ( | ||
indexer.shape == (index.shape[0], 1) or indexer.shape == (1, index.shape[0]) | ||
): | ||
indexer = xp.ravel( | ||
indexer | ||
) # flattening to 1D, jax.numpy has it, not sure about cubed | ||
|
||
# Get dtype in array-api-style | ||
dtype = getattr(indexer, "dtype", None) | ||
|
||
# if we have like a jax boolean mask array | ||
if xp.issubdtype(dtype, xp.bool_): | ||
if indexer.shape != index.shape: | ||
msg = ( | ||
f"Boolean index does not match AnnData’s shape along this dimension. " | ||
f"Boolean index has shape {indexer.shape}, expected {index.shape}" | ||
) | ||
raise IndexError(msg) | ||
return indexer | ||
|
||
# all good, you can return it | ||
elif xp.issubdtype(dtype, xp.integer): | ||
return indexer | ||
# float number case | ||
elif xp.issubdtype(dtype, xp.floating): | ||
indexer_int = xp.astype(indexer, xp.int32) # jax default to it | ||
# If all floats were “safe” (like 0.0, 1.0, 2.0), return them cast to integers. | ||
is_fractional = xp.not_equal(indexer, xp.astype(indexer_int, xp.floating)) | ||
if xp.any(is_fractional): | ||
msg = f"Indexer {indexer!r} has non-integer floating point values." | ||
raise IndexError(msg) | ||
return indexer_int | ||
|
||
else: | ||
try: | ||
values = indexer.tolist() # converting to the list | ||
except Exception: | ||
msg = f"Could not convert {indexer!r} to list for string lookup." | ||
raise IndexError(msg) | ||
positions = index.get_indexer(values) | ||
if np.any(positions < 0): | ||
not_found = [ | ||
v for v, p in zip(values, positions, strict=False) if p < 0 | ||
] | ||
msg = ( | ||
f"Values {not_found}, from {values}, " | ||
"are not valid obs/ var names or indices." | ||
) | ||
raise KeyError(msg) | ||
return positions | ||
|
||
msg = f"Unknown indexer {indexer!r} of type {type(indexer)}" | ||
raise IndexError() | ||
raise IndexError(msg) | ||
|
||
|
||
def _fix_slice_bounds(s: slice, length: int) -> slice: | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually just let the error be thrown in this case. If something isn't writeable, I don't think that's our responsibility to handle