Skip to content

Commit c1c34ef

Browse files
committed
Match subnets in O(n^3) via linear_sum_assignment.
To solve subnets, instead of brute-forcing the n! combinations, use scipy.optimize.linear_sum_assignment (commonly known as the Hungarian/Munkres algorithm, although scipy actually uses a different algorithm) which provides a solution in O(n^3). This is not an original idea; see e.g. https://imagej.net/imagej-wiki-static/TrackMate_Algorithms.html#Solving_LAP Locally, this method solves the "slow" example in adaptive-search.ipynb (`tracks_regular = trackpy.link_df(cg, 0.75)`) in ~50ms instead of 1min. linear_sum_assignment ("lsa") was extensively benchmarked in scipy PR#12541 (which was most about adding a *sparse* variant, but one can just look at the performance of lsa), which shows sub-second runtimes for thousands of inputs. This is also why this PR fully skips the use of MAX_SUB_NET (at most, one may consider setting an alternate value in the thousands for this strategy...). (In fact, it may perhaps(?) be possible to get even better performance via by completely dropping the subnet paradigm and just passing the whole distance matrix (sparsified by removing edges greater than the search range) to scipy.sparse.csgraph.min_weight_full_bipartite_matching, but 1) one needs to figure out how to handle unmatched points (maybe by adding "dummy" points with a high but finite link cost), and 2) this would require more massive code surgery anyways.) The docs will need to be updated, but this PR is just to discuss the implementation.
1 parent 42e5315 commit c1c34ef

File tree

4 files changed

+51
-13
lines changed

4 files changed

+51
-13
lines changed

trackpy/linking/linking.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111
coords_from_df, coords_from_df_iter,
1212
SubnetOversizeException)
1313
from .subnet import HashBTree, HashKDTree, Subnets, split_subnet
14-
from .subnetlinker import (subnet_linker_recursive, subnet_linker_drop,
15-
subnet_linker_numba, subnet_linker_nonrecursive)
14+
from .subnetlinker import (
15+
subnet_linker_lsa,
16+
subnet_linker_recursive,
17+
subnet_linker_drop,
18+
subnet_linker_numba,
19+
subnet_linker_nonrecursive,
20+
)
1621

1722
logger = logging.getLogger(__name__)
1823

@@ -50,7 +55,7 @@ def link_iter(coords_iter, search_range, **kwargs):
5055
Reduce search_range by multiplying it by this factor.
5156
neighbor_strategy : {'KDTree', 'BTree'}
5257
algorithm used to identify nearby features. Default 'KDTree'.
53-
link_strategy : {'recursive', 'nonrecursive', 'hybrid', 'numba', 'drop', 'auto'}
58+
link_strategy : {'lsa', 'recursive', 'nonrecursive', 'hybrid', 'numba', 'drop', 'auto'}
5459
algorithm used to resolve subnetworks of nearby particles
5560
'auto' uses hybrid (numba+recursive) if available
5661
'drop' causes particles in subnetworks to go unlinked
@@ -143,7 +148,7 @@ def link(f, search_range, pos_columns=None, t_column='frame', **kwargs):
143148
Reduce search_range by multiplying it by this factor.
144149
neighbor_strategy : {'KDTree', 'BTree'}
145150
algorithm used to identify nearby features. Default 'KDTree'.
146-
link_strategy : {'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
151+
link_strategy : {'lsa', 'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
147152
algorithm used to resolve subnetworks of nearby particles
148153
'auto' uses hybrid (numba+recursive) if available
149154
'drop' causes particles in subnetworks to go unlinked
@@ -241,7 +246,7 @@ def link_df_iter(f_iter, search_range, pos_columns=None,
241246
Reduce search_range by multiplying it by this factor.
242247
neighbor_strategy : {'KDTree', 'BTree'}
243248
algorithm used to identify nearby features. Default 'KDTree'.
244-
link_strategy : {'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
249+
link_strategy : {'lsa', 'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
245250
algorithm used to resolve subnetworks of nearby particles
246251
'auto' uses hybrid (numba+recursive) if available
247252
'drop' causes particles in subnetworks to go unlinked
@@ -391,9 +396,11 @@ def __init__(self, search_range, memory=0, predictor=None,
391396
if NUMBA_AVAILABLE:
392397
link_strategy = 'hybrid'
393398
else:
394-
link_strategy = 'recursive'
399+
link_strategy = 'lsa'
395400

396-
if link_strategy == 'recursive':
401+
if link_strategy == 'lsa':
402+
subnet_linker = subnet_linker_lsa
403+
elif link_strategy == 'recursive':
397404
subnet_linker = subnet_linker_recursive
398405
elif link_strategy == 'hybrid':
399406
subnet_linker = subnet_linker_numba

trackpy/linking/subnetlinker.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,41 @@
77
from collections import deque
88

99
import numpy as np
10+
from scipy import optimize
1011

1112
from .utils import SubnetOversizeException
1213
from ..try_numba import try_numba_jit
1314

1415

16+
def subnet_linker_lsa(source_set, dest_set, search_range, max_size=None):
17+
src = [*source_set]
18+
dst = [*dest_set]
19+
dst_uuid2idx = {d.uuid: i for i, d in enumerate(dst)}
20+
# "Too-far" pairs are actually assigned a distance of search_range; see test_penalty.
21+
inf = search_range ** 2
22+
d2s = np.full((len(src), len(dst)), inf)
23+
for i, s in enumerate(src):
24+
for d, dist in s.forward_cands:
25+
d2s[i, dst_uuid2idx[d.uuid]] = dist ** 2
26+
src_idxs, dst_idxs = optimize.linear_sum_assignment(d2s)
27+
keep = d2s[src_idxs, dst_idxs] < inf # Other pairs were actually too far.
28+
src_idxs = src_idxs[keep].tolist()
29+
dst_idxs = dst_idxs[keep].tolist()
30+
lost_src = sorted({*range(len(src))} - {*src_idxs})
31+
lost_dst = sorted({*range(len(dst))} - {*dst_idxs})
32+
sn_spl = [
33+
*[src[i] for i in src_idxs],
34+
*[src[i] for i in lost_src],
35+
*[None] * len(lost_dst),
36+
]
37+
sn_dpl = [
38+
*[dst[i] for i in dst_idxs],
39+
*[None] * len(lost_src),
40+
*[dst[i] for i in lost_dst],
41+
]
42+
return sn_spl, sn_dpl
43+
44+
1545
def recursive_linker_obj(s_sn, dest_size, search_range, max_size=30, diag=False):
1646
snl = SubnetLinker(s_sn, dest_size, search_range, max_size=max_size)
1747
# In Python 3, we must convert to lists to return mutable collections.

trackpy/tests/test_linking.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,13 @@ def to_eucl(arr):
304304
def test_oversize_fail(self):
305305
with self.assertRaises(SubnetOversizeException):
306306
df = contracting_grid()
307-
self.link(df, search_range=2)
307+
self.link(df, search_range=2, link_strategy='recursive')
308308

309309
def test_adaptive_fail(self):
310310
"""Check recursion limit"""
311311
with self.assertRaises(SubnetOversizeException):
312-
self.link(contracting_grid(), search_range=2, adaptive_stop=1.84)
312+
self.link(contracting_grid(), search_range=2, adaptive_stop=1.84,
313+
link_strategy='recursive')
313314

314315
def link(self, f, search_range, *args, **kwargs):
315316
kwargs = dict(self.linker_opts, **kwargs)

trackpy/tests/test_predict.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ def test_fail_predict(self):
102102
def test_subnet_fail(self):
103103
with self.assertRaises(trackpy.SubnetOversizeException):
104104
Nside = Nside_oversize
105-
ll = self.get_linked_lengths_from_iterfunc((mkframe(0, Nside),
106-
mkframe(25, Nside),
107-
mkframe(75, Nside)),
108-
self.get_unwrapped_linker(), 100)
105+
ll = self.get_linked_lengths_from_iterfunc(
106+
(mkframe(0, Nside), mkframe(25, Nside), mkframe(75, Nside)),
107+
functools.partial(self.get_unwrapped_linker(), link_strategy='recursive'),
108+
100)
109109

110110

111111
class BaselinePredictIterTests(LinkIterWithPrediction, BaselinePredictTests, StrictTestCase):

0 commit comments

Comments
 (0)