Skip to content

Commit 9de1722

Browse files
author
anna-grim
committed
refactor: updated feature generation
1 parent 605b8e4 commit 9de1722

4 files changed

Lines changed: 166 additions & 284 deletions

File tree

src/neuron_proofreader/proposal_graph.py

Lines changed: 12 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -388,25 +388,24 @@ def proposal_avg_radii(self, proposal):
388388
def proposal_directionals(self, proposal, depth):
389389
# Extract points along branches
390390
i, j = tuple(proposal)
391-
xyz_list_i = self.truncated_edge_attr_xyz(i, depth)
392-
xyz_list_j = self.truncated_edge_attr_xyz(j, depth)
393-
origin = self.proposal_midpoint(proposal)
391+
path_i = self.path_thru_node(i, depth)
392+
path_j = self.path_thru_node(j, depth)
393+
path_xyz_i = self.node_xyz[np.array(path_i)]
394+
path_xyz_j = self.node_xyz[np.array(path_j)]
395+
print("Path Lengths:", len(path_i), len(path_j))
394396

395397
# Compute tangent vectors - branches
396-
direction_i = geometry.get_directional(xyz_list_i, origin, depth)
397-
direction_j = geometry.get_directional(xyz_list_j, origin, depth)
398-
direction = geometry.tangent(self.proposal_attr(proposal, "xyz"))
399-
if np.isnan(direction).any():
400-
direction[0] = 0
401-
direction[1] = 0
398+
dir_i = geometry.tangent(path_xyz_i)
399+
dir_j = geometry.tangent(path_xyz_j)
400+
dir_proposal = geometry.tangent(self.proposal_attr(proposal, "xyz"))
402401

403402
# Compute features
404-
dot_i = abs(np.dot(direction, direction_i))
405-
dot_j = abs(np.dot(direction, direction_j))
403+
dot_i = abs(np.dot(dir_proposal, dir_i))
404+
dot_j = abs(np.dot(dir_proposal, dir_j))
406405
if self.is_simple(proposal):
407-
dot_ij = np.dot(direction_i, direction_j)
406+
dot_ij = np.dot(dir_i, dir_j)
408407
else:
409-
dot_ij = np.dot(direction_i, direction_j)
408+
dot_ij = np.dot(dir_i, dir_j)
410409
if not self.is_simple(proposal):
411410
dot_ij = max(dot_ij, -dot_ij)
412411
return np.array([dot_i, dot_j, dot_ij])
@@ -418,67 +417,7 @@ def proposal_midpoint(self, proposal):
418417
i, j = tuple(proposal)
419418
return geometry.midpoint(self.node_xyz[i], self.node_xyz[j])
420419

421-
def truncated_edge_attr_xyz(self, i, depth):
422-
xyz_path_list = self.edge_attr(i, "xyz")
423-
return [geometry.truncate_path(path, depth) for path in xyz_path_list]
424-
425420
# --- Helpers ---
426-
def node_attr(self, i, key):
427-
if key == "xyz":
428-
return self.node_xyz[i]
429-
elif key == "radius":
430-
return self.node_radius[i]
431-
else:
432-
return self.nodes[i][key]
433-
434-
def edge_attr(self, i, key="xyz", ignore=False):
435-
"""
436-
Gets the edge attribute specified by "key" for all edges connected to
437-
the given node.
438-
439-
Parameters
440-
----------
441-
i : int
442-
Node for which the edge attributes are to be retrieved.
443-
key : str, optional
444-
Key specifying the type of edge attribute to retrieve. The default
445-
is "xyz".
446-
ignore : bool, optional
447-
If True, it will only consider direct neighbors of node "i". If
448-
False, the method will follow add the edge attributes along the
449-
path of chain-like connections from node "i" to its neighbors,
450-
provided that the neighbor nodes have degree 2.
451-
452-
Returns
453-
-------
454-
List[numpy.ndarray]
455-
Edge attribute specified by "key" for all edges connected to the
456-
given node.
457-
"""
458-
attrs = list()
459-
for j in self.neighbors(i):
460-
attr_ij = self.orient_edge_attr((i, j), i, key=key)
461-
if not ignore:
462-
root = i
463-
while self.degree[j] == 2:
464-
k = [k for k in self.neighbors(j) if k != root][0]
465-
attr_jk = self.orient_edge_attr((j, k), j, key=key)
466-
if key == "xyz":
467-
attr_ij = np.vstack([attr_ij, attr_jk])
468-
else:
469-
attr_ij = np.concatenate((attr_ij, attr_jk))
470-
root = j
471-
j = k
472-
attrs.append(attr_ij)
473-
return attrs
474-
475-
def edge_length(self, edge):
476-
xyz = self.edges[edge]["xyz"]
477-
if len(xyz) < 2:
478-
return 0.0
479-
else:
480-
return np.linalg.norm(xyz[1:] - xyz[:-1], axis=1).sum()
481-
482421
def find_fragments_near_xyz(self, query_xyz, max_dist):
483422
hits = dict()
484423
xyz_list = geometry.query_ball(self.kdtree, query_xyz, max_dist)
@@ -506,13 +445,6 @@ def is_soma(self, i):
506445
"""
507446
return self.node_component_id[i] in self.soma_ids
508447

509-
def orient_edge_attr(self, edge, i, key="xyz"):
510-
node_attr = self.node_attr(i, key)
511-
if (self.edges[edge][key][0] == node_attr).all():
512-
return self.edges[edge][key]
513-
else:
514-
return np.flip(self.edges[edge][key], axis=0)
515-
516448
def update_component_ids(self, component_id, root):
517449
"""
518450
Updates the component_id of all nodes connected to "root".

src/neuron_proofreader/skeleton_graph.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,36 @@ def path_length(self, path):
769769
else:
770770
return 0
771771

772+
def path_thru_node(self, i, max_depth=np.inf):
773+
if self.degree[i] == 1:
774+
return self.path_from_leaf(i, max_depth)
775+
else:
776+
assert self.degree[i] == 2
777+
j, k = self.neighbors(i)
778+
path_ij = self.directed_path(i, j, max_depth=max_depth)
779+
path_ik = self.directed_path(i, k, max_depth=max_depth)
780+
return path_ij[::-1] + path_ik[1:]
781+
782+
def directed_path(self, start_node, next_node, max_depth=np.inf):
783+
queue = [(next_node, 0)]
784+
visited = [start_node, next_node]
785+
path = list()
786+
while queue:
787+
# Visit node
788+
i, dist_i = queue.pop()
789+
if self.degree[i] != 2:
790+
return path
791+
else:
792+
path.append(i)
793+
794+
# Update queue
795+
for j in self.neighbors(i):
796+
dist_j = dist_i + self.dist(i, j)
797+
if dist_j < max_depth and j not in visited:
798+
queue.append((j, dist_j))
799+
visited.append(j)
800+
return visited
801+
772802
def rooted_subgraph(self, root, radius):
773803
"""
774804
Gets a rooted subgraph with the given radius (in microns).

src/neuron_proofreader/split_proofreading/split_feature_extraction.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
1010
"""
1111

12-
from concurrent.futures import ThreadPoolExecutor, as_completed
12+
from concurrent.futures import as_completed, ThreadPoolExecutor
13+
from scipy.spatial import KDTree
1314
from skimage.transform import resize
1415
from torch_geometric.data import HeteroData
1516

@@ -98,9 +99,7 @@ def __init__(self, graph):
9899
"""
99100
# Instance attributes
100101
self.graph = graph
101-
102-
# Build KD-tree from leaf nodes
103-
self.graph.set_kdtree(node_type="leaf")
102+
self.kdtree = KDTree(graph.node_xyz[np.array(graph.leaf_nodes())])
104103

105104
def __call__(self, subgraph, features):
106105
"""
@@ -398,9 +397,10 @@ def __init__(
398397
self.patch_shape = patch_shape
399398

400399
# Annotate mask
401-
node1, node2 = tuple(self.proposal)
402-
self.annotate_edge(node1)
403-
self.annotate_edge(node2)
400+
i, j = self.proposal
401+
self.voxels = {u: self.get_branch_voxels(u) for u in [i, j]}
402+
self.annotate_edge(i)
403+
self.annotate_edge(j)
404404
self.annotate_proposal()
405405

406406
# --- Core Routines ---
@@ -462,25 +462,8 @@ def get_branch_profile(self, node):
462462
profile : numpy.ndarray
463463
Intensity profile along the branch containing the given node.
464464
"""
465-
def check_emptiness():
466-
"""
467-
Checks if voxels is empty.
468-
"""
469-
if len(voxels) < 2:
470-
voxels.append(self.graph.get_local_voxel(node, self.offset))
471-
472-
# Get branch voxel coordinates
473-
voxels = self.get_branch_voxels(node)
474-
voxels = geometry_util.make_voxels_connected(voxels)
475-
voxels = img_util.get_contained_voxels(voxels, self.mask.shape)
476-
check_emptiness()
477-
478-
# Resample voxels
479-
voxels = np.array(voxels)
480-
voxels = geometry_util.resample_curve_3d(voxels, 16).astype(int)
481-
voxels = img_util.get_contained_voxels(voxels, self.mask.shape)
482-
check_emptiness()
483-
return self._extract_profile(voxels)
465+
profile = self._extract_profile(self.voxels[node])
466+
return geometry_util.resample_curve_1d(profile, 16)
484467

485468
def _extract_profile(self, voxels):
486469
"""
@@ -513,9 +496,7 @@ def annotate_edge(self, node):
513496
node : int
514497
Node ID used to get branch to be annotated.
515498
"""
516-
voxels = self.get_branch_voxels(node)
517-
voxels = geometry_util.make_voxels_connected(voxels)
518-
img_util.annotate_voxels(self.mask, voxels, val=0.5)
499+
img_util.annotate_voxels(self.mask, self.voxels[node], val=0.5)
519500

520501
def annotate_proposal(self):
521502
"""
@@ -540,8 +521,8 @@ def get_profile_line(self, n_pts=None):
540521
Voxel line between the two nodes of a proposal.
541522
"""
542523
node1, node2 = self.proposal
543-
voxel1 = self.graph.get_local_voxel(node1, self.offset)
544-
voxel2 = self.graph.get_local_voxel(node2, self.offset)
524+
voxel1 = self.graph.node_local_voxel(node1, self.offset)
525+
voxel2 = self.graph.node_local_voxel(node2, self.offset)
545526
if n_pts:
546527
return geometry_util.make_line(voxel1, voxel2, n_pts)
547528
else:
@@ -563,10 +544,22 @@ def get_branch_voxels(self, node):
563544
Voxel coordinates representing the edge path in local patch
564545
coordinates.
565546
"""
566-
pts = np.vstack(self.graph.edge_attr(node, "xyz"))
567-
anisotropy = self.graph.anisotropy
568-
voxels = [img_util.to_voxels(xyz, anisotropy) for xyz in pts]
569-
return geometry_util.shift_path(voxels, self.offset)
547+
queue = [(node, self.graph.node_local_voxel(node, self.offset))]
548+
visited = {node}
549+
voxels = list()
550+
while queue:
551+
# Visit node
552+
i, voxel_i = queue.pop()
553+
voxels.append(voxel_i)
554+
555+
# Update queue
556+
for j in self.graph.neighbors(i):
557+
voxel_j = self.graph.node_local_voxel(j, self.offset)
558+
contained_j = img_util.is_contained(voxel_j, self.patch_shape)
559+
if contained_j and j not in visited:
560+
queue.append((j, voxel_j))
561+
visited.add(j)
562+
return geometry_util.make_voxels_connected(voxels)
570563

571564

572565
# --- Feature Data Structures ---
@@ -929,19 +922,22 @@ def __init__(self, object_ids):
929922

930923

931924
# --- Helpers ---
932-
def check_list_length(my_list, min_length=2):
925+
def check_list_length(arr, min_length=2):
933926
"""
934-
Checks that the list contains at least "min_length" items.
927+
Checks that the array contains at least "min_length" items.
935928
936929
Parameters
937930
----------
938-
my_list : list
939-
List to be checked.
931+
arr : list
932+
Array to be checked.
940933
min_length : int
941-
Minimum items that must be contained in the list
934+
Minimum length of the array.
942935
"""
943-
while len(my_list) < min_length:
944-
my_list.append(my_list[-1])
936+
if arr.shape[0] < min_length:
937+
pad_size = min_length - arr.shape[0]
938+
padding = np.repeat(arr[-1:], pad_size, axis=0)
939+
arr = np.concatenate([arr, padding], axis=0)
940+
return arr
945941

946942

947943
def get_feature_dict():

0 commit comments

Comments
 (0)