From 627043aab9a73ee37f5bc755dfd6ab237db86e3d Mon Sep 17 00:00:00 2001 From: Frederik Van der Veken Date: Tue, 6 Feb 2024 23:33:36 +0100 Subject: [PATCH 1/2] Corrected overflow with ele_stop <= ele_start. Overflow should only happen when num_turns == 1 --- tests/test_tracker.py | 9 +++++++-- xtrack/tracker.py | 12 +++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/test_tracker.py b/tests/test_tracker.py index 91481baa6..e524134ca 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -291,14 +291,19 @@ def _ele_start_to_ele_stop(line, particles_init): assert line.record_last_track.x.shape==(len(particles.x), expected_num_monitor) -# Track from any ele_start until any ele_stop that is smaller than or equal to ele_start (turn increses by one) +# Track from any ele_start until any ele_stop that is smaller than or equal to ele_start # for one, two, and ten turns def _ele_start_to_ele_stop_with_overflow(line, particles_init): n_elem = len(line.element_names) for turns in [1, 2, 10]: for start in range(n_elem): for stop in range(start+1): - expected_end_turn = turns + if stop == 0: + # last turn is a complete turn + expected_end_turn = turns + else: + # last turn is incomplete, but overflow if turns == 1 + expected_end_turn = turns if turns==1 else turns - 1 expected_end_element = stop expected_num_monitor = expected_end_turn if expected_end_element==0 else expected_end_turn+1 diff --git a/xtrack/tracker.py b/xtrack/tracker.py index 589d5d031..2ac170e29 100644 --- a/xtrack/tracker.py +++ b/xtrack/tracker.py @@ -322,7 +322,7 @@ def _track(self, particles, *args, with_progress: Union[bool, int] = False, if ele_stop is None: ele_stop = len(self.line) - if ele_start >= ele_stop: + if ele_start >= ele_stop and num_turns == 1: # we need an additional turn and space in the monitor for # the incomplete turn num_turns += 1 @@ -726,7 +726,7 @@ def _prepare_collective_track_session(self, particles, ele_start, ele_stop, if ele_stop == 0: ele_stop = None - if ele_stop is not None and ele_stop <= ele_start: + if ele_stop is not None and ele_stop <= ele_start and num_turns == 1: num_turns += 1 if ele_stop is not None: @@ -1156,6 +1156,10 @@ def _track_no_collective( else: # We are using ele_start, ele_stop, and num_turns + if isinstance(ele_stop, str): + ele_stop = self.line.element_names.index(ele_stop) + if ele_stop == 0: + ele_stop = None if num_turns is None: num_turns = 1 else: @@ -1166,11 +1170,9 @@ def _track_no_collective( num_elements_first_turn = self.num_elements - ele_start num_middle_turns = num_turns - 1 else: - if isinstance(ele_stop, str): - ele_stop = self.line.element_names.index(ele_stop) assert ele_stop >= 0 assert ele_stop <= self.num_elements - if ele_stop <= ele_start: + if ele_stop <= ele_start and num_turns == 1: # Correct for overflow: num_turns += 1 if num_turns == 1: From 5aee85569a50c91444c6d1c0c4e52e748810c807 Mon Sep 17 00:00:00 2001 From: Frederik Van der Veken Date: Wed, 7 Feb 2024 09:31:15 +0100 Subject: [PATCH 2/2] Updated tbt monitor test to expect the correct behaviour --- tests/test_tracker.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_tracker.py b/tests/test_tracker.py index e524134ca..5439aaff6 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -551,22 +551,25 @@ def test_tracking_with_progress(test_context, with_progress, turns, collective): @for_all_test_contexts @pytest.mark.parametrize( - 'ele_start,ele_stop,expected_x', + 'ele_start,ele_stop,num_turns,expected_x', [ - (None, None, [0, 0.005, 0.010, 0.015, 0.020, 0.025]), - (None, 3, [0, 0.005, 0.010, 0.015, 0.020, 0.023]), - (2, None, [0, 0.003, 0.008, 0.013, 0.018, 0.023]), - (2, 3, [0, 0.003, 0.008, 0.013, 0.018, 0.021]), - (3, 2, [0, 0.002, 0.007, 0.012, 0.017, 0.022, 0.024]), + (None, None, 5, [0, 0.005, 0.010, 0.015, 0.020, 0.025]), + (None, 3, 5, [0, 0.005, 0.010, 0.015, 0.020, 0.023]), + (2, None, 5, [0, 0.003, 0.008, 0.013, 0.018, 0.023]), + (2, 3, 5, [0, 0.003, 0.008, 0.013, 0.018, 0.021]), + (3, 2, 5, [0, 0.002, 0.007, 0.012, 0.017, 0.019]), + (2, 3, 1, [0, 0.001]), + (3, 2, 1, [0, 0.002, 0.004]), ], ) @pytest.mark.parametrize('with_progress', [False, True, 1, 2, 3]) -def test_tbt_monitor_with_progress(test_context, ele_start, ele_stop, expected_x, with_progress): +def test_tbt_monitor_with_progress(test_context, ele_start, ele_stop, num_turns, expected_x, with_progress): line = xt.Line(elements=[xt.Drift(length=1, _context=test_context)] * 5) line.build_tracker(_context=test_context) p = xt.Particles(px=0.001, _context=test_context) - line.track(p, num_turns=5, turn_by_turn_monitor=True, with_progress=with_progress, ele_start=ele_start, ele_stop=ele_stop) + line.track(p, num_turns=num_turns, turn_by_turn_monitor=True, + with_progress=with_progress, ele_start=ele_start, ele_stop=ele_stop) p.move(_context=xo.context_default) monitor_recorded_x = line.record_last_track.x