Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,19 @@ def _ele_start_to_ele_stop(line, particles_init):
assert line.record_last_track.x.shape==(len(particles.x), expected_num_monitor)


# Track from any ele_start until any ele_stop that is smaller than or equal to ele_start (turn increses by one)
# Track from any ele_start until any ele_stop that is smaller than or equal to ele_start
# for one, two, and ten turns
def _ele_start_to_ele_stop_with_overflow(line, particles_init):
n_elem = len(line.element_names)
for turns in [1, 2, 10]:
for start in range(n_elem):
for stop in range(start+1):
expected_end_turn = turns
if stop == 0:
# last turn is a complete turn
expected_end_turn = turns
else:
# last turn is incomplete, but overflow if turns == 1
expected_end_turn = turns if turns==1 else turns - 1
expected_end_element = stop
expected_num_monitor = expected_end_turn if expected_end_element==0 else expected_end_turn+1

Expand Down Expand Up @@ -546,22 +551,25 @@ def test_tracking_with_progress(test_context, with_progress, turns, collective):

@for_all_test_contexts
@pytest.mark.parametrize(
'ele_start,ele_stop,expected_x',
'ele_start,ele_stop,num_turns,expected_x',
[
(None, None, [0, 0.005, 0.010, 0.015, 0.020, 0.025]),
(None, 3, [0, 0.005, 0.010, 0.015, 0.020, 0.023]),
(2, None, [0, 0.003, 0.008, 0.013, 0.018, 0.023]),
(2, 3, [0, 0.003, 0.008, 0.013, 0.018, 0.021]),
(3, 2, [0, 0.002, 0.007, 0.012, 0.017, 0.022, 0.024]),
(None, None, 5, [0, 0.005, 0.010, 0.015, 0.020, 0.025]),
(None, 3, 5, [0, 0.005, 0.010, 0.015, 0.020, 0.023]),
(2, None, 5, [0, 0.003, 0.008, 0.013, 0.018, 0.023]),
(2, 3, 5, [0, 0.003, 0.008, 0.013, 0.018, 0.021]),
(3, 2, 5, [0, 0.002, 0.007, 0.012, 0.017, 0.019]),
(2, 3, 1, [0, 0.001]),
(3, 2, 1, [0, 0.002, 0.004]),
],
)
@pytest.mark.parametrize('with_progress', [False, True, 1, 2, 3])
def test_tbt_monitor_with_progress(test_context, ele_start, ele_stop, expected_x, with_progress):
def test_tbt_monitor_with_progress(test_context, ele_start, ele_stop, num_turns, expected_x, with_progress):
line = xt.Line(elements=[xt.Drift(length=1, _context=test_context)] * 5)
line.build_tracker(_context=test_context)

p = xt.Particles(px=0.001, _context=test_context)
line.track(p, num_turns=5, turn_by_turn_monitor=True, with_progress=with_progress, ele_start=ele_start, ele_stop=ele_stop)
line.track(p, num_turns=num_turns, turn_by_turn_monitor=True,
with_progress=with_progress, ele_start=ele_start, ele_stop=ele_stop)
p.move(_context=xo.context_default)

monitor_recorded_x = line.record_last_track.x
Expand Down
12 changes: 7 additions & 5 deletions xtrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _track(self, particles, *args, with_progress: Union[bool, int] = False,
if ele_stop is None:
ele_stop = len(self.line)

if ele_start >= ele_stop:
if ele_start >= ele_stop and num_turns == 1:
# we need an additional turn and space in the monitor for
# the incomplete turn
num_turns += 1
Expand Down Expand Up @@ -726,7 +726,7 @@ def _prepare_collective_track_session(self, particles, ele_start, ele_stop,
if ele_stop == 0:
ele_stop = None

if ele_stop is not None and ele_stop <= ele_start:
if ele_stop is not None and ele_stop <= ele_start and num_turns == 1:
num_turns += 1

if ele_stop is not None:
Expand Down Expand Up @@ -1156,6 +1156,10 @@ def _track_no_collective(

else:
# We are using ele_start, ele_stop, and num_turns
if isinstance(ele_stop, str):
ele_stop = self.line.element_names.index(ele_stop)
if ele_stop == 0:
ele_stop = None
if num_turns is None:
num_turns = 1
else:
Expand All @@ -1166,11 +1170,9 @@ def _track_no_collective(
num_elements_first_turn = self.num_elements - ele_start
num_middle_turns = num_turns - 1
else:
if isinstance(ele_stop, str):
ele_stop = self.line.element_names.index(ele_stop)
assert ele_stop >= 0
assert ele_stop <= self.num_elements
if ele_stop <= ele_start:
if ele_stop <= ele_start and num_turns == 1:
# Correct for overflow:
num_turns += 1
if num_turns == 1:
Expand Down