Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3217285
Add fields for 3d and qv filtering
enric-bazz May 5, 2026
ebf5105
Adapt for 3d and qv filtering, add column remapping for robustness
enric-bazz May 5, 2026
f99b3fb
Add arguments for 3d, qv filtering; pass spatialdata save flag to wit…
enric-bazz May 5, 2026
6e9a264
Support 3d input and qv filtering, initialize spatialdata loader, oth…
enric-bazz May 5, 2026
e0424e8
Add spatialdata loader module, expose it, add all optional dependenci…
enric-bazz May 5, 2026
e4e8846
Add all export modules from v2-incremental
enric-bazz May 5, 2026
27e6816
Update with optional dependencies
enric-bazz May 5, 2026
2049276
Restore debugging API
enric-bazz May 5, 2026
0bcff40
Adjust parameters registering, align spatialdata writers arguments
enric-bazz May 5, 2026
344365e
Return separate shape elements on input boundarie
enric-bazz May 5, 2026
8948d7a
Fix table parsing on input boundaries and improve code behavior
enric-bazz May 5, 2026
960fbf2
Fix regex construction for pattern filtering
enric-bazz May 5, 2026
93d29b3
Improve code robustness on dataframe joins
enric-bazz May 6, 2026
bdec108
Clean code for convex hull and delaunay boundaries on cells (no fragm…
enric-bazz May 6, 2026
c26dd83
Clean export module to minimal APIs for spatialdata writing
enric-bazz May 6, 2026
413c217
Remove scripts dir from v2-incremental branch
enric-bazz May 6, 2026
59d91fb
Remove sopa support from modules and dependencies
enric-bazz May 6, 2026
733c24b
Remove additional modules from v2-incremental
enric-bazz May 6, 2026
149af1d
Lower bound PCA dimensionality to number of genes
enric-bazz May 6, 2026
29e8028
Move required functions within spatialdata loader module
enric-bazz May 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,38 @@ dependencies = [
"opencv-python",
"pandas",
"polars",
"pqdm",
"pyarrow",
"rtree",
"scanpy",
"scipy",
"shapely",
"scikit-image",
"scikit-learn",
"tifffile",
"torch_geometric",
"zarr",
]

[project.optional-dependencies]
spatialdata = [
"spatialdata>=0.7.2",
"spatialdata-io>=0.6.0",
]

spatialdata-io = [
"spatialdata-io>=0.6.0",
]

spatialdata-all = [
"spatialdata>=0.7.2",
"spatialdata-io>=0.6.0",
"sopa>=2.0.0",
]

plot = [
"matplotlib>=3.7",
"uniplot>=0.10.0",
]

[build-system]
Expand All @@ -39,4 +63,4 @@ build-backend = "hatchling.build"
packages = ["src/segger"]

[project.scripts]
segger = "segger.cli.main:app"
segger = "segger.cli.main:app"
57 changes: 57 additions & 0 deletions src/segger/cli/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@
help="Related to loss function parameters.",
sort_key=7,
)
group_quality = Group(
name="Quality Filtering",
help="Related to transcript quality filtering.",
sort_key=8,
)
group_3d = Group(
name="3D Support",
help="Related to 3D coordinate handling.",
sort_key=9,
)

def _resolve_use_3d_flag(use_3d: Literal["auto", "true", "false"]) -> bool | str:
if use_3d == "auto":
return "auto"
return use_3d == "true"

app_segment = App(name="segment", help="Run cell segmentation on spatial transcriptomics data.")

Expand Down Expand Up @@ -293,16 +308,53 @@ def segment(
"save_anndata",
group=group_io,
)] = registry.get_default("save_anndata"),

save_spatialdata: Annotated[bool, registry.get_parameter(
"save_spatialdata",
group=group_io, # might change
)] = registry.get_default("save_spatialdata"),

boundary_method: Annotated[
Literal["convex_hull", "delaunay", "skip"],
registry.get_parameter(
"boundary_method",
group=group_io, # might change
)] = registry.get_default("boundary_method"),

debug: Annotated[bool, Parameter(
help="Whether to save additional debug information (trainer, predictions).",
)] = "none",

# Quality filtering
min_qv: Annotated[float | None, Parameter(
help="Minimum transcript quality threshold. Set to 0 to disable.",
validator=validators.Number(gte=0),
group=group_quality,
)] = 20.0,

# 3D support
use_3d: Annotated[
Literal["auto", "true", "false"],
Parameter(
help="Use 3D coordinates for graph construction ('false' default).",
group=group_3d,
),
] = "false",
):
"""Run cell segmentation on spatial transcriptomics data."""

# Setup logger and debug directory
logger = logging.getLogger(__name__)

use_3d_value = _resolve_use_3d_flag(use_3d)

output_directory = Path(output_directory)
if output_directory.exists() and not output_directory.is_dir():
raise ValueError(
f"Output path exists and is not a directory: {output_directory}"
)
output_directory.mkdir(parents=True, exist_ok=True)

# Remove SLURM environment autodetect
from lightning.pytorch.plugins.environments import SLURMEnvironment
SLURMEnvironment.detect = lambda: False
Expand All @@ -328,6 +380,8 @@ def segment(
tiling_margin_prediction=tiling_margin_prediction,
tiling_nodes_per_tile=max_nodes_per_tile,
edges_per_batch=max_edges_per_batch,
use_3d=use_3d_value,
min_qv=min_qv,
)

# Setup Lightning Model
Expand Down Expand Up @@ -364,8 +418,11 @@ def segment(

csvlogger = CSVLogger(output_directory)
writer = ISTSegmentationWriter(
input_directory,
output_directory,
save_anndata=save_anndata,
save_spatialdata=save_spatialdata,
boundary_method=boundary_method,
debug=debug,
)
trainer = Trainer(
Expand Down
69 changes: 62 additions & 7 deletions src/segger/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from lightning.pytorch import LightningDataModule
from torchvision.transforms import Compose
from dataclasses import dataclass
from typing import Literal
from typing import Literal, Optional
from pathlib import Path
import polars as pl
import torch
import gc
import os
import numpy as np

from .tile_dataset import (
Expand Down Expand Up @@ -143,6 +144,8 @@ class ISTDataModule(LightningDataModule):
prediction_graph_mode: Literal["nucleus", "cell", "uniform"] = "cell"
prediction_graph_max_k: int = 3
prediction_graph_buffer_ratio: float = 0.05
use_3d: bool | Literal["auto"] = False
min_qv: Optional[float] = 20.0
tiling_mode: Literal["adaptive", "square"] = "adaptive" # TODO: Remove (benchmarking only)
tiling_margin_training: float = 20.
tiling_margin_prediction: float = 20.
Expand All @@ -166,11 +169,54 @@ def load(self):
tx_fields = StandardTranscriptFields()
bd_fields = StandardBoundaryFields()

# Load standardized IST data
self.logger.debug(f"Loading standardized IST data from {self.input_directory}...")
pp = get_preprocessor(self.input_directory)
tx = self.tx = pp.transcripts
bd = self.bd = pp.boundaries
# Load standardized IST data (raw platform directory or SpatialData .zarr)
input_path = Path(self.input_directory)
tx = None
bd = None

try:
from ..io.spatialdata_loader import (
is_spatialdata_path,
load_from_spatialdata,
)
has_spatialdata_loader = True
except Exception:
has_spatialdata_loader = False

if has_spatialdata_loader and is_spatialdata_path(input_path):
tx_lf, bd = load_from_spatialdata(
input_path,
boundary_type="all",
normalize=True,
)
tx = tx_lf.collect() if isinstance(tx_lf, pl.LazyFrame) else tx_lf

# Keep behavior consistent with raw Xenium filtering when quality exists.
quality_col = getattr(tx_fields, "quality", "qv")
if (
self.min_qv is not None
and self.min_qv > 0
and quality_col in tx.columns
):
tx = tx.filter(pl.col(quality_col) >= self.min_qv)
else:
pp = get_preprocessor(
self.input_directory,
min_qv=self.min_qv,
include_z=(self.use_3d is not False),
)
tx = pp.transcripts
bd = pp.boundaries

self.tx = tx
self.bd = bd

if bd is None or len(bd) == 0:
raise ValueError(
"No boundary shapes found in input data. "
"Segger requires cell/nucleus polygons in raw input or SpatialData shapes."
)

# Mask transcripts to reference segmentation
if self.segmentation_graph_mode == "nucleus":
Expand All @@ -187,8 +233,16 @@ def load(self):
f"Unrecognized segmentation graph mode: "
f"'{self.segmentation_graph_mode}'."
)
tx_mask = pl.col(tx_fields.compartment).is_in(compartments)
bd_mask = bd[bd_fields.boundary_type] == boundary_type

if tx_fields.compartment in tx.columns:
tx_mask = pl.col(tx_fields.compartment).is_in(compartments)
else:
tx_mask = pl.col(tx_fields.cell_id).is_not_null()

if bd_fields.boundary_type in bd.columns:
bd_mask = bd[bd_fields.boundary_type] == boundary_type
else:
bd_mask = np.ones(len(bd), dtype=bool)

# Generate reference AnnData
self.logger.debug("Generating reference AnnData object...")
Expand Down Expand Up @@ -222,6 +276,7 @@ def load(self):
prediction_graph_mode=self.prediction_graph_mode,
prediction_graph_max_k=self.prediction_graph_max_k,
prediction_graph_buffer_ratio=self.prediction_graph_buffer_ratio,
use_3d=self.use_3d,
)

# Tile graph dataset
Expand Down
7 changes: 5 additions & 2 deletions src/segger/data/utils/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def setup_anndata(
ad.obs
.join(
(
boundaries
boundaries.drop_duplicates(subset=bd_fields.id) # some data oddly has duplicate boundary entries on the same cell id
.reset_index(names=bd_fields.index)
.set_index(bd_fields.id, verify_integrity=True)
.get(bd_fields.index)
Expand Down Expand Up @@ -195,7 +195,10 @@ def setup_anndata(
# Build gene embedding on filtered dataset
C = np.corrcoef(ad[ad.obs['filtered']].layers['norm'].todense().T)
C = np.nan_to_num(C, 0, posinf=True, neginf=True)
model = sklearn.decomposition.PCA(n_components=cells_embedding_size)
model = sklearn.decomposition.PCA(n_components=min(cells_embedding_size, ad.var.shape[0]))
if ad.var.shape[0] < cells_embedding_size:
import warnings
warnings.warn('cell embedding size is larger than input feature space, falling back to that size.')
ad.varm['X_corr'] = model.fit_transform(C)

# Build PCs on filtered cells and project all cells
Expand Down
17 changes: 15 additions & 2 deletions src/segger/data/utils/heterodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def setup_heterodata(
prediction_graph_mode: Literal["nucleus", "cell", "uniform"],
prediction_graph_max_k: int,
prediction_graph_buffer_ratio: float,
use_3d: bool | Literal["auto"] = False,
cells_embedding_key: str = 'X_pca',
cells_clusters_column: str = 'phenograph_cluster',
cells_encoding_column: str = 'cell_encoding',
Expand All @@ -44,6 +45,11 @@ def setup_heterodata(
tx_fields.cell_cluster,
tx_fields.gene_cluster,
]

transcripts = transcripts.with_columns(
pl.col(tx_fields.feature).cast(pl.Utf8)
)

# Update transcripts with fields for training

transcripts = (
Expand All @@ -55,9 +61,14 @@ def setup_heterodata(
pl.from_pandas(
adata.var[[genes_encoding_column, genes_clusters_column]],
include_index=True
),
).rename({
pl.from_pandas(
adata.var[[genes_encoding_column, genes_clusters_column]],
include_index=True
).columns[0]: tx_fields.feature
}),
left_on=tx_fields.feature,
right_on=adata.var.index.name if adata.var.index.name else 'None',
right_on=tx_fields.feature,
)
.rename(
{
Expand Down Expand Up @@ -135,6 +146,7 @@ def setup_heterodata(
transcripts,
max_k=transcripts_graph_max_k,
max_dist=transcripts_graph_max_dist,
use_3d=use_3d,
)

# Reference segmentation graph
Expand All @@ -150,6 +162,7 @@ def setup_heterodata(
max_k=prediction_graph_max_k,
buffer_ratio=prediction_graph_buffer_ratio,
mode=prediction_graph_mode,
use_3d=use_3d if prediction_graph_mode == "uniform" else False,
)

return data
36 changes: 33 additions & 3 deletions src/segger/data/utils/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def edge_index_to_knn(
def kdtree_neighbors(
points: np.ndarray,
max_k: int,
max_dist: float,
max_dist: float = np.inf,
query: np.ndarray | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Wrapper for KDTree kNN and conversion to edge_index COO format.
Expand All @@ -148,11 +148,25 @@ def setup_transcripts_graph(
tx: pl.DataFrame,
max_k: int,
max_dist: float,
use_3d: bool | Literal["auto"] = False,
) -> torch.Tensor:
"""TODO: Add description.
"""
tx_fields = TrainingTranscriptFields()
points = tx[[tx_fields.x, tx_fields.y]].to_numpy()
coord_cols = [tx_fields.x, tx_fields.y]
has_z = tx_fields.z in tx.columns

if use_3d == "auto":
use_3d = has_z and tx[tx_fields.z].null_count() < len(tx)
elif use_3d is True and not has_z:
raise ValueError(
f"use_3d=True but z column '{tx_fields.z}' not found in transcripts. "
f"Available columns: {tx.columns}"
)
if use_3d and has_z:
coord_cols.append(tx_fields.z)

points = tx[coord_cols].to_numpy()
edge_index, _ = kdtree_neighbors(
points=points,
max_k=max_k,
Expand Down Expand Up @@ -184,6 +198,7 @@ def setup_prediction_graph(
max_k: int,
buffer_ratio: float,
mode: Literal['nucleus', 'cell', 'uniform'] = 'cell',
use_3d: bool | Literal["auto"] = False,
) -> torch.Tensor:
"""TODO: Add description.
"""
Expand All @@ -192,12 +207,27 @@ def setup_prediction_graph(

# Uniform kNN graph
if mode == "uniform":
points = tx[[tx_fields.x, tx_fields.y]].to_numpy()
coord_cols = [tx_fields.x, tx_fields.y]
has_z = tx_fields.z in tx.columns
if use_3d == "auto":
use_3d = has_z and tx[tx_fields.z].null_count() < len(tx)
elif use_3d is True and not has_z:
raise ValueError(
f"use_3d=True but z column '{tx_fields.z}' not found in transcripts. "
f"Available columns: {tx.columns}"
)
if use_3d and has_z:
coord_cols.append(tx_fields.z)

points = tx[coord_cols].to_numpy()
query = bd.geometry.centroid.get_coordinates().values
if use_3d and len(coord_cols) == 3:
query = np.hstack([query, np.zeros((len(query), 1))])
edge_index, _ = kdtree_neighbors(
points=points,
query=query,
max_k=max_k,
max_dist=np.inf,
)
return edge_index

Expand Down
Loading