From a2957bfd669d65ba1ce0cc80e8867cc4367987fa Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 21 May 2026 17:55:15 +0000 Subject: [PATCH] feat: save merge detection model preds --- .../merge_proofreading/merge_inference.py | 82 +++++++++++-------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index b07ddf8e..73bb319d 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -19,6 +19,7 @@ import networkx as nx import numpy as np import os +import pandas as pd import torch from neuron_proofreader.machine_learning.point_cloud_models import ( @@ -47,7 +48,7 @@ def __init__( # Instance attributes self.dataset = dataset self.device = device - self.node_preds = np.ones((len(dataset.graph.node_xyz))) * 1e-2 + self.node_preds = np.ones((len(dataset.node_xyz))) * 1e-2 self.patch_shape = dataset.patch_shape self.remove_detected_sites = remove_detected_sites self.threshold = threshold @@ -111,7 +112,7 @@ def filter_with_nms(self, merge_sites, likelihoods): while merge_sites: # Local max root = merge_sites.pop() - xyz_root = self.dataset.graph.node_xyz[root] + xyz_root = self.dataset.node_xyz[root] if root in merge_sites_set: filtered_merge_sites.add(root) merge_sites_set.remove(root) @@ -125,7 +126,7 @@ def filter_with_nms(self, merge_sites, likelihoods): # Visit node i, dist_i = queue.pop() if i in merge_sites_set: - xyz_i = self.dataset.graph.node_xyz[i] + xyz_i = self.dataset.node_xyz[i] iou = img_util.compute_iou3d( xyz_i, xyz_root, self.patch_shape, self.patch_shape ) @@ -134,8 +135,8 @@ def filter_with_nms(self, merge_sites, likelihoods): self.node_preds[i] = 1e-2 # Populate queue - for j in self.dataset.graph.neighbors(i): - dist_j = dist_i + self.dataset.graph.dist(i, j) + for j in self.dataset.neighbors(i): + dist_j = dist_i + self.dataset.dist(i, j) if j not in visited and dist_j < self.patch_shape[0]: queue.append((j, dist_j)) visited.add(j) @@ -145,26 +146,26 @@ def remove_merge_sites(self, merge_site_nodes, max_depth=10): rm_nodes = set() for root in tqdm(merge_site_nodes, desc="Remove Merge Sites"): # Extract neighborhood - root = self.dataset.graph.find_nearby_branching_node(root) - nbhd = self.dataset.graph.nodes_within_distance(root, max_depth) + root = self.dataset.find_nearby_branching_node(root) + nbhd = self.dataset.nodes_within_distance(root, max_depth) # Check for branching node in neighborhood for i in list(nbhd): - if i != root and self.dataset.graph.degree[i] >= 3: - nbhd_i = self.dataset.graph.nodes_within_distance(root, 8) + if i != root and self.dataset.degree[i] >= 3: + nbhd_i = self.dataset.nodes_within_distance(root, 8) nbhd.extend(nbhd_i) # Add nodes to removal list rm_nodes.update(set(nbhd)) # Update graph - self.dataset.graph.remove_nodes(rm_nodes) + self.dataset.remove_nodes(rm_nodes) print("# Nodes Deleted:", len(rm_nodes)) # --- Helpers --- def get_detected_sites(self, threshold): nodes = np.where(self.node_preds >= threshold)[0] - return [self.dataset.graph.node_xyz[i] for i in nodes] + return [self.dataset.node_xyz[i] for i in nodes] def save_parameters(self, output_dir): json_path = os.path.join(output_dir, "detection_parameters.json") @@ -185,7 +186,7 @@ def save_results( self.save_sites(output_dir) if save_fragments: fragments_path = os.path.join(output_dir, "fragments.zip") - self.dataset.graph.to_zipped_swcs(fragments_path) + self.dataset.to_zipped_swcs(fragments_path) # Upload results to S3 (if applicable) if output_prefix_s3: @@ -193,9 +194,18 @@ def save_results( util.upload_dir_to_s3(output_dir, bucket_name, prefix) def save_sites(self, output_dir): + # Save model predictions + df = pd.DataFrame(columns=["World", "Segment_ID", "Prediction"]) + df["World"] = self.dataset.node_xyz + df["Prediction"] = self.node_preds + df["Segment_ID"] = [ + self.dataset.node_segment_id(i) for i in self.dataset.nodes + ] + df.to_csv(os.path.join(output_dir, "model_predictions.csv")) + # Get predicted merge sites nodes = np.where(self.node_preds >= self.threshold)[0] - detected_sites = [self.dataset.graph.node_xyz[i] for i in nodes] + detected_sites = [self.dataset.node_xyz[i] for i in nodes] print("# Sites Saved:", len(nodes)) # Save predicted merge sites @@ -213,14 +223,14 @@ def save_train_dataset(self, output_dir): roots = list() visited_ids = set() for i in np.where(self.node_preds >= self.threshold)[0]: - cc_id = self.dataset.graph.node_component_id[i] + cc_id = self.dataset.node_component_id[i] if cc_id not in visited_ids: roots.append([i]) visited_ids.add(cc_id) # Save fragments zip_path = os.path.join(output_dir, "fragments.zip") - self.dataset.graph._batch_to_zipped_swcs(roots, zip_path, False) + self.dataset._batch_to_zipped_swcs(roots, zip_path, False) self.save_sites(output_dir) print("# Fragments Saved:", len(roots)) @@ -279,7 +289,7 @@ def __iter__(self): # Search graph visited_ids = set() for u in self.graph.leaf_nodes(): - component_id = self.graph.node_component_id[u] + component_id = self.node_component_id[u] if component_id not in visited_ids and component_id in valid_ids: visited_ids.add(component_id) yield from self._generate_batches_from_component(u) @@ -311,11 +321,11 @@ def find_fragments_to_search(self): # Check if path length satisfies threshold if length > self.min_size: - component_ids.add(self.graph.node_component_id[node]) + component_ids.add(self.node_component_id[node]) return component_ids def get_patch_centers(self, nodes): - patch_centers = [self.graph.node_voxel(i) for i in nodes] + patch_centers = [self.node_voxel(i) for i in nodes] return np.array(patch_centers, dtype=int) def get_label_mask(self, nodes, img_shape, offset): @@ -331,8 +341,8 @@ def get_label_mask(self, nodes, img_shape, offset): # Annotate mask subgraph = self.get_contained_subgraph(nodes, img_shape, offset) for i, j in subgraph.edges: - voxel_i = self.graph.node_voxel(i) - offset - voxel_j = self.graph.node_voxel(j) - offset + voxel_i = self.node_voxel(i) - offset + voxel_j = self.node_voxel(j) - offset voxels = geometry_util.make_digital_line(voxel_i, voxel_j) img_util.annotate_voxels(segment_mask, voxels) return segment_mask @@ -344,13 +354,13 @@ def get_contained_subgraph(self, nodes, img_shape, offset): while queue: # Visit node i = queue.pop() - voxel_i = self.graph.node_voxel(i) - offset + voxel_i = self.node_voxel(i) - offset if not img_util.is_contained(voxel_i, img_shape, buffer=1): continue # Update queue - for j in self.graph.neighbors(i): - voxel_j = self.graph.node_voxel(j) - offset + for j in self.neighbors(i): + voxel_j = self.node_voxel(j) - offset if img_util.is_contained(voxel_j, img_shape): subgraph.add_edge(i, j) if j not in visited: @@ -359,7 +369,7 @@ def get_contained_subgraph(self, nodes, img_shape, offset): return subgraph def is_contained(self, node): - voxel = self.graph.node_voxel(node) + voxel = self.node_voxel(node) shape = self.img_reader.shape()[2::] buffer = np.max(self.patch_shape) + 1 return img_util.is_contained(voxel, shape, buffer=buffer) @@ -380,7 +390,7 @@ def read_superchunk(self, nodes): def is_near_leaf(self, node, threshold=20): # Check if node is branching - if self.graph.degree[node] > 2: + if self.degree[node] > 2: return False # Search neighborhood @@ -389,12 +399,12 @@ def is_near_leaf(self, node, threshold=20): while len(queue) > 0: # Visit node i, dist_i = queue.pop() - if self.graph.degree[i] == 1: + if self.degree[i] == 1: return True # Update queue - for j in self.graph.neighbors(i): - dist_j = dist_i + self.graph.dist(i, j) + for j in self.neighbors(i): + dist_j = dist_i + self.dist(i, j) if j not in visited and dist_j < threshold: queue.append((j, dist_j)) visited.add(j) @@ -489,7 +499,7 @@ def _generate_batch_nodes(self, root): nodes = list() for i, j in nx.dfs_edges(self.graph, source=root): # Check if starting new batch - self.distance_traversed += self.graph.dist(i, j) + self.distance_traversed += self.dist(i, j) if len(nodes) == 0: if self.is_node_valid(i): root = i @@ -499,7 +509,7 @@ def _generate_batch_nodes(self, root): continue # Check whether to yield batch - is_node_far = self.graph.dist(root, j) > 512 + is_node_far = self.dist(root, j) > 512 is_batch_full = len(nodes) == self.batch_size if is_node_far or is_batch_full: # Yield nodes in batch @@ -509,8 +519,8 @@ def _generate_batch_nodes(self, root): nodes = list() # Visit j - is_next = self.graph.dist(last_node, j) >= self.step_size - 2 - is_branching = self.graph.degree[j] >= 3 + is_next = self.dist(last_node, j) >= self.step_size - 2 + is_branching = self.degree[j] >= 3 if (is_next or is_branching) and self.is_node_valid(j): last_node = j nodes.append(j) @@ -560,6 +570,9 @@ def _get_multimodal_batch(self, nodes, img, offset): return nodes, batch # --- Helpers --- + def __getattr__(self, name): + return getattr(self.graph, name) + def estimate_iterations(self): """ Estimates the number of iterations required to search graph. @@ -573,7 +586,7 @@ def estimate_iterations(self): total_cable_length = 0 n_fragments = 0 for nodes in map(list, nx.connected_components(self.graph)): - cable_length = self.graph.cable_length(root=nodes[0]) + cable_length = self.cable_length(root=nodes[0]) if cable_length > self.min_size: total_cable_length += cable_length n_fragments += 1 @@ -651,6 +664,9 @@ def _generate_batch_nodes(self, root): root = j # --- Helpers --- + def __getattr__(self, name): + return getattr(self.graph, name) + def estimate_iterations(self): """ Estimates the number of iterations required to search graph.