Skip to content

Commit a2b8022

Browse files
committed
add ctc metrics for track comparision
1 parent 251c862 commit a2b8022

2 files changed

Lines changed: 88 additions & 4 deletions

File tree

src/napatrackmater/Trackcomparator.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import os
77
import numpy as np
88
import pandas as pd
9+
from typing import List, Tuple
910
from scipy.spatial import cKDTree
1011
from scipy.optimize import linear_sum_assignment
1112
from typing import Union, Dict
1213
from .Trackmate import TrackMate
14+
1315
import concurrent.futures
1416
from tqdm import tqdm
1517

@@ -44,7 +46,7 @@ def _track_cloud(self, df: pd.DataFrame) -> Dict[int, np.ndarray]:
4446
"""Group spots by unique_id and return dict of track_id -> Nx3 array."""
4547
return {tid: grp[['z','y','x']].values for tid, grp in df.groupby('unique_id')}
4648

47-
def evaluate(self, threshold: float) -> Dict[str, object]:
49+
def evaluate(self, threshold: float, compute_bci: bool = False) -> Dict[str, object]:
4850
"""
4951
Perform optimal assignment between GT and predicted tracks.
5052
Returns dict with:
@@ -86,8 +88,11 @@ def compute_row(item):
8688
# metrics
8789
cca = self.cca_metric()
8890
ct = self.ct_metric(assignments)
91+
bci = self.bci_metric(assignments) if compute_bci else None
92+
8993
return {'assignments':assignments,'num_hits':num_hits,
90-
'num_gt':num_gt,'num_pred':num_pred,'cca':cca,'ct':ct}
94+
'num_gt':num_gt,'num_pred':num_pred,'cca':cca,'ct':ct,
95+
'bci': bci}
9196

9297
def cca_metric(self) -> float:
9398
"""Cell Cycle Accuracy: CDF distance of track-length histograms."""
@@ -124,3 +129,82 @@ def ct_metric(self, assignments: pd.DataFrame) -> float:
124129
gt_id,pr_id = r['gt_track'], r['pred_track']
125130
if spans['gt'][gt_id]==spans['pred'][pr_id]: T_rc+=1
126131
return float(2*T_rc/(T_r+T_c)) if (T_r+T_c)>0 else np.nan
132+
133+
def _get_mitosis_events(self, tm: TrackMate) -> List[Tuple[int,int]]:
134+
"""
135+
Walk every dividing track in tm and return a list of
136+
(track_id, split_frame) for *every* split event.
137+
"""
138+
events = []
139+
for trk in tm.DividingTrackIds:
140+
if trk is None or trk == tm.TrackidBox:
141+
continue
142+
for spot in tm.all_current_cell_ids[int(trk)]:
143+
children = tm.edge_target_lookup.get(spot, [])
144+
# split if >1 children
145+
if isinstance(children, list) and len(children) > 1:
146+
t_split = tm.unique_spot_properties[spot][tm.frameid_key]
147+
events.append((int(trk), int(t_split)))
148+
# no break → capture multiple splits per track
149+
return events
150+
151+
def bci_metric(self,
152+
assignments: pd.DataFrame,
153+
tol: int = 1
154+
) -> float:
155+
"""
156+
Branching Correctness Index (F1) of mitotic events.
157+
Matches every GT (track,time) split to its assigned pred track within ±tol frames.
158+
"""
159+
gt_events = self._get_mitosis_events(self.gt)
160+
pred_events = self._get_mitosis_events(self.pred)
161+
162+
# map GT → assigned pred
163+
assign_map = {
164+
int(r.gt_track): int(r.pred_track)
165+
for r in assignments.itertuples(index=False)
166+
if r.matched
167+
}
168+
169+
tp = fp = fn = 0
170+
171+
# true / false negatives
172+
for gt_tid, t_gt in gt_events:
173+
pr_tid = assign_map.get(gt_tid)
174+
if pr_tid is None:
175+
fn += 1
176+
else:
177+
# did that pred track also split near the same time?
178+
if any(pr == pr_tid and abs(t_pr - t_gt) <= tol
179+
for pr, t_pr in pred_events):
180+
tp += 1
181+
else:
182+
fn += 1
183+
184+
# false positives: pred splits that didn’t match any GT
185+
matched_preds = set(assign_map.values())
186+
for pr_tid, t_pr in pred_events:
187+
# if this pred track wasn’t assigned at all → FP
188+
if pr_tid not in matched_preds:
189+
fp += 1
190+
else:
191+
# if its matched GT track never split at this time → FP
192+
# find all GT tracks that map to this pred
193+
gt_tids = [gt for gt, pr in assign_map.items() if pr == pr_tid]
194+
# for any such GT, is there a GT split within tol of t_pr?
195+
matched_any = False
196+
for candidate_gt in gt_tids:
197+
for gt2, t_gt2 in gt_events:
198+
if gt2 == candidate_gt and abs(t_pr - t_gt2) <= tol:
199+
matched_any = True
200+
break
201+
if matched_any:
202+
break
203+
if not matched_any:
204+
fp += 1
205+
206+
# finally F1
207+
precision = tp / max(tp + fp, 1)
208+
recall = tp / max(tp + fn, 1)
209+
return 2 * (precision * recall) / max(precision + recall, 1e-4)
210+

src/napatrackmater/_version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = version = "5.8.0"
2-
__version_tuple__ = version_tuple = (5, 8.0)
1+
__version__ = version = "5.8.1"
2+
__version_tuple__ = version_tuple = (5, 8, 1)

0 commit comments

Comments
 (0)