11import argparse
22import json
33import os
4- from typing import Optional
4+ from typing import List , Optional
55
66import pandas as pd
77from flamingo_tools .s3_utils import get_s3_path
@@ -67,86 +67,117 @@ def label_custom_components(tsv_table, custom_dict):
6767 return tsv_table
6868
6969
70- def repro_label_components (
71- ddict : dict ,
72- output_dir : str ,
70+ def _load_json_as_list (ddict_path : str ) -> List [dict ]:
71+ with open (ddict_path , "r" ) as f :
72+ data = json .loads (f .read ())
73+ # ensure the result is always a list
74+ return data if isinstance (data , list ) else [data ]
75+
76+
77+ def label_components_single (
78+ table_path : str ,
79+ out_path : str ,
80+ cell_type : str = "sgn" ,
81+ component_list : List [int ] = [1 ],
82+ max_edge_distance : float = 30 ,
83+ min_component_length : int = 50 ,
84+ min_size : int = 1000 ,
85+ s3 : bool = False ,
7386 s3_credentials : Optional [str ] = None ,
7487 s3_bucket_name : Optional [str ] = None ,
7588 s3_service_endpoint : Optional [str ] = None ,
89+ custom_dic : Optional [dict ] = None ,
90+ ** _
7691):
77- default_cell_type = "sgn"
78- default_component_list = [1 ]
79- default_iterations_erode = None
80- default_max_edge_distance = 30
81- default_min_length = 50
82- default_min_size = 1000
83- default_seg_channel = "SGN_v2"
84- default_threshold_erode = None
85-
86- with open (ddict , "r" ) as myfile :
87- data = myfile .read ()
88- param_dicts = json .loads (data )
89-
90- for dic in param_dicts :
91- cochlea = dic ["cochlea" ]
92- print (f"\n { cochlea } " )
93-
94- cell_type = dic .get ("cell_type" , default_cell_type )
95- component_list = dic .get ("component_list" , default_component_list )
96- iterations_erode = dic .get ("iterations_erode" , default_iterations_erode )
97- max_edge_distance = dic .get ("max_edge_distance" , default_max_edge_distance )
98- min_component_length = dic .get ("min_component_length" , default_min_length )
99- min_size = dic .get ("min_size" , default_min_size )
100- table_name = dic .get ("segmentation_channel" , default_seg_channel )
101- threshold_erode = dic .get ("threshold_erode" , default_threshold_erode )
102-
103- s3_path = os .path .join (f"{ cochlea } " , "tables" , table_name , "default.tsv" )
104- tsv_path , fs = get_s3_path (s3_path , bucket_name = s3_bucket_name ,
92+ """Process a single cochlea using one set of parameters or a custom_dic.
93+ """
94+ if s3 :
95+ tsv_path , fs = get_s3_path (table_path , bucket_name = s3_bucket_name ,
10596 service_endpoint = s3_service_endpoint , credential_file = s3_credentials )
106- with fs .open (tsv_path , "r" ) as f :
107- table = pd .read_csv (f , sep = "\t " )
108-
109- if "custom_dic" in list (dic .keys ()):
110- print (len (table [table ["component_labels" ] == 1 ]))
111- tsv_table = label_custom_components (table , dic ["custom_dic" ])
97+ with fs .open (tsv_path , "r" ) as f :
98+ table = pd .read_csv (f , sep = "\t " )
99+
100+ if custom_dic is not None :
101+ tsv_table = label_custom_components (table , custom_dic )
102+ else :
103+ if cell_type == "sgn" :
104+ tsv_table = label_components_sgn (table , min_size = min_size ,
105+ min_component_length = min_component_length ,
106+ max_edge_distance = max_edge_distance )
107+ elif cell_type == "ihc" :
108+ tsv_table = label_components_ihc (table , min_size = min_size ,
109+ min_component_length = min_component_length ,
110+ max_edge_distance = max_edge_distance )
112111 else :
113- if cell_type == "sgn" :
114- tsv_table = label_components_sgn (table , min_size = min_size ,
115- threshold_erode = threshold_erode ,
116- min_component_length = min_component_length ,
117- max_edge_distance = max_edge_distance ,
118- iterations_erode = iterations_erode )
119- elif cell_type == "ihc" :
120- tsv_table = label_components_ihc (table , min_size = min_size ,
121- min_component_length = min_component_length ,
122- max_edge_distance = max_edge_distance )
123- else :
124- raise ValueError ("Choose a supported cell type. Either 'sgn' or 'ihc'." )
112+ raise ValueError ("Choose a supported cell type. Either 'sgn' or 'ihc'." )
125113
126- custom_comp = len (tsv_table [tsv_table ["component_labels" ].isin (component_list )])
127- print (f"Total { cell_type .upper ()} s: { len (tsv_table )} " )
128- if component_list == [1 ]:
129- print (f"Largest component has { custom_comp } { cell_type .upper ()} s." )
130- else :
131- for comp in component_list :
132- print (f"Component { comp } has { len (tsv_table [tsv_table ["component_labels" ] == comp ])} instances." )
133- print (f"Custom component(s) have { custom_comp } { cell_type .upper ()} s." )
114+ custom_comp = len (tsv_table [tsv_table ["component_labels" ].isin (component_list )])
115+ print (f"Total { cell_type .upper ()} s: { len (tsv_table )} " )
116+ if component_list == [1 ]:
117+ print (f"Largest component has { custom_comp } { cell_type .upper ()} s." )
118+ else :
119+ for comp in component_list :
120+ num_instances = len (tsv_table [tsv_table ["component_labels" ] == comp ])
121+ print (f"Component { comp } has { num_instances } instances." )
122+ print (f"Custom component(s) have { custom_comp } { cell_type .upper ()} s." )
134123
135- cochlea_str = "-" .join (cochlea .split ("_" ))
136- table_str = "-" .join (table_name .split ("_" ))
137- os .makedirs (output_dir , exist_ok = True )
138- out_path = os .path .join (output_dir , "_" .join ([cochlea_str , f"{ table_str } .tsv" ]))
124+ tsv_table .to_csv (out_path , sep = "\t " , index = False )
139125
140- tsv_table .to_csv (out_path , sep = "\t " , index = False )
126+
127+ def repro_label_components (
128+ output_path : str ,
129+ table_path : Optional [str ] = None ,
130+ ddict : Optional [str ] = None ,
131+ ** kwargs
132+ ):
133+ """Wrapper function for labeling connected components using a segmentation table.
134+ The function is used to distinguish between a passed parameter dictionary in JSON format
135+ and the explicit setting of parameters.
136+ """
137+ if ddict is None :
138+ label_components_single (table_path , output_path , ** kwargs )
139+ else :
140+ param_dicts = _load_json_as_list (ddict )
141+ for params in param_dicts :
142+
143+ cochlea = params ["cochlea" ]
144+ print (f"\n { cochlea } " )
145+ seg_channel = params ["segmentation_channel" ]
146+ table_path = os .path .join (f"{ cochlea } " , "tables" , seg_channel , "default.tsv" )
147+
148+ if os .path .isdir (output_path ):
149+ cochlea_str = "-" .join (cochlea .split ("_" ))
150+ table_str = "-" .join (seg_channel .split ("_" ))
151+ save_path = os .path .join (output_path , "_" .join ([cochlea_str , f"{ table_str } .tsv" ]))
152+ else :
153+ save_path = output_path
154+ label_components_single (table_path = table_path , out_path = save_path , ** params )
141155
142156
143157def main ():
144158 parser = argparse .ArgumentParser (
145159 description = "Script to label segmentation using a segmentation table and graph connected components." )
146160
147- parser .add_argument ("-i" , "--input" , type = str , required = True , help = "Input JSON dictionary." )
148- parser .add_argument ("-o" , "--output" , type = str , required = True , help = "Output directory." )
161+ parser .add_argument ("-o" , "--output" , type = str , required = True ,
162+ help = "Output path. Either directory or specific file." )
163+
164+ parser .add_argument ("-i" , "--input" , type = str , default = None , help = "Input path to segmentation table." )
165+ parser .add_argument ("-j" , "--json" , type = str , default = None , help = "Input JSON dictionary." )
166+
167+ parser .add_argument ("--cell_type" , type = str , default = "sgn" ,
168+ help = "Cell type of segmentation. Either 'sgn' or 'ihc'." )
169+
170+ # options for post-processing
171+ parser .add_argument ("--min_size" , type = int , default = 1000 ,
172+ help = "Minimal number of pixels for filtering small instances." )
173+ parser .add_argument ("--min_component_length" , type = int , default = 50 ,
174+ help = "Minimal length for filtering out connected components." )
175+ parser .add_argument ("--max_edge_distance" , type = float , default = 30 ,
176+ help = "Maximal distance in micrometer between points to create edges for connected components." )
177+ parser .add_argument ("-c" , "--components" , type = str , nargs = "+" , default = [1 ], help = "List of connected components." )
149178
179+ # options for S3 bucket
180+ parser .add_argument ("--s3" , action = "store_true" , help = "Flag for using S3 bucket." )
150181 parser .add_argument ("--s3_credentials" , type = str , default = None ,
151182 help = "Input file containing S3 credentials. "
152183 "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported." )
@@ -158,8 +189,18 @@ def main():
158189 args = parser .parse_args ()
159190
160191 repro_label_components (
161- args .input , args .output ,
162- args .s3_credentials , args .s3_bucket_name , args .s3_service_endpoint ,
192+ output_path = args .output ,
193+ table_path = args .input ,
194+ ddict = args .json ,
195+ cell_type = args .cell_type ,
196+ component_list = args .components ,
197+ max_edge_distance = args .max_edge_distance ,
198+ min_component_length = args .min_component_length ,
199+ min_size = args .min_size ,
200+ s3 = args .s3 ,
201+ s3_credentials = args .s3_credentials ,
202+ s3_bucket_name = args .s3_bucket_name ,
203+ s3_service_endpoint = args .s3_service_endpoint ,
163204 )
164205
165206
0 commit comments