Skip to content

Commit ff6039e

Browse files
anna-grimanna-grim
andauthored
refactor: simplified soma loading (#411)
Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 64b132f commit ff6039e

5 files changed

Lines changed: 105 additions & 157 deletions

File tree

src/deep_neurographs/fragments_graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
remove_high_risk_merges=False,
7272
segmentation_path=None,
7373
smooth_bool=True,
74-
somas_path=None,
74+
soma_centroids=None,
7575
verbose=False,
7676
):
7777
"""
@@ -99,9 +99,8 @@ def __init__(
9999
smooth_bool : bool, optional
100100
Indication of whether to smooth xyz coordinates from SWC files.
101101
The default is True.
102-
somas_path : str, optional
103-
Path to a txt file containing xyz coordinates of detected somas.
104-
The default is None.
102+
soma_centroids : List[Tuple[float]] or None, optional
103+
Physcial coordinates of soma centroids. The default is None.
105104
verbose : bool, optional
106105
Indication of whether to display a progress bar while building
107106
FragmentsGraph. The default is True.
@@ -123,7 +122,7 @@ def __init__(
123122
remove_high_risk_merges=remove_high_risk_merges,
124123
segmentation_path=segmentation_path,
125124
smooth_bool=smooth_bool,
126-
somas_path=somas_path,
125+
soma_centroids=soma_centroids,
127126
verbose=verbose,
128127
)
129128
self.swc_reader = swc_util.Reader(anisotropy, min_size)

src/deep_neurographs/inference.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from torch.nn.functional import sigmoid
3737
from tqdm import tqdm
3838

39-
import ast
4039
import networkx as nx
4140
import numpy as np
4241
import os
@@ -68,7 +67,7 @@ def __init__(
6867
output_dir,
6968
config,
7069
segmentation_path=None,
71-
somas_path=None,
70+
soma_centroids=None,
7271
s3_dict=None,
7372
):
7473
"""
@@ -92,9 +91,8 @@ def __init__(
9291
for the inference pipeline.
9392
segmentation_path : str, optional
9493
Path to segmentation stored in GCS bucket. The default is None.
95-
somas_path : str, optional
96-
Path to a txt file containing xyz coordinates of detected somas.
97-
The default is None.
94+
soma_centroids : List[Tuple[float]] or None, optional
95+
Physcial coordinates of soma centroids. The default is None.
9896
s3_dict : dict, optional
9997
...
10098
@@ -110,7 +108,7 @@ def __init__(
110108
self.brain_id = brain_id
111109
self.segmentation_id = segmentation_id
112110
self.segmentation_path = segmentation_path
113-
self.somas_path = somas_path
111+
self.soma_centroids = soma_centroids
114112
self.s3_dict = s3_dict
115113

116114
# Extract config settings
@@ -148,14 +146,14 @@ def run(self, fragments_pointer):
148146

149147
# Main
150148
self.build_graph(fragments_pointer)
151-
self.connect_soma_fragments() if self.somas_path else None
149+
self.connect_soma_fragments() if self.soma_centroids else None
152150
self.generate_proposals(self.graph_config.search_radius)
153151
self.classify_proposals(self.ml_config.threshold)
154152

155153
# Finish
156154
t, unit = util.time_writer(time() - t0)
157155
self.report_graph(prefix="\nFinal")
158-
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")
156+
self.report(f"Total Runtime: {t:.2f} {unit}\n")
159157
self.save_results()
160158

161159
def run_schedule(
@@ -177,7 +175,7 @@ def run_schedule(
177175
# Finish
178176
t, unit = util.time_writer(time() - t0)
179177
self.report_graph(prefix="\nFinal")
180-
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")
178+
self.report(f"Total Runtime: {t:.2f} {unit}\n")
181179
self.save_results()
182180

183181
def build_graph(self, fragments_pointer):
@@ -207,7 +205,7 @@ def build_graph(self, fragments_pointer):
207205
remove_high_risk_merges=self.graph_config.remove_high_risk_merges,
208206
segmentation_path=self.segmentation_path,
209207
smooth_bool=self.graph_config.smooth_bool,
210-
somas_path=self.somas_path,
208+
soma_centroids=self.soma_centroids,
211209
verbose=True,
212210
)
213211
self.graph.load_fragments(fragments_pointer)
@@ -219,7 +217,7 @@ def build_graph(self, fragments_pointer):
219217

220218
t, unit = util.time_writer(time() - t0)
221219
self.report_graph(prefix="\nInitial")
222-
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")
220+
self.report(f"Module Runtime: {t:.2f} {unit}\n")
223221

224222
def filter_fragments(self):
225223
self.graph = fragment_filtering.remove_curvy(self.graph, 200)
@@ -235,7 +233,7 @@ def connect_soma_fragments(self):
235233
# Parse locations
236234
nodes_list = list()
237235
merge_cnt, soma_cnt = 0, 0
238-
for soma_xyz in util.load_soma_locations(self.somas_path):
236+
for soma_xyz in self.soma_centroids:
239237
hits = self.graph.find_fragments_near_xyz(soma_xyz, 20)
240238
if len(hits) > 1:
241239
# Determine new swc id
@@ -295,13 +293,13 @@ def generate_proposals(self, radius):
295293
proposals_per_leaf=self.graph_config.proposals_per_leaf,
296294
trim_endpoints_bool=self.graph_config.trim_endpoints_bool,
297295
)
298-
n_proposals = util.reformat_number(self.graph.n_proposals())
296+
n_proposals = format(self.graph.n_proposals(), ",")
299297

300298
# Report results
301299
t, unit = util.time_writer(time() - t0)
302300
self.report(f"# Proposals: {n_proposals}")
303301
self.report(f"# Proposals Blocked: {self.graph.n_proposals_blocked}")
304-
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")
302+
self.report(f"Module Runtime: {t:.2f} {unit}\n")
305303

306304
def classify_proposals(self, accept_threshold):
307305
"""
@@ -341,9 +339,9 @@ def classify_proposals(self, accept_threshold):
341339
# Report results
342340
t, unit = util.time_writer(time() - t0)
343341
self.report(f"# Merges Blocked: {self.graph.n_merges_blocked}")
344-
self.report(f"# Accepted: {util.reformat_number(len(accepts))}")
345-
self.report(f"% Accepted: {round(len(accepts) / n_proposals, 4)}")
346-
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")
342+
self.report(f"# Accepted: {format(len(accepts), ',')}")
343+
self.report(f"% Accepted: {len(accepts) / n_proposals:.4f}")
344+
self.report(f"Module Runtime: {t:.4f} {unit}\n")
347345

348346
def save_results(self):
349347
"""
@@ -447,7 +445,7 @@ def write_metadata(self):
447445
"min_fragment_size": f"{self.graph_config.min_size}um",
448446
"node_spacing": self.graph_config.node_spacing,
449447
"remove_doubles": self.graph_config.remove_doubles,
450-
"use_somas": self.segmentation_path and self.somas_path,
448+
"use_somas": len(self.soma_centroids) > 0,
451449
"complex_proposals": self.graph_config.complex_bool,
452450
"long_range_bool": self.graph_config.long_range_bool,
453451
"proposals_per_leaf": self.graph_config.proposals_per_leaf,
@@ -486,17 +484,16 @@ def report_graph(self, prefix="\n"):
486484
"""
487485
# Compute values
488486
n_components = nx.number_connected_components(self.graph)
489-
n_components = util.reformat_number(n_components)
490-
n_nodes = util.reformat_number(self.graph.number_of_nodes())
491-
n_edges = util.reformat_number(self.graph.number_of_edges())
492-
usage = round(util.get_memory_usage(), 2)
487+
n_components = format(n_components, ",")
488+
n_nodes = format(self.graph.number_of_nodes(), ",")
489+
n_edges = format(self.graph.number_of_edges(), ",")
493490

494491
# Report
495492
self.report(f"{prefix} Graph")
496493
self.report(f"# Connected Components: {n_components}")
497494
self.report(f"# Nodes: {n_nodes}")
498495
self.report(f"# Edges: {n_edges}")
499-
self.report(f"Memory Consumption: {usage} GBs")
496+
self.report(f"Memory Consumption: {util.get_memory_usage():.2f} GBs")
500497

501498

502499
class InferenceEngine:

src/deep_neurographs/machine_learning/groundtruth_generation.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def find_alignments(gt_graph, pred_graph, kdtree):
102102
)
103103
if aligned:
104104
i = util.sample_once(nodes)
105-
pred_id = pred_graph.nodes[i]["swc_id"]
105+
pred_id = pred_graph.nodes[i]["swc_id"]
106106
valid_ids.add(pred_id)
107107
pred_to_target[pred_id] = target_id
108108
return pred_to_target
@@ -150,7 +150,7 @@ def is_component_aligned(gt_graph, pred_graph, nodes, kdtree):
150150

151151
intersects = True if percent_aligned > 0.5 else False
152152
aligned_score = np.mean(dists[dists < np.percentile(dists, 80)])
153-
153+
154154
# Deterine whether aligned
155155
if (aligned_score < ALIGNED_THRESHOLD and hat_swc_id) and intersects:
156156
return True, hat_swc_id
@@ -196,7 +196,7 @@ def is_valid(gt_graph, pred_graph, kdtree, gt_id, proposal):
196196
if is_connected(hat_edge_i, hat_edge_j):
197197
# Orient ground truth edges
198198
hat_edge_xyz_i, hat_edge_xyz_j = orient_edges(
199-
gt_graph.edges[hat_edge_i]["xyz"],
199+
gt_graph.edges[hat_edge_i]["xyz"],
200200
gt_graph.edges[hat_edge_j]["xyz"]
201201
)
202202

@@ -303,5 +303,3 @@ def length_to_idx(xyz_list, idx):
303303
for i in range(0, idx):
304304
length += geometry_util.dist(xyz_list[i], xyz_list[i + 1])
305305
return length
306-
307-

src/deep_neurographs/utils/graph_util.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
from scipy.spatial import KDTree
3535
from tqdm import tqdm
3636

37-
import ast
38-
import multiprocessing
3937
import networkx as nx
4038
import numpy as np
4139
import os
@@ -61,7 +59,7 @@ def __init__(
6159
remove_high_risk_merges=False,
6260
segmentation_path=None,
6361
smooth_bool=True,
64-
somas_path=None,
62+
soma_centroids=None,
6563
verbose=False,
6664
):
6765
"""
@@ -90,9 +88,8 @@ def __init__(
9088
smooth_bool : bool, optional
9189
Indication of whether to smooth xyz coordinates from SWC files.
9290
The default is True.
93-
somas_path : str, optional
94-
Path to a txt file containing xyz coordinates of detected somas.
95-
The default is None.
91+
soma_centroids : List[Tuple[float]] or None, optional
92+
Physcial coordinates of soma centroids. The default is None.
9693
verbose : bool, optional
9794
Indication of whether to display a progress bar while building
9895
FragmentsGraph. The default is True.
@@ -108,65 +105,54 @@ def __init__(
108105
self.node_spacing = node_spacing
109106
self.prune_depth = prune_depth
110107
self.smooth_bool = smooth_bool
111-
self.soma_kdtree = None
108+
self.soma_centroids = soma_centroids
112109
self.verbose = verbose
113110

114111
# Set irreducibles extracter
115-
if somas_path and remove_high_risk_merges:
112+
if soma_centroids and remove_high_risk_merges:
116113
self.extracter = self.break_and_extract
117114
else:
118115
self.extracter = self.extract
119116

120117
# Load somas
121-
if segmentation_path and somas_path:
122-
self.load_somas(segmentation_path, somas_path)
118+
if segmentation_path and soma_centroids:
119+
self.soma_kdtree = KDTree(self.soma_centroids)
120+
self.ingest_somas(segmentation_path)
121+
else:
122+
self.soma_kdtree = None
123123

124-
def load_somas(self, segmentation_path, somas_path):
124+
def ingest_somas(self, segmentation_path):
125125
"""
126-
Loads soma locations from a specified file and detects merges in a
127-
segmentation.
126+
Loads soma locations from a specified file and search for interestions
127+
between soma locations and objects in segmentation mask.
128128
129129
Parameters
130130
----------
131131
segmentation_path : str
132-
Path to segmentation stored in GCS bucket. The default is None.
133-
somas_path : str
134-
Path to a txt file containing xyz coordinates of detected somas.
132+
Path to segmentation stored in GCS bucket.
135133
136134
Returns
137135
-------
138136
None
139137
140138
"""
141-
# Read soma locations txt file
142-
if isinstance(somas_path, str):
143-
xyz_list = util.read_txt(somas_path)
144-
elif isinstance(somas_path, dict):
145-
xyz_list = util.read_s3_txt_file(somas_path)
146-
else:
147-
raise Exception(f"Invalid format - somas_path={somas_path}")
148-
149-
# Process soma locations
150139
reader = img_util.TensorStoreReader(segmentation_path)
151140
with ThreadPoolExecutor() as executor:
152141
# Assign threads
153142
threads = list()
154-
for xyz in util.load_soma_locations(somas_path):
143+
for xyz in self.soma_centroids:
155144
voxel = img_util.to_voxels(xyz, (0.748, 0.748, 1.0))
156145
threads.append(executor.submit(reader.read_voxel, voxel, xyz))
157146

158147
# Store results
159-
soma_xyz_list = list()
160148
for thread in as_completed(threads):
161149
xyz, seg_id = thread.result()
162150
if seg_id != 0:
163151
self.id_to_soma[str(seg_id)].append(xyz)
164-
soma_xyz_list.append(xyz)
165-
self.soma_kdtree = KDTree(soma_xyz_list)
166152

167153
# Report results
168154
if self.verbose:
169-
print("# Somas:", len(soma_xyz_list))
155+
print("# Somas:", len(self.soma_centroids))
170156
print("# Soma-Fragment Intersections:", len(self.id_to_soma))
171157

172158
# --- Irreducibles Extraction ---
@@ -198,7 +184,6 @@ def extract_irreducibles(self, swc_dicts):
198184
high_risk_cnt = 0
199185
desc = "Extract Graphs"
200186
pbar = tqdm(total=len(swc_dicts), desc=desc) if self.verbose else None
201-
multiprocessing.set_start_method('spawn', force=True)
202187
with ProcessPoolExecutor() as executor:
203188
processes = list()
204189
while swc_dicts:

0 commit comments

Comments
 (0)