Skip to content

Commit 863665e

Browse files
author
anna-grim
committed
refactor: updated proposal trimming
1 parent efb46be commit 863665e

10 files changed

Lines changed: 136 additions & 188 deletions

File tree

src/neuron_proofreader/merge_proofreading/merge_datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __getitem__(self, idx):
317317
"""
318318
# Get example
319319
brain_id, subgraph, label = self.get_site(idx)
320-
voxel = subgraph.get_voxel(0)
320+
voxel = subgraph.node_voxel(0)
321321

322322
# Extract subgraph and image patches centered at site
323323
img_patch = self.get_img_patch(brain_id, voxel)
@@ -509,7 +509,7 @@ def get_segment_mask(self, brain_id, center, subgraph):
509509
segment_mask = np.zeros(self.patch_shape)
510510

511511
# Annotate fragment
512-
center = subgraph.get_voxel(0)
512+
center = subgraph.node_voxel(0)
513513
offset = img_util.get_offset(center, self.patch_shape)
514514
for node1, node2 in subgraph.edges:
515515
# Get local voxel coordinates

src/neuron_proofreader/merge_proofreading/merge_inference.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __iter__(self):
251251

252252
# Search graph
253253
visited_ids = set()
254-
for u in self.graph.get_leafs():
254+
for u in self.graph.leaf_nodes():
255255
component_id = self.graph.node_component_id[u]
256256
if component_id not in visited_ids and component_id in valid_ids:
257257
visited_ids.add(component_id)
@@ -288,7 +288,7 @@ def find_fragments_to_search(self):
288288
return component_ids
289289

290290
def get_patch_centers(self, nodes):
291-
patch_centers = [self.graph.get_voxel(i) for i in nodes]
291+
patch_centers = [self.graph.node_voxel(i) for i in nodes]
292292
return np.array(patch_centers, dtype=int)
293293

294294
def get_label_mask(self, nodes, img_shape, offset):
@@ -304,8 +304,8 @@ def get_label_mask(self, nodes, img_shape, offset):
304304
# Annotate mask
305305
subgraph = self.get_contained_subgraph(nodes, img_shape, offset)
306306
for i, j in subgraph.edges:
307-
voxel_i = self.graph.get_voxel(i) - offset
308-
voxel_j = self.graph.get_voxel(j) - offset
307+
voxel_i = self.graph.node_voxel(i) - offset
308+
voxel_j = self.graph.node_voxel(j) - offset
309309
voxels = geometry_util.make_digital_line(voxel_i, voxel_j)
310310
img_util.annotate_voxels(segment_mask, voxels)
311311
return segment_mask
@@ -317,13 +317,13 @@ def get_contained_subgraph(self, nodes, img_shape, offset):
317317
while queue:
318318
# Visit node
319319
i = queue.pop()
320-
voxel_i = self.graph.get_voxel(i) - offset
320+
voxel_i = self.graph.node_voxel(i) - offset
321321
if not img_util.is_contained(voxel_i, img_shape, buffer=1):
322322
continue
323323

324324
# Update queue
325325
for j in self.graph.neighbors(i):
326-
voxel_j = self.graph.get_voxel(j) - offset
326+
voxel_j = self.graph.node_voxel(j) - offset
327327
if img_util.is_contained(voxel_j, img_shape):
328328
subgraph.add_edge(i, j)
329329
if j not in visited:
@@ -332,7 +332,7 @@ def get_contained_subgraph(self, nodes, img_shape, offset):
332332
return subgraph
333333

334334
def is_contained(self, node):
335-
voxel = self.graph.get_voxel(node)
335+
voxel = self.graph.node_voxel(node)
336336
shape = self.img_reader.shape()[2::]
337337
buffer = np.max(self.patch_shape) + 1
338338
return img_util.is_contained(voxel, shape, buffer=buffer)
@@ -601,7 +601,7 @@ def _generate_batch_nodes(self, root):
601601
if len(patch_centers) == 0 and self.graph.degree[i] > 2:
602602
root = i
603603
nodes.append(i)
604-
patch_centers.append(self.graph.get_voxel(i))
604+
patch_centers.append(self.graph.node_voxel(i))
605605

606606
# Check whether to yield batch
607607
is_node_far = self.graph.dist(root, j) > 256
@@ -619,7 +619,7 @@ def _generate_batch_nodes(self, root):
619619
# Visit j
620620
if self.graph.degree[j] > 2:
621621
nodes.append(j)
622-
patch_centers.append(self.graph.get_voxel(j))
622+
patch_centers.append(self.graph.node_voxel(j))
623623
if len(patch_centers) == 1:
624624
root = j
625625

src/neuron_proofreader/new_proposal_graph.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,8 @@ def __init__(
104104
self.merged_ids = set()
105105
self.n_merges_blocked = 0
106106
self.n_proposals_blocked = 0
107-
self.node_proposals = defaultdict(set)
108-
self.proposals = set()
109107

108+
self.reset_proposals()
110109
self.proposal_generator = ProposalGenerator(
111110
self,
112111
max_proposals_per_leaf=max_proposals_per_leaf,
@@ -163,6 +162,17 @@ def connect_soma_fragments(self, soma_centroids):
163162
results.append(f"# Soma Fragments Merged: {merge_cnt}")
164163
return "\n".join(results)
165164

165+
def relabel_nodes(self):
166+
# Call parent class
167+
old_proposals = self.list_proposals()
168+
old_to_new = super().relabel_nodes()
169+
170+
# Update proposals
171+
self.reset_proposals()
172+
for proposal in old_proposals:
173+
i, j = proposal
174+
self.add_proposal(int(old_to_new[i]), int(old_to_new[j]))
175+
166176
# --- Proposal Operations ---
167177
def add_proposal(self, i, j):
168178
"""
@@ -198,15 +208,16 @@ def generate_proposals(self, search_radius):
198208
# Proposal generation
199209
proposals = self.proposal_generator(search_radius)
200210
self.store_proposals(proposals)
201-
#self.trim_proposals() TEMP
211+
self.trim_proposals()
212+
self.relabel_nodes()
202213

203214
# Set groundtruth
204215
if self.gt_path:
205216
gt_graph = SkeletonGraph(anisotropy=self.anisotropy)
206217
gt_graph.load(self.gt_path)
207218
self.gt_accepts = groundtruth_generation.run(gt_graph, self)
208219

209-
def get_sorted_proposals(self):
220+
def sorted_proposals(self):
210221
"""
211222
Return proposals sorted by physical length.
212223
@@ -324,6 +335,10 @@ def remove_proposal(self, proposal):
324335
self.node_proposals[j].remove(i)
325336
self.proposals.remove(proposal)
326337

338+
def reset_proposals(self):
339+
self.node_proposals = defaultdict(set)
340+
self.proposals = set()
341+
327342
def store_proposals(self, proposals):
328343
self.node_proposals = defaultdict(set)
329344
for proposal in proposals:

src/neuron_proofreader/skeleton_graph.py

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def reassign_component_ids(self):
225225
component_id_to_swc_id = dict()
226226
for i, nodes in enumerate(nx.connected_components(self)):
227227
nodes = np.array(list(nodes), dtype=int)
228-
component_id_to_swc_id[i + 1] = self.get_swc_id(nodes[0])
228+
component_id_to_swc_id[i + 1] = self.node_swc_id(nodes[0])
229229
self.node_component_id[nodes] = i + 1
230230
self.component_id_to_swc_id = component_id_to_swc_id
231231

@@ -240,12 +240,11 @@ def relabel_nodes(self):
240240
# Set edge ids
241241
old_to_new = dict(zip(old_node_ids, new_node_ids))
242242
old_edge_ids = list(self.edges)
243-
edge_attrs = {(i, j): data for i, j, data in self.edges(data=True)}
244243

245244
# Reset graph
246245
self.clear()
247246
for (i, j) in old_edge_ids:
248-
self.add_edge(old_to_new[i], old_to_new[j], **edge_attrs[(i, j)])
247+
self.add_edge(old_to_new[i], old_to_new[j])
249248

250249
# Update attributes
251250
self.node_radius = self.node_radius[old_node_ids]
@@ -255,6 +254,7 @@ def relabel_nodes(self):
255254
self.reassign_component_ids()
256255
if self.kdtree:
257256
self.set_kdtree()
257+
return old_to_new
258258

259259
def remove_nodes(self, nodes, relabel_nodes=True):
260260
"""
@@ -272,7 +272,7 @@ def remove_nodes(self, nodes, relabel_nodes=True):
272272
self.relabel_nodes()
273273

274274
# --- Getters ---
275-
def get_branchings(self):
275+
def branching_nodes(self):
276276
"""
277277
Gets all branching nodes in the graph.
278278
@@ -283,7 +283,13 @@ def get_branchings(self):
283283
"""
284284
return [i for i in self.nodes if self.degree[i] > 2]
285285

286-
def get_connected_nodes(self, root):
286+
def component_id_from_swc_id(self, query_swc_id):
287+
for component_id, swc_id in self.component_id_to_swc_id.items():
288+
if query_swc_id == swc_id:
289+
return component_id
290+
raise ValueError(f"SWC ID={query_swc_id} not found")
291+
292+
def connected_nodes(self, root):
287293
"""
288294
Gets all nodes connected to the given root node.
289295
@@ -307,7 +313,7 @@ def get_connected_nodes(self, root):
307313
visited.add(j)
308314
return visited
309315

310-
def get_leafs(self):
316+
def leaf_nodes(self):
311317
"""
312318
Gets all leaf nodes in the graph.
313319
@@ -318,23 +324,7 @@ def get_leafs(self):
318324
"""
319325
return [i for i in self.nodes if self.degree[i] == 1]
320326

321-
def get_voxel(self, i):
322-
"""
323-
Gets the voxel coordinate of the given node.
324-
325-
Parameters
326-
----------
327-
i : int
328-
Node ID.
329-
330-
Returns
331-
-------
332-
float
333-
Voxel coordinate of the given node.
334-
"""
335-
return img_util.to_voxels(self.node_xyz[i], self.anisotropy)
336-
337-
def get_local_voxel(self, node, offset):
327+
def node_local_voxel(self, node, offset):
338328
"""
339329
Computes the local voxel coordinate of the given node within the image
340330
patch defined by "center" and "patch_shape".
@@ -352,10 +342,9 @@ def get_local_voxel(self, node, offset):
352342
Local voxel coordinate of the given node within the image patch
353343
defined by "center" and "patch_shape".
354344
"""
355-
voxel = self.get_voxel(node)
356-
return tuple([v - o for v, o in zip(voxel, offset)])
345+
return tuple([v - o for v, o in zip(self.node_voxel(node), offset)])
357346

358-
def get_node_segment_id(self, node):
347+
def node_segment_id(self, node):
359348
"""
360349
Gets the segment ID corresponding to the given node.
361350
@@ -371,7 +360,40 @@ def get_node_segment_id(self, node):
371360
"""
372361
return self.get_swc_id(node).split(".")[0]
373362

374-
def get_nodes_with_component_id(self, component_id):
363+
def node_swc_id(self, i):
364+
"""
365+
Gets the SWC ID of the given node.
366+
367+
Parameters
368+
----------
369+
i : int
370+
Node ID.
371+
372+
Returns
373+
-------
374+
str
375+
SWC ID of the given node.
376+
"""
377+
component_id = self.node_component_id[i]
378+
return self.component_id_to_swc_id[component_id]
379+
380+
def node_voxel(self, i):
381+
"""
382+
Gets the voxel coordinate of the given node.
383+
384+
Parameters
385+
----------
386+
i : int
387+
Node ID.
388+
389+
Returns
390+
-------
391+
float
392+
Voxel coordinate of the given node.
393+
"""
394+
return img_util.to_voxels(self.node_xyz[i], self.anisotropy)
395+
396+
def nodes_with_component_id(self, component_id):
375397
"""
376398
Gets all nodes with the given componenet ID.
377399
@@ -387,7 +409,7 @@ def get_nodes_with_component_id(self, component_id):
387409
"""
388410
return set(np.where(self.node_component_id == component_id)[0])
389411

390-
def get_nodes_with_segment_id(self, segment_id):
412+
def nodes_with_segment_id(self, segment_id):
391413
"""
392414
Gets all nodes with the given segment ID.
393415
@@ -406,19 +428,13 @@ def get_nodes_with_segment_id(self, segment_id):
406428
for swc_id in self.get_swc_ids():
407429
segment_id = int(swc_id.replace(".0", ""))
408430
if segment_id == query_id:
409-
component_id = self.get_component_id_from_swc_id(swc_id)
431+
component_id = self.component_id_from_swc_id(swc_id)
410432
nodes = nodes.union(
411-
self.get_nodes_with_component_id(component_id)
433+
self.nodes_with_component_id(component_id)
412434
)
413435
return nodes
414436

415-
def get_component_id_from_swc_id(self, query_swc_id):
416-
for component_id, swc_id in self.component_id_to_swc_id.items():
417-
if query_swc_id == swc_id:
418-
return component_id
419-
raise ValueError(f"SWC ID={query_swc_id} not found")
420-
421-
def nodes_within_distance(self, root, radius):
437+
def nodes_within_distance(self, root, max_depth):
422438
queue = [(root, 0)]
423439
visited = {root}
424440
while queue:
@@ -428,11 +444,28 @@ def nodes_within_distance(self, root, radius):
428444
# Populate queue
429445
for j in self.neighbors(i):
430446
dist_j = dist_i + self.dist(i, j)
431-
if dist_j < radius and j not in visited:
447+
if dist_j < max_depth and j not in visited:
432448
queue.append((j, dist_j))
433449
visited.add(j)
434450
return list(visited)
435451

452+
def path_from_leaf(self, leaf, max_depth=np.inf):
453+
queue = [(leaf, 0)]
454+
path = [leaf]
455+
while queue:
456+
# Visit node
457+
i, dist_i = queue.pop()
458+
if self.degree[i] != 2 and dist_i > 0:
459+
return path
460+
461+
# Update queue
462+
for j in self.neighbors(i):
463+
dist_j = dist_i + self.dist(i, j)
464+
if dist_j < max_depth and j not in path:
465+
queue.append((j, dist_j))
466+
path.append(j)
467+
return path
468+
436469
def rooted_subgraph(self, root, radius):
437470
"""
438471
Gets a rooted subgraph with the given radius (in microns).
@@ -478,23 +511,6 @@ def rooted_subgraph(self, root, radius):
478511
subgraph.node_xyz = self.node_xyz[idxs]
479512
return subgraph
480513

481-
def get_swc_id(self, i):
482-
"""
483-
Gets the SWC ID of the given node.
484-
485-
Parameters
486-
----------
487-
i : int
488-
Node ID.
489-
490-
Returns
491-
-------
492-
str
493-
SWC ID of the given node.
494-
"""
495-
component_id = self.node_component_id[i]
496-
return self.component_id_to_swc_id[component_id]
497-
498514
def get_swc_ids(self):
499515
"""
500516
Gets the set of all unique SWC IDs of nodes in the graph.
@@ -581,7 +597,7 @@ def write_entry(node, parent):
581597
write_entry(j, i)
582598

583599
# Finish
584-
filename = self.get_swc_id(i)
600+
filename = self.node_swc_id(i)
585601
filename = util.set_zip_path(zip_writer, filename, ".swc")
586602
zip_writer.writestr(filename, text_buffer.getvalue())
587603

@@ -625,7 +641,7 @@ def clip_to_bbox(self, metadata_path):
625641
# Clip graph
626642
nodes = list()
627643
for i in self.nodes:
628-
voxel = np.array(self.get_voxel(i))
644+
voxel = np.array(self.node_voxel(i))
629645
if not img_util.is_contained(voxel - origin, shape):
630646
nodes.append(i)
631647
self.remove_nodes_from(nodes)

0 commit comments

Comments
 (0)