diff --git a/src/segger/cli/segment.py b/src/segger/cli/segment.py index dff7e05..8b95e03 100644 --- a/src/segger/cli/segment.py +++ b/src/segger/cli/segment.py @@ -154,7 +154,7 @@ def segment( group=group_prediction, )] = registry.get_default("prediction_graph_max_k"), - prediction_expansion_ratio: Annotated[float | None, registry.get_parameter( + prediction_graph_buffer_ratio: Annotated[float | None, registry.get_parameter( "prediction_graph_buffer_ratio", validator=validators.Number(gt=0), group=group_prediction, @@ -323,7 +323,7 @@ def segment( transcripts_graph_max_dist=transcripts_max_dist, prediction_graph_mode=prediction_mode, prediction_graph_max_k=prediction_max_k, - prediction_graph_buffer_ratio=prediction_expansion_ratio, + prediction_graph_buffer_ratio=prediction_graph_buffer_ratio, tiling_margin_training=tiling_margin_training, tiling_margin_prediction=tiling_margin_prediction, tiling_nodes_per_tile=max_nodes_per_tile, diff --git a/src/segger/data/tiling.py b/src/segger/data/tiling.py index e4e3d09..59f5e16 100644 --- a/src/segger/data/tiling.py +++ b/src/segger/data/tiling.py @@ -199,7 +199,7 @@ def __init__( ): # Calculate QuadTree on points and set as tiles points = points_to_geoseries(positions, backend='cuspatial') - _, quadtree = get_quadtree_index( + _, quadtree, _ = get_quadtree_index( points, max_tile_size, with_bounds=True, diff --git a/src/segger/geometry/quadtree.py b/src/segger/geometry/quadtree.py index c8f8ee4..a2d0114 100644 --- a/src/segger/geometry/quadtree.py +++ b/src/segger/geometry/quadtree.py @@ -25,13 +25,14 @@ def get_quadtree_kwargs( A dictionary of keyword arguments including x_min, x_max, y_min, y_max, scale, and max_depth. """ - # Calculate bounds - x_min = float(points.points.x.min()) - x_max = float(points.points.x.max()) - y_min = float(points.points.y.min()) - y_max = float(points.points.y.max()) - - # Get hyperparams for quadtree + # Calculate bounds | Optimisation: Use interleaved view, without copying data + xy = cp.asarray(points.points.xy).reshape(-1, 2) # zero-copy view + x_min = float(xy[:, 0].min()) + x_max = float(xy[:, 0].max()) + y_min = float(xy[:, 1].min()) + y_max = float(xy[:, 1].max()) + + # Get hyperparams for quadtree extent = max(x_max - x_min, y_max - y_min) max_depth = 1 while extent // (1 << max_depth) > 0: @@ -140,7 +141,7 @@ def get_quadtree_index( points: cuspatial.GeoSeries, max_size: int, with_bounds: bool = True, -) -> tuple[cudf.Series, cudf.DataFrame]: +) -> tuple[cudf.Series, cudf.DataFrame, dict]: """Build a cuSpatial quadtree from 2D point data. Parameters @@ -190,7 +191,7 @@ def get_quadtree_index( y_max=y_max, ) - return indices, quadtree + return indices, quadtree, kwargs def quadtree_to_geoseries( diff --git a/src/segger/geometry/query.py b/src/segger/geometry/query.py index 022803c..31fc46a 100644 --- a/src/segger/geometry/query.py +++ b/src/segger/geometry/query.py @@ -3,6 +3,7 @@ import numpy as np import cuspatial import cudf +import cupy as cp from .conversion import ( polygons_to_geoseries, @@ -47,14 +48,15 @@ def _points_in_polygons_contains( mapping each contained point to its containing polygon. """ # Setup inputs for spatial join + cp.get_default_memory_pool().free_all_blocks() + cp.get_default_pinned_memory_pool().free_all_blocks() if max_size is None: max_size = 10000 if len(points) > 5e7 else 1000 # heuristic - point_indices, quadtree = get_quadtree_index( + point_indices, quadtree, kwargs = get_quadtree_index( points, max_size, with_bounds=False ) - kwargs = get_quadtree_kwargs(points) # Perform spatial join in batches batch_idx = np.linspace(0, len(polygons), (batches or 1) + 1, dtype=int) diff --git a/src/segger/io/fields.py b/src/segger/io/fields.py index 40bd6be..8dc30a9 100644 --- a/src/segger/io/fields.py +++ b/src/segger/io/fields.py @@ -23,6 +23,26 @@ class XeniumTranscriptFields: 'UnassignedCodeword_*', ] +@dataclass +class XeniumTranscriptFieldsV1: + filename: str = 'transcripts.parquet' + x: str = 'x_location' + y: str = 'y_location' + feature: str = 'feature_name' + cell_id: str = 'cell_id' + null_cell_id: str = "-1" + compartment: str = 'overlaps_nucleus' + nucleus_value: int = 1 + quality: str = 'qv' + filter_substrings = [ + 'NegControlProbe_*', + 'antisense_*', + 'NegControlCodeword*', + 'BLANK_*', + 'DeprecatedCodeword_*', + 'UnassignedCodeword_*', + ] + @dataclass class XeniumBoundaryFields: cell_filename: str = 'cell_boundaries.parquet' diff --git a/src/segger/io/preprocessor.py b/src/segger/io/preprocessor.py index 597a818..8320230 100644 --- a/src/segger/io/preprocessor.py +++ b/src/segger/io/preprocessor.py @@ -7,6 +7,7 @@ import geopandas as gpd import polars as pl import pandas as pd +import json import warnings import logging import sys @@ -21,7 +22,8 @@ MerscopeBoundaryFields, StandardTranscriptFields, StandardBoundaryFields, - XeniumTranscriptFields, + XeniumTranscriptFields, + XeniumTranscriptFieldsV1, XeniumBoundaryFields, CosMxTranscriptFields, CosMxBoundaryFields, @@ -372,16 +374,48 @@ class XeniumPreprocessor(ISTPreprocessor): """ Preprocessor for 10x Genomics Xenium datasets. """ + + tx_fields = XeniumTranscriptFields() + bd_fields = XeniumBoundaryFields() + sw_version = lambda version: version[0] > 1 + @staticmethod - def _validate_directory(data_dir: Path): + def _get_analysis_sw_version(data_dir: Path) -> str: + """ + Get 10x xenium analysis software version. Example experiment.xenium file: + { + ..., + "analysis_sw_version": "xenium-3.3.1.1" + } + Return: + version : list of ints representing major, minor, and patch version numbers (e.g. [3, 3, 1, 1]) + """ + # get version + path_meta = data_dir / "experiment.xenium" + with open(path_meta) as f: + meta = json.load(f) + # version can be xenium-x.y.z or Xenium-x.y.z, ... + version = meta["analysis_sw_version"].split("-")[-1].split(".") + version = [int(v) for v in version] + return version + + @classmethod + def _validate_directory(cls, data_dir: Path): + + # Apply xenium software version 2 or higher (when cell id "Unassigned" was introduced. Previously -1) + version = XeniumPreprocessor._get_analysis_sw_version(data_dir) + if not cls.sw_version(version): + raise IOError( + f"Xenium analysis software version must be 2.0.0 or higher, " + f"but found version {'.'.join(version)}." + ) + # Check required files/directories - bd_fields = XeniumBoundaryFields() - tx_fields = XeniumTranscriptFields() for pat in [ - tx_fields.filename, - bd_fields.cell_filename, - bd_fields.nucleus_filename, + cls.tx_fields.filename, + cls.bd_fields.cell_filename, + cls.bd_fields.nucleus_filename, ]: num_matches = len(list(data_dir.glob(pat))) if not num_matches == 1: @@ -394,7 +428,7 @@ def _validate_directory(data_dir: Path): def transcripts(self) -> pl.DataFrame: # Field names - raw = XeniumTranscriptFields() + raw = self.tx_fields std = StandardTranscriptFields() return ( @@ -405,6 +439,11 @@ def transcripts(self) -> pl.DataFrame: ) # Add numeric index at beginning .with_row_index(name=std.row_index) + # Cast binary columns to string (Some Xenium parquet stores these as binary) + .with_columns( + pl.col(raw.feature).cast(pl.Utf8), + pl.col(raw.cell_id).cast(pl.Utf8), + ) # Filter data .filter(pl.col(raw.quality) >= 20) .filter(pl.col(raw.feature).str.contains( @@ -437,15 +476,16 @@ def transcripts(self) -> pl.DataFrame: .collect() ) - @staticmethod + @classmethod def _get_boundaries( + cls, filepath: Path, boundary_type: str ) -> gpd.GeoDataFrame: # TODO: Add documentation # Field names - raw = XeniumBoundaryFields() + raw = cls.bd_fields std = StandardBoundaryFields() # Read in flat vertices and convert to geometries @@ -463,7 +503,7 @@ def _get_boundaries( @cached_property def boundaries(self) -> gpd.GeoDataFrame: # TODO: Add documentation - raw = XeniumBoundaryFields() + raw = self.bd_fields std = StandardBoundaryFields() # Join boundary datasets @@ -496,14 +536,24 @@ def boundaries(self) -> gpd.GeoDataFrame: cells.reset_index(drop=False, names=std.id), nuclei.reset_index(drop=False, names=std.id), ]) - # Convert index to string type (to join on AnnData) - bd.index = bd[std.id] + '_' + bd[std.boundary_type].map({ + # cell_id is string in later 10x versions, but int in earlier versions. + bd.index = bd[std.id].astype(str) + '_' + bd[std.boundary_type].map({ std.nucleus_value: '0', std.cell_value: '1', }) return bd +@register_preprocessor("10x_xenium_v1") +class XeniumPreprocessorV1(XeniumPreprocessor): + """ + Preprocessor for 10x Genomics Xenium datasets analyzed with software version 1.x. + """ + + tx_fields = XeniumTranscriptFieldsV1() + bd_fields = XeniumBoundaryFields() + sw_version = lambda version: version[0] == 1 + @register_preprocessor("vizgen_merscope") class MerscopePreprocessor(ISTPreprocessor):