|
21 | 21 | COLOR_LITERATURE = "#27339C" |
22 | 22 |
|
23 | 23 |
|
24 | | -def plot_legend_suppfig02(save_path): |
25 | | - """Plot common legend for figure 2c. |
26 | | -
|
27 | | - Args: |
28 | | - save_path: save path to save legend. |
29 | | - """ |
30 | | - # Colors |
31 | | - color = [COLOR_P, COLOR_R, COLOR_F, COLOR_T] |
32 | | - label = ["Precision", "Recall", "F1-score", "Processing time"] |
33 | | - |
34 | | - handles = [get_flatline_handle(c) for c in color] |
35 | | - legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False) |
36 | | - export_legend(legend, save_path) |
37 | | - legend.remove() |
38 | | - plt.close() |
39 | | - |
40 | | - |
41 | | -def supp_fig_02(save_path, plot=False, segm="SGN", mode="precision"): |
42 | | - # SGN |
43 | | - value_dict = { |
44 | | - "SGN": { |
45 | | - "stardist": { |
46 | | - "label": "Stardist", |
47 | | - "precision": 0.706, |
48 | | - "recall": 0.630, |
49 | | - "f1-score": 0.628, |
50 | | - "marker": "*", |
51 | | - "runtime": 536.5, |
52 | | - "runtime_std": 148.4 |
53 | | - }, |
54 | | - "micro_sam": { |
55 | | - "label": "µSAM", |
56 | | - "precision": 0.140, |
57 | | - "recall": 0.782, |
58 | | - "f1-score": 0.228, |
59 | | - "marker": "D", |
60 | | - "runtime": 407.5, |
61 | | - "runtime_std": 107.5 |
62 | | - }, |
63 | | - "cellpose_3": { |
64 | | - "label": "Cellpose 3", |
65 | | - "precision": 0.117, |
66 | | - "recall": 0.607, |
67 | | - "f1-score": 0.186, |
68 | | - "marker": "v", |
69 | | - "runtime": 167.9116359, |
70 | | - "runtime_std": 40.2, |
71 | | - }, |
72 | | - "cellpose_sam": { |
73 | | - "label": "Cellpose-SAM", |
74 | | - "precision": 0.250, |
75 | | - "recall": 0.003, |
76 | | - "f1-score": 0.005, |
77 | | - "marker": "^", |
78 | | - "runtime": 2232.007748, |
79 | | - "runtime_std": None, |
80 | | - }, |
81 | | - "spiner2D": { |
82 | | - "label": "Spiner", |
83 | | - "precision": 0.373, |
84 | | - "recall": 0.340, |
85 | | - "f1-score": 0.326, |
86 | | - "marker": "o", |
87 | | - "runtime": None, |
88 | | - "runtime_std": None, |
89 | | - }, |
90 | | - "distance_unet": { |
91 | | - "label": "CochleaNet", |
92 | | - "precision": 0.886, |
93 | | - "recall": 0.804, |
94 | | - "f1-score": 0.837, |
95 | | - "marker": "s", |
96 | | - "runtime": 168.8, |
97 | | - "runtime_std": 21.8 |
98 | | - }, |
99 | | - }, |
100 | | - "IHC": { |
101 | | - "micro_sam": { |
102 | | - "label": "µSAM", |
103 | | - "precision": 0.053, |
104 | | - "recall": 0.684, |
105 | | - "f1-score": 0.094, |
106 | | - "marker": "D", |
107 | | - "runtime": 445.6, |
108 | | - "runtime_std": 106.6 |
109 | | - }, |
110 | | - "cellpose_3": { |
111 | | - "label": "Cellpose 3", |
112 | | - "precision": 0.375, |
113 | | - "recall": 0.554, |
114 | | - "f1-score": 0.329, |
115 | | - "marker": "v", |
116 | | - "runtime": 162.3493934, |
117 | | - "runtime_std": 30.1, |
118 | | - }, |
119 | | - "cellpose_sam": { |
120 | | - "label": "Cellpose-SAM", |
121 | | - "precision": 0.636, |
122 | | - "recall": 0.025, |
123 | | - "f1-score": 0.047, |
124 | | - "marker": "^", |
125 | | - "runtime": 2137.944779, |
126 | | - "runtime_std": None |
127 | | - }, |
128 | | - "distance_unet": { |
129 | | - "label": "CochleaNet", |
130 | | - "precision": 0.693, |
131 | | - "recall": 0.567, |
132 | | - "f1-score": 0.618, |
133 | | - "marker": "s", |
134 | | - "runtime": 69.01, |
135 | | - "runtime_std": None |
136 | | - }, |
137 | | - } |
138 | | - } |
139 | | - |
140 | | - # Convert setting labels to numerical x positions |
141 | | - offset = 0.08 # horizontal shift for scatter separation |
142 | | - |
143 | | - # Plot |
144 | | - tick_rotation = 0 |
145 | | - |
146 | | - main_label_size = 20 |
147 | | - main_tick_size = 16 |
148 | | - marker_size = 200 |
149 | | - |
150 | | - labels = [value_dict[segm][key]["label"] for key in value_dict[segm].keys()] |
151 | | - |
152 | | - if mode == "precision": |
153 | | - fig, ax = plt.subplots(figsize=(10, 5)) |
154 | | - # Convert setting labels to numerical x positions |
155 | | - offset = 0.08 # horizontal shift for scatter separation |
156 | | - for num, key in enumerate(list(value_dict[segm].keys())): |
157 | | - precision = [value_dict[segm][key]["precision"]] |
158 | | - recall = [value_dict[segm][key]["recall"]] |
159 | | - f1score = [value_dict[segm][key]["f1-score"]] |
160 | | - marker = value_dict[segm][key]["marker"] |
161 | | - x_pos = num + 1 |
162 | | - |
163 | | - plt.scatter([x_pos - offset], precision, label="Precision manual", |
164 | | - color=COLOR_P, marker=marker, s=marker_size) |
165 | | - plt.scatter([x_pos], recall, label="Recall manual", |
166 | | - color=COLOR_R, marker=marker, s=marker_size) |
167 | | - plt.scatter([x_pos + offset], f1score, label="F1-score manual", |
168 | | - color=COLOR_F, marker=marker, s=marker_size) |
169 | | - |
170 | | - # Labels and formatting |
171 | | - x_pos = np.arange(1, len(labels)+1) |
172 | | - plt.xticks(x_pos, labels, fontsize=main_tick_size, rotation=tick_rotation) |
173 | | - plt.yticks(fontsize=main_tick_size) |
174 | | - plt.ylabel("Value", fontsize=main_label_size) |
175 | | - plt.ylim(-0.1, 1) |
176 | | - # plt.legend(loc="lower right", fontsize=legendsize) |
177 | | - plt.grid(axis="y", linestyle="solid", alpha=0.5) |
178 | | - |
179 | | - elif mode == "runtime": |
180 | | - fig, ax = plt.subplots(figsize=(8.5, 5)) |
181 | | - if "Spiner" in labels: |
182 | | - labels.remove("Spiner") |
183 | | - |
184 | | - # Convert setting labels to numerical x positions |
185 | | - offset = 0.08 # horizontal shift for scatter separation |
186 | | - x_pos = 1 |
187 | | - for num, key in enumerate(list(value_dict[segm].keys())): |
188 | | - runtime = [value_dict[segm][key]["runtime"]] |
189 | | - if runtime[0] is None: |
190 | | - continue |
191 | | - marker = value_dict[segm][key]["marker"] |
192 | | - plt.scatter([x_pos], runtime, label="Runtime", color=COLOR_T, marker=marker, s=marker_size) |
193 | | - x_pos = x_pos + 1 |
194 | | - |
195 | | - # Labels and formatting |
196 | | - x_pos = np.arange(1, len(labels)+1) |
197 | | - plt.xticks(x_pos, labels, fontsize=16, rotation=tick_rotation) |
198 | | - plt.yticks(fontsize=main_tick_size) |
199 | | - plt.ylabel("Processing time [s]", fontsize=main_label_size) |
200 | | - plt.ylim(10, 2600) |
201 | | - plt.yscale('log') |
202 | | - # plt.legend(loc="lower right", fontsize=legendsize) |
203 | | - plt.grid(axis="y", linestyle="solid", alpha=0.5) |
204 | | - |
205 | | - else: |
206 | | - raise ValueError("Unsupported mode for plotting.") |
207 | | - |
208 | | - plt.tight_layout() |
209 | | - prism_cleanup_axes(ax) |
210 | | - |
211 | | - if ".png" in save_path: |
212 | | - plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) |
213 | | - else: |
214 | | - plt.savefig(save_path, bbox_inches='tight', pad_inches=0) |
215 | | - |
216 | | - if plot: |
217 | | - plt.show() |
218 | | - else: |
219 | | - plt.close() |
220 | | - |
221 | | - |
222 | | -def plot_legend_fig02c(save_path, plot_mode="shapes"): |
223 | | - """Plot common legend for figure 2c. |
| 24 | +def plot_legend_fig02c( |
| 25 | + save_path: str, |
| 26 | + plot_mode: str = "shapes", |
| 27 | +): |
| 28 | + """Plot common legend for Figure 2c. |
224 | 29 |
|
225 | 30 | Args:. |
226 | 31 | save_path: save path to save legend. |
@@ -253,7 +58,10 @@ def plot_legend_fig02c(save_path, plot_mode="shapes"): |
253 | 58 | raise ValueError("Choose either 'shapes' or 'colors' as plot_mode.") |
254 | 59 |
|
255 | 60 |
|
256 | | -def fig_02c(save_path, plot=False, all_versions=False): |
| 61 | +def fig_02c( |
| 62 | + save_path: str, |
| 63 | + plot: bool = False, |
| 64 | +): |
257 | 65 | """Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual), |
258 | 66 | IHC (distance U-Net, manual), and synapse detection (U-Net). |
259 | 67 | """ |
@@ -336,7 +144,11 @@ def _load_ribbon_synapse_counts(): |
336 | 144 | return syn_counts |
337 | 145 |
|
338 | 146 |
|
339 | | -def fig_02d(save_path, plot=False, plot_average_ribbon_synapses=False): |
| 147 | +def fig_02d( |
| 148 | + save_path: str, |
| 149 | + plot: bool = False, |
| 150 | + plot_average_ribbon_synapses: bool = False, |
| 151 | +): |
340 | 152 | """Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values. |
341 | 153 | """ |
342 | 154 | prism_style() |
@@ -446,31 +258,22 @@ def fig_02d(save_path, plot=False, plot_average_ribbon_synapses=False): |
446 | 258 |
|
447 | 259 |
|
448 | 260 | def main(): |
449 | | - parser = argparse.ArgumentParser(description="Generate plots for Fig 2 of the cochlea paper.") |
| 261 | + parser = argparse.ArgumentParser(description="Generate plots for Figure 2 of the CochleaNet paper.") |
450 | 262 | parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig2") |
451 | 263 | parser.add_argument("--plot", action="store_true") |
452 | 264 | args = parser.parse_args() |
453 | 265 |
|
454 | 266 | os.makedirs(args.figure_dir, exist_ok=True) |
455 | 267 |
|
456 | 268 | # Panel C: Evaluation of the segmentation results: |
457 | | - fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot, all_versions=False) |
| 269 | + fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot) |
458 | 270 | plot_legend_fig02c(os.path.join(args.figure_dir, f"fig_02c_legend_shapes.{FILE_EXTENSION}"), plot_mode="shapes") |
459 | 271 | plot_legend_fig02c(os.path.join(args.figure_dir, f"fig_02c_legend_colors.{FILE_EXTENSION}"), plot_mode="colors") |
460 | 272 |
|
461 | 273 | # Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC |
462 | 274 | fig_02d(save_path=os.path.join(args.figure_dir, f"fig_02d.{FILE_EXTENSION}"), |
463 | 275 | plot=args.plot, plot_average_ribbon_synapses=True) |
464 | 276 |
|
465 | | - # Supplementary Figure 2: Comparing other methods in terms of segmentation accuracy and runtime |
466 | | - plot_legend_suppfig02(save_path=os.path.join(args.figure_dir, f"figsupp_02_legend_colors.{FILE_EXTENSION}")) |
467 | | - supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn.{FILE_EXTENSION}"), segm="SGN") |
468 | | - supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc.{FILE_EXTENSION}"), segm="IHC") |
469 | | - supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn_time.{FILE_EXTENSION}"), |
470 | | - segm="SGN", mode="runtime") |
471 | | - supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc_time.{FILE_EXTENSION}"), |
472 | | - segm="IHC", mode="runtime") |
473 | | - |
474 | 277 |
|
475 | 278 | if __name__ == "__main__": |
476 | 279 | main() |
0 commit comments