Skip to content

Commit 1d4c135

Browse files
timtreisclaude
andcommitted
Remove device parameter, let JAX handle device selection
JAX selects the appropriate device based on its install (CPU/GPU) and runtime context managers. The explicit device arg added unnecessary complexity with no benefit over JAX's built-in device management. Removes device from: align_obs, align_images, AlignBackend protocol, StAlignBackend, MoscotBackend, and require_jax. Simplifies require_jax to a pure import guard. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2a02b72 commit 1d4c135

5 files changed

Lines changed: 15 additions & 59 deletions

File tree

src/squidpy/experimental/tl/_align/_api.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def align_obs(
4848
*,
4949
output_mode: Literal["affine", "obs", "return"] = "obs",
5050
key_added: str | None = None,
51-
device: Literal["cpu", "gpu"] | None = None,
5251
inplace: bool = True,
5352
**flavour_kwargs: Any,
5453
) -> AnnData | SpatialData | AlignResult | None:
@@ -83,9 +82,6 @@ def align_obs(
8382
Name for the aligned table when ``output_mode='obs'`` and inputs are
8483
SpatialData. Defaults to ``'{adata_query_name}_aligned'``.
8584
Rejected with any other ``output_mode``.
86-
device
87-
``'cpu'``/``'gpu'`` to force a JAX device, or ``None`` to let JAX
88-
pick the default. Only consulted by JAX-backed flavours.
8985
inplace
9086
If ``True``, mutate the query container; otherwise return a copy.
9187
Only affects SpatialData inputs -- for plain AnnData with
@@ -99,7 +95,7 @@ def align_obs(
9995

10096
pair = resolve_obs_pair(data_ref, data_query, adata_ref_name, adata_query_name)
10197
backend = get_backend(flavour)
102-
result = backend.align_obs(pair, device=device, **flavour_kwargs)
98+
result = backend.align_obs(pair, **flavour_kwargs)
10399

104100
# Auto-generate key_added for SpatialData obs writeback.
105101
if key_added is None and output_mode == "obs" and pair.query_element_key is not None:
@@ -118,7 +114,6 @@ def align_images(
118114
scale_ref: str | Literal["auto"] = "auto",
119115
scale_query: str | Literal["auto"] = "auto",
120116
output_mode: Literal["affine", "return"] = "affine",
121-
device: Literal["cpu", "gpu"] | None = None,
122117
inplace: bool = True,
123118
**flavour_kwargs: Any,
124119
) -> SpatialData | AlignResult | None:
@@ -145,7 +140,7 @@ def align_images(
145140
``'affine'`` registers the fit on the query image element so all of
146141
its scales inherit the transformation; ``'return'`` returns the raw
147142
:class:`AlignResult`.
148-
device, inplace, flavour_kwargs
143+
inplace, flavour_kwargs
149144
See :func:`align_obs`.
150145
"""
151146
validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_IMAGES, op="align_images")
@@ -160,7 +155,7 @@ def align_images(
160155
scale_query=scale_query,
161156
)
162157
backend = get_backend(flavour)
163-
result = backend.align_images(pair, device=device, **flavour_kwargs)
158+
result = backend.align_images(pair, **flavour_kwargs)
164159

165160
return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace)
166161

src/squidpy/experimental/tl/_align/_backends/_base.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable
5+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
66

77
if TYPE_CHECKING:
88
from squidpy.experimental.tl._align._types import AlignPair, AlignResult
@@ -24,15 +24,11 @@ class AlignBackend(Protocol):
2424
def align_obs(
2525
self,
2626
pair: AlignPair,
27-
*,
28-
device: Literal["cpu", "gpu"] | None = None,
2927
**kwargs: Any,
3028
) -> AlignResult: ...
3129

3230
def align_images(
3331
self,
3432
pair: AlignPair,
35-
*,
36-
device: Literal["cpu", "gpu"] | None = None,
3733
**kwargs: Any,
3834
) -> AlignResult: ...

src/squidpy/experimental/tl/_align/_backends/_moscot.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from __future__ import annotations
1010

11-
from typing import TYPE_CHECKING, Any, Literal
11+
from typing import TYPE_CHECKING, Any
1212

1313
if TYPE_CHECKING:
1414
from squidpy.experimental.tl._align._types import AlignPair, AlignResult
@@ -21,13 +21,11 @@ class MoscotBackend:
2121
def align_obs(
2222
self,
2323
pair: AlignPair,
24-
*,
25-
device: Literal["cpu", "gpu"] | None = None,
2624
**kwargs: Any,
2725
) -> AlignResult:
2826
from squidpy.experimental.tl._align._jax import require_jax
2927

30-
require_jax(device)
28+
require_jax()
3129
raise NotImplementedError(
3230
"moscot backend `align_obs`: TODO. Skeleton landed; the moscot "
3331
"solver will replace this body in a follow-up PR."
@@ -36,8 +34,6 @@ def align_obs(
3634
def align_images(
3735
self,
3836
pair: AlignPair,
39-
*,
40-
device: Literal["cpu", "gpu"] | None = None,
4137
**kwargs: Any,
4238
) -> AlignResult:
4339
raise NotImplementedError("moscot does not implement image alignment; use `flavour='stalign'`.")

src/squidpy/experimental/tl/_align/_backends/_stalign.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""STalign backend.
22
3-
Wraps the JAX LDDMM solver lifted from scverse/squidpy#1150 (Selman Özleyen)
3+
Wraps the JAX LDDMM solver lifted from scverse/squidpy#1150 (Selman Ozleyen)
44
into the :class:`AlignBackend` Protocol. Only ``align_obs`` is implemented
55
today; ``align_images`` raises until upstream support exists.
66
"""
77

88
from __future__ import annotations
99

10-
from typing import TYPE_CHECKING, Any, Literal
10+
from typing import TYPE_CHECKING, Any
1111

1212
import numpy as np
1313

@@ -23,7 +23,6 @@ def align_obs(
2323
self,
2424
pair: AlignPair,
2525
*,
26-
device: Literal["cpu", "gpu"] | None = None,
2726
config: Any | None = None,
2827
landmarks_source: np.ndarray | None = None,
2928
landmarks_target: np.ndarray | None = None,
@@ -39,7 +38,7 @@ def align_obs(
3938
# `ModuleNotFoundError: import of jax halted; None in sys.modules`
4039
# instead of the clean `ImportError("JAX is required ...")` from
4140
# _jax.require_jax.
42-
require_jax(device)
41+
require_jax()
4342

4443
from squidpy.experimental.tl._align._backends._stalign_tools import stalign_points
4544
from squidpy.experimental.tl._align._types import AlignResult, ObsDisplacement
@@ -94,8 +93,6 @@ def align_obs(
9493
def align_images(
9594
self,
9695
pair: AlignPair,
97-
*,
98-
device: Literal["cpu", "gpu"] | None = None,
9996
**kwargs: Any,
10097
) -> AlignResult:
10198
raise NotImplementedError(
Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Lazy JAX import + device selection for JAX-backed alignment backends.
1+
"""Lazy JAX import guard for JAX-backed alignment backends.
22
33
JAX is an optional dependency. Importing this module is cheap; calling
44
:func:`require_jax` is what actually pulls JAX in, and only the
@@ -7,50 +7,22 @@
77

88
from __future__ import annotations
99

10-
from typing import TYPE_CHECKING, Any, Literal
10+
from typing import Any
1111

12-
if TYPE_CHECKING:
13-
Device = Any # jax.Device, but importing it eagerly defeats the purpose
12+
_INSTALL_HINT = 'JAX is required for the requested align_* flavour. Install with `pip install "squidpy[jax]"`.'
1413

1514

16-
_INSTALL_HINT = (
17-
"JAX is required for the requested align_* flavour. "
18-
"Install with `pip install jax` (CPU) or follow the JAX install guide for GPU."
19-
)
20-
21-
22-
def require_jax(device: Literal["cpu", "gpu"] | None = None) -> tuple[Any, Any]:
23-
"""Import JAX lazily and return ``(jax, device)``.
24-
25-
Parameters
26-
----------
27-
device
28-
``"cpu"``/``"gpu"`` to force a platform, or ``None`` to use whatever
29-
JAX picks as the default.
30-
31-
Returns
32-
-------
33-
jax_module
34-
The imported :mod:`jax` module.
35-
device
36-
A :class:`jax.Device` of the requested platform.
15+
def require_jax() -> Any:
16+
"""Import JAX lazily and return the module.
3717
3818
Raises
3919
------
4020
ImportError
4121
If JAX is not installed.
42-
RuntimeError
43-
If the requested device platform is not available on this host.
4422
"""
4523
try:
4624
import jax
4725
except ImportError as e:
4826
raise ImportError(_INSTALL_HINT) from e
4927

50-
if device is None:
51-
return jax, jax.devices()[0]
52-
53-
matching = [d for d in jax.devices() if d.platform == device]
54-
if not matching:
55-
raise RuntimeError(f"No JAX device of kind {device!r} available; have {[d.platform for d in jax.devices()]}.")
56-
return jax, matching[0]
28+
return jax

0 commit comments

Comments
 (0)