diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 24f7cf6e..e8561431 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -9,12 +9,14 @@ from ._make_tiles import make_tiles, make_tiles_from_spots from ._qc_image import qc_image from ._qc_metrics import QCMetric +from ._stitched_labels import make_stitched_labels __all__ = [ "BackgroundDetectionParams", "FelzenszwalbParams", "QCMetric", "WekaParams", + "make_stitched_labels", "detect_tissue", "make_tiles", "make_tiles_from_spots", diff --git a/src/squidpy/experimental/im/_stitched_labels.py b/src/squidpy/experimental/im/_stitched_labels.py new file mode 100644 index 00000000..b22738ab --- /dev/null +++ b/src/squidpy/experimental/im/_stitched_labels.py @@ -0,0 +1,421 @@ +"""Materialise a stitched labels element from a stitch_tile_cuts result. + +Companion to :func:`squidpy.experimental.tl.stitch_tile_cuts`. Takes the +piece-to-group mapping from ``stitch_group_id`` in the QC table and writes +a new labels element where stitched pieces share a single ID. The original +labels element is untouched. +""" + +from __future__ import annotations + +import copy as _copy +from collections.abc import Callable + +import anndata as ad +import dask.array as da +import numpy as np +import pandas as pd +import spatialdata as sd +import xarray as xr +from scipy.ndimage import binary_closing +from skimage.measure import regionprops +from skimage.morphology import disk as morph_disk +from spatialdata._logging import logger as logg +from spatialdata.models import Labels2DModel, TableModel +from spatialdata.transformations import get_transformation + +from squidpy.experimental.tl._tiling_qc import resolve_labels_array + +__all__ = ["make_stitched_labels"] + + +def _build_lookup(adata_obs: pd.DataFrame, dtype: np.dtype) -> np.ndarray: + """Build an int->int LUT from ``label_id`` to ``stitch_group_id``. + + LUT covers ``[0, max_label_id]``; unmapped indices keep their own value + (identity), so background (0) and any cells absent from the QC table are + preserved. + + Raises + ------ + ValueError + If ``stitch_group_id`` (or ``label_id``) values exceed the labels' + dtype range -- silent truncation here would alias unrelated cells. + """ + label_ids = adata_obs["label_id"].astype(np.int64).to_numpy() + group_ids = adata_obs["stitch_group_id"].astype(np.int64).to_numpy() + if np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + worst = max(int(label_ids.max(initial=0)), int(group_ids.max(initial=0))) + if worst > info.max: + raise ValueError( + f"label_id / stitch_group_id values up to {worst} exceed the labels " + f"dtype range {dtype} (max {info.max}); cannot build a safe LUT." + ) + max_id = int(label_ids.max(initial=0)) + lut = np.arange(max_id + 1, dtype=dtype) + lut[label_ids] = group_ids.astype(dtype) + return lut + + +def _apply_lut(labels_da: xr.DataArray, lut: np.ndarray) -> da.Array | np.ndarray: + """Lazily remap a labels DataArray via ``np.take`` over its dask blocks. + + Returns a bare array (dask or numpy) rather than a DataArray so the caller + can re-parse via Labels2DModel without colliding transformation metadata. + """ + src = labels_da.data + if isinstance(src, da.Array): + return src.map_blocks(lambda block, _lut=lut: _lut[block], dtype=lut.dtype) + return lut[np.asarray(src)] + + +def _join_stitched_labels( + labels_arr: da.Array | np.ndarray, + stitched_group_ids: set[int], + close_radius: int = 3, +) -> np.ndarray: + """Morphologically close gaps between pieces of each stitched group. + + The basic LUT remap leaves stitched groups as multi-component regions + (the cut stripe between pieces stays at 0). This pass fills only those + background pixels that lie inside the closed convex-ish hull of each + stitched group, so each group becomes a single connected component. + Other cells' pixels are never overwritten. + + Forces materialisation of the labels array; cost is O(image_size) for + the regionprops pass plus O(stitched_groups x bbox) for the per-group + closing. Both are bounded for typical workloads. + """ + if hasattr(labels_arr, "compute"): + out = np.asarray(labels_arr.compute()) + else: + out = np.asarray(labels_arr).copy() + if not stitched_group_ids: + return out + + # Single regionprops pass to locate every label's bbox. + bboxes = {int(r.label): r.bbox for r in regionprops(out)} + pad = close_radius + 2 + H, W = out.shape[-2], out.shape[-1] + structure = morph_disk(close_radius) + + for gid in stitched_group_ids: + bbox = bboxes.get(int(gid)) + if bbox is None: + continue + r0, c0, r1, c1 = bbox + r0 = max(r0 - pad, 0) + c0 = max(c0 - pad, 0) + r1 = min(r1 + pad, H) + c1 = min(c1 + pad, W) + crop = out[r0:r1, c0:c1] + mask = crop == int(gid) + if not mask.any(): + continue + closed = binary_closing(mask, structure=structure) + # Only fill genuine background pixels -- never overwrite another cell. + fill = closed & ~mask & (crop == 0) + if fill.any(): + crop[fill] = int(gid) + out[r0:r1, c0:c1] = crop + return out + + +_BUILTIN_STRATEGIES: dict[str, Callable[[pd.Series], object]] = { + "sum": lambda s: s.sum(), + "min": lambda s: s.min(), + "max": lambda s: s.max(), + "mean": lambda s: s.mean(), + "median": lambda s: s.median(), + "first": lambda s: s.iloc[0], +} + +# Vectorised counterparts: ``f(block) -> 1-D array of length n_cols``. Used +# in :func:`_aggregate_X` to avoid an O(groups*cols) Python loop when the +# user passes a built-in strategy name. Callable strategies fall back to +# the per-column path. +_BUILTIN_X_REDUCERS: dict[str, Callable[[np.ndarray], np.ndarray]] = { + "sum": lambda b: b.sum(axis=0), + "min": lambda b: b.min(axis=0), + "max": lambda b: b.max(axis=0), + "mean": lambda b: b.mean(axis=0), + "median": lambda b: np.median(b, axis=0), + "first": lambda b: b[0], +} + +# Columns whose value is shared across all members of a stitch group; we always +# take the first member's value rather than aggregating. +_GROUP_INVARIANT_COLS = frozenset({"stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence", "region"}) + + +def _resolve_strategy(strategy: str | Callable[[pd.Series], object]) -> Callable[[pd.Series], object]: + if callable(strategy): + return strategy + if strategy not in _BUILTIN_STRATEGIES: + raise ValueError( + f"Unknown merge_strategy {strategy!r}. Use one of {sorted(_BUILTIN_STRATEGIES)} or pass a callable." + ) + return _BUILTIN_STRATEGIES[strategy] + + +def _aggregate_X( + X, + group_indices: list[np.ndarray], + strategy: str | Callable[[pd.Series], object], +) -> np.ndarray: + """Aggregate ``X`` row-blocks into one row per group, column-wise. + + Built-in string strategies use vectorised numpy reductions. Callable + strategies fall back to a per-column ``pd.Series`` loop -- slow on wide + matrices but expressive. Sparse input is densified once up front. + """ + if hasattr(X, "toarray"): + X = X.toarray() + X = np.asarray(X) + n_groups = len(group_indices) + out = np.empty((n_groups, X.shape[1]), dtype=X.dtype) + reducer = None + if isinstance(strategy, str) and strategy in _BUILTIN_X_REDUCERS: + reducer = _BUILTIN_X_REDUCERS[strategy] + + if reducer is not None: + for i, idx in enumerate(group_indices): + out[i] = reducer(X[idx]) + else: + # Resolve callable lazily so non-builtin strings raise upstream. + strategy_fn = _resolve_strategy(strategy) + for i, idx in enumerate(group_indices): + block = X[idx] + for c in range(X.shape[1]): + out[i, c] = strategy_fn(pd.Series(block[:, c])) + return out + + +def _collapse_groups( + adata: ad.AnnData, + new_labels_key: str, + merge_strategy: str | Callable[[pd.Series], object], +) -> ad.AnnData: + """Collapse each stitch group into a single row. + + Output has one row per unique ``stitch_group_id``: unstitched cells (their + own group) keep their row unchanged, stitched groups (size 2-4) collapse + via ``merge_strategy``. ``.obs`` columns, ``.uns``, ``.var`` and ``.X`` + are preserved/aggregated; ``spatialdata_attrs`` and the ``region`` column + are rewritten to point at the new labels element. + + Aggregation rules: + - ``label_id``: rewritten to the group id (matches new labels element). + - ``stitch_group_id``, ``is_stitched``, ``n_pieces``, ``stitch_confidence``, + ``region``: members agree -> first value. + - Other numeric obs columns and all of ``X``: ``merge_strategy`` (default + ``"sum"``). Built-ins: ``sum``, ``min``, ``max``, ``mean``, ``median``, + ``first``. A callable receives a :class:`pandas.Series` and returns a + scalar; it's applied column-wise to both ``.obs`` and ``.X``. + - Non-numeric obs columns: ``"first"`` regardless of ``merge_strategy`` + (sum/mean don't make sense for strings/categoricals). + + Note: ``merge_strategy="sum"`` is the right default for additive features + (area, intensity, count) but wrong for centroids, scores, fractions. + Override accordingly for those. + + .. warning:: + ``.obsm``, ``.obsp``, ``.layers`` are passed through but not + aggregated. If their row dimensions become inconsistent with the new + ``n_obs``, downstream tools may complain. Drop them if not needed. + """ + obs = adata.obs + if "stitch_group_id" not in obs.columns: + raise ValueError("AnnData missing 'stitch_group_id'; run stitch_tile_cuts first.") + if "label_id" not in obs.columns: + raise ValueError("AnnData missing 'label_id'.") + + strategy_fn = _resolve_strategy(merge_strategy) + group_ids = obs["stitch_group_id"].astype(int).to_numpy() + unique_groups = np.unique(group_ids) + # Positional indices per group (preserves order of first appearance via sort). + indices_by_group: list[np.ndarray] = [np.where(group_ids == g)[0] for g in unique_groups] + + # ---- Aggregate obs ---- + out_rows: dict[str, list] = {col: [] for col in obs.columns} + for gid, idx in zip(unique_groups, indices_by_group, strict=True): + members = obs.iloc[idx] + for col in obs.columns: + series = members[col] + if col == "label_id": + out_rows[col].append(int(gid)) + elif col in _GROUP_INVARIANT_COLS or not pd.api.types.is_numeric_dtype(series): + out_rows[col].append(series.iloc[0]) + else: + out_rows[col].append(strategy_fn(series)) + + new_obs = pd.DataFrame(out_rows) + # Preserve dtypes where possible (groupby+agg sometimes loses categorical). + for col in new_obs.columns: + try: + new_obs[col] = new_obs[col].astype(obs[col].dtype) + except (TypeError, ValueError): + pass + # Update the region column to point at the new labels element. + if "region" in new_obs.columns: + new_obs["region"] = pd.Categorical([new_labels_key] * len(new_obs)) + new_obs.index = [f"group_{gid}" for gid in unique_groups] + + # ---- Aggregate X ---- + if adata.X is not None and adata.X.shape[1] > 0: + new_X = _aggregate_X(adata.X, indices_by_group, merge_strategy) + else: + new_X = np.empty((len(unique_groups), 0), dtype=np.float32) + + # ---- Preserve var / uns / pass-through obsm-style fields ---- + new_uns = _copy.deepcopy(dict(adata.uns)) + new_uns["spatialdata_attrs"] = { + "region": new_labels_key, + "region_key": "region", + "instance_key": "label_id", + } + out = ad.AnnData(X=new_X, obs=new_obs, var=adata.var.copy(), uns=new_uns) + + # Warn if there are row-dimensioned fields we didn't aggregate; user can + # decide whether to drop them. + skipped = [name for name in ("obsm", "obsp", "layers") if getattr(adata, name, None)] + if skipped: + logg.warning( + f"AnnData has {skipped}; these were not aggregated and the " + "resulting table omits them. Pass them through manually if needed." + ) + + return out + + +def make_stitched_labels( + sdata: sd.SpatialData, + labels_key: str, + qc_table_key: str | None = None, + labels_key_added: str | None = None, + table_key_added: str | None = None, + write_table: bool = True, + merge_strategy: str | Callable[[pd.Series], object] = "sum", + join_labels: bool = False, + join_close_radius: int = 3, + inplace: bool = True, +) -> dict[str, object] | None: + """Materialise a stitched labels element from a stitch_tile_cuts result. + + Reads the ``stitch_group_id`` mapping in the QC table, builds a lazy + int->int LUT, and registers a new labels element where each stitched + group shares a single ID. The original labels element is **not** + modified. + + Optionally also writes a companion AnnData (``write_table=True``) with one + row per unique ``stitch_group_id`` -- unstitched cells keep their row + unchanged, stitched groups (size 2-4) collapse via ``merge_strategy``. + + Parameters + ---------- + sdata + :class:`~spatialdata.SpatialData` with a labels element and a QC + table that has been processed by + :func:`squidpy.experimental.tl.stitch_tile_cuts`. + labels_key + Key in ``sdata.labels`` of the original labels element. + qc_table_key + Key of the QC table. Defaults to ``"{labels_key}_qc"``. + labels_key_added + Key for the new labels element. Defaults to + ``"{labels_key}_stitched"``. Existing element at this key is + overwritten with a warning. + table_key_added + Key for the optional collapsed AnnData (one row per unique + ``stitch_group_id``). Defaults to ``"{labels_key_added}_table"`` + (must differ from the labels element key -- SpatialData requires + unique names across element types). + write_table + If ``True``, also write the collapsed AnnData to + ``sdata.tables[table_key_added]``. + merge_strategy + How to aggregate numeric ``.obs`` columns and ``.X`` across the + 2-4 pieces of each stitched cell. String options: + ``"sum"`` (default), ``"min"``, ``"max"``, ``"mean"``, ``"median"``, + ``"first"``. Callable: receives a :class:`pandas.Series` (one + column of one group's members) and returns a scalar; applied + column-wise. + + ``"sum"`` is the right default for additive features (area, + intensity); for centroids, scores, or fractions, override with + ``"mean"`` or pass a callable. + + Two classes of columns are **always** taken from the first member + regardless of ``merge_strategy`` (including callables): + + - Group-invariant columns -- ``stitch_group_id``, ``is_stitched``, + ``n_pieces``, ``stitch_confidence``, ``region`` -- because every + member of a group already shares the same value. + - Non-numeric columns (strings, categoricals, booleans) -- because + ``sum`` / ``mean`` / etc. don't have a meaningful interpretation. + join_labels + If ``True``, morphologically close the gap between pieces of each + stitched group so the resulting labels are single connected + components instead of multi-component regions sharing an ID. Only + background pixels inside each group's closed hull are filled; + other cells are never overwritten. **Forces materialisation of + the labels array** -- cost is O(image_size) plus O(stitched x bbox). + Default ``False`` preserves the original gap pixels. + join_close_radius + Radius (px) of the disk structuring element used when + ``join_labels=True``. Default ``3`` matches the closing radius + used during scoring; raise it if pieces remain disconnected after + joining. + inplace + If ``True`` (default), write the new labels element (and table when + ``write_table=True``) into ``sdata``. If ``False``, return the + materialised objects in a dict ``{"labels": ..., "table": ...}`` + without mutating ``sdata``; ``"table"`` is ``None`` when + ``write_table=False``. + """ + if labels_key not in sdata.labels: + raise ValueError(f"Labels key '{labels_key}' not found in sdata.labels.") + table_key = qc_table_key if qc_table_key is not None else f"{labels_key}_qc" + if table_key not in sdata.tables: + raise ValueError(f"QC table '{table_key}' not found in sdata.tables.") + adata = sdata.tables[table_key] + if "stitch_group_id" not in adata.obs.columns: + raise ValueError( + f"QC table '{table_key}' is missing 'stitch_group_id'; run squidpy.experimental.tl.stitch_tile_cuts first." + ) + + qc_params = adata.uns.get("tiling_qc", {}) + scale = qc_params.get("scale") + labels_da = resolve_labels_array(sdata, labels_key, scale) + + lut = _build_lookup(adata.obs, labels_da.dtype) + new_data = _apply_lut(labels_da, lut) + if join_labels: + stitched_gids = adata.obs.loc[adata.obs["is_stitched"].astype(bool), "stitch_group_id"].astype(int).unique() + new_data = _join_stitched_labels(new_data, set(int(g) for g in stitched_gids), close_radius=join_close_radius) + + out_key = labels_key_added if labels_key_added is not None else f"{labels_key}_stitched" + new_labels = Labels2DModel.parse( + data=new_data, + dims=("y", "x"), + transformations=get_transformation(sdata.labels[labels_key], get_all=True), + ) + new_table = None + if write_table: + collapsed = _collapse_groups(adata, out_key, merge_strategy) + new_table = TableModel.parse(collapsed) + + if not inplace: + return {"labels": new_labels, "table": new_table} + + if out_key in sdata.labels: + logg.warning(f"Overwriting existing labels element '{out_key}'.") + sdata.labels[out_key] = new_labels + + if new_table is not None: + tbl_key = table_key_added if table_key_added is not None else f"{out_key}_table" + if tbl_key in sdata.tables: + logg.warning(f"Overwriting existing table '{tbl_key}'.") + sdata.tables[tbl_key] = new_table + return None diff --git a/src/squidpy/experimental/tl/__init__.py b/src/squidpy/experimental/tl/__init__.py index e4e52977..56037011 100644 --- a/src/squidpy/experimental/tl/__init__.py +++ b/src/squidpy/experimental/tl/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations from ._tiling_qc import calculate_tiling_qc +from ._tiling_stitch import stitch_tile_cuts -__all__ = ["calculate_tiling_qc"] +__all__ = ["calculate_tiling_qc", "stitch_tile_cuts"] diff --git a/src/squidpy/experimental/tl/_tiling_qc.py b/src/squidpy/experimental/tl/_tiling_qc.py index 16c5cd74..a372a6a4 100644 --- a/src/squidpy/experimental/tl/_tiling_qc.py +++ b/src/squidpy/experimental/tl/_tiling_qc.py @@ -56,6 +56,24 @@ __all__ = ["calculate_tiling_qc"] + +def resolve_labels_array(sdata: sd.SpatialData, labels_key: str, scale: str | None) -> xr.DataArray: + """Resolve a labels element to its 2-D :class:`~xarray.DataArray`. + + Single-scale elements pass through; multi-scale (:class:`xarray.DataTree`) + elements require a ``scale`` selector. + + Shared by :func:`calculate_tiling_qc` and the downstream stitch helpers + so the multi-scale branch stays consistent. + """ + node = sdata.labels[labels_key] + if isinstance(node, xr.DataTree): + if scale is None: + raise ValueError(f"Labels '{labels_key}' is multi-scale; pass `scale` (e.g. 'scale0').") + return node[scale].ds["image"] + return node + + # Minimum cell area in pixels - smaller cells produce noisy contours _MIN_CELL_AREA = 20 @@ -531,13 +549,7 @@ def calculate_tiling_qc( if n_neighbors < 1: raise ValueError(f"n_neighbors must be >= 1, got {n_neighbors}.") - labels_node = sdata.labels[labels_key] - if isinstance(labels_node, xr.DataTree): - if scale is None: - raise ValueError("When using multi-scale labels, please specify the scale.") - labels_da = labels_node[scale].ds["image"] - else: - labels_da = labels_node + labels_da = resolve_labels_array(sdata, labels_key, scale) cell_info = _compute_centroids_for_labels(sdata, labels_key, labels_da, scale) if not cell_info: @@ -663,6 +675,21 @@ def _process_one(spec): if inplace: table_key = adata_key_added if adata_key_added is not None else f"{labels_key}_qc" + # If a previous QC table at this key carried stitch results, those are + # stale relative to the new outlier set. The new AnnData doesn't + # include them; warn the user so they know to re-run stitch_tile_cuts. + if table_key in sdata.tables: + existing = sdata.tables[table_key] + stitch_cols = [ + c + for c in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence") + if c in existing.obs.columns + ] + if stitch_cols: + logg.warning( + f"Previous QC table at '{table_key}' carried stitch columns " + f"({stitch_cols}); re-run squidpy.experimental.tl.stitch_tile_cuts." + ) sdata.tables[table_key] = TableModel.parse(adata) return None return adata diff --git a/src/squidpy/experimental/tl/_tiling_stitch.py b/src/squidpy/experimental/tl/_tiling_stitch.py new file mode 100644 index 00000000..a8944d58 --- /dev/null +++ b/src/squidpy/experimental/tl/_tiling_stitch.py @@ -0,0 +1,812 @@ +"""Stitching of tile-cut cells flagged by :func:`calculate_tiling_qc`. + +When segmentation is run tile-by-tile (Cellpose, Stardist, Mesmer, ...) cells +that straddle tile boundaries get cut into 2-4 pieces with characteristic +straight, axis-aligned cut edges. :func:`calculate_tiling_qc` flags these +as ``is_outlier=True``. This module pairs facing cut edges across boundaries +and assigns each candidate pair a heuristic geometric score in [0, 1]. + +The score is the arithmetic mean of four dataset-independent features -- +``iou``, ``endpoint_match``, ``merge_compactness``, ``merge_solidity`` -- +computed from the cut-edge geometry and the union mask after closing the +seam gap. No model is fitted or shipped; the formula is documented inline +and recorded in ``.uns["tiling_stitch"]["score_formula"]``. Users should +tune ``min_confidence`` for their data; ``0.7`` is a reasonable default. + +The labels element is **never** modified here -- only ``.obs`` columns are +written. Materialising a stitched labels element is opt-in via +:func:`squidpy.experimental.im.make_stitched_labels`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np +import spatialdata as sd +import xarray as xr +from scipy.ndimage import binary_closing +from skimage.measure import find_contours, regionprops +from skimage.measure import label as cc_label +from skimage.morphology import disk as morph_disk +from spatialdata._logging import logger as logg + +from ._tiling_qc import resolve_labels_array + +if TYPE_CHECKING: + from collections.abc import Iterable + + import anndata as ad + +__all__ = ["stitch_tile_cuts"] + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +# Sub-pixel tolerance for "lies on a bbox edge". ``find_contours`` returns +# half-pixel coordinates, so a real edge run sits within ~0.5 px of the line. +_DISTANCE_TOL = 0.75 + +# Cut-edge length must exceed both an absolute floor (filters tiny cells) and +# this fraction of the cell's equivalent diameter (filters arc-tops on +# naturally curved cells, where the bbox-edge contact is a single pixel). +_MIN_EDGE_LENGTH = 5.0 +_MIN_EDGE_LENGTH_RATIO = 0.4 + +# Density check: of the integer parallel-axis positions within a candidate +# run, what fraction has at least one near-edge contour point? A single- +# point arc-top fails this; a real chord across the cut passes trivially. +_MIN_EDGE_COVERAGE = 0.5 + +# Loose IoU floor used at candidate enumeration -- selection happens at the +# calibrated score stage. Keeping this loose lets the model see borderline +# negatives during scoring rather than excluding them upstream. +_CANDIDATE_MIN_IOU = 0.2 + +# Morphological closing radius used to bridge the gap when materialising the +# union mask for shape-quality features. Larger than ``max_gap`` to be +# robust to small cells where the gap is a meaningful fraction of cell size. +_CLOSE_RADIUS = 3 + +_METHOD_KEY = "tiling_stitch" +_STITCH_COLUMNS = ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence") + +# Features combined into ``stitch_confidence`` (arithmetic mean). All four +# are dataset-independent geometry / shape signals in [0, 1]. ``gap_score`` +# is also computed but only used as a hard filter (already inside +# ``max_gap`` by construction); it does not enter the score. +_SCORE_FEATURES: tuple[str, ...] = ( + "iou", + "endpoint_match", + "merge_compactness", + "merge_solidity", +) +_SCORE_FORMULA = "arithmetic_mean(iou, endpoint_match, merge_compactness, merge_solidity)" + +# --------------------------------------------------------------------------- +# Dataclasses +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _CutEdge: + """A candidate cut edge on a single cell's bbox. + + Attributes + ---------- + cell_id + Label ID of the piece carrying this edge. + axis + ``"h"`` (horizontal cut: edge is a horizontal line, cell sits above + or below it) or ``"v"`` (vertical cut). + coord + Position of the cut line: y-coord for ``"h"``, x-coord for ``"v"``. + extent + ``(min, max)`` along the parallel axis -- the chord at the cut line. + normal_dir + ``+1`` if the cell's centroid sits at greater coord than the cut + line, ``-1`` otherwise. Used to enforce facing pairs. + length + Euclidean length of the run (``extent[1] - extent[0]``). + """ + + cell_id: int + axis: str + coord: float + extent: tuple[float, float] + normal_dir: int + length: float + + +@dataclass(frozen=True) +class _StitchPair: + """A scored candidate pairing of two cut edges across a tile boundary. + + Confidence is the calibrated logistic-regression probability; feature + components are kept for diagnostics and for ``min``-based group + confidence aggregation. + """ + + cell_a: int + cell_b: int + axis: str + confidence: float + iou: float + endpoint_match: float + gap_score: float + merge_solidity: float + merge_compactness: float + edge_a: _CutEdge | None = field(default=None, repr=False) + edge_b: _CutEdge | None = field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Stage 1: cut-edge extraction +# --------------------------------------------------------------------------- + + +def _read_bbox_slice(labels_da: xr.DataArray | np.ndarray, y0: int, y1: int, x0: int, x1: int) -> np.ndarray: + """Read a 2-D bbox slice from numpy or xarray, squeezing singleton dims.""" + if isinstance(labels_da, np.ndarray): + return labels_da[y0:y1, x0:x1] + arr = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values + while arr.ndim > 2: + arr = arr.squeeze(0) + return arr + + +def _compute_outlier_bboxes( + labels_da: xr.DataArray | np.ndarray, + outlier_ids: Iterable[int], + chunk_size: int = 4096, +) -> dict[int, tuple[int, int, int, int]]: + """Compute global bboxes for the outlier subset in a single chunked pass. + + Returns mapping ``label_id -> (min_row, min_col, max_row, max_col)``. + Works on numpy or dask-backed xarray; for xarray the array is read in + ``chunk_size`` x ``chunk_size`` tiles so memory is bounded. + """ + outlier_set = {int(x) for x in outlier_ids} + bboxes: dict[int, tuple[int, int, int, int]] = {} + + if isinstance(labels_da, np.ndarray): + for region in regionprops(labels_da): + if region.label in outlier_set: + bboxes[region.label] = region.bbox + return bboxes + + H = int(labels_da.sizes.get("y", labels_da.shape[-2])) + W = int(labels_da.sizes.get("x", labels_da.shape[-1])) + # TODO: faster path -- pre-mask each chunk with np.where(np.isin(chunk, + # outlier_set), chunk, 0) before regionprops, so non-outlier cells are + # skipped instead of scanned. Worth doing if outlier fraction is < ~5%. + for y0 in range(0, H, chunk_size): + y1 = min(y0 + chunk_size, H) + for x0 in range(0, W, chunk_size): + x1 = min(x0 + chunk_size, W) + chunk = _read_bbox_slice(labels_da, y0, y1, x0, x1) + for region in regionprops(chunk): + lid = int(region.label) + if lid not in outlier_set: + continue + r0, c0, r1, c1 = region.bbox + r0 += y0 + c0 += x0 + r1 += y0 + c1 += x0 + prev = bboxes.get(lid) + if prev is None: + bboxes[lid] = (r0, c0, r1, c1) + else: + bboxes[lid] = (min(prev[0], r0), min(prev[1], c0), max(prev[2], r1), max(prev[3], c1)) + return bboxes + + +def _bbox_edge_run( + contour: np.ndarray, + perp_axis: int, + target: float, + distance_tol: float = _DISTANCE_TOL, + min_coverage: float = _MIN_EDGE_COVERAGE, +) -> tuple[float, float, float] | None: + """Find the extent of contour points lying near a single bbox edge. + + A genuine cut edge has many contour points clustered at the bbox boundary, + spanning a long parallel-axis range with high integer-position coverage. + A naturally curved cell only touches its bbox at a single point, which + fails either the count, length, or coverage check. + + Returns ``(ext_lo, ext_hi, length)`` if a substantial run is found. + """ + parallel_axis = 1 - perp_axis + near = np.abs(contour[:, perp_axis] - target) <= distance_tol + if near.sum() < 3: + return None + parallel_vals = contour[near, parallel_axis] + ext_lo = float(parallel_vals.min()) + ext_hi = float(parallel_vals.max()) + length = ext_hi - ext_lo + if length <= 0: + return None + width = max(int(np.ceil(length)), 1) + bins = np.zeros(width + 1, dtype=bool) + bins[np.clip((parallel_vals - ext_lo).astype(int), 0, width)] = True + coverage = float(bins.sum()) / (width + 1) + if coverage < min_coverage: + return None + return ext_lo, ext_hi, length + + +def _extract_cut_edges( + labels_da: xr.DataArray | np.ndarray, + outlier_ids: Iterable[int], + bboxes: dict[int, tuple[int, int, int, int]] | None = None, + distance_tol: float = _DISTANCE_TOL, + min_edge_length: float = _MIN_EDGE_LENGTH, + min_edge_length_ratio: float = _MIN_EDGE_LENGTH_RATIO, + min_edge_coverage: float = _MIN_EDGE_COVERAGE, +) -> list[_CutEdge]: + """Extract cardinal-aligned bbox-edge runs (cut-edge candidates) per outlier. + + For each outlier cell: + 1. Crop labels to its bbox + 1 px pad, build a binary mask. + 2. Trace its contour with :func:`skimage.measure.find_contours`. + 3. Check each of the 4 bbox-edge lines for a substantial straight run. + + A piece cut at a tile boundary always has its cut on a bbox edge -- the + piece terminates exactly at the cut. Curved cells only touch the bbox + at a single contour point, which the density check rejects. + + Cells at a 4-tile corner produce 2 perpendicular edges; mid-stripe pieces + can produce 2 parallel edges. + """ + outlier_list = [int(x) for x in outlier_ids] + if bboxes is None: + bboxes = _compute_outlier_bboxes(labels_da, outlier_list) + + edges: list[_CutEdge] = [] + for lid in outlier_list: + bbox = bboxes.get(lid) + if bbox is None: + continue + min_r, min_c, max_r, max_c = bbox + + crop_arr = _read_bbox_slice(labels_da, min_r, max_r, min_c, max_c) + mask = (crop_arr == lid).astype(np.float32) + if not mask.any(): + continue + mask = np.pad(mask, 1, mode="constant", constant_values=0) + contours = find_contours(mask, 0.5) + if not contours: + continue + contour = max(contours, key=len) + contour_global = contour.copy() + contour_global[:, 0] += min_r - 1 + contour_global[:, 1] += min_c - 1 + + # Local centroid from the mask (avoids a second regionprops call). + ys, xs = np.where(mask) + cy = float(ys.mean()) + min_r - 1 + cx = float(xs.mean()) + min_c - 1 + area = float(mask.sum()) + eq_diameter = float(np.sqrt(4 * area / np.pi)) + min_len = max(min_edge_length, min_edge_length_ratio * eq_diameter) + + # find_contours places level set 0.5 outside the integer pixel boundary. + bbox_targets = [ + ("h", float(min_r) - 0.5), + ("h", float(max_r) - 0.5), + ("v", float(min_c) - 0.5), + ("v", float(max_c) - 0.5), + ] + for axis, target in bbox_targets: + perp_axis = 0 if axis == "h" else 1 + run = _bbox_edge_run(contour_global, perp_axis, target, distance_tol, min_edge_coverage) + if run is None: + continue + ext_lo, ext_hi, length = run + if length < min_len: + continue + cell_coord = cy if axis == "h" else cx + normal = 1 if cell_coord > target else -1 + edges.append( + _CutEdge( + cell_id=lid, + axis=axis, + coord=target, + extent=(ext_lo, ext_hi), + normal_dir=normal, + length=float(length), + ) + ) + + return edges + + +# --------------------------------------------------------------------------- +# Stage 2: pair candidate enumeration + features +# --------------------------------------------------------------------------- + + +def _extent_overlap(a: tuple[float, float], b: tuple[float, float]) -> float: + return max(0.0, min(a[1], b[1]) - max(a[0], b[0])) + + +def _merge_shape_features( + labels_da: xr.DataArray | np.ndarray, + cell_ids: Iterable[int], + bboxes: dict[int, tuple[int, int, int, int]], + close_radius: int = _CLOSE_RADIUS, +) -> dict[str, float]: + """Materialise the union of given pieces, close the gap, and return shape stats. + + Solidity (area / convex_hull_area) and compactness (4*pi*A / P^2) drop + sharply when two unrelated cells are joined -- the union is concave at the + join. ``merge_compactness`` is the strongest single feature in the + calibration model. + """ + cell_list = [int(c) for c in cell_ids] + if not cell_list: + return {"merge_solidity": 0.0, "merge_compactness": 0.0} + + # Union bbox + padding to give morphological closing room. + rs = [bboxes[c][0] for c in cell_list if c in bboxes] + cs = [bboxes[c][1] for c in cell_list if c in bboxes] + re = [bboxes[c][2] for c in cell_list if c in bboxes] + ce = [bboxes[c][3] for c in cell_list if c in bboxes] + if not rs: + return {"merge_solidity": 0.0, "merge_compactness": 0.0} + pad = close_radius + 2 + H = labels_da.shape[-2] if hasattr(labels_da, "shape") else int(labels_da.sizes["y"]) + W = labels_da.shape[-1] if hasattr(labels_da, "shape") else int(labels_da.sizes["x"]) + r0 = max(min(rs) - pad, 0) + c0 = max(min(cs) - pad, 0) + r1 = min(max(re) + pad, H) + c1 = min(max(ce) + pad, W) + + crop = _read_bbox_slice(labels_da, r0, r1, c0, c1) + mask = np.isin(crop, cell_list) + if not mask.any(): + return {"merge_solidity": 0.0, "merge_compactness": 0.0} + + closed = binary_closing(mask, structure=morph_disk(close_radius)) + cc = cc_label(closed, connectivity=2) + if cc.max() == 0: + return {"merge_solidity": 0.0, "merge_compactness": 0.0} + sizes = np.bincount(cc.ravel()) + sizes[0] = 0 + biggest = int(sizes.argmax()) + region = regionprops((cc == biggest).astype(np.uint8))[0] + perimeter = max(region.perimeter, 1.0) + compactness = float(min(4 * np.pi * region.area / (perimeter * perimeter), 1.0)) + return {"merge_solidity": float(region.solidity), "merge_compactness": compactness} + + +def _pair_geometry_features(e: _CutEdge, c: _CutEdge, max_gap: float) -> dict[str, float] | None: + """Compute geometry-only features for a candidate pair, returning ``None`` + if the pair fails the basic facing/overlap/IoU filters. + """ + if c.normal_dir == e.normal_dir: + return None + # Facing: cell with +1 normal must sit at greater coord than cell with -1. + if (e.coord - c.coord) * e.normal_dir < -1e-6: + return None + overlap = _extent_overlap(e.extent, c.extent) + if overlap <= 0: + return None + union = e.length + c.length - overlap + iou = overlap / union if union > 0 else 0.0 + if iou < _CANDIDATE_MIN_IOU: + return None + gap = abs(e.coord - c.coord) + if gap > max_gap: + return None + endpoint_dist = abs(e.extent[0] - c.extent[0]) + abs(e.extent[1] - c.extent[1]) + max_len = max(e.length, c.length) + endpoint_match = max(0.0, 1.0 - endpoint_dist / max_len) if max_len > 0 else 0.0 + gap_score = 1.0 - gap / max_gap + return { + "iou": float(iou), + "endpoint_match": float(endpoint_match), + "gap_score": float(gap_score), + } + + +def _enumerate_pair_candidates( + edges: list[_CutEdge], + max_gap: float, +) -> list[tuple[_CutEdge, _CutEdge, dict[str, float]]]: + """Find all (e, c) pairs of facing cut edges with their geometry features. + + Returns one entry per surviving candidate. No selection / scoring yet. + """ + out: list[tuple[_CutEdge, _CutEdge, dict[str, float]]] = [] + by_axis: dict[str, list[_CutEdge]] = {"h": [], "v": []} + for e in edges: + by_axis[e.axis].append(e) + + for axis_edges in by_axis.values(): + axis_edges.sort(key=lambda e: e.coord) + coords = np.array([e.coord for e in axis_edges]) + for i, e in enumerate(axis_edges): + lo = int(np.searchsorted(coords, e.coord - max_gap, side="left")) + hi = int(np.searchsorted(coords, e.coord + max_gap, side="right")) + for j in range(lo, hi): + if j <= i: + continue # symmetry: emit each unordered pair once + c = axis_edges[j] + if c.cell_id == e.cell_id: + continue + feats = _pair_geometry_features(e, c, max_gap) + if feats is None: + continue + out.append((e, c, feats)) + return out + + +# --------------------------------------------------------------------------- +# Stage 4: scoring (arithmetic mean of geometry + shape features) +# --------------------------------------------------------------------------- + + +def _score_pair_features(features: dict[str, float]) -> float: + """Return the heuristic stitch score in [0, 1]. + + Arithmetic mean of the four features in :data:`_SCORE_FEATURES`. The + score is dataset-independent and not a calibrated probability -- users + pick ``min_confidence`` based on their false-merge tolerance. + """ + return float(sum(features[name] for name in _SCORE_FEATURES) / len(_SCORE_FEATURES)) + + +def _score_pairs( + candidates: list[tuple[_CutEdge, _CutEdge, dict[str, float]]], + labels_da: xr.DataArray | np.ndarray, + bboxes: dict[int, tuple[int, int, int, int]], + min_confidence: float, +) -> list[_StitchPair]: + """Compute shape features per candidate, score, and apply confidence filter. + + Greedy: each cell keeps its highest-confidence pairing per axis. Mid-stripe + cells (two parallel cuts) retain one pair per axis. + """ + scored: list[_StitchPair] = [] + for e, c, geom in candidates: + shape = _merge_shape_features(labels_da, [e.cell_id, c.cell_id], bboxes) + feats = {**geom, **shape} + confidence = _score_pair_features(feats) + if confidence < min_confidence: + continue + # Canonicalise so cell_a < cell_b for deterministic union-find. + if e.cell_id < c.cell_id: + ea, eb = e, c + else: + ea, eb = c, e + scored.append( + _StitchPair( + cell_a=ea.cell_id, + cell_b=eb.cell_id, + axis=e.axis, + confidence=confidence, + iou=feats["iou"], + endpoint_match=feats["endpoint_match"], + gap_score=feats["gap_score"], + merge_solidity=feats["merge_solidity"], + merge_compactness=feats["merge_compactness"], + edge_a=ea, + edge_b=eb, + ) + ) + + # Deduplicate to one entry per (cell_a, cell_b, axis), keeping max confidence. + by_pair: dict[tuple[int, int, str], _StitchPair] = {} + for p in scored: + k = (p.cell_a, p.cell_b, p.axis) + if k not in by_pair or by_pair[k].confidence < p.confidence: + by_pair[k] = p + return sorted(by_pair.values(), key=lambda p: (-p.confidence, p.cell_a, p.cell_b)) + + +# --------------------------------------------------------------------------- +# Stage 5: group assembly via union-find + validation +# --------------------------------------------------------------------------- + + +class _UnionFind: + """Union-find with smallest-label-as-root for deterministic group IDs.""" + + def __init__(self) -> None: + self.parent: dict[int, int] = {} + + def find(self, x: int) -> int: + self.parent.setdefault(x, x) + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra, rb = self.find(a), self.find(b) + if ra == rb: + return + if ra < rb: + self.parent[rb] = ra + else: + self.parent[ra] = rb + + +def _validate_corner_junction( + pairs_in_group: list[_StitchPair], + max_gap: float, +) -> bool: + """For 4-piece groups, require the cut edges' endpoints to converge near + a single junction point. + + Two perpendicular cuts (one h, one v) define the junction; the four + pieces' edges should share endpoints there. If the spread is greater + than ``max_gap``, the group geometry is implausible. + """ + h_edges = [p.edge_a for p in pairs_in_group if p.axis == "h"] + [p.edge_b for p in pairs_in_group if p.axis == "h"] + v_edges = [p.edge_a for p in pairs_in_group if p.axis == "v"] + [p.edge_b for p in pairs_in_group if p.axis == "v"] + if not h_edges or not v_edges: + return True # not a corner; nothing to validate + + # Junction y is roughly the mean h-edge coord; junction x is mean v-edge coord. + junction_y = float(np.mean([e.coord for e in h_edges if e is not None])) + junction_x = float(np.mean([e.coord for e in v_edges if e is not None])) + + # Each h edge's extent should reach to junction_x within max_gap; each v + # edge's extent should reach to junction_y within max_gap. + for e in h_edges: + if e is None: + continue + if min(abs(e.extent[0] - junction_x), abs(e.extent[1] - junction_x)) > max_gap: + return False + for e in v_edges: + if e is None: + continue + if min(abs(e.extent[0] - junction_y), abs(e.extent[1] - junction_y)) > max_gap: + return False + return True + + +def _assemble_groups( + pairs: list[_StitchPair], + candidate_ids: Iterable[int], + max_group_size: int, + max_gap: float, +) -> tuple[dict[int, int], dict[int, float]]: + """Build stitch groups via union-find with size + corner validation. + + Returns + ------- + groups + ``cell_id -> group_id`` (group_id == own cell_id for unstitched). + confidences + ``cell_id -> stitch_confidence`` -- min over pairwise confidences in + the cell's group; ``1.0`` for confirmed-solo (no surviving pair). + """ + uf = _UnionFind() + candidate_list = [int(c) for c in candidate_ids] + for cid in candidate_list: + uf.find(cid) + for p in pairs: + uf.union(p.cell_a, p.cell_b) + + # Collect group members + the pairs internal to each group. + members: dict[int, list[int]] = {} + for cid in candidate_list: + members.setdefault(uf.find(cid), []).append(cid) + pairs_by_group: dict[int, list[_StitchPair]] = {} + for p in pairs: + root = uf.find(p.cell_a) + pairs_by_group.setdefault(root, []).append(p) + + groups: dict[int, int] = {} + confidences: dict[int, float] = {} + + for root, mem in members.items(): + size = len(mem) + group_pairs = pairs_by_group.get(root, []) + + # Size cap: collapse oversized groups back to singletons. + if size > max_group_size: + for m in mem: + groups[m] = m + confidences[m] = 1.0 + continue + + # Corner validation for 4-piece groups. + if size == 4 and not _validate_corner_junction(group_pairs, max_gap): + for m in mem: + groups[m] = m + confidences[m] = 1.0 + continue + + if size == 1: + groups[mem[0]] = mem[0] + confidences[mem[0]] = 1.0 + continue + + # Group confidence = min over pairwise confidences (weakest link). + group_conf = float(min(p.confidence for p in group_pairs)) + for m in mem: + groups[m] = root + confidences[m] = group_conf + + return groups, confidences + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def stitch_tile_cuts( + sdata: sd.SpatialData, + labels_key: str, + qc_table_key: str | None = None, + min_confidence: float = 0.7, + max_gap: float = 3.0, + max_group_size: int = 4, + inplace: bool = True, +) -> ad.AnnData | None: + """Stitch tile-cut cells flagged by :func:`calculate_tiling_qc`. + + Reads ``is_outlier=True`` cells from the QC table, pairs facing cut edges + across tile boundaries, scores each pair via a transparent geometric + composite, and assembles high-confidence pairs into stitch groups via + union-find. + + The score per pair is the arithmetic mean of four features in [0, 1]: + ``iou`` (1-D extent overlap), ``endpoint_match`` (chord endpoints + coincide), ``merge_compactness`` (``4*pi*A / P^2`` of the closed union + mask), and ``merge_solidity`` (union area / convex hull area). No + coefficients are fitted or shipped; the formula is recorded in + ``.uns["tiling_stitch"]`` so users can audit and re-derive offline. + + The labels element is **never modified** -- only ``.obs`` columns are + written. Materialising a stitched labels element is opt-in via + :func:`squidpy.experimental.im.make_stitched_labels`. + + Parameters + ---------- + sdata + :class:`~spatialdata.SpatialData` with a labels element and a QC + table from :func:`calculate_tiling_qc`. + labels_key + Key in ``sdata.labels``. + qc_table_key + Key of the QC table. Defaults to ``"{labels_key}_qc"``. + min_confidence + Threshold on ``stitch_confidence`` (arithmetic mean of the four + geometric features). ``0.7`` (default) is a starting point; + raise it for stricter precision, lower for recall. Tune for + your data -- the score is heuristic, not a calibrated probability. + max_gap + Maximum perpendicular distance (px) between facing cut edges to be + considered a candidate pair. + max_group_size + Cap on group size; oversized groups (likely false merges) collapse + to singletons. + inplace + If ``True``, write back into ``sdata.tables[qc_table_key]``. + Otherwise return the modified AnnData. + + Returns + ------- + The QC :class:`~anndata.AnnData` with four new ``.obs`` columns when + ``inplace=False``, otherwise ``None``. + """ + if labels_key not in sdata.labels: + raise ValueError(f"Labels key '{labels_key}' not found in sdata.labels.") + if min_confidence < 0 or min_confidence > 1: + raise ValueError(f"min_confidence must be in [0, 1], got {min_confidence}.") + if max_gap < 0: + raise ValueError(f"max_gap must be non-negative, got {max_gap}.") + if max_group_size < 1: + raise ValueError(f"max_group_size must be >= 1, got {max_group_size}.") + + table_key = qc_table_key if qc_table_key is not None else f"{labels_key}_qc" + if table_key not in sdata.tables: + raise ValueError(f"QC table '{table_key}' not found. Run calculate_tiling_qc first.") + adata = sdata.tables[table_key].copy() + + if "is_outlier" not in adata.obs.columns: + raise ValueError(f"QC table '{table_key}' is missing 'is_outlier'; re-run calculate_tiling_qc.") + if "label_id" not in adata.obs.columns: + raise ValueError(f"QC table '{table_key}' is missing 'label_id'.") + + existing = [c for c in _STITCH_COLUMNS if c in adata.obs.columns] + if existing: + logg.warning(f"Overwriting existing stitch columns: {existing}.") + adata.obs.drop(columns=existing, inplace=True) + + # Resolve which labels DataArray was used at QC time (multi-scale aware). + qc_params = adata.uns.get("tiling_qc", {}) + scale = qc_params.get("scale") + labels_da = resolve_labels_array(sdata, labels_key, scale) + + label_ids = adata.obs["label_id"].astype(int).to_numpy() + is_outlier = adata.obs["is_outlier"].to_numpy(dtype=bool) + outlier_ids = label_ids[is_outlier].tolist() + + n_outliers = len(outlier_ids) + logg.info(f"Stitching {n_outliers} outlier cells (out of {len(label_ids)} total).") + + if n_outliers == 0: + logg.warning("No outliers flagged; nothing to stitch.") + groups: dict[int, int] = {} + confidences: dict[int, float] = {} + edges: list[_CutEdge] = [] + pairs: list[_StitchPair] = [] + else: + bboxes = _compute_outlier_bboxes(labels_da, outlier_ids) + missing = [lid for lid in outlier_ids if lid not in bboxes] + if missing: + logg.warning( + f"{len(missing)} outlier label_id(s) flagged in the QC table do not appear " + f"in '{labels_key}' (e.g. {missing[:5]}); they will not be stitched." + ) + edges = _extract_cut_edges(labels_da, outlier_ids, bboxes=bboxes) + candidates = _enumerate_pair_candidates(edges, max_gap=max_gap) + pairs = _score_pairs(candidates, labels_da, bboxes, min_confidence=min_confidence) + groups, confidences = _assemble_groups(pairs, outlier_ids, max_group_size=max_group_size, max_gap=max_gap) + + # Write .obs columns with three states distinguished by stitch_confidence: + # - non-outlier cell -> own label_id, False, 1, NaN (not evaluated) + # - outlier solo -> own label_id, False, 1, 1.0 (checked, no partner) + # - outlier stitched -> shared root, True, n, calibrated P + n = len(label_ids) + stitch_group_id = label_ids.copy() + is_stitched = np.zeros(n, dtype=bool) + n_pieces = np.ones(n, dtype=np.int32) + stitch_confidence = np.full(n, np.nan, dtype=np.float64) + + group_sizes: dict[int, int] = {} + if outlier_ids: + for root in groups.values(): + group_sizes[root] = group_sizes.get(root, 0) + 1 + + id_to_idx = {int(lid): i for i, lid in enumerate(label_ids)} + for cid, root in groups.items(): + i = id_to_idx[int(cid)] + stitch_group_id[i] = int(root) + size = group_sizes[root] + n_pieces[i] = size + is_stitched[i] = size > 1 + stitch_confidence[i] = float(confidences.get(cid, 1.0)) + + adata.obs["stitch_group_id"] = stitch_group_id + adata.obs["is_stitched"] = is_stitched + adata.obs["n_pieces"] = n_pieces + adata.obs["stitch_confidence"] = stitch_confidence + + n_groups = sum(1 for s in group_sizes.values() if s > 1) + n_stitched = int(is_stitched.sum()) + # Use string keys so the dict round-trips through zarr-backed .uns cleanly. + pieces_dist: dict[str, int] = {} + for s in group_sizes.values(): + if s > 1: + key = str(int(s)) + pieces_dist[key] = pieces_dist.get(key, 0) + 1 + + adata.uns[_METHOD_KEY] = { + "min_confidence": float(min_confidence), + "max_gap": float(max_gap), + "max_group_size": int(max_group_size), + "n_outliers": int(n_outliers), + "n_candidate_pairs": int(len(pairs)), + "n_stitched_groups": int(n_groups), + "n_stitched_cells": int(n_stitched), + "n_pieces_distribution": pieces_dist, + "score_features": list(_SCORE_FEATURES), + "score_formula": _SCORE_FORMULA, + } + + if not inplace: + return adata + sdata.tables[table_key] = adata + return None diff --git a/tests/_images/StitchVisual_seam_before_after.png b/tests/_images/StitchVisual_seam_before_after.png new file mode 100644 index 00000000..00bde1ad Binary files /dev/null and b/tests/_images/StitchVisual_seam_before_after.png differ diff --git a/tests/_images/StitchVisual_seam_join_labels.png b/tests/_images/StitchVisual_seam_join_labels.png new file mode 100644 index 00000000..b60e69b1 Binary files /dev/null and b/tests/_images/StitchVisual_seam_join_labels.png differ diff --git a/tests/experimental/test_stitched_labels.py b/tests/experimental/test_stitched_labels.py new file mode 100644 index 00000000..f041ae0b --- /dev/null +++ b/tests/experimental/test_stitched_labels.py @@ -0,0 +1,302 @@ +"""Tests for sq.experimental.im.make_stitched_labels.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import squidpy as sq + + +def _qc_and_stitch(sdata, **stitch_kwargs): + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + nmads_cut=1.0, + nmads_smoothed=1.5, + ) + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels", **stitch_kwargs) + + +class TestMakeStitchedLabels: + def test_creates_new_labels_element(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + assert "labels_stitched" not in sdata.labels + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + assert "labels_stitched" in sdata.labels + + def test_original_labels_unchanged(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + original_arr = np.asarray(sdata.labels["labels"].values).copy() + _qc_and_stitch(sdata) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + after_arr = np.asarray(sdata.labels["labels"].values) + np.testing.assert_array_equal(original_arr, after_arr) + + def test_remap_unifies_stitched_pieces(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched cells in this fixture realisation") + + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + new_arr = np.asarray(sdata.labels["labels_stitched"].values) + old_arr = np.asarray(sdata.labels["labels"].values) + + # Pick one stitched group with >= 2 pieces + gid = int(stitched["stitch_group_id"].iloc[0]) + pieces = stitched.loc[stitched["stitch_group_id"] == gid, "label_id"].astype(int).tolist() + assert len(pieces) >= 2 + + # All original pixels of those pieces should now carry the group id + for piece_id in pieces: + mask = old_arr == piece_id + assert mask.any() + assert (new_arr[mask] == gid).all(), f"piece {piece_id} not remapped to {gid}" + + def test_unstitched_pieces_keep_their_id(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + old_arr = np.asarray(sdata.labels["labels"].values) + new_arr = np.asarray(sdata.labels["labels_stitched"].values) + # Pixels with label 0 (background) stay 0 + bg = old_arr == 0 + assert (new_arr[bg] == 0).all() + # Cells whose group_id == label_id are unchanged in the remap + adata = sdata.tables["labels_qc"] + unstitched = adata.obs[adata.obs["stitch_group_id"].astype(int) == adata.obs["label_id"].astype(int)] + # Spot-check the first 5 unstitched + for lid in unstitched["label_id"].astype(int).iloc[:5]: + mask = old_arr == lid + if mask.any(): + assert (new_arr[mask] == lid).all() + + def test_collapsed_table_one_row_per_group(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + assert "labels_stitched_table" in sdata.tables + agg = sdata.tables["labels_stitched_table"] + adata = sdata.tables["labels_qc"] + # Output has one row per unique stitch_group_id (unstitched cells stay + # as singleton groups, stitched groups collapse to one row). + n_groups = adata.obs["stitch_group_id"].nunique() + assert agg.n_obs == n_groups + for col in ("label_id", "stitch_group_id", "n_pieces", "is_stitched", "stitch_confidence"): + assert col in agg.obs.columns + + def test_collapsed_table_includes_unstitched_cells(self, sdata_tile_boundary): + """Both stitched (collapsed) and unstitched (passthrough) rows present.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + agg = sdata.tables["labels_stitched_table"] + # At least some unstitched cells should be in the output. + assert (~agg.obs["is_stitched"].astype(bool)).sum() > 0, "expected unstitched rows" + # The is_stitched column flags which rows are collapsed groups. + if agg.obs["is_stitched"].astype(bool).sum() > 0: + assert (agg.obs.loc[agg.obs["is_stitched"].astype(bool), "n_pieces"] >= 2).all() + + def test_merge_strategy_sum_aggregates_numeric_columns(self, sdata_tile_boundary): + """For a stitched group, a synthetic numeric column should sum across pieces.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + adata.obs["fake_area"] = 100.0 + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="sum") + agg = sdata.tables["labels_stitched_table"] + stitched = agg.obs[agg.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched groups in this realisation") + # Each stitched group has n_pieces members each contributing 100. + np.testing.assert_array_equal( + stitched["fake_area"].to_numpy(), + stitched["n_pieces"].to_numpy() * 100.0, + ) + + def test_merge_strategy_mean_aggregates_numeric_columns(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + adata.obs["fake_intensity"] = 42.0 + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="mean") + agg = sdata.tables["labels_stitched_table"] + stitched = agg.obs[agg.obs["is_stitched"].astype(bool)] + if len(stitched) > 0: + np.testing.assert_allclose(stitched["fake_intensity"].to_numpy(), 42.0) + + def test_merge_strategy_callable(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + adata.obs["fake_count"] = 1 + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels( + sdata, + labels_key="labels", + write_table=True, + merge_strategy=lambda s: len(s), + ) + agg = sdata.tables["labels_stitched_table"] + # Callable returns len of group, so fake_count == n_pieces post-merge. + np.testing.assert_array_equal( + agg.obs["fake_count"].to_numpy(), + agg.obs["n_pieces"].to_numpy(), + ) + + def test_invalid_merge_strategy_raises(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + with pytest.raises(ValueError, match="Unknown merge_strategy"): + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", merge_strategy="bogus") + + def test_group_invariant_columns_take_first(self, sdata_tile_boundary): + """is_stitched, n_pieces, stitch_confidence are not affected by sum strategy.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata_orig = sdata.tables["labels_qc"] + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, merge_strategy="sum") + agg = sdata.tables["labels_stitched_table"] + # n_pieces should be in {1, 2, 3, 4} -- if "sum" had been applied to it, + # a 4-piece group would show n_pieces = 16. + assert (agg.obs["n_pieces"].astype(int) <= 4).all() + # Members of a stitch group share is_stitched value; collapsed row should match. + stitched = adata_orig.obs[adata_orig.obs["is_stitched"].astype(bool)] + for gid in stitched["stitch_group_id"].astype(int).unique(): + row = agg.obs[agg.obs["stitch_group_id"].astype(int) == gid] + assert len(row) == 1 + assert bool(row["is_stitched"].iloc[0]) is True + + def test_aggregated_table_preserves_qc_columns_and_uns(self, sdata_tile_boundary): + """The reduced table must keep the QC table's obs columns and uns + instead of constructing a fresh AnnData from scratch.""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + # User adds a custom obs column to simulate downstream annotation. + adata.obs["my_custom_flag"] = True + sdata.tables["labels_qc"] = adata + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + agg = sdata.tables["labels_stitched_table"] + # Original QC obs columns survive + for col in ( + "max_straight_edge_ratio", + "cardinal_alignment_score", + "cut_score", + "smoothed_cut_score", + "is_outlier", + "nhood_outlier_fraction", + "centroid_y", + "centroid_x", + "my_custom_flag", + ): + assert col in agg.obs.columns, f"missing preserved column: {col}" + # Uns surfaces survive (tiling_qc params, tiling_stitch params) + assert "tiling_qc" in agg.uns + assert "tiling_stitch" in agg.uns + # spatialdata_attrs now points at the stitched labels element + attrs = agg.uns["spatialdata_attrs"] + assert attrs["region"] == "labels_stitched" + assert attrs["instance_key"] == "label_id" + + def test_aggregated_table_label_id_matches_new_element_ids(self, sdata_tile_boundary): + """label_id values in the table must equal the IDs in the new labels + element (the stitch_group_id values become the new instance keys).""" + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True) + agg = sdata.tables["labels_stitched_table"] + new_arr = np.asarray(sdata.labels["labels_stitched"].values) + unique_in_image = set(np.unique(new_arr).tolist()) - {0} + unique_in_table = set(agg.obs["label_id"].astype(int).tolist()) + # Every row in the table must reference an existing instance in the labels element. + assert unique_in_table.issubset(unique_in_image), f"orphan rows: {unique_in_table - unique_in_image}" + + def test_errors_when_stitch_not_run(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200) + with pytest.raises(ValueError, match="stitch_group_id"): + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + + def test_errors_on_missing_labels_key(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + with pytest.raises(ValueError, match="not found"): + sq.experimental.im.make_stitched_labels(sdata, labels_key="bogus") + + def test_idempotent(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + first = np.asarray(sdata.labels["labels_stitched"].values).copy() + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + second = np.asarray(sdata.labels["labels_stitched"].values) + np.testing.assert_array_equal(first, second) + + def test_join_labels_false_keeps_multi_component(self, sdata_tile_boundary): + from skimage.measure import label as cc_label + + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched cells in this fixture realisation") + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=False) + arr = np.asarray(sdata.labels["labels_stitched"].values) + # At least one stitched group should have >1 connected component + # (the unjoined behaviour leaves the cut stripe as background). + any_multi = False + for gid in stitched["stitch_group_id"].astype(int).unique()[:5]: + mask = arr == gid + if mask.any(): + ncc = int(cc_label(mask).max()) + if ncc > 1: + any_multi = True + break + assert any_multi, "expected at least one multi-component stitched group with join_labels=False" + + def test_join_labels_true_unifies_components(self, sdata_tile_boundary): + from skimage.measure import label as cc_label + + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + if len(stitched) == 0: + pytest.skip("no stitched cells in this fixture realisation") + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=True) + arr = np.asarray(sdata.labels["labels_stitched"].values) + for gid in stitched["stitch_group_id"].astype(int).unique(): + mask = arr == gid + if not mask.any(): + continue + ncc = int(cc_label(mask).max()) + assert ncc == 1, f"group {gid} still has {ncc} components after join_labels=True" + + def test_join_labels_does_not_overwrite_other_cells(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _qc_and_stitch(sdata, min_confidence=0.5) + adata = sdata.tables["labels_qc"] + # Snapshot every non-stitched cell's pixel set before joining, then + # confirm none of those pixels changed identity afterwards. + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=False) + before_arr = np.asarray(sdata.labels["labels_stitched"].values).copy() + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", join_labels=True) + after_arr = np.asarray(sdata.labels["labels_stitched"].values) + non_stitched_gids = ( + adata.obs.loc[~adata.obs["is_stitched"].astype(bool), "stitch_group_id"].astype(int).unique() + ) + for gid in non_stitched_gids[:20]: + before_mask = before_arr == gid + if not before_mask.any(): + continue + # Non-stitched cells must keep all their original pixels. + assert (after_arr[before_mask] == gid).all(), f"non-stitched cell {gid} was overwritten" diff --git a/tests/experimental/test_tiling_stitch.py b/tests/experimental/test_tiling_stitch.py new file mode 100644 index 00000000..5d938b99 --- /dev/null +++ b/tests/experimental/test_tiling_stitch.py @@ -0,0 +1,403 @@ +"""Tests for sq.experimental.tl.stitch_tile_cuts.""" + +from __future__ import annotations + +import dask.array as da +import matplotlib.pyplot as plt +import numpy as np +import pytest +import xarray as xr +from spatialdata import SpatialData +from spatialdata.models import Labels2DModel + +import squidpy as sq +from tests.conftest import PlotTester, PlotTesterMeta + + +def _run_qc_and_stitch(sdata, **stitch_kwargs): + """Run QC + stitch on the fixture sdata; return the resulting AnnData.""" + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + nmads_cut=1.0, + nmads_smoothed=1.5, + ) + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels", **stitch_kwargs) + return sdata.tables["labels_qc"] + + +# --------------------------------------------------------------------------- +# Smoke + column contract +# --------------------------------------------------------------------------- + + +class TestStitchObsContract: + """The 4 .obs columns and the NaN-vs-1.0 confidence convention.""" + + def test_columns_present(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata) + for col in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence"): + assert col in adata.obs.columns, f"missing {col}" + + def test_non_outliers_have_nan_confidence(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata) + non_outliers = ~adata.obs["is_outlier"].astype(bool) + # At least some non-outliers should exist + assert non_outliers.sum() > 0 + # All non-outliers must have NaN confidence + assert adata.obs.loc[non_outliers, "stitch_confidence"].isna().all() + # And they keep their own label_id as group + assert (adata.obs.loc[non_outliers, "stitch_group_id"] == adata.obs.loc[non_outliers, "label_id"]).all() + # And n_pieces == 1, is_stitched False + assert (adata.obs.loc[non_outliers, "n_pieces"] == 1).all() + assert (~adata.obs.loc[non_outliers, "is_stitched"].astype(bool)).all() + + def test_solo_outliers_have_1p0_confidence(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata) + solo_outliers = adata.obs["is_outlier"].astype(bool) & ~adata.obs["is_stitched"].astype(bool) + # If any solo outliers exist, they get confidence 1.0 (checked, no partner) + if solo_outliers.sum() > 0: + assert (adata.obs.loc[solo_outliers, "stitch_confidence"] == 1.0).all() + + def test_stitched_have_calibrated_confidence(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.5) + stitched = adata.obs["is_stitched"].astype(bool) + if stitched.sum() > 0: + confs = adata.obs.loc[stitched, "stitch_confidence"] + # all in [min_confidence, 1.0] + assert (confs >= 0.5).all() + assert (confs <= 1.0).all() + # n_pieces between 2 and max_group_size (default 4) + sizes = adata.obs.loc[stitched, "n_pieces"] + assert (sizes >= 2).all() + assert (sizes <= 4).all() + + def test_group_id_shared_within_group(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.5) + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + for gid, members in stitched.groupby("stitch_group_id"): + n = members["n_pieces"].iloc[0] + assert len(members) == n, f"group {gid}: {len(members)} rows but n_pieces={n}" + + +# --------------------------------------------------------------------------- +# Audit trail +# --------------------------------------------------------------------------- + + +class TestUnsMetadata: + def test_uns_records_params_and_score_formula(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.7, max_gap=4.0, max_group_size=4) + assert "tiling_stitch" in adata.uns + meta = adata.uns["tiling_stitch"] + assert meta["min_confidence"] == 0.7 + assert meta["max_gap"] == 4.0 + assert meta["max_group_size"] == 4 + # No fitted coefficients -- transparent formula instead. + assert "model_coefficients" not in meta + assert "model_intercept" not in meta + assert "score_formula" in meta + assert set(meta["score_features"]) == { + "iou", + "endpoint_match", + "merge_compactness", + "merge_solidity", + } + + +# --------------------------------------------------------------------------- +# Behaviour vs ground truth +# --------------------------------------------------------------------------- + + +class TestRecoveryVsGroundTruth: + def test_stitches_some_cuts(self, sdata_tile_boundary): + sdata, gt = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.5) + # Of the cells the fixture marks as cut pieces, some should end up stitched. + cut_mask = adata.obs["label_id"].isin(gt.cut_cell_ids) + n_cut_in_stitched = (cut_mask & adata.obs["is_stitched"].astype(bool)).sum() + assert n_cut_in_stitched > 0, "expected at least some cut pieces to be stitched" + + def test_no_intact_cells_get_stitched_at_high_threshold(self, sdata_tile_boundary): + sdata, gt = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.9) + # At threshold 0.9 (high precision), intact cells should not falsely merge. + intact_mask = adata.obs["label_id"].isin(gt.intact_cell_ids) + # Allow up to a handful of false merges given the dense ellipse fixture. + n_false = (intact_mask & adata.obs["is_stitched"].astype(bool)).sum() + assert n_false <= 5, f"too many intact cells flagged stitched: {n_false}" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + def test_missing_qc_table_errors(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + with pytest.raises(ValueError, match="QC table"): + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels") + + def test_missing_labels_key_errors(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + with pytest.raises(ValueError, match="not found in sdata.labels"): + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="bogus") + + def test_invalid_min_confidence(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + with pytest.raises(ValueError, match="min_confidence"): + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels", min_confidence=1.5) + + +# --------------------------------------------------------------------------- +# Idempotency + inplace +# --------------------------------------------------------------------------- + + +class TestIdempotencyAndInplace: + def test_rerun_overwrites_with_warning(self, sdata_tile_boundary, caplog): + sdata, _ = sdata_tile_boundary + _run_qc_and_stitch(sdata) + # Second run should warn about overwrite, leave column count unchanged. + n_cols_before = len(sdata.tables["labels_qc"].obs.columns) + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels") + n_cols_after = len(sdata.tables["labels_qc"].obs.columns) + assert n_cols_before == n_cols_after + + def test_inplace_false_returns_adata_without_writing(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200) + n_cols_before = len(sdata.tables["labels_qc"].obs.columns) + result = sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels", inplace=False) + n_cols_after = len(sdata.tables["labels_qc"].obs.columns) + assert result is not None + assert "stitch_group_id" in result.obs.columns + # In-place table unchanged + assert n_cols_before == n_cols_after + + +# --------------------------------------------------------------------------- +# QC re-run drops stitch columns +# --------------------------------------------------------------------------- + + +class TestResolveLabelsArray: + """resolve_labels_array unit-tests; verifies the multi-scale branch.""" + + def test_single_scale_passthrough(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + from squidpy.experimental.tl._tiling_qc import resolve_labels_array + + out = resolve_labels_array(sdata, "labels", scale=None) + assert isinstance(out, xr.DataArray) + + def test_multiscale_requires_scale(self): + from squidpy.experimental.tl._tiling_qc import resolve_labels_array + + labels = np.zeros((64, 64), dtype=np.int32) + labels[10:20, 10:20] = 1 + labels_xr = xr.DataArray(da.from_array(labels, chunks=(32, 32)), dims=("y", "x")) + sdata = SpatialData(labels={"labels_ms": Labels2DModel.parse(labels_xr, scale_factors=[2])}) + # No scale -> error. + with pytest.raises(ValueError, match="multi-scale"): + resolve_labels_array(sdata, "labels_ms", scale=None) + # With scale -> returns the resolved DataArray. + out = resolve_labels_array(sdata, "labels_ms", scale="scale0") + assert isinstance(out, xr.DataArray) + + +class TestMultiScaleEndToEnd: + """QC -> stitch on a multiscale labels element.""" + + def _make_sdata(self) -> SpatialData: + from tests.experimental.conftest import make_tile_boundary_sdata + + sdata, _ = make_tile_boundary_sdata() + # Wrap the existing single-scale labels element as multiscale. + labels_arr = np.asarray(sdata.labels["labels"].values) + labels_xr = xr.DataArray(da.from_array(labels_arr, chunks=(200, 200)), dims=("y", "x")) + ms = Labels2DModel.parse(labels_xr, scale_factors=[2]) + sdata = SpatialData( + images={"image": sdata.images["image"]}, + labels={"labels": ms}, + ) + return sdata + + def test_stitch_runs_on_multiscale(self): + sdata = self._make_sdata() + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + scale="scale0", + tile_size=200, + nmads_cut=1.0, + nmads_smoothed=1.5, + ) + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels") + adata = sdata.tables["labels_qc"] + for col in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence"): + assert col in adata.obs.columns + + def test_make_stitched_labels_runs_on_multiscale(self): + sdata = self._make_sdata() + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + scale="scale0", + tile_size=200, + nmads_cut=1.0, + nmads_smoothed=1.5, + ) + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels") + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + assert "labels_stitched" in sdata.labels + + +class TestInplaceFalseMakeStitchedLabels: + def test_inplace_false_returns_dict_without_writing(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _run_qc_and_stitch(sdata) + assert "labels_stitched" not in sdata.labels + assert "labels_stitched_table" not in sdata.tables + result = sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=True, inplace=False) + # Nothing written + assert "labels_stitched" not in sdata.labels + assert "labels_stitched_table" not in sdata.tables + # Result has both objects + assert isinstance(result, dict) + assert result["labels"] is not None + assert result["table"] is not None + + def test_inplace_false_no_table_when_disabled(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _run_qc_and_stitch(sdata) + result = sq.experimental.im.make_stitched_labels(sdata, labels_key="labels", write_table=False, inplace=False) + assert result["labels"] is not None + assert result["table"] is None + + +class TestQCRerunDropsStitch: + def test_qc_rerun_removes_stitch_columns(self, sdata_tile_boundary, caplog): + sdata, _ = sdata_tile_boundary + _run_qc_and_stitch(sdata) + # Re-running QC should produce a table without stitch columns. + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200) + adata = sdata.tables["labels_qc"] + for col in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence"): + assert col not in adata.obs.columns + + +# --------------------------------------------------------------------------- +# Visual: before/after stitching, zoomed on a tile seam +# --------------------------------------------------------------------------- + + +def _label_to_rgb(arr: np.ndarray, seed: int = 0) -> np.ndarray: + """Render a label image with a stable random colour per label, background black.""" + rng = np.random.default_rng(seed) + n = int(arr.max()) + 1 + colors = rng.random((n, 3)) + colors[0] = 0.0 # background + return colors[arr] + + +class TestStitchVisual(PlotTester, metaclass=PlotTesterMeta): + """Visual baselines comparing labels before/after stitch_tile_cuts + + make_stitched_labels, zoomed in on a tile seam. Baselines live in + ``tests/_images/StitchVisual_*.png`` and are downloaded from CI artifacts; + they are not generated locally. + """ + + # Hardcoded crop window around the seam at y=200 (fixture tile_borders_y[0]). + _ZOOM = (150, 250, 250, 350) # (y0, y1, x0, x1) + _SEAM_Y = 200 + + def test_plot_seam_before_after(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + nmads_cut=1.0, + nmads_smoothed=1.5, + ) + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels", min_confidence=0.5) + sq.experimental.im.make_stitched_labels(sdata, labels_key="labels") + + y0, y1, x0, x1 = self._ZOOM + before = np.asarray(sdata.labels["labels"].values)[y0:y1, x0:x1] + after = np.asarray(sdata.labels["labels_stitched"].values)[y0:y1, x0:x1] + + # Same colour seed for both panels: unstitched cells look identical, + # only stitched pieces change colour. + before_rgb = _label_to_rgb(before, seed=0) + after_rgb = _label_to_rgb(after, seed=0) + + fig, axes = plt.subplots(1, 2, figsize=(8, 4)) + for ax, rgb, title in zip(axes, [before_rgb, after_rgb], ["Before", "After"], strict=True): + ax.imshow(rgb, interpolation="nearest") + ax.axhline(self._SEAM_Y - y0, color="white", linestyle="--", linewidth=1.0) + ax.set_title(title) + ax.set_xticks([]) + ax.set_yticks([]) + fig.tight_layout() + + def test_plot_seam_join_labels(self, sdata_tile_boundary): + """Side-by-side: After (join_labels=False) vs After (join_labels=True). + + join_labels=False leaves stitched cells as multi-component regions + (the cut stripe stays at 0); join_labels=True morphologically closes + the gap so the stitched cell is a single connected component. + """ + sdata, _ = sdata_tile_boundary + sq.experimental.tl.calculate_tiling_qc( + sdata, + labels_key="labels", + tile_size=200, + nmads_cut=1.0, + nmads_smoothed=1.5, + ) + sq.experimental.tl.stitch_tile_cuts(sdata, labels_key="labels", min_confidence=0.5) + + # Run twice with different join settings into separate output keys. + sq.experimental.im.make_stitched_labels( + sdata, + labels_key="labels", + labels_key_added="labels_stitched_split", + write_table=False, + join_labels=False, + ) + sq.experimental.im.make_stitched_labels( + sdata, + labels_key="labels", + labels_key_added="labels_stitched_joined", + write_table=False, + join_labels=True, + ) + + y0, y1, x0, x1 = self._ZOOM + split = np.asarray(sdata.labels["labels_stitched_split"].values)[y0:y1, x0:x1] + joined = np.asarray(sdata.labels["labels_stitched_joined"].values)[y0:y1, x0:x1] + + split_rgb = _label_to_rgb(split, seed=0) + joined_rgb = _label_to_rgb(joined, seed=0) + + fig, axes = plt.subplots(1, 2, figsize=(8, 4)) + for ax, rgb, title in zip( + axes, [split_rgb, joined_rgb], ["join_labels=False", "join_labels=True"], strict=True + ): + ax.imshow(rgb, interpolation="nearest") + ax.axhline(self._SEAM_Y - y0, color="white", linestyle="--", linewidth=1.0) + ax.set_title(title) + ax.set_xticks([]) + ax.set_yticks([]) + fig.tight_layout()