Skip to content

Commit be3fb42

Browse files
committed
Update Figures: naming and documentation
1 parent e2421a6 commit be3fb42

File tree

13 files changed

+898
-697
lines changed

13 files changed

+898
-697
lines changed

flamingo_tools/segmentation/cochlea_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def map_frequency(table: pd.DataFrame, animal: str = "mouse", otof: bool = False
557557
if otof and animal == "mouse":
558558
# freq_min = 4.84 kHz
559559
# freq_max = 78.8 kHz
560-
# Mueller, Hearing Research 202 (2005) 6373, https://doi.org/10.1016/j.heares.2004.08.011
560+
# Mueller, Hearing Research 202 (2005) 63-73, https://doi.org/10.1016/j.heares.2004.08.011
561561
# function has format f(x) = 10 ** (a * (k - (1-x)))
562562
var_a = 100 / 82.5
563563
var_k = 1.565

scripts/baselines/eval_baseline.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import imageio.v3 as imageio
55
import numpy as np
66
import pandas as pd
7+
from glob import glob
8+
from pathlib import Path
79

810
from flamingo_tools.validation import compute_matches_for_annotated_slice
911

@@ -291,12 +293,69 @@ def print_accuracy_ihc():
291293
print_accuracy(os.path.join(seg_dir, baseline))
292294

293295

296+
def runtimes_sgn():
297+
for_comparison = ["distance_unet", "micro-sam", "cellpose3", "cellpose-sam", "stardist"]
298+
299+
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
300+
val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn"
301+
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation"
302+
303+
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
304+
305+
runtimes = {name: [] for name in for_comparison}
306+
307+
for path in image_paths:
308+
eval_fname = Path(path).stem + "_dic.json"
309+
for seg_name in for_comparison:
310+
eval_path = os.path.join(val_sgn_dir, seg_name, eval_fname)
311+
with open(eval_path, "r") as f:
312+
result = json.load(f)
313+
rt = result["time"]
314+
runtimes[seg_name].append(rt)
315+
316+
for name, rts in runtimes.items():
317+
print(name, ":", np.mean(rts), "+-", np.std(rts))
318+
319+
320+
def runtimes_ihc():
321+
for_comparison = ["distance_unet_v3", "micro-sam", "cellpose3", "cellpose-sam"]
322+
323+
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
324+
val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc"
325+
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs"
326+
327+
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
328+
329+
runtimes = {name: [] for name in for_comparison}
330+
331+
for path in image_paths:
332+
eval_fname = Path(path).stem + "_dic.json"
333+
for seg_name in for_comparison:
334+
eval_path = os.path.join(val_sgn_dir, seg_name, eval_fname)
335+
if not os.path.exists(eval_path):
336+
continue
337+
with open(eval_path, "r") as f:
338+
result = json.load(f)
339+
rt = result["time"]
340+
runtimes[seg_name].append(rt)
341+
342+
for name, rts in runtimes.items():
343+
print(name, ":", np.mean(rts), "+-", np.std(rts))
344+
345+
294346
def main():
295347
eval_all_sgn()
296348
eval_all_ihc()
297349
print_accuracy_sgn()
298350
print_accuracy_ihc()
299351

352+
# average runtimes and standard deviation
353+
print("SGNs:")
354+
runtimes_sgn()
355+
print()
356+
print("IHCs:")
357+
runtimes_ihc()
358+
300359

301360
if __name__ == "__main__":
302361
main()

scripts/figures/plot_fig2.py

Lines changed: 16 additions & 213 deletions
Original file line numberDiff line numberDiff line change
@@ -21,206 +21,11 @@
2121
COLOR_LITERATURE = "#27339C"
2222

2323

24-
def plot_legend_suppfig02(save_path):
25-
"""Plot common legend for figure 2c.
26-
27-
Args:
28-
save_path: save path to save legend.
29-
"""
30-
# Colors
31-
color = [COLOR_P, COLOR_R, COLOR_F, COLOR_T]
32-
label = ["Precision", "Recall", "F1-score", "Processing time"]
33-
34-
handles = [get_flatline_handle(c) for c in color]
35-
legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False)
36-
export_legend(legend, save_path)
37-
legend.remove()
38-
plt.close()
39-
40-
41-
def supp_fig_02(save_path, plot=False, segm="SGN", mode="precision"):
42-
# SGN
43-
value_dict = {
44-
"SGN": {
45-
"stardist": {
46-
"label": "Stardist",
47-
"precision": 0.706,
48-
"recall": 0.630,
49-
"f1-score": 0.628,
50-
"marker": "*",
51-
"runtime": 536.5,
52-
"runtime_std": 148.4
53-
},
54-
"micro_sam": {
55-
"label": "µSAM",
56-
"precision": 0.140,
57-
"recall": 0.782,
58-
"f1-score": 0.228,
59-
"marker": "D",
60-
"runtime": 407.5,
61-
"runtime_std": 107.5
62-
},
63-
"cellpose_3": {
64-
"label": "Cellpose 3",
65-
"precision": 0.117,
66-
"recall": 0.607,
67-
"f1-score": 0.186,
68-
"marker": "v",
69-
"runtime": 167.9116359,
70-
"runtime_std": 40.2,
71-
},
72-
"cellpose_sam": {
73-
"label": "Cellpose-SAM",
74-
"precision": 0.250,
75-
"recall": 0.003,
76-
"f1-score": 0.005,
77-
"marker": "^",
78-
"runtime": 2232.007748,
79-
"runtime_std": None,
80-
},
81-
"spiner2D": {
82-
"label": "Spiner",
83-
"precision": 0.373,
84-
"recall": 0.340,
85-
"f1-score": 0.326,
86-
"marker": "o",
87-
"runtime": None,
88-
"runtime_std": None,
89-
},
90-
"distance_unet": {
91-
"label": "CochleaNet",
92-
"precision": 0.886,
93-
"recall": 0.804,
94-
"f1-score": 0.837,
95-
"marker": "s",
96-
"runtime": 168.8,
97-
"runtime_std": 21.8
98-
},
99-
},
100-
"IHC": {
101-
"micro_sam": {
102-
"label": "µSAM",
103-
"precision": 0.053,
104-
"recall": 0.684,
105-
"f1-score": 0.094,
106-
"marker": "D",
107-
"runtime": 445.6,
108-
"runtime_std": 106.6
109-
},
110-
"cellpose_3": {
111-
"label": "Cellpose 3",
112-
"precision": 0.375,
113-
"recall": 0.554,
114-
"f1-score": 0.329,
115-
"marker": "v",
116-
"runtime": 162.3493934,
117-
"runtime_std": 30.1,
118-
},
119-
"cellpose_sam": {
120-
"label": "Cellpose-SAM",
121-
"precision": 0.636,
122-
"recall": 0.025,
123-
"f1-score": 0.047,
124-
"marker": "^",
125-
"runtime": 2137.944779,
126-
"runtime_std": None
127-
},
128-
"distance_unet": {
129-
"label": "CochleaNet",
130-
"precision": 0.693,
131-
"recall": 0.567,
132-
"f1-score": 0.618,
133-
"marker": "s",
134-
"runtime": 69.01,
135-
"runtime_std": None
136-
},
137-
}
138-
}
139-
140-
# Convert setting labels to numerical x positions
141-
offset = 0.08 # horizontal shift for scatter separation
142-
143-
# Plot
144-
tick_rotation = 0
145-
146-
main_label_size = 20
147-
main_tick_size = 16
148-
marker_size = 200
149-
150-
labels = [value_dict[segm][key]["label"] for key in value_dict[segm].keys()]
151-
152-
if mode == "precision":
153-
fig, ax = plt.subplots(figsize=(10, 5))
154-
# Convert setting labels to numerical x positions
155-
offset = 0.08 # horizontal shift for scatter separation
156-
for num, key in enumerate(list(value_dict[segm].keys())):
157-
precision = [value_dict[segm][key]["precision"]]
158-
recall = [value_dict[segm][key]["recall"]]
159-
f1score = [value_dict[segm][key]["f1-score"]]
160-
marker = value_dict[segm][key]["marker"]
161-
x_pos = num + 1
162-
163-
plt.scatter([x_pos - offset], precision, label="Precision manual",
164-
color=COLOR_P, marker=marker, s=marker_size)
165-
plt.scatter([x_pos], recall, label="Recall manual",
166-
color=COLOR_R, marker=marker, s=marker_size)
167-
plt.scatter([x_pos + offset], f1score, label="F1-score manual",
168-
color=COLOR_F, marker=marker, s=marker_size)
169-
170-
# Labels and formatting
171-
x_pos = np.arange(1, len(labels)+1)
172-
plt.xticks(x_pos, labels, fontsize=main_tick_size, rotation=tick_rotation)
173-
plt.yticks(fontsize=main_tick_size)
174-
plt.ylabel("Value", fontsize=main_label_size)
175-
plt.ylim(-0.1, 1)
176-
# plt.legend(loc="lower right", fontsize=legendsize)
177-
plt.grid(axis="y", linestyle="solid", alpha=0.5)
178-
179-
elif mode == "runtime":
180-
fig, ax = plt.subplots(figsize=(8.5, 5))
181-
if "Spiner" in labels:
182-
labels.remove("Spiner")
183-
184-
# Convert setting labels to numerical x positions
185-
offset = 0.08 # horizontal shift for scatter separation
186-
x_pos = 1
187-
for num, key in enumerate(list(value_dict[segm].keys())):
188-
runtime = [value_dict[segm][key]["runtime"]]
189-
if runtime[0] is None:
190-
continue
191-
marker = value_dict[segm][key]["marker"]
192-
plt.scatter([x_pos], runtime, label="Runtime", color=COLOR_T, marker=marker, s=marker_size)
193-
x_pos = x_pos + 1
194-
195-
# Labels and formatting
196-
x_pos = np.arange(1, len(labels)+1)
197-
plt.xticks(x_pos, labels, fontsize=16, rotation=tick_rotation)
198-
plt.yticks(fontsize=main_tick_size)
199-
plt.ylabel("Processing time [s]", fontsize=main_label_size)
200-
plt.ylim(10, 2600)
201-
plt.yscale('log')
202-
# plt.legend(loc="lower right", fontsize=legendsize)
203-
plt.grid(axis="y", linestyle="solid", alpha=0.5)
204-
205-
else:
206-
raise ValueError("Unsupported mode for plotting.")
207-
208-
plt.tight_layout()
209-
prism_cleanup_axes(ax)
210-
211-
if ".png" in save_path:
212-
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
213-
else:
214-
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
215-
216-
if plot:
217-
plt.show()
218-
else:
219-
plt.close()
220-
221-
222-
def plot_legend_fig02c(save_path, plot_mode="shapes"):
223-
"""Plot common legend for figure 2c.
24+
def plot_legend_fig02c(
25+
save_path: str,
26+
plot_mode: str = "shapes",
27+
):
28+
"""Plot common legend for Figure 2c.
22429
22530
Args:.
22631
save_path: save path to save legend.
@@ -253,7 +58,10 @@ def plot_legend_fig02c(save_path, plot_mode="shapes"):
25358
raise ValueError("Choose either 'shapes' or 'colors' as plot_mode.")
25459

25560

256-
def fig_02c(save_path, plot=False, all_versions=False):
61+
def fig_02c(
62+
save_path: str,
63+
plot: bool = False,
64+
):
25765
"""Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual),
25866
IHC (distance U-Net, manual), and synapse detection (U-Net).
25967
"""
@@ -336,7 +144,11 @@ def _load_ribbon_synapse_counts():
336144
return syn_counts
337145

338146

339-
def fig_02d(save_path, plot=False, plot_average_ribbon_synapses=False):
147+
def fig_02d(
148+
save_path: str,
149+
plot: bool = False,
150+
plot_average_ribbon_synapses: bool = False,
151+
):
340152
"""Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values.
341153
"""
342154
prism_style()
@@ -446,31 +258,22 @@ def fig_02d(save_path, plot=False, plot_average_ribbon_synapses=False):
446258

447259

448260
def main():
449-
parser = argparse.ArgumentParser(description="Generate plots for Fig 2 of the cochlea paper.")
261+
parser = argparse.ArgumentParser(description="Generate plots for Figure 2 of the CochleaNet paper.")
450262
parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig2")
451263
parser.add_argument("--plot", action="store_true")
452264
args = parser.parse_args()
453265

454266
os.makedirs(args.figure_dir, exist_ok=True)
455267

456268
# Panel C: Evaluation of the segmentation results:
457-
fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot, all_versions=False)
269+
fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot)
458270
plot_legend_fig02c(os.path.join(args.figure_dir, f"fig_02c_legend_shapes.{FILE_EXTENSION}"), plot_mode="shapes")
459271
plot_legend_fig02c(os.path.join(args.figure_dir, f"fig_02c_legend_colors.{FILE_EXTENSION}"), plot_mode="colors")
460272

461273
# Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC
462274
fig_02d(save_path=os.path.join(args.figure_dir, f"fig_02d.{FILE_EXTENSION}"),
463275
plot=args.plot, plot_average_ribbon_synapses=True)
464276

465-
# Supplementary Figure 2: Comparing other methods in terms of segmentation accuracy and runtime
466-
plot_legend_suppfig02(save_path=os.path.join(args.figure_dir, f"figsupp_02_legend_colors.{FILE_EXTENSION}"))
467-
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn.{FILE_EXTENSION}"), segm="SGN")
468-
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc.{FILE_EXTENSION}"), segm="IHC")
469-
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn_time.{FILE_EXTENSION}"),
470-
segm="SGN", mode="runtime")
471-
supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc_time.{FILE_EXTENSION}"),
472-
segm="IHC", mode="runtime")
473-
474277

475278
if __name__ == "__main__":
476279
main()

0 commit comments

Comments
 (0)