Skip to content

Commit e837ce6

Browse files
timtreispre-commit-ci[bot]selmanozleyen
authored
Add helper functions to tile image for inference (#1053)
* mvp * mvp * for notebook * added small plotting func * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor * added images from runner * made logging robust to different loggers * restored logging * restored logging * removed dead code * edge cases that caused scverse CI to fail * bump * make_tile_grid function * mvp * tiling functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mypy fixes * bugfix * updated tests for correct plot * images from runner * bugfix + more tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * images from runner * cleaned up * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated docstrings * adressed reviewer feedback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed dead code * renamed enum --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selman Özleyen <32667648+selmanozleyen@users.noreply.github.com>
1 parent cacd943 commit e837ce6

20 files changed

Lines changed: 1096 additions & 78 deletions

src/squidpy/_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from multiprocessing import Manager, cpu_count
1212
from queue import Queue
1313
from threading import Thread
14-
from typing import TYPE_CHECKING, Any
14+
from typing import TYPE_CHECKING, Any, Literal
1515

1616
import joblib as jl
1717
import numba
1818
import numpy as np
19+
import xarray as xr
1920
from spatialdata.models import Image2DModel, Labels2DModel
2021

2122
__all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"]
@@ -372,3 +373,17 @@ def _yx_from_shape(shape: tuple[int, ...]) -> tuple[int, int]:
372373
return shape[1], shape[2]
373374

374375
raise ValueError(f"Unsupported shape {shape}. Expected (y, x) or (c, y, x).")
376+
377+
378+
def _ensure_dim_order(img_da: xr.DataArray, order: Literal["cyx", "yxc"] = "yxc") -> xr.DataArray:
379+
"""
380+
Ensure dims are in the requested order and that a 'c' dim exists.
381+
Only supports images with dims subset of {'y','x','c'}.
382+
"""
383+
dims = list(img_da.dims)
384+
if "y" not in dims or "x" not in dims:
385+
raise ValueError(f'Expected dims to include "y" and "x". Found dims={dims}')
386+
if "c" not in dims:
387+
img_da = img_da.expand_dims({"c": [0]})
388+
# After possible expand, just transpose to target
389+
return img_da.transpose(*tuple(order))

src/squidpy/experimental/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from __future__ import annotations
88

9-
from . import im
10-
from .im._detect_tissue import detect_tissue
9+
from . import im, pl
1110

12-
__all__ = ["detect_tissue", "im"]
11+
__all__ = ["im", "pl"]

src/squidpy/experimental/im/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,12 @@
55
FelzenszwalbParams,
66
detect_tissue,
77
)
8+
from ._make_tiles import make_tiles, make_tiles_from_spots
89

9-
__all__ = ["detect_tissue", "BackgroundDetectionParams", "FelzenszwalbParams"]
10+
__all__ = [
11+
"BackgroundDetectionParams",
12+
"FelzenszwalbParams",
13+
"detect_tissue",
14+
"make_tiles",
15+
"make_tiles_from_spots",
16+
]

src/squidpy/experimental/im/_detect_tissue.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,27 @@
44
from dataclasses import dataclass
55
from typing import Literal
66

7+
import dask.array as da
78
import numpy as np
89
import spatialdata as sd
910
import xarray as xr
11+
from dask.base import is_dask_collection
1012
from dask_image.ndinterp import affine_transform as da_affine
1113
from skimage import measure
1214
from skimage.filters import gaussian, threshold_otsu
1315
from skimage.morphology import binary_closing, disk, remove_small_holes
1416
from skimage.segmentation import felzenszwalb
1517
from skimage.util import img_as_float
16-
from spatialdata._logging import logger as logg
18+
from spatialdata._logging import logger
1719
from spatialdata.models import Labels2DModel
1820
from spatialdata.transformations import get_transformation
1921

20-
from squidpy._utils import _get_scale_factors, _yx_from_shape
22+
from squidpy._utils import _ensure_dim_order, _get_scale_factors, _yx_from_shape
2123

22-
from ._utils import _flatten_channels, _get_image_data
24+
from ._utils import _flatten_channels, _get_element_data
2325

2426

25-
class DETECT_TISSUE_METHOD(enum.Enum):
27+
class DetectTissueMethod(enum.Enum):
2628
OTSU = enum.auto()
2729
FELZENSZWALB = enum.auto()
2830

@@ -70,7 +72,7 @@ def detect_tissue(
7072
image_key: str,
7173
*,
7274
scale: str = "auto",
73-
method: DETECT_TISSUE_METHOD | str = DETECT_TISSUE_METHOD.OTSU,
75+
method: DetectTissueMethod | str = DetectTissueMethod.OTSU,
7476
channel_format: Literal["infer", "rgb", "rgba", "multichannel"] = "infer",
7577
background_detection_params: BackgroundDetectionParams | None = None,
7678
corners_are_background: bool = True,
@@ -98,8 +100,8 @@ def detect_tissue(
98100
method
99101
Tissue detection method. Valid options are:
100102
101-
- `DETECT_TISSUE_METHOD.OTSU` or `"otsu"` - Otsu thresholding with background detection.
102-
- `DETECT_TISSUE_METHOD.FELZENSZWALB` or `"felzenszwalb"` - Felzenszwalb superpixel segmentation.
103+
- `DetectTissueMethod.OTSU` or `"otsu"` - Otsu thresholding with background detection.
104+
- `DetectTissueMethod.FELZENSZWALB` or `"felzenszwalb"` - Felzenszwalb superpixel segmentation.
103105
104106
channel_format
105107
Expected format of image channels. Valid options are:
@@ -155,7 +157,7 @@ def detect_tissue(
155157
# Normalize method
156158
if isinstance(method, str):
157159
try:
158-
method = DETECT_TISSUE_METHOD[method.upper()]
160+
method = DetectTissueMethod[method.upper()]
159161
except KeyError as e:
160162
raise ValueError('method must be "otsu" or "felzenszwalb"') from e
161163

@@ -170,7 +172,9 @@ def detect_tissue(
170172
manual_scale = scale.lower() != "auto"
171173

172174
# Load smallest available or explicit scale
173-
img_src = _get_image_data(sdata, image_key, scale=scale if manual_scale else "auto")
175+
img_node = sdata.images[image_key]
176+
img_da = _get_element_data(img_node, scale if manual_scale else "auto", "image", image_key)
177+
img_src = _ensure_dim_order(img_da, "yxc")
174178
src_h, src_w = _yx_from_shape(img_src.shape)
175179
n_src_px = src_h * src_w
176180

@@ -180,13 +184,13 @@ def detect_tissue(
180184
# Decide working resolution
181185
need_downscale = (not manual_scale) and (n_src_px > auto_max_pixels)
182186
if need_downscale:
183-
logg.info("Downscaling for faster computation.")
187+
logger.info("Downscaling for faster computation.")
184188
img_grey = _downscale_with_dask(img_grey=img_grey_da, target_pixels=auto_max_pixels)
185189
else:
186190
img_grey = img_grey_da.values # may compute
187191

188192
# First-pass foreground
189-
if method == DETECT_TISSUE_METHOD.OTSU:
193+
if method == DetectTissueMethod.OTSU:
190194
img_fg_mask_bool = _segment_otsu(img_grey=img_grey, params=bgp)
191195
else:
192196
p = felzenszwalb_params or FelzenszwalbParams()
@@ -225,13 +229,9 @@ def detect_tissue(
225229
return None
226230

227231
# If dask-backed, return a NumPy array to honor the signature
228-
try:
229-
import dask.array as da # noqa: F401
232+
if is_dask_collection(img_fg_labels_up):
233+
return np.asarray(img_fg_labels_up.compute())
230234

231-
if hasattr(img_fg_labels_up, "compute"):
232-
return np.asarray(img_fg_labels_up.compute())
233-
except (ImportError, AttributeError, TypeError):
234-
pass
235235
return np.asarray(img_fg_labels_up)
236236

237237

@@ -241,8 +241,6 @@ def _affine_upscale_nearest(labels: np.ndarray, scale_matrix: np.ndarray, target
241241
Nearest-neighbor affine upscaling using dask-image. Returns dask array if available, else NumPy.
242242
"""
243243
try:
244-
import dask.array as da
245-
246244
lbl_da = da.from_array(labels, chunks="auto")
247245
result = da_affine(
248246
lbl_da,
@@ -256,6 +254,7 @@ def _affine_upscale_nearest(labels: np.ndarray, scale_matrix: np.ndarray, target
256254
)
257255

258256
return np.asarray(result)
257+
259258
except (ImportError, AttributeError, TypeError):
260259
sy = target_shape[0] / labels.shape[0]
261260
sx = target_shape[1] / labels.shape[1]
@@ -311,7 +310,7 @@ def _downscale_with_dask(img_grey: xr.DataArray, target_pixels: int) -> np.ndarr
311310

312311
fy = max(1, int(np.ceil(h / target_h)))
313312
fx = max(1, int(np.ceil(w / target_w)))
314-
logg.info(f"Downscaling from {h}×{w} with coarsen={fy}×{fx} to ≤{target_pixels} px.")
313+
logger.info(f"Downscaling from {h}×{w} with coarsen={fy}×{fx} to ≤{target_pixels} px.")
315314

316315
da_small = _ensure_dask(img_grey).coarsen(y=fy, x=fx, boundary="trim").mean()
317316
return np.asarray(_dask_compute(da_small))
@@ -322,9 +321,7 @@ def _ensure_dask(da: xr.DataArray) -> xr.DataArray:
322321
Ensure DataArray is dask-backed. If not, chunk to reasonable tiles.
323322
"""
324323
try:
325-
import dask.array as dask_array
326-
327-
if hasattr(da, "data") and isinstance(da.data, dask_array.Array):
324+
if hasattr(da, "data") and isinstance(da.data, da.Array):
328325
return da
329326
return da.chunk({"y": 2048, "x": 2048})
330327
except (ImportError, AttributeError):
@@ -336,10 +333,9 @@ def _dask_compute(img_da: xr.DataArray) -> np.ndarray:
336333
Compute an xarray DataArray (possibly dask-backed) to a NumPy array with a ProgressBar if available.
337334
"""
338335
try:
339-
import dask.array as dask_array
340336
from dask.diagnostics import ProgressBar
341337

342-
if hasattr(img_da, "data") and isinstance(img_da.data, dask_array.Array):
338+
if hasattr(img_da, "data") and isinstance(img_da.data, da.Array):
343339
with ProgressBar():
344340
computed = img_da.data.compute()
345341
return np.asarray(computed)

0 commit comments

Comments
 (0)