@@ -225,7 +225,7 @@ def reassign_component_ids(self):
225225 component_id_to_swc_id = dict ()
226226 for i , nodes in enumerate (nx .connected_components (self )):
227227 nodes = np .array (list (nodes ), dtype = int )
228- component_id_to_swc_id [i + 1 ] = self .get_swc_id (nodes [0 ])
228+ component_id_to_swc_id [i + 1 ] = self .node_swc_id (nodes [0 ])
229229 self .node_component_id [nodes ] = i + 1
230230 self .component_id_to_swc_id = component_id_to_swc_id
231231
@@ -240,12 +240,11 @@ def relabel_nodes(self):
240240 # Set edge ids
241241 old_to_new = dict (zip (old_node_ids , new_node_ids ))
242242 old_edge_ids = list (self .edges )
243- edge_attrs = {(i , j ): data for i , j , data in self .edges (data = True )}
244243
245244 # Reset graph
246245 self .clear ()
247246 for (i , j ) in old_edge_ids :
248- self .add_edge (old_to_new [i ], old_to_new [j ], ** edge_attrs [( i , j )] )
247+ self .add_edge (old_to_new [i ], old_to_new [j ])
249248
250249 # Update attributes
251250 self .node_radius = self .node_radius [old_node_ids ]
@@ -255,6 +254,7 @@ def relabel_nodes(self):
255254 self .reassign_component_ids ()
256255 if self .kdtree :
257256 self .set_kdtree ()
257+ return old_to_new
258258
259259 def remove_nodes (self , nodes , relabel_nodes = True ):
260260 """
@@ -272,7 +272,7 @@ def remove_nodes(self, nodes, relabel_nodes=True):
272272 self .relabel_nodes ()
273273
274274 # --- Getters ---
275- def get_branchings (self ):
275+ def branching_nodes (self ):
276276 """
277277 Gets all branching nodes in the graph.
278278
@@ -283,7 +283,13 @@ def get_branchings(self):
283283 """
284284 return [i for i in self .nodes if self .degree [i ] > 2 ]
285285
286- def get_connected_nodes (self , root ):
286+ def component_id_from_swc_id (self , query_swc_id ):
287+ for component_id , swc_id in self .component_id_to_swc_id .items ():
288+ if query_swc_id == swc_id :
289+ return component_id
290+ raise ValueError (f"SWC ID={ query_swc_id } not found" )
291+
292+ def connected_nodes (self , root ):
287293 """
288294 Gets all nodes connected to the given root node.
289295
@@ -307,7 +313,7 @@ def get_connected_nodes(self, root):
307313 visited .add (j )
308314 return visited
309315
310- def get_leafs (self ):
316+ def leaf_nodes (self ):
311317 """
312318 Gets all leaf nodes in the graph.
313319
@@ -318,23 +324,7 @@ def get_leafs(self):
318324 """
319325 return [i for i in self .nodes if self .degree [i ] == 1 ]
320326
321- def get_voxel (self , i ):
322- """
323- Gets the voxel coordinate of the given node.
324-
325- Parameters
326- ----------
327- i : int
328- Node ID.
329-
330- Returns
331- -------
332- float
333- Voxel coordinate of the given node.
334- """
335- return img_util .to_voxels (self .node_xyz [i ], self .anisotropy )
336-
337- def get_local_voxel (self , node , offset ):
327+ def node_local_voxel (self , node , offset ):
338328 """
339329 Computes the local voxel coordinate of the given node within the image
340330 patch defined by "center" and "patch_shape".
@@ -352,10 +342,9 @@ def get_local_voxel(self, node, offset):
352342 Local voxel coordinate of the given node within the image patch
353343 defined by "center" and "patch_shape".
354344 """
355- voxel = self .get_voxel (node )
356- return tuple ([v - o for v , o in zip (voxel , offset )])
345+ return tuple ([v - o for v , o in zip (self .node_voxel (node ), offset )])
357346
358- def get_node_segment_id (self , node ):
347+ def node_segment_id (self , node ):
359348 """
360349 Gets the segment ID corresponding to the given node.
361350
@@ -371,7 +360,40 @@ def get_node_segment_id(self, node):
371360 """
372361 return self .get_swc_id (node ).split ("." )[0 ]
373362
374- def get_nodes_with_component_id (self , component_id ):
363+ def node_swc_id (self , i ):
364+ """
365+ Gets the SWC ID of the given node.
366+
367+ Parameters
368+ ----------
369+ i : int
370+ Node ID.
371+
372+ Returns
373+ -------
374+ str
375+ SWC ID of the given node.
376+ """
377+ component_id = self .node_component_id [i ]
378+ return self .component_id_to_swc_id [component_id ]
379+
380+ def node_voxel (self , i ):
381+ """
382+ Gets the voxel coordinate of the given node.
383+
384+ Parameters
385+ ----------
386+ i : int
387+ Node ID.
388+
389+ Returns
390+ -------
391+ float
392+ Voxel coordinate of the given node.
393+ """
394+ return img_util .to_voxels (self .node_xyz [i ], self .anisotropy )
395+
396+ def nodes_with_component_id (self , component_id ):
375397 """
376398 Gets all nodes with the given componenet ID.
377399
@@ -387,7 +409,7 @@ def get_nodes_with_component_id(self, component_id):
387409 """
388410 return set (np .where (self .node_component_id == component_id )[0 ])
389411
390- def get_nodes_with_segment_id (self , segment_id ):
412+ def nodes_with_segment_id (self , segment_id ):
391413 """
392414 Gets all nodes with the given segment ID.
393415
@@ -406,19 +428,13 @@ def get_nodes_with_segment_id(self, segment_id):
406428 for swc_id in self .get_swc_ids ():
407429 segment_id = int (swc_id .replace (".0" , "" ))
408430 if segment_id == query_id :
409- component_id = self .get_component_id_from_swc_id (swc_id )
431+ component_id = self .component_id_from_swc_id (swc_id )
410432 nodes = nodes .union (
411- self .get_nodes_with_component_id (component_id )
433+ self .nodes_with_component_id (component_id )
412434 )
413435 return nodes
414436
415- def get_component_id_from_swc_id (self , query_swc_id ):
416- for component_id , swc_id in self .component_id_to_swc_id .items ():
417- if query_swc_id == swc_id :
418- return component_id
419- raise ValueError (f"SWC ID={ query_swc_id } not found" )
420-
421- def nodes_within_distance (self , root , radius ):
437+ def nodes_within_distance (self , root , max_depth ):
422438 queue = [(root , 0 )]
423439 visited = {root }
424440 while queue :
@@ -428,11 +444,28 @@ def nodes_within_distance(self, root, radius):
428444 # Populate queue
429445 for j in self .neighbors (i ):
430446 dist_j = dist_i + self .dist (i , j )
431- if dist_j < radius and j not in visited :
447+ if dist_j < max_depth and j not in visited :
432448 queue .append ((j , dist_j ))
433449 visited .add (j )
434450 return list (visited )
435451
452+ def path_from_leaf (self , leaf , max_depth = np .inf ):
453+ queue = [(leaf , 0 )]
454+ path = [leaf ]
455+ while queue :
456+ # Visit node
457+ i , dist_i = queue .pop ()
458+ if self .degree [i ] != 2 and dist_i > 0 :
459+ return path
460+
461+ # Update queue
462+ for j in self .neighbors (i ):
463+ dist_j = dist_i + self .dist (i , j )
464+ if dist_j < max_depth and j not in path :
465+ queue .append ((j , dist_j ))
466+ path .append (j )
467+ return path
468+
436469 def rooted_subgraph (self , root , radius ):
437470 """
438471 Gets a rooted subgraph with the given radius (in microns).
@@ -478,23 +511,6 @@ def rooted_subgraph(self, root, radius):
478511 subgraph .node_xyz = self .node_xyz [idxs ]
479512 return subgraph
480513
481- def get_swc_id (self , i ):
482- """
483- Gets the SWC ID of the given node.
484-
485- Parameters
486- ----------
487- i : int
488- Node ID.
489-
490- Returns
491- -------
492- str
493- SWC ID of the given node.
494- """
495- component_id = self .node_component_id [i ]
496- return self .component_id_to_swc_id [component_id ]
497-
498514 def get_swc_ids (self ):
499515 """
500516 Gets the set of all unique SWC IDs of nodes in the graph.
@@ -581,7 +597,7 @@ def write_entry(node, parent):
581597 write_entry (j , i )
582598
583599 # Finish
584- filename = self .get_swc_id (i )
600+ filename = self .node_swc_id (i )
585601 filename = util .set_zip_path (zip_writer , filename , ".swc" )
586602 zip_writer .writestr (filename , text_buffer .getvalue ())
587603
@@ -625,7 +641,7 @@ def clip_to_bbox(self, metadata_path):
625641 # Clip graph
626642 nodes = list ()
627643 for i in self .nodes :
628- voxel = np .array (self .get_voxel (i ))
644+ voxel = np .array (self .node_voxel (i ))
629645 if not img_util .is_contained (voxel - origin , shape ):
630646 nodes .append (i )
631647 self .remove_nodes_from (nodes )
0 commit comments