44from dataclasses import dataclass
55from typing import Literal
66
7+ import dask .array as da
78import numpy as np
89import spatialdata as sd
910import xarray as xr
11+ from dask .base import is_dask_collection
1012from dask_image .ndinterp import affine_transform as da_affine
1113from skimage import measure
1214from skimage .filters import gaussian , threshold_otsu
1315from skimage .morphology import binary_closing , disk , remove_small_holes
1416from skimage .segmentation import felzenszwalb
1517from skimage .util import img_as_float
16- from spatialdata ._logging import logger as logg
18+ from spatialdata ._logging import logger
1719from spatialdata .models import Labels2DModel
1820from spatialdata .transformations import get_transformation
1921
20- from squidpy ._utils import _get_scale_factors , _yx_from_shape
22+ from squidpy ._utils import _ensure_dim_order , _get_scale_factors , _yx_from_shape
2123
22- from ._utils import _flatten_channels , _get_image_data
24+ from ._utils import _flatten_channels , _get_element_data
2325
2426
25- class DETECT_TISSUE_METHOD (enum .Enum ):
27+ class DetectTissueMethod (enum .Enum ):
2628 OTSU = enum .auto ()
2729 FELZENSZWALB = enum .auto ()
2830
@@ -70,7 +72,7 @@ def detect_tissue(
7072 image_key : str ,
7173 * ,
7274 scale : str = "auto" ,
73- method : DETECT_TISSUE_METHOD | str = DETECT_TISSUE_METHOD .OTSU ,
75+ method : DetectTissueMethod | str = DetectTissueMethod .OTSU ,
7476 channel_format : Literal ["infer" , "rgb" , "rgba" , "multichannel" ] = "infer" ,
7577 background_detection_params : BackgroundDetectionParams | None = None ,
7678 corners_are_background : bool = True ,
@@ -98,8 +100,8 @@ def detect_tissue(
98100 method
99101 Tissue detection method. Valid options are:
100102
101- - `DETECT_TISSUE_METHOD .OTSU` or `"otsu"` - Otsu thresholding with background detection.
102- - `DETECT_TISSUE_METHOD .FELZENSZWALB` or `"felzenszwalb"` - Felzenszwalb superpixel segmentation.
103+ - `DetectTissueMethod .OTSU` or `"otsu"` - Otsu thresholding with background detection.
104+ - `DetectTissueMethod .FELZENSZWALB` or `"felzenszwalb"` - Felzenszwalb superpixel segmentation.
103105
104106 channel_format
105107 Expected format of image channels. Valid options are:
@@ -155,7 +157,7 @@ def detect_tissue(
155157 # Normalize method
156158 if isinstance (method , str ):
157159 try :
158- method = DETECT_TISSUE_METHOD [method .upper ()]
160+ method = DetectTissueMethod [method .upper ()]
159161 except KeyError as e :
160162 raise ValueError ('method must be "otsu" or "felzenszwalb"' ) from e
161163
@@ -170,7 +172,9 @@ def detect_tissue(
170172 manual_scale = scale .lower () != "auto"
171173
172174 # Load smallest available or explicit scale
173- img_src = _get_image_data (sdata , image_key , scale = scale if manual_scale else "auto" )
175+ img_node = sdata .images [image_key ]
176+ img_da = _get_element_data (img_node , scale if manual_scale else "auto" , "image" , image_key )
177+ img_src = _ensure_dim_order (img_da , "yxc" )
174178 src_h , src_w = _yx_from_shape (img_src .shape )
175179 n_src_px = src_h * src_w
176180
@@ -180,13 +184,13 @@ def detect_tissue(
180184 # Decide working resolution
181185 need_downscale = (not manual_scale ) and (n_src_px > auto_max_pixels )
182186 if need_downscale :
183- logg .info ("Downscaling for faster computation." )
187+ logger .info ("Downscaling for faster computation." )
184188 img_grey = _downscale_with_dask (img_grey = img_grey_da , target_pixels = auto_max_pixels )
185189 else :
186190 img_grey = img_grey_da .values # may compute
187191
188192 # First-pass foreground
189- if method == DETECT_TISSUE_METHOD .OTSU :
193+ if method == DetectTissueMethod .OTSU :
190194 img_fg_mask_bool = _segment_otsu (img_grey = img_grey , params = bgp )
191195 else :
192196 p = felzenszwalb_params or FelzenszwalbParams ()
@@ -225,13 +229,9 @@ def detect_tissue(
225229 return None
226230
227231 # If dask-backed, return a NumPy array to honor the signature
228- try :
229- import dask . array as da # noqa: F401
232+ if is_dask_collection ( img_fg_labels_up ) :
233+ return np . asarray ( img_fg_labels_up . compute ())
230234
231- if hasattr (img_fg_labels_up , "compute" ):
232- return np .asarray (img_fg_labels_up .compute ())
233- except (ImportError , AttributeError , TypeError ):
234- pass
235235 return np .asarray (img_fg_labels_up )
236236
237237
@@ -241,8 +241,6 @@ def _affine_upscale_nearest(labels: np.ndarray, scale_matrix: np.ndarray, target
241241 Nearest-neighbor affine upscaling using dask-image. Returns dask array if available, else NumPy.
242242 """
243243 try :
244- import dask .array as da
245-
246244 lbl_da = da .from_array (labels , chunks = "auto" )
247245 result = da_affine (
248246 lbl_da ,
@@ -256,6 +254,7 @@ def _affine_upscale_nearest(labels: np.ndarray, scale_matrix: np.ndarray, target
256254 )
257255
258256 return np .asarray (result )
257+
259258 except (ImportError , AttributeError , TypeError ):
260259 sy = target_shape [0 ] / labels .shape [0 ]
261260 sx = target_shape [1 ] / labels .shape [1 ]
@@ -311,7 +310,7 @@ def _downscale_with_dask(img_grey: xr.DataArray, target_pixels: int) -> np.ndarr
311310
312311 fy = max (1 , int (np .ceil (h / target_h )))
313312 fx = max (1 , int (np .ceil (w / target_w )))
314- logg .info (f"Downscaling from { h } ×{ w } with coarsen={ fy } ×{ fx } to ≤{ target_pixels } px." )
313+ logger .info (f"Downscaling from { h } ×{ w } with coarsen={ fy } ×{ fx } to ≤{ target_pixels } px." )
315314
316315 da_small = _ensure_dask (img_grey ).coarsen (y = fy , x = fx , boundary = "trim" ).mean ()
317316 return np .asarray (_dask_compute (da_small ))
@@ -322,9 +321,7 @@ def _ensure_dask(da: xr.DataArray) -> xr.DataArray:
322321 Ensure DataArray is dask-backed. If not, chunk to reasonable tiles.
323322 """
324323 try :
325- import dask .array as dask_array
326-
327- if hasattr (da , "data" ) and isinstance (da .data , dask_array .Array ):
324+ if hasattr (da , "data" ) and isinstance (da .data , da .Array ):
328325 return da
329326 return da .chunk ({"y" : 2048 , "x" : 2048 })
330327 except (ImportError , AttributeError ):
@@ -336,10 +333,9 @@ def _dask_compute(img_da: xr.DataArray) -> np.ndarray:
336333 Compute an xarray DataArray (possibly dask-backed) to a NumPy array with a ProgressBar if available.
337334 """
338335 try :
339- import dask .array as dask_array
340336 from dask .diagnostics import ProgressBar
341337
342- if hasattr (img_da , "data" ) and isinstance (img_da .data , dask_array .Array ):
338+ if hasattr (img_da , "data" ) and isinstance (img_da .data , da .Array ):
343339 with ProgressBar ():
344340 computed = img_da .data .compute ()
345341 return np .asarray (computed )
0 commit comments