|
6 | 6 | import os |
7 | 7 | import numpy as np |
8 | 8 | import pandas as pd |
| 9 | +from typing import List, Tuple |
9 | 10 | from scipy.spatial import cKDTree |
10 | 11 | from scipy.optimize import linear_sum_assignment |
11 | 12 | from typing import Union, Dict |
12 | 13 | from .Trackmate import TrackMate |
| 14 | + |
13 | 15 | import concurrent.futures |
14 | 16 | from tqdm import tqdm |
15 | 17 |
|
@@ -44,7 +46,7 @@ def _track_cloud(self, df: pd.DataFrame) -> Dict[int, np.ndarray]: |
44 | 46 | """Group spots by unique_id and return dict of track_id -> Nx3 array.""" |
45 | 47 | return {tid: grp[['z','y','x']].values for tid, grp in df.groupby('unique_id')} |
46 | 48 |
|
47 | | - def evaluate(self, threshold: float) -> Dict[str, object]: |
| 49 | + def evaluate(self, threshold: float, compute_bci: bool = False) -> Dict[str, object]: |
48 | 50 | """ |
49 | 51 | Perform optimal assignment between GT and predicted tracks. |
50 | 52 | Returns dict with: |
@@ -86,8 +88,11 @@ def compute_row(item): |
86 | 88 | # metrics |
87 | 89 | cca = self.cca_metric() |
88 | 90 | ct = self.ct_metric(assignments) |
| 91 | + bci = self.bci_metric(assignments) if compute_bci else None |
| 92 | + |
89 | 93 | return {'assignments':assignments,'num_hits':num_hits, |
90 | | - 'num_gt':num_gt,'num_pred':num_pred,'cca':cca,'ct':ct} |
| 94 | + 'num_gt':num_gt,'num_pred':num_pred,'cca':cca,'ct':ct, |
| 95 | + 'bci': bci} |
91 | 96 |
|
92 | 97 | def cca_metric(self) -> float: |
93 | 98 | """Cell Cycle Accuracy: CDF distance of track-length histograms.""" |
@@ -124,3 +129,82 @@ def ct_metric(self, assignments: pd.DataFrame) -> float: |
124 | 129 | gt_id,pr_id = r['gt_track'], r['pred_track'] |
125 | 130 | if spans['gt'][gt_id]==spans['pred'][pr_id]: T_rc+=1 |
126 | 131 | return float(2*T_rc/(T_r+T_c)) if (T_r+T_c)>0 else np.nan |
| 132 | + |
| 133 | + def _get_mitosis_events(self, tm: TrackMate) -> List[Tuple[int,int]]: |
| 134 | + """ |
| 135 | + Walk every dividing track in tm and return a list of |
| 136 | + (track_id, split_frame) for *every* split event. |
| 137 | + """ |
| 138 | + events = [] |
| 139 | + for trk in tm.DividingTrackIds: |
| 140 | + if trk is None or trk == tm.TrackidBox: |
| 141 | + continue |
| 142 | + for spot in tm.all_current_cell_ids[int(trk)]: |
| 143 | + children = tm.edge_target_lookup.get(spot, []) |
| 144 | + # split if >1 children |
| 145 | + if isinstance(children, list) and len(children) > 1: |
| 146 | + t_split = tm.unique_spot_properties[spot][tm.frameid_key] |
| 147 | + events.append((int(trk), int(t_split))) |
| 148 | + # no break → capture multiple splits per track |
| 149 | + return events |
| 150 | + |
| 151 | + def bci_metric(self, |
| 152 | + assignments: pd.DataFrame, |
| 153 | + tol: int = 1 |
| 154 | + ) -> float: |
| 155 | + """ |
| 156 | + Branching Correctness Index (F1) of mitotic events. |
| 157 | + Matches every GT (track,time) split to its assigned pred track within ±tol frames. |
| 158 | + """ |
| 159 | + gt_events = self._get_mitosis_events(self.gt) |
| 160 | + pred_events = self._get_mitosis_events(self.pred) |
| 161 | + |
| 162 | + # map GT → assigned pred |
| 163 | + assign_map = { |
| 164 | + int(r.gt_track): int(r.pred_track) |
| 165 | + for r in assignments.itertuples(index=False) |
| 166 | + if r.matched |
| 167 | + } |
| 168 | + |
| 169 | + tp = fp = fn = 0 |
| 170 | + |
| 171 | + # true / false negatives |
| 172 | + for gt_tid, t_gt in gt_events: |
| 173 | + pr_tid = assign_map.get(gt_tid) |
| 174 | + if pr_tid is None: |
| 175 | + fn += 1 |
| 176 | + else: |
| 177 | + # did that pred track also split near the same time? |
| 178 | + if any(pr == pr_tid and abs(t_pr - t_gt) <= tol |
| 179 | + for pr, t_pr in pred_events): |
| 180 | + tp += 1 |
| 181 | + else: |
| 182 | + fn += 1 |
| 183 | + |
| 184 | + # false positives: pred splits that didn’t match any GT |
| 185 | + matched_preds = set(assign_map.values()) |
| 186 | + for pr_tid, t_pr in pred_events: |
| 187 | + # if this pred track wasn’t assigned at all → FP |
| 188 | + if pr_tid not in matched_preds: |
| 189 | + fp += 1 |
| 190 | + else: |
| 191 | + # if its matched GT track never split at this time → FP |
| 192 | + # find all GT tracks that map to this pred |
| 193 | + gt_tids = [gt for gt, pr in assign_map.items() if pr == pr_tid] |
| 194 | + # for any such GT, is there a GT split within tol of t_pr? |
| 195 | + matched_any = False |
| 196 | + for candidate_gt in gt_tids: |
| 197 | + for gt2, t_gt2 in gt_events: |
| 198 | + if gt2 == candidate_gt and abs(t_pr - t_gt2) <= tol: |
| 199 | + matched_any = True |
| 200 | + break |
| 201 | + if matched_any: |
| 202 | + break |
| 203 | + if not matched_any: |
| 204 | + fp += 1 |
| 205 | + |
| 206 | + # finally F1 |
| 207 | + precision = tp / max(tp + fp, 1) |
| 208 | + recall = tp / max(tp + fn, 1) |
| 209 | + return 2 * (precision * recall) / max(precision + recall, 1e-4) |
| 210 | + |
0 commit comments