Skip to content

Commit ff2d551

Browse files
committed
Re-factor component labeling for usage with CLI
1 parent be3fb42 commit ff2d551

File tree

2 files changed

+111
-139
lines changed

2 files changed

+111
-139
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 3 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -363,110 +363,43 @@ def graph_connected_components(coords: dict, max_edge_distance: float, min_compo
363363

364364
def components_sgn(
365365
table: pd.DataFrame,
366-
keyword: str = "distance_nn100",
367-
threshold_erode: Optional[float] = None,
368366
min_component_length: int = 50,
369367
max_edge_distance: float = 30,
370-
iterations_erode: int = 0,
371-
postprocess_threshold: Optional[float] = None,
372-
postprocess_components: Optional[List[int]] = None,
373368
) -> List[List[int]]:
374369
"""Eroding the SGN segmentation.
375370
376371
Args:
377372
table: Dataframe of segmentation table.
378-
keyword: Keyword of the dataframe column for erosion.
379-
threshold_erode: Threshold of column value after erosion step with spatial statistics.
380373
min_component_length: Minimal length for filtering out connected components.
381374
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
382-
iterations_erode: Number of iterations for erosion.
383-
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
384-
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
385375
386376
Returns:
387377
Subgraph components as lists of label_ids of dataframe.
388378
"""
389-
if keyword not in table:
390-
distance_avg = nearest_neighbor_distance(table, n_neighbors=100)
391-
table.loc[:, keyword] = list(distance_avg)
392-
393379
centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"]))
394380
labels = [int(i) for i in list(table["label_id"])]
395-
396-
distance_nn = list(table[keyword])
397-
distance_nn.sort()
398-
399-
if len(table) < 20000:
400-
min_cells = None
401-
average_dist = int(distance_nn[int(len(table) * 0.8)])
402-
threshold = threshold_erode if threshold_erode is not None else average_dist
403-
else:
404-
min_cells = 20000
405-
threshold = threshold_erode if threshold_erode is not None else 40
406-
407-
if iterations_erode != 0 and iterations_erode is not None:
408-
print(f"Using threshold of {threshold} micrometer for eroding segmentation with keyword {keyword}.")
409-
new_subset = erode_subset(table.copy(), iterations=iterations_erode,
410-
threshold=threshold, min_cells=min_cells, keyword=keyword)
411-
else:
412-
new_subset = table.copy()
413-
414-
# create graph from coordinates of eroded subset
415-
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
416-
labels_subset = [int(i) for i in list(new_subset["label_id"])]
417381
coords = {}
418-
for index, element in zip(labels_subset, centroids_subset):
382+
for index, element in zip(labels, centroids):
419383
coords[index] = element
420384

421385
components, _ = graph_connected_components(coords, max_edge_distance, min_component_length)
422386

423-
# add original coordinates closer to eroded component than threshold
424-
if postprocess_threshold is not None:
425-
if postprocess_components is None:
426-
pp_components = components
427-
else:
428-
pp_components = [components[i] for i in postprocess_components]
429-
430-
add_coords = []
431-
for label_id, centr in zip(labels, centroids):
432-
if label_id not in labels_subset:
433-
add_coord = []
434-
for comp_index, component in enumerate(pp_components):
435-
for comp_label in component:
436-
dist = math.dist(centr, centroids[comp_label - 1])
437-
if dist <= postprocess_threshold:
438-
add_coord.append([comp_index, label_id])
439-
break
440-
if len(add_coord) != 0:
441-
add_coords.append(add_coord)
442-
if len(add_coords) != 0:
443-
for c in add_coords:
444-
components[c[0][0]].append(c[0][1])
445-
446387
return components
447388

448389

449390
def label_components_sgn(
450391
table: pd.DataFrame,
451392
min_size: int = 1000,
452-
threshold_erode: Optional[float] = None,
453393
min_component_length: int = 50,
454394
max_edge_distance: float = 30,
455-
iterations_erode: int = 0,
456-
postprocess_threshold: Optional[float] = None,
457-
postprocess_components: Optional[List[int]] = None,
458395
) -> List[int]:
459396
"""Label SGN components using graph connected components.
460397
461398
Args:
462399
table: Dataframe of segmentation table.
463400
min_size: Minimal number of pixels for filtering small instances.
464-
threshold_erode: Threshold of column value after erosion step with spatial statistics.
465401
min_component_length: Minimal length for filtering out connected components.
466402
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
467-
iterations_erode: Number of iterations for erosion.
468-
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
469-
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
470403
471404
Returns:
472405
List of component label for each point in dataframe. 0 - background, then in descending order of size
@@ -476,10 +409,8 @@ def label_components_sgn(
476409
entries_filtered = table[table.n_pixels < min_size]
477410
table = table[table.n_pixels >= min_size]
478411

479-
components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
480-
max_edge_distance=max_edge_distance, iterations_erode=iterations_erode,
481-
postprocess_threshold=postprocess_threshold,
482-
postprocess_components=postprocess_components)
412+
components = components_sgn(table, min_component_length=min_component_length,
413+
max_edge_distance=max_edge_distance)
483414

484415
# add size-filtered objects to have same initial length
485416
table = pd.concat([table, entries_filtered], ignore_index=True)

reproducibility/label_components/repro_label_components.py

Lines changed: 108 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
import json
33
import os
4-
from typing import Optional
4+
from typing import List, Optional
55

66
import pandas as pd
77
from flamingo_tools.s3_utils import get_s3_path
@@ -67,86 +67,117 @@ def label_custom_components(tsv_table, custom_dict):
6767
return tsv_table
6868

6969

70-
def repro_label_components(
71-
ddict: dict,
72-
output_dir: str,
70+
def _load_json_as_list(ddict_path: str) -> List[dict]:
71+
with open(ddict_path, "r") as f:
72+
data = json.loads(f.read())
73+
# ensure the result is always a list
74+
return data if isinstance(data, list) else [data]
75+
76+
77+
def label_components_single(
78+
table_path: str,
79+
out_path: str,
80+
cell_type: str = "sgn",
81+
component_list: List[int] = [1],
82+
max_edge_distance: float = 30,
83+
min_component_length: int = 50,
84+
min_size: int = 1000,
85+
s3: bool = False,
7386
s3_credentials: Optional[str] = None,
7487
s3_bucket_name: Optional[str] = None,
7588
s3_service_endpoint: Optional[str] = None,
89+
custom_dic: Optional[dict] = None,
90+
**_
7691
):
77-
default_cell_type = "sgn"
78-
default_component_list = [1]
79-
default_iterations_erode = None
80-
default_max_edge_distance = 30
81-
default_min_length = 50
82-
default_min_size = 1000
83-
default_seg_channel = "SGN_v2"
84-
default_threshold_erode = None
85-
86-
with open(ddict, "r") as myfile:
87-
data = myfile.read()
88-
param_dicts = json.loads(data)
89-
90-
for dic in param_dicts:
91-
cochlea = dic["cochlea"]
92-
print(f"\n{cochlea}")
93-
94-
cell_type = dic.get("cell_type", default_cell_type)
95-
component_list = dic.get("component_list", default_component_list)
96-
iterations_erode = dic.get("iterations_erode", default_iterations_erode)
97-
max_edge_distance = dic.get("max_edge_distance", default_max_edge_distance)
98-
min_component_length = dic.get("min_component_length", default_min_length)
99-
min_size = dic.get("min_size", default_min_size)
100-
table_name = dic.get("segmentation_channel", default_seg_channel)
101-
threshold_erode = dic.get("threshold_erode", default_threshold_erode)
102-
103-
s3_path = os.path.join(f"{cochlea}", "tables", table_name, "default.tsv")
104-
tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name,
92+
"""Process a single cochlea using one set of parameters or a custom_dic.
93+
"""
94+
if s3:
95+
tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name,
10596
service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
106-
with fs.open(tsv_path, "r") as f:
107-
table = pd.read_csv(f, sep="\t")
108-
109-
if "custom_dic" in list(dic.keys()):
110-
print(len(table[table["component_labels"] == 1]))
111-
tsv_table = label_custom_components(table, dic["custom_dic"])
97+
with fs.open(tsv_path, "r") as f:
98+
table = pd.read_csv(f, sep="\t")
99+
100+
if custom_dic is not None:
101+
tsv_table = label_custom_components(table, custom_dic)
102+
else:
103+
if cell_type == "sgn":
104+
tsv_table = label_components_sgn(table, min_size=min_size,
105+
min_component_length=min_component_length,
106+
max_edge_distance=max_edge_distance)
107+
elif cell_type == "ihc":
108+
tsv_table = label_components_ihc(table, min_size=min_size,
109+
min_component_length=min_component_length,
110+
max_edge_distance=max_edge_distance)
112111
else:
113-
if cell_type == "sgn":
114-
tsv_table = label_components_sgn(table, min_size=min_size,
115-
threshold_erode=threshold_erode,
116-
min_component_length=min_component_length,
117-
max_edge_distance=max_edge_distance,
118-
iterations_erode=iterations_erode)
119-
elif cell_type == "ihc":
120-
tsv_table = label_components_ihc(table, min_size=min_size,
121-
min_component_length=min_component_length,
122-
max_edge_distance=max_edge_distance)
123-
else:
124-
raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.")
112+
raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.")
125113

126-
custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)])
127-
print(f"Total {cell_type.upper()}s: {len(tsv_table)}")
128-
if component_list == [1]:
129-
print(f"Largest component has {custom_comp} {cell_type.upper()}s.")
130-
else:
131-
for comp in component_list:
132-
print(f"Component {comp} has {len(tsv_table[tsv_table["component_labels"] == comp])} instances.")
133-
print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.")
114+
custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)])
115+
print(f"Total {cell_type.upper()}s: {len(tsv_table)}")
116+
if component_list == [1]:
117+
print(f"Largest component has {custom_comp} {cell_type.upper()}s.")
118+
else:
119+
for comp in component_list:
120+
num_instances = len(tsv_table[tsv_table["component_labels"] == comp])
121+
print(f"Component {comp} has {num_instances} instances.")
122+
print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.")
134123

135-
cochlea_str = "-".join(cochlea.split("_"))
136-
table_str = "-".join(table_name.split("_"))
137-
os.makedirs(output_dir, exist_ok=True)
138-
out_path = os.path.join(output_dir, "_".join([cochlea_str, f"{table_str}.tsv"]))
124+
tsv_table.to_csv(out_path, sep="\t", index=False)
139125

140-
tsv_table.to_csv(out_path, sep="\t", index=False)
126+
127+
def repro_label_components(
128+
output_path: str,
129+
table_path: Optional[str] = None,
130+
ddict: Optional[str] = None,
131+
**kwargs
132+
):
133+
"""Wrapper function for labeling connected components using a segmentation table.
134+
The function is used to distinguish between a passed parameter dictionary in JSON format
135+
and the explicit setting of parameters.
136+
"""
137+
if ddict is None:
138+
label_components_single(table_path, output_path, **kwargs)
139+
else:
140+
param_dicts = _load_json_as_list(ddict)
141+
for params in param_dicts:
142+
143+
cochlea = params["cochlea"]
144+
print(f"\n{cochlea}")
145+
seg_channel = params["segmentation_channel"]
146+
table_path = os.path.join(f"{cochlea}", "tables", seg_channel, "default.tsv")
147+
148+
if os.path.isdir(output_path):
149+
cochlea_str = "-".join(cochlea.split("_"))
150+
table_str = "-".join(seg_channel.split("_"))
151+
save_path = os.path.join(output_path, "_".join([cochlea_str, f"{table_str}.tsv"]))
152+
else:
153+
save_path = output_path
154+
label_components_single(table_path=table_path, out_path=save_path, **params)
141155

142156

143157
def main():
144158
parser = argparse.ArgumentParser(
145159
description="Script to label segmentation using a segmentation table and graph connected components.")
146160

147-
parser.add_argument("-i", "--input", type=str, required=True, help="Input JSON dictionary.")
148-
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory.")
161+
parser.add_argument("-o", "--output", type=str, required=True,
162+
help="Output path. Either directory or specific file.")
163+
164+
parser.add_argument("-i", "--input", type=str, default=None, help="Input path to segmentation table.")
165+
parser.add_argument("-j", "--json", type=str, default=None, help="Input JSON dictionary.")
166+
167+
parser.add_argument("--cell_type", type=str, default="sgn",
168+
help="Cell type of segmentation. Either 'sgn' or 'ihc'.")
169+
170+
# options for post-processing
171+
parser.add_argument("--min_size", type=int, default=1000,
172+
help="Minimal number of pixels for filtering small instances.")
173+
parser.add_argument("--min_component_length", type=int, default=50,
174+
help="Minimal length for filtering out connected components.")
175+
parser.add_argument("--max_edge_distance", type=float, default=30,
176+
help="Maximal distance in micrometer between points to create edges for connected components.")
177+
parser.add_argument("-c", "--components", type=str, nargs="+", default=[1], help="List of connected components.")
149178

179+
# options for S3 bucket
180+
parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.")
150181
parser.add_argument("--s3_credentials", type=str, default=None,
151182
help="Input file containing S3 credentials. "
152183
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
@@ -158,8 +189,18 @@ def main():
158189
args = parser.parse_args()
159190

160191
repro_label_components(
161-
args.input, args.output,
162-
args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint,
192+
output_path=args.output,
193+
table_path=args.input,
194+
ddict=args.json,
195+
cell_type=args.cell_type,
196+
component_list=args.components,
197+
max_edge_distance=args.max_edge_distance,
198+
min_component_length=args.min_component_length,
199+
min_size=args.min_size,
200+
s3=args.s3,
201+
s3_credentials=args.s3_credentials,
202+
s3_bucket_name=args.s3_bucket_name,
203+
s3_service_endpoint=args.s3_service_endpoint,
163204
)
164205

165206

0 commit comments

Comments
 (0)