diff --git a/biahub/core/__init__.py b/biahub/core/__init__.py new file mode 100644 index 00000000..5238b2e1 --- /dev/null +++ b/biahub/core/__init__.py @@ -0,0 +1,6 @@ +""" +Core modules for geometric operations. + +Provides Transform (homogeneous matrix operations) and Graph/GraphMatcher +(point matching for bead-based registration). +""" diff --git a/biahub/core/graph_matching.py b/biahub/core/graph_matching.py new file mode 100644 index 00000000..faf70ba0 --- /dev/null +++ b/biahub/core/graph_matching.py @@ -0,0 +1,774 @@ +""" +Graph-based point matching for bead registration. + +Provides two classes: +- ``Graph``: Geometric graph built from 2D/3D point clouds with local feature + extraction (edge distances, angles, PCA-based descriptors). +- ``GraphMatcher``: Matches nodes between two graphs using either the Hungarian + algorithm (cost matrix from position + edge consistency) or scikit-image + descriptor matching, with geometric consistency filtering. + +Typical usage:: + + mov_graph = Graph.from_nodes(mov_peaks, mode='knn', k=5) + ref_graph = Graph.from_nodes(ref_peaks, mode='knn', k=5) + matcher = GraphMatcher(algorithm='hungarian', cross_check=True) + matches = matcher.match(mov_graph, ref_graph) + matches = matcher.filter_matches(matches, mov_graph, ref_graph, direction_threshold=50) +""" + +from collections import defaultdict +from functools import cached_property +from typing import Literal, Optional + +import click +import numpy as np + +from numpy.typing import NDArray +from scipy.optimize import linear_sum_assignment +from scipy.spatial.distance import cdist +from skimage.feature import match_descriptors +from sklearn.neighbors import NearestNeighbors, radius_neighbors_graph + +# ============================================================ +# GRAPH CLASS +# ============================================================ + + +class Graph: + """ + Geometric graph for 2D/3D point registration with local feature extraction. + + Parameters + ---------- + nodes : NDArray[np.floating] + (N, D) array of nodes (points) + edges : list[tuple[int, int]] + List of edges (pairs of node indices) + mode : Literal["knn", "radius", "full"], default="knn" + Mode for building edges: + - "knn": k-nearest neighbors + - "radius": radius neighbors + - "full": all-to-all + k : int, default=5 + Number of nearest neighbors for k-nearest neighbors mode + radius : float, default=30.0 + Radius for radius neighbors mode + + Examples + -------- + >>> # Build a graph from nodes with k-nearest neighbors + >>> graph = Graph.from_nodes(nodes, mode='knn', k=5) + + >>> # Build a graph from nodes with radius neighbors + >>> graph = Graph.from_nodes(nodes, mode='radius', radius=30.0) + + >>> # Build a graph from nodes with all-to-all edges + >>> graph = Graph.from_nodes(nodes, mode='full') + """ + + def __init__( + self, + nodes: NDArray[np.floating], + edges: list[tuple[int, int]], + ): + self.nodes = np.asarray(nodes, dtype=np.float32) + self._edges = edges + + if self.nodes.ndim != 2: + raise ValueError(f"nodes must be 2D array, got shape {self.nodes.shape}") + if self.dim not in (2, 3): + raise ValueError(f"nodes must be 2D or 3D points, got dim={self.dim}") + + @classmethod + def from_nodes( + cls, + nodes: NDArray[np.floating], + mode: Literal["knn", "radius", "full"] = "knn", + k: int = 5, + radius: float = 30.0, + ) -> "Graph": + """Build a graph from nodes with automatic edge construction.""" + edges = cls._build_edges(nodes, mode=mode, k=k, radius=radius) + return cls(nodes, edges) + + @staticmethod + def _build_edges( + points: NDArray[np.floating], + mode: Literal["knn", "radius", "full"] = "knn", + k: int = 5, + radius: float = 30.0, + ) -> list[tuple[int, int]]: + """Build edges for a point set using various strategies.""" + n = len(points) + if n <= 1: + return [] + + if mode == "knn": + k_eff = min(k + 1, n) + nbrs = NearestNeighbors(n_neighbors=k_eff).fit(points) + _, indices = nbrs.kneighbors(points) + edges = [(i, j) for i in range(n) for j in indices[i] if i != j] + + elif mode == "radius": + graph = radius_neighbors_graph( + points, radius=radius, mode='connectivity', include_self=False + ) + if graph.nnz == 0: + return [] + edges = [(i, j) for i in range(n) for j in graph[i].nonzero()[1]] + + elif mode == "full": + edges = [(i, j) for i in range(n) for j in range(n) if i != j] + + else: + raise ValueError(f"Unknown mode: {mode}") + + return edges + + @property + def n_nodes(self) -> int: + return len(self.nodes) + + @property + def dim(self) -> int: + return self.nodes.shape[1] + + @property + def edges(self) -> list[tuple[int, int]]: + return self._edges + + @cached_property + def neighbor_map(self) -> dict[int, list[int]]: + """Adjacency list representation.""" + neighbors = defaultdict(list) + for i, j in self._edges: + neighbors[i].append(j) + return dict(neighbors) + + @cached_property + def edge_distances(self) -> dict[tuple[int, int], float]: + """Distance for each edge (bidirectional).""" + distances = {} + for i, j in self._edges: + vec = self.nodes[j] - self.nodes[i] + d = float(np.linalg.norm(vec)) + distances[(i, j)] = distances[(j, i)] = d + return distances + + @cached_property + def edge_angles(self) -> dict[tuple[int, int], float]: + """Angle for each edge in radians (2D only).""" + if self.dim != 2: + return {} + + angles = {} + for i, j in self._edges: + vec = self.nodes[j] - self.nodes[i] + angle = float(np.arctan2(vec[1], vec[0])) + angles[(i, j)] = angles[(j, i)] = angle + return angles + + @cached_property + def edge_descriptors(self) -> NDArray[np.floating]: + """ + Local edge statistics for each node. + + Returns (N, 4): [mean_length, std_length, mean_angle, std_angle] + """ + desc = np.zeros((self.n_nodes, 4), dtype=np.float32) + + for i in range(self.n_nodes): + neighbors = self.neighbor_map.get(i, []) + if not neighbors: + continue + + lengths = np.array([self.edge_distances[(i, j)] for j in neighbors]) + desc[i, 0] = np.mean(lengths) + desc[i, 1] = np.std(lengths) + + if self.dim == 2 and self.edge_angles: + angles = np.array([self.edge_angles[(i, j)] for j in neighbors]) + desc[i, 2] = np.mean(angles) + desc[i, 3] = np.std(angles) + + return desc + + @cached_property + def pca_features(self) -> tuple[NDArray, NDArray]: + """ + PCA-based features for local neighborhoods. + + Returns + ------- + directions : (N, D) array of dominant directions + anisotropy : (N,) array of anisotropy ratios + """ + n = self.n_nodes + d = self.dim + directions = np.zeros((n, d), dtype=np.float32) + anisotropy = np.zeros(n, dtype=np.float32) + + for i in range(n): + neighbors = self.neighbor_map.get(i, []) + if not neighbors: + directions[i] = np.nan + anisotropy[i] = np.nan + continue + + local_points = self.nodes[neighbors].copy() + local_points -= local_points.mean(axis=0) + + _, S, Vt = np.linalg.svd(local_points, full_matrices=False) + + directions[i] = Vt[0] if Vt.shape[0] > 0 else np.zeros(d) + anisotropy[i] = S[0] / (S[-1] + 1e-5) if len(S) >= 2 else 0.0 + + return directions, anisotropy + + def get_neighbors(self, node_idx: int) -> list[int]: + """Get neighbor indices for a specific node.""" + return self.neighbor_map.get(node_idx, []) + + def __repr__(self) -> str: + return f"Graph(n_nodes={self.n_nodes}, n_edges={len(self.edges)}, dim={self.dim})" + + +# ============================================================ +# GRAPH MATCHER +# ============================================================ + + +class GraphMatcher: + """ + Matches nodes between two geometric graphs for point registration. + + Supports two matching algorithms: + - 'hungarian': Graph-based matching using cost matrix + Hungarian algorithm + - 'descriptor': Feature-based matching using skimage.feature.match_descriptors + + Parameters + ---------- + algorithm : {'hungarian', 'descriptor'}, default='hungarian' + Matching algorithm to use + weights : dict[str, float], optional + Weights for cost components (hungarian only): + - 'dist': Direct position distance (default: 0.5) + - 'edge_length': Local edge length consistency (default: 1.0) + - 'edge_angle': Local edge angle consistency (default: 1.0) + - 'pca_dir': PCA direction similarity (default: 0.0) + - 'pca_aniso': PCA anisotropy similarity (default: 0.0) + - 'edge_descriptor': Edge descriptor distance (default: 0.0) + distance_metric : str, default='euclidean' + Metric for position distances + normalize : bool, default=False + Whether to normalize each cost component to [0, 1] + cost_threshold : float, default=0.9 + Quantile threshold for accepting matches (0.0-1.0) + cross_check : bool, default=False + Whether to require bidirectional consistency + max_ratio : float, optional + Maximum ratio between best and second-best match (Lowe's ratio test) + metric : str, default='euclidean' + Distance metric for descriptor matching (descriptor algorithm only) + verbose : bool, default=False + Print matching statistics + + Examples + -------- + >>> # Hungarian matching (graph-based) + >>> matcher = GraphMatcher( + ... algorithm='hungarian', + ... weights={'dist': 0.5, 'edge_length': 1.0}, + ... cross_check=True + ... ) + >>> matches = matcher.match(moving_graph, ref_graph) + + >>> # Descriptor matching (feature-based) + >>> matcher = GraphMatcher( + ... algorithm='descriptor', + ... cross_check=True, + ... max_ratio=0.8 + ... ) + >>> matches = matcher.match(moving_graph, ref_graph) + """ + + def __init__( + self, + algorithm: Literal['hungarian', 'descriptor'] = 'hungarian', + weights: Optional[dict[str, float]] = None, + distance_metric: str = 'euclidean', + normalize: bool = False, + cost_threshold: float = 0.9, + cross_check: bool = False, + max_ratio: Optional[float] = None, + metric: str = 'euclidean', # for descriptor matching + verbose: bool = False, + ): + self.algorithm = algorithm + + # Hungarian-specific parameters + default_weights = { + "dist": 0.5, + "edge_length": 1.0, + "edge_angle": 1.0, + "pca_dir": 0.0, + "pca_aniso": 0.0, + "edge_descriptor": 0.0, + } + self.weights = {**default_weights, **(weights or {})} + self.distance_metric = distance_metric + self.normalize = normalize + self.cost_threshold = cost_threshold + + # Common parameters + self.cross_check = cross_check + self.max_ratio = max_ratio + self.verbose = verbose + + # Descriptor matching parameters + self.metric = metric + + def match( + self, + moving: Graph, + reference: Graph, + verbose: Optional[bool] = None, + ) -> NDArray[np.integer]: + """ + Find correspondences between two graphs. + + Parameters + ---------- + moving : Graph + Moving/source graph + reference : Graph + Reference/target graph + verbose : bool, optional + Override instance verbose setting + + Returns + ------- + NDArray + (N_matches, 2) array of [moving_idx, reference_idx] pairs + """ + verbose = verbose if verbose is not None else self.verbose + + # Validate + if moving.dim != reference.dim: + raise ValueError( + f"Dimension mismatch: moving={moving.dim}D, reference={reference.dim}D" + ) + + if moving.n_nodes == 0 or reference.n_nodes == 0: + if verbose: + click.echo("Warning: One or both graphs are empty") + return np.array([]).reshape(0, 2).astype(np.int32) + + # Dispatch to appropriate algorithm + if self.algorithm == 'hungarian': + return self._match_hungarian(moving, reference, verbose) + elif self.algorithm == 'descriptor': + return self._match_descriptor(moving, reference, verbose) + else: + raise ValueError(f"Unknown algorithm: {self.algorithm}") + + # ============================================================ + # HUNGARIAN MATCHING + # ============================================================ + + def _match_hungarian( + self, + moving: Graph, + reference: Graph, + verbose: bool, + ) -> NDArray[np.integer]: + """Hungarian algorithm matching.""" + if self.cross_check: + return self._match_hungarian_cross_check(moving, reference, verbose) + else: + return self._match_hungarian_single(moving, reference, verbose) + + def _match_hungarian_single( + self, + moving: Graph, + reference: Graph, + verbose: bool, + ) -> NDArray[np.integer]: + """Hungarian matching in one direction.""" + C = self.compute_cost_matrix(moving, reference) + matches = self._solve_assignment(C, verbose) + return matches + + def _match_hungarian_cross_check( + self, + moving: Graph, + reference: Graph, + verbose: bool, + ) -> NDArray[np.integer]: + """Hungarian matching with bidirectional consistency.""" + if verbose: + click.echo("Computing forward matches (A → B)...") + + C_ab = self.compute_cost_matrix(moving, reference) + matches_ab = self._solve_assignment(C_ab, False) + + if verbose: + click.echo(f"Forward: {len(matches_ab)} matches") + click.echo("Computing backward matches (B → A)...") + + C_ba = self.compute_cost_matrix(reference, moving) + matches_ba = self._solve_assignment(C_ba, False) + + if verbose: + click.echo(f"Backward: {len(matches_ba)} matches") + + # Keep only symmetric matches + reverse_map = {(j, i) for i, j in matches_ba} + matches = np.array( + [[i, j] for i, j in matches_ab if (i, j) in reverse_map], dtype=np.int32 + ) + + if verbose: + click.echo(f"Cross-check: {len(matches)} symmetric matches") + + return matches + + def compute_cost_matrix( + self, + moving: Graph, + reference: Graph, + ) -> NDArray[np.floating]: + """ + Compute full cost matrix between two graphs. + + Parameters + ---------- + moving : Graph + Moving graph with N nodes + reference : Graph + Reference graph with M nodes + + Returns + ------- + NDArray + (N, M) cost matrix where C[i,j] = cost of matching moving[i] to reference[j] + """ + n, m = moving.n_nodes, reference.n_nodes + C_total = np.zeros((n, m), dtype=np.float32) + w = self.weights + + # Position distance + if w["dist"] > 0: + C_dist = cdist(moving.nodes, reference.nodes, metric=self.distance_metric) + if self.normalize: + max_val = C_dist.max() + if max_val > 0: + C_dist = C_dist / max_val + C_total += w["dist"] * C_dist + + # Edge consistency + if w["edge_length"] > 0: + C_edge_len = self._compute_edge_consistency_cost( + moving, reference, attr_type='distance', default_cost=1e6 + ) + if self.normalize: + max_val = C_edge_len.max() + if max_val > 0: + C_edge_len = C_edge_len / max_val + C_total += w["edge_length"] * C_edge_len + + if w["edge_angle"] > 0 and moving.dim == 2: + C_edge_ang = self._compute_edge_consistency_cost( + moving, reference, attr_type='angle', default_cost=np.pi + ) + if self.normalize: + C_edge_ang = C_edge_ang / np.pi + C_total += w["edge_angle"] * C_edge_ang + + # PCA features + if w["pca_dir"] > 0 or w["pca_aniso"] > 0: + mov_dirs, mov_aniso = moving.pca_features + ref_dirs, ref_aniso = reference.pca_features + + if w["pca_dir"] > 0: + dot = np.clip(np.dot(mov_dirs, ref_dirs.T), -1.0, 1.0) + C_dir = 1 - np.abs(dot) + if self.normalize: + max_val = C_dir.max() + if max_val > 0: + C_dir = C_dir / max_val + C_total += w["pca_dir"] * C_dir + + if w["pca_aniso"] > 0: + C_aniso = np.abs(mov_aniso[:, None] - ref_aniso[None, :]) + if self.normalize: + max_val = C_aniso.max() + if max_val > 0: + C_aniso = C_aniso / max_val + C_total += w["pca_aniso"] * C_aniso + + # Edge descriptors + if w["edge_descriptor"] > 0: + mov_desc = moving.edge_descriptors + ref_desc = reference.edge_descriptors + C_desc = cdist(mov_desc, ref_desc) + if self.normalize: + max_val = C_desc.max() + if max_val > 0: + C_desc = C_desc / max_val + C_total += w["edge_descriptor"] * C_desc + + return C_total + + def _compute_edge_consistency_cost( + self, + moving: Graph, + reference: Graph, + attr_type: str, + default_cost: float, + ) -> NDArray[np.floating]: + """Compute cost based on local edge attribute consistency.""" + n, m = moving.n_nodes, reference.n_nodes + cost_matrix = np.full((n, m), default_cost, dtype=np.float32) + + if attr_type == 'distance': + mov_attrs = moving.edge_distances + ref_attrs = reference.edge_distances + elif attr_type == 'angle': + mov_attrs = moving.edge_angles + ref_attrs = reference.edge_angles + if not mov_attrs or not ref_attrs: + return cost_matrix + else: + raise ValueError(f"Unknown attr_type: {attr_type}") + + mov_neighbors = moving.neighbor_map + ref_neighbors = reference.neighbor_map + + for i in range(n): + i_neighbors = mov_neighbors.get(i, []) + if not i_neighbors: + continue + + for j in range(m): + j_neighbors = ref_neighbors.get(j, []) + if not j_neighbors: + continue + + C_local = np.full( + (len(i_neighbors), len(j_neighbors)), default_cost, dtype=np.float32 + ) + + for ii, ni in enumerate(i_neighbors): + for jj, nj in enumerate(j_neighbors): + mov_edge = (i, ni) + ref_edge = (j, nj) + + if mov_edge in mov_attrs and ref_edge in ref_attrs: + C_local[ii, jj] = abs(mov_attrs[mov_edge] - ref_attrs[ref_edge]) + + row_ind, col_ind = linear_sum_assignment(C_local) + matched_costs = C_local[row_ind, col_ind] + cost_matrix[i, j] = ( + matched_costs.mean() if len(matched_costs) > 0 else default_cost + ) + + return cost_matrix + + def _solve_assignment( + self, + C: NDArray[np.floating], + verbose: bool, + ) -> NDArray[np.integer]: + """Solve assignment problem with padding for unequal sizes.""" + n_A, n_B = C.shape + n = max(n_A, n_B) + + # Pad to square + dummy_cost = 1e6 + C_padded = np.full((n, n), dummy_cost, dtype=np.float32) + C_padded[:n_A, :n_B] = C + + # Solve + row_ind, col_ind = linear_sum_assignment(C_padded) + + # Filter matches + cost_thresh = np.quantile(C, self.cost_threshold) + matches = [] + + for i, j in zip(row_ind, col_ind): + if i >= n_A or j >= n_B: + continue + + if C[i, j] >= cost_thresh: + continue + + if self.max_ratio is not None: + costs_i = C[i, :] + sorted_costs = np.sort(costs_i) + if len(sorted_costs) > 1: + second_best = sorted_costs[1] + ratio = C[i, j] / (second_best + 1e-10) + if ratio > self.max_ratio: + continue + + matches.append((i, j)) + + if verbose: + click.echo(f"Found {len(matches)} matches (cost_threshold={cost_thresh:.3f})") + + return np.array(matches, dtype=np.int32).reshape(-1, 2) + + # ============================================================ + # DESCRIPTOR MATCHING + # ============================================================ + + def _match_descriptor( + self, + moving: Graph, + reference: Graph, + verbose: bool, + ) -> NDArray[np.integer]: + """ + Feature-based matching using edge descriptors. + + Uses skimage.feature.match_descriptors with edge_descriptors + as the feature vectors. + """ + # Get descriptors + mov_desc = moving.nodes + ref_desc = reference.nodes + + if verbose: + click.echo( + f"Matching {mov_desc.shape[0]} moving descriptors to {ref_desc.shape[0]} reference descriptors" + ) + + # Use skimage's match_descriptors + matches = match_descriptors( + mov_desc, + ref_desc, + metric=self.metric, + cross_check=self.cross_check, + max_ratio=self.max_ratio if self.max_ratio is not None else 1.0, + ) + + if verbose: + click.echo(f"Found {len(matches)} descriptor matches") + + return matches.astype(np.int32) + + def filter_matches( + self, + matches: NDArray[np.integer], + moving: Graph, + reference: Graph, + angle_threshold: Optional[float] = 0, + direction_threshold: Optional[float] = 0, + min_distance_quantile: float = 0.01, + max_distance_quantile: float = 0.95, + verbose: Optional[bool] = None, + ) -> NDArray[np.integer]: + """ + Filter matches based on geometric consistency. + + Parameters + ---------- + matches : NDArray + (N, 2) array of matches + moving : Graph + Moving graph + reference : Graph + Reference graph + angle_threshold : float, optional + Maximum deviation from dominant angle (degrees, 2D only). + If None, skip 2D angle filtering. + direction_threshold : float, optional + Maximum angular deviation from dominant direction (degrees, 2D/3D). + Uses dot product between normalized vectors. + If None, skip direction filtering. + min_distance_quantile : float + Lower quantile cutoff for distances + max_distance_quantile : float + Upper quantile cutoff for distances + verbose : bool, optional + Override instance verbose setting + + Returns + ------- + NDArray + (K, 2) filtered matches where K <= N + """ + verbose = verbose if verbose is not None else self.verbose + + if len(matches) == 0: + return matches + + # Distance filtering + if min_distance_quantile != 0 or max_distance_quantile != 0: + dist = np.linalg.norm( + moving.nodes[matches[:, 0]] - reference.nodes[matches[:, 1]], axis=1 + ) + + low = np.quantile(dist, min_distance_quantile) + high = np.quantile(dist, max_distance_quantile) + + if verbose: + click.echo( + f"Distance filtering: quantiles [{min_distance_quantile}, {max_distance_quantile}]" + ) + click.echo(f"Distance range: [{low:.3f}, {high:.3f}]") + + keep = (dist >= low) & (dist <= high) + matches = matches[keep] + + if verbose: + click.echo(f"Matches after distance filtering: {len(matches)}") + + # Direction filtering (2D/3D) - NEW + if direction_threshold != 0: + vectors = reference.nodes[matches[:, 1]] - moving.nodes[matches[:, 0]] + + # Normalize vectors + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + unit_vectors = vectors / (norms + 1e-10) + + # Find dominant direction using circular/spherical mean + mean_direction = unit_vectors.mean(axis=0) + mean_direction = mean_direction / (np.linalg.norm(mean_direction) + 1e-10) + + # Compute angular deviation from dominant direction + dot_products = np.clip(unit_vectors @ mean_direction, -1.0, 1.0) + angles_rad = np.arccos(dot_products) + angles_deg = np.degrees(angles_rad) + + keep = angles_deg <= direction_threshold + matches = matches[keep] + + if verbose: + click.echo(f"Dominant direction: {mean_direction}") + click.echo(f"Direction threshold: {direction_threshold}°") + click.echo(f"Matches after direction filtering: {len(matches)}") + + # Angle filtering (2D only, legacy) + if angle_threshold != 0 and moving.dim == 2: + vectors = reference.nodes[matches[:, 1]] - moving.nodes[matches[:, 0]] + angles_rad = np.arctan2(vectors[:, 1], vectors[:, 0]) + angles_deg = np.degrees(angles_rad) + + bins = np.linspace(-180, 180, 36) + hist, bin_edges = np.histogram(angles_deg, bins=bins) + dominant_bin_index = np.argmax(hist) + dominant_angle = ( + bin_edges[dominant_bin_index] + bin_edges[dominant_bin_index + 1] + ) / 2 + + keep = np.abs(angles_deg - dominant_angle) <= angle_threshold + matches = matches[keep] + + if verbose: + click.echo(f"Dominant angle: {dominant_angle:.1f}°") + click.echo(f"Matches after 2D angle filtering: {len(matches)}") + + return matches diff --git a/biahub/core/transform.py b/biahub/core/transform.py new file mode 100644 index 00000000..247b08a8 --- /dev/null +++ b/biahub/core/transform.py @@ -0,0 +1,551 @@ +""" +Geometric transform class for 2D/3D volumes. + +This module provides an immutable Transform class that wraps homogeneous +transformation matrices and provides methods for application, inversion, +composition, and conversion between different representations. + +Coordinate convention: ZYX ordering for 3D, YX for 2D. +""" + +from __future__ import annotations +from typing import Literal + +import numpy as np + +from numpy.typing import NDArray +from scipy.ndimage import affine_transform + +TransformType = Literal["affine", "similarity", "euclidean", "rigid"] +Backend = Literal["scipy", "ants"] + + +class Transform: + """ + Geometric transform for 2D/3D volumes. + + Wraps a homogeneous transformation matrix (3×3 for 2D, 4×4 for 3D) + and provides methods for application, inversion, composition, and conversion. + + This class is immutable - all operations return new Transform instances. + + Coordinate convention: ZYX ordering for 3D, YX for 2D. + + Parameters + ---------- + matrix : np.ndarray + Homogeneous transformation matrix. Shape (3, 3) for 2D or (4, 4) for 3D. + transform_type : TransformType + Type of transform. Affects estimation constraints but not application. + - "affine": Full affine (12 DOF in 3D) + - "similarity": Rotation + translation + uniform scale (7 DOF in 3D) + - "euclidean" / "rigid": Rotation + translation only (6 DOF in 3D) + + Examples + -------- + >>> t = Transform.identity(ndim=3) + >>> t_shifted = Transform.from_translation([0, 10, 20]) # ZYX + >>> composed = t_shifted @ t + >>> inverted = t_shifted.invert() + >>> points_transformed = t_shifted.apply_points(points) + >>> volume_transformed = t_shifted.apply(volume, output_shape=(64, 128, 128)) + """ + + def __init__( + self, + matrix: NDArray[np.floating], + transform_type: TransformType = "affine", + ): + matrix = np.asarray(matrix, dtype=np.float64) + + if matrix.shape == (3, 3): + self._ndim = 2 + elif matrix.shape == (4, 4): + self._ndim = 3 + else: + raise ValueError( + f"Matrix must be (3, 3) for 2D or (4, 4) for 3D, got {matrix.shape}" + ) + + # Store as immutable + self._matrix = matrix + self._matrix.flags.writeable = False + self._type = transform_type + + # ==================== Properties ==================== + + @property + def matrix(self) -> NDArray[np.floating]: + """The homogeneous transformation matrix (copy).""" + return self._matrix.copy() + + @property + def ndim(self) -> int: + """Number of spatial dimensions (2 or 3).""" + return self._ndim + + @property + def transform_type(self) -> TransformType: + """Type of transform.""" + return self._type + + @property + def translation(self) -> NDArray[np.floating]: + """Translation component. ZYX for 3D, YX for 2D.""" + return self._matrix[:-1, -1].copy() + + @property + def linear(self) -> NDArray[np.floating]: + """Linear component (rotation, scale, shear).""" + return self._matrix[:-1, :-1].copy() + + @property + def is_identity(self) -> bool: + """Check if this is an identity transform.""" + return np.allclose(self._matrix, np.eye(self._ndim + 1)) + + # ==================== Constructors ==================== + + @classmethod + def identity(cls, ndim: int = 3) -> Transform: + """ + Create an identity transform. + + Parameters + ---------- + ndim : int + Number of dimensions (2 or 3). + + Returns + ------- + Transform + Identity transform. + """ + if ndim == 2: + matrix = np.eye(3, dtype=np.float64) + elif ndim == 3: + matrix = np.eye(4, dtype=np.float64) + else: + raise ValueError(f"ndim must be 2 or 3, got {ndim}") + + return cls(matrix, transform_type="affine") + + @classmethod + def from_translation(cls, offset: NDArray[np.floating]) -> Transform: + """ + Create a pure translation transform. + + Parameters + ---------- + offset : array-like + Translation vector. ZYX for 3D, YX for 2D. + + Returns + ------- + Transform + Translation transform. + + Examples + -------- + >>> t = Transform.from_translation([5, 10, 15]) # 3D: Z=5, Y=10, X=15 + >>> t = Transform.from_translation([10, 15]) # 2D: Y=10, X=15 + """ + offset = np.asarray(offset, dtype=np.float64) + ndim = len(offset) + + if ndim == 2: + matrix = np.eye(3, dtype=np.float64) + matrix[:2, 2] = offset + elif ndim == 3: + matrix = np.eye(4, dtype=np.float64) + matrix[:3, 3] = offset + else: + raise ValueError(f"offset must be 2D or 3D, got shape {offset.shape}") + + return cls(matrix, transform_type="euclidean") + + @classmethod + def from_skimage( + cls, + skimage_transform, + ndim: int = 3, + ) -> Transform: + """ + Create Transform from a scikit-image transform. + + Parameters + ---------- + skimage_transform : skimage.transform._geometric.GeometricTransform + A scikit-image transform (EuclideanTransform, SimilarityTransform, + AffineTransform, etc.) + ndim : int + Number of dimensions for the output transform. If the skimage + transform is 2D and ndim=3, it will be embedded in 3D (YX plane). + + Returns + ------- + Transform + New Transform instance. + + Examples + -------- + >>> from skimage.transform import SimilarityTransform + >>> sk_tform = SimilarityTransform(scale=0.9, rotation=0.1, translation=(10, 20)) + >>> t = Transform.from_skimage(sk_tform, ndim=3) + """ + params = skimage_transform.params + + # Determine transform type from skimage class name + class_name = type(skimage_transform).__name__.lower() + if "euclidean" in class_name: + transform_type = "euclidean" + elif "similarity" in class_name: + transform_type = "similarity" + elif "affine" in class_name: + transform_type = "affine" + else: + transform_type = "affine" + + if params.shape == (3, 3): + # 2D transform + if ndim == 2: + return cls(params, transform_type=transform_type) + elif ndim == 3: + # Embed 2D transform in 3D (YX plane, Z unchanged) + matrix_3d = np.eye(4, dtype=np.float64) + matrix_3d[1:3, 1:3] = params[:2, :2] # YX linear part + matrix_3d[1:3, 3] = params[:2, 2] # YX translation + return cls(matrix_3d, transform_type=transform_type) + elif params.shape == (4, 4): + # 3D transform + if ndim == 3: + return cls(params, transform_type=transform_type) + else: + raise ValueError("Cannot convert 3D skimage transform to 2D") + else: + raise ValueError(f"Unexpected skimage transform shape: {params.shape}") + + # ==================== Algebraic Operations ==================== + + def invert(self) -> Transform: + """ + Return the inverse transform. + + Returns + ------- + Transform + New transform that undoes this one. + """ + return Transform( + np.linalg.inv(self._matrix), + transform_type=self._type, + ) + + def compose(self, other: Transform) -> Transform: + """ + Compose this transform with another: self @ other. + + The result applies `other` first, then `self`. + For points: composed.apply_points(p) == self.apply_points(other.apply_points(p)) + + Parameters + ---------- + other : Transform + Transform to compose with. + + Returns + ------- + Transform + New composed transform. + """ + if self.ndim != other.ndim: + raise ValueError(f"Cannot compose {self.ndim}D and {other.ndim}D transforms") + + # Result type: most general of the two + type_hierarchy = ["euclidean", "rigid", "similarity", "affine"] + self_idx = type_hierarchy.index(self._type) + other_idx = type_hierarchy.index(other._type) + result_type = type_hierarchy[max(self_idx, other_idx)] + + return Transform( + self._matrix @ other._matrix, + transform_type=result_type, + ) + + def __matmul__(self, other: Transform) -> Transform: + """Compose transforms using @ operator.""" + return self.compose(other) + + # ==================== Application Methods ==================== + + def apply_points(self, points: NDArray[np.floating]) -> NDArray[np.floating]: + """ + Apply transform to a point cloud. + + Parameters + ---------- + points : NDArray + (N, D) array of points where D is ndim (2 or 3). + Coordinates in ZYX order for 3D, YX for 2D. + + Returns + ------- + NDArray + (N, D) transformed points. + + Examples + -------- + >>> t = Transform.from_translation([0, 5, 10]) + >>> points = np.array([[0, 0, 0], [1, 1, 1]]) + >>> transformed = t.apply_points(points) + """ + points = np.asarray(points, dtype=np.float64) + + if points.ndim != 2: + raise ValueError(f"points must be 2D array (N, D), got shape {points.shape}") + + if points.shape[1] != self._ndim: + raise ValueError(f"points must have {self._ndim} columns, got {points.shape[1]}") + + # Convert to homogeneous coordinates + ones = np.ones((points.shape[0], 1)) + points_homogeneous = np.hstack([points, ones]) + + # Apply transform + transformed = (self._matrix @ points_homogeneous.T).T + + # Convert back from homogeneous + return transformed[:, :-1] + + def apply( + self, + moving: NDArray, + reference: NDArray | None = None, + order: int = 1, + mode: str = "constant", + cval: float = 0.0, + backend: Backend = "scipy", + ) -> NDArray: + """ + Apply transform to align moving image with reference space. + + Parameters + ---------- + moving : NDArray + Image to transform (ZYX for 3D, YX for 2D). + reference : NDArray, optional + Reference image defining output space. If None, uses moving image shape. + order : int, default=1 + Interpolation order (0=nearest, 1=linear, 3=cubic). + mode : str, default='constant' + How to handle boundaries ('constant', 'edge', 'reflect', 'wrap'). + cval : float, default=0.0 + Fill value for constant mode. + backend : {'scipy', 'ants'}, default='scipy' + Backend to use for transformation. + + Returns + ------- + NDArray + Transformed image in reference space. + + Examples + -------- + >>> t = Transform.from_translation([0, 5, 10]) + >>> aligned = t.apply(moving, reference=fixed) + """ + moving = np.asarray(moving) + + # Validate dimensions + if moving.ndim != self._ndim: + raise ValueError(f"Expected {self._ndim}D array, got {moving.ndim}D") + + # Determine output shape + output_shape = reference.shape if reference is not None else moving.shape + + if backend == "scipy": + return self._apply_scipy(moving, output_shape, order, mode, cval) + elif backend == "ants": + return self._apply_ants(moving, reference) + else: + raise ValueError(f"Unknown backend: {backend}") + + def _apply_scipy( + self, + moving: NDArray, + output_shape: tuple, + order: int, + mode: str, + cval: float, + ) -> NDArray: + """Apply transform using scipy.ndimage.""" + # scipy applies inverse transform: output[i] = input[inv_matrix @ i] + inv_matrix = np.linalg.inv(self._matrix) + affine_matrix = inv_matrix[:-1, :-1] + offset = inv_matrix[:-1, -1] + + return affine_transform( + moving, + matrix=affine_matrix, + offset=offset, + output_shape=output_shape, + order=order, + mode=mode, + cval=cval, + ) + + def _apply_ants(self, moving: NDArray, reference: NDArray | None) -> NDArray: + """Apply transform using ANTs.""" + try: + import ants + except ImportError: + raise ImportError( + "ANTsPy is required for backend='ants'. Install with: pip install antspyx" + ) + + if self._ndim != 3: + raise NotImplementedError("ANTs backend only supports 3D transforms") + + # Convert to ANTs images + moving_ants = ants.from_numpy(moving.astype(np.float32)) + + if reference is not None: + reference_ants = ants.from_numpy(reference.astype(np.float32)) + else: + reference_ants = moving_ants + + # Convert transform to ANTs + transform_ants = self.to_ants() + + # Apply + result_ants = transform_ants.apply_to_image(moving_ants, reference=reference_ants) + + return result_ants.numpy() + + # ==================== ANTs Conversion ==================== + def to_ants(self): + """ + Convert to ANTs transform. + + Returns + ------- + ants.ANTsTransform + ANTs transform object. + + Notes + ----- + Requires ANTsPy to be installed. + Works for both 2D and 3D transforms. + """ + try: + import ants + except ImportError: + raise ImportError( + "ANTsPy is required for to_ants(). Install with: pip install antspyx" + ) + if self._ndim not in (2, 3): + raise ValueError(f"Unsupported ndim: {self._ndim}") + T_ants_style = self._matrix[:, :-1].ravel() + T_ants_style[-self._ndim :] = self._matrix[: self._ndim, -1] + T_ants = ants.new_ants_transform( + transform_type="AffineTransform", + dimension=self._ndim, + ) + T_ants.set_parameters(T_ants_style) + return T_ants + + @classmethod + def from_ants(cls, T_ants) -> Transform: + """ + Create Transform from ANTs transform. + + Parameters + ---------- + T_ants : ants.ANTsTransform + ANTs transform object. + + Returns + ------- + Transform + New Transform instance. + + Notes + ----- + Works for both 2D and 3D ANTs transforms. + Based on conversion from: + https://sourceforge.net/p/advants/discussion/840261/thread/9fbbaab7/ + """ + + params = T_ants.parameters + fixed_params = T_ants.fixed_parameters + if len(params) == 6: + ndim = 2 + elif len(params) == 12: + ndim = 3 + else: + raise ValueError( + f"Unexpected ANTs parameter count: {len(params)}. Expected 6 (2D) or 12 (3D)." + ) + + T_numpy = params.reshape((ndim, ndim + 1), order="F") + T_numpy[:, :ndim] = T_numpy[:, :ndim].transpose() + T_numpy = np.vstack((T_numpy, np.array([0] * ndim + [1]))) + T_numpy[:ndim, -1] += (np.eye(ndim) - T_numpy[:ndim, :ndim]) @ fixed_params + + return cls(T_numpy) + + # ==================== Serialization ==================== + + def to_list(self) -> list[list[float]]: + """Convert matrix to nested list (for JSON/YAML serialization).""" + return self._matrix.tolist() + + @classmethod + def from_list( + cls, + matrix_list: list[list[float]], + transform_type: TransformType = "affine", + ) -> Transform: + """Create Transform from nested list.""" + return cls(np.array(matrix_list), transform_type=transform_type) + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "matrix": self.to_list(), + "transform_type": self._type, + "ndim": self._ndim, + } + + @classmethod + def from_dict(cls, data: dict) -> Transform: + """Create Transform from dictionary.""" + return cls( + np.array(data["matrix"]), + transform_type=data.get("transform_type", "affine"), + ) + + # ==================== String Representations ==================== + + def __repr__(self) -> str: + return ( + f"Transform(ndim={self._ndim}, type='{self._type}', " + f"translation={self.translation.round(3).tolist()})" + ) + + def __str__(self) -> str: + matrix_str = np.array2string(self._matrix, precision=4, suppress_small=True) + return f"Transform({self._type}, {self._ndim}D)\n{matrix_str}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Transform): + return NotImplemented + return ( + self._ndim == other._ndim + and self._type == other._type + and np.allclose(self._matrix, other._matrix) + ) + + def __hash__(self) -> int: + # For use in sets/dicts - based on rounded matrix values + return hash((self._ndim, self._type, tuple(self._matrix.round(6).ravel()))) diff --git a/biahub/estimate_registration.py b/biahub/estimate_registration.py index f6722db3..f22a3aa9 100644 --- a/biahub/estimate_registration.py +++ b/biahub/estimate_registration.py @@ -1,58 +1,37 @@ -import os -import shutil - -from datetime import datetime from pathlib import Path -from typing import Literal, Union import ants import click -import dask.array as da import napari import numpy as np -import submitit from iohub import open_ome_zarr -from matplotlib import pyplot as plt from numpy.typing import ArrayLike -from scipy.interpolate import interp1d -from scipy.optimize import linear_sum_assignment -from scipy.spatial.distance import cdist -from skimage.feature import match_descriptors -from skimage.transform import AffineTransform, EuclideanTransform, SimilarityTransform -from sklearn.neighbors import NearestNeighbors, radius_neighbors_graph +from skimage.transform import EuclideanTransform, SimilarityTransform from waveorder.focus import focus_from_transverse_band -from biahub.characterize_psf import detect_peaks from biahub.cli.parsing import ( config_filepath, local, output_filepath, sbatch_filepath, - sbatch_to_submitit, source_position_dirpaths, target_position_dirpaths, ) -from biahub.cli.slurm import wait_for_jobs_to_finish from biahub.cli.utils import ( - _check_nan_n_zeros, - estimate_resources, model_to_yaml, yaml_to_model, ) -from biahub.optimize_registration import _optimize_registration -from biahub.register import ( +from biahub.registration.utils import ( convert_transform_to_ants, convert_transform_to_numpy, + evaluate_transforms, get_3D_fliplr_matrix, get_3D_rescaling_matrix, get_3D_rotation_matrix, + plot_translations, ) from biahub.settings import ( - AffineTransformSettings, - AntsRegistrationSettings, - BeadsMatchSettings, - DetectPeaksSettings, EstimateRegistrationSettings, RegistrationSettings, StabilizationSettings, @@ -77,376 +56,6 @@ ] -def validate_transforms( - transforms: list[ArrayLike], - shape_zyx: tuple[int, int, int], - window_size: int = 10, - tolerance: float = 100.0, - verbose: bool = False, -) -> list[ArrayLike]: - """ - Validate that a provided list of transforms do not deviate beyond the tolerance threshold - relative to the average transform within a given window size. - - Parameters - ---------- - transforms : list[ArrayLike] - List of affine transformation matrices (4x4), one for each timepoint. - shape_zyx : tuple[int, int, int] - Shape of the source (i.e. moving) volume (Z, Y, X). - window_size : int - Size of the moving window for smoothing transformations. - tolerance : float - Maximum allowed difference between consecutive transformations for validation. - verbose : bool - If True, prints detailed logs of the validation process. - - Returns - ------- - list[ArrayLike] - List of affine transformation matrices with invalid or inconsistent values replaced by None. - """ - valid_transforms = [] - reference_transform = None - - for i, transform in enumerate(transforms): - if transform is not None: - if len(valid_transforms) < window_size: - # Bootstrap the buffer without validating yet - valid_transforms.append(transform) - reference_transform = np.mean(valid_transforms, axis=0) - if verbose: - click.echo( - f"[Bootstrap] Accepting transform at timepoint {i} (no validation)" - ) - elif check_transforms_difference( - transform, reference_transform, shape_zyx, tolerance, verbose - ): - valid_transforms.append(transform) - if len(valid_transforms) > window_size: - valid_transforms.pop(0) - reference_transform = np.mean(valid_transforms, axis=0) - if verbose: - click.echo(f"Transform at timepoint {i} is valid") - else: - transforms[i] = None - if verbose: - click.echo( - f"Transform at timepoint {i} is invalid and will be interpolated" - ) - else: - transforms[i] = None - if verbose: - click.echo(f"Transform at timepoint {i} is None and will be interpolated") - - return transforms - - -def interpolate_transforms( - transforms: list[ArrayLike], - window_size: int = 3, - interpolation_type: Literal["linear", "cubic"] = "linear", - verbose: bool = False, -): - """ - Interpolate missing transforms (None) in a list of affine transformation matrices. - - Parameters - ---------- - transforms : list[ArrayLike] - List of affine transformation matrices (4x4), one for each timepoint. - window_size : int - Local window radius for interpolation. If 0, global interpolation is used. - interpolation_type : Literal["linear", "cubic"] - Interpolation type. - verbose : bool - If True, prints detailed logs of the interpolation process. - - Returns - ------- - list[ArrayLike] - List of affine transformation matrices with missing values filled via linear interpolation. - """ - n = len(transforms) - valid_transform_indices = [i for i, t in enumerate(transforms) if t is not None] - valid_transforms = [np.array(transforms[i]) for i in valid_transform_indices] - - if not valid_transform_indices or len(valid_transform_indices) < 2: - raise ValueError("At least two valid transforms are required for interpolation.") - - missing_indices = [i for i in range(n) if transforms[i] is None] - - if not missing_indices: - return transforms # nothing to do - if verbose: - click.echo(f"Interpolating missing transforms at timepoints: {missing_indices}") - - if window_size > 0: - for idx in missing_indices: - # Define local window - start = max(0, idx - window_size) - end = min(n, idx + window_size + 1) - - local_x = [] - local_y = [] - - for j in range(start, end): - if j in valid_transform_indices: - local_x.append(j) - local_y.append(np.array(transforms[j])) - - if len(local_x) < 2: - # Not enough neighbors for interpolation. Assign to closes valid transform - closest_valid_idx = valid_transform_indices[ - np.argmin(np.abs(np.asarray(valid_transform_indices) - idx)) - ] - transforms[idx] = transforms[closest_valid_idx] - if verbose: - click.echo( - f"Not enough interpolation neighbors were found for timepoint {idx} using closest valid transform at timepoint {closest_valid_idx}" - ) - continue - - f = interp1d( - local_x, local_y, axis=0, kind=interpolation_type, fill_value='extrapolate' - ) - transforms[idx] = f(idx).tolist() - if verbose: - click.echo(f"Interpolated timepoint {idx} using neighbors: {local_x}") - - else: - # Global interpolation using all valid transforms - f = interp1d( - valid_transform_indices, - valid_transforms, - axis=0, - kind='linear', - fill_value='extrapolate', - ) - transforms = [ - f(i).tolist() if transforms[i] is None else transforms[i] for i in range(n) - ] - - return transforms - - -def check_transforms_difference( - tform1: ArrayLike, - tform2: ArrayLike, - shape_zyx: tuple[int, int, int], - threshold: float = 5.0, - verbose: bool = False, -): - """ - Evaluate the difference between two affine transforms by calculating the - Mean Squared Error (MSE) of a grid of points transformed by each matrix. - - Parameters - ---------- - tform1 : ArrayLike - First affine transform (4x4 matrix). - tform2 : ArrayLike - Second affine transform (4x4 matrix). - shape_zyx : tuple[int, int, int] - Shape of the source (i.e. moving) volume (Z, Y, X). - threshold : float - The maximum allowed MSE difference. - verbose : bool - Flag to print the MSE difference. - - Returns - ------- - bool - True if the MSE difference is within the threshold, False otherwise. - """ - tform1 = np.array(tform1) - tform2 = np.array(tform2) - (Z, Y, X) = shape_zyx - - zz, yy, xx = np.meshgrid( - np.linspace(0, Z - 1, 10), np.linspace(0, Y - 1, 10), np.linspace(0, X - 1, 10) - ) - - grid_points = np.vstack([zz.ravel(), yy.ravel(), xx.ravel(), np.ones(zz.size)]).T - - points_tform1 = np.dot(tform1, grid_points.T).T - points_tform2 = np.dot(tform2, grid_points.T).T - - differences = np.linalg.norm(points_tform1[:, :3] - points_tform2[:, :3], axis=1) - mse = np.mean(differences) - - if verbose: - click.echo(f'MSE of transformed points: {mse:.2f}; threshold: {threshold:.2f}') - return mse <= threshold - - -def evaluate_transforms( - transforms: ArrayLike, - shape_zyx: tuple[int, int, int], - validation_window_size: int = 10, - validation_tolerance: float = 100.0, - interpolation_window_size: int = 3, - interpolation_type: Literal["linear", "cubic"] = "linear", - verbose: bool = False, -) -> ArrayLike: - """ - Evaluate a list of affine transformation matrices. - Transform matrices are checked for deviation from the average within a given window size. - If a transform is found to lead to shift larger than the given tolerance, - that transform will be replaced by interpolation of valid transforms within a given window size. - - Parameters - ---------- - transforms : ArrayLike - List of affine transformation matrices (4x4), one for each timepoint. - shape_zyx : tuple[int, int, int] - Shape of the source (i.e. moving) volume (Z, Y, X). - validation_window_size : int - Size of the moving window for smoothing transformations. - validation_tolerance : float - Maximum allowed difference between consecutive transformations for validation. - interpolation_window_size : int - Size of the local window for interpolation. - interpolation_type : Literal["linear", "cubic"] - Interpolation type. - verbose : bool - If True, prints detailed logs of the evaluation and validation process. - - Returns - ------- - list[ArrayLike] - List of affine transformation matrices with missing values filled via linear interpolation. - """ - - if not isinstance(transforms, list): - transforms = transforms.tolist() - if len(transforms) < validation_window_size: - raise Warning( - f"Not enough transforms for validation and interpolation. " - f"Required: {validation_window_size}, " - f"Provided: {len(transforms)}" - ) - else: - transforms = validate_transforms( - transforms=transforms, - window_size=validation_window_size, - tolerance=validation_tolerance, - shape_zyx=shape_zyx, - verbose=verbose, - ) - - if len(transforms) < interpolation_window_size: - raise Warning( - f"Not enough transforms for interpolation. " - f"Required: {interpolation_window_size}, " - f"Provided: {len(transforms)}" - ) - else: - transforms = interpolate_transforms( - transforms=transforms, - window_size=interpolation_window_size, - interpolation_type=interpolation_type, - verbose=verbose, - ) - return transforms - - -def save_transforms( - model: Union[AffineTransformSettings, StabilizationSettings, RegistrationSettings], - transforms: list[ArrayLike], - output_filepath_settings: Path, - output_filepath_plot: Path = None, - verbose: bool = False, -): - """ - Save the transforms to a yaml file and plot the translations. - - Parameters - ---------- - model : Union[AffineTransformSettings, StabilizationSettings, RegistrationSettings] - Model to save the transforms to. - transforms : list[ArrayLike] - List of affine transformation matrices (4x4), one for each timepoint. - output_filepath_settings : Path - Path to the output settings file. - output_filepath_plot : Path - Path to the output plot file. - verbose : bool - If True, prints detailed logs of the saving process. - - Returns - ------- - None - - Notes - ----- - The transforms are saved to a yaml file and a plot of the translations is saved to a png file. - The plot is saved in the same directory as the settings file and is named "translations.png". - - """ - if transforms is None or len(transforms) == 0: - raise ValueError("Transforms are empty") - - if not isinstance(transforms, list): - transforms = transforms.tolist() - - model.affine_transform_zyx_list = transforms - - if output_filepath_settings.suffix not in [".yml", ".yaml"]: - output_filepath_settings = output_filepath_settings.with_suffix(".yml") - - output_filepath_settings.parent.mkdir(parents=True, exist_ok=True) - model_to_yaml(model, output_filepath_settings) - - if verbose and output_filepath_plot is not None: - if output_filepath_plot.suffix not in [".png"]: - output_filepath_plot = output_filepath_plot.with_suffix(".png") - output_filepath_plot.parent.mkdir(parents=True, exist_ok=True) - - plot_translations(np.asarray(transforms), output_filepath_plot) - - -def plot_translations( - transforms_zyx: ArrayLike, - output_filepath: Path, -): - """ - Plot the translations of a list of affine transformation matrices. - - Parameters - ---------- - transforms_zyx : ArrayLike - List of affine transformation matrices (4x4), one for each timepoint. - output_filepath : Path - Path to the output plot file. - Returns - ------- - None - - Notes - ----- - The plot is saved as a png file. - The plot is saved in the same directory as the output file. - The plot is saved as a png file. - """ - transforms_zyx = np.asarray(transforms_zyx) - os.makedirs(output_filepath.parent, exist_ok=True) - - z_transforms = transforms_zyx[:, 0, 3] - y_transforms = transforms_zyx[:, 1, 3] - x_transforms = transforms_zyx[:, 2, 3] - _, axs = plt.subplots(3, 1, figsize=(10, 10)) - - axs[0].plot(z_transforms) - axs[0].set_title("Z-Translation") - axs[1].plot(x_transforms) - axs[1].set_title("X-Translation") - axs[2].plot(y_transforms) - axs[2].set_title("Y-Translation") - plt.savefig(output_filepath, dpi=300, bbox_inches='tight') - plt.close() - - def user_assisted_registration( source_channel_volume: ArrayLike, source_channel_name: str, @@ -753,1269 +362,6 @@ def lambda_callback(layer, event): return [tform.tolist()] -def shrink_slice(s: slice, shrink_fraction: float = 0.1, min_width: int = 5) -> slice: - """ - Shrink a slice by a fraction of its length. - - Parameters - ---------- - s : slice - The slice to shrink. - shrink_fraction : float - The fraction of the slice to shrink. - min_width : int - The minimum width of the slice. - - Returns - ------- - slice - The shrunk slice. - Notes - ----- - If the slice is too small, return the original slice. - - """ - start = s.start or 0 - stop = s.stop or 0 - length = stop - start - if length <= min_width: - return slice(start, stop) - - shrink = int(length * shrink_fraction) - new_start = start + shrink - new_stop = stop - shrink - if new_stop <= new_start: - return slice(start, stop) - return slice(new_start, new_stop) - - -def ants_registration( - source_data_tczyx: da.Array, - target_data_tczyx: da.Array, - source_channel_index: int | list[int], - target_channel_index: int, - ants_registration_settings: AntsRegistrationSettings, - affine_transform_settings: AffineTransformSettings, - verbose: bool = False, - output_folder_path: Path = None, - cluster: str = 'local', - sbatch_filepath: Path = None, -) -> list[ArrayLike]: - """ - Perform ants registration of two volumetric image channels. - - This function calculates timepoint-specific affine transformations to align a source channel - to a target channel in 4D (T, Z, Y, X) data. It validates, smooths, and interpolates transformations - across timepoints for consistent registration. - - Parameters - ---------- - source_data_tczyx : da.Array - 4D array (T, C, Z, Y, X) of the source channel (Dask array). - target_data_tczyx : da.Array - 4D array (T, C, Z, Y, X) of the target channel (Dask array). - source_channel_index : int | list[int] - Index of the source channel. - target_channel_index : int - Index of the target channel. - ants_registration_settings : AntsRegistrationSettings - Settings for the ants registration. - affine_transform_settings : AffineTransformSettings - Settings for the affine transform. - verbose : bool - If True, prints detailed logs of the registration process. - output_folder_path : Path - Path to the output folder. - cluster : str - Cluster to use. - sbatch_filepath : Path - Path to the sbatch file. - - Returns - ------- - list[ArrayLike] - List of affine transformation matrices (4x4), one for each timepoint. - Invalid or missing transformations are interpolated. - - Notes - ----- - Each timepoint is processed in parallel using submitit executor. - Use verbose=True for detailed logging during registration. The verbose output will be saved at the same level as the output zarr. - """ - T, C, Z, Y, X = source_data_tczyx.shape - initial_tform = np.asarray(affine_transform_settings.approx_transform) - - num_cpus, gb_ram_per_cpu = estimate_resources( - shape=(T, 2, Z, Y, X), ram_multiplier=16, max_num_cpus=16 - ) - - # Prepare SLURM arguments - slurm_args = { - "slurm_job_name": "estimate_registration", - "slurm_mem_per_cpu": f"{gb_ram_per_cpu}G", - "slurm_cpus_per_task": num_cpus, - "slurm_array_parallelism": 100, - "slurm_time": 30, - "slurm_partition": "preempted", - } - - if sbatch_filepath: - slurm_args.update(sbatch_to_submitit(sbatch_filepath)) - - output_folder_path.mkdir(parents=True, exist_ok=True) - slurm_out_path = output_folder_path / "slurm_output" - slurm_out_path.mkdir(parents=True, exist_ok=True) - - # Submitit executor - executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster) - executor.update_parameters(**slurm_args) - - click.echo(f"Submitting SLURM estimate regstration jobs with resources: {slurm_args}") - output_transforms_path = output_folder_path / "xyz_transforms" - output_transforms_path.mkdir(parents=True, exist_ok=True) - - click.echo('Computing registration transforms...') - # NOTE: ants is mulitthreaded so no need for multiprocessing here - # Submit jobs - jobs = [] - with submitit.helpers.clean_env(), executor.batch(): - for t in range(T): - job = executor.submit( - _optimize_registration, - source_data_tczyx[t], - target_data_tczyx[t], - initial_tform=initial_tform, - source_channel_index=source_channel_index, - target_channel_index=target_channel_index, - crop=True, - target_mask_radius=0.8, - clip=True, - sobel_fitler=ants_registration_settings.sobel_filter, - verbose=verbose, - slurm=True, - output_folder_path=output_transforms_path, - t_idx=t, - ) - jobs.append(job) - - # Save job IDs - timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - log_path = slurm_out_path / f"job_ids_{timestamp}.log" - with open(log_path, "w") as log_file: - for job in jobs: - log_file.write(f"{job.job_id}\n") - - wait_for_jobs_to_finish(jobs) - - # Load the transforms - transforms = [] - for t in range(T): - file_path = output_transforms_path / f"{t}.npy" - if not os.path.exists(file_path): - transforms.append(None) - click.echo(f"Transform for timepoint {t} not found.") - else: - T_zyx_shift = np.load(file_path).tolist() - transforms.append(T_zyx_shift) - - if len(transforms) != T: - raise ValueError( - f"Number of transforms {len(transforms)} does not match number of timepoints {T}" - ) - - # Remove the output temporary folder - shutil.rmtree(output_transforms_path) - - return transforms - - -def beads_based_registration( - source_channel_tzyx: da.Array, - target_channel_tzyx: da.Array, - beads_match_settings: BeadsMatchSettings = None, - affine_transform_settings: AffineTransformSettings = None, - verbose: bool = False, - cluster: bool = False, - sbatch_filepath: Path = None, - output_folder_path: Path = None, -) -> list[ArrayLike]: - """ - Perform beads-based temporal registration of 4D data using affine transformations. - - This function calculates timepoint-specific affine transformations to align a source channel - to a target channel in 4D (T, Z, Y, X) data. It validates, smooths, and interpolates transformations - across timepoints for consistent registration. - - Parameters - ---------- - source_channel_tzyx : da.Array - 4D array (T, Z, Y, X) of the source channel (Dask array). - target_channel_tzyx : da.Array - 4D array (T, Z, Y, X) of the target channel (Dask array). - beads_match_settings : BeadsMatchSettings - Settings for the beads match. - affine_transform_settings : AffineTransformSettings - Settings for the affine transform. - verbose : bool - If True, prints detailed logs of the registration process. - cluster : bool - If True, uses the cluster. - sbatch_filepath : Path - Path to the sbatch file. - output_folder_path : Path - Path to the output folder. - - Returns - ------- - list[ArrayLike] - List of affine transformation matrices (4x4), one for each timepoint. - Invalid or missing transformations are interpolated. - - Notes - ----- - Each timepoint is processed in parallel using submitit executor. - Use verbose=True for detailed logging during registration. The verbose output will be saved at the same level as the output zarr. - """ - - (T, Z, Y, X) = source_channel_tzyx.shape - - num_cpus, gb_ram_per_cpu = estimate_resources( - shape=(T, 2, Z, Y, X), ram_multiplier=5, max_num_cpus=16 - ) - - # Prepare SLURM arguments - slurm_args = { - "slurm_job_name": "estimate_registration", - "slurm_mem_per_cpu": f"{gb_ram_per_cpu}G", - "slurm_cpus_per_task": num_cpus, - "slurm_array_parallelism": 100, - "slurm_time": 30, - "slurm_partition": "preempted", - "slurm_use_srun": False, - } - - if sbatch_filepath: - slurm_args.update(sbatch_to_submitit(sbatch_filepath)) - - output_folder_path.mkdir(parents=True, exist_ok=True) - slurm_out_path = output_folder_path / "slurm_output" - slurm_out_path.mkdir(parents=True, exist_ok=True) - - # Submitit executor - executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster) - executor.update_parameters(**slurm_args) - - click.echo(f"Submitting SLURM focus estimation jobs with resources: {slurm_args}") - output_transforms_path = output_folder_path / "xyz_transforms" - output_transforms_path.mkdir(parents=True, exist_ok=True) - - # Submit jobs - jobs = [] - with submitit.helpers.clean_env(), executor.batch(): - for t in range(T): - job = executor.submit( - estimate_transform_from_beads, - source_channel_tzyx=source_channel_tzyx, - target_channel_tzyx=target_channel_tzyx, - beads_match_settings=beads_match_settings, - affine_transform_settings=affine_transform_settings, - verbose=verbose, - slurm=True, - output_folder_path=output_transforms_path, - t_idx=t, - ) - jobs.append(job) - - # Save job IDs - timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - log_path = slurm_out_path / f"job_ids_{timestamp}.log" - with open(log_path, "w") as log_file: - for job in jobs: - log_file.write(f"{job.job_id}\n") - - wait_for_jobs_to_finish(jobs) - - transforms = [] - for t in range(T): - file_path = output_transforms_path / f"{t}.npy" - if not os.path.exists(file_path): - transforms.append(None) - else: - T_zyx_shift = np.load(file_path).tolist() - transforms.append(T_zyx_shift) - - # Remove the output temporary folder - shutil.rmtree(output_transforms_path) - - return transforms - - -def get_local_pca_features( - points: ArrayLike, - edges: list[tuple[int, int]], -) -> tuple[ArrayLike, ArrayLike]: - """ - Compute dominant direction and anisotropy for each point using PCA, - using neighborhoods defined by existing graph edges. - - Parameters - ---------- - points : ArrayLike - (n, 2) array of points. - edges : list[tuple[int, int]] - List of edges (i, j) in the graph. - - Returns - ------- - tuple[ArrayLike, ArrayLike] - - directions : (n, 3) array of dominant directions. - - anisotropy : (n,) array of anisotropy. - - Notes - ----- - The PCA features are computed as the dominant direction and anisotropy of the local neighborhood of each point. - The direction is the first principal component of the local neighborhood. - The anisotropy is the ratio of the first to third principal component of the local neighborhood. - """ - n = len(points) - directions = np.zeros((n, 3)) - anisotropy = np.zeros(n) - - # Build neighbor list from edges - from collections import defaultdict - - neighbor_map = defaultdict(list) - for i, j in edges: - neighbor_map[i].append(j) - - for i in range(n): - neighbors = neighbor_map[i] - if not neighbors: - directions[i] = np.nan - anisotropy[i] = np.nan - continue - - local_points = points[neighbors].astype(np.float32) - local_points -= local_points.mean(axis=0) - _, S, Vt = np.linalg.svd(local_points, full_matrices=False) - - directions[i] = Vt[0] if Vt.shape[0] > 0 else np.zeros(3) - anisotropy[i] = S[0] / (S[2] + 1e-5) if len(S) >= 3 else 0.0 - - return directions, anisotropy - - -def get_edge_descriptors( - points: ArrayLike, - edges: list[tuple[int, int]], -) -> ArrayLike: - """ - Compute edge descriptors for a set of points. - - Parameters - ---------- - points : ArrayLike - (n, 2) array of points. - edges : list[tuple[int, int]] - List of edges (i, j) in the graph. - - Returns - ------- - ArrayLike - (n, 4) array of edge descriptors. - Each row contains: - - mean length - - std length - - mean angle - - std angle - - Notes - ----- - The edge descriptors are computed as the mean and standard deviation of the lengths and angles of the edges. - """ - n = len(points) - desc = np.zeros((n, 4)) - for i in range(n): - neighbors = [j for a, j in edges if a == i] - if not neighbors: - continue - vectors = points[neighbors] - points[i] - lengths = np.linalg.norm(vectors, axis=1) - angles = np.arctan2(vectors[:, 1], vectors[:, 0]) - desc[i, 0] = np.mean(lengths) - desc[i, 1] = np.std(lengths) - desc[i, 2] = np.mean(angles) - desc[i, 3] = np.std(angles) - return desc - - -def get_edge_attrs( - points: ArrayLike, - edges: list[tuple[int, int]], -) -> tuple[dict[tuple[int, int], float], dict[tuple[int, int], float]]: - """ - Compute edge distances and angles for a set of points. - - Parameters - ---------- - points : ArrayLike - (n, 2) array of points. - edges : list[tuple[int, int]] - List of edges (i, j) in the graph. - - Returns - ------- - tuple[dict[tuple[int, int], float], dict[tuple[int, int], float]] - - distances : dict[tuple[int, int], float] - - angles : dict[tuple[int, int], float] - - """ - distances, angles = {}, {} - for i, j in edges: - vec = points[j] - points[i] - d = np.linalg.norm(vec) - angle = np.arctan2(vec[1], vec[0]) - distances[(i, j)] = distances[(j, i)] = d - angles[(i, j)] = angles[(j, i)] = angle - return distances, angles - - -def match_hungarian_local_cost( - i: int, - j: int, - s_neighbors: list[int], - t_neighbors: list[int], - source_attrs: dict[tuple[int, int], float], - target_attrs: dict[tuple[int, int], float], - default_cost: float, -) -> float: - """ - Match neighbor edges between two graphs using the Hungarian algorithm for local cost estimation. - The cost is the mean of the absolute differences between the source and target edge attributes. - - Parameters - ---------- - i : int - Index of the source edge. - j : int - Index of the target edge. - s_neighbors : list[int] - List of source neighbors. - t_neighbors : list[int] - List of target neighbors. - source_attrs : dict[tuple[int, int], float] - Dictionary of source edge attributes. - target_attrs : dict[tuple[int, int], float] - Dictionary of target edge attributes. - """ - C = np.full((len(s_neighbors), len(t_neighbors)), default_cost) - - # compute cost matrix - for ii, sn in enumerate(s_neighbors): - # get target neighbors - for jj, tn in enumerate(t_neighbors): - s_edge = (i, sn) - t_edge = (j, tn) - if s_edge in source_attrs and t_edge in target_attrs: - C[ii, jj] = abs(source_attrs[s_edge] - target_attrs[t_edge]) - - # use hungarian algorithm to find the best match - row_ind, col_ind = linear_sum_assignment(C) - # get the mean of the matched costs - matched_costs = C[row_ind, col_ind] - # return the mean of the matched costs - - return matched_costs.mean() if len(matched_costs) > 0 else default_cost - - -def compute_edge_consistency_cost( - n: int, - m: int, - source_attrs: dict[tuple[int, int], float], - target_attrs: dict[tuple[int, int], float], - source_edges: list[tuple[int, int]], - target_edges: list[tuple[int, int]], - default: float = 1e6, - hungarian: bool = True, -) -> ArrayLike: - """ - Compute the cost matrix for matching edges between two graphs. - - Parameters - ---------- - n : int - Number of source edges. - m : int - Number of target edges. - source_attrs : dict[tuple[int, int], float] - Dictionary of source edge attributes. - target_attrs : dict[tuple[int, int], float] - Dictionary of target edge attributes. - source_edges : list[tuple[int, int]] - List of edges (i, j) in source graph. - target_edges : list[tuple[int, int]] - List of edges (i, j) in target graph. - default : float - Default value for the cost matrix. - hungarian : bool - Whether to use the Hungarian algorithm for local cost estimation. - If False, the cost matrix is computed as the mean of the absolute differences between the source and target edge attributes. - If True, the cost matrix is computed as the mean of the absolute differences between the source and target edge attributes using the Hungarian algorithm. - - Returns - ------- - ArrayLike - Cost matrix of shape (n, m). - - Notes - ----- - The cost matrix is computed as the mean of the absolute differences between the source and target edge attributes. - """ - cost_matrix = np.full((n, m), default) - for i in range(n): - # get source neighbors - s_neighbors = [j for a, j in source_edges if a == i] - for j in range(m): - # get target neighbors - t_neighbors = [k for a, k in target_edges if a == j] - if hungarian: - # hungarian algorithm based cost estimation - cost_matrix[i, j] = match_hungarian_local_cost( - i, j, s_neighbors, t_neighbors, source_attrs, target_attrs, default - ) - else: - # position based cost estimation (mean of the absolute differences between the source and target edge attributes) - common_len = min(len(s_neighbors), len(t_neighbors)) - diffs = [] - for k in range(common_len): - s_edge = (i, s_neighbors[k]) - t_edge = (j, t_neighbors[k]) - if s_edge in source_attrs and t_edge in target_attrs: - v1 = source_attrs[s_edge] - v2 = target_attrs[t_edge] - diff = np.abs(v1 - v2) - diffs.append(diff) - cost_matrix[i, j] = np.mean(diffs) if diffs else default - - return cost_matrix - - -def compute_cost_matrix( - source_peaks: ArrayLike, - target_peaks: ArrayLike, - source_edges: list[tuple[int, int]], - target_edges: list[tuple[int, int]], - weights: dict[str, float] = None, - distance_metric: str = 'euclidean', - normalize: bool = False, -) -> ArrayLike: - """ - Compute a cost matrix for matching peaks between two graphs based on: - - Euclidean or other distance between peaks - - Consistency in edge distances - - Consistency in edge angles - - PCA features - - Edge descriptors - - Parameters - ---------- - source_peaks : ArrayLike - (n, 2) array of source node coordinates. - target_peaks : ArrayLike - (m, 2) array of target node coordinates. - source_edges : list[tuple[int, int]] - List of edges (i, j) in source graph. - target_edges : list[tuple[int, int]] - List of edges (i, j) in target graph. - weights : dict[str, float] - Weights for different cost components. - distance_metric : str - Metric for direct point-to-point distances. - normalize : bool - Whether to normalize the cost matrix. - - Notes - ----- - The cost matrix is computed as the sum of the weighted costs for each component. - The weights are defined in the `weights` parameter. - The default weights are: - - dist: 0.5 - - edge_angle: 1.0 - - edge_length: 1.0 - - pca_dir: 0.0 - - pca_aniso: 0.0 - - edge_descriptor: 0.0 - - Returns - ------- - ArrayLike - Cost matrix of shape (n, m). - """ - n, m = len(source_peaks), len(target_peaks) - C_total = np.zeros((n, m)) - - # --- Default weights --- - default_weights = { - "dist": 0.5, - "edge_angle": 1.0, - "edge_length": 1.0, - "pca_dir": 0.0, - "pca_aniso": 0.0, - "edge_descriptor": 0.0, - } - if weights is None: - weights = default_weights - else: - weights = {**default_weights, **weights} # override defaults - - # --- Base distance cost --- - if weights["dist"] > 0: - C_dist = cdist(source_peaks, target_peaks, metric=distance_metric) - if normalize: - C_dist /= C_dist.max() - C_total += weights["dist"] * C_dist - - # --- Edge angle and length costs --- - source_dists, source_angles = get_edge_attrs(source_peaks, source_edges) - target_dists, target_angles = get_edge_attrs(target_peaks, target_edges) - - if weights["edge_length"] > 0: - C_edge_len = compute_edge_consistency_cost( - n=n, - m=m, - source_attrs=source_dists, - target_attrs=target_dists, - source_edges=source_edges, - target_edges=target_edges, - default=1e6, - ) - if normalize: - C_edge_len /= C_edge_len.max() - C_total += weights["edge_length"] * C_edge_len - - if weights["edge_angle"] > 0: - C_edge_ang = compute_edge_consistency_cost( - n=n, - m=m, - source_attrs=source_angles, - target_attrs=target_angles, - source_edges=source_edges, - target_edges=target_edges, - default=np.pi, - ) - if normalize: - C_edge_ang /= np.pi - C_total += weights["edge_angle"] * C_edge_ang - - # --- PCA features --- - if weights["pca_dir"] > 0 or weights["pca_aniso"] > 0: - dirs_s, aniso_s = get_local_pca_features(source_peaks, source_edges) - dirs_t, aniso_t = get_local_pca_features(target_peaks, target_edges) - - if weights["pca_dir"] > 0: - dot = np.clip(np.dot(dirs_s, dirs_t.T), -1.0, 1.0) - C_dir = 1 - np.abs(dot) - if normalize: - C_dir /= C_dir.max() - C_total += weights["pca_dir"] * C_dir - - if weights["pca_aniso"] > 0: - C_aniso = np.abs(aniso_s[:, None] - aniso_t[None, :]) - if normalize: - C_aniso /= C_aniso.max() - C_total += weights["pca_aniso"] * C_aniso - # --- Edge descriptors --- - if weights["edge_descriptor"] > 0: - desc_s = get_edge_descriptors(source_peaks, source_edges) - desc_t = get_edge_descriptors(target_peaks, target_edges) - C_desc = cdist(desc_s, desc_t) - if normalize: - C_desc /= C_desc.max() - C_total += weights["edge_descriptor"] * C_desc - - return C_total - - -def build_edge_graph( - points: ArrayLike, - mode: Literal["knn", "radius", "full"] = "knn", - k: int = 5, - radius: float = 30.0, -) -> list[tuple[int, int]]: - """ - Build a set of edges for a graph based on a given strategy. - - Parameters - ---------- - points : ArrayLike - (N, 3) array of 3D point coordinates. - mode : Literal["knn", "radius", "full"] - Mode for building the edge graph. - k : int - Number of neighbors if mode == "knn". - radius : float - Distance threshold if mode == "radius". - - Returns - ------- - list[tuple[int, int]] - List of (i, j) index pairs representing edges. - """ - n = len(points) - if n <= 1: - return [] - - if mode == "knn": - k_eff = min(k + 1, n) - nbrs = NearestNeighbors(n_neighbors=k_eff).fit(points) - _, indices = nbrs.kneighbors(points) - edges = [(i, j) for i in range(n) for j in indices[i] if i != j] - - elif mode == "radius": - graph = radius_neighbors_graph( - points, radius=radius, mode='connectivity', include_self=False - ) - if graph.nnz == 0: - return [] - edges = [(i, j) for i in range(n) for j in graph[i].nonzero()[1]] - - elif mode == "full": - edges = [(i, j) for i in range(n) for j in range(n) if i != j] - - return edges - - -def match_hungarian_global_cost( - C: ArrayLike, - cost_threshold: float = 1e5, - dummy_cost: float = 1e6, - max_ratio: float = None, -) -> ArrayLike: - """ - Runs Hungarian matching with padding for unequal-sized graphs, - optionally applying max_ratio filtering similar to match_descriptors. - - Parameters - ---------- - C : ArrayLike - Cost matrix of shape (n_A, n_B). - cost_threshold : float - Maximum cost to consider a valid match. - dummy_cost : float - Cost assigned to dummy nodes (must be > cost_threshold). - max_ratio : float, optional - Maximum allowed ratio between best and second-best cost. - - Returns - ------- - ArrayLike - Array of shape (N_matches, 2) with valid (A_idx, B_idx) pairs. - """ - n_A, n_B = C.shape - n = max(n_A, n_B) - - # Pad cost matrix to square shape - C_padded = np.full((n, n), fill_value=dummy_cost) - C_padded[:n_A, :n_B] = C - - # Solve the assignment problem - row_ind, col_ind = linear_sum_assignment(C_padded) - - matches = [] - for i, j in zip(row_ind, col_ind): - if i >= n_A or j >= n_B: - continue # matched with dummy - if C[i, j] >= cost_threshold: - continue # too costly - - if max_ratio is not None: - # Find second-best match for i - costs_i = C[i, :] - sorted_costs = np.sort(costs_i) - if len(sorted_costs) > 1: - second_best = sorted_costs[1] - ratio = C[i, j] / (second_best + 1e-10) # avoid division by zero - if ratio > max_ratio: - continue # reject if not sufficiently better - # else (only one candidate) => accept by default - - matches.append((i, j)) - - return np.array(matches) - - -def detect_bead_peaks( - source_channel_zyx: da.Array, - target_channel_zyx: da.Array, - source_peaks_settings: DetectPeaksSettings, - target_peaks_settings: DetectPeaksSettings, - verbose: bool = False, - filter_dirty_peaks: bool = False, -) -> tuple[ArrayLike, ArrayLike]: - """ - Detect peaks in source and target channels using the detect_peaks function. - - Parameters - ---------- - source_channel_zyx : da.Array - (T, Z, Y, X) array of the source channel (Dask array). - target_channel_zyx : da.Array - (T, Z, Y, X) array of the target channel (Dask array). - source_peaks_settings : DetectPeaksSettings - Settings for the source peaks. - target_peaks_settings : DetectPeaksSettings - Settings for the target peaks. - verbose : bool - If True, prints detailed logs during the process. - filter_dirty_peaks : bool - If True, filters the dirty peaks. - Returns - ------- - tuple[ArrayLike, ArrayLike] - Tuple of (source_peaks, target_peaks). - """ - if verbose: - click.echo('Detecting beads in source dataset') - - source_peaks = detect_peaks( - source_channel_zyx, - block_size=source_peaks_settings.block_size, - threshold_abs=source_peaks_settings.threshold_abs, - nms_distance=source_peaks_settings.nms_distance, - min_distance=source_peaks_settings.min_distance, - verbose=verbose, - ) - if verbose: - click.echo('Detecting beads in target dataset') - - target_peaks = detect_peaks( - target_channel_zyx, - block_size=target_peaks_settings.block_size, - threshold_abs=target_peaks_settings.threshold_abs, - nms_distance=target_peaks_settings.nms_distance, - min_distance=target_peaks_settings.min_distance, - verbose=verbose, - ) - if verbose: - click.echo(f'Total of peaks in source dataset: {len(source_peaks)}') - click.echo(f'Total of peaks in target dataset: {len(target_peaks)}') - - if len(source_peaks) < 2 or len(target_peaks) < 2: - click.echo('Not enough beads detected') - return - if filter_dirty_peaks: - print("Filtering dirty peaks") - with open_ome_zarr( - Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/dirty_on_mantis/lf_mask_2025_05_01_A549_DENV_sensor_DENV_T_9_0.zarr/C/1/000000" - ) - ) as dirty_mask_ds: - dirty_mask_load = np.asarray(dirty_mask_ds.data[0, 0]) - - # filter the dirty peaks - # Keep only peaks whose (y, x) column is clean across all Z slices - target_peaks_filtered = [] - for peak in target_peaks: - z, y, x = peak.astype(int) - if ( - 0 <= y < dirty_mask_load.shape[1] - and 0 <= x < dirty_mask_load.shape[2] - and not dirty_mask_load[:, y, x].any() # True if all Z are clean at (y, x) - ): - target_peaks_filtered.append(peak) - target_peaks = np.array(target_peaks_filtered) - return source_peaks, target_peaks - - -def get_matches_from_hungarian( - source_peaks: ArrayLike, - target_peaks: ArrayLike, - beads_match_settings: BeadsMatchSettings, - verbose: bool = False, -) -> ArrayLike: - """ - Get matches from beads using the hungarian algorithm. - Parameters - ---------- - source_peaks : ArrayLike - (n, 2) array of source peaks. - target_peaks : ArrayLike - (m, 2) array of target peaks. - beads_match_settings : BeadsMatchSettings - Settings for the beads match. - verbose : bool - If True, prints detailed logs during the process. - - Returns - ------- - ArrayLike - (n, 2) array of matches. - """ - hungarian_settings = beads_match_settings.hungarian_match_settings - cost_settings = hungarian_settings.cost_matrix_settings - edge_settings = hungarian_settings.edge_graph_settings - source_edges = build_edge_graph( - source_peaks, mode=edge_settings.method, k=edge_settings.k, radius=edge_settings.radius - ) - target_edges = build_edge_graph( - target_peaks, mode=edge_settings.method, k=edge_settings.k, radius=edge_settings.radius - ) - - if hungarian_settings.cross_check: - # Step 1: A → B - C_ab = compute_cost_matrix( - source_peaks, - target_peaks, - source_edges, - target_edges, - weights=cost_settings.weights, - distance_metric=hungarian_settings.distance_metric, - normalize=cost_settings.normalize, - ) - - matches_ab = match_hungarian_global_cost( - C_ab, - cost_threshold=np.quantile(C_ab, hungarian_settings.cost_threshold), - max_ratio=hungarian_settings.max_ratio, - ) - - # Step 2: B → A (swap arguments) - C_ba = compute_cost_matrix( - target_peaks, - source_peaks, - target_edges, - source_edges, - weights=cost_settings.weights, - distance_metric=hungarian_settings.distance_metric, - normalize=cost_settings.normalize, - ) - - matches_ba = match_hungarian_global_cost( - C_ba, - cost_threshold=np.quantile(C_ba, hungarian_settings.cost_threshold), - max_ratio=hungarian_settings.max_ratio, - ) - - # Step 3: Invert matches_ba to compare - reverse_map = {(j, i) for i, j in matches_ba} - - # Step 4: Keep only symmetric matches - matches = np.array([[i, j] for i, j in matches_ab if (i, j) in reverse_map]) - else: - # without cross-check - - C = compute_cost_matrix( - source_peaks, - target_peaks, - source_edges, - target_edges, - weights=cost_settings.weights, - distance_metric=hungarian_settings.distance_metric, - normalize=cost_settings.normalize, - ) - - matches = match_hungarian_global_cost( - C, - cost_threshold=np.quantile(C, hungarian_settings.cost_threshold), - max_ratio=hungarian_settings.max_ratio, - ) - return matches - - -def get_matches_from_beads( - source_peaks: ArrayLike, - target_peaks: ArrayLike, - beads_match_settings: BeadsMatchSettings, - verbose: bool = False, -) -> ArrayLike: - """ - Get matches from beads using the hungarian algorithm. - - Parameters - ---------- - source_peaks : ArrayLike - (n, 2) array of source peaks. - target_peaks : ArrayLike - (m, 2) array of target peaks. - beads_match_settings : BeadsMatchSettings - Settings for the beads match. - verbose : bool - If True, prints detailed logs during the process. - - Returns - ------- - ArrayLike - (n, 2) array of matches. - """ - if verbose: - click.echo(f'Getting matches from beads with settings: {beads_match_settings}') - - if beads_match_settings.algorithm == 'match_descriptor': - match_descriptor_settings = beads_match_settings.match_descriptor_settings - matches = match_descriptors( - source_peaks, - target_peaks, - metric=match_descriptor_settings.distance_metric, - max_ratio=match_descriptor_settings.max_ratio, - cross_check=match_descriptor_settings.cross_check, - ) - - elif beads_match_settings.algorithm == 'hungarian': - matches = get_matches_from_hungarian( - source_peaks=source_peaks, - target_peaks=target_peaks, - beads_match_settings=beads_match_settings, - verbose=verbose, - ) - - if verbose: - click.echo(f'Total of matches: {len(matches)}') - - return matches - - -def filter_matches( - matches: ArrayLike, - source_peaks: ArrayLike, - target_peaks: ArrayLike, - angle_threshold: float = 30, - distance_threshold: float = 0.95, - verbose: bool = False, -) -> ArrayLike: - """ - Filter matches based on the angle and distance thresholds. - - Parameters - ---------- - matches : ArrayLike - (n, 2) array of matches. - source_peaks : ArrayLike - (n, 2) array of source peaks. - target_peaks : ArrayLike - (n, 2) array of target peaks. - angle_threshold : float - Angle threshold in degrees. - distance_threshold : float - Distance threshold. - verbose : bool - If True, prints detailed logs during the process. - - Returns - ------- - ArrayLike - (n, 2) array of filtered matches. - - Notes - ----- - Uses the angle and distance thresholds to filter matches. - The angle threshold is the maximum allowed angle between the source and target peaks. - The distance threshold is the maximum allowed distance between the source and target peaks. - The dominant angle is the angle that appears most frequently in the matches. - """ - if distance_threshold: - click.echo(f'Filtering matches with distance threshold: {distance_threshold}') - dist = np.linalg.norm( - source_peaks[matches[:, 0]] - target_peaks[matches[:, 1]], axis=1 - ) - matches = matches[dist < np.quantile(dist, distance_threshold), :] - - if verbose: - click.echo(f'Total of matches after distance filtering: {len(matches)}') - - if angle_threshold: - click.echo(f'Filtering matches with angle threshold: {angle_threshold}') - vectors = target_peaks[matches[:, 1]] - source_peaks[matches[:, 0]] - angles_rad = np.arctan2(vectors[:, 1], vectors[:, 0]) - angles_deg = np.degrees(angles_rad) - - bins = np.linspace(-180, 180, 36) # 10-degree bins - hist, bin_edges = np.histogram(angles_deg, bins=bins) - - dominant_bin_index = np.argmax(hist) - dominant_angle = ( - bin_edges[dominant_bin_index] + bin_edges[dominant_bin_index + 1] - ) / 2 - - filtered_indices = np.where(np.abs(angles_deg - dominant_angle) <= angle_threshold)[0] - - matches = matches[filtered_indices] - - if verbose: - click.echo(f'Total of matches after angle filtering: {len(matches)}') - - return matches - - -def estimate_transform( - matches: ArrayLike, - source_peaks: ArrayLike, - target_peaks: ArrayLike, - affine_transform_settings: AffineTransformSettings, - verbose: bool = False, -) -> ArrayLike: - """ - Estimate the affine transformation matrix between source and target channels - based on detected bead matches at a specific timepoint. - - Parameters - ---------- - matches : ArrayLike - (n, 2) array of matches. - source_peaks : ArrayLike - (n, 2) array of source peaks. - target_peaks : ArrayLike - (n, 2) array of target peaks. - affine_transform_settings : AffineTransformSettings - Settings for the affine transform. - verbose : bool - If True, prints detailed logs during the process. - - Returns - ------- - ArrayLike - (4, 4) array of the affine transformation matrix. - """ - if verbose: - click.echo(f"Estimating transform with settings: {affine_transform_settings}") - - if affine_transform_settings.transform_type == 'affine': - tform = AffineTransform(dimensionality=3) - - elif affine_transform_settings.transform_type == 'euclidean': - tform = EuclideanTransform(dimensionality=3) - - elif affine_transform_settings.transform_type == 'similarity': - tform = SimilarityTransform(dimensionality=3) - - else: - raise ValueError(f'Unknown transform type: {affine_transform_settings.transform_type}') - - tform.estimate(source_peaks[matches[:, 0]], target_peaks[matches[:, 1]]) - - return tform - - -def estimate_transform_from_beads( - t_idx: int, - source_channel_tzyx: da.Array, - target_channel_tzyx: da.Array, - beads_match_settings: BeadsMatchSettings, - affine_transform_settings: AffineTransformSettings, - verbose: bool = False, - slurm: bool = False, - output_folder_path: Path = None, -) -> list | None: - """ - Calculate the affine transformation matrix between source and target channels - based on detected bead matches at a specific timepoint. - - This function detects beads in both source and target datasets, matches them, - and computes an affine transformation to align the two channels. It applies - various filtering steps, including angle-based filtering, to improve match quality. - - Parameters - ---------- - t_idx : int - Timepoint index to process. - source_channel_tzyx : da.Array - 4D array (T, Z, Y, X) of the source channel (Dask array). - target_channel_tzyx : da.Array - 4D array (T, Z, Y, X) of the target channel (Dask array). - beads_match_settings : BeadsMatchSettings - Settings for the beads match. - affine_transform_settings : AffineTransformSettings - Settings for the affine transform. - verbose : bool - If True, prints detailed logs during the process. - slurm : bool - If True, uses SLURM for parallel processing. - output_folder_path : Path - Path to save the output. - - Returns - ------- - list | None - A 4x4 affine transformation matrix as a nested list if successful, - or None if no valid transformation could be calculated. - - Notes - ----- - Uses ANTsPy for initial transformation application and bead detection. - Peaks (beads) are detected using a block-based algorithm with thresholds for source and target datasets. - Bead matches are filtered based on distance and angular deviation from the dominant direction. - If fewer than three matches are found after filtering, the function returns None. - """ - - click.echo(f'Processing timepoint: {t_idx}') - - source_channel_zyx = np.asarray(source_channel_tzyx[t_idx]).astype(np.float32) - target_channel_zyx = np.asarray(target_channel_tzyx[t_idx]).astype(np.float32) - - if _check_nan_n_zeros(source_channel_zyx) or _check_nan_n_zeros(target_channel_zyx): - click.echo(f'Beads data is missing at timepoint {t_idx}') - return - - approx_tform = np.asarray(affine_transform_settings.approx_transform) - source_data_ants = ants.from_numpy(source_channel_zyx) - target_data_ants = ants.from_numpy(target_channel_zyx) - source_data_reg = ( - convert_transform_to_ants(approx_tform) - .apply_to_image(source_data_ants, reference=target_data_ants) - .numpy() - ) - - source_peaks, target_peaks = detect_bead_peaks( - source_channel_zyx=source_data_reg, - target_channel_zyx=target_channel_zyx, - source_peaks_settings=beads_match_settings.source_peaks_settings, - target_peaks_settings=beads_match_settings.target_peaks_settings, - verbose=verbose, - ) - - matches = get_matches_from_beads( - source_peaks=source_peaks, - target_peaks=target_peaks, - beads_match_settings=beads_match_settings, - verbose=verbose, - ) - - matches = filter_matches( - matches=matches, - source_peaks=source_peaks, - target_peaks=target_peaks, - angle_threshold=beads_match_settings.filter_angle_threshold, - distance_threshold=beads_match_settings.filter_distance_threshold, - ) - - if len(matches) < 3: - click.echo( - f'Source and target beads were not matches successfully for timepoint {t_idx}' - ) - return - - tform = estimate_transform( - matches=matches, - source_peaks=source_peaks, - target_peaks=target_peaks, - affine_transform_settings=affine_transform_settings, - verbose=verbose, - ) - compount_tform = np.asarray(approx_tform) @ tform.inverse.params - - if verbose: - click.echo(f'Matches: {matches}') - click.echo(f"tform.params: {tform.params}") - click.echo(f"tform.inverse.params: {tform.inverse.params}") - click.echo(f"compount_tform: {compount_tform}") - - if slurm: - print(f"Saving transform to {output_folder_path}") - output_folder_path.mkdir(parents=True, exist_ok=True) - np.save(output_folder_path / f"{t_idx}.npy", compount_tform) - - return compount_tform.tolist() - - def estimate_registration( source_position_dirpaths: list[str], target_position_dirpaths: list[str], @@ -2098,23 +444,31 @@ def estimate_registration( eval_transform_settings = settings.eval_transform_settings if settings.estimation_method == "beads": - transforms = beads_based_registration( - source_channel_tzyx=source_channel_data, - target_channel_tzyx=target_channel_data, + from biahub.registration.beads import estimate_tczyx + + transforms = estimate_tczyx( + mov_tczyx=source_data, + ref_tczyx=target_data, + mov_channel_index=source_channel_index, + ref_channel_index=target_channel_index, beads_match_settings=settings.beads_match_settings, affine_transform_settings=settings.affine_transform_settings, verbose=settings.verbose, cluster=cluster, sbatch_filepath=sbatch_filepath, output_folder_path=output_dir, + ref_voxel_size=target_channel_voxel_size, + mov_voxel_size=source_channel_voxel_size, ) elif settings.estimation_method == "ants": - transforms = ants_registration( - source_data_tczyx=source_data, - target_data_tczyx=target_data, - source_channel_index=source_channel_index, - target_channel_index=target_channel_index, + from biahub.registration.ants import estimate_tczyx + + transforms = estimate_tczyx( + mov_tczyx=source_data, + ref_tczyx=target_data, + mov_channel_index=source_channel_index, + ref_channel_index=target_channel_index, ants_registration_settings=settings.ants_registration_settings, affine_transform_settings=settings.affine_transform_settings, sbatch_filepath=sbatch_filepath, @@ -2173,9 +527,10 @@ def estimate_registration( ) model = StabilizationSettings( - stabilization_estimation_channel='', - stabilization_type='xyz', - stabilization_channels=registration_source_channels, + stabilization_estimation_channel=target_channel_name, + stabilization_type="affine", + stabilization_method=settings.estimation_method, + stabilization_channels=[source_channel_name, target_channel_name], affine_transform_zyx_list=transforms, time_indices='all', output_voxel_size=voxel_size, diff --git a/biahub/estimate_stabilization.py b/biahub/estimate_stabilization.py index 01c11ef2..4c1490b2 100644 --- a/biahub/estimate_stabilization.py +++ b/biahub/estimate_stabilization.py @@ -1,5 +1,4 @@ import itertools -import os import shutil from datetime import datetime @@ -30,14 +29,12 @@ ) from biahub.cli.slurm import wait_for_jobs_to_finish from biahub.cli.utils import estimate_resources, yaml_to_model -from biahub.estimate_registration import ( - estimate_transform_from_beads, +from biahub.registration.utils import ( evaluate_transforms, + match_shape, save_transforms, ) from biahub.settings import ( - AffineTransformSettings, - BeadsMatchSettings, EstimateStabilizationSettings, FocusFindingSettings, PhaseCrossCorrSettings, @@ -76,112 +73,6 @@ def remove_beads_fov_from_path_list( return position_dirpaths -def pad_to_shape( - arr: ArrayLike, - shape: Tuple[int, ...], - mode: str, - verbose: bool = False, - **kwargs, -) -> ArrayLike: - """ - Pad or crop array to match provided shape. - - Parameters - ---------- - arr : ArrayLike - Input array. - shape : Tuple[int] - Output shape. - mode : str - Padding mode (see np.pad). - verbose : bool - If True, print verbose output. - kwargs : dict - Additional keyword arguments for np.pad. - - Returns - ------- - ArrayLike - Padded array. - """ - assert arr.ndim == len(shape) - - dif = tuple(s - a for s, a in zip(shape, arr.shape)) - assert all(d >= 0 for d in dif) - - pad_width = [[s // 2, s - s // 2] for s in dif] - - if verbose: - click.echo( - f"padding: input shape {arr.shape}, output shape {shape}, padding {pad_width}" - ) - - return np.pad(arr, pad_width=pad_width, mode=mode, **kwargs) - - -def center_crop( - arr: ArrayLike, - shape: Tuple[int, ...], - verbose: bool = False, -) -> ArrayLike: - """ - Crop the center of `arr` to match provided shape. - - Parameters - ---------- - arr : ArrayLike - Input array. - shape : Tuple[int, ...] - """ - assert arr.ndim == len(shape) - - starts = tuple((cur_s - s) // 2 for cur_s, s in zip(arr.shape, shape)) - - assert all(s >= 0 for s in starts) - - slicing = tuple(slice(s, s + d) for s, d in zip(starts, shape)) - if verbose: - click.echo( - f"center crop: input shape {arr.shape}, output shape {shape}, slicing {slicing}" - ) - return arr[slicing] - - -def match_shape( - img: ArrayLike, - shape: Tuple[int, ...], - verbose: bool = False, -) -> ArrayLike: - """ - Pad or crop array to match provided shape. - - Parameters - ---------- - img : ArrayLike - Input array. - shape : Tuple[int, ...] - verbose : bool - If True, print verbose output. - - Returns - ------- - ArrayLike - Padded or cropped array. - """ - - if np.any(shape > img.shape): - padded_shape = np.maximum(img.shape, shape) - img = pad_to_shape(img, padded_shape, mode="reflect") - - if np.any(shape < img.shape): - img = center_crop(img, shape) - - if verbose: - click.echo(f"matched shape: input shape {img.shape}, output shape {shape}") - - return img - - def plot_cross_correlation( corr, title="Cross-Correlation", @@ -803,132 +694,6 @@ def estimate_xyz_stabilization_pcc( return fov_transforms -def estimate_xyz_stabilization_with_beads( - channel_tzyx: da.Array, - beads_match_settings: BeadsMatchSettings, - affine_transform_settings: AffineTransformSettings, - verbose: bool = False, - cluster: str = "local", - sbatch_filepath: Optional[Path] = None, - output_folder_path: Path = None, -) -> list[ArrayLike]: - """ - Estimate the xyz stabilization for a single position. - - Parameters - ---------- - channel_tzyx : da.Array - Source channel data. - beads_match_settings : BeadsMatchSettings - Settings for the beads match. - affine_transform_settings : AffineTransformSettings - Settings for the affine transform. - verbose : bool - If True, print verbose output. - cluster : str - Cluster to use. - sbatch_filepath : Path - Path to the sbatch file. - output_folder_path : Path - Path to the output folder. - - Returns - ------- - list[ArrayLike] - List of the xyz stabilization for each timepoint. - """ - - (T, Z, Y, X) = channel_tzyx.shape - - if beads_match_settings.t_reference == "first": - target_channel_tzyx = np.broadcast_to(channel_tzyx[0], (T, Z, Y, X)).copy() - elif beads_match_settings.t_reference == "previous": - target_channel_tzyx = np.roll(channel_tzyx, shift=-1, axis=0) - target_channel_tzyx[0] = channel_tzyx[0] - - else: - raise ValueError("Invalid reference. Please use 'first' or 'previous as reference") - - # Compute transformations in parallel - - num_cpus, gb_ram_per_cpu = estimate_resources( - shape=(T, 1, Z, Y, X), ram_multiplier=5, max_num_cpus=16 - ) - - # Prepare SLURM arguments - slurm_args = { - "slurm_job_name": "estimate_focus_z", - "slurm_mem_per_cpu": f"{gb_ram_per_cpu}G", - "slurm_cpus_per_task": num_cpus, - "slurm_array_parallelism": 100, - "slurm_time": 30, - "slurm_partition": "preempted", - } - - if sbatch_filepath: - slurm_args.update(sbatch_to_submitit(sbatch_filepath)) - - output_folder_path.mkdir(parents=True, exist_ok=True) - slurm_out_path = output_folder_path / "slurm_output" - slurm_out_path.mkdir(parents=True, exist_ok=True) - - # Submitit executor - executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster) - executor.update_parameters(**slurm_args) - - click.echo(f"Submitting SLURM focus estimation jobs with resources: {slurm_args}") - output_transforms_path = output_folder_path / "xyz_transforms" - output_transforms_path.mkdir(parents=True, exist_ok=True) - - # Submit jobs - jobs = [] - with submitit.helpers.clean_env(), executor.batch(): - for t in range(1, T, 1): - job = executor.submit( - estimate_transform_from_beads, - source_channel_tzyx=channel_tzyx, - target_channel_tzyx=target_channel_tzyx, - verbose=verbose, - beads_match_settings=beads_match_settings, - affine_transform_settings=affine_transform_settings, - slurm=True, - output_folder_path=output_transforms_path, - t_idx=t, - ) - jobs.append(job) - - # Save job IDs - timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - log_path = slurm_out_path / f"job_ids_{timestamp}.log" - with open(log_path, "w") as log_file: - for job in jobs: - log_file.write(f"{job.job_id}\n") - - wait_for_jobs_to_finish(jobs) - - # Load the transforms - transforms = [np.eye(4).tolist()] - for t in range(1, T): - file_path = output_transforms_path / f"{t}.npy" - if not os.path.exists(file_path): - transforms.append(None) - click.echo(f"Transform for timepoint {t} not found.") - else: - T_zyx_shift = np.load(file_path).tolist() - transforms.append(T_zyx_shift) - - # Check if the number of transforms matches the number of timepoints - if len(transforms) != T: - raise ValueError( - f"Number of transforms {len(transforms)} does not match number of timepoints {T}" - ) - - # Remove the output folder - shutil.rmtree(output_transforms_path) - - return transforms - - def estimate_xy_stabilization_per_position( input_position_dirpath: Path, output_folder_path: Path, @@ -1633,21 +1398,26 @@ def estimate_stabilization( click.echo( f"Error estimating {stabilization_type} stabilization parameters: {e}" ) - elif stabilization_method == "beads": + from biahub.registration.beads import estimate_tczyx + click.echo("Estimating xyz stabilization parameters with beads") with open_ome_zarr(input_position_dirpaths[0], mode="r") as beads_position: source_channels = beads_position.channel_names source_channel_index = source_channels.index(stabilization_estimation_channel) - channel_tzyx = beads_position.data.dask_array()[:, source_channel_index] + channel_tczyx = beads_position.data.dask_array() - xyz_transforms = estimate_xyz_stabilization_with_beads( - channel_tzyx=channel_tzyx, + xyz_transforms = estimate_tczyx( + mov_tczyx=channel_tczyx, + ref_tczyx=channel_tczyx, + mov_channel_index=source_channel_index, + ref_channel_index=source_channel_index, beads_match_settings=settings.beads_match_settings, affine_transform_settings=settings.affine_transform_settings, verbose=verbose, output_folder_path=output_dirpath, + mode="stabilization", cluster=cluster, sbatch_filepath=sbatch_filepath, ) diff --git a/biahub/registration/ants.py b/biahub/registration/ants.py new file mode 100644 index 00000000..230a1389 --- /dev/null +++ b/biahub/registration/ants.py @@ -0,0 +1,532 @@ +""" +ANTs-based intensity registration module. + +Provides functions for registering volumetric imaging data using the ANTsPy +library's optimization-based registration. This complements bead-based +registration by directly optimizing image similarity metrics. + +Pipeline overview +----------------- +1. **Preprocessing** (`preprocess_czyx`): Apply initial transform, crop to + overlapping region (LIR), clip intensities, and optionally apply Sobel filter. +2. **Registration** (`estimate`): Run ANTs optimization (Similarity transform) + on preprocessed volumes. +3. **Postprocessing** (`postprocess_transform`): Compose the initial transform + with the ANTs correction, accounting for any crop offsets. +4. **Batch processing** (`estimate_tczyx`): Submit per-timepoint registration + jobs to SLURM via submitit. + +Key conventions +--------------- +- Coordinates are in ZYX order for 3D data. +- "mov" / "moving" refers to the source channel being aligned. +- "ref" / "reference" refers to the fixed target channel. +- Transforms are 4x4 homogeneous matrices stored as Transform objects. +""" + +from datetime import datetime +from pathlib import Path + +import ants +import click +import dask.array as da +import numpy as np +import submitit + +from skimage import filters + +from biahub.cli.parsing import ( + sbatch_to_submitit, +) +from biahub.cli.slurm import wait_for_jobs_to_finish +from biahub.cli.utils import _check_nan_n_zeros, estimate_resources +from biahub.core.transform import Transform +from biahub.registration.utils import ( + find_lir, + load_transforms, +) +from biahub.settings import ( + AffineTransformSettings, + AntsRegistrationSettings, +) + + +def estimate( + ref: np.ndarray, + mov: np.ndarray, + verbose: bool = False, + ants_kwargs: dict = None, +) -> tuple[Transform, Transform]: + """ + Estimate affine transformation using ANTs registration. + + Works for both 2D (Y, X) and 3D (Z, Y, X) arrays. + + Parameters + ---------- + ref : np.ndarray + Reference image (2D or 3D) + mov : np.ndarray + Moving image (2D or 3D) + verbose : bool + Print optimization progress + ants_kwargs : dict, optional + Additional ANTs parameters + + Returns + ------- + fwd_transform : Transform + Forward transformation (mov → ref) + inv_transform : Transform + Inverse transformation (ref → mov) + """ + if ref.ndim not in (2, 3) or mov.ndim not in (2, 3): + raise ValueError( + f"Images must be 2D or 3D, got ref.ndim={ref.ndim}, mov.ndim={mov.ndim}" + ) + + if ref.ndim != mov.ndim: + raise ValueError(f"Dimension mismatch: ref.ndim={ref.ndim}, mov.ndim={mov.ndim}") + + if ants_kwargs is None: + ants_kwargs = { + "type_of_transform": "Similarity", + "aff_shrink_factors": (6, 3, 1), + "aff_iterations": (2100, 1200, 50), + "aff_smoothing_sigmas": (2, 1, 0), + } + + mov_ants = ants.from_numpy(mov) + ref_ants = ants.from_numpy(ref) + + if verbose: + click.echo(f"Optimizing registration parameters using ANTs with kwargs: {ants_kwargs}") + + reg = ants.registration( + fixed=ref_ants, + moving=mov_ants, + **ants_kwargs, + verbose=verbose, + ) + + fwd_transform_mat = ants.read_transform(reg["fwdtransforms"][0]) + inv_transform_mat = ants.read_transform(reg["invtransforms"][0]) + + fwd_transform = Transform.from_ants(fwd_transform_mat) + inv_transform = Transform.from_ants(inv_transform_mat) + + if fwd_transform.matrix is None or inv_transform.matrix is None: + raise ValueError("Failed to estimate registration transform.") + + return fwd_transform, inv_transform + + +def preprocess_czyx( + mov_czyx: np.ndarray, + ref_czyx: np.ndarray, + initial_tform: Transform, + mov_channel_index: int | list = 0, + ref_channel_index: int = 0, + crop: bool = False, + ref_mask_radius: float | None = None, + clip: bool = False, + sobel_filter: bool = False, + verbose: bool = False, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Optimize the affine transform between source and target channels using ANTs library. + + Parameters + ---------- + mov_czyx : np.ndarray + Source channel data in CZYX format. + ref_czyx : np.ndarray + Target channel data in CZYX format. + initial_tform : np.ndarray + Approximate estimate of the affine transform matrix, often obtained through manual registration, see `estimate-registration`. + mov_channel_index : int | list, optional + Index or list of indices of mov channels to be used for registration, by default 0. + ref_channel_index : int, optional + Index of the reference channel to be used for registration, by default 0. + crop : bool, optional + Whether to crop the moving and reference channels to the overlapping region as determined by the LIR algorithm, by default False. + ref_mask_radius : float | None, optional + Radius of the circular mask which will be applied to the reference channel. By default None in which case no masking will be applied. + clip : bool, optional + Whether to clip the moving and reference channels to reasonable (hardcoded) values, by default False. + sobel_filter : bool, optional + Whether to apply Sobel filter to the moving and reference channels, by default False. + verbose : bool, optional + Whether to print verbose output during registration, by default False. + + Returns + ------- + Transform | None + Optimized affine transform matrix or None if the input data contains NaN or zeros. + + Notes + ----- + This function applies an initial affine transform to the source channels, crops them to the overlapping region with the target channel, + clips the values, applies Sobel filtering if specified, and then optimizes the registration parameters using ANTs library. + + This function currently assumes that target channel is phase and source channels are fluorescence. + If multiple source channels are provided, they will be summed, after clipping, filtering, and cropping, if enabled. + """ + + mov_czyx = np.asarray(mov_czyx).astype(np.float32) + ref_czyx = np.asarray(ref_czyx).astype(np.float32) + + if ref_mask_radius is not None and not (0 < ref_mask_radius <= 1): + raise ValueError( + "ref_mask_radius must be given as a fraction of image width, i.e. (0, 1]." + ) + + if _check_nan_n_zeros(mov_czyx) or _check_nan_n_zeros(ref_czyx): + raise ValueError("Input data contains NaN or zeros.") + t_form_ants = initial_tform.to_ants() + + ref_zyx = ref_czyx[ref_channel_index] + if ref_zyx.ndim != 3: + raise ValueError(f"Expected 3D reference channel, got shape {ref_zyx.shape}") + ref_ants_pre_crop = ants.from_numpy(ref_zyx) + + if not isinstance(mov_channel_index, list): + mov_channel_index = [mov_channel_index] + + mov_channels = [] + for idx in mov_channel_index: + if verbose: + click.echo(f"Applying initial transform to moving channel {idx}...") + # Cropping, clipping, and filtering are applied after registration with initial_tform + _mov_channel = np.asarray(mov_czyx[idx]).astype(np.float32) + if _mov_channel.ndim != 3: + raise ValueError(f"Expected 3D moving channel, got shape {_mov_channel.shape}") + mov_channel = t_form_ants.apply_to_image( + ants.from_numpy(_mov_channel), reference=ref_ants_pre_crop + ).numpy() + if mov_channel.ndim != 3: + raise ValueError( + f"apply_to_image returned non-3D array: {mov_channel.shape}.\n" + "This is likely caused by mismatched input shape or invalid transform/reference." + ) + mov_channels.append(mov_channel) + + _offset = np.zeros(3, dtype=np.float32) + if crop: + if verbose: + click.echo( + "Estimating crop for moving and reference channels to overlapping region..." + ) + mask = (ref_zyx != 0) & (mov_channels[0] != 0) + + # Can be refactored with code in cropping PR #88 + if ref_mask_radius is not None: + ref_mask = np.zeros(ref_zyx.shape[-2:], dtype=bool) + + y, x = np.ogrid[: ref_mask.shape[-2], : ref_mask.shape[-1]] + center = (ref_mask.shape[-2] // 2, ref_mask.shape[-1] // 2) + radius = int(ref_mask_radius * min(center)) + + ref_mask[(x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2] = True + mask *= ref_mask + + z_slice, y_slice, x_slice = find_lir(mask.astype(np.uint8)) + click.echo( + f"Cropping to region z={z_slice.start}:{z_slice.stop}, " + f"y={y_slice.start}:{y_slice.stop}, " + f"x={x_slice.start}:{x_slice.stop}" + ) + + _offset = np.asarray( + [_s.start for _s in (z_slice, y_slice, x_slice)], dtype=np.float32 + ) + ref_zyx = ref_zyx[z_slice, y_slice, x_slice] + mov_channels = [_channel[z_slice, y_slice, x_slice] for _channel in mov_channels] + + # TODO: hardcoded clipping limits + if clip: + if verbose: + click.echo("Clipping moving and reference channels to reasonable values...") + ref_zyx = np.clip(ref_zyx, 0, 0.5) + mov_channels = [ + np.clip(_channel, 110, np.quantile(_channel, 0.99)) for _channel in mov_channels + ] + + if sobel_filter: + if verbose: + click.echo("Applying Sobel filter to moving and reference channels...") + ref_zyx = filters.sobel(ref_zyx) + mov_channels = [filters.sobel(_channel) for _channel in mov_channels] + + mov_zyx = np.sum(mov_channels, axis=0) + + return ref_zyx, mov_zyx, _offset + + +def estimate_czyx( + mov_czyx: np.ndarray, + ref_czyx: np.ndarray, + initial_tform: np.ndarray, + mov_channel_index: int | list = 0, + ref_channel_index: int = 0, + crop: bool = False, + ref_mask_radius: float | None = None, + clip: bool = False, + sobel_filter: bool = False, + verbose: bool = False, + t_idx: int = 0, + output_folder_path: str | None = None, +) -> Transform: + """ + Optimize the affine transform between source and target channels using ANTs library. + + Parameters + ---------- + mov_czyx : np.ndarray + Moving channel data in CZYX format. + ref_czyx : np.ndarray + Reference channel data in CZYX format. + initial_tform : np.ndarray + Approximate estimate of the affine transform matrix, often obtained through manual registration, see `estimate-registration`. + mov_channel_index : int | list, optional + Index or list of indices of moving channels to be used for registration, by default 0. + ref_channel_index : int, optional + Index of the reference channel to be used for registration, by default 0. + crop : bool, optional + Whether to crop the moving and reference channels to the overlapping region as determined by the LIR algorithm, by default False. + ref_mask_radius : float | None, optional + Radius of the circular mask which will be applied to the reference channel. By default None in which case no masking will be applied. + clip : bool, optional + Whether to clip the moving and reference channels to reasonable (hardcoded) values, by default False. + sobel_filter : bool, optional + Whether to apply Sobel filter to the moving and reference channels, by default False. + verbose : bool, optional + Whether to print verbose output during registration, by default False. + t_idx : int, optional + Time index for the registration, by default 0. + output_folder_path : str | None, optional + Path to the folder where the output transform will be saved, by default None. + + Returns + ------- + Transform | None + Optimized affine transform matrix or None if the input data contains NaN or zeros. + + Notes + ----- + This function applies an initial affine transform to the moving channels, crops them to the overlapping region with the reference channel, + clips the values, applies Sobel filtering if specified, and then optimizes the registration parameters using ANTs library. + + This function currently assumes that reference channel is phase and moving channels are fluorescence. + If multiple moving channels are provided, they will be summed, after clipping, filtering, and cropping, if enabled. + """ + initial_tform = Transform(matrix=initial_tform) + + ref_zyx, mov_zyx, preprocess_offset = preprocess_czyx( + mov_czyx=mov_czyx, + ref_czyx=ref_czyx, + initial_tform=initial_tform, + mov_channel_index=mov_channel_index, + ref_channel_index=ref_channel_index, + crop=crop, + clip=clip, + ref_mask_radius=ref_mask_radius, + sobel_filter=sobel_filter, + verbose=verbose, + ) + + fwd_transform, inv_transform = estimate( + ref=ref_zyx, + mov=mov_zyx, + verbose=verbose, + ) + + composed_transform = postprocess_transform( + initial_transform=initial_tform, + fwd_transform=fwd_transform, + preprocess_offset=preprocess_offset, + ) + if verbose: + click.echo(f"Initial transform: {initial_tform}") + click.echo(f"Forward transform: {fwd_transform}") + click.echo(f"Inverse transform: {inv_transform}") + click.echo(f"Composed transform: {composed_transform}") + + if composed_transform is None: + raise ValueError("Failed to estimate registration transform for timepoint.") + + if output_folder_path: + output_folder_path.mkdir(parents=True, exist_ok=True) + if verbose: + click.echo( + f"Saving registration transform for timepoint {t_idx} to {output_folder_path}" + ) + + np.save(output_folder_path / f"{t_idx}.npy", composed_transform.matrix) + + return composed_transform + + +def postprocess_transform( + initial_transform: Transform, + fwd_transform: Transform, + preprocess_offset: np.ndarray, +) -> Transform: + """ + Compose the initial and ANTs-estimated transforms, accounting for crop offset. + + The ANTs registration operates on cropped volumes. This function shifts + into the cropped ROI, applies the ANTs correction, shifts back, and + composes the result with the initial transform: + composed = initial @ shift_to_roi @ fwd_transform @ shift_back + + Parameters + ---------- + initial_transform : Transform + The approximate transform applied before ANTs registration. + fwd_transform : Transform + The forward correction estimated by ANTs on the cropped volumes. + preprocess_offset : np.ndarray + (3,) ZYX offset of the crop ROI origin within the full volume. + Zero if no cropping was applied. + + Returns + ------- + Transform + The final composed transform mapping original moving -> reference space. + """ + + shift_to_roi = np.eye(4) + shift_to_roi[:3, -1] = preprocess_offset + + shift_back = np.eye(4) + shift_back[:3, -1] = -preprocess_offset + + composed_matrix = ( + initial_transform.matrix @ shift_to_roi @ fwd_transform.matrix @ shift_back + ) + + return Transform(matrix=composed_matrix) + + +def estimate_tczyx( + mov_tczyx: da.Array, + ref_tczyx: da.Array, + mov_channel_index: int | list[int], + ref_channel_index: int, + ants_registration_settings: AntsRegistrationSettings, + affine_transform_settings: AffineTransformSettings, + verbose: bool = False, + output_folder_path: Path = None, + cluster: str = 'local', + sbatch_filepath: Path = None, +) -> list[Transform]: + """ + Perform ants registration of two volumetric image channels. + + This function calculates timepoint-specific affine transformations to align a moving channel + to a target channel in 4D (T, Z, Y, X) data. It validates, smooths, and interpolates transformations + across timepoints for consistent registration. + + Parameters + ---------- + mov_tczyx : da.Array + 4D array (T, C, Z, Y, X) of the moving channel (Dask array). + ref_tczyx : da.Array + 4D array (T, C, Z, Y, X) of the reference channel (Dask array). + mov_channel_index : int | list[int] + Index of the moving channel. + ref_channel_index : int + Index of the reference channel. + ants_registration_settings : AntsRegistrationSettings + Settings for the ANTs registration. + affine_transform_settings : AffineTransformSettings + Settings for the affine transform. + verbose : bool, optional + Whether to print verbose output during registration, by default False. + output_folder_path : str | None, optional + Path to the folder where the output transform will be saved, by default None. + cluster : str, optional + Cluster to use, by default 'local'. + Returns + ------- + list[Transform] + List of affine transformation matrices (4x4), one for each timepoint. + Invalid or missing transformations are interpolated. + + Notes + ----- + Each timepoint is processed in parallel using submitit executor. + Use verbose=True for detailed logging during registration. The verbose output will be saved at the same level as the output zarr. + """ + T, C, Z, Y, X = mov_tczyx.shape + initial_tform = np.asarray(affine_transform_settings.approx_transform) + click.echo(f"Initial transform: {initial_tform}") + + num_cpus, gb_ram_per_cpu = estimate_resources( + shape=(T, 2, Z, Y, X), ram_multiplier=16, max_num_cpus=16 + ) + + # Prepare SLURM arguments + slurm_args = { + "slurm_job_name": "estimate_registration_ants", + "slurm_mem_per_cpu": f"{gb_ram_per_cpu}G", + "slurm_cpus_per_task": num_cpus, + "slurm_array_parallelism": 100, + "slurm_time": 30, + "slurm_partition": "preempted", + } + + if sbatch_filepath: + slurm_args.update(sbatch_to_submitit(sbatch_filepath)) + + output_folder_path.mkdir(parents=True, exist_ok=True) + slurm_out_path = output_folder_path / "slurm_output" + slurm_out_path.mkdir(parents=True, exist_ok=True) + + # Submitit executor + executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster) + executor.update_parameters(**slurm_args) + + click.echo(f"Submitting SLURM estimate regstration jobs with resources: {slurm_args}") + output_transforms_path = output_folder_path / "xyz_transforms" + output_transforms_path.mkdir(parents=True, exist_ok=True) + + click.echo('Computing registration transforms...') + # NOTE: ants is mulitthreaded so no need for multiprocessing here + # Submit jobs + jobs = [] + with submitit.helpers.clean_env(), executor.batch(): + for t in range(T): + job = executor.submit( + estimate_czyx, + mov_czyx=mov_tczyx[t], + ref_czyx=ref_tczyx[t], + initial_tform=initial_tform, + mov_channel_index=mov_channel_index, + ref_channel_index=ref_channel_index, + crop=ants_registration_settings.crop, + ref_mask_radius=ants_registration_settings.ref_mask_radius, + clip=ants_registration_settings.clip, + sobel_filter=ants_registration_settings.sobel_filter, + verbose=verbose, + t_idx=t, + output_folder_path=output_transforms_path, + ) + jobs.append(job) + + # Save job IDs + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + log_path = slurm_out_path / f"job_ids_{timestamp}.log" + with open(log_path, "w") as log_file: + for job in jobs: + log_file.write(f"{job.job_id}\n") + + wait_for_jobs_to_finish(jobs) + + transforms = load_transforms(output_transforms_path, T, verbose) + if len(transforms) != T: + raise ValueError( + f"Number of transforms {len(transforms)} does not match number of timepoints {T}" + ) + return transforms diff --git a/biahub/registration/beads.py b/biahub/registration/beads.py new file mode 100644 index 00000000..0f4e13e5 --- /dev/null +++ b/biahub/registration/beads.py @@ -0,0 +1,1120 @@ +""" +Beads-based registration module. + +Provides functions for registering volumetric imaging data by detecting fluorescent +bead landmarks in moving and reference channels, matching them using graph-based +algorithms, and estimating affine transformations. + +Pipeline overview +----------------- +1. **Peak detection** (`peaks_from_beads`): Detect bead positions in both channels. +2. **Matching** (`matches_from_beads`): Find bead correspondences via graph matching + (Hungarian or descriptor-based) with geometric consistency filtering. +3. **Transform estimation** (`transform_from_matches`): Fit an affine/euclidean/similarity + transform from matched bead pairs. +4. **Iterative refinement** (`optimize_transform`, `estimate`): Compose the approximate + transform with bead-based corrections, re-detect peaks, and score until convergence. +5. **Parameter tuning** (`optimize_matches`): Grid search over matching settings to find + the combination that maximizes registration quality. + +Key conventions +--------------- +- Coordinates are in ZYX order for 3D data. +- "mov" / "moving" refers to the source channel being aligned. +- "ref" / "reference" refers to the fixed target channel. +- Transforms map from moving space to reference space (forward direction). +""" + +from datetime import datetime +from itertools import product +from pathlib import Path +from typing import Literal + +import ants +import click +import dask.array as da +import numpy as np +import submitit + +from iohub import open_ome_zarr +from numpy.typing import ArrayLike +from scipy.spatial import cKDTree +from skimage.transform import AffineTransform, EuclideanTransform, SimilarityTransform + +from biahub.characterize_psf import detect_peaks +from biahub.cli.parsing import ( + sbatch_to_submitit, +) +from biahub.cli.slurm import wait_for_jobs_to_finish +from biahub.cli.utils import ( + _check_nan_n_zeros, + estimate_resources, +) +from biahub.core.graph_matching import Graph, GraphMatcher +from biahub.core.transform import Transform +from biahub.registration.utils import get_aprox_transform, load_transforms +from biahub.settings import AffineTransformSettings, BeadsMatchSettings, DetectPeaksSettings + + +def optimize_matches( + mov: ArrayLike, + ref: ArrayLike, + approx_transform: Transform, + beads_match_settings: BeadsMatchSettings, + affine_transform_settings: AffineTransformSettings, + param_grid: dict = None, + verbose: bool = False, +) -> BeadsMatchSettings: + """ + Optimize BeadsMatchSettings by grid search over matching and filter parameters. + + For each parameter combination: detects peaks in approximately registered space, + matches them, estimates a correction transform, composes it with the approx transform, + applies to the full volume via ANTs, re-detects peaks, and scores the overlap. + + Parameters + ---------- + mov : ArrayLike + Original (unregistered) moving volume (Z, Y, X). + ref : ArrayLike + Reference volume (Z, Y, X). + approx_transform : Transform + Initial approximate transform to compose with. + beads_match_settings : BeadsMatchSettings + Initial matching settings to use as baseline. + affine_transform_settings : AffineTransformSettings + Settings for the affine transform estimation. + param_grid : dict, optional + Dictionary of parameter names to lists of values to search. + Supported keys: 'min_distance_quantile', 'max_distance_quantile', + 'direction_threshold', 'cost_threshold', 'max_ratio', 'k', + 'weights_dist', 'weights_edge_angle', 'weights_edge_length', + 'weights_pca_dir', 'weights_pca_aniso', 'weights_edge_descriptor'. + verbose : bool + If True, prints logs for each trial. + + Returns + ------- + BeadsMatchSettings + The settings that produced the best overlap score. + """ + if param_grid is None: + param_grid = { + 'min_distance_quantile': [0, 0.01], + 'max_distance_quantile': [0, 0.99], + 'direction_threshold': [0, 50], + 'k': [5, 10], + } + + score_radius = beads_match_settings.qc_settings.score_centroid_mask_radius + + # Convert volumes to ANTs images once (reused across all trials) + mov_ants = ants.from_numpy(mov.astype(np.float32)) + ref_ants = ants.from_numpy(ref.astype(np.float32)) + + # Apply approximate transform to moving volume and detect peaks once. + # These peaks are reused for all parameter combinations in the grid search. + click.echo("Detecting peaks in approximately registered space for grid search...") + mov_reg_approx = ( + approx_transform.to_ants().apply_to_image(mov_ants, reference=ref_ants).numpy() + ) + mov_peaks, ref_peaks = peaks_from_beads( + mov=mov_reg_approx, + ref=ref, + mov_peaks_settings=beads_match_settings.source_peaks_settings, + ref_peaks_settings=beads_match_settings.target_peaks_settings, + verbose=False, + ) + if mov_peaks is None or ref_peaks is None or len(mov_peaks) < 2 or len(ref_peaks) < 2: + click.echo("Not enough peaks detected for optimization, returning original settings.") + return beads_match_settings + + click.echo( + f"Starting grid search: {len(mov_peaks)} mov peaks, {len(ref_peaks)} ref peaks, " + f"{np.prod([len(v) for v in param_grid.values()])} parameter combinations." + ) + + ndim = mov_peaks.shape[1] + best_score = -1.0 + best_settings = beads_match_settings + + grid_keys = list(param_grid.keys()) + grid_values = [param_grid[k] for k in grid_keys] + + def apply_trial_params(trial_settings, trial_params): + """Apply parameter values from a grid search trial to a BeadsMatchSettings copy.""" + fm = trial_settings.filter_matches_settings + hm = trial_settings.hungarian_match_settings + w = hm.cost_matrix_settings.weights + param_map = { + 'min_distance_quantile': lambda v: setattr(fm, 'min_distance_quantile', v), + 'max_distance_quantile': lambda v: setattr(fm, 'max_distance_quantile', v), + 'direction_threshold': lambda v: setattr(fm, 'direction_threshold', v), + 'cost_threshold': lambda v: setattr(hm, 'cost_threshold', v), + 'max_ratio': lambda v: setattr(hm, 'max_ratio', v), + 'k': lambda v: setattr(hm.edge_graph_settings, 'k', v), + 'weights_dist': lambda v: w.__setitem__('dist', v), + 'weights_edge_angle': lambda v: w.__setitem__('edge_angle', v), + 'weights_edge_length': lambda v: w.__setitem__('edge_length', v), + 'weights_pca_dir': lambda v: w.__setitem__('pca_dir', v), + 'weights_pca_aniso': lambda v: w.__setitem__('pca_aniso', v), + 'weights_edge_descriptor': lambda v: w.__setitem__('edge_descriptor', v), + } + for key, val in trial_params.items(): + if key in param_map: + param_map[key](val) + + for combo in product(*grid_values): + trial_params = dict(zip(grid_keys, combo)) + trial_settings = beads_match_settings.model_copy(deep=True) + apply_trial_params(trial_settings, trial_params) + + try: + matches = matches_from_beads( + mov_peaks=mov_peaks, + ref_peaks=ref_peaks, + beads_match_settings=trial_settings, + verbose=False, + ) + + if len(matches) < 3: + continue + + fwd_transform, inv_transform = transform_from_matches( + matches=matches, + mov_peaks=mov_peaks, + ref_peaks=ref_peaks, + affine_transform_settings=affine_transform_settings, + ndim=ndim, + verbose=False, + ) + + # Compose approx_transform with correction and apply to full volume + composed_transform = approx_transform @ inv_transform + mov_reg_optimized = ( + composed_transform.to_ants() + .apply_to_image(mov_ants, reference=ref_ants) + .numpy() + ) + + # Re-detect peaks and score + mov_peaks_opt, ref_peaks_opt = peaks_from_beads( + mov=mov_reg_optimized, + ref=ref, + mov_peaks_settings=beads_match_settings.source_peaks_settings, + ref_peaks_settings=beads_match_settings.target_peaks_settings, + verbose=False, + ) + if mov_peaks_opt is None or ref_peaks_opt is None: + continue + + score = overlap_score( + mov_peaks=mov_peaks_opt, + ref_peaks=ref_peaks_opt, + radius=score_radius, + verbose=False, + ) + + if np.isnan(score): + continue + + if verbose: + click.echo(f" {trial_params} -> matches={len(matches)}, score={score:.4f}") + + if score > best_score: + best_score = score + best_settings = trial_settings + + except Exception as e: + if verbose: + click.echo(f" {trial_params} -> failed: {e}") + continue + + if verbose: + click.echo(f"Best score: {best_score:.4f}") + click.echo(f"Best settings: {best_settings}") + + return best_settings + + +def overlap_score( + mov_peaks: ArrayLike, + ref_peaks: ArrayLike, + radius: int = 6, + verbose: bool = False, +) -> float: + """ + Compute the overlap fraction between two sets of bead peaks. + + For each reference peak, checks whether any moving peak falls within a + spherical neighborhood of the given radius (using a KDTree). The score is + the fraction of reference peaks that have at least one nearby moving peak, + normalized by the smaller peak set size. + + Parameters + ---------- + mov_peaks : ArrayLike + (N_mov, D) array of moving bead coordinates (z, y, x). + ref_peaks : ArrayLike + (N_ref, D) array of reference bead coordinates (z, y, x). + radius : int + Spherical neighborhood radius in voxels for overlap counting. + verbose : bool + If True, prints peak counts and overlap statistics. + + Returns + ------- + float + Overlap fraction in [0, 1]. Returns np.nan if either peak set is empty. + """ + + if len(mov_peaks) == 0 or len(ref_peaks) == 0: + click.echo("No peaks found, returning nan metrics") + return np.nan + + # ---- Overlap counting using KDTree ---- + mov_tree = cKDTree(mov_peaks) + + ref_peaks_mask = np.zeros(len(ref_peaks), dtype=bool) + mov_peaks_mask = np.zeros(len(mov_peaks), dtype=bool) + + for i, p in enumerate(ref_peaks): + idx = mov_tree.query_ball_point(p, r=radius) + if idx: + ref_peaks_mask[i] = True + mov_peaks_mask[idx] = True + + peaks_overlap_count = int(ref_peaks_mask.sum()) + + # ---- Overlap fraction ---- + peaks_overlap_fraction = peaks_overlap_count / max(min(len(mov_peaks), len(ref_peaks)), 1) + + if verbose: + click.echo(f"Mov peaks: {len(mov_peaks)}") + click.echo(f"Ref peaks: {len(ref_peaks)}") + click.echo(f"Peaks overlap count: {peaks_overlap_count}") + click.echo(f"Peaks overlap fraction: {peaks_overlap_fraction}") + + return peaks_overlap_fraction + + +def estimate_tczyx( + mov_tczyx: da.Array, + ref_tczyx: da.Array, + mov_channel_index: int, + ref_channel_index: int = None, + beads_match_settings: BeadsMatchSettings = None, + affine_transform_settings: AffineTransformSettings = None, + verbose: bool = False, + cluster: bool = False, + sbatch_filepath: Path = None, + output_folder_path: Path = None, + ref_voxel_size: tuple[float, float, float] = (0.174, 0.1494, 0.1494), + mov_voxel_size: tuple[float, float, float] = (0.174, 0.1494, 0.1494), + mode: Literal["registration", "stabilization"] = "registration", +) -> list[Transform]: + """ + Estimate beads-based registration transforms for all timepoints. + + Orchestrates the full registration pipeline: computes the approximate transform + (if needed), then estimates per-timepoint transforms either sequentially with + propagation or independently via SLURM, depending on settings. + + Parameters + ---------- + mov_tczyx : da.Array + Moving data (T, C, Z, Y, X). + ref_tczyx : da.Array + Reference data (T, C, Z, Y, X). + mov_channel_index : int + Channel index in the moving data containing beads. + ref_channel_index : int, optional + Channel index in the reference data. Ignored in stabilization mode. + beads_match_settings : BeadsMatchSettings + Settings for bead detection, matching, filtering, and QC. + affine_transform_settings : AffineTransformSettings + Settings for transform type, initial approx transform, and propagation. + verbose : bool + If True, prints detailed logs. + cluster : bool + If True, submits jobs to SLURM; otherwise runs locally. + sbatch_filepath : Path, optional + Path to sbatch file for custom SLURM parameters. + output_folder_path : Path + Directory to save per-timepoint transforms and logs. + ref_voxel_size : tuple[float, float, float] + Reference voxel size (Z, Y, X) in microns. + mov_voxel_size : tuple[float, float, float] + Moving voxel size (Z, Y, X) in microns. + mode : {"registration", "stabilization"} + "registration": align two different channels. + "stabilization": align one channel to itself over time. + + Returns + ------- + list[Transform] + One 4x4 affine transform per timepoint. + """ + mov_tzyx = mov_tczyx[:, mov_channel_index] + if mode == "stabilization": + ref_tzyx = mov_tzyx + elif mode == "registration": + ref_tzyx = ref_tczyx[:, ref_channel_index] + + output_transforms_path = output_folder_path / "xyz_transforms" + output_transforms_path.mkdir(parents=True, exist_ok=True) + + if affine_transform_settings.compute_approx_transform: + approx_transform = get_aprox_transform( + mov_shape=mov_tzyx.shape[-3:], + ref_shape=ref_tzyx.shape[-3:], + pre_affine_90degree_rotation=-1, + pre_affine_fliplr=False, + verbose=verbose, + ref_voxel_size=ref_voxel_size, + mov_voxel_size=mov_voxel_size, + ) + click.echo("Computed approx transform: ", approx_transform) + affine_transform_settings.approx_transform = approx_transform.to_list() + + if affine_transform_settings.use_prev_t_transform: + estimate_with_propagation( + mov_tzyx=mov_tzyx, + ref_tzyx=ref_tzyx, + beads_match_settings=beads_match_settings, + affine_transform_settings=affine_transform_settings, + verbose=verbose, + output_folder_path=output_transforms_path, + mode=mode, + ) + else: + estimate_independently( + mov_tzyx=mov_tzyx, + ref_tzyx=ref_tzyx, + beads_match_settings=beads_match_settings, + affine_transform_settings=affine_transform_settings, + verbose=verbose, + output_folder_path=output_transforms_path, + cluster=cluster, + sbatch_filepath=sbatch_filepath, + mode=mode, + ) + + transforms = load_transforms(output_transforms_path, mov_tzyx.shape[0], verbose) + + return transforms + + +def estimate_with_propagation( + mov_tzyx: da.Array, + ref_tzyx: da.Array, + beads_match_settings: BeadsMatchSettings, + affine_transform_settings: AffineTransformSettings, + verbose: bool = False, + output_folder_path: Path = None, + mode: Literal["registration", "stabilization"] = "registration", +) -> None: + """ + Estimate transforms sequentially, propagating each result to the next timepoint. + + Processes timepoints in order (t=0, 1, 2, ...). After each timepoint, the + estimated transform is used as the approximate transform for the next timepoint. + This is useful when drift is gradual and cumulative, as each timepoint starts + from a better initial guess. + + Parameters + ---------- + mov_tzyx : da.Array + Moving volume (T, Z, Y, X). + ref_tzyx : da.Array + Reference volume (T, Z, Y, X). + beads_match_settings : BeadsMatchSettings + Settings for bead detection, matching, and filtering. + affine_transform_settings : AffineTransformSettings + Settings for transform type and initial approximate transform. + Modified in-place: approx_transform is updated after each timepoint. + verbose : bool + If True, prints progress for each timepoint. + output_folder_path : Path + Directory to save per-timepoint transform .npy files. + mode : {"registration", "stabilization"} + "registration": align moving to reference channel. + "stabilization": align moving channel to itself over time. + """ + initial_transform = affine_transform_settings.approx_transform + T, _, _, _ = mov_tzyx.shape + for t in range(T): + if mode == "stabilization" and t == 0: + continue + if np.sum(mov_tzyx[t]) == 0 or np.sum(ref_tzyx[t]) == 0: + click.echo(f"Timepoint {t} has no data, skipping") + else: + approx_transform = estimate_tzyx( + t_idx=t, + mov_tzyx=mov_tzyx, + ref_tzyx=ref_tzyx, + beads_match_settings=beads_match_settings, + affine_transform_settings=affine_transform_settings, + verbose=verbose, + output_folder_path=output_folder_path, + mode=mode, + user_transform=initial_transform, + ) + + if approx_transform is not None: + affine_transform_settings.approx_transform = approx_transform.to_list() + else: + affine_transform_settings.approx_transform = initial_transform + + +def estimate_independently( + mov_tzyx: da.Array, + ref_tzyx: da.Array, + beads_match_settings: BeadsMatchSettings, + affine_transform_settings: AffineTransformSettings, + verbose: bool = False, + output_folder_path: Path = None, + cluster: str = 'local', + sbatch_filepath: Path = None, + mode: Literal["registration", "stabilization"] = "registration", +) -> None: + """ + Estimate transforms for all timepoints independently via SLURM. + + Each timepoint is submitted as an independent job using submitit. All jobs + use the same approximate transform as their starting point (no propagation). + Suitable for large datasets where timepoints can be processed in parallel. + + Parameters + ---------- + mov_tzyx : da.Array + Moving volume (T, Z, Y, X). + ref_tzyx : da.Array + Reference volume (T, Z, Y, X). + beads_match_settings : BeadsMatchSettings + Settings for bead detection, matching, and filtering. + affine_transform_settings : AffineTransformSettings + Settings for transform type and initial approximate transform. + verbose : bool + If True, prints progress for each timepoint. + output_folder_path : Path + Directory to save per-timepoint transform .npy files. + cluster : str + Submitit cluster backend ('local', 'slurm', etc.). + sbatch_filepath : Path, optional + Path to sbatch file for custom SLURM parameters. + mode : {"registration", "stabilization"} + "registration": align moving to reference channel. + "stabilization": align moving channel to itself over time. + """ + T, Z, Y, X = mov_tzyx.shape + num_cpus, gb_ram_per_cpu = estimate_resources( + shape=(T, 2, Z, Y, X), ram_multiplier=5, max_num_cpus=16 + ) + + # Prepare SLURM arguments + slurm_args = { + "slurm_job_name": "estimate_registration", + "slurm_mem_per_cpu": f"{gb_ram_per_cpu}G", + "slurm_cpus_per_task": num_cpus, + "slurm_array_parallelism": 100, + "slurm_time": 30, + "slurm_partition": "preempted", + "slurm_use_srun": False, + } + + if sbatch_filepath: + slurm_args.update(sbatch_to_submitit(sbatch_filepath)) + + slurm_out_path = output_folder_path.parent / "slurm_output" + slurm_out_path.mkdir(parents=True, exist_ok=True) + + # Submitit executor + executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster) + executor.update_parameters(**slurm_args) + click.echo(f"Submitting SLURM focus estimation jobs with resources: {slurm_args}") + + # Submit jobs + jobs = [] + with submitit.helpers.clean_env(), executor.batch(): + for t in range(T): + job = executor.submit( + estimate_tzyx, + t_idx=t, + mov_tzyx=mov_tzyx, + ref_tzyx=ref_tzyx, + beads_match_settings=beads_match_settings, + affine_transform_settings=affine_transform_settings, + verbose=verbose, + output_folder_path=output_folder_path, + mode=mode, + ) + jobs.append(job) + + # Save job IDs + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + log_path = slurm_out_path / f"job_ids_{timestamp}.log" + with open(log_path, "w") as log_file: + for job in jobs: + log_file.write(f"{job.job_id}\n") + + wait_for_jobs_to_finish(jobs) + + +def peaks_from_beads( + mov: da.Array, + ref: da.Array, + mov_peaks_settings: DetectPeaksSettings, + ref_peaks_settings: DetectPeaksSettings, + verbose: bool = False, + mask_path: Path = None, +) -> tuple[ArrayLike, ArrayLike]: + """ + Detect peaks in moving and reference channels using the detect_peaks function. + + Parameters + ---------- + mov : da.Array + (Z, Y, X) array of the moving channel (Dask array). + ref : da.Array + (Z, Y, X) array of the reference channel (Dask array). + mov_peaks_settings : DetectPeaksSettings + Settings for the moving peaks. + ref_peaks_settings : DetectPeaksSettings + Settings for the reference peaks. + verbose : bool + If True, prints detailed logs during the process. + mask_path : Path + Path to the mask file. + Returns + ------- + tuple[ArrayLike, ArrayLike] + Tuple of (mov_peaks, ref_peaks). + """ + if verbose: + click.echo('Detecting beads in moving dataset') + # TODO: detecte peaks in the zyx space, use skimage.feature.peak_local_max for 2D + mov_peaks = detect_peaks( + mov, + block_size=mov_peaks_settings.block_size, + threshold_abs=mov_peaks_settings.threshold_abs, + nms_distance=mov_peaks_settings.nms_distance, + min_distance=mov_peaks_settings.min_distance, + verbose=verbose, + ) + if verbose: + click.echo('Detecting beads in reference dataset') + # TODO: detecte peaks in the zyx space, use skimage.feature.peak_local_max for 2D + ref_peaks = detect_peaks( + ref, + block_size=ref_peaks_settings.block_size, + threshold_abs=ref_peaks_settings.threshold_abs, + nms_distance=ref_peaks_settings.nms_distance, + min_distance=ref_peaks_settings.min_distance, + verbose=verbose, + ) + if verbose: + click.echo(f'Total of peaks in moving dataset: {len(mov_peaks)}') + click.echo(f'Total of peaks in reference dataset: {len(ref_peaks)}') + + if len(mov_peaks) < 2 or len(ref_peaks) < 2: + click.echo('Not enough beads detected') + return + if mask_path is not None: + click.echo("Filtering peaks with mask") + with open_ome_zarr(mask_path) as mask_ds: + mask_load = np.asarray(mask_ds.data[0, 0]) + + # filter the peaks with the mask + # Keep only peaks whose (y, x) column is clean across all Z slices + ref_peaks_filtered = [] + for peak in ref_peaks: + z, y, x = peak.astype(int) + if ( + 0 <= y < mask_load.shape[1] + and 0 <= x < mask_load.shape[2] + and not mask_load[:, y, x].any() # True if all Z are clean at (y, x) + ): + ref_peaks_filtered.append(peak) + ref_peaks = np.array(ref_peaks_filtered) + return mov_peaks, ref_peaks + + +def matches_from_beads( + mov_peaks: ArrayLike, + ref_peaks: ArrayLike, + beads_match_settings: BeadsMatchSettings, + verbose: bool = False, +) -> ArrayLike: + """ + Find bead correspondences between moving and reference peak sets. + + Supports two matching algorithms: + - "hungarian": Builds k-NN graphs for both peak sets, computes a cost matrix + based on position distance and edge consistency, then solves the assignment + problem with the Hungarian algorithm. + - "match_descriptor": Uses scikit-image's descriptor matching on peak positions. + + After matching, applies geometric consistency filters (distance quantiles, + direction threshold, angle threshold) to remove outliers. + + Parameters + ---------- + mov_peaks : ArrayLike + (N, D) array of moving peak coordinates (D = 2 or 3). + ref_peaks : ArrayLike + (M, D) array of reference peak coordinates. + beads_match_settings : BeadsMatchSettings + Settings controlling the matching algorithm, graph construction, + cost matrix weights, and post-match filtering. + verbose : bool + If True, prints matching settings and match count. + + Returns + ------- + ArrayLike + (K, 2) array of matched index pairs [mov_idx, ref_idx]. + """ + if verbose: + click.echo(f'Getting matches from beads with settings: {beads_match_settings}') + + if beads_match_settings.algorithm == 'match_descriptor': + mov_graph = Graph.from_nodes(mov_peaks) + ref_graph = Graph.from_nodes(ref_peaks) + + match_descriptor_settings = beads_match_settings.match_descriptor_settings + matcher = GraphMatcher( + algorithm='descriptor', + cross_check=match_descriptor_settings.cross_check, + max_ratio=match_descriptor_settings.max_ratio, + metric=match_descriptor_settings.distance_metric, + verbose=verbose, + ) + + matches = matcher.match(mov_graph, ref_graph) + + elif beads_match_settings.algorithm == 'hungarian': + hungarian_match_settings = beads_match_settings.hungarian_match_settings + mov_graph = Graph.from_nodes( + mov_peaks, mode='knn', k=hungarian_match_settings.edge_graph_settings.k + ) + ref_graph = Graph.from_nodes( + ref_peaks, mode='knn', k=hungarian_match_settings.edge_graph_settings.k + ) + + matcher = GraphMatcher( + algorithm='hungarian', + weights=hungarian_match_settings.cost_matrix_settings.weights, + cost_threshold=hungarian_match_settings.cost_threshold, + cross_check=hungarian_match_settings.cross_check, + max_ratio=hungarian_match_settings.max_ratio, + verbose=verbose, + ) + + matches = matcher.match(mov_graph, ref_graph) + + # Filter as part of the pipeline + matches = matcher.filter_matches( + matches, + mov_graph, + ref_graph, + angle_threshold=beads_match_settings.filter_matches_settings.angle_threshold, + min_distance_quantile=beads_match_settings.filter_matches_settings.min_distance_quantile, + max_distance_quantile=beads_match_settings.filter_matches_settings.max_distance_quantile, + direction_threshold=beads_match_settings.filter_matches_settings.direction_threshold, + ) + + if verbose: + click.echo(f'Total of matches: {len(matches)}') + + return matches + + +def transform_from_matches( + matches: ArrayLike, + mov_peaks: ArrayLike, + ref_peaks: ArrayLike, + affine_transform_settings: AffineTransformSettings, + ndim: int = 3, + verbose: bool = False, +) -> tuple[Transform, Transform]: + """ + Estimate the affine transformation matrix between source and target channels + based on detected bead matches at a specific timepoint. + + Parameters + ---------- + matches : ArrayLike + (n, 2) array of matches. + mov_peaks : ArrayLike + (n, 2) array of moving peaks. + ref_peaks : ArrayLike + (n, 2) array of reference peaks. + affine_transform_settings : AffineTransformSettings + Settings for the affine transform. + ndim: int + Number of dimensions. + verbose : bool + If True, prints detailed logs during the process. + + Returns + ------- + tuple[Transform, Transform] + Tuple of forward and inverse transforms. + """ + if verbose: + click.echo(f"Estimating transform with settings: {affine_transform_settings}") + # Detect dimensionality from peaks + if ndim not in (2, 3): + raise ValueError(f"Peaks must be 2D or 3D, got {ndim}D") + + # Create appropriate transform + if affine_transform_settings.transform_type == 'affine': + transform = AffineTransform(dimensionality=ndim) + elif affine_transform_settings.transform_type == 'euclidean': + transform = EuclideanTransform(dimensionality=ndim) + elif affine_transform_settings.transform_type == 'similarity': + transform = SimilarityTransform(dimensionality=ndim) + else: + raise ValueError(f'Unknown transform type: {affine_transform_settings.transform_type}') + + # Fit transform + transform.estimate(mov_peaks[matches[:, 0]], ref_peaks[matches[:, 1]]) + + inv_transform = Transform(matrix=transform.inverse.params) + fwd_transform = Transform(matrix=transform.params) + + return fwd_transform, inv_transform + + +def estimate_tzyx( + t_idx: int, + mov_tzyx: da.Array, + ref_tzyx: da.Array, + beads_match_settings: BeadsMatchSettings, + affine_transform_settings: AffineTransformSettings, + verbose: bool = False, + output_folder_path: Path = None, + mode: Literal["registration", "stabilization"] = "registration", + user_transform: Transform = None, +) -> Transform: + """ + Estimate the affine transform for a single timepoint. + + Extracts the 3D volumes for the given timepoint, sets up the reference + depending on the mode (registration vs stabilization), and delegates to + `estimate()` for iterative bead-based transform estimation. + + Parameters + ---------- + t_idx : int + Timepoint index to process. + mov_tzyx : da.Array + Moving volume (T, Z, Y, X). + ref_tzyx : da.Array + Reference volume (T, Z, Y, X). Ignored in stabilization mode. + beads_match_settings : BeadsMatchSettings + Settings for bead detection, matching, and filtering. + affine_transform_settings : AffineTransformSettings + Settings for transform type and initial approximate transform. + verbose : bool + If True, prints detailed logs during the process. + output_folder_path : Path, optional + Directory to save the transform as ``{t_idx}.npy``. + mode : {"registration", "stabilization"} + "registration": align moving to reference (different channels). + "stabilization": align moving channel to itself over time, + using t_reference setting ("first" or "previous"). + user_transform : Transform, optional + Alternative initial transform to compete with the default on iteration 0. + + Returns + ------- + Transform or None + The estimated 4x4 affine transform, or None if estimation failed. + """ + click.echo("........................................................................") + click.echo(f'Processing timepoint: {t_idx}') + + (T, Z, Y, X) = mov_tzyx.shape + + if mode == "stabilization": + click.echo("Performing stabilization, aka registration over time in the same file.") + if affine_transform_settings.t_reference == "first": + ref_tzyx = np.broadcast_to(mov_tzyx[0], (T, Z, Y, X)).copy() + elif affine_transform_settings.t_reference == "previous": + ref_tzyx = np.roll(mov_tzyx, shift=-1, axis=0) + ref_tzyx[0] = mov_tzyx[0] + else: + raise ValueError( + "Invalid reference. Please use 'first' or 'previous' as reference." + ) + elif mode == "registration": + click.echo("Performing registration between different files") + mov_zyx = np.asarray(mov_tzyx[t_idx]).astype(np.float32) + ref_zyx = np.asarray(ref_tzyx[t_idx]).astype(np.float32) + + if output_folder_path: + output_folder_path.mkdir(parents=True, exist_ok=True) + output_filepath = output_folder_path / f"{t_idx}.npy" + else: + output_filepath = None + + transform = estimate( + mov=mov_zyx, + ref=ref_zyx, + beads_match_settings=beads_match_settings, + affine_transform_settings=affine_transform_settings, + verbose=verbose, + output_filepath=output_filepath, + user_transform=user_transform, + ) + return transform + + +def optimize_transform( + transform: Transform, + mov: da.Array, + ref: da.Array, + beads_match_settings: BeadsMatchSettings, + affine_transform_settings: AffineTransformSettings, + verbose: bool = False, + debug: bool = False, +) -> tuple[Transform, float]: + """ + Refine a transform by bead matching and evaluate registration quality. + + Applies the current transform to the moving volume, detects beads in both + the registered moving and reference volumes, matches them, estimates a + correction transform, and composes it with the input transform. Returns + the better of the two (original vs corrected) based on overlap score. + + Parameters + ---------- + transform : Transform + Current transform to refine (maps moving -> reference space). + mov : ArrayLike + Original (unregistered) moving volume (Z, Y, X). + ref : ArrayLike + Reference volume (Z, Y, X). + beads_match_settings : BeadsMatchSettings + Settings controlling peak detection, matching, and filtering. + affine_transform_settings : AffineTransformSettings + Settings for the transform type (affine/euclidean/similarity). + verbose : bool + If True, prints quality scores before and after optimization. + debug : bool + If True, prints detailed intermediate results (peaks, matches, transforms). + + Returns + ------- + tuple[Transform, float] + The best transform and its overlap score. + Returns (None, -1) if not enough peaks or matches are found. + """ + mov_ants = ants.from_numpy(mov) + ref_ants = ants.from_numpy(ref) + + # Step 1: Score the current transform by applying it and measuring peak overlap + if debug: + click.echo("Step 1: Scoring current transform (before bead matching)...") + mov_reg_approx = transform.to_ants().apply_to_image(mov_ants, reference=ref_ants).numpy() + mov_peaks, ref_peaks = peaks_from_beads( + mov=mov_reg_approx, + ref=ref, + mov_peaks_settings=beads_match_settings.source_peaks_settings, + ref_peaks_settings=beads_match_settings.target_peaks_settings, + verbose=debug, + ) + if mov_peaks is None or ref_peaks is None: + return None, -1 + + quality_score_approx = overlap_score( + mov_peaks=mov_peaks, + ref_peaks=ref_peaks, + radius=beads_match_settings.qc_settings.score_centroid_mask_radius, + verbose=debug, + ) + + # Step 2: Match beads and estimate a correction transform + if debug: + click.echo("Step 2: Matching beads to estimate correction transform...") + matches = matches_from_beads( + mov_peaks=mov_peaks, + ref_peaks=ref_peaks, + beads_match_settings=beads_match_settings, + verbose=debug, + ) + + if len(matches) < 3: + click.echo('Not enough matches found, returning the current transform') + return None, -1 + + fwd_transform, inv_transform = transform_from_matches( + matches=matches, + mov_peaks=mov_peaks, + ref_peaks=ref_peaks, + affine_transform_settings=affine_transform_settings, + ndim=mov.ndim, + verbose=debug, + ) + composed_transform = transform @ inv_transform + + # Step 3: Score the composed (corrected) transform + if debug: + click.echo("Step 3: Scoring composed transform (after bead matching)...") + mov_reg_optimized = ( + composed_transform.to_ants().apply_to_image(mov_ants, reference=ref_ants).numpy() + ) + mov_peaks_optimized, ref_peaks_optimized = peaks_from_beads( + mov=mov_reg_optimized, + ref=ref, + mov_peaks_settings=beads_match_settings.source_peaks_settings, + ref_peaks_settings=beads_match_settings.target_peaks_settings, + verbose=debug, + ) + + quality_score_optimized = overlap_score( + mov_peaks=mov_peaks_optimized, + ref_peaks=ref_peaks_optimized, + radius=beads_match_settings.qc_settings.score_centroid_mask_radius, + verbose=debug, + ) + if debug: + click.echo(f'Bead matches: {matches}') + click.echo(f"Forward transform: {fwd_transform}") + click.echo(f"Inverse transform: {inv_transform}") + click.echo(f"Composed transform: {composed_transform}") + + if verbose: + click.echo(f"Quality score before beads matching: {quality_score_approx}") + click.echo(f"Quality score after beads matching: {quality_score_optimized}") + + if quality_score_optimized >= quality_score_approx: + return composed_transform, quality_score_optimized + else: + return transform, quality_score_approx + + +def estimate( + mov: da.Array, + ref: da.Array, + beads_match_settings: BeadsMatchSettings, + affine_transform_settings: AffineTransformSettings, + verbose: bool = False, + output_filepath: Path = None, + user_transform: Transform = None, + debug: bool = False, +) -> Transform: + """ + Estimate the best affine transformation between moving and reference volumes. + + Iteratively refines the transform by detecting beads, matching them, estimating + a correction, and scoring the result. Supports an optional user-provided + transform that competes with the computed one on the first iteration. + + Works for both 2D (Y, X) and 3D (Z, Y, X) arrays. + + Parameters + ---------- + mov : ArrayLike + Moving channel volume (Z, Y, X) or (Y, X). + ref : ArrayLike + Reference channel volume (Z, Y, X) or (Y, X). + beads_match_settings : BeadsMatchSettings + Settings for bead detection, matching, filtering, and QC iterations. + affine_transform_settings : AffineTransformSettings + Settings for transform type and initial approximate transform. + verbose : bool + If True, prints the best transform and score at the end. + output_filepath : Path, optional + If provided, saves the best transform matrix as a .npy file. + user_transform : Transform, optional + An alternative initial transform (e.g. from a previous timepoint). + Tested on the first iteration; used if it scores better. + debug : bool + If True, passes debug flag to optimize_transform for detailed logging. + + Returns + ------- + Transform + The best transform found across all iterations. Falls back to the + initial approximate transform if no valid optimization was found. + """ + + if _check_nan_n_zeros(mov) or _check_nan_n_zeros(ref): + click.echo('Skipping: moving or reference data contains only NaN/zeros.') + return + + initial_transform = Transform( + matrix=np.asarray(affine_transform_settings.approx_transform) + ) + transform = initial_transform + + current_iterations = 0 + qc_iterations = beads_match_settings.qc_settings.iterations + transform_iter_dict = {} + + while current_iterations < qc_iterations: + click.echo( + f"Iteration {current_iterations + 1}/{qc_iterations}: " + "optimizing transform via bead matching..." + ) + optimized_transform, quality_score_optimized = optimize_transform( + transform=transform, + mov=mov, + ref=ref, + beads_match_settings=beads_match_settings, + affine_transform_settings=affine_transform_settings, + verbose=verbose, + debug=debug, + ) + transform_iter_dict[current_iterations] = { + "transform": optimized_transform, + "quality_score": quality_score_optimized, + } + if quality_score_optimized == 1: + break + transform = optimized_transform + + if user_transform is not None and current_iterations == 0: + click.echo("Optimizing user transform:") + user_transform = Transform(matrix=np.asarray(user_transform)) + optimized_transform_user, quality_score_optimized_user = optimize_transform( + transform=user_transform, + mov=mov, + ref=ref, + beads_match_settings=beads_match_settings, + affine_transform_settings=affine_transform_settings, + verbose=verbose, + debug=debug, + ) + + if quality_score_optimized < quality_score_optimized_user: + + transform_iter_dict[current_iterations] = { + "transform": optimized_transform_user, + "quality_score": quality_score_optimized_user, + } + if quality_score_optimized_user == 1: + break + transform = optimized_transform_user + + if transform is None: + break + current_iterations += 1 + + # get highest quality score + best_quality_score = max(transform_iter_dict.values(), key=lambda x: x["quality_score"]) + best_transform = best_quality_score["transform"] + + if best_transform is None: + best_transform = initial_transform + if verbose: + click.echo(f"Best transform: {best_transform}") + click.echo(f"Best quality score: {best_quality_score['quality_score']}") + if output_filepath: + click.echo(f"Saving transform to {output_filepath}") + np.save(output_filepath, best_transform.to_list()) + + return best_transform diff --git a/biahub/registration/utils.py b/biahub/registration/utils.py new file mode 100644 index 00000000..c678747c --- /dev/null +++ b/biahub/registration/utils.py @@ -0,0 +1,961 @@ +""" +Registration utility functions. + +Helper functions shared across registration methods (beads, ANTs, phase correlation). +Includes: +- Transform validation and interpolation across timepoints. +- Approximate transform computation from voxel sizes and rotation/flip parameters. +- Transform I/O (save/load from disk, convert between numpy and ANTs formats). +- Volume utilities: LIR cropping, affine application, padding, and shape matching. + +Key conventions +--------------- +- Coordinates and shapes are in ZYX order for 3D data. +- Transforms are 4x4 homogeneous matrices. +""" + +import os + +from pathlib import Path +from typing import Literal, Tuple, Union + +import ants +import click +import largestinteriorrectangle as lir +import numpy as np +import scipy + +from matplotlib import pyplot as plt +from numpy.typing import ArrayLike +from scipy.interpolate import interp1d + +from biahub.cli.utils import ( + model_to_yaml, +) +from biahub.core.transform import Transform +from biahub.settings import ( + AffineTransformSettings, + RegistrationSettings, + StabilizationSettings, +) + +# TODO: see if at some point these globals should be hidden or exposed. +NA_DETECTION_SOURCE = 1.35 +NA_DETECTION_TARGET = 1.35 +WAVELENGTH_EMISSION_SOURCE_CHANNEL = 0.45 # in um +WAVELENGTH_EMISSION_TARGET_CHANNEL = 0.6 # in um +FOCUS_SLICE_ROI_WIDTH = 150 # size of central ROI used to find focal slice + + +def get_aprox_transform( + mov_shape: tuple[int, int, int], + ref_shape: tuple[int, int, int], + pre_affine_90degree_rotation: int = -1, + pre_affine_fliplr: bool = False, + verbose: bool = False, + ref_voxel_size: tuple[float, float, float] = (0.174, 0.1494, 0.1494), + mov_voxel_size: tuple[float, float, float] = (0.174, 0.1494, 0.1494), +) -> Transform: + + mov_Z, mov_Y, mov_X = mov_shape + ref_Z, ref_Y, ref_X = ref_shape + + # Calculate scaling factors for displaying data source_channel_voxel_size, target_channel_voxel_size + scaling_factor_z = mov_voxel_size[-3] / ref_voxel_size[-3] + scaling_factor_yx = mov_voxel_size[-1] / ref_voxel_size[-1] + click.echo( + f"Z scaling factor: {scaling_factor_z:.3f}; XY scaling factor: {scaling_factor_yx:.3f}\n" + ) + + scaling_affine = get_3D_rescaling_matrix( + (ref_Z, ref_Y, ref_X), + (scaling_factor_z, scaling_factor_yx, scaling_factor_yx), + (ref_Z, ref_Y, ref_X), + ) + rotate90_affine = get_3D_rotation_matrix( + (mov_Z, mov_Y, mov_X), + 90 * pre_affine_90degree_rotation, + (ref_Z, ref_Y, ref_X), + ) + + # Apply flip transformation if requested (flip happens first) + if pre_affine_fliplr: + fliplr_affine = get_3D_fliplr_matrix( + (mov_Z, mov_Y, mov_X), + (ref_Z, ref_Y, ref_X), + ) + else: + fliplr_affine = np.eye(4) + + compound_affine = np.linalg.inv(scaling_affine @ rotate90_affine @ fliplr_affine) + + return Transform(matrix=compound_affine) + + +def validate_transforms( + transforms: list[ArrayLike], + shape_zyx: tuple[int, int, int], + window_size: int = 10, + tolerance: float = 100.0, + verbose: bool = False, +) -> list[ArrayLike]: + """ + Validate that a provided list of transforms do not deviate beyond the tolerance threshold + relative to the average transform within a given window size. + + Parameters + ---------- + transforms : list[ArrayLike] + List of affine transformation matrices (4x4), one for each timepoint. + shape_zyx : tuple[int, int, int] + Shape of the source (i.e. moving) volume (Z, Y, X). + window_size : int + Size of the moving window for smoothing transformations. + tolerance : float + Maximum allowed difference between consecutive transformations for validation. + verbose : bool + If True, prints detailed logs of the validation process. + + Returns + ------- + list[ArrayLike] + List of affine transformation matrices with invalid or inconsistent values replaced by None. + """ + valid_transforms = [] + reference_transform = None + + for i, transform in enumerate(transforms): + if transform is not None: + if len(valid_transforms) < window_size: + # Bootstrap the buffer without validating yet + valid_transforms.append(transform) + reference_transform = np.mean(valid_transforms, axis=0) + if verbose: + click.echo( + f"[Bootstrap] Accepting transform at timepoint {i} (no validation)" + ) + elif check_transforms_difference( + transform, reference_transform, shape_zyx, tolerance, verbose + ): + valid_transforms.append(transform) + if len(valid_transforms) > window_size: + valid_transforms.pop(0) + reference_transform = np.mean(valid_transforms, axis=0) + if verbose: + click.echo(f"Transform at timepoint {i} is valid") + else: + transforms[i] = None + if verbose: + click.echo( + f"Transform at timepoint {i} is invalid and will be interpolated" + ) + else: + transforms[i] = None + if verbose: + click.echo(f"Transform at timepoint {i} is None and will be interpolated") + + return transforms + + +def interpolate_transforms( + transforms: list[ArrayLike], + window_size: int = 3, + interpolation_type: Literal["linear", "cubic"] = "linear", + verbose: bool = False, +): + """ + Interpolate missing transforms (None) in a list of affine transformation matrices. + + Parameters + ---------- + transforms : list[ArrayLike] + List of affine transformation matrices (4x4), one for each timepoint. + window_size : int + Local window radius for interpolation. If 0, global interpolation is used. + interpolation_type : Literal["linear", "cubic"] + Interpolation type. + verbose : bool + If True, prints detailed logs of the interpolation process. + + Returns + ------- + list[ArrayLike] + List of affine transformation matrices with missing values filled via linear interpolation. + """ + n = len(transforms) + valid_transform_indices = [i for i, t in enumerate(transforms) if t is not None] + valid_transforms = [np.array(transforms[i]) for i in valid_transform_indices] + + if not valid_transform_indices or len(valid_transform_indices) < 2: + raise ValueError("At least two valid transforms are required for interpolation.") + + missing_indices = [i for i in range(n) if transforms[i] is None] + + if not missing_indices: + return transforms # nothing to do + if verbose: + click.echo(f"Interpolating missing transforms at timepoints: {missing_indices}") + + if window_size > 0: + for idx in missing_indices: + # Define local window + start = max(0, idx - window_size) + end = min(n, idx + window_size + 1) + + local_x = [] + local_y = [] + + for j in range(start, end): + if j in valid_transform_indices: + local_x.append(j) + local_y.append(np.array(transforms[j])) + + if len(local_x) < 2: + # Not enough neighbors for interpolation. Assign to closes valid transform + closest_valid_idx = valid_transform_indices[ + np.argmin(np.abs(np.asarray(valid_transform_indices) - idx)) + ] + transforms[idx] = transforms[closest_valid_idx] + if verbose: + click.echo( + f"Not enough interpolation neighbors were found for timepoint {idx} using closest valid transform at timepoint {closest_valid_idx}" + ) + continue + + f = interp1d( + local_x, local_y, axis=0, kind=interpolation_type, fill_value='extrapolate' + ) + transforms[idx] = f(idx).tolist() + if verbose: + click.echo(f"Interpolated timepoint {idx} using neighbors: {local_x}") + + else: + # Global interpolation using all valid transforms + f = interp1d( + valid_transform_indices, + valid_transforms, + axis=0, + kind='linear', + fill_value='extrapolate', + ) + transforms = [ + f(i).tolist() if transforms[i] is None else transforms[i] for i in range(n) + ] + + return transforms + + +def check_transforms_difference( + tform1: ArrayLike, + tform2: ArrayLike, + shape_zyx: tuple[int, int, int], + threshold: float = 5.0, + verbose: bool = False, +): + """ + Evaluate the difference between two affine transforms by calculating the + Mean Squared Error (MSE) of a grid of points transformed by each matrix. + + Parameters + ---------- + tform1 : ArrayLike + First affine transform (4x4 matrix). + tform2 : ArrayLike + Second affine transform (4x4 matrix). + shape_zyx : tuple[int, int, int] + Shape of the source (i.e. moving) volume (Z, Y, X). + threshold : float + The maximum allowed MSE difference. + verbose : bool + Flag to print the MSE difference. + + Returns + ------- + bool + True if the MSE difference is within the threshold, False otherwise. + """ + tform1 = np.array(tform1) + tform2 = np.array(tform2) + (Z, Y, X) = shape_zyx + + zz, yy, xx = np.meshgrid( + np.linspace(0, Z - 1, 10), np.linspace(0, Y - 1, 10), np.linspace(0, X - 1, 10) + ) + + grid_points = np.vstack([zz.ravel(), yy.ravel(), xx.ravel(), np.ones(zz.size)]).T + + points_tform1 = np.dot(tform1, grid_points.T).T + points_tform2 = np.dot(tform2, grid_points.T).T + + differences = np.linalg.norm(points_tform1[:, :3] - points_tform2[:, :3], axis=1) + mse = np.mean(differences) + + if verbose: + click.echo(f'MSE of transformed points: {mse:.2f}; threshold: {threshold:.2f}') + return mse <= threshold + + +def evaluate_transforms( + transforms: ArrayLike, + shape_zyx: tuple[int, int, int], + validation_window_size: int = 10, + validation_tolerance: float = 100.0, + interpolation_window_size: int = 3, + interpolation_type: Literal["linear", "cubic"] = "linear", + verbose: bool = False, +) -> ArrayLike: + """ + Evaluate a list of affine transformation matrices. + Transform matrices are checked for deviation from the average within a given window size. + If a transform is found to lead to shift larger than the given tolerance, + that transform will be replaced by interpolation of valid transforms within a given window size. + + Parameters + ---------- + transforms : ArrayLike + List of affine transformation matrices (4x4), one for each timepoint. + shape_zyx : tuple[int, int, int] + Shape of the source (i.e. moving) volume (Z, Y, X). + validation_window_size : int + Size of the moving window for smoothing transformations. + validation_tolerance : float + Maximum allowed difference between consecutive transformations for validation. + interpolation_window_size : int + Size of the local window for interpolation. + interpolation_type : Literal["linear", "cubic"] + Interpolation type. + verbose : bool + If True, prints detailed logs of the evaluation and validation process. + + Returns + ------- + list[ArrayLike] + List of affine transformation matrices with missing values filled via linear interpolation. + """ + + if not isinstance(transforms, list): + transforms = transforms.tolist() + if len(transforms) < validation_window_size: + raise Warning( + f"Not enough transforms for validation and interpolation. " + f"Required: {validation_window_size}, " + f"Provided: {len(transforms)}" + ) + else: + transforms = validate_transforms( + transforms=transforms, + window_size=validation_window_size, + tolerance=validation_tolerance, + shape_zyx=shape_zyx, + verbose=verbose, + ) + + if len(transforms) < interpolation_window_size: + raise Warning( + f"Not enough transforms for interpolation. " + f"Required: {interpolation_window_size}, " + f"Provided: {len(transforms)}" + ) + else: + transforms = interpolate_transforms( + transforms=transforms, + window_size=interpolation_window_size, + interpolation_type=interpolation_type, + verbose=verbose, + ) + return transforms + + +def save_transforms( + model: Union[AffineTransformSettings, StabilizationSettings, RegistrationSettings], + transforms: list[ArrayLike], + output_filepath_settings: Path, + output_filepath_plot: Path = None, + verbose: bool = False, +): + """ + Save the transforms to a yaml file and plot the translations. + + Parameters + ---------- + model : Union[AffineTransformSettings, StabilizationSettings, RegistrationSettings] + Model to save the transforms to. + transforms : list[ArrayLike] + List of affine transformation matrices (4x4), one for each timepoint. + output_filepath_settings : Path + Path to the output settings file. + output_filepath_plot : Path + Path to the output plot file. + verbose : bool + If True, prints detailed logs of the saving process. + + Returns + ------- + None + + Notes + ----- + The transforms are saved to a yaml file and a plot of the translations is saved to a png file. + The plot is saved in the same directory as the settings file and is named "translations.png". + + """ + if transforms is None or len(transforms) == 0: + raise ValueError("Transforms are empty") + + if not isinstance(transforms, list): + transforms = transforms.tolist() + + model.affine_transform_zyx_list = transforms + + if output_filepath_settings.suffix not in [".yml", ".yaml"]: + output_filepath_settings = output_filepath_settings.with_suffix(".yml") + + output_filepath_settings.parent.mkdir(parents=True, exist_ok=True) + model_to_yaml(model, output_filepath_settings) + + if verbose and output_filepath_plot is not None: + if output_filepath_plot.suffix not in [".png"]: + output_filepath_plot = output_filepath_plot.with_suffix(".png") + output_filepath_plot.parent.mkdir(parents=True, exist_ok=True) + + plot_translations(np.asarray(transforms), output_filepath_plot) + + +def plot_translations( + transforms_zyx: ArrayLike, + output_filepath: Path, +): + """ + Plot the translations of a list of affine transformation matrices. + + Parameters + ---------- + transforms_zyx : ArrayLike + List of affine transformation matrices (4x4), one for each timepoint. + output_filepath : Path + Path to the output plot file. + Returns + ------- + None + + Notes + ----- + The plot is saved as a png file. + The plot is saved in the same directory as the output file. + The plot is saved as a png file. + """ + transforms_zyx = np.asarray(transforms_zyx) + os.makedirs(output_filepath.parent, exist_ok=True) + + z_transforms = transforms_zyx[:, 0, 3] + y_transforms = transforms_zyx[:, 1, 3] + x_transforms = transforms_zyx[:, 2, 3] + _, axs = plt.subplots(3, 1, figsize=(10, 10)) + + axs[0].plot(z_transforms) + axs[0].set_title("Z-Translation") + axs[1].plot(x_transforms) + axs[1].set_title("X-Translation") + axs[2].plot(y_transforms) + axs[2].set_title("Y-Translation") + plt.savefig(output_filepath, dpi=300, bbox_inches='tight') + plt.close() + + +def convert_transform_to_ants(T_numpy: np.ndarray): + """Homogeneous 3D transformation matrix from numpy to ants + + Parameters + ---------- + numpy_transform :4x4 homogenous matrix + + Returns + ------- + Ants transformation matrix object + """ + assert T_numpy.shape == (4, 4) + + T_ants_style = T_numpy[:, :-1].ravel() + T_ants_style[-3:] = T_numpy[:3, -1] + T_ants = ants.new_ants_transform( + transform_type="AffineTransform", + ) + T_ants.set_parameters(T_ants_style) + + return T_ants + + +def convert_transform_to_numpy(T_ants): + """ + Convert the ants transformation matrix to numpy 3D homogenous transform + + Modified from Jordao's dexp code + + Parameters + ---------- + T_ants : Ants transfromation matrix object + + Returns + ------- + np.array + Converted Ants to numpy array + + """ + + T_numpy = T_ants.parameters.reshape((3, 4), order="F") + T_numpy[:, :3] = T_numpy[:, :3].transpose() + T_numpy = np.vstack((T_numpy, np.array([0, 0, 0, 1]))) + + # Reference: + # https://sourceforge.net/p/advants/discussion/840261/thread/9fbbaab7/ + # https://github.com/netstim/leaddbs/blob/a2bb3e663cf7fceb2067ac887866124be54aca7d/helpers/ea_antsmat2mat.m + # T = original translation offset from A + # T = T + (I - A) @ centering + + T_numpy[:3, -1] += (np.eye(3) - T_numpy[:3, :3]) @ T_ants.fixed_parameters + + return T_numpy + + +def find_lir(registered_zyx: np.ndarray, plot: bool = False) -> Tuple: + registered_zyx = np.asarray(registered_zyx, dtype=bool) + + # Find the lir in YX at Z//2 + registered_yx = registered_zyx[registered_zyx.shape[0] // 2].copy() + coords_yx = lir.lir(registered_yx) + coords_yx = list(map(int, coords_yx)) + + x, y, width, height = coords_yx + x_start, x_stop = x, x + width + y_start, y_stop = y, y + height + x_slice = slice(x_start, x_stop) + y_slice = slice(y_start, y_stop) + + # Iterate over ZX and ZY slices to find optimal Z cropping params + _coords = [] + for _x in (x_start, x_start + (x_stop - x_start) // 2, x_stop - 1): + registered_zy = registered_zyx[:, y_slice, _x].copy() + coords_zy = lir.lir(registered_zy) + _, z, _, depth = coords_zy + z_start, z_stop = z, z + depth + _coords.append((z_start, z_stop)) + for _y in (y_start, y_start + (y_stop - y_start) // 2, y_stop - 1): + registered_zx = registered_zyx[:, _y, x_slice].copy() + coords_zx = lir.lir(registered_zx) + _, z, _, depth = coords_zx + z_start, z_stop = z, z + depth + _coords.append((z_start, z_stop)) + + _coords = np.asarray(_coords) + z_start = int(_coords.max(axis=0)[0]) + z_stop = int(_coords.min(axis=0)[1]) + z_slice = slice(z_start, z_stop) + + if plot: + xy_corners = ((x, y), (x + width, y), (x + width, y + height), (x, y + height)) + rectangle_yx = plt.Polygon( + xy_corners, + closed=True, + fill=None, + edgecolor="r", + ) + # Add the rectangle to the plot + _, ax = plt.subplots(nrows=1, ncols=2) + ax[0].imshow(registered_yx) + ax[0].add_patch(rectangle_yx) + + zx_corners = ((x, z), (x + width, z), (x + width, z + depth), (x, z + depth)) + rectangle_zx = plt.Polygon( + zx_corners, + closed=True, + fill=None, + edgecolor="r", + ) + ax[1].imshow(registered_zx) + ax[1].add_patch(rectangle_zx) + plt.savefig("./lir.png") + + return (z_slice, y_slice, x_slice) + + +def find_overlapping_volume( + input_zyx_shape: Tuple, + target_zyx_shape: Tuple, + transformation_matrix: np.ndarray, + method: str = "LIR", + plot: bool = False, +) -> Tuple: + """ + Find the overlapping rectangular volume after registration of two 3D datasets + + Parameters + ---------- + input_zyx_shape : Tuple + shape of input array + target_zyx_shape : Tuple + shape of target array + transformation_matrix : np.ndarray + affine transformation matrix + method : str, optional + method of finding the overlapping volume, by default 'LIR' + + Returns + ------- + Tuple + ZYX slices of the overlapping volume after registration + + """ + + # Make dummy volumes + moving_volume = np.ones(tuple(input_zyx_shape), dtype=np.float32) + fixed_volume = np.ones(tuple(target_zyx_shape), dtype=np.float32) + + # Convert to ants objects + fixed_volume_ants = ants.from_numpy(fixed_volume) + moving_volume_ants = ants.from_numpy(moving_volume) + + tform_ants = convert_transform_to_ants(transformation_matrix) + + # Now apply the transform using this grid + registered_volume = tform_ants.apply_to_image( + moving_volume_ants, reference=fixed_volume_ants + ).numpy() + if method == "LIR": + click.echo("Starting Largest interior rectangle (LIR) search") + mask = (registered_volume > 0) & (fixed_volume > 0) + z_slice, y_slice, x_slice = find_lir(mask, plot=plot) + + else: + raise ValueError(f"Unknown method {method}") + + return (z_slice, y_slice, x_slice) + + +def rescale_voxel_size(affine_matrix, input_scale): + return np.linalg.norm(affine_matrix, axis=1) * input_scale + + +def load_transforms(transforms_path: Path, T: int, verbose: bool = False) -> list[ArrayLike]: + # Load the transforms + transforms = [] + for t in range(T): + file_path = transforms_path / f"{t}.npy" + if not os.path.exists(file_path): + transforms.append(None) + if verbose: + click.echo(f"Transform for timepoint {t} not found.") + + else: + matrix = np.load(file_path) + transforms.append(matrix.tolist()) + + if verbose: + click.echo(f"Transform for timepoint {t}: {matrix}") + + return transforms + + +def get_3D_rescaling_matrix(start_shape_zyx, scaling_factor_zyx=(1, 1, 1), end_shape_zyx=None): + center_Y_start, center_X_start = np.array(start_shape_zyx)[-2:] / 2 + if end_shape_zyx is None: + center_Y_end, center_X_end = (center_Y_start, center_X_start) + else: + center_Y_end, center_X_end = np.array(end_shape_zyx)[-2:] / 2 + + scaling_matrix = np.array( + [ + [scaling_factor_zyx[-3], 0, 0, 0], + [ + 0, + scaling_factor_zyx[-2], + 0, + -center_Y_start * scaling_factor_zyx[-2] + center_Y_end, + ], + [ + 0, + 0, + scaling_factor_zyx[-1], + -center_X_start * scaling_factor_zyx[-1] + center_X_end, + ], + [0, 0, 0, 1], + ] + ) + return scaling_matrix + + +def get_3D_rotation_matrix( + start_shape_zyx: Tuple, angle: float = 0.0, end_shape_zyx: Tuple = None +) -> np.ndarray: + """ + Rotate Transformation Matrix + + Parameters + ---------- + start_shape_zyx : Tuple + Shape of the input + angle : float, optional + Angles of rotation in degrees + end_shape_zyx : Tuple, optional + Shape of output space + + Returns + ------- + np.ndarray + Rotation matrix + """ + # TODO: make this 3D? + center_Y_start, center_X_start = np.array(start_shape_zyx)[-2:] / 2 + if end_shape_zyx is None: + center_Y_end, center_X_end = (center_Y_start, center_X_start) + else: + center_Y_end, center_X_end = np.array(end_shape_zyx)[-2:] / 2 + + theta = np.radians(angle) + + rotation_matrix = np.array( + [ + [1, 0, 0, 0], + [ + 0, + np.cos(theta), + -np.sin(theta), + -center_Y_start * np.cos(theta) + + np.sin(theta) * center_X_start + + center_Y_end, + ], + [ + 0, + np.sin(theta), + np.cos(theta), + -center_Y_start * np.sin(theta) + - center_X_start * np.cos(theta) + + center_X_end, + ], + [0, 0, 0, 1], + ] + ) + return rotation_matrix + + +def get_3D_fliplr_matrix(start_shape_zyx: tuple, end_shape_zyx: tuple = None) -> np.ndarray: + """ + Get 3D left-right flip transformation matrix. + + Parameters + ---------- + start_shape_zyx : tuple + Shape of the source volume (Z, Y, X). + end_shape_zyx : tuple, optional + Shape of the target volume (Z, Y, X). If None, uses start_shape_zyx. + + Returns + ------- + np.ndarray + 4x4 transformation matrix for left-right flip. + """ + center_X_start = start_shape_zyx[-1] / 2 + if end_shape_zyx is None: + center_X_end = center_X_start + else: + center_X_end = end_shape_zyx[-1] / 2 + + # Flip matrix: reflects across X axis and translates to maintain center + flip_matrix = np.array( + [ + [1, 0, 0, 0], # Z unchanged + [0, 1, 0, 0], # Y unchanged + [0, 0, -1, 2 * center_X_end], # X flipped and translated + [0, 0, 0, 1], # Homogeneous coordinate + ] + ) + return flip_matrix + + +def apply_affine_transform( + zyx_data: np.ndarray, + matrix: np.ndarray, + output_shape_zyx: Tuple, + method="ants", + interpolation: str = "linear", + crop_output_slicing: bool = None, +) -> np.ndarray: + """_summary_ + + Parameters + ---------- + zyx_data : np.ndarray + 3D input array to be transformed + matrix : np.ndarray + 3D Homogenous transformation matrix + output_shape_zyx : Tuple + output target zyx shape + method : str, optional + method to use for transformation, by default 'ants' + interpolation: str, optional + interpolation mode for ants, by default "linear" + crop_output : bool, optional + crop the output to the largest interior rectangle, by default False + + Returns + ------- + np.ndarray + registered zyx data + """ + + Z, Y, X = output_shape_zyx + if crop_output_slicing is not None: + Z_slice, Y_slice, X_slice = crop_output_slicing + Z = Z_slice.stop - Z_slice.start + Y = Y_slice.stop - Y_slice.start + X = X_slice.stop - X_slice.start + + # TODO: based on the signature of this function, it should not be called on 4D array + if zyx_data.ndim == 4: + registered_czyx = np.zeros((zyx_data.shape[0], Z, Y, X), dtype=np.float32) + for c in range(zyx_data.shape[0]): + registered_czyx[c] = apply_affine_transform( + zyx_data[c], + matrix, + output_shape_zyx, + method=method, + interpolation=interpolation, + crop_output_slicing=crop_output_slicing, + ) + return registered_czyx + else: + # Convert nans to 0 + zyx_data = np.nan_to_num(zyx_data, nan=0) + + # NOTE: default set to ANTS apply_affine method until we decide we get a benefit from using cupy + # The ants method on CPU is 10x faster than scipy on CPU. Cupy method has not been bencharked vs ANTs + + if method == "ants": + # The output has to be a ANTImage Object + empty_target_array = np.zeros((output_shape_zyx), dtype=np.float32) + target_zyx_ants = ants.from_numpy(empty_target_array) + + T_ants = convert_transform_to_ants(matrix) + + zyx_data_ants = ants.from_numpy(zyx_data.astype(np.float32)) + registered_zyx = T_ants.apply_to_image( + zyx_data_ants, reference=target_zyx_ants, interpolation=interpolation + ).numpy() + + elif method == "scipy": + + registered_zyx = scipy.ndimage.affine_transform(zyx_data, matrix, output_shape_zyx) + + else: + raise ValueError(f"Unknown method {method}") + + # Crop the output to the largest interior rectangle + if crop_output_slicing is not None: + registered_zyx = registered_zyx[Z_slice, Y_slice, X_slice] + + return registered_zyx + + +def pad_to_shape( + arr: ArrayLike, + shape: Tuple[int, ...], + mode: str, + verbose: bool = False, + **kwargs, +) -> ArrayLike: + """ + Pad or crop array to match provided shape. + + Parameters + ---------- + arr : ArrayLike + Input array. + shape : Tuple[int] + Output shape. + mode : str + Padding mode (see np.pad). + verbose : bool + If True, print verbose output. + kwargs : dict + Additional keyword arguments for np.pad. + + Returns + ------- + ArrayLike + Padded array. + """ + assert arr.ndim == len(shape) + + dif = tuple(s - a for s, a in zip(shape, arr.shape)) + assert all(d >= 0 for d in dif) + + pad_width = [[s // 2, s - s // 2] for s in dif] + + if verbose: + click.echo( + f"padding: input shape {arr.shape}, output shape {shape}, padding {pad_width}" + ) + + return np.pad(arr, pad_width=pad_width, mode=mode, **kwargs) + + +def center_crop( + arr: ArrayLike, + shape: Tuple[int, ...], + verbose: bool = False, +) -> ArrayLike: + """ + Crop the center of `arr` to match provided shape. + + Parameters + ---------- + arr : ArrayLike + Input array. + shape : Tuple[int, ...] + """ + assert arr.ndim == len(shape) + + starts = tuple((cur_s - s) // 2 for cur_s, s in zip(arr.shape, shape)) + + assert all(s >= 0 for s in starts) + + slicing = tuple(slice(s, s + d) for s, d in zip(starts, shape)) + if verbose: + click.echo( + f"center crop: input shape {arr.shape}, output shape {shape}, slicing {slicing}" + ) + return arr[slicing] + + +def match_shape( + img: ArrayLike, + shape: Tuple[int, ...], + verbose: bool = False, +) -> ArrayLike: + """ + Pad or crop array to match provided shape. + + Parameters + ---------- + img : ArrayLike + Input array. + shape : Tuple[int, ...] + verbose : bool + If True, print verbose output. + + Returns + ------- + ArrayLike + Padded or cropped array. + """ + + if np.any(shape > img.shape): + padded_shape = np.maximum(img.shape, shape) + img = pad_to_shape(img, padded_shape, mode="reflect") + + if np.any(shape < img.shape): + img = center_crop(img, shape) + + if verbose: + click.echo(f"matched shape: input shape {img.shape}, output shape {shape}") + + return img diff --git a/biahub/settings.py b/biahub/settings.py index f536635e..b577581d 100644 --- a/biahub/settings.py +++ b/biahub/settings.py @@ -121,9 +121,21 @@ class MatchDescriptorSettings(MyBaseModel): cross_check: bool = False +class FilterMatchesSettings(MyBaseModel): + angle_threshold: float = 0 + direction_threshold: float = 0 + min_distance_quantile: float = 0.01 + max_distance_quantile: float = 0.95 + + +class QCBeadsRegistrationSettings(MyBaseModel): + iterations: int = 2 + score_threshold: float = 0.40 + score_centroid_mask_radius: int = 6 + + class BeadsMatchSettings(MyBaseModel): algorithm: Literal["hungarian", "match_descriptor"] = "hungarian" - t_reference: Literal["first", "previous"] = "first" source_peaks_settings: Optional[DetectPeaksSettings] = Field( default_factory=DetectPeaksSettings ) @@ -132,8 +144,8 @@ class BeadsMatchSettings(MyBaseModel): ) match_descriptor_settings: MatchDescriptorSettings = MatchDescriptorSettings() hungarian_match_settings: HungarianMatchSettings = HungarianMatchSettings() - filter_distance_threshold: float = 0.95 - filter_angle_threshold: float = 0 + filter_matches_settings: FilterMatchesSettings = FilterMatchesSettings() + qc_settings: QCBeadsRegistrationSettings = QCBeadsRegistrationSettings() class PhaseCrossCorrSettings(MyBaseModel): @@ -172,8 +184,11 @@ class EvalTransformSettings(MyBaseModel): class AffineTransformSettings(MyBaseModel): + t_reference: Literal["first", "previous"] = "first" transform_type: Literal["euclidean", "similarity", "affine"] = "euclidean" approx_transform: list = np.eye(4).tolist() + use_prev_t_transform: bool = True + compute_approx_transform: bool = False @field_validator("approx_transform") @classmethod @@ -537,10 +552,10 @@ def validate_slice_lengths(self): class StabilizationSettings(MyBaseModel): stabilization_estimation_channel: str - stabilization_type: Literal["z", "xy", "xyz"] - stabilization_method: Literal["beads", "phase-cross-corr", "focus-finding"] = ( - "focus-finding" - ) + stabilization_type: Literal["z", "xy", "xyz", "affine"] + stabilization_method: Literal[ + "beads", "phase-cross-corr", "focus-finding", "manual", "ants", "beads" + ] = "focus-finding" stabilization_channels: list affine_transform_zyx_list: list time_indices: Union[NonNegativeInt, list[NonNegativeInt], Literal["all"]] = "all" diff --git a/biahub/stabilize.py b/biahub/stabilize.py index 46212a13..e530c0e0 100644 --- a/biahub/stabilize.py +++ b/biahub/stabilize.py @@ -7,6 +7,7 @@ import submitit from iohub.ngff import open_ome_zarr +from iohub.ngff.utils import create_empty_plate from scipy.linalg import svd from scipy.spatial.transform import Rotation as R @@ -23,7 +24,6 @@ ) from biahub.cli.utils import ( copy_n_paste_czyx, - create_empty_hcs_zarr, estimate_resources, process_single_position_v2, yaml_to_model, @@ -214,7 +214,7 @@ def stabilize( } # Create the output zarr mirroring input_position_dirpaths - create_empty_hcs_zarr( + create_empty_plate( store_path=output_dirpath, position_keys=[p.parts[-3:] for p in input_position_dirpaths], **output_metadata, diff --git a/scripts/debug_beads_registration.py b/scripts/debug_beads_registration.py index ac126282..3f99a3ee 100644 --- a/scripts/debug_beads_registration.py +++ b/scripts/debug_beads_registration.py @@ -1,36 +1,61 @@ -# %% +""" +Debug script for beads-based registration. + +Interactive notebook-style script for developing and debugging the beads +registration pipeline. Walks through each step: +1. Load data (label-free reference + light-sheet moving volumes) +2. Apply approximate transform and visualize alignment +3. Detect bead peaks in both channels +4. Match beads and visualize correspondences +5. Estimate correction transform from matches +6. Run the full iterative optimization pipeline (optimize_transform) + +Usage: Run cell-by-cell in an IDE with interactive Python support (e.g. VSCode). +Set ``visualize = True`` to open napari viewers at each stage. +""" + +# %% Imports import ants import numpy as np from pathlib import Path from iohub import open_ome_zarr -from biahub.register import convert_transform_to_ants +from biahub.core.transform import Transform import napari -from biahub.estimate_registration import ( - filter_matches, - detect_bead_peaks, - get_matches_from_beads, - estimate_transform) from biahub.settings import EstimateRegistrationSettings -import numpy as np - - -# %%% +from biahub.registration.beads import ( + transform_from_matches, + matches_from_beads, + peaks_from_beads, + optimize_transform, +) -dataset = '2024_11_21_A549_TOMM20_DENV' +# %% Dataset configuration +dataset = '2024_12_11_A549_LAMP1_DENV' fov = 'C/1/000000' -root_path = Path(f'/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/{dataset}/1-preprocess/') -t_idx = 76 +root_path = Path(f'/hpc/projects/intracellular_dashboard/organelle_dynamics/{dataset}/1-preprocess/') +t_idx = 3 lf_data_path = root_path / f"label-free/0-reconstruct/{dataset}.zarr" / fov ls_data_path = root_path / f"light-sheet/raw/0-deskew/{dataset}.zarr" / fov visualize = True -# %% +# %% Registration settings config_dict = { "target_channel_name": "Phase3D", "source_channel_name": "GFP EX488 EM525-45", - "beads_match_settings": { + "beads_match_settings": { "algorithm": "hungarian", + "qc_settings": { + "iterations": 2, + "score_threshold": 0.40, + "score_centroid_mask_radius": 6 + }, + "filter_matches_settings": { + "min_distance_quantile": 0.05, + "max_distance_quantile": 0.99, + "angle_threshold": 0, + "direction_threshold": 50 + }, "source_peaks_settings": { "threshold_abs": 110, "nms_distance": 16, @@ -43,14 +68,9 @@ "min_distance": 0, "block_size": [8, 8, 8] }, - "match_descriptor_settings": { - "distance_metric": "euclidean", - "max_ratio": 1, - "cross_check": True - }, "hungarian_match_settings": { "distance_metric": "euclidean", - "cost_threshold": 0.1, + "cost_threshold": 0.05, "cross_check": True, "max_ratio": 1, "edge_graph_settings": { @@ -62,18 +82,18 @@ "weights": { "dist": 0.5, "edge_angle": 1, - "edge_length": 1, - "pca_dir": 0, - "pca_aniso": 0, + "edge_length": 1.0, + "pca_dir": 0.0, + "pca_aniso": 0.0, "edge_descriptor": 0 } } }, - "filter_distance_threshold": 0.95, - "filter_angle_threshold": 0, }, "affine_transform_settings": { - "transform_type": "similarity", + "use_prev_t_transform": True, + "transform_type": "affine", + "compute_approx_transform": False, "approx_transform": [ [1, 0, 0, 0], [0, 0, -1.288, 1960], @@ -86,184 +106,158 @@ } config = EstimateRegistrationSettings(**config_dict) -# %% Load data -with open_ome_zarr(lf_data_path) as target_ds: - target_channel_name = target_ds.channel_names - target_channel_index = target_ds.channel_names.index(config.target_channel_name) - target_data = np.asarray(target_ds.data[t_idx, target_channel_index]) # take phase channel - target_scale = target_ds.scale - -with open_ome_zarr(ls_data_path) as source_ds: - source_channel_name = source_ds.channel_names - source_channel_index = source_channel_name.index(config.source_channel_name) - source_data = np.asarray(source_ds.data[t_idx, source_channel_index]) # take mCherry channel or the GFP channel (depending where the beads are) - source_scale = source_ds.scale - -# Register LS data with approx tranform -source_data_ants = ants.from_numpy(source_data) -target_data_ants = ants.from_numpy(target_data) - -source_data_reg_ants = convert_transform_to_ants(np.asarray(config.affine_transform_settings.approx_transform)).apply_to_image( - source_data_ants, reference=target_data_ants +# %% Load reference (label-free) and moving (light-sheet) volumes for a single timepoint +with open_ome_zarr(lf_data_path) as ref_ds: + ref_channel_name = ref_ds.channel_names + ref_channel_index = ref_ds.channel_names.index(config.target_channel_name) + ref_data = np.asarray(ref_ds.data[t_idx, ref_channel_index]) + ref_scale = ref_ds.scale + +with open_ome_zarr(ls_data_path) as mov_ds: + mov_channel_name = mov_ds.channel_names + mov_channel_index = mov_channel_name.index(config.source_channel_name) + mov_data = np.asarray(mov_ds.data[t_idx, mov_channel_index]) + mov_scale = mov_ds.scale + +# Convert to ANTs images (reused throughout the script) +mov_data_ants = ants.from_numpy(mov_data) +ref_data_ants = ants.from_numpy(ref_data) + +# %% Compute approximate transform from voxel sizes (if not provided in config) +if config.affine_transform_settings.compute_approx_transform: + from biahub.registration.utils import get_aprox_transform + + approx_transform = get_aprox_transform( + mov_shape=mov_data.shape[-3:], + ref_shape=ref_data.shape[-3:], + pre_affine_90degree_rotation=-1, + pre_affine_fliplr=False, + verbose=True, + ref_voxel_size=ref_scale, + mov_voxel_size=mov_scale, + ) + config.affine_transform_settings.approx_transform = approx_transform.to_list() +# %% Apply approximate transform to moving volume and visualize overlay +initial_transform = Transform( + matrix=np.asarray(config.affine_transform_settings.approx_transform) ) -source_data_reg = source_data_reg_ants.numpy() - +mov_data_reg_ants = initial_transform.to_ants().apply_to_image( + mov_data_ants, reference=ref_data_ants +) +mov_data_reg = mov_data_reg_ants.numpy() # %% if visualize: viewer = napari.Viewer() - viewer.add_image(target_data, name='LF') - viewer.add_image(source_data_reg, name='LS') - -# %% Detect peaks in LS data + viewer.add_image(ref_data, name='LF (reference)') + viewer.add_image(mov_data_reg, name='LS (approx registered)') -source_peaks, target_peaks = detect_bead_peaks( - source_data_reg, - target_data, +# %% Detect bead peaks in both the approximately registered moving and the reference volumes +mov_peaks, ref_peaks = peaks_from_beads( + mov_data_reg, + ref_data, config.beads_match_settings.source_peaks_settings, config.beads_match_settings.target_peaks_settings, verbose=True, - filter_dirty_peaks=False ) -#%% + +# %% Visualize detected moving peaks overlaid on the registered LS volume if visualize: viewer = napari.Viewer() - viewer.add_image(source_data_reg, name='LS') + viewer.add_image(mov_data_reg, name='LS (approx registered)') viewer.add_points( - source_peaks, name='peaks local max LS', size=20, symbol='disc', face_color='magenta' + mov_peaks, name='LS peaks', size=20, symbol='disc', face_color='magenta' ) -# %% + +# %% Visualize detected reference peaks overlaid on the LF volume if visualize: viewer = napari.Viewer() - viewer.add_image(target_data, name='LF') + viewer.add_image(ref_data, name='LF (reference)') viewer.add_points( - target_peaks, name='peaks local max LF', size=20, symbol='disc', face_color='green' + ref_peaks, name='LF peaks', size=20, symbol='disc', face_color='green' ) -# %% -matches = get_matches_from_beads( - source_peaks, - target_peaks, +# %% Match beads between moving and reference peak sets +matches = matches_from_beads( + mov_peaks, + ref_peaks, config.beads_match_settings, verbose=True ) - -# %% -matches = filter_matches( - matches, - source_peaks, - target_peaks, - angle_threshold=config.beads_match_settings.filter_angle_threshold, - distance_threshold=config.beads_match_settings.filter_distance_threshold, - verbose=True -) -# %% +# %% Visualize matched bead pairs as 3D lines connecting corresponding peaks if visualize: - # visualize matches viewer = napari.Viewer() - viewer.add_image(target_data, name='LF', contrast_limits=(0.5, 1.0),blending='additive') + viewer.add_image(ref_data, name='LF', contrast_limits=(0.5, 1.0), blending='additive') viewer.add_points( - target_peaks, name='LF peaks', size=12, symbol='ring', face_color='yellow',blending='additive' + ref_peaks, name='LF peaks', size=12, symbol='ring', + face_color='yellow', blending='additive' + ) + viewer.add_image( + mov_data_reg, name='LS', contrast_limits=(110, 230), + blending='additive', colormap='green' ) - viewer.add_image(source_data_reg, name='LS', contrast_limits=(110, 230), blending='additive', colormap='green') viewer.add_points( - source_peaks, name='LS peaks', size=12, symbol='ring', face_color='red',blending='additive' + mov_peaks, name='LS peaks', size=12, symbol='ring', + face_color='red', blending='additive' ) - # Project in 3D to be able to view the lines viewer.add_shapes( - data=[np.asarray([source_peaks[m[0]], target_peaks[m[1]]]) for m in matches], + data=[np.asarray([mov_peaks[m[0]], ref_peaks[m[1]]]) for m in matches], shape_type='line', edge_width=5, blending='additive', ) viewer.dims.ndisplay = 3 - -# %% Register LS data using compount transform - -tform = estimate_transform( - matches, - source_peaks, - target_peaks, - config.affine_transform_settings, - verbose=True +# %% Estimate correction transform from matches and compose with approx transform +fwd_transform, inv_transform = transform_from_matches( + matches=matches, + mov_peaks=mov_peaks, + ref_peaks=ref_peaks, + affine_transform_settings=config.affine_transform_settings, + ndim=mov_data_reg.ndim, ) - -compount_tform = np.asarray(config.affine_transform_settings.approx_transform) @ tform.inverse.params -compount_tform_ants = convert_transform_to_ants(compount_tform) -source_data_reg_2 = compount_tform_ants.apply_to_image( - source_data_ants, reference=target_data_ants +# Compose: apply approx transform first, then the bead-based correction +composed_transform = initial_transform @ inv_transform +mov_data_reg_2 = composed_transform.to_ants().apply_to_image( + mov_data_ants, reference=ref_data_ants ).numpy() +# %% Compare approx-only vs bead-corrected registration if visualize: viewer = napari.Viewer() - viewer.add_image(target_data, name='LF', contrast_limits=(-0.5, 1.0)) + viewer.add_image(ref_data, name='LF (reference)', contrast_limits=(-0.5, 1.0)) viewer.add_image( - source_data_reg, - name='LS approx registered', - contrast_limits=(110, 230), - blending='additive', - colormap='green' - ) - viewer.add_image( - source_data_reg_2, - name='LS registered', + mov_data_reg, + name='LS approx registered', contrast_limits=(110, 230), blending='additive', - colormap='magenta' - ) - -# %% - -fov_cell = 'C/2/000000' -root_path = Path(f'/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/{dataset}/1-preprocess/') -lf_data_sample_path = root_path / f"label-free/0-reconstruct/{dataset}.zarr" / fov_cell -ls_data_sample_path = root_path / f"light-sheet/raw/0-deskew/{dataset}.zarr" / fov_cell - -with open_ome_zarr(lf_data_sample_path) as target_ds: - target_channel_name = target_ds.channel_names - target_channel_index = target_ds.channel_names.index(config.target_channel_name) - target_sample_data = np.asarray(target_ds.data[t_idx, target_channel_index]) # take phase channel - target_scale = target_ds.scale - -with open_ome_zarr(ls_data_sample_path) as source_ds: - source_channel_name = source_ds.channel_names - source_channel_index = source_channel_name.index(config.source_channel_name) - source_sample_data = np.asarray(source_ds.data[t_idx, source_channel_index]) # take mCherry channel or the GFP channel (depending where the beads are) - source_scale = source_ds.scale - -# Register LS data with approx tranform -source_data_sample_ants = ants.from_numpy(source_sample_data) -target_data_sample_ants = ants.from_numpy(target_sample_data) - -source_data_sample_reg_ants = convert_transform_to_ants(np.asarray(config.affine_transform_settings.approx_transform)).apply_to_image( - source_data_sample_ants, reference=target_data_sample_ants -) -source_data_sample_reg = source_data_sample_reg_ants.numpy() - -source_data_sample_reg_2 = compount_tform_ants.apply_to_image( - source_data_sample_ants, reference=target_data_sample_ants -).numpy() - -if visualize: - viewer = napari.Viewer() - viewer.add_image(target_sample_data, name='LF', contrast_limits=(-0.5, 1.0)) - viewer.add_image( - source_data_sample_reg, - name='LS approx registered', - contrast_limits=(110, 230), - blending='additive', - colormap='green' + colormap='green' ) viewer.add_image( - source_data_sample_reg_2, - name='LS registered', + mov_data_reg_2, + name='LS bead-corrected', contrast_limits=(110, 230), blending='additive', colormap='magenta' ) +# %% Inspect matches and composed transform +matches # %% +composed_transform +# %% Run the full iterative optimization pipeline (detect -> match -> correct -> score) +optimized_transform, quality_score_optimized = optimize_transform( + transform=initial_transform, + mov=mov_data, + ref=ref_data, + beads_match_settings=config.beads_match_settings, + affine_transform_settings=config.affine_transform_settings, + verbose=True, + debug=True, +) +print(f"Optimized transform:\n{optimized_transform}") +print(f"Quality score: {quality_score_optimized:.4f}") +# %% \ No newline at end of file diff --git a/scripts/debug_beads_stabilization.py b/scripts/debug_beads_stabilization.py index 33d58f4f..88096f1f 100644 --- a/scripts/debug_beads_stabilization.py +++ b/scripts/debug_beads_stabilization.py @@ -3,58 +3,41 @@ import numpy as np from pathlib import Path from iohub import open_ome_zarr -from biahub.characterize_psf import detect_peaks -from biahub.register import convert_transform_to_ants +from biahub.core.transform import convert_transform_to_ants, Transform import napari -from skimage.transform import AffineTransform -from skimage.feature import match_descriptors -from biahub.estimate_registration import ( - build_edge_graph, - compute_cost_matrix, - match_hungarian_global_cost, - filter_matches, - detect_bead_peaks, - get_matches_from_beads, - estimate_transform) -from biahub.settings import EstimateStabilizationSettings +from biahub.settings import EstimateRegistrationSettings import numpy as np +from biahub.registration.beads import transform_from_matches, matches_from_beads, peaks_from_beads, optimize_transform # %%% -dataset = '2025_05_21_A549_MAP1LC3B_RPL36_GFP_sensor_ZIKV_DENV' -fov = 'A/3/000000' -root_path = Path(f'/hpc/projects/intracellular_dashboard/viral-sensor/{dataset}/1-preprocess/') - -t_target = 0 -t_source = 10 +dataset = '2024_12_11_A549_LAMP1_DENV' +fov = 'C/1/000000' +root_path = Path(f'/hpc/projects/intracellular_dashboard/organelle_dynamics/{dataset}/1-preprocess/') lf_data_path = root_path / f"label-free/0-reconstruct/{dataset}.zarr" / fov visualize = True +t_ref = 3 +t_mov = 4 -# %% Load data -with open_ome_zarr(lf_data_path) as target_ds: - target_channel_name = target_ds.channel_names - target_channel_index = target_ds.channel_names.index(config['target_channel_name']) - target_data = np.asarray(target_ds.data[t_target, target_channel_index]) # take phase channel - target_scale = target_ds.scale - -with open_ome_zarr(lf_data_path) as source_ds: - source_channel_name = source_ds.channel_names - source_channel_index = source_channel_name.index(config['source_channel_name']) - source_data = np.asarray(source_ds.data[t_source, source_channel_index]) # take mCherry channel or the GFP channel (depending where the beads are) - source_scale = source_ds.scale - -#%% -## If stabilizing LF beads, peaks threshold_abs == 0.8, if LS beads, peaks threshold_abs == 110 +# %% config_dict = { - "stabilization_estimation_channel": "Phase3D", - "stabilization_channels": ["Phase3D"], - "stabilization_type": "xyz", - "stabilization_method": "beads", - "beads_match_settings": { + "target_channel_name": "Phase3D", + "source_channel_name": "Phase3D", + "beads_match_settings": { "algorithm": "hungarian", - "t_reference": "first", + "qc_settings": { + "iterations": 2, + "score_threshold": 0.40, + "score_centroid_mask_radius": 6 + }, + "filter_matches_settings": { + "min_distance_quantile": 0.0, + "max_distance_quantile": 0.0, + "angle_threshold": 0, + "direction_threshold": 0 + }, "source_peaks_settings": { "threshold_abs": 0.8, "nms_distance": 16, @@ -67,109 +50,115 @@ "min_distance": 0, "block_size": [8, 8, 8] }, - "match_descriptor_settings": { - "distance_metric": "euclidean", - "max_ratio": 1, - "cross_check": True - }, "hungarian_match_settings": { "distance_metric": "euclidean", - "cost_threshold": 0.1, + "cost_threshold": 0.05, "cross_check": True, "max_ratio": 1, "edge_graph_settings": { "method": "knn", - "k": 5 + "k": 10 }, "cost_matrix_settings": { "normalize": False, "weights": { - "dist": 0.5, - "edge_angle": 1.0, + "dist": 1, + "edge_angle": 1, "edge_length": 1.0, "pca_dir": 0.0, "pca_aniso": 0.0, - "edge_descriptor": 0.0 + "edge_descriptor": 0 } } }, - "filter_distance_threshold": 0.95, - "filter_angle_threshold": 0 }, "affine_transform_settings": { - "transform_type": "euclidean", + "use_prev_t_transform": True, + "transform_type": "affine", + "approx_transform": [ + [1, 0, 0, 0], + [0, 0, -1.288, 1960], + [0, 1.288, 0, -460], + [0.0, 0.0, 0.0, 1.0] + ] }, "verbose": True } -config = EstimateStabilizationSettings(**config_dict) +config = EstimateRegistrationSettings(**config_dict) + +# %% Load data +with open_ome_zarr(lf_data_path) as ref_ds: + ref_channel_name = ref_ds.channel_names + ref_channel_index = ref_ds.channel_names.index(config.target_channel_name) + ref_data = np.asarray(ref_ds.data[t_ref, ref_channel_index]) # take phase channel + ref_scale = ref_ds.scale + + mov_data = np.asarray(ref_ds.data[t_mov, ref_channel_index]) # take phase channel + mov_scale = ref_ds.scale +# + + +# qc, measure matches vectors directions, lenght, angle .. statiscs, then evaluate the mean, std, min, max, etc. +# to determin filter # %% if visualize: viewer = napari.Viewer() - viewer.add_image(target_data, name='Target data') - viewer.add_image(source_data, name='Source data') + viewer.add_image(ref_data, name='LF') + viewer.add_image(mov_data, name='LS') # %% Detect peaks in LS data -source_peaks, target_peaks = detect_bead_peaks( - source_data, - target_data, +mov_peaks, ref_peaks = peaks_from_beads( + mov_data, + ref_data, config.beads_match_settings.source_peaks_settings, config.beads_match_settings.target_peaks_settings, verbose=True, - filter_dirty_peaks=True ) + #%% if visualize: viewer = napari.Viewer() - viewer.add_image(source_data, name='Source data') + viewer.add_image(mov_data, name='LS') viewer.add_points( - source_peaks, name='peaks local max Source', size=20, symbol='disc', face_color='magenta' + mov_peaks, name='peaks local max LS', size=20, symbol='disc', face_color='magenta' ) # %% if visualize: viewer = napari.Viewer() - viewer.add_image(target_data, name='Target data') + viewer.add_image(ref_data, name='LF') viewer.add_points( - target_peaks, name='peaks local max Target', size=20, symbol='disc', face_color='green' + ref_peaks, name='peaks local max LF', size=20, symbol='disc', face_color='green' ) # %% -matches = get_matches_from_beads( - source_peaks, - target_peaks, + +matches = matches_from_beads( + mov_peaks, + ref_peaks, config.beads_match_settings, verbose=True ) - -# %% -matches = filter_matches( - matches, - source_peaks, - target_peaks, - angle_threshold=config.beads_match_settings.filter_angle_threshold, - distance_threshold=config.beads_match_settings.filter_distance_threshold, - verbose=True -) # %% if visualize: # visualize matches viewer = napari.Viewer() - viewer.add_image(target_data, name='Target data', contrast_limits=(0.5, 1.0),blending='additive') + viewer.add_image(ref_data, name='LF', contrast_limits=(0.5, 1.0),blending='additive') viewer.add_points( - target_peaks, name='Target peaks', size=12, symbol='ring', face_color='yellow',blending='additive' + ref_peaks, name='LF peaks', size=12, symbol='ring', face_color='yellow',blending='additive' ) - viewer.add_image(source_data, name='Source data', contrast_limits=(110, 230), blending='additive', colormap='green') + viewer.add_image(mov_data, name='LS', contrast_limits=(110, 230), blending='additive', colormap='green') viewer.add_points( - source_peaks, name='Source peaks', size=12, symbol='ring', face_color='red',blending='additive' + mov_peaks, name='LS peaks', size=12, symbol='ring', face_color='red',blending='additive' ) # Project in 3D to be able to view the lines viewer.add_shapes( - data=[np.asarray([source_peaks[m[0]], target_peaks[m[1]]]) for m in matches], + data=[np.asarray([mov_peaks[m[0]], ref_peaks[m[1]]]) for m in matches], shape_type='line', edge_width=5, blending='additive', @@ -178,39 +167,39 @@ # %% Register LS data using compount transform +from biahub.core.transform import Transform +initial_transform = Transform( + matrix=np.asarray(config.affine_transform_settings.approx_transform) + ) -tform = estimate_transform( - matches, - source_peaks, - target_peaks, - config.affine_transform_settings, - verbose=True +fwd_transform, inv_transform = transform_from_matches( + matches=matches, + mov_peaks=mov_peaks, + ref_peaks=ref_peaks, + affine_transform_settings=config.affine_transform_settings, + ndim=mov_data.ndim, ) - -compount_tform = tform.inverse.params -source_data_ants = ants.from_numpy(source_data) -target_data_ants = ants.from_numpy(target_data) -compount_tform_ants = convert_transform_to_ants(compount_tform) -source_data_reg_2 = compount_tform_ants.apply_to_image( - source_data_ants, reference=target_data_ants +mov_data_ants = ants.from_numpy(mov_data) +ref_data_ants = ants.from_numpy(ref_data) +mov_data_reg = inv_transform.to_ants().apply_to_image( + mov_data_ants, reference=ref_data_ants ).numpy() if visualize: viewer = napari.Viewer() - viewer.add_image(target_data, name='Target data', contrast_limits=(-0.5, 1.0)) + viewer.add_image(ref_data, name='LF', contrast_limits=(-0.5, 1.0)) viewer.add_image( - source_data, - name='Source data', + mov_data, + name='LS approx registered', contrast_limits=(110, 230), blending='additive', colormap='green' ) viewer.add_image( - source_data_reg_2, - name='Source data registered', + mov_data_reg, + name='LS registered', contrast_limits=(110, 230), blending='additive', colormap='magenta' ) -# %% diff --git a/settings/example_estimate_registration_settings_beads.yml b/settings/example_estimate_registration_settings_beads.yml index fa7f9dd5..6e4bb6a8 100644 --- a/settings/example_estimate_registration_settings_beads.yml +++ b/settings/example_estimate_registration_settings_beads.yml @@ -12,7 +12,15 @@ estimation_method: beads beads_match_settings: algorithm: hungarian # "hungarian", "match_descriptor" - t_reference: "first" # Reference timepoint for matching + qc_settings: + iterations: 1 # Number of quality control iterations + score_threshold: 0.40 # Quality control score threshold + score_centroid_mask_radius: 6 # Radius of the centroid mask for quality control + filter_matches_settings: + min_distance_quantile: 0.01 # Filter matches based on distance threshold + max_distance_quantile: 0.99 # Filter matches based on distance threshold + angle_threshold: 0 # Filter matches based on angle threshold + direction_threshold: 0 # Filter matches based on direction threshold source_peaks_settings: threshold_abs: 110 # intensity threshold for peak detection nms_distance: 16 # non-maximum suppression radius @@ -51,12 +59,12 @@ beads_match_settings: pca_dir: 0.0 # Principal Component Analysis direction weight pca_aniso: 0.0 # Principal Component Analysis anisotropy weight edge_descriptor: 0.0 # Edge descriptor weight - filter_distance_threshold: 0.95 # Filter matches based on distance threshold - filter_angle_threshold: 0 # Filter matches based on angle threshold # Settings for affine transformation estimation (applies to all stabilization methods) affine_transform_settings: - transform_type: similarity # Type of transform to estimate "euclidean", "similarity", or "affine" + use_prev_t_transform: True # Use the previous timepoint transform as the initial transform + compute_approx_transform: False # Compute the approximate transform using the beads match settings + transform_type: affine # Type of transform to estimate "euclidean", "similarity", or "affine" approx_transform: # Optional initial guess for the 4x4 affine transformation matrix (if None, will be estimated) - - 1 - 0 diff --git a/settings/example_estimate_stabilization_settings_xyz_beads.yml b/settings/example_estimate_stabilization_settings_xyz_beads.yml index dcb8053f..54c25ee7 100644 --- a/settings/example_estimate_stabilization_settings_xyz_beads.yml +++ b/settings/example_estimate_stabilization_settings_xyz_beads.yml @@ -18,8 +18,16 @@ stabilization_type: "xyz" stabilization_method: "beads" beads_match_settings: + qc_settings: + iterations: 1 # Number of quality control iterations + score_threshold: 0.40 # Quality control score threshold + score_centroid_mask_radius: 6 # Radius of the centroid mask for quality control algorithm: hungarian # "hungarian", "match_descriptor" - t_reference: "first" # Reference timepoint for matching + filter_matches_settings: + min_distance_quantile: 0.01 # Filter matches based on distance threshold + max_distance_quantile: 0.99 # Filter matches based on distance threshold + angle_threshold: 0 # Filter matches based on angle threshold + direction_threshold: 0 # Filter matches based on direction threshold source_peaks_settings: threshold_abs: 110 # intensity threshold for peak detection nms_distance: 16 # non-maximum suppression radius @@ -57,12 +65,11 @@ beads_match_settings: pca_dir: 0.0 # Principal Component Analysis direction weight (0.0 means no direction) pca_aniso: 0.0 # Principal Component Analysis anisotropy weight (0.0 means no anisotropy) edge_descriptor: 0.0 # Edge descriptor weight (0.0 means no descriptor) - filter_distance_threshold: 0.95 # Filter matches based on distance threshold - filter_angle_threshold: 0 # Filter matches based on dominant angle direction (in degrees) # Settings for affine transformation estimation (applies to all stabilization methods) affine_transform_settings: + t_reference: "first" # Reference timepoint for transform transform_type: euclidean # Type of transform to estimate "euclidean", "similarity", or "affine" # Validation and interpolation settings for transform smoothing