1
1
import argparse
2
- import numpy as np
3
- from os .path import join , basename , exists
2
+ from os .path import join , basename
4
3
from multiprocessing import Pool , cpu_count
4
+ import numpy as np
5
5
6
6
from ctc_metrics .metrics import (
7
7
valid , det , seg , tra , ct , tf , bc , cca , op_ctb , op_csb , bio , op_clb , lnk
@@ -73,7 +73,7 @@ def load_data(
73
73
74
74
# Match golden truth tracking masks to result masks
75
75
traj = {}
76
- is_valid = True
76
+ is_valid = 0
77
77
if trajectory_data :
78
78
traj = match_computed_to_reference_masks (
79
79
ref_tra_masks , comp_masks , threads = threads )
@@ -99,7 +99,6 @@ def calculate_metrics(
99
99
ref_tracks : np .ndarray ,
100
100
traj : dict ,
101
101
segm : dict ,
102
- comp_masks : list ,
103
102
metrics : list = None ,
104
103
is_valid : bool = None ,
105
104
): # pylint: disable=too-complex
@@ -111,7 +110,6 @@ def calculate_metrics(
111
110
ref_tracks: The reference tracks result file.
112
111
traj: The frame-wise trajectory match data.
113
112
segm: The frame-wise segmentation match data.
114
- comp_masks: The computed masks.
115
113
metrics: The metrics to evaluate.
116
114
is_valid: A Flag if the results are valid
117
115
@@ -186,7 +184,7 @@ def calculate_metrics(
186
184
if "BIO(0)" in results and "LNK" in results :
187
185
for i in range (4 ):
188
186
results [f"OP_CLB({ i } )" ] = op_clb (
189
- results [f"LNK" ], results [f"BIO({ i } )" ])
187
+ results [f"LNK( { i } ) " ], results [f"BIO({ i } )" ])
190
188
191
189
return results
192
190
@@ -219,17 +217,17 @@ def evaluate_sequence(
219
217
trajectory_data = True
220
218
segmentation_data = True
221
219
222
- if metrics == ["SEG" ] or metrics == ["CCA" ]:
220
+ if metrics in ( ["SEG" ], ["CCA" ]) :
223
221
trajectory_data = False
224
222
225
223
if "SEG" not in metrics :
226
224
segmentation_data = False
227
225
228
- comp_tracks , ref_tracks , traj , segm , comp_masks , is_valid = load_data (
226
+ comp_tracks , ref_tracks , traj , segm , _ , is_valid = load_data (
229
227
res , gt , trajectory_data , segmentation_data , threads )
230
228
231
229
results = calculate_metrics (
232
- comp_tracks , ref_tracks , traj , segm , comp_masks , metrics , is_valid )
230
+ comp_tracks , ref_tracks , traj , segm , metrics , is_valid )
233
231
234
232
print ("with results: " , results , " done!" )
235
233
0 commit comments