Skip to content

Commit 573f00d

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents ad585fb + 222a257 commit 573f00d

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

ctc_metrics/metrics/validation/valid.py

+21
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,25 @@ def no_empty_frames(
170170
return int(is_valid)
171171

172172

173+
def no_empty_tracking_result(
174+
tracks: np.ndarray
175+
):
176+
"""
177+
Checks if there is at least one detection in th results.
178+
179+
Args:
180+
tracks: The tracks to inspect
181+
182+
Returns:
183+
1 if there are detections, 0 otherwise.
184+
"""
185+
is_valid = 1
186+
if len(tracks) == 0:
187+
warnings.warn("No tracks in result.", UserWarning)
188+
is_valid = 0
189+
return is_valid
190+
191+
173192
def valid(
174193
masks: list,
175194
tracks: np.ndarray,
@@ -195,6 +214,8 @@ def valid(
195214
196215
"""
197216
is_valid = 1
217+
# If tracks is empty, the result is invalid
218+
is_valid = no_empty_tracking_result(tracks)
198219
# Get the labels in each frame
199220
num_frames = max(tracks[:, 2].max() + 1, len(masks))
200221
frames = [[] for _ in range(num_frames)]

ctc_metrics/scripts/evaluate.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ def calculate_metrics(
151151
Returns:
152152
The results stored in a dictionary.
153153
"""
154+
# Check if results are valid
155+
results = {x: None for x in metrics}
156+
if not is_valid:
157+
print("Invalid results!")
158+
results["Valid"] = 0
159+
return results
160+
154161
# Create merge tracks
155162
if traj:
156163
new_tracks, new_labels, new_mapped = merge_tracks(
@@ -175,12 +182,6 @@ def calculate_metrics(
175182
)
176183

177184
# Calculate metrics
178-
results = {x: None for x in metrics}
179-
if not is_valid:
180-
print("Invalid results!")
181-
results["Valid"] = 0
182-
return results
183-
184185
if "Valid" in metrics:
185186
results["Valid"] = is_valid
186187

ctc_metrics/utils/filesystem.py

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def read_tracking_file(
7373
return None
7474
with open(path, "r", encoding="utf-8") as f:
7575
lines = f.readlines()
76+
if len(lines) == 0:
77+
return np.zeros((0, 4))
7678
seperator = " " if " " in lines[0] else "\t"
7779
lines = [x.strip().split(seperator) for x in lines]
7880
lines = [[int(y) for y in x if y != ""] for x in lines]

0 commit comments

Comments
 (0)