Skip to content

Commit 1920dd6

Browse files
Backport PR scverse#2053 on branch 0.12.x (fix: unbound dask) (scverse#2120)
Co-authored-by: Ilan Gold <ilanbassgold@gmail.com>
1 parent f9d3116 commit 1920dd6

5 files changed

Lines changed: 41 additions & 19 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ cu11 = [ "cupy-cuda11x" ]
109109
lazy = [ "xarray>=2025.06.1", "aiohttp", "requests", "anndata[dask]" ]
110110
# https://github.com/dask/dask/issues/11290
111111
# https://github.com/dask/dask/issues/11752
112-
dask = [ "dask[array]>=2023.5.1,!=2024.8.*,!=2024.9.*,<2025.2.0" ]
112+
dask = [
113+
"dask[array]>=2023.5.1,!=2024.8.*,!=2024.9.*,!=2025.2.*,!=2025.3.*,!=2025.4.*,!=2025.5.*,!=2025.6.*,!=2025.7.*,!=2025.8.*",
114+
]
113115

114116
[tool.hatch.version]
115117
source = "vcs"

src/anndata/_core/merge.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,9 @@ def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None):
617617
sub_el = _subset(el, make_slice(indexer, axis, len(shape)))
618618

619619
if any(indexer == -1):
620+
# TODO: Remove this condition once https://github.com/dask/dask/pull/12078 is released
621+
if isinstance(sub_el._meta, CSArray | CSMatrix) and np.isscalar(fill_value):
622+
fill_value = np.array([[fill_value]])
620623
sub_el[make_slice(indexer == -1, axis, len(shape))] = fill_value
621624

622625
return sub_el

tests/test_concatenate.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -957,17 +957,20 @@ def test_nan_merge(axis_name, join_type, array_type):
957957
alt_axis, alt_axis_name = merge._resolve_axis(1 - axis)
958958
mapping_attr = f"{alt_axis_name}m"
959959
adata_shape = (20, 10)
960-
961-
arr = array_type(
962-
sparse.random(adata_shape[alt_axis], 10, density=0.1, format="csr")
963-
)
964-
arr_nan = arr.copy()
960+
# TODO: Revert to https://github.com/scverse/anndata/blob/71fdf821919fc5ff3c864dc74c4432c370573984/tests/test_concatenate.py#L961-L970 after https://github.com/scipy/scipy/pull/23626.
961+
# The need for this handling arose as a result of
962+
# https://github.com/dask/dask/pull/11755/files#diff-65211e64fa680da306e9612b92c60f557365507d46486325f0e7e04359bce64fR456-R459
963+
sparse_arr = sparse.random(adata_shape[alt_axis], 10, density=0.1, format="csr")
964+
sparse_arr_nan = sparse_arr.copy()
965965
with warnings.catch_warnings():
966966
warnings.simplefilter("ignore", category=sparse.SparseEfficiencyWarning)
967967
for _ in range(10):
968-
arr_nan[np.random.choice(arr.shape[0]), np.random.choice(arr.shape[1])] = (
969-
np.nan
970-
)
968+
sparse_arr_nan[
969+
np.random.choice(sparse_arr.shape[0]),
970+
np.random.choice(sparse_arr.shape[1]),
971+
] = np.nan
972+
arr = array_type(sparse_arr)
973+
arr_nan = array_type(sparse_arr_nan)
971974

972975
_data = {"X": sparse.csr_matrix(adata_shape), mapping_attr: {"arr": arr_nan}}
973976
orig1 = AnnData(**_data)
@@ -1811,7 +1814,7 @@ def test_concat_dask_sparse_matches_memory(join_type, merge_strategy):
18111814
X = sparse.random(50, 20, density=0.5, format="csr")
18121815
X_dask = da.from_array(X, chunks=(5, 20))
18131816
var_names_1 = [f"gene_{i}" for i in range(20)]
1814-
var_names_2 = [f"gene_{i}{'_foo' if (i % 2) else ''}" for i in range(20, 40)]
1817+
var_names_2 = [f"gene_{i}{'_foo' if (i % 2) else ''}" for i in range(20)]
18151818

18161819
ad1 = AnnData(X=X, var=pd.DataFrame(index=var_names_1))
18171820
ad2 = AnnData(X=X, var=pd.DataFrame(index=var_names_2))
@@ -1821,7 +1824,6 @@ def test_concat_dask_sparse_matches_memory(join_type, merge_strategy):
18211824

18221825
res_in_memory = concat([ad1, ad2], join=join_type, merge=merge_strategy)
18231826
res_dask = concat([ad1_dask, ad2_dask], join=join_type, merge=merge_strategy)
1824-
18251827
assert_equal(res_in_memory, res_dask)
18261828

18271829

tests/test_dask_view_mem.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def _alloc_cache():
6969
# if we put a 2 factor on 2**19
7070
# the results seems more accurate with the experimental results
7171
# For example from dask.random we allocate 1mb
72+
# As of 2025.09.* dask, this needs a bit more than the previous 1.5mb.
73+
# TODO: Why?
7274
@pytest.mark.usefixtures("_alloc_cache")
73-
@pytest.mark.limit_memory("1.5 MB")
75+
@pytest.mark.limit_memory("1.7 MB")
7476
def test_size_of_view(mapping_name, give_chunks):
7577
import dask.array as da
7678

tests/test_views.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from contextlib import ExitStack
3+
from contextlib import ExitStack, nullcontext
44
from copy import deepcopy
55
from operator import mul
66
from typing import TYPE_CHECKING
@@ -22,7 +22,7 @@
2222
SparseCSRArrayView,
2323
SparseCSRMatrixView,
2424
)
25-
from anndata.compat import CupyCSCMatrix, DaskArray
25+
from anndata.compat import CSArray, CupyCSCMatrix, DaskArray
2626
from anndata.tests.helpers import (
2727
BASE_MATRIX_PARAMS,
2828
CUPY_MATRIX_PARAMS,
@@ -189,13 +189,26 @@ def test_modify_view_component(matrix_type, mapping_name, request):
189189
with pytest.warns(ad.ImplicitModificationWarning, match=rf".*\.{mapping_name}.*"):
190190
m[0, 0] = 100
191191
assert not subset.is_view
192-
assert getattr(subset, mapping_name)["m"][0, 0] == 100
192+
# TODO: Remove `raises` after https://github.com/scipy/scipy/pull/23626.
193+
import dask
193194

194-
assert init_hash == hash_func(adata)
195+
is_dask_with_broken_view_setting = (
196+
"sparse_dask" in request.node.callspec.id
197+
and Version(dask.__version__) >= Version("2025.02.0")
198+
)
199+
is_sparse_array_in_lower_dask_version = (
200+
not is_dask_with_broken_view_setting
201+
and isinstance(m, DaskArray)
202+
and isinstance(m._meta, CSArray)
203+
)
204+
with (
205+
pytest.raises(ValueError, match=r"shape mismatch")
206+
if is_sparse_array_in_lower_dask_version or is_dask_with_broken_view_setting
207+
else nullcontext()
208+
):
209+
assert getattr(subset, mapping_name)["m"][0, 0] == 100
195210

196-
if "sparse_array_dask_array" in request.node.callspec.id:
197-
msg = "sparse arrays in dask are generally expected to fail but in this case they do not"
198-
pytest.fail(msg)
211+
assert init_hash == hash_func(adata)
199212

200213

201214
@pytest.mark.parametrize("attr", ["obsm", "varm"])

0 commit comments

Comments
 (0)