3636from torch .nn .functional import sigmoid
3737from tqdm import tqdm
3838
39- import ast
4039import networkx as nx
4140import numpy as np
4241import os
@@ -68,7 +67,7 @@ def __init__(
6867 output_dir ,
6968 config ,
7069 segmentation_path = None ,
71- somas_path = None ,
70+ soma_centroids = None ,
7271 s3_dict = None ,
7372 ):
7473 """
@@ -92,9 +91,8 @@ def __init__(
9291 for the inference pipeline.
9392 segmentation_path : str, optional
9493 Path to segmentation stored in GCS bucket. The default is None.
95- somas_path : str, optional
96- Path to a txt file containing xyz coordinates of detected somas.
97- The default is None.
94+ soma_centroids : List[Tuple[float]] or None, optional
95+ Physcial coordinates of soma centroids. The default is None.
9896 s3_dict : dict, optional
9997 ...
10098
@@ -110,7 +108,7 @@ def __init__(
110108 self .brain_id = brain_id
111109 self .segmentation_id = segmentation_id
112110 self .segmentation_path = segmentation_path
113- self .somas_path = somas_path
111+ self .soma_centroids = soma_centroids
114112 self .s3_dict = s3_dict
115113
116114 # Extract config settings
@@ -148,14 +146,14 @@ def run(self, fragments_pointer):
148146
149147 # Main
150148 self .build_graph (fragments_pointer )
151- self .connect_soma_fragments () if self .somas_path else None
149+ self .connect_soma_fragments () if self .soma_centroids else None
152150 self .generate_proposals (self .graph_config .search_radius )
153151 self .classify_proposals (self .ml_config .threshold )
154152
155153 # Finish
156154 t , unit = util .time_writer (time () - t0 )
157155 self .report_graph (prefix = "\n Final" )
158- self .report (f"Total Runtime: { round ( t , 4 ) } { unit } \n " )
156+ self .report (f"Total Runtime: { t :.2f } { unit } \n " )
159157 self .save_results ()
160158
161159 def run_schedule (
@@ -177,7 +175,7 @@ def run_schedule(
177175 # Finish
178176 t , unit = util .time_writer (time () - t0 )
179177 self .report_graph (prefix = "\n Final" )
180- self .report (f"Total Runtime: { round ( t , 4 ) } { unit } \n " )
178+ self .report (f"Total Runtime: { t :.2f } { unit } \n " )
181179 self .save_results ()
182180
183181 def build_graph (self , fragments_pointer ):
@@ -207,7 +205,7 @@ def build_graph(self, fragments_pointer):
207205 remove_high_risk_merges = self .graph_config .remove_high_risk_merges ,
208206 segmentation_path = self .segmentation_path ,
209207 smooth_bool = self .graph_config .smooth_bool ,
210- somas_path = self .somas_path ,
208+ soma_centroids = self .soma_centroids ,
211209 verbose = True ,
212210 )
213211 self .graph .load_fragments (fragments_pointer )
@@ -219,7 +217,7 @@ def build_graph(self, fragments_pointer):
219217
220218 t , unit = util .time_writer (time () - t0 )
221219 self .report_graph (prefix = "\n Initial" )
222- self .report (f"Module Runtime: { round ( t , 4 ) } { unit } \n " )
220+ self .report (f"Module Runtime: { t :.2f } { unit } \n " )
223221
224222 def filter_fragments (self ):
225223 self .graph = fragment_filtering .remove_curvy (self .graph , 200 )
@@ -235,7 +233,7 @@ def connect_soma_fragments(self):
235233 # Parse locations
236234 nodes_list = list ()
237235 merge_cnt , soma_cnt = 0 , 0
238- for soma_xyz in util . load_soma_locations ( self .somas_path ) :
236+ for soma_xyz in self .soma_centroids :
239237 hits = self .graph .find_fragments_near_xyz (soma_xyz , 20 )
240238 if len (hits ) > 1 :
241239 # Determine new swc id
@@ -295,13 +293,13 @@ def generate_proposals(self, radius):
295293 proposals_per_leaf = self .graph_config .proposals_per_leaf ,
296294 trim_endpoints_bool = self .graph_config .trim_endpoints_bool ,
297295 )
298- n_proposals = util . reformat_number (self .graph .n_proposals ())
296+ n_proposals = format (self .graph .n_proposals (), "," )
299297
300298 # Report results
301299 t , unit = util .time_writer (time () - t0 )
302300 self .report (f"# Proposals: { n_proposals } " )
303301 self .report (f"# Proposals Blocked: { self .graph .n_proposals_blocked } " )
304- self .report (f"Module Runtime: { round ( t , 4 ) } { unit } \n " )
302+ self .report (f"Module Runtime: { t :.2f } { unit } \n " )
305303
306304 def classify_proposals (self , accept_threshold ):
307305 """
@@ -341,9 +339,9 @@ def classify_proposals(self, accept_threshold):
341339 # Report results
342340 t , unit = util .time_writer (time () - t0 )
343341 self .report (f"# Merges Blocked: { self .graph .n_merges_blocked } " )
344- self .report (f"# Accepted: { util . reformat_number (len (accepts ))} " )
345- self .report (f"% Accepted: { round ( len (accepts ) / n_proposals , 4 ) } " )
346- self .report (f"Module Runtime: { round ( t , 4 ) } { unit } \n " )
342+ self .report (f"# Accepted: { format (len (accepts ), ',' )} " )
343+ self .report (f"% Accepted: { len (accepts ) / n_proposals :.4f } " )
344+ self .report (f"Module Runtime: { t :.4f } { unit } \n " )
347345
348346 def save_results (self ):
349347 """
@@ -447,7 +445,7 @@ def write_metadata(self):
447445 "min_fragment_size" : f"{ self .graph_config .min_size } um" ,
448446 "node_spacing" : self .graph_config .node_spacing ,
449447 "remove_doubles" : self .graph_config .remove_doubles ,
450- "use_somas" : self .segmentation_path and self . somas_path ,
448+ "use_somas" : len ( self .soma_centroids ) > 0 ,
451449 "complex_proposals" : self .graph_config .complex_bool ,
452450 "long_range_bool" : self .graph_config .long_range_bool ,
453451 "proposals_per_leaf" : self .graph_config .proposals_per_leaf ,
@@ -486,17 +484,16 @@ def report_graph(self, prefix="\n"):
486484 """
487485 # Compute values
488486 n_components = nx .number_connected_components (self .graph )
489- n_components = util .reformat_number (n_components )
490- n_nodes = util .reformat_number (self .graph .number_of_nodes ())
491- n_edges = util .reformat_number (self .graph .number_of_edges ())
492- usage = round (util .get_memory_usage (), 2 )
487+ n_components = format (n_components , "," )
488+ n_nodes = format (self .graph .number_of_nodes (), "," )
489+ n_edges = format (self .graph .number_of_edges (), "," )
493490
494491 # Report
495492 self .report (f"{ prefix } Graph" )
496493 self .report (f"# Connected Components: { n_components } " )
497494 self .report (f"# Nodes: { n_nodes } " )
498495 self .report (f"# Edges: { n_edges } " )
499- self .report (f"Memory Consumption: { usage } GBs" )
496+ self .report (f"Memory Consumption: { util . get_memory_usage ():.2f } GBs" )
500497
501498
502499class InferenceEngine :
0 commit comments