diff --git a/pyproject.toml b/pyproject.toml index 097a80e..972d3b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,9 @@ dependencies = [ "opencv-python", "pandas", "polars", + "pqdm", "pyarrow", + "rtree", "scanpy", "scipy", "shapely", @@ -29,6 +31,28 @@ dependencies = [ "scikit-learn", "tifffile", "torch_geometric", + "zarr", +] + +[project.optional-dependencies] +spatialdata = [ + "spatialdata>=0.7.2", + "spatialdata-io>=0.6.0", +] + +spatialdata-io = [ + "spatialdata-io>=0.6.0", +] + +spatialdata-all = [ + "spatialdata>=0.7.2", + "spatialdata-io>=0.6.0", + "sopa>=2.0.0", +] + +plot = [ + "matplotlib>=3.7", + "uniplot>=0.10.0", ] [build-system] @@ -39,4 +63,4 @@ build-backend = "hatchling.build" packages = ["src/segger"] [project.scripts] -segger = "segger.cli.main:app" \ No newline at end of file +segger = "segger.cli.main:app" diff --git a/src/segger/cli/segment.py b/src/segger/cli/segment.py index dff7e05..6a1e0aa 100644 --- a/src/segger/cli/segment.py +++ b/src/segger/cli/segment.py @@ -58,6 +58,21 @@ help="Related to loss function parameters.", sort_key=7, ) +group_quality = Group( + name="Quality Filtering", + help="Related to transcript quality filtering.", + sort_key=8, +) +group_3d = Group( + name="3D Support", + help="Related to 3D coordinate handling.", + sort_key=9, +) + +def _resolve_use_3d_flag(use_3d: Literal["auto", "true", "false"]) -> bool | str: + if use_3d == "auto": + return "auto" + return use_3d == "true" app_segment = App(name="segment", help="Run cell segmentation on spatial transcriptomics data.") @@ -293,16 +308,53 @@ def segment( "save_anndata", group=group_io, )] = registry.get_default("save_anndata"), + + save_spatialdata: Annotated[bool, registry.get_parameter( + "save_spatialdata", + group=group_io, # might change + )] = registry.get_default("save_spatialdata"), + + boundary_method: Annotated[ + Literal["convex_hull", "delaunay", "skip"], + registry.get_parameter( + "boundary_method", + group=group_io, # might change + )] = registry.get_default("boundary_method"), debug: Annotated[bool, Parameter( help="Whether to save additional debug information (trainer, predictions).", )] = "none", + + # Quality filtering + min_qv: Annotated[float | None, Parameter( + help="Minimum transcript quality threshold. Set to 0 to disable.", + validator=validators.Number(gte=0), + group=group_quality, + )] = 20.0, + + # 3D support + use_3d: Annotated[ + Literal["auto", "true", "false"], + Parameter( + help="Use 3D coordinates for graph construction ('false' default).", + group=group_3d, + ), + ] = "false", ): """Run cell segmentation on spatial transcriptomics data.""" # Setup logger and debug directory logger = logging.getLogger(__name__) + use_3d_value = _resolve_use_3d_flag(use_3d) + + output_directory = Path(output_directory) + if output_directory.exists() and not output_directory.is_dir(): + raise ValueError( + f"Output path exists and is not a directory: {output_directory}" + ) + output_directory.mkdir(parents=True, exist_ok=True) + # Remove SLURM environment autodetect from lightning.pytorch.plugins.environments import SLURMEnvironment SLURMEnvironment.detect = lambda: False @@ -328,6 +380,8 @@ def segment( tiling_margin_prediction=tiling_margin_prediction, tiling_nodes_per_tile=max_nodes_per_tile, edges_per_batch=max_edges_per_batch, + use_3d=use_3d_value, + min_qv=min_qv, ) # Setup Lightning Model @@ -364,8 +418,11 @@ def segment( csvlogger = CSVLogger(output_directory) writer = ISTSegmentationWriter( + input_directory, output_directory, save_anndata=save_anndata, + save_spatialdata=save_spatialdata, + boundary_method=boundary_method, debug=debug, ) trainer = Trainer( diff --git a/src/segger/data/data_module.py b/src/segger/data/data_module.py index efcdde2..512c699 100644 --- a/src/segger/data/data_module.py +++ b/src/segger/data/data_module.py @@ -5,11 +5,12 @@ from lightning.pytorch import LightningDataModule from torchvision.transforms import Compose from dataclasses import dataclass -from typing import Literal +from typing import Literal, Optional from pathlib import Path import polars as pl import torch import gc +import os import numpy as np from .tile_dataset import ( @@ -143,6 +144,8 @@ class ISTDataModule(LightningDataModule): prediction_graph_mode: Literal["nucleus", "cell", "uniform"] = "cell" prediction_graph_max_k: int = 3 prediction_graph_buffer_ratio: float = 0.05 + use_3d: bool | Literal["auto"] = False + min_qv: Optional[float] = 20.0 tiling_mode: Literal["adaptive", "square"] = "adaptive" # TODO: Remove (benchmarking only) tiling_margin_training: float = 20. tiling_margin_prediction: float = 20. @@ -166,11 +169,54 @@ def load(self): tx_fields = StandardTranscriptFields() bd_fields = StandardBoundaryFields() - # Load standardized IST data self.logger.debug(f"Loading standardized IST data from {self.input_directory}...") - pp = get_preprocessor(self.input_directory) - tx = self.tx = pp.transcripts - bd = self.bd = pp.boundaries + # Load standardized IST data (raw platform directory or SpatialData .zarr) + input_path = Path(self.input_directory) + tx = None + bd = None + + try: + from ..io.spatialdata_loader import ( + is_spatialdata_path, + load_from_spatialdata, + ) + has_spatialdata_loader = True + except Exception: + has_spatialdata_loader = False + + if has_spatialdata_loader and is_spatialdata_path(input_path): + tx_lf, bd = load_from_spatialdata( + input_path, + boundary_type="all", + normalize=True, + ) + tx = tx_lf.collect() if isinstance(tx_lf, pl.LazyFrame) else tx_lf + + # Keep behavior consistent with raw Xenium filtering when quality exists. + quality_col = getattr(tx_fields, "quality", "qv") + if ( + self.min_qv is not None + and self.min_qv > 0 + and quality_col in tx.columns + ): + tx = tx.filter(pl.col(quality_col) >= self.min_qv) + else: + pp = get_preprocessor( + self.input_directory, + min_qv=self.min_qv, + include_z=(self.use_3d is not False), + ) + tx = pp.transcripts + bd = pp.boundaries + + self.tx = tx + self.bd = bd + + if bd is None or len(bd) == 0: + raise ValueError( + "No boundary shapes found in input data. " + "Segger requires cell/nucleus polygons in raw input or SpatialData shapes." + ) # Mask transcripts to reference segmentation if self.segmentation_graph_mode == "nucleus": @@ -187,8 +233,16 @@ def load(self): f"Unrecognized segmentation graph mode: " f"'{self.segmentation_graph_mode}'." ) - tx_mask = pl.col(tx_fields.compartment).is_in(compartments) - bd_mask = bd[bd_fields.boundary_type] == boundary_type + + if tx_fields.compartment in tx.columns: + tx_mask = pl.col(tx_fields.compartment).is_in(compartments) + else: + tx_mask = pl.col(tx_fields.cell_id).is_not_null() + + if bd_fields.boundary_type in bd.columns: + bd_mask = bd[bd_fields.boundary_type] == boundary_type + else: + bd_mask = np.ones(len(bd), dtype=bool) # Generate reference AnnData self.logger.debug("Generating reference AnnData object...") @@ -222,6 +276,7 @@ def load(self): prediction_graph_mode=self.prediction_graph_mode, prediction_graph_max_k=self.prediction_graph_max_k, prediction_graph_buffer_ratio=self.prediction_graph_buffer_ratio, + use_3d=self.use_3d, ) # Tile graph dataset diff --git a/src/segger/data/utils/anndata.py b/src/segger/data/utils/anndata.py index 93db4c4..725f1ef 100644 --- a/src/segger/data/utils/anndata.py +++ b/src/segger/data/utils/anndata.py @@ -159,7 +159,7 @@ def setup_anndata( ad.obs .join( ( - boundaries + boundaries.drop_duplicates(subset=bd_fields.id) # some data oddly has duplicate boundary entries on the same cell id .reset_index(names=bd_fields.index) .set_index(bd_fields.id, verify_integrity=True) .get(bd_fields.index) @@ -195,7 +195,10 @@ def setup_anndata( # Build gene embedding on filtered dataset C = np.corrcoef(ad[ad.obs['filtered']].layers['norm'].todense().T) C = np.nan_to_num(C, 0, posinf=True, neginf=True) - model = sklearn.decomposition.PCA(n_components=cells_embedding_size) + model = sklearn.decomposition.PCA(n_components=min(cells_embedding_size, ad.var.shape[0])) + if ad.var.shape[0] < cells_embedding_size: + import warnings + warnings.warn('cell embedding size is larger than input feature space, falling back to that size.') ad.varm['X_corr'] = model.fit_transform(C) # Build PCs on filtered cells and project all cells diff --git a/src/segger/data/utils/heterodata.py b/src/segger/data/utils/heterodata.py index 40f43d9..9254609 100644 --- a/src/segger/data/utils/heterodata.py +++ b/src/segger/data/utils/heterodata.py @@ -24,6 +24,7 @@ def setup_heterodata( prediction_graph_mode: Literal["nucleus", "cell", "uniform"], prediction_graph_max_k: int, prediction_graph_buffer_ratio: float, + use_3d: bool | Literal["auto"] = False, cells_embedding_key: str = 'X_pca', cells_clusters_column: str = 'phenograph_cluster', cells_encoding_column: str = 'cell_encoding', @@ -44,6 +45,11 @@ def setup_heterodata( tx_fields.cell_cluster, tx_fields.gene_cluster, ] + + transcripts = transcripts.with_columns( + pl.col(tx_fields.feature).cast(pl.Utf8) + ) + # Update transcripts with fields for training transcripts = ( @@ -55,9 +61,14 @@ def setup_heterodata( pl.from_pandas( adata.var[[genes_encoding_column, genes_clusters_column]], include_index=True - ), + ).rename({ + pl.from_pandas( + adata.var[[genes_encoding_column, genes_clusters_column]], + include_index=True + ).columns[0]: tx_fields.feature + }), left_on=tx_fields.feature, - right_on=adata.var.index.name if adata.var.index.name else 'None', + right_on=tx_fields.feature, ) .rename( { @@ -135,6 +146,7 @@ def setup_heterodata( transcripts, max_k=transcripts_graph_max_k, max_dist=transcripts_graph_max_dist, + use_3d=use_3d, ) # Reference segmentation graph @@ -150,6 +162,7 @@ def setup_heterodata( max_k=prediction_graph_max_k, buffer_ratio=prediction_graph_buffer_ratio, mode=prediction_graph_mode, + use_3d=use_3d if prediction_graph_mode == "uniform" else False, ) return data diff --git a/src/segger/data/utils/neighbors.py b/src/segger/data/utils/neighbors.py index ce7ab3e..3c9d654 100644 --- a/src/segger/data/utils/neighbors.py +++ b/src/segger/data/utils/neighbors.py @@ -122,7 +122,7 @@ def edge_index_to_knn( def kdtree_neighbors( points: np.ndarray, max_k: int, - max_dist: float, + max_dist: float = np.inf, query: np.ndarray | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Wrapper for KDTree kNN and conversion to edge_index COO format. @@ -148,11 +148,25 @@ def setup_transcripts_graph( tx: pl.DataFrame, max_k: int, max_dist: float, + use_3d: bool | Literal["auto"] = False, ) -> torch.Tensor: """TODO: Add description. """ tx_fields = TrainingTranscriptFields() - points = tx[[tx_fields.x, tx_fields.y]].to_numpy() + coord_cols = [tx_fields.x, tx_fields.y] + has_z = tx_fields.z in tx.columns + + if use_3d == "auto": + use_3d = has_z and tx[tx_fields.z].null_count() < len(tx) + elif use_3d is True and not has_z: + raise ValueError( + f"use_3d=True but z column '{tx_fields.z}' not found in transcripts. " + f"Available columns: {tx.columns}" + ) + if use_3d and has_z: + coord_cols.append(tx_fields.z) + + points = tx[coord_cols].to_numpy() edge_index, _ = kdtree_neighbors( points=points, max_k=max_k, @@ -184,6 +198,7 @@ def setup_prediction_graph( max_k: int, buffer_ratio: float, mode: Literal['nucleus', 'cell', 'uniform'] = 'cell', + use_3d: bool | Literal["auto"] = False, ) -> torch.Tensor: """TODO: Add description. """ @@ -192,12 +207,27 @@ def setup_prediction_graph( # Uniform kNN graph if mode == "uniform": - points = tx[[tx_fields.x, tx_fields.y]].to_numpy() + coord_cols = [tx_fields.x, tx_fields.y] + has_z = tx_fields.z in tx.columns + if use_3d == "auto": + use_3d = has_z and tx[tx_fields.z].null_count() < len(tx) + elif use_3d is True and not has_z: + raise ValueError( + f"use_3d=True but z column '{tx_fields.z}' not found in transcripts. " + f"Available columns: {tx.columns}" + ) + if use_3d and has_z: + coord_cols.append(tx_fields.z) + + points = tx[coord_cols].to_numpy() query = bd.geometry.centroid.get_coordinates().values + if use_3d and len(coord_cols) == 3: + query = np.hstack([query, np.zeros((len(query), 1))]) edge_index, _ = kdtree_neighbors( points=points, query=query, max_k=max_k, + max_dist=np.inf, ) return edge_index diff --git a/src/segger/data/writer.py b/src/segger/data/writer.py index 5db1498..ae22d34 100644 --- a/src/segger/data/writer.py +++ b/src/segger/data/writer.py @@ -12,6 +12,7 @@ from ..io import TrainingTranscriptFields, TrainingBoundaryFields from . import ISTDataModule from .utils.anndata import anndata_from_transcripts +from ..export.spatialdata_writer import SpatialDataWriter class ISTSegmentationWriter(BasePredictionWriter): """TODO: Description @@ -23,14 +24,20 @@ class ISTSegmentationWriter(BasePredictionWriter): """ def __init__( self, + input_directory: Path, output_directory: Path, save_anndata: bool = True, + save_spatialdata: bool = True, + boundary_method: str = "convex_hull", debug: bool = False ): # "write" callback at the end of prediction epoch super().__init__(write_interval="epoch") + self.input_directory = Path (input_directory) self.output_directory = Path(output_directory) self.save_anndata = save_anndata + self.save_spatialdata = save_spatialdata + self.boundary_method = boundary_method self.segger_logger = logging.getLogger(__name__) # setup debugging @@ -124,6 +131,24 @@ def write_anndata( ) adata.write_h5ad(self.output_directory / 'segger_anndata.h5ad') + if self.save_spatialdata: + writer = SpatialDataWriter( + include_boundaries="True", + boundary_method=self.boundary_method, + boundary_n_jobs=4, + ) + tx, _ = _resolve_transcripts_and_boundaries(self.input_directory) + output_path = writer.write( + predictions=segmentation, + output_dir=self.output_directory, + transcripts=tx, + # boundaries=bd, + output_name="segger_segmentation.zarr", + ) + print(f"Written SpatialData output: {output_path}") + + + @classmethod def assign_transcripts_to_cells( cls, @@ -286,3 +311,47 @@ def on_fit_end(self, trainer, pl_module): self.segger_logger.debug(f"Saving trainer state to {self.path_debug / 'trainer_state_final.ckpt'}") trainer.save_checkpoint(self.path_debug / "trainer_state_final.ckpt") +def _is_spatialdata_path(path: Path | str) -> bool: + try: + from ..io.spatialdata_loader import is_spatialdata_path as _impl + return _impl(path) + except Exception: + p = Path(path) + return ( + p.suffix == ".zarr" + or (p / ".zgroup").exists() + or (p / "zarr.json").exists() + or (p / "points").exists() + or (p / "shapes").exists() + ) + + +def _resolve_transcripts_and_boundaries(source_path): + "Spatialdata loader for tx/bd. Hardcoded to Xenium naming." + if _is_spatialdata_path(source_path): + try: + from ..io.spatialdata_loader import load_from_spatialdata + except Exception as exc: + raise ImportError( + "SpatialData input requested, but spatialdata support is unavailable. " + "Install with: pip install segger[spatialdata]" + ) from exc + tx, bd = load_from_spatialdata( + source_path, + points_key="transcripts", + cell_shapes_key="cell_boundaries", + nucleus_shapes_key="nucleus_boundaries", + boundary_type="all", + ) + return (tx.collect() if isinstance(tx, pl.LazyFrame) else tx), bd + + from ..io import get_preprocessor + pp = get_preprocessor(source_path) + tx = pp.transcripts + if isinstance(tx, pl.LazyFrame): + tx = tx.collect() + try: + bd = pp.boundaries + except Exception: + bd = None + return tx, bd \ No newline at end of file diff --git a/src/segger/export/__init__.py b/src/segger/export/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/segger/export/boundary.py b/src/segger/export/boundary.py new file mode 100644 index 0000000..82eca0f --- /dev/null +++ b/src/segger/export/boundary.py @@ -0,0 +1,525 @@ +"""Delaunay triangulation-based cell boundary generation. + +This module provides sophisticated boundary extraction using Delaunay triangulation +with iterative edge refinement and cycle detection. This produces more accurate +cell boundaries than simple convex hulls. +""" + +from typing import Iterable, Tuple, Union +from concurrent.futures import ThreadPoolExecutor +import geopandas as gpd +import numpy as np +import pandas as pd +import polars as pl +import rtree.index +from scipy.spatial import Delaunay +from shapely.geometry import MultiPolygon, Polygon +from tqdm import tqdm + + +def vector_angle(v1: np.ndarray, v2: np.ndarray) -> float: + """Calculate angle between two vectors in degrees. + + Parameters + ---------- + v1 : np.ndarray + First vector. + v2 : np.ndarray + Second vector. + + Returns + ------- + float + Angle in degrees. + """ + dot_product = np.dot(v1, v2) + magnitude_v1 = np.linalg.norm(v1) + magnitude_v2 = np.linalg.norm(v2) + cos_angle = np.clip(dot_product / (magnitude_v1 * magnitude_v2 + 1e-8), -1.0, 1.0) + return np.degrees(np.arccos(cos_angle)) + + +def triangle_angles_from_points( + points: np.ndarray, + triangles: np.ndarray, +) -> np.ndarray: + """Calculate angles for all triangles in a Delaunay triangulation. + + Parameters + ---------- + points : np.ndarray + Point coordinates, shape (N, 2). + triangles : np.ndarray + Triangle vertex indices, shape (M, 3). + + Returns + ------- + np.ndarray + Angles for each triangle vertex, shape (M, 3). + """ + # Vectorized angle computation for all triangles + p1 = points[triangles[:, 0]] + p2 = points[triangles[:, 1]] + p3 = points[triangles[:, 2]] + + v1 = p2 - p1 + v2 = p3 - p1 + v3 = p3 - p2 + + def _angles(u: np.ndarray, v: np.ndarray) -> np.ndarray: + dot = (u * v).sum(axis=1) + denom = (np.linalg.norm(u, axis=1) * np.linalg.norm(v, axis=1)) + 1e-8 + cos = np.clip(dot / denom, -1.0, 1.0) + return np.degrees(np.arccos(cos)) + + a = _angles(v1, v2) + b = _angles(-v1, v3) + c = _angles(-v2, -v3) + return np.stack([a, b, c], axis=1) + + +def dfs(v: int, graph: dict, path: list, colors: dict) -> None: + """Depth-first search for cycle detection. + + Parameters + ---------- + v : int + Current vertex. + graph : dict + Adjacency list representation of graph. + path : list + Current path being built. + colors : dict + Vertex visit status (0=unvisited, 1=visited). + """ + colors[v] = 1 + path.append(v) + for d in graph[v]: + if colors[d] == 0: + dfs(d, graph, path, colors) + + +class BoundaryIdentification: + """Delaunay triangulation-based polygon boundary extraction. + + This class implements a two-phase iterative algorithm for extracting + cell boundaries from transcript point clouds: + + 1. Phase 1: Remove long boundary edges (> 2 * d_max) + 2. Phase 2: Remove boundary edges with extreme angles + + Parameters + ---------- + data : np.ndarray + 2D point coordinates, shape (N, 2). + """ + + def __init__(self, data: np.ndarray): + self.graph = None + self.edges = {} + self.d = Delaunay(data) + self.d_max = self.calculate_d_max(self.d.points) + self.generate_edges() + + def generate_edges(self) -> None: + """Generate edge dictionary from Delaunay triangulation.""" + d = self.d + edges = {} + angles = triangle_angles_from_points(d.points, d.simplices) + + for index, simplex in enumerate(d.simplices): + for p in range(3): + edge = tuple(sorted((simplex[p], simplex[(p + 1) % 3]))) + if edge not in edges: + edges[edge] = {"simplices": {}} + edges[edge]["simplices"][index] = angles[index][(p + 2) % 3] + + edges_coordinates = d.points[np.array(list(edges.keys()))] + edges_length = np.sqrt( + (edges_coordinates[:, 1, 0] - edges_coordinates[:, 0, 0]) ** 2 + + (edges_coordinates[:, 1, 1] - edges_coordinates[:, 0, 1]) ** 2 + ) + + for edge, coords, length in zip(edges, edges_coordinates, edges_length): + edges[edge]["coords"] = coords + edges[edge]["length"] = length + + self.edges = edges + + def calculate_part_1(self, plot: bool = False) -> None: + """Phase 1: Remove long boundary edges iteratively. + + Removes edges longer than 2 * d_max from the boundary. + + Parameters + ---------- + plot : bool + Whether to generate visualization (not implemented). + """ + edges = self.edges + d = self.d + d_max = self.d_max + + boundary_edges = [edge for edge in edges if len(edges[edge]["simplices"]) < 2] + + flag = True + while flag: + flag = False + next_boundary_edges = [] + + for current_edge in boundary_edges: + if current_edge not in edges: + continue + + if edges[current_edge]["length"] > 2 * d_max: + if len(edges[current_edge]["simplices"].keys()) == 0: + del edges[current_edge] + continue + + simplex_id = list(edges[current_edge]["simplices"].keys())[0] + simplex = d.simplices[simplex_id] + + for edge in self.get_edges_from_simplex(simplex): + if edge != current_edge: + edges[edge]["simplices"].pop(simplex_id) + next_boundary_edges.append(edge) + + del edges[current_edge] + flag = True + else: + next_boundary_edges.append(current_edge) + + boundary_edges = next_boundary_edges + + def calculate_part_2(self, plot: bool = False) -> None: + """Phase 2: Remove boundary edges with extreme angles. + + Removes edges where the opposite angle is too large, indicating + a concave region that should be excluded. + + Parameters + ---------- + plot : bool + Whether to generate visualization (not implemented). + """ + edges = self.edges + d = self.d + d_max = self.d_max + + boundary_edges = [edge for edge in edges if len(edges[edge]["simplices"]) < 2] + boundary_edges_length = len(boundary_edges) + next_boundary_edges = [] + + while len(next_boundary_edges) != boundary_edges_length: + next_boundary_edges = [] + + for current_edge in boundary_edges: + if current_edge not in edges: + continue + + if len(edges[current_edge]["simplices"].keys()) == 0: + del edges[current_edge] + continue + + simplex_id = list(edges[current_edge]["simplices"].keys())[0] + simplex = d.simplices[simplex_id] + + # Remove if edge is long with large angle, or if angle is very obtuse + if ( + edges[current_edge]["length"] > 1.5 * d_max + and edges[current_edge]["simplices"][simplex_id] > 90 + ) or edges[current_edge]["simplices"][simplex_id] > 180 - 180 / 16: + + for edge in self.get_edges_from_simplex(simplex): + if edge != current_edge: + edges[edge]["simplices"].pop(simplex_id) + next_boundary_edges.append(edge) + + del edges[current_edge] + else: + next_boundary_edges.append(current_edge) + + boundary_edges_length = len(boundary_edges) + boundary_edges = next_boundary_edges + + def find_cycles(self) -> Union[Polygon, MultiPolygon, None]: + """Find boundary cycles and convert to Shapely geometry. + + Returns + ------- + Union[Polygon, MultiPolygon, None] + Polygon if single cycle, MultiPolygon if multiple, None on error. + """ + e = self.edges + boundary_edges = [edge for edge in e if len(e[edge]["simplices"]) < 2] + self.graph = self.generate_graph(boundary_edges) + cycles = self.get_cycles(self.graph) + + try: + if len(cycles) == 1: + geom = Polygon(self.d.points[cycles[0]]) + else: + geom = MultiPolygon( + [Polygon(self.d.points[c]) for c in cycles if len(c) >= 3] + ) + except Exception: + return None + + return geom + + @staticmethod + def calculate_d_max(points: np.ndarray) -> float: + """Calculate maximum nearest-neighbor distance. + + Parameters + ---------- + points : np.ndarray + Point coordinates, shape (N, 2). + + Returns + ------- + float + Maximum nearest-neighbor distance. + """ + index = rtree.index.Index() + for i, p in enumerate(points): + index.insert(i, p[[0, 1, 0, 1]]) + + short_edges = [] + for i, p in enumerate(points): + res = list(index.nearest(p[[0, 1, 0, 1]], 2))[-1] + short_edges.append([i, res]) + + nearest_points = points[short_edges] + nearest_dists = np.sqrt( + (nearest_points[:, 0, 0] - nearest_points[:, 1, 0]) ** 2 + + (nearest_points[:, 0, 1] - nearest_points[:, 1, 1]) ** 2 + ) + return nearest_dists.max() + + @staticmethod + def get_edges_from_simplex(simplex: np.ndarray) -> list: + """Extract edge tuples from a triangle simplex. + + Parameters + ---------- + simplex : np.ndarray + Triangle vertex indices, shape (3,). + + Returns + ------- + list + List of edge tuples. + """ + edges = [] + for p in range(3): + edges.append(tuple(sorted((simplex[p], simplex[(p + 1) % 3])))) + return edges + + @staticmethod + def generate_graph(edges: list) -> dict: + """Generate adjacency list from edge list. + + Parameters + ---------- + edges : list + List of edge tuples. + + Returns + ------- + dict + Adjacency list representation. + """ + vertices = set() + for edge in edges: + vertices.add(edge[0]) + vertices.add(edge[1]) + + vertices = sorted(list(vertices)) + graph = {v: [] for v in vertices} + + for e in edges: + graph[e[0]].append(e[1]) + graph[e[1]].append(e[0]) + + return graph + + @staticmethod + def get_cycles(graph: dict) -> list: + """Find all connected components (cycles) in boundary graph. + + Parameters + ---------- + graph : dict + Adjacency list representation. + + Returns + ------- + list + List of cycles (each cycle is a list of vertex indices). + """ + colors = {v: 0 for v in graph} + cycles = [] + + for v in graph.keys(): + if colors[v] == 0: + cycle = [] + dfs(v, graph, cycle, colors) + cycles.append(cycle) + + return cycles + + +def generate_boundary( + df: Union[pd.DataFrame, pl.DataFrame], + x: str = "x", + y: str = "y", +) -> Union[Polygon, MultiPolygon, None]: + """Generate boundary polygon for a single cell's transcripts. + + Uses Delaunay triangulation with iterative edge refinement to produce + more accurate boundaries than simple convex hulls. + + Parameters + ---------- + df : Union[pd.DataFrame, pl.DataFrame] + Transcript data with x, y coordinates. + x : str + Column name for x coordinate. + y : str + Column name for y coordinate. + + Returns + ------- + Union[Polygon, MultiPolygon, None] + Cell boundary geometry, or None if insufficient points. + """ + # Convert Polars to pandas if needed + if isinstance(df, pl.DataFrame): + df = df.to_pandas() + + if len(df) < 3: + return None + + bi = BoundaryIdentification(df[[x, y]].values) + bi.calculate_part_1(plot=False) + bi.calculate_part_2(plot=False) + return bi.find_cycles() + + +def generate_boundaries( + df: Union[pd.DataFrame, pl.DataFrame], + x: str = "x", + y: str = "y", + cell_id: str = "seg_cell_id", + n_jobs: int = 1, + chunksize: int = 8, + progress: bool = True, +) -> gpd.GeoDataFrame: + """Generate boundaries for all cells in a segmentation result. + + Parameters + ---------- + df : Union[pd.DataFrame, pl.DataFrame] + Transcript data with cell assignments. + x : str + Column name for x coordinate. + y : str + Column name for y coordinate. + cell_id : str + Column name for cell ID. + + Returns + ------- + gpd.GeoDataFrame + GeoDataFrame with cell_id, length, and geometry columns. + """ + def iter_groups() -> Tuple[Iterable[Tuple[object, np.ndarray]], int]: + if isinstance(df, pl.DataFrame): + grouped = df.group_by(cell_id).agg( + [ + pl.col(x).list().alias("_x"), + pl.col(y).list().alias("_y"), + ] + ) + total = grouped.height + + def _gen(): + for cid, xs, ys in grouped.iter_rows(): + yield cid, np.column_stack((xs, ys)) + + return _gen(), total + + group_df = df.groupby(cell_id) + total = group_df.ngroups + + def _gen(): + for cid, t in group_df: + yield cid, t[[x, y]].to_numpy() + + return _gen(), total + + def _compute_one(item: Tuple[object, np.ndarray]) -> Tuple[object, int, Union[Polygon, MultiPolygon, None]]: + cid, points = item + n_points = points.shape[0] + if n_points < 3: + return cid, n_points, None + try: + bi = BoundaryIdentification(points) + bi.calculate_part_1(plot=False) + bi.calculate_part_2(plot=False) + geom = bi.find_cycles() + except Exception: + geom = None + return cid, n_points, geom + + group_iter, total = iter_groups() + res = [] + + if n_jobs and n_jobs > 1: + with ThreadPoolExecutor(max_workers=n_jobs) as ex: + iterator = ex.map(_compute_one, group_iter, chunksize=chunksize) + if progress: + iterator = tqdm(iterator, total=total, desc="Generating boundaries") + for cid, length, geom in iterator: + res.append({"cell_id": cid, "length": length, "geom": geom}) + else: + iterator = group_iter + if progress: + iterator = tqdm(iterator, total=total, desc="Generating boundaries") + for item in iterator: + cid, length, geom = _compute_one(item) + res.append({"cell_id": cid, "length": length, "geom": geom}) + + return gpd.GeoDataFrame( + data=[[b["cell_id"], b["length"]] for b in res], + geometry=[b["geom"] for b in res], + columns=["cell_id", "length"], + ) + + +def extract_largest_polygon( + geom: Union[Polygon, MultiPolygon, None], +) -> Union[Polygon, None]: + """Extract the largest polygon from a geometry. + + Parameters + ---------- + geom : Union[Polygon, MultiPolygon, None] + Input geometry. + + Returns + ------- + Union[Polygon, None] + Largest polygon, or None if input is None. + """ + if geom is None: + return None + if getattr(geom, "is_empty", False): + return None + if isinstance(geom, MultiPolygon): + candidates = [p for p in geom.geoms if p is not None and not p.is_empty] + if not candidates: + return None + return max(candidates, key=lambda p: p.area) + return geom diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py new file mode 100644 index 0000000..761d3bb --- /dev/null +++ b/src/segger/export/spatialdata_writer.py @@ -0,0 +1,797 @@ +"""Write segmentation results as SpatialData Zarr stores. + +This writer creates SpatialData-compatible Zarr stores containing: +- points["transcripts"]: Transcripts with segger_cell_id column +- shapes["cells"]: Cell boundaries (optional, can be input or generated) +- tables["cell_table"]: AnnData table with cell x gene counts (optional) + +NO images are included (per requirements). + +Usage +----- +>>> from segger.export.spatialdata_writer import SpatialDataWriter +>>> writer = SpatialDataWriter() +>>> output_path = writer.write( +... predictions=predictions, +... transcripts=transcripts, +... output_dir=Path("output/"), +... boundaries=boundaries, # Optional +... ) + +Installation +------------ +Requires the spatialdata optional dependency: + pip install segger[spatialdata] +""" + +from __future__ import annotations + +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Optional, Union + +import numpy as np +import pandas as pd +import polars as pl +from anndata import AnnData +from scipy import sparse as sp + + +from segger.utils.optional_deps import ( + require_spatialdata, +) +if TYPE_CHECKING: + import geopandas as gpd + from spatialdata import SpatialData + + +# @register_writer(OutputFormat.SPATIALDATA) +class SpatialDataWriter: + """Write segmentation results as SpatialData Zarr store. + + Creates a SpatialData object with: + - points["transcripts"]: Transcripts with cell assignments + - shapes["cells"]: Cell boundaries (if provided or generated) + + Parameters + ---------- + include_boundaries + Whether to include cell shapes in output. Default True. + boundary_method + How to generate boundaries if not provided: + - "input": Use input boundaries if available + - "convex_hull": Generate convex hull per cell + - "delaunay": Delaunay triangulation-based boundary extraction + - "skip": Don't include shapes + boundary_n_jobs + Parallel workers for Delaunay boundary generation (threads). + points_key + Key for transcripts in sdata.points. Default "transcripts". + shapes_key + Key for cell shapes in sdata.shapes. Default "cells". + include_table + Whether to include AnnData table in sdata.tables. Default True. + table_key + Key for AnnData table in sdata.tables. Default "cell_table". + table_region_key + Column in shapes that identifies cells. Default "cell_id". + """ + + def __init__( + self, + include_boundaries: bool = True, + boundary_method: Literal["input", "convex_hull", "delaunay", "skip"] = "convex_hull", + boundary_n_jobs: int = 1, + points_key: str = "transcripts", + shapes_key: str = "cells", + include_table: bool = True, + table_key: str = "cells_table", + fragment_table_key: str = "fragments_table", + table_region_key: str = "cell_id", + ): + require_spatialdata() + + self.include_boundaries = include_boundaries + self.boundary_method = boundary_method + self.boundary_n_jobs = boundary_n_jobs + self.points_key = points_key + self.shapes_key = shapes_key + self.include_table = include_table + self.table_key = table_key + self.table_region_key = table_region_key + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + transcripts: Optional[pl.DataFrame] = None, + boundaries: Optional["gpd.GeoDataFrame"] = None, + output_name: str = "segmentation.zarr", + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + feature_column: str = "feature_name", + x_column: str = "x", + y_column: str = "y", + z_column: Optional[str] = "z", + overwrite: bool = True, + **kwargs, + ) -> Path: + """Write segmentation results to SpatialData Zarr store. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + output_dir + Output directory. + transcripts + Original transcripts DataFrame. Required for SPATIALDATA format. + boundaries + Cell boundaries GeoDataFrame. Optional. + output_name + Output Zarr store name. Default "segmentation.zarr". + row_index_column + Column name for row index. + cell_id_column + Column name for cell ID in predictions. + similarity_column + Column name for similarity in predictions. + feature_column + Column name for gene/feature in transcripts. + x_column + Column name for x-coordinate. + y_column + Column name for y-coordinate. + z_column + Column name for z-coordinate (optional). + overwrite + Whether to overwrite existing Zarr store. + + Returns + ------- + Path + Path to the written .zarr store. + + Raises + ------ + ValueError + If transcripts are not provided. + """ + if transcripts is None: + raise ValueError( + "SpatialData format requires transcripts DataFrame. " + "Pass 'transcripts' parameter to write()." + ) + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / output_name + + # Check if exists + if output_path.exists() and not overwrite: + raise FileExistsError( + f"Output path exists: {output_path}. " + "Use overwrite=True to replace." + ) + + # Merge predictions with transcripts + merged = self._merge_predictions( + predictions=predictions, + transcripts=transcripts, + row_index_column=row_index_column, + cell_id_column=cell_id_column, + similarity_column=similarity_column, + ) + + # Create SpatialData object + sdata = self._create_spatialdata( + transcripts=merged, + boundaries=boundaries, + x_column=x_column, + y_column=y_column, + z_column=z_column, + cell_id_column=cell_id_column, + feature_column=feature_column, + ) + + # Write to Zarr + self._write_spatialdata_zarr( + sdata=sdata, + output_path=output_path, + overwrite=overwrite, + ) + + return output_path + + def _merge_predictions( + self, + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + row_index_column: str, + cell_id_column: str, + similarity_column: str, + ) -> pl.DataFrame: + """Merge predictions with transcripts.""" + # Prepare predictions + pred_cols = [row_index_column, cell_id_column] + if similarity_column in predictions.columns: + pred_cols.append(similarity_column) + + pred_subset = predictions.select(pred_cols) + + # Add row_index if missing + if row_index_column not in transcripts.columns: + transcripts = transcripts.with_row_index(name=row_index_column) + + # Join + merged = transcripts.join(pred_subset, on=row_index_column, how="left") + + # Fill unassigned with -1 + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(-1) + ) + if similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + return merged + + def _create_spatialdata( + self, + transcripts: pl.DataFrame, + boundaries: Optional["gpd.GeoDataFrame"], + x_column: str, + y_column: str, + z_column: Optional[str], + cell_id_column: str, + feature_column: str, + ) -> "SpatialData": + """Create SpatialData object from transcripts and boundaries.""" + import spatialdata + from spatialdata.models import PointsModel, ShapesModel, TableModel + import dask.dataframe as dd + + identity = self._identity_transform() + transformations = {"global": identity} if identity is not None else None + + # Convert transcripts to pandas for SpatialData + tx_pd = transcripts.to_pandas() + + # SOPA expects "cell_id" assignment in points. + if cell_id_column in tx_pd.columns and "cell_id" not in tx_pd.columns: + tx_pd['cell_id']= tx_pd[cell_id_column] + #NOTE: having both 'cell_id' and 'segger_cell_id' creates confusion + # tx_pd = tx_pd.rename(columns={cell_id_column: "cell_id"}) + # this would be better but fails as later code still relies on cell_id_column + + # Check for z-coordinate + has_z = z_column and z_column in tx_pd.columns + + # Create points element + # SpatialData expects coordinates in specific columns + coords_cols = [x_column, y_column] + if has_z: + coords_cols.append(z_column) + + # Ensure coordinates are float + for col in coords_cols: + if col in tx_pd.columns: + tx_pd[col] = tx_pd[col].astype(float) + + # Create Dask DataFrame for points + tx_pd[feature_column] = tx_pd[feature_column].astype("category") + tx_dask = dd.from_pandas(tx_pd) + + # Points element + points_parse_kwargs = { + "coordinates": { + "x": x_column, + "y": y_column, + **({"z": z_column} if has_z else {}), + }, + "instance_key": cell_id_column, # or 'cell_id' which is hard-coded now + "feature_key": feature_column, + } + if transformations is not None: + points_parse_kwargs["transformations"] = transformations + + points = PointsModel.parse(tx_dask, **points_parse_kwargs) + points_elements = {self.points_key: points} + + # Shapes + def _ensure_cell_id(gdf): + if gdf is None: + return None + if "cell_id" in gdf.columns: + return gdf + if cell_id_column in gdf.columns: + gdf = gdf.copy() + gdf["cell_id"] = gdf[cell_id_column] + return gdf + gdf = gdf.reset_index(drop=False) + if "cell_id" not in gdf.columns and len(gdf.columns) > 0: + gdf["cell_id"] = gdf[gdf.columns[0]] + return gdf + + + def _parse_shapes(shapes): + if shapes is None or len(shapes) == 0: + return None + kwargs = {"transformations": transformations} if transformations is not None else {} + return ShapesModel.parse(shapes, **kwargs) + + shapes_elements = {} + + shape_specs = [(self.shapes_key, tx_pd)] + + for shape_key, shape_tx_pd in shape_specs: + shapes = self._get_generated_boundaries(shape_tx_pd, x_column, y_column, cell_id_column) + shapes = _ensure_cell_id(shapes) + parsed = _parse_shapes(shapes) + if parsed is not None: + shapes_elements[shape_key] = parsed + + # Optional AnnData table + tables_elements = {} + if self.include_table: + region = self.shapes_key if self.shapes_key in shapes_elements else None + instance_key = self.table_region_key if region is not None else None + table = build_anndata_table( + transcripts=transcripts, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=-1, + region=None, + region_key=None, + obs_index_as_str=True, + ) + if region is not None: + table.obs["region"] = region + if instance_key and instance_key not in table.obs.columns: + table.obs[instance_key] = table.obs.index.astype(str) + try: + table = TableModel.parse( + table, + region=region, + region_key="region", + instance_key=instance_key or "instance_id", + ) + except Exception: + pass + tables_elements[self.table_key] = table + + for name, table in tables_elements.items(): + if 'spatialdata_attrs' not in table.uns.keys(): + warnings.warn( + f"Table {name} does not contain the `uns['spatialdata_attrs']` field as no shapes element is associated." + ) + + # Create SpatialData (prefer modern constructor methods, keep fallback on single elemnts) + sdata = self._build_spatialdata( + spatialdata=spatialdata, + points_elements=points_elements, + shapes_elements=shapes_elements, + tables_elements=tables_elements, + ) + + return sdata + + def _identity_transform(self): + """Return SpatialData identity transform when available.""" + try: + from spatialdata.transformations import Identity + return Identity() + except Exception: + return None + + def _build_spatialdata(self, spatialdata, points_elements: dict, shapes_elements: dict, tables_elements: dict): + """Build a SpatialData object across SpatialData API variants.""" + + if hasattr(spatialdata.SpatialData, "init_from_elements"): + return spatialdata.SpatialData.init_from_elements(points_elements | shapes_elements | tables_elements) + else: + return spatialdata.SpatialData( + points=points_elements, + shapes=shapes_elements, + tables=tables_elements, + ) + + + def _build_table_element( + self, + TableModel, + transcripts: pl.DataFrame, + var_transcripts: pl.DataFrame, + region: Optional[str], + cell_id_column: str, + feature_column: str, + x_column: str, + y_column: str, + z_column: Optional[str], + ): + """Build a SpatialData table and attach region metadata when available.""" + table = build_anndata_table( + transcripts=transcripts, + var_transcripts=var_transcripts, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=-1, + region=None, + region_key=None, + obs_index_as_str=True, + ) + if region is None: + return TableModel.validate(table) + + instance_key = self.table_region_key + table.obs["region"] = region + if instance_key and instance_key not in table.obs.columns: + table.obs[instance_key] = table.obs.index.astype(str) + try: + return TableModel.parse( + table, + region=region, + region_key="region", + instance_key=instance_key or "instance_id", + ) + except Exception as e: + warnings.warn(f"TableModel.parse failed: {e}") + return table + + def _write_spatialdata_zarr(self, sdata, output_path: Path, overwrite: bool) -> None: + """Write SpatialData object with compatibility fallback.""" + try: + sdata.write(output_path, overwrite=overwrite) + return + except TypeError: + pass + + if output_path.exists(): + import shutil + shutil.rmtree(output_path) + sdata.write(output_path) + + + + def _get_input_boundaries(self, cell_tx_pd, cell_id_column, boundaries, bd_type): + + selected_ids = cell_tx_pd[cell_id_column].dropna().unique() + if len(selected_ids) == 0 or boundaries is None: + if boundaries is None: + warnings.warn("No input boundaries were found. Skipping boundary generation.") + return None + + boundaries_filtered = boundaries.loc[boundaries['boundary_type'] == bd_type] + boundaries_gdf = boundaries_filtered[boundaries_filtered["cell_id"].isin(selected_ids)].copy() + + return boundaries_gdf if not boundaries_gdf.empty else None + + + + def _get_generated_boundaries( + self, + transcripts: pd.DataFrame, + x_column: str, + y_column: str, + cell_id_column: str, + ) -> Optional[gpd.GeoDataFrame]: + """Generate cell boundaries based on the selected boundary method. + Args + transcripts: dataframe of group transcripts (cells or fragments) + x_column, y_column: transcripts 2D coordinates + cell_id_column: cell ID + """ + import geopandas as gpd + + assigned = transcripts[transcripts[cell_id_column] != -1].copy() + if assigned.empty: + return None + + if self.boundary_method == "convex_hull": + from shapely.geometry import MultiPoint + + hulls, cell_ids = [], [] + + for cell_id, group in assigned.groupby(cell_id_column): + if len(group) < 3: + continue + points = list(zip(group[x_column], group[y_column])) + hull = MultiPoint(points).convex_hull + if hull.is_empty or hull.geom_type != "Polygon": + continue + hulls.append(hull) + cell_ids.append(cell_id) + + if not hulls: + return None + return gpd.GeoDataFrame({"cell_id": cell_ids}, geometry=hulls) + + elif self.boundary_method == "delaunay": + from segger.export.boundary import generate_boundaries + warnings.filterwarnings('ignore', 'GeoSeries.notna', UserWarning) + + boundaries_gdf = generate_boundaries( + assigned, + x=x_column, + y=y_column, + cell_id=cell_id_column, + n_jobs=self.boundary_n_jobs, + ) + boundaries_gdf = boundaries_gdf[ + boundaries_gdf.geometry.notna() & ~boundaries_gdf.geometry.is_empty + ] + if len(boundaries_gdf) == 0: + return None + return boundaries_gdf + + return None + + +def write_spatialdata( + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + output_dir: Path, + boundaries: Optional["gpd.GeoDataFrame"] = None, + output_name: str = "segmentation.zarr", + **kwargs, +) -> Path: + """Convenience function to write SpatialData output. + + Parameters + ---------- + predictions + Segmentation predictions. + transcripts + Original transcripts. + output_dir + Output directory. + boundaries + Cell boundaries (optional). + output_name + Output filename. + **kwargs + Additional arguments passed to SpatialDataWriter.write(). + + Returns + ------- + Path + Path to written .zarr store. + + Examples + -------- + >>> path = write_spatialdata( + ... predictions=preds, + ... transcripts=tx, + ... output_dir=Path("output/"), + ... ) + """ + writer = SpatialDataWriter() + return writer.write( + predictions=predictions, + output_dir=output_dir, + transcripts=transcripts, + boundaries=boundaries, + output_name=output_name, + **kwargs, + ) + + +### APIs from other exporting formats in v2-incremental ### + +### ANNDATA EXPORT ### + +def build_anndata_table( + transcripts: pl.DataFrame, + cell_id_column: str = "segger_cell_id", + feature_column: str = "feature_name", + x_column: Optional[str] = "x", + y_column: Optional[str] = "y", + z_column: Optional[str] = "z", + unassigned_value: Union[int, str, None] = -1, + region: Optional[str] = None, + region_key: Optional[str] = None, + obs_index_as_str: bool = False, +) -> AnnData: + """Build AnnData from assigned transcripts. + + Parameters + ---------- + transcripts + Transcript DataFrame with segmentation assignments. + cell_id_column + Column with assigned cell IDs. + feature_column + Column with gene/feature names. + x_column, y_column, z_column + Coordinate columns (optional). If present, centroids are stored in + ``obsm["X_spatial"]``. + unassigned_value + Marker for unassigned transcripts (filtered out). + region, region_key + SpatialData table linkage metadata. + obs_index_as_str + If True, cast cell IDs to string for ``obs`` index. + """ + if cell_id_column not in transcripts.columns: + raise ValueError(f"Missing cell_id column: {cell_id_column}") + if feature_column not in transcripts.columns: + raise ValueError(f"Missing feature column: {feature_column}") + + assigned = transcripts.filter(pl.col(cell_id_column).is_not_null()) + if unassigned_value is not None: + col_dtype = transcripts.schema.get(cell_id_column) + try: + compare_value = pl.Series([unassigned_value]).cast(col_dtype).item() + filter_expr = pl.col(cell_id_column) != compare_value + except Exception: + filter_expr = ( + pl.col(cell_id_column).cast(pl.Utf8) != str(unassigned_value) + ) + assigned = assigned.filter(filter_expr) + + # Gene list from all transcripts (even if no assignments) + var_idx = ( + transcripts + .select(feature_column) + .unique() + .sort(feature_column) + .get_column(feature_column) + .to_list() + ) + + if assigned.height == 0: + obs_index = pd.Index([], name=cell_id_column) + if obs_index_as_str: + var_index = pd.Index([str(v) for v in var_idx], name=feature_column) + else: + var_index = pd.Index(var_idx, name=feature_column) + X = sp.csr_matrix((0, len(var_index))) + adata = AnnData(X=X, obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=var_index)) + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + return adata + + feature_idx = ( + assigned + .select(feature_column) + .unique() + .sort(feature_column) + .with_row_index(name="_fid") + ) + cell_idx = ( + assigned + .select(cell_id_column) + .unique() + .sort(cell_id_column) + .with_row_index(name="_cid") + ) + + mapped = ( + assigned + .join(feature_idx, on=feature_column) + .join(cell_idx, on=cell_id_column) + ) + counts = ( + mapped + .group_by(["_cid", "_fid"]) + .agg(pl.len().alias("_count")) + ) + ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T + rows = ijv[0].astype(np.int64, copy=False) + cols = ijv[1].astype(np.int64, copy=False) + data = ijv[2].astype(np.int64, copy=False) + + n_cells = cell_idx.height + n_genes = feature_idx.height + X = sp.coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)).tocsr() + + obs_ids = cell_idx.get_column(cell_id_column).to_list() + var_ids = feature_idx.get_column(feature_column).to_list() + if obs_index_as_str: + obs_ids = [str(v) for v in obs_ids] + var_ids = [str(v) for v in var_ids] + + adata = AnnData( + X=X, + obs=pd.DataFrame(index=pd.Index(obs_ids, name=cell_id_column)), + var=pd.DataFrame(index=pd.Index(var_ids, name=feature_column)), + ) + + # Add centroid coordinates if present + if x_column in assigned.columns and y_column in assigned.columns: + coords_cols = [x_column, y_column] + if z_column and z_column in assigned.columns: + coords_cols.append(z_column) + centroids = ( + assigned + .group_by(cell_id_column) + .agg([pl.col(c).mean().alias(c) for c in coords_cols]) + ) + centroids_pd = ( + centroids + .to_pandas() + .set_index(cell_id_column) + .reindex(adata.obs.index) + ) + adata.obsm["X_spatial"] = centroids_pd[coords_cols].to_numpy() + + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + + return adata + +### MERGED EXPORT ### + +def merge_predictions_with_transcripts( + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + unassigned_marker: Union[int, str, None] = -1, +) -> pl.DataFrame: + """Merge predictions with transcripts (functional interface). + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + transcripts + Original transcripts DataFrame. + row_index_column + Column name for row index. + cell_id_column + Column name for cell ID in predictions. + similarity_column + Column name for similarity in predictions. + unassigned_marker + Value for unassigned transcripts. + + Returns + ------- + pl.DataFrame + Merged DataFrame with all original columns plus predictions. + + Examples + -------- + >>> merged = merge_predictions_with_transcripts(predictions, transcripts) + >>> print(merged.columns) + ['row_index', 'x', 'y', 'feature_name', 'segger_cell_id', 'segger_similarity'] + """ + # Prepare predictions + pred_cols = [row_index_column, cell_id_column] + if similarity_column in predictions.columns: + pred_cols.append(similarity_column) + + pred_subset = predictions.select(pred_cols) + + # Add row_index if missing + if row_index_column not in transcripts.columns: + transcripts = transcripts.with_row_index(name=row_index_column) + + # Join + merged = transcripts.join(pred_subset, on=row_index_column, how="left") + + # Fill unassigned + if unassigned_marker is not None: + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(unassigned_marker) + ) + if similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + return merged diff --git a/src/segger/io/__init__.py b/src/segger/io/__init__.py index 1f1ad20..a913449 100644 --- a/src/segger/io/__init__.py +++ b/src/segger/io/__init__.py @@ -1,7 +1,81 @@ -from .preprocessor import get_preprocessor -from .fields import ( - StandardBoundaryFields, - TrainingBoundaryFields, - StandardTranscriptFields, - TrainingTranscriptFields, -) \ No newline at end of file +"""Input/output modules for spatial transcriptomics data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +import importlib + +__all__ = [ + # Preprocessors + "get_preprocessor", + # Fields + "StandardBoundaryFields", + "TrainingBoundaryFields", + "StandardTranscriptFields", + "TrainingTranscriptFields", + # SpatialData (optional) + "SpatialDataLoader", + "load_from_spatialdata", + "is_spatialdata_path", +] + +if TYPE_CHECKING: # pragma: no cover + from .fields import ( + StandardBoundaryFields, + TrainingBoundaryFields, + StandardTranscriptFields, + TrainingTranscriptFields, + ) + from .preprocessor import get_preprocessor + from .spatialdata_loader import ( + SpatialDataLoader, + load_from_spatialdata, + is_spatialdata_path, + ) + + +def __getattr__(name: str): + if name in { + "StandardBoundaryFields", + "TrainingBoundaryFields", + "StandardTranscriptFields", + "TrainingTranscriptFields", + }: + from .fields import ( + StandardBoundaryFields, + TrainingBoundaryFields, + StandardTranscriptFields, + TrainingTranscriptFields, + ) + return locals()[name] + + if name == "get_preprocessor": + from .preprocessor import get_preprocessor + return get_preprocessor + + if name in { + "SpatialDataLoader", + "load_from_spatialdata", + "is_spatialdata_path", + }: + try: + from .spatialdata_loader import ( + SpatialDataLoader, + load_from_spatialdata, + is_spatialdata_path, + ) + except Exception: + return None + return locals()[name] + + if name in { + "fields", + "preprocessor", + "spatialdata_loader", + }: + try: + return importlib.import_module(f"{__name__}.{name}") + except Exception as exc: + raise ImportError(f"Failed to import module '{name}'.") from exc + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/segger/io/fields.py b/src/segger/io/fields.py index 40bd6be..ff1ed35 100644 --- a/src/segger/io/fields.py +++ b/src/segger/io/fields.py @@ -8,6 +8,7 @@ class XeniumTranscriptFields: filename: str = 'transcripts.parquet' x: str = 'x_location' y: str = 'y_location' + z: str = 'z_location' feature: str = 'feature_name' cell_id: str = 'cell_id' null_cell_id: str = 'UNASSIGNED' @@ -38,6 +39,7 @@ class MerscopeTranscriptFields: filename: str = 'detected_transcripts.csv' x: str = 'global_x' y: str = 'global_y' + z: str = 'global_z' feature: str = 'gene' cell_id: str = 'cell_id' @@ -54,6 +56,7 @@ class CosMxTranscriptFields: filename: str = '*_tx_file.csv' x: str = 'x_global_px' y: str = 'y_global_px' + z: str = 'z' feature: str = 'target' cell_id: str = 'cell' compartment: str = 'CellComp' @@ -87,8 +90,10 @@ class StandardTranscriptFields: row_index: str = 'row_index' x: str = 'x' y: str = 'y' + z: str = 'z' feature: str = 'feature_name' cell_id: str = 'cell_id' + quality: str = 'qv' compartment: str = 'cell_compartment' extracellular_value: int = 0 cytoplasmic_value: int = 1 diff --git a/src/segger/io/filtering.py b/src/segger/io/filtering.py new file mode 100644 index 0000000..abb796d --- /dev/null +++ b/src/segger/io/filtering.py @@ -0,0 +1,79 @@ +"""Shared transcript filtering utilities for I/O readers.""" + +from __future__ import annotations + +import re +from typing import Collection, Sequence + +import polars as pl + +from .fields import CosMxTranscriptFields, MerscopeTranscriptFields, XeniumTranscriptFields + + +_PLATFORM_ALIASES: dict[str, str] = { + "10x_xenium": "xenium", + "nanostring_cosmx": "cosmx", + "vizgen_merscope": "merscope", +} + + +def normalize_platform_name(platform: str | None) -> str | None: + """Normalize platform aliases to canonical names.""" + if platform is None: + return None + lowered = str(platform).strip().lower() + return _PLATFORM_ALIASES.get(lowered, lowered) + + +def infer_platform_from_columns(columns: Collection[str]) -> str | None: + """Infer source platform from transcript table columns.""" + cols = set(columns) + + # CosMx marker columns are highly specific. + if "CellComp" in cols or {"x_global_px", "y_global_px"}.issubset(cols): + return "cosmx" + + # Xenium marker columns. + if "overlaps_nucleus" in cols or "qv" in cols: + return "xenium" + if {"x_location", "y_location", "feature_name"}.issubset(cols): + return "xenium" + + # MERSCOPE marker columns. + if {"global_x", "global_y"}.issubset(cols): + return "merscope" + + return None + + +def platform_feature_filter_patterns(platform: str | None) -> list[str]: + """Return feature-name control patterns for the given platform.""" + normalized = normalize_platform_name(platform) + if normalized == "xenium": + return list(XeniumTranscriptFields.filter_substrings) + if normalized == "cosmx": + return list(CosMxTranscriptFields.filter_substrings) + if normalized == "merscope": + return list(MerscopeTranscriptFields.filter_substrings) + return [] + +def glob_patterns_to_regex(patterns: Sequence[str]) -> str: + """Convert glob-like patterns (`*`) to a regex union.""" + regexes = [] + for pattern in patterns: + regex_pattern = re.escape(pattern).replace(r"\*", ".*") + regexes.append(f"^{regex_pattern}$") + return "|".join(regexes) + + +def apply_feature_filters( + lf: pl.LazyFrame, + feature_column: str, + patterns: Sequence[str], +) -> pl.LazyFrame: + """Drop rows whose feature names match control/blank patterns.""" + if not patterns: + return lf + pattern_regex = glob_patterns_to_regex(patterns) + feature_expr = pl.col(feature_column).cast(pl.String, strict=False) + return lf.filter(feature_expr.str.contains(pattern_regex).fill_null(False).not_()) diff --git a/src/segger/io/preprocessor.py b/src/segger/io/preprocessor.py index 597a818..23634ed 100644 --- a/src/segger/io/preprocessor.py +++ b/src/segger/io/preprocessor.py @@ -2,7 +2,7 @@ from functools import cached_property from abc import ABC, abstractmethod from anndata import AnnData -from typing import Literal +from typing import Literal, Optional from pathlib import Path import geopandas as gpd import polars as pl @@ -34,6 +34,14 @@ # Register of available ISTPreprocessor subclasses keyed by platform name. PREPROCESSORS = {} + +def _lazyframe_column_names(lf: pl.LazyFrame) -> list[str]: + """Return column names for a LazyFrame across Polars versions.""" + try: + return lf.collect_schema().names() + except AttributeError: + return lf.columns + def register_preprocessor(name): """ Decorator to register a preprocessor class under a given platform name. @@ -60,7 +68,14 @@ class ISTPreprocessor(ABC): transcript and boundary GeoDataFrames for the given platform. """ - def __init__(self, data_dir: Path): + DEFAULT_MIN_QV: Optional[float] = None + + def __init__( + self, + data_dir: Path, + min_qv: Optional[float] = None, + include_z: bool = True, + ): """ Parameters ---------- @@ -70,6 +85,8 @@ def __init__(self, data_dir: Path): data_dir = Path(data_dir) type(self)._validate_directory(data_dir) self.data_dir = data_dir + self.min_qv = self.DEFAULT_MIN_QV if min_qv is None else min_qv + self.include_z = include_z @staticmethod @abstractmethod @@ -280,7 +297,7 @@ def transcripts(self) -> pl.DataFrame: raw = CosMxTranscriptFields() std = StandardTranscriptFields() - return ( + lf = ( # Read in lazily pl.scan_csv(next(self.data_dir.glob(raw.filename))) .with_row_index(name=std.row_index) @@ -310,13 +327,22 @@ def transcripts(self) -> pl.DataFrame: .otherwise(None) .alias(std.cell_id) ) - # Map to standard field names - .rename({raw.x: std.x, raw.y: std.y, raw.feature: std.feature}) - - # Subset to necessary fields - .select([std.row_index, std.x, std.y, std.feature, std.cell_id, - std.compartment]) + ) + rename_map = {raw.x: std.x, raw.y: std.y, raw.feature: std.feature} + select_cols = [std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment] + if self.include_z: + schema_names = _lazyframe_column_names(lf) + if raw.z in schema_names: + rename_map[raw.z] = std.z + select_cols.append(std.z) + + return ( + lf + # Map to standard field names + .rename(rename_map) + # Subset to necessary fields + .select(select_cols) # Add numeric index .with_row_index() .collect() @@ -372,6 +398,8 @@ class XeniumPreprocessor(ISTPreprocessor): """ Preprocessor for 10x Genomics Xenium datasets. """ + DEFAULT_MIN_QV: float = 20.0 + @staticmethod def _validate_directory(data_dir: Path): @@ -397,7 +425,7 @@ def transcripts(self) -> pl.DataFrame: raw = XeniumTranscriptFields() std = StandardTranscriptFields() - return ( + lf = ( # Read in lazily pl.scan_parquet( self.data_dir / raw.filename, @@ -405,8 +433,12 @@ def transcripts(self) -> pl.DataFrame: ) # Add numeric index at beginning .with_row_index(name=std.row_index) - # Filter data - .filter(pl.col(raw.quality) >= 20) + ) + if self.min_qv is not None and self.min_qv > 0: + lf = lf.filter(pl.col(raw.quality) >= self.min_qv) + + lf = ( + lf .filter(pl.col(raw.feature).str.contains( '|'.join(raw.filter_substrings)).not_() ) @@ -415,7 +447,7 @@ def transcripts(self) -> pl.DataFrame: pl.when(pl.col(raw.compartment) == raw.nucleus_value) .then(std.nucleus_value) .when( - (pl.col(raw.compartment) != raw.nucleus_value) & + (pl.col(raw.compartment) != raw.nucleus_value) & (pl.col(raw.cell_id) != raw.null_cell_id) ) .then(std.cytoplasmic_value) @@ -428,12 +460,22 @@ def transcripts(self) -> pl.DataFrame: .replace(raw.null_cell_id, None) .alias(std.cell_id) ) + ) + + rename_map = {raw.x: std.x, raw.y: std.y, raw.feature: std.feature} + select_cols = [std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment] + if self.include_z: + schema_names = _lazyframe_column_names(lf) + if raw.z in schema_names: + rename_map[raw.z] = std.z + select_cols.append(std.z) + + return ( + lf # Map to standard field names - .rename({raw.x: std.x, raw.y: std.y, raw.feature: std.feature}) - - # Subset to necessary fields - .select([std.row_index, std.x, std.y, std.feature, std.cell_id, - std.compartment]) + .rename(rename_map) + # Subset to necessary fields + .select(select_cols) .collect() ) @@ -540,7 +582,9 @@ def _infer_platform(data_dir: Path) -> str: def get_preprocessor( data_dir: Path, - platform: str | None = None + platform: str | None = None, + min_qv: Optional[float] = None, + include_z: bool = True, ) -> ISTPreprocessor: data_dir = Path(data_dir) if platform is None: @@ -551,4 +595,4 @@ def get_preprocessor( f"Available: {list(PREPROCESSORS)}" ) cls = PREPROCESSORS[platform.lower()] - return cls(data_dir) + return cls(data_dir, min_qv=min_qv, include_z=include_z) diff --git a/src/segger/io/spatialdata_loader.py b/src/segger/io/spatialdata_loader.py new file mode 100644 index 0000000..895716f --- /dev/null +++ b/src/segger/io/spatialdata_loader.py @@ -0,0 +1,474 @@ +"""Load transcript and boundary data from SpatialData .zarr stores. + +This loader normalizes heterogeneous SpatialData point/shape schemas to +Segger's internal fields so the same downstream data module can run on both +vendor raw inputs and SpatialData inputs. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal, Optional + +import geopandas as gpd +import pandas as pd +import polars as pl + +from segger.io.fields import StandardBoundaryFields, StandardTranscriptFields +from segger.utils.optional_deps import ( + SPATIALDATA_IO_AVAILABLE, + require_spatialdata, + warn_spatialdata_io_unavailable, +) + + +_COMMON_POINTS_KEYS = [ + "transcripts", + "molecules", + "points", + "spots", + "tx", +] + +_COMMON_CELL_SHAPES_KEYS = [ + "cells", + "cell_boundaries", + "cell_shapes", + "cell_polygons", + "boundaries", +] + +_COMMON_NUCLEUS_SHAPES_KEYS = [ + "nuclei", + "nucleus_boundaries", + "nucleus_shapes", + "nucleus_polygons", + "nuclei_boundaries", +] + + +def _lazyframe_column_names(lf: pl.LazyFrame) -> list[str]: + """Return column names for a LazyFrame across Polars versions.""" + try: + return lf.collect_schema().names() + except AttributeError: + return lf.columns + + +def _safe_to_geodataframe(data: object) -> gpd.GeoDataFrame: + """Best-effort conversion of a SpatialData shapes element to GeoDataFrame.""" + if isinstance(data, gpd.GeoDataFrame): + return data.copy() + if hasattr(data, "compute"): + data = data.compute() + if hasattr(data, "to_geopandas"): + return data.to_geopandas().copy() + if hasattr(data, "to_pandas"): + df = data.to_pandas() + elif isinstance(data, pd.DataFrame): + df = data + else: + df = pd.DataFrame(data) + + if "geometry" in df.columns: + return gpd.GeoDataFrame(df, geometry="geometry") + + raise ValueError( + "Could not convert shapes element to GeoDataFrame: no geometry column found." + ) + + +def _largest_polygon(geom): + """Convert MultiPolygon/GeometryCollection to a single polygon when possible.""" + if geom is None or geom.is_empty: + return geom + gtype = geom.geom_type + if gtype == "Polygon": + return geom + if gtype == "MultiPolygon": + parts = list(geom.geoms) + if not parts: + return geom + return max(parts, key=lambda p: p.area) + if gtype == "GeometryCollection": + parts = [g for g in geom.geoms if g.geom_type == "Polygon"] + if parts: + return max(parts, key=lambda p: p.area) + return geom + + +class SpatialDataLoader: + """Load and normalize points/shapes from a SpatialData Zarr store.""" + + def __init__( + self, + path: Path | str, + points_key: Optional[str] = None, + cell_shapes_key: Optional[str] = None, + nucleus_shapes_key: Optional[str] = None, + coordinate_system: str = "global", + ): + require_spatialdata() + if not SPATIALDATA_IO_AVAILABLE: + warn_spatialdata_io_unavailable( + "Platform-specific SpatialData readers (Xenium/MERSCOPE/CosMX)" + ) + + self._path = Path(path) + self._points_key = points_key + self._cell_shapes_key = cell_shapes_key + self._nucleus_shapes_key = nucleus_shapes_key + self._coordinate_system = coordinate_system + self._sdata = None + + if not self._path.exists(): + raise FileNotFoundError(f"SpatialData store not found: {self._path}") + + @property + def sdata(self): + if self._sdata is None: + spatialdata = require_spatialdata() + if hasattr(spatialdata, "read_zarr"): + self._sdata = spatialdata.read_zarr(str(self._path)) + else: + # Fallback for API variants + self._sdata = spatialdata.SpatialData.read(str(self._path)) + return self._sdata + + @property + def points_key(self) -> str: + if self._points_key is None: + self._points_key = self._detect_points_key() + return self._points_key + + @property + def cell_shapes_key(self) -> Optional[str]: + if self._cell_shapes_key is None: + self._cell_shapes_key = self._detect_shapes_key(_COMMON_CELL_SHAPES_KEYS) + return self._cell_shapes_key + + @property + def nucleus_shapes_key(self) -> Optional[str]: + if self._nucleus_shapes_key is None: + self._nucleus_shapes_key = self._detect_shapes_key(_COMMON_NUCLEUS_SHAPES_KEYS) + return self._nucleus_shapes_key + + def _detect_points_key(self) -> str: + available = list(self.sdata.points.keys()) + if not available: + raise ValueError( + f"No points elements found in SpatialData store: {self._path}" + ) + + for key in _COMMON_POINTS_KEYS: + if key in available: + return key + + lowered = {k.lower(): k for k in available} + for pattern in ("transcript", "molecule", "spot", "point"): + for lk, orig in lowered.items(): + if pattern in lk: + return orig + + return available[0] + + def _detect_shapes_key(self, preferred: list[str]) -> Optional[str]: + available = list(self.sdata.shapes.keys()) + if not available: + return None + + for key in preferred: + if key in available: + return key + + lowered = {k.lower(): k for k in available} + # Fuzzy fallback for newer naming conventions + for pattern in ("cell", "nucleus", "nuclei", "boundar", "polygon", "shape"): + for lk, orig in lowered.items(): + if pattern in lk: + return orig + + return available[0] + + @staticmethod + def _detect_column( + columns: set[str], + candidates: list[str], + optional: bool = False, + ) -> Optional[str]: + for candidate in candidates: + if candidate in columns: + return candidate + if optional: + return None + raise ValueError( + f"Could not detect required column. Tried {candidates}. " + f"Available columns: {sorted(columns)}" + ) + + def _points_to_pandas(self, points_obj) -> pd.DataFrame: + if hasattr(points_obj, "compute"): + points_obj = points_obj.compute() + if isinstance(points_obj, pd.DataFrame): + df = points_obj.copy() + elif hasattr(points_obj, "to_pandas"): + df = points_obj.to_pandas() + else: + df = pd.DataFrame(points_obj) + + # Recover coordinates from geometry when needed + if "geometry" in df.columns and ("x" not in df.columns or "y" not in df.columns): + geom = df["geometry"] + if len(geom) > 0: + try: + df = df.copy() + df["x"] = geom.x + df["y"] = geom.y + except Exception: + pass + + return df + + def transcripts( + self, + normalize: bool = True, + gene_column: Optional[str] = None, + quality_column: Optional[str] = None, + ) -> pl.LazyFrame: + """Load transcripts from SpatialData and normalize to standard fields.""" + std = StandardTranscriptFields() + points_obj = self.sdata.points[self.points_key] + df = self._points_to_pandas(points_obj) + + lf = pl.from_pandas(df).lazy().with_row_index(name=std.row_index) + if not normalize: + return lf + + columns = set(df.columns) + + x_col = self._detect_column(columns, ["x", "x_location", "global_x", "x_global_px"]) + y_col = self._detect_column(columns, ["y", "y_location", "global_y", "y_global_px"]) + z_col = self._detect_column( + columns, + ["z", "z_location", "global_z", "z_global_px"], + optional=True, + ) + + if gene_column is None: + gene_column = self._detect_column( + columns, + ["feature_name", "gene", "target", "gene_name", "feature"], + ) + + if quality_column is None: + quality_column = self._detect_column( + columns, + ["qv", "quality", "quality_score", "score"], + optional=True, + ) + + cell_id_col = self._detect_column( + columns, + ["cell_id", "cell", "segger_cell_id", "segmentation_cell_id", "instance_id"], + optional=True, + ) + + compartment_col = self._detect_column( + columns, + ["cell_compartment", "overlaps_nucleus", "compartment", "CellComp"], + optional=True, + ) + + rename_map = { + x_col: std.x, + y_col: std.y, + gene_column: std.feature, + } + if z_col: + rename_map[z_col] = std.z + if cell_id_col: + rename_map[cell_id_col] = std.cell_id + quality_field = getattr(std, "quality", None) + if quality_column and quality_field: + rename_map[quality_column] = quality_field + + lf = lf.rename({k: v for k, v in rename_map.items() if k != v}) + + # Normalize/derive compartment labels for segmentation masking. + if compartment_col: + # Handle common formats: bool overlaps_nucleus, numeric labels, strings. + source_col = compartment_col + if source_col in rename_map: + source_col = rename_map[source_col] + source_dtype = lf.collect_schema().get(source_col) + + if source_dtype == pl.Boolean: + lf = lf.with_columns( + pl.when(pl.col(source_col)) + .then(std.nucleus_value) + .when(pl.col(std.cell_id).is_not_null()) + .then(std.cytoplasmic_value) + .otherwise(std.extracellular_value) + .alias(std.compartment) + ) + else: + as_str = pl.col(source_col).cast(pl.Utf8).str.to_lowercase() + lf = lf.with_columns( + pl.when( + as_str.is_in(["1", "true", "t", "nucleus", "nuclear"]) + ) + .then(std.nucleus_value) + .when( + as_str.is_in(["2", "cytoplasm", "cytoplasmic", "membrane"]) + ) + .then(std.cytoplasmic_value) + .when(pl.col(std.cell_id).is_not_null()) + .then(std.cytoplasmic_value) + .otherwise(std.extracellular_value) + .alias(std.compartment) + ) + else: + lf = lf.with_columns( + pl.when(pl.col(std.cell_id).is_not_null()) + .then(std.nucleus_value) + .otherwise(std.extracellular_value) + .alias(std.compartment) + ) + + select_cols = [std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment] + schema_names = _lazyframe_column_names(lf) + + if z_col and std.z in schema_names: + select_cols.append(std.z) + + quality_field = getattr(std, "quality", None) + if quality_field and quality_field in schema_names: + select_cols.append(quality_field) + + return lf.select(select_cols) + + def _normalize_boundary_ids(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + std = StandardBoundaryFields() + if std.id in gdf.columns: + return gdf + + for candidate in ( + "cell_id", + "cell", + "instance_id", + "segger_cell_id", + "id", + "label", + "EntityID", + ): + if candidate in gdf.columns: + gdf = gdf.copy() + gdf[std.id] = gdf[candidate] + return gdf + + gdf = gdf.reset_index(drop=False) + index_col = gdf.columns[0] + gdf[std.id] = gdf[index_col] + return gdf + + def _prepare_shapes( + self, + shape_key: str, + boundary_label: str, + ) -> gpd.GeoDataFrame: + std = StandardBoundaryFields() + raw = self.sdata.shapes[shape_key] + gdf = _safe_to_geodataframe(raw) + gdf = self._normalize_boundary_ids(gdf) + + gdf = gdf[gdf.geometry.notnull()].copy() + if not gdf.empty: + try: + gdf["geometry"] = gdf.geometry.buffer(0) + except Exception: + pass + gdf["geometry"] = gdf.geometry.apply(_largest_polygon) + gdf = gdf[gdf.geometry.notnull()].copy() + gdf = gdf[~gdf.geometry.is_empty].copy() + + gdf[std.boundary_type] = boundary_label + return gdf + + def boundaries( + self, + boundary_type: Literal["cell", "nucleus", "all"] = "all", + ) -> Optional[gpd.GeoDataFrame]: + """Load boundaries from SpatialData and normalize to standard fields.""" + std = StandardBoundaryFields() + + parts: list[gpd.GeoDataFrame] = [] + if boundary_type in {"cell", "all"} and self.cell_shapes_key is not None: + parts.append(self._prepare_shapes(self.cell_shapes_key, std.cell_value)) + + if boundary_type in {"nucleus", "all"} and self.nucleus_shapes_key is not None: + parts.append(self._prepare_shapes(self.nucleus_shapes_key, std.nucleus_value)) + + if not parts: + # Fallback: if no specific key detected but shapes exist, use first key as cell shapes. + available = list(self.sdata.shapes.keys()) + if not available: + return None + parts.append(self._prepare_shapes(available[0], std.cell_value)) + + result = gpd.GeoDataFrame( + pd.concat(parts, ignore_index=True), + geometry="geometry", + crs=parts[0].crs if parts and hasattr(parts[0], "crs") else None, + ) + + # Compute contains_nucleus when possible + if std.contains_nucleus not in result.columns: + if std.boundary_type in result.columns: + nucleus_ids = set( + result.loc[ + result[std.boundary_type] == std.nucleus_value, + std.id, + ].astype(str) + ) + result[std.contains_nucleus] = result[std.id].astype(str).isin(nucleus_ids) + result.loc[ + result[std.boundary_type] == std.nucleus_value, + std.contains_nucleus, + ] = True + else: + result[std.contains_nucleus] = False + + return result + + +def load_from_spatialdata( + path: Path | str, + points_key: Optional[str] = None, + cell_shapes_key: Optional[str] = None, + nucleus_shapes_key: Optional[str] = None, + boundary_type: Literal["cell", "nucleus", "all"] = "all", + normalize: bool = True, +) -> tuple[pl.LazyFrame, Optional[gpd.GeoDataFrame]]: + """Convenience loader for SpatialData .zarr stores.""" + loader = SpatialDataLoader( + path=path, + points_key=points_key, + cell_shapes_key=cell_shapes_key, + nucleus_shapes_key=nucleus_shapes_key, + ) + tx = loader.transcripts(normalize=normalize) + bd = loader.boundaries(boundary_type=boundary_type) + return tx, bd + + +def is_spatialdata_path(path: Path | str) -> bool: + """Check whether a path looks like a SpatialData zarr store.""" + p = Path(path) + return ( + p.suffix == ".zarr" + or (p / ".zgroup").exists() + or (p / "zarr.json").exists() + or (p / "points").exists() + or (p / "shapes").exists() + or (p / "tables").exists() + ) diff --git a/src/segger/utils/__init__.py b/src/segger/utils/__init__.py new file mode 100644 index 0000000..78a5eb6 --- /dev/null +++ b/src/segger/utils/__init__.py @@ -0,0 +1,61 @@ +"""Utility modules for Segger.""" +import logging +import os +import sys + +def setup_logging(level: str = "WARNING", log_file: str = None): + fmt = "%(asctime)s | %(levelname)-8s | %(name)s:%(lineno)d - %(message)s" + datefmt = "%Y-%m-%d %H:%M:%S" + + handlers = [logging.StreamHandler(sys.stdout)] + if log_file: + handlers.append(logging.FileHandler(log_file)) + + logging.basicConfig( + level=getattr(logging, level.upper()), + format=fmt, + datefmt=datefmt, + handlers=handlers, + force=True, # override any previously set handlers + ) + +from segger.utils.optional_deps import ( + # Availability flags + SPATIALDATA_AVAILABLE, + SPATIALDATA_IO_AVAILABLE, + # Import functions (raise ImportError if missing) + require_spatialdata, + require_spatialdata_io, + # Decorators for functions requiring optional deps + requires_spatialdata, + requires_spatialdata_io, + # Warning functions for soft failures + warn_spatialdata_unavailable, + warn_spatialdata_io_unavailable, + warn_rapids_unavailable, + # RAPIDS helpers + require_rapids, + # Version utilities + get_spatialdata_version, + check_spatialdata_version, +) + +__all__ = [ + # Availability flags + "SPATIALDATA_AVAILABLE", + "SPATIALDATA_IO_AVAILABLE", + # Import functions + "require_spatialdata", + "require_spatialdata_io", + # Decorators + "requires_spatialdata", + "requires_spatialdata_io", + # Warning functions + "warn_spatialdata_unavailable", + "warn_spatialdata_io_unavailable", + "warn_rapids_unavailable", + "require_rapids", + # Version utilities + "get_spatialdata_version", + "check_spatialdata_version", +] diff --git a/src/segger/utils/optional_deps.py b/src/segger/utils/optional_deps.py new file mode 100644 index 0000000..72b9553 --- /dev/null +++ b/src/segger/utils/optional_deps.py @@ -0,0 +1,309 @@ +"""Optional dependency handling with informative warnings. + +This module provides lazy import wrappers for optional dependencies +(spatialdata, spatialdata-io) with clear installation instructions +when the dependencies are not available. + +Usage +----- +Check availability: + >>> from segger.utils.optional_deps import SPATIALDATA_AVAILABLE + >>> if SPATIALDATA_AVAILABLE: + ... import spatialdata + +Require and get import (raises ImportError with instructions if missing): + >>> from segger.utils.optional_deps import require_spatialdata + >>> spatialdata = require_spatialdata() + +Decorator for functions requiring optional deps: + >>> from segger.utils.optional_deps import requires_spatialdata + >>> @requires_spatialdata + ... def my_function(): + ... import spatialdata + ... return spatialdata.SpatialData() +""" + +from __future__ import annotations + +import functools +import importlib +import importlib.util +import warnings +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +if TYPE_CHECKING: + import types + +# Type variable for decorator +F = TypeVar("F", bound=Callable[..., Any]) + + +# ----------------------------------------------------------------------------- +# Availability flags +# ----------------------------------------------------------------------------- + +def _check_spatialdata() -> bool: + """Check if spatialdata is available.""" + try: + return importlib.util.find_spec("spatialdata") is not None + except Exception: + return False + + +def _check_spatialdata_io() -> bool: + """Check if spatialdata-io is available.""" + try: + return importlib.util.find_spec("spatialdata_io") is not None + except Exception: + return False + + +# Availability flags (evaluated once at import time) +SPATIALDATA_AVAILABLE: bool = _check_spatialdata() +SPATIALDATA_IO_AVAILABLE: bool = _check_spatialdata_io() + + +# ----------------------------------------------------------------------------- +# Installation instructions +# ----------------------------------------------------------------------------- + +SPATIALDATA_INSTALL_MSG = """ +spatialdata is not installed. This package is required for SpatialData I/O support. + +To install spatialdata support: + pip install segger[spatialdata] + +Or install spatialdata directly: + pip install spatialdata>=0.7.2 +""" + +SPATIALDATA_IO_INSTALL_MSG = """ +spatialdata-io is not installed. This package is required for reading platform-specific +SpatialData formats (Xenium, MERSCOPE, CosMX). + +To install spatialdata-io support: + pip install segger[spatialdata-io] + +For full SpatialData support: + pip install segger[spatialdata] + +Or install spatialdata-io directly: + pip install spatialdata-io>=0.6.0 +""" + + + +# ----------------------------------------------------------------------------- +# Import functions with error messages +# ----------------------------------------------------------------------------- + +def require_spatialdata() -> "types.ModuleType": + """Import and return spatialdata, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The spatialdata module. + + Raises + ------ + ImportError + If spatialdata is not installed, with installation instructions. + """ + if not SPATIALDATA_AVAILABLE: + raise ImportError(SPATIALDATA_INSTALL_MSG) + import spatialdata + return spatialdata + + +def require_spatialdata_io() -> "types.ModuleType": + """Import and return spatialdata_io, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The spatialdata_io module. + + Raises + ------ + ImportError + If spatialdata-io is not installed, with installation instructions. + """ + if not SPATIALDATA_IO_AVAILABLE: + raise ImportError(SPATIALDATA_IO_INSTALL_MSG) + import spatialdata_io + return spatialdata_io + + +# ----------------------------------------------------------------------------- +# Decorators for requiring optional dependencies +# ----------------------------------------------------------------------------- + +def requires_spatialdata(func: F) -> F: + """Decorator that raises ImportError if spatialdata is not available. + + Parameters + ---------- + func + Function that requires spatialdata. + + Returns + ------- + F + Wrapped function that checks for spatialdata before execution. + + Examples + -------- + >>> @requires_spatialdata + ... def load_from_zarr(path): + ... import spatialdata + ... return spatialdata.read_zarr(path) + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_spatialdata() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +def requires_spatialdata_io(func: F) -> F: + """Decorator that raises ImportError if spatialdata-io is not available. + + Parameters + ---------- + func + Function that requires spatialdata-io. + + Returns + ------- + F + Wrapped function that checks for spatialdata-io before execution. + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_spatialdata_io() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +# ----------------------------------------------------------------------------- +# Warning functions for soft failures +# ----------------------------------------------------------------------------- + +def warn_spatialdata_unavailable(feature: str = "SpatialData support") -> None: + """Emit a warning that spatialdata is not available. + + Parameters + ---------- + feature + Description of the feature requiring spatialdata. + """ + warnings.warn( + f"{feature} requires spatialdata. " + "Install with: pip install segger[spatialdata]", + UserWarning, + stacklevel=2, + ) + + +def warn_spatialdata_io_unavailable(feature: str = "Platform-specific SpatialData readers") -> None: + """Emit a warning that spatialdata-io is not available. + + Parameters + ---------- + feature + Description of the feature requiring spatialdata-io. + """ + warnings.warn( + f"{feature} requires spatialdata-io. " + "Install with: pip install segger[spatialdata-io]", + UserWarning, + stacklevel=2, + ) + + +def _import_optional_packages(packages: list[str]) -> tuple[dict[str, "types.ModuleType"], list[str]]: + """Import optional packages and return (modules, missing).""" + modules: dict[str, "types.ModuleType"] = {} + missing: list[str] = [] + for package in packages: + try: + modules[package] = importlib.import_module(package) + except Exception: + missing.append(package) + return modules, missing + + +def require_rapids( + packages: list[str] | None = None, + feature: str = "Segger", +) -> dict[str, "types.ModuleType"]: + """Import RAPIDS-related packages or raise with installation instructions.""" + package_list = packages or ["cupy", "cudf", "cuml", "cugraph", "cuspatial"] + modules, missing = _import_optional_packages(package_list) + if missing: + missing_list = ", ".join(missing) + raise ImportError( + f"{feature} requires RAPIDS GPU packages: {missing_list}. " + + RAPIDS_INSTALL_MSG.strip() + ) + return modules + + +def warn_rapids_unavailable( + feature: str = "Segger", + packages: list[str] | None = None, +) -> bool: + """Warn if RAPIDS-related packages are unavailable. Returns True if present.""" + package_list = packages or ["cupy", "cudf", "cuml", "cugraph", "cuspatial"] + _, missing = _import_optional_packages(package_list) + if not missing: + return True + missing_list = ", ".join(missing) + warnings.warn( + f"{feature} requires RAPIDS GPU packages ({missing_list}). " + + RAPIDS_INSTALL_MSG.strip(), + UserWarning, + stacklevel=2, + ) + return False + + +# ----------------------------------------------------------------------------- +# Version checking +# ----------------------------------------------------------------------------- + +def get_spatialdata_version() -> str | None: + """Get the installed spatialdata version, or None if not installed.""" + if not SPATIALDATA_AVAILABLE: + return None + try: + import spatialdata + return getattr(spatialdata, "__version__", "unknown") + except Exception: + return None + + +def check_spatialdata_version(min_version: str = "0.7.2") -> bool: + """Check if spatialdata version meets minimum requirement. + + Parameters + ---------- + min_version + Minimum required version string. + + Returns + ------- + bool + True if version is sufficient, False otherwise. + """ + version = get_spatialdata_version() + if version is None or version == "unknown": + return False + + try: + from packaging.version import Version + return Version(version) >= Version(min_version) + except ImportError: + # Fallback to simple string comparison + return version >= min_version