Skip to content

Commit b80fd0c

Browse files
committed
update backends
1 parent 8e5a9b8 commit b80fd0c

23 files changed

Lines changed: 449 additions & 77 deletions

hatch.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ download = "python ./.scripts/ci/download_data.py {args}"
3737

3838
[[envs.hatch-test.matrix]]
3939
deps = ["stable"]
40-
python = ["3.11", "3.12", "3.13"]
40+
python = ["3.12", "3.13"]
4141

4242
[[envs.hatch-test.matrix]]
4343
deps = ["pre"]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ dependencies = [
6565
"scikit-image>=0.25",
6666
# due to https://github.com/scikit-image/scikit-image/issues/6850 breaks rescale ufunc
6767
"scikit-learn>=0.24",
68-
"scverse-backends",
68+
"scverse-backends>=0.0.2,<0.1",
6969
"spatialdata>=0.7.1",
7070
"spatialdata-plot",
7171
"statsmodels>=0.12",

src/squidpy/_backends/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525
},
2626
)
2727

28-
dispatch = _dispatcher.dispatch
28+
backend_dispatch = _dispatcher.backend_dispatch
2929
settings = _dispatcher.settings
3030
get_backend = _dispatcher.get_backend
3131
available_backend_names = _dispatcher.available_backend_names
3232
discover = _dispatcher.discover
3333

3434
__all__ = [
3535
"available_backend_names",
36+
"backend_dispatch",
3637
"discover",
37-
"dispatch",
3838
"get_backend",
3939
"settings",
4040
]

src/squidpy/_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,45 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
228228
return wrapper
229229

230230

231+
_JOBLIB_BACKENDS = frozenset({"dask", "loky", "multiprocessing", "ray", "sequential", "threading"})
232+
233+
234+
def _is_joblib_backend(backend: Any) -> bool:
235+
return isinstance(backend, str) and (backend in _JOBLIB_BACKENDS or backend in jl.parallel.BACKENDS)
236+
237+
238+
def _deprecate_backend_as_parallel_backend(func: Callable[..., Any]) -> Callable[..., Any]:
239+
@functools.wraps(func)
240+
def wrapper(*args: Any, **kwargs: Any) -> Any:
241+
if _is_joblib_backend(kwargs.get("backend")):
242+
if "parallel_backend" in kwargs:
243+
raise TypeError("Pass only one of `backend` or `parallel_backend` for joblib parallelism.")
244+
warnings.warn(
245+
"Using `backend` for joblib parallelism is deprecated. Use `parallel_backend` instead.",
246+
FutureWarning,
247+
stacklevel=2,
248+
)
249+
kwargs["parallel_backend"] = kwargs.pop("backend")
250+
return func(*args, **kwargs)
251+
252+
return wrapper
253+
254+
255+
def _deprecate_legacy_joblib_backend(func: Callable[..., Any]) -> Callable[..., Any]:
256+
@functools.wraps(func)
257+
def wrapper(*args: Any, **kwargs: Any) -> Any:
258+
if _is_joblib_backend(kwargs.get("backend")):
259+
warnings.warn(
260+
"Using `backend` for joblib parallelism is deprecated and has no effect.",
261+
FutureWarning,
262+
stacklevel=2,
263+
)
264+
kwargs.pop("backend")
265+
return func(*args, **kwargs)
266+
267+
return wrapper
268+
269+
231270
def _get_n_cores(n_cores: int | None) -> int:
232271
"""
233272
Make number of cores a positive integer.

src/squidpy/gr/_build.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
get_model,
3939
)
4040

41+
from squidpy._backends import backend_dispatch
4142
from squidpy._constants._constants import CoordType, Transform
4243
from squidpy._constants._pkg_constants import Key
4344
from squidpy._docs import d, inject_docs
@@ -61,6 +62,7 @@ class SpatialNeighborsResult(NamedTuple):
6162

6263
@d.dedent
6364
@inject_docs(t=Transform, c=CoordType)
65+
@backend_dispatch
6466
def spatial_neighbors(
6567
adata: AnnData | SpatialData,
6668
spatial_key: str = Key.obsm.spatial,

src/squidpy/gr/_ligrec.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from functools import partial
99
from itertools import product
1010
from types import MappingProxyType
11-
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
11+
from typing import TYPE_CHECKING, Any, Literal
1212

1313
import numpy as np
1414
import pandas as pd
@@ -17,11 +17,11 @@
1717
from scipy.sparse import csc_matrix
1818
from spatialdata import SpatialData
1919

20-
from squidpy._backends import dispatch
20+
from squidpy._backends import backend_dispatch
2121
from squidpy._constants._constants import ComplexPolicy, CorrAxis
2222
from squidpy._constants._pkg_constants import Key
2323
from squidpy._docs import d, inject_docs
24-
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
24+
from squidpy._utils import NDArrayA, Signal, SigQueue, _deprecate_backend_as_parallel_backend, _get_n_cores, parallelize
2525
from squidpy._validators import assert_positive, check_tuple_needles
2626
from squidpy.gr._utils import (
2727
_assert_categorical_obs,
@@ -31,11 +31,11 @@
3131

3232
__all__ = ["ligrec", "PermutationTest"]
3333

34-
StrSeq: TypeAlias = Sequence[str]
35-
SeqTuple: TypeAlias = Sequence[tuple[str, str]]
36-
Interaction_t: TypeAlias = pd.DataFrame | Mapping[str, StrSeq] | StrSeq | tuple[StrSeq, StrSeq] | SeqTuple
34+
type StrSeq = Sequence[str]
35+
type SeqTuple = Sequence[tuple[str, str]]
36+
type Interaction_t = pd.DataFrame | Mapping[str, StrSeq] | StrSeq | tuple[StrSeq, StrSeq] | SeqTuple
3737

38-
Cluster_t: TypeAlias = StrSeq | tuple[StrSeq, StrSeq] | SeqTuple
38+
type Cluster_t = StrSeq | tuple[StrSeq, StrSeq] | SeqTuple
3939

4040
SOURCE = "source"
4141
TARGET = "target"
@@ -314,6 +314,7 @@ def prepare(
314314
@d.get_sections(base="PT_test", sections=["Parameters"])
315315
@d.dedent
316316
@inject_docs(src=SOURCE, tgt=TARGET, fa=CorrAxis)
317+
@_deprecate_backend_as_parallel_backend
317318
def test(
318319
self,
319320
cluster_key: str,
@@ -327,6 +328,7 @@ def test(
327328
copy: bool = False,
328329
key_added: str | None = None,
329330
numba_parallel: bool | None = None,
331+
parallel_backend: str = "loky",
330332
**kwargs: Any,
331333
) -> Mapping[str, pd.DataFrame] | None:
332334
"""
@@ -357,6 +359,8 @@ def test(
357359
If `None`, ``'{{cluster_key}}_ligrec'`` will be used.
358360
%(numba_parallel)s
359361
%(parallelize)s
362+
parallel_backend
363+
Which joblib backend to use for permutation parallelism.
360364
361365
Returns
362366
-------
@@ -423,6 +427,7 @@ def test(
423427
seed=seed,
424428
n_jobs=n_jobs,
425429
numba_parallel=numba_parallel,
430+
parallel_backend=parallel_backend,
426431
**kwargs,
427432
)
428433
index = pd.MultiIndex.from_frame(interactions, names=[SOURCE, TARGET])
@@ -630,7 +635,8 @@ def prepare(
630635

631636

632637
@d.dedent
633-
@dispatch
638+
@_deprecate_backend_as_parallel_backend
639+
@backend_dispatch
634640
def ligrec(
635641
adata: AnnData | SpatialData,
636642
cluster_key: str,
@@ -643,6 +649,7 @@ def ligrec(
643649
copy: bool = False,
644650
key_added: str | None = None,
645651
gene_symbols: str | None = None,
652+
parallel_backend: str = "loky",
646653
**kwargs: Any,
647654
) -> Mapping[str, pd.DataFrame] | None:
648655
"""
@@ -655,6 +662,8 @@ def ligrec(
655662
%(PT_test.parameters)s
656663
gene_symbols
657664
Key in :attr:`anndata.AnnData.var` to use instead of :attr:`anndata.AnnData.var_names`.
665+
parallel_backend
666+
Which joblib backend to use for permutation parallelism.
658667
659668
Returns
660669
-------
@@ -673,6 +682,7 @@ def ligrec(
673682
corr_axis=corr_axis,
674683
copy=copy,
675684
key_added=key_added,
685+
parallel_backend=parallel_backend,
676686
**kwargs,
677687
)
678688
)
@@ -688,6 +698,7 @@ def _analysis(
688698
seed: int | None = None,
689699
n_jobs: int = 1,
690700
numba_parallel: bool | None = None,
701+
parallel_backend: str = "loky",
691702
**kwargs: Any,
692703
) -> TempResult:
693704
"""
@@ -711,8 +722,10 @@ def _analysis(
711722
Number of parallel jobs to launch.
712723
numba_parallel
713724
Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically.
725+
parallel_backend
726+
Which joblib backend to use for permutation parallelism.
714727
kwargs
715-
Keyword arguments for :func:`squidpy._utils.parallelize`, such as ``n_jobs`` or ``backend``.
728+
Additional keyword arguments for :func:`squidpy._utils.parallelize`.
716729
717730
Returns
718731
-------
@@ -756,6 +769,7 @@ def extractor(res: Sequence[TempResult]) -> TempResult:
756769
_analysis_helper,
757770
np.arange(n_perms, dtype=np.int32).tolist(),
758771
n_jobs=n_jobs,
772+
backend=parallel_backend,
759773
unit="permutation",
760774
extractor=extractor,
761775
**kwargs,

src/squidpy/gr/_nhood.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,18 @@
1717
from scanpy import logging as logg
1818
from spatialdata import SpatialData
1919

20+
from squidpy._backends import backend_dispatch
2021
from squidpy._constants._constants import Centrality
2122
from squidpy._constants._pkg_constants import Key
2223
from squidpy._docs import d, inject_docs
23-
from squidpy._utils import NDArrayA, Signal, SigQueue, _get_n_cores, parallelize
24+
from squidpy._utils import (
25+
NDArrayA,
26+
Signal,
27+
SigQueue,
28+
_deprecate_backend_as_parallel_backend,
29+
_get_n_cores,
30+
parallelize,
31+
)
2432
from squidpy._validators import assert_positive
2533
from squidpy.gr._utils import (
2634
_assert_categorical_obs,
@@ -134,6 +142,8 @@ def _create_function(n_cls: int, parallel: bool = False) -> Callable[[NDArrayA,
134142

135143
@d.get_sections(base="nhood_ench", sections=["Parameters"])
136144
@d.dedent
145+
@_deprecate_backend_as_parallel_backend
146+
@backend_dispatch
137147
def nhood_enrichment(
138148
adata: AnnData | SpatialData,
139149
cluster_key: str,
@@ -144,7 +154,7 @@ def nhood_enrichment(
144154
seed: int | None = None,
145155
copy: bool = False,
146156
n_jobs: int | None = None,
147-
backend: str = "loky",
157+
parallel_backend: str = "loky",
148158
show_progress_bar: bool = True,
149159
) -> NhoodEnrichmentResult | None:
150160
"""
@@ -161,6 +171,8 @@ def nhood_enrichment(
161171
%(seed)s
162172
%(copy)s
163173
%(parallelize)s
174+
parallel_backend
175+
Which joblib backend to use for parallel neighborhood enrichment.
164176
165177
Returns
166178
-------
@@ -203,7 +215,7 @@ def nhood_enrichment(
203215
collection=np.arange(n_perms).tolist(),
204216
extractor=np.vstack,
205217
n_jobs=n_jobs,
206-
backend=backend,
218+
backend=parallel_backend,
207219
show_progress_bar=show_progress_bar,
208220
)(
209221
callback=_test,
@@ -230,14 +242,16 @@ def nhood_enrichment(
230242

231243
@d.dedent
232244
@inject_docs(c=Centrality)
245+
@_deprecate_backend_as_parallel_backend
246+
@backend_dispatch
233247
def centrality_scores(
234248
adata: AnnData | SpatialData,
235249
cluster_key: str,
236250
score: str | Iterable[str] | None = None,
237251
connectivity_key: str | None = None,
238252
copy: bool = False,
239253
n_jobs: int | None = None,
240-
backend: str = "loky",
254+
parallel_backend: str = "loky",
241255
show_progress_bar: bool = False,
242256
) -> pd.DataFrame | None:
243257
"""
@@ -260,6 +274,8 @@ def centrality_scores(
260274
%(conn_key)s
261275
%(copy)s
262276
%(parallelize)s
277+
parallel_backend
278+
Which joblib backend to use for parallel centrality computation.
263279
264280
Returns
265281
-------
@@ -307,7 +323,7 @@ def centrality_scores(
307323
collection=cat,
308324
extractor=pd.concat,
309325
n_jobs=n_jobs,
310-
backend=backend,
326+
backend=parallel_backend,
311327
show_progress_bar=show_progress_bar,
312328
)(clusters=clusters, fun=v, method=k)
313329
res_list.append(df)
@@ -326,6 +342,7 @@ def centrality_scores(
326342

327343

328344
@d.dedent
345+
@backend_dispatch
329346
def interaction_matrix(
330347
adata: AnnData | SpatialData,
331348
cluster_key: str,

src/squidpy/gr/_niche.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from spatialdata import SpatialData
1919
from spatialdata._logging import logger as logg
2020

21+
from squidpy._backends import backend_dispatch
2122
from squidpy._constants._constants import NicheDefinitions
2223
from squidpy._docs import d, inject_docs
2324
from squidpy._validators import assert_isinstance, assert_key_in_adata, assert_one_of
@@ -27,6 +28,7 @@
2728

2829
@d.dedent
2930
@inject_docs(fla=NicheDefinitions)
31+
@backend_dispatch
3032
def calculate_niche(
3133
data: AnnData | SpatialData,
3234
flavor: Literal["neighborhood", "utag", "cellcharter", "spatialleiden"],

0 commit comments

Comments
 (0)