Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions trackpy/linking/linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
coords_from_df, coords_from_df_iter,
SubnetOversizeException)
from .subnet import HashBTree, HashKDTree, Subnets, split_subnet
from .subnetlinker import (subnet_linker_recursive, subnet_linker_drop,
subnet_linker_numba, subnet_linker_nonrecursive)
from .subnetlinker import (
subnet_linker_lsa,
subnet_linker_recursive,
subnet_linker_drop,
subnet_linker_numba,
subnet_linker_nonrecursive,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,7 +55,7 @@ def link_iter(coords_iter, search_range, **kwargs):
Reduce search_range by multiplying it by this factor.
neighbor_strategy : {'KDTree', 'BTree'}
algorithm used to identify nearby features. Default 'KDTree'.
link_strategy : {'recursive', 'nonrecursive', 'hybrid', 'numba', 'drop', 'auto'}
link_strategy : {'lsa', 'recursive', 'nonrecursive', 'hybrid', 'numba', 'drop', 'auto'}
algorithm used to resolve subnetworks of nearby particles
'auto' uses hybrid (numba+recursive) if available
'drop' causes particles in subnetworks to go unlinked
Expand Down Expand Up @@ -143,7 +148,7 @@ def link(f, search_range, pos_columns=None, t_column='frame', **kwargs):
Reduce search_range by multiplying it by this factor.
neighbor_strategy : {'KDTree', 'BTree'}
algorithm used to identify nearby features. Default 'KDTree'.
link_strategy : {'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
link_strategy : {'lsa', 'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
algorithm used to resolve subnetworks of nearby particles
'auto' uses hybrid (numba+recursive) if available
'drop' causes particles in subnetworks to go unlinked
Expand Down Expand Up @@ -241,7 +246,7 @@ def link_df_iter(f_iter, search_range, pos_columns=None,
Reduce search_range by multiplying it by this factor.
neighbor_strategy : {'KDTree', 'BTree'}
algorithm used to identify nearby features. Default 'KDTree'.
link_strategy : {'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
link_strategy : {'lsa', 'recursive', 'nonrecursive', 'numba', 'hybrid', 'drop', 'auto'}
algorithm used to resolve subnetworks of nearby particles
'auto' uses hybrid (numba+recursive) if available
'drop' causes particles in subnetworks to go unlinked
Expand Down Expand Up @@ -391,9 +396,11 @@ def __init__(self, search_range, memory=0, predictor=None,
if NUMBA_AVAILABLE:
link_strategy = 'hybrid'
else:
link_strategy = 'recursive'
link_strategy = 'lsa'

if link_strategy == 'recursive':
if link_strategy == 'lsa':
subnet_linker = subnet_linker_lsa
elif link_strategy == 'recursive':
subnet_linker = subnet_linker_recursive
elif link_strategy == 'hybrid':
subnet_linker = subnet_linker_numba
Expand Down
30 changes: 30 additions & 0 deletions trackpy/linking/subnetlinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,41 @@
from collections import deque

import numpy as np
from scipy import optimize

from .utils import SubnetOversizeException
from ..try_numba import try_numba_jit


def subnet_linker_lsa(source_set, dest_set, search_range, max_size=None):
src = [*source_set]
dst = [*dest_set]
dst_uuid2idx = {d.uuid: i for i, d in enumerate(dst)}
# "Too-far" pairs are actually assigned a distance of search_range; see test_penalty.
inf = search_range ** 2
d2s = np.full((len(src), len(dst)), inf)
for i, s in enumerate(src):
for d, dist in s.forward_cands:
d2s[i, dst_uuid2idx[d.uuid]] = dist ** 2
src_idxs, dst_idxs = optimize.linear_sum_assignment(d2s)
keep = d2s[src_idxs, dst_idxs] < inf # Other pairs were actually too far.
src_idxs = src_idxs[keep].tolist()
dst_idxs = dst_idxs[keep].tolist()
lost_src = sorted({*range(len(src))} - {*src_idxs})
lost_dst = sorted({*range(len(dst))} - {*dst_idxs})
sn_spl = [
*[src[i] for i in src_idxs],
*[src[i] for i in lost_src],
*[None] * len(lost_dst),
]
sn_dpl = [
*[dst[i] for i in dst_idxs],
*[None] * len(lost_src),
*[dst[i] for i in lost_dst],
]
return sn_spl, sn_dpl


def recursive_linker_obj(s_sn, dest_size, search_range, max_size=30, diag=False):
snl = SubnetLinker(s_sn, dest_size, search_range, max_size=max_size)
# In Python 3, we must convert to lists to return mutable collections.
Expand Down
5 changes: 3 additions & 2 deletions trackpy/tests/test_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,13 @@ def to_eucl(arr):
def test_oversize_fail(self):
with self.assertRaises(SubnetOversizeException):
df = contracting_grid()
self.link(df, search_range=2)
self.link(df, search_range=2, link_strategy='recursive')

def test_adaptive_fail(self):
"""Check recursion limit"""
with self.assertRaises(SubnetOversizeException):
self.link(contracting_grid(), search_range=2, adaptive_stop=1.84)
self.link(contracting_grid(), search_range=2, adaptive_stop=1.84,
link_strategy='recursive')

def link(self, f, search_range, *args, **kwargs):
kwargs = dict(self.linker_opts, **kwargs)
Expand Down
8 changes: 4 additions & 4 deletions trackpy/tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def test_fail_predict(self):
def test_subnet_fail(self):
with self.assertRaises(trackpy.SubnetOversizeException):
Nside = Nside_oversize
ll = self.get_linked_lengths_from_iterfunc((mkframe(0, Nside),
mkframe(25, Nside),
mkframe(75, Nside)),
self.get_unwrapped_linker(), 100)
ll = self.get_linked_lengths_from_iterfunc(
(mkframe(0, Nside), mkframe(25, Nside), mkframe(75, Nside)),
functools.partial(self.get_unwrapped_linker(), link_strategy='recursive'),
100)


class BaselinePredictIterTests(LinkIterWithPrediction, BaselinePredictTests, StrictTestCase):
Expand Down
Loading