Skip to content

Commit 605b8e4

Browse files
author
anna-grim
committed
feat: leaf-branch proposals
1 parent 389a958 commit 605b8e4

4 files changed

Lines changed: 143 additions & 31 deletions

File tree

src/neuron_proofreader/proposal_graph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,23 @@ def add_proposal(self, i, j):
193193
self.node_proposals[j].add(i)
194194
self.proposals.add(proposal)
195195

196-
def generate_proposals(self, search_radius):
196+
def generate_proposals(self, search_radius, allow_nonleaf_targets=False):
197197
"""
198198
Generates proposals from leaf nodes.
199199
200200
Parameters
201201
----------
202202
search_radius : float
203203
Search radius used to generate proposals.
204+
allow_nonleaf_targets : bool, optional
205+
Indication of whether to generate proposals between leaf and nodes
206+
with degree 2. Default is False.
204207
"""
205208
# Proposal generation
206-
proposals = self.proposal_generator(search_radius)
209+
proposals = self.proposal_generator(
210+
search_radius, allow_nonleaf_targets=allow_nonleaf_targets
211+
)
212+
207213
self.search_radius = search_radius
208214
self.store_proposals(proposals)
209215
self.trim_proposals()

src/neuron_proofreader/skeleton_graph.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020
import zipfile
2121

22-
from neuron_proofreader.utils import graph_util as gutil, img_util, util
22+
from neuron_proofreader.utils import geometry_util, graph_util, img_util, util
2323

2424

2525
class SkeletonGraph(nx.Graph):
@@ -82,7 +82,7 @@ def __init__(
8282

8383
# Graph Loader
8484
anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0)
85-
self.graph_loader = gutil.GraphLoader(
85+
self.graph_loader = graph_util.GraphLoader(
8686
anisotropy=anisotropy,
8787
min_size=min_size,
8888
node_spacing=node_spacing,
@@ -686,15 +686,15 @@ def nodes_with_segment_id(self, segment_id):
686686
)
687687
return nodes
688688

689-
def nodes_within_distance(self, root, max_depth):
689+
def nodes_within_distance(self, root, max_dist):
690690
"""
691691
Gets nodes connected to the given root node up to a certain depth.
692692
693693
Parameters
694694
----------
695695
root : int
696696
Node ID and root of search.
697-
max_depth : float
697+
max_dist : float
698698
Maximum distance (in microns) between returned nodes and root.
699699
700700
Returns
@@ -711,7 +711,7 @@ def nodes_within_distance(self, root, max_depth):
711711
# Populate queue
712712
for j in self.neighbors(i):
713713
dist_j = dist_i + self.dist(i, j)
714-
if dist_j < max_depth and j not in visited:
714+
if dist_j < max_dist and j not in visited:
715715
queue.append((j, dist_j))
716716
visited.add(j)
717717
return list(visited)
@@ -750,6 +750,25 @@ def path_from_leaf(self, leaf, max_depth=np.inf):
750750
path.append(j)
751751
return path
752752

753+
def path_length(self, path):
754+
"""
755+
Computes the length of the given path.
756+
757+
Parameters
758+
----------
759+
path : List[int]
760+
List of nodes that forms a path.
761+
762+
Returns
763+
-------
764+
Length of the given path.
765+
"""
766+
if len(path) > 1:
767+
diffs = self.node_xyz[path[1:]] - self.node_xyz[path[:-1]]
768+
return np.sqrt(np.sum(diffs ** 2))
769+
else:
770+
return 0
771+
753772
def rooted_subgraph(self, root, radius):
754773
"""
755774
Gets a rooted subgraph with the given radius (in microns).
@@ -795,25 +814,6 @@ def rooted_subgraph(self, root, radius):
795814
subgraph.node_xyz = self.node_xyz[idxs]
796815
return subgraph
797816

798-
def path_length(self, path):
799-
"""
800-
Computes the length of the given path.
801-
802-
Parameters
803-
----------
804-
path : List[int]
805-
List of nodes that forms a path.
806-
807-
Returns
808-
-------
809-
Length of the given path.
810-
"""
811-
if len(path) > 1:
812-
diffs = self.node_xyz[path[1:]] - self.node_xyz[path[:-1]]
813-
return np.sqrt(np.sum(diffs ** 2))
814-
else:
815-
return 0
816-
817817
def set_kdtree(self):
818818
"""
819819
Initializes KD-Tree from node xyz coordinates.
@@ -852,3 +852,20 @@ def summary(self, prefix=""):
852852
summary.append(f"# Edges: {n_edges}")
853853
summary.append(f"Memory Consumption: {memory:.2f} GBs")
854854
return "\n".join(summary)
855+
856+
def tangent_from_leaf(self, leaf, max_depth=np.inf):
857+
"""
858+
Computes the tangent vector of the path emanating from the given leaf.
859+
860+
Parameters
861+
----------
862+
leaf : int
863+
Node ID and starting point of path extracted.
864+
max_depth : float
865+
Maximum depth (in microns) of path extracted.
866+
867+
Returns
868+
-------
869+
"""
870+
path = self.path_from_leaf(leaf, max_depth=max_depth)
871+
return geometry_util.tangent(self.node_xyz[np.array(path)])

src/neuron_proofreader/split_proofreading/proposal_generation.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def __call__(self, initial_radius, allow_nonleaf_targets=False):
7171
with degree 2. Default is False.
7272
"""
7373
# Initializations
74-
self.set_kdtree()
7574
self.allow_nonleaf_targets = allow_nonleaf_targets
75+
self.set_kdtree()
7676
iterator = self.graph.leaf_nodes()
7777
if self.graph.verbose:
7878
iterator = tqdm(iterator, desc="Proposal Generation")
@@ -141,7 +141,7 @@ def find_node_candidates(self, leaf, radius):
141141
"""
142142
node_candidates = list()
143143
for node in self.get_nearby_nodes(leaf, radius):
144-
# UPDATE - check whether to move to leaf if allowing leaf-branch connections
144+
node = self.adjust_position(leaf, node, radius)
145145
if self.is_valid_proposal(leaf, node):
146146
node_candidates.append(node)
147147
return node_candidates
@@ -178,6 +178,73 @@ def get_nearby_nodes(self, leaf, radius):
178178
return [val["node"] for val in pts_dict.values()]
179179

180180
# --- Helpers ---
181+
def adjust_position(self, leaf, node, radius):
182+
"""
183+
Adjusts a candidate node by searching for a nearby leaf, or falling
184+
back to angular alignment if none is found.
185+
186+
Parameters
187+
----------
188+
leaf : int
189+
Node ID of the leaf node.
190+
node : int
191+
Node ID of the candidate node.
192+
radius : float
193+
Maximum path distance (in microns) for searching neighboring
194+
nodes.
195+
196+
Returns
197+
-------
198+
node : int
199+
Node ID of candidate.
200+
"""
201+
# Check if node is close to another leaf
202+
queue = [(node, 0)]
203+
visited = {node}
204+
while queue:
205+
# Visit node
206+
i, dist_i = queue.pop()
207+
if self.graph.degree[i] == 1:
208+
return i
209+
210+
# Update queue
211+
for j in self.graph.neighbors(i):
212+
dist_j = dist_i + self.graph.dist(i, j)
213+
if dist_j < radius and j not in visited:
214+
queue.append((j, dist_j))
215+
visited.add(j)
216+
return self.maximize_branch_angle(leaf, node)
217+
218+
def maximize_branch_angle(self, leaf, node):
219+
"""
220+
Selects a nearby node that maximizes the angular alignment with a
221+
leaf's tangent.
222+
223+
Parameters
224+
----------
225+
leaf : int
226+
Node ID of the leaf node.
227+
node : int
228+
Node ID of the candidate node. The search for better candidates is
229+
restricted to nodes within a fixed distance of this node.
230+
231+
Returns
232+
-------
233+
node : int
234+
Node ID of the node that maximizes angular alignment with the leaf
235+
tangent.
236+
"""
237+
max_inner_product = 0
238+
leaf_tangent = self.graph.tangent_from_leaf(leaf, max_depth=20)
239+
for i in self.graph.nodes_within_distance(node, 30):
240+
pts = [self.graph.node_xyz[i], self.graph.node_xyz[leaf]]
241+
proposal_tangent = geometry_util.tangent(pts)
242+
inner_product = abs(np.sum(leaf_tangent * proposal_tangent))
243+
if inner_product > max_inner_product:
244+
max_inner_product = inner_product
245+
node = i
246+
return node
247+
181248
def is_valid_proposal(self, leaf, i):
182249
"""
183250
Determines whether a pair of nodes satisfies the following:
@@ -226,10 +293,8 @@ def query_nbhd(self, node, radius):
226293
else:
227294
nodes = list()
228295
for idx in self.kdtree.query_ball_point(xyz, radius):
229-
xyz = self.kdtree.data[idx]
230-
node = self.graph.closest_node(xyz)
296+
node = self.graph.closest_node(self.kdtree.data[idx])
231297
nodes.append(node)
232-
assert self.graph.degree[node] == 1
233298
return nodes
234299

235300
def select_closest_components(self, pts_dict):

src/neuron_proofreader/utils/ml_util.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def build_network(self, input_dim, output_dim, n_layers):
6363

6464
@staticmethod
6565
def _init_weights(m):
66+
"""
67+
Initializes weights for linear layers using Kaiming initialization.
68+
69+
Parameters
70+
----------
71+
m : torch.nn.Module
72+
Module to initialize.
73+
"""
6674
if isinstance(m, nn.Linear):
6775
nn.init.kaiming_normal_(m.weight, nonlinearity="leaky_relu")
6876
if m.bias is not None:
@@ -116,8 +124,24 @@ def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1):
116124

117125
# --- Data Structures ---
118126
class TensorDict(dict):
127+
"""
128+
A class for model inputs in a dictionary.
129+
"""
119130

120131
def to(self, device):
132+
"""
133+
Moves dictionary values to the specified GPU device.
134+
135+
Parameters
136+
----------
137+
device : str
138+
Name of GPU device to move inputs to.
139+
140+
Returns
141+
-------
142+
TensorDict
143+
Dictionary with values moved to the specified GPU device.
144+
"""
121145
return TensorDict({k: self.move(v, device) for k, v in self.items()})
122146

123147
def move(self, v, device):

0 commit comments

Comments
 (0)