diff --git a/trackpy/linking/linking.py b/trackpy/linking/linking.py index 622da90d..609f47c3 100644 --- a/trackpy/linking/linking.py +++ b/trackpy/linking/linking.py @@ -11,8 +11,13 @@ coords_from_df, coords_from_df_iter, SubnetOversizeException) from .subnet import HashBTree, HashKDTree, Subnets, split_subnet -from .subnetlinker import (subnet_linker_recursive, subnet_linker_drop, - subnet_linker_numba, subnet_linker_nonrecursive) +from .subnetlinker import ( + subnet_linker_lsa, + subnet_linker_recursive, + subnet_linker_drop, + subnet_linker_numba, + subnet_linker_nonrecursive, +) logger = logging.getLogger(__name__) @@ -50,7 +55,7 @@ def link_iter(coords_iter, search_range, **kwargs): Reduce search_range by multiplying it by this factor. neighbor_strategy : {'KDTree', 'BTree'} algorithm used to identify nearby features. Default 'KDTree'. - link_strategy : {'recursive', 'nonrecursive', 'hybrid', 'numba', 'drop', 'auto'} + link_strategy : {'lsa', 'recursive', 'nonrecursive', 'hybrid', 'numba', 'drop', 'auto'} algorithm used to resolve subnetworks of nearby particles 'auto' uses hybrid (numba+recursive) if available 'drop' causes particles in subnetworks to go unlinked @@ -143,7 +148,7 @@ def link(f, search_range, pos_columns=None, t_column='frame', **kwargs): Reduce search_range by multiplying it by this factor. neighbor_strategy : {'KDTree', 'BTree'} algorithm used to identify nearby features. Default 'KDTree'. - link_strategy : {'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'} + link_strategy : {'lsa', 'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'} algorithm used to resolve subnetworks of nearby particles 'auto' uses hybrid (numba+recursive) if available 'drop' causes particles in subnetworks to go unlinked @@ -241,7 +246,7 @@ def link_df_iter(f_iter, search_range, pos_columns=None, Reduce search_range by multiplying it by this factor. neighbor_strategy : {'KDTree', 'BTree'} algorithm used to identify nearby features. Default 'KDTree'. - link_strategy : {'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'} + link_strategy : {'lsa', 'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'} algorithm used to resolve subnetworks of nearby particles 'auto' uses hybrid (numba+recursive) if available 'drop' causes particles in subnetworks to go unlinked @@ -391,9 +396,11 @@ def __init__(self, search_range, memory=0, predictor=None, if NUMBA_AVAILABLE: link_strategy = 'hybrid' else: - link_strategy = 'recursive' + link_strategy = 'lsa' - if link_strategy == 'recursive': + if link_strategy == 'lsa': + subnet_linker = subnet_linker_lsa + elif link_strategy == 'recursive': subnet_linker = subnet_linker_recursive elif link_strategy == 'hybrid': subnet_linker = subnet_linker_numba diff --git a/trackpy/linking/subnetlinker.py b/trackpy/linking/subnetlinker.py index b7b1d3be..782767fd 100644 --- a/trackpy/linking/subnetlinker.py +++ b/trackpy/linking/subnetlinker.py @@ -7,11 +7,41 @@ from collections import deque import numpy as np +from scipy import optimize from .utils import SubnetOversizeException from ..try_numba import try_numba_jit +def subnet_linker_lsa(source_set, dest_set, search_range, max_size=None): + src = [*source_set] + dst = [*dest_set] + dst_uuid2idx = {d.uuid: i for i, d in enumerate(dst)} + # "Too-far" pairs are actually assigned a distance of search_range; see test_penalty. + inf = search_range ** 2 + d2s = np.full((len(src), len(dst)), inf) + for i, s in enumerate(src): + for d, dist in s.forward_cands: + d2s[i, dst_uuid2idx[d.uuid]] = dist ** 2 + src_idxs, dst_idxs = optimize.linear_sum_assignment(d2s) + keep = d2s[src_idxs, dst_idxs] < inf # Other pairs were actually too far. + src_idxs = src_idxs[keep].tolist() + dst_idxs = dst_idxs[keep].tolist() + lost_src = sorted({*range(len(src))} - {*src_idxs}) + lost_dst = sorted({*range(len(dst))} - {*dst_idxs}) + sn_spl = [ + *[src[i] for i in src_idxs], + *[src[i] for i in lost_src], + *[None] * len(lost_dst), + ] + sn_dpl = [ + *[dst[i] for i in dst_idxs], + *[None] * len(lost_src), + *[dst[i] for i in lost_dst], + ] + return sn_spl, sn_dpl + + def recursive_linker_obj(s_sn, dest_size, search_range, max_size=30, diag=False): snl = SubnetLinker(s_sn, dest_size, search_range, max_size=max_size) # In Python 3, we must convert to lists to return mutable collections. diff --git a/trackpy/tests/test_linking.py b/trackpy/tests/test_linking.py index 4ff5fe03..9a55d410 100644 --- a/trackpy/tests/test_linking.py +++ b/trackpy/tests/test_linking.py @@ -304,12 +304,13 @@ def to_eucl(arr): def test_oversize_fail(self): with self.assertRaises(SubnetOversizeException): df = contracting_grid() - self.link(df, search_range=2) + self.link(df, search_range=2, link_strategy='recursive') def test_adaptive_fail(self): """Check recursion limit""" with self.assertRaises(SubnetOversizeException): - self.link(contracting_grid(), search_range=2, adaptive_stop=1.84) + self.link(contracting_grid(), search_range=2, adaptive_stop=1.84, + link_strategy='recursive') def link(self, f, search_range, *args, **kwargs): kwargs = dict(self.linker_opts, **kwargs) diff --git a/trackpy/tests/test_predict.py b/trackpy/tests/test_predict.py index 9f118b2b..722d230b 100644 --- a/trackpy/tests/test_predict.py +++ b/trackpy/tests/test_predict.py @@ -102,10 +102,10 @@ def test_fail_predict(self): def test_subnet_fail(self): with self.assertRaises(trackpy.SubnetOversizeException): Nside = Nside_oversize - ll = self.get_linked_lengths_from_iterfunc((mkframe(0, Nside), - mkframe(25, Nside), - mkframe(75, Nside)), - self.get_unwrapped_linker(), 100) + ll = self.get_linked_lengths_from_iterfunc( + (mkframe(0, Nside), mkframe(25, Nside), mkframe(75, Nside)), + functools.partial(self.get_unwrapped_linker(), link_strategy='recursive'), + 100) class BaselinePredictIterTests(LinkIterWithPrediction, BaselinePredictTests, StrictTestCase):