Skip to content

Commit cea92aa

Browse files
committed
Add test for ReplayTrackFeeder
1 parent a85fe20 commit cea92aa

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

stonesoup/feeder/tests/test_track.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from ...types.state import GaussianState
55
from ...types.track import Track
6-
from ..track import Tracks2GaussianDetectionFeeder
6+
from ..track import Tracks2GaussianDetectionFeeder, ReplayTrackFeeder
77

88
t1 = Track(GaussianState([1, 1, 1, 1], np.diag([2, 2, 2, 2]), timestamp=2))
99
t2 = Track([GaussianState([1, 1, 1, 1], np.diag([2, 2, 2, 2]), timestamp=1),
@@ -13,6 +13,8 @@
1313
GaussianState([3, 1], np.diag([2, 2]), timestamp=2)])
1414
t4 = Track(GaussianState([1, 0, 1, 0, 1, 0], np.diag([2, 2, 2, 2, 2, 2]), timestamp=2))
1515

16+
times = [0, 1, 2]
17+
1618

1719
@pytest.mark.parametrize(
1820
"tracks",
@@ -53,3 +55,23 @@ def test_Track2GaussianDetectionFeeder(tracks):
5355
if detection.metadata['track_id'] == track.id))
5456
assert np.all(detection.state_vector == track.state_vector)
5557
assert np.all(detection.covar == track.covar)
58+
59+
60+
@pytest.mark.parametrize(("tracks", "times"),
61+
[([t1], times), ([t1], None),
62+
([t1, t2], times), ([t1, t2], None),
63+
([t2, t3], times), ([t2, t3], None),
64+
([t1, t2, t3, t4], times), ([t1, t2, t3, t4], None)])
65+
def test_ReplayTrackFeeder(tracks, times):
66+
feeder = ReplayTrackFeeder(reader=tracks, times=times)
67+
feeder_times = []
68+
feeder_tracks = set()
69+
for new_time, new_tracks in feeder:
70+
print(new_time, new_tracks)
71+
feeder_times.append(new_time)
72+
feeder_tracks |= new_tracks
73+
if times is not None:
74+
assert times == feeder_times
75+
assert len(tracks) == len(feeder_tracks)
76+
assert (sorted([len(track) for track in tracks]) ==
77+
sorted([len(track) for track in feeder_tracks]))

0 commit comments

Comments
 (0)