Skip to content

Commit d9a5cb6

Browse files
committed
fix some tests
1 parent cd00be1 commit d9a5cb6

File tree

3 files changed

+192
-107
lines changed

3 files changed

+192
-107
lines changed

tests/test_xwakes_kick_vs_pyheadtail.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@
77
import xwakes as xw
88
import xtrack as xt
99
import xobjects as xo
10+
from xobjects.test_helpers import for_all_test_contexts
1011

1112
test_data_folder = pathlib.Path(__file__).parent.joinpath(
1213
'../test_data').absolute()
1314

14-
def test_xwakes_kick_vs_pyheadtail_table_dipolar():
15+
@for_all_test_contexts(excluding="ContextPyopencl")
16+
def test_xwakes_kick_vs_pyheadtail_table_dipolar(test_context):
1517

1618
from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles
1719

18-
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000))
20+
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
21+
_context=test_context)
1922
p.x[p.zeta > 0] += 1e-3
2023
p.y[p.zeta > 0] += 1e-3
2124
p_table = p.copy()
@@ -31,7 +34,8 @@ def test_xwakes_kick_vs_pyheadtail_table_dipolar():
3134
'quadrupolar_xy', 'dipolar_yx', 'quadrupolar_yx',
3235
'constant_x', 'constant_y'])
3336
wake_from_table = xw.WakeFromTable(table, columns=['dipolar_x', 'dipolar_y'])
34-
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100)
37+
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100,
38+
_context=test_context)
3539

3640
# Zotter convention
3741
assert table['dipolar_x'].values[1] > 0
@@ -72,11 +76,13 @@ def test_xwakes_kick_vs_pyheadtail_table_dipolar():
7276
xo.assert_allclose(p_table.px, p_ref.px, atol=1e-30, rtol=2e-3)
7377
xo.assert_allclose(p_table.py, p_ref.py, atol=1e-30, rtol=2e-3)
7478

75-
def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
79+
@for_all_test_contexts(excluding="ContextPyopencl")
80+
def test_xwakes_kick_vs_pyheadtail_table_quadrupolar(test_context):
7681

7782
from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles
7883

79-
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000))
84+
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
85+
_context=test_context)
8086
p.x[p.zeta > 0] += 1e-3
8187
p.y[p.zeta > 0] += 1e-3
8288
p_table = p.copy()
@@ -93,7 +99,8 @@ def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
9399
'constant_x', 'constant_y'])
94100

95101
wake_from_table = xw.WakeFromTable(table, columns=['quadrupolar_x', 'quadrupolar_y'])
96-
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100)
102+
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=100,
103+
_context=test_context)
97104

98105
# This is specific of this table
99106
assert table['quadrupolar_x'].values[1] < 0
@@ -136,13 +143,16 @@ def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
136143
xo.assert_allclose(p_table.py, p_ref.py, atol=1e-30, rtol=2e-3)
137144

138145

139-
def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
146+
@for_all_test_contexts(excluding="ContextPyopencl")
147+
def test_xwakes_kick_vs_pyheadtail_table_longitudinal(test_context):
140148

141149
from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles
142150

143151
p = xt.Particles.merge([
144-
xt.Particles(p0c=7e12, zeta=np.linspace(-1e-3, 1e-3, 100000)),
145-
xt.Particles(p0c=7e12, zeta=1e-6+np.zeros(100000))
152+
xt.Particles(p0c=7e12, zeta=np.linspace(-1e-3, 1e-3, 100000),
153+
_context=test_context),
154+
xt.Particles(p0c=7e12, zeta=1e-6+np.zeros(100000),
155+
_context=test_context),
146156
])
147157

148158
p_table = p.copy()
@@ -158,7 +168,9 @@ def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
158168
'quadrupolar_xy', 'dipolar_yx', 'quadrupolar_yx',
159169
'constant_x', 'constant_y'])
160170
wake_from_table = xw.WakeFromTable(table, columns=['time', 'longitudinal'])
161-
wake_from_table.configure_for_tracking(zeta_range=(-2e-3, 2e-3), num_slices=1000)
171+
wake_from_table.configure_for_tracking(zeta_range=(-2e-3, 2e-3),
172+
num_slices=1000,
173+
_context=test_context)
162174

163175
assert len(wake_from_table.components) == 1
164176
assert wake_from_table.components[0].plane == 'z'
@@ -191,12 +203,13 @@ def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
191203
assert np.max(p_ref.delta) > 1e-12
192204
xo.assert_allclose(p_table.delta, p_ref.delta, atol=1e-14, rtol=0)
193205

194-
def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
206+
@for_all_test_contexts(excluding="ContextPyopencl")
207+
def test_xwakes_kick_vs_pyheadtail_resonator_dipolar(test_context):
195208

196209
from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles
197210

198211
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
199-
weight=1e14)
212+
weight=1e14, _context=test_context)
200213
p.x[p.zeta > 0] += 1e-3
201214
p.y[p.zeta > 0] += 1e-3
202215
p_table = p.copy()
@@ -207,7 +220,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
207220
r=1e8, q=1e7, f_r=1e9,
208221
kind=xw.Yokoya('circular'), # equivalent to: kind=['dipolar_x', 'dipolar_y'],
209222
)
210-
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
223+
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
224+
_context=test_context)
211225

212226
assert len(wake.components) == 2
213227
assert wake.components[0].plane == 'x'
@@ -237,7 +251,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
237251
table = pd.DataFrame({'time': t_samples, 'dipolar_x': w_dipole_x_samples,
238252
'dipolar_y': w_dipole_y_samples})
239253
wake_from_table = xw.WakeFromTable(table)
240-
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
254+
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
255+
_context=test_context)
241256

242257
assert len(wake_from_table.components) == 2
243258
assert wake_from_table.components[0].plane == 'x'
@@ -277,12 +292,13 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
277292
xo.assert_allclose(p_table.px, p_ref.px, rtol=0, atol=2e-3*np.max(np.abs(p_ref.px)))
278293
xo.assert_allclose(p_table.py, p_ref.py, rtol=0, atol=2e-3*np.max(np.abs(p_ref.py)))
279294

280-
def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
295+
@for_all_test_contexts(excluding="ContextPyopencl")
296+
def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar(test_context):
281297

282298
from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles
283299

284300
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
285-
weight=1e14)
301+
weight=1e14, _context=test_context)
286302
p.x[p.zeta > 0] += 1e-3
287303
p.y[p.zeta > 0] += 1e-3
288304
p_table = p.copy()
@@ -293,7 +309,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
293309
r=1e8, q=1e7, f_r=1e9,
294310
kind=['quadrupolar_x', 'quadrupolar_y'],
295311
)
296-
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
312+
wake.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
313+
_context=test_context)
297314

298315
assert len(wake.components) == 2
299316
assert wake.components[0].plane == 'x'
@@ -322,7 +339,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
322339
table = pd.DataFrame({'time': t_samples, 'quadrupolar_x': w_quadrupole_x_samples,
323340
'quadrupolar_y': w_quadrupole_y_samples})
324341
wake_from_table = xw.WakeFromTable(table)
325-
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50)
342+
wake_from_table.configure_for_tracking(zeta_range=(-1, 1), num_slices=50,
343+
_context=test_context)
326344

327345
assert len(wake_from_table.components) == 2
328346
assert wake_from_table.components[0].plane == 'x'
@@ -365,12 +383,13 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
365383
xo.assert_allclose(p_table.px, p_ref.px, rtol=0, atol=2e-3*np.max(np.abs(p_ref.px)))
366384
xo.assert_allclose(p_table.py, p_ref.py, rtol=0, atol=2e-3*np.max(np.abs(p_ref.py)))
367385

368-
def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
386+
@for_all_test_contexts(excluding="ContextPyopencl")
387+
def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal(test_context):
369388

370389
from xpart.pyheadtail_interface.pyhtxtparticles import PyHtXtParticles
371390

372391
p = xt.Particles(p0c=7e12, zeta=np.linspace(-1, 1, 100000),
373-
weight=1e14)
392+
weight=1e14, _context=test_context)
374393
p.x[p.zeta > 0] += 1e-3
375394
p.y[p.zeta > 0] += 1e-3
376395
p_table = p.copy()
@@ -381,7 +400,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
381400
r=1e8, q=1e7, f_r=1e9,
382401
kind='longitudinal'
383402
)
384-
wake.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50)
403+
wake.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50,
404+
_context=test_context)
385405

386406
assert len(wake.components) == 1
387407
assert wake.components[0].plane == 'z'
@@ -402,7 +422,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
402422
w_longitudinal_x_samples[0] *= 2 # Undo sampling weight
403423
table = pd.DataFrame({'time': t_samples, 'longitudinal': w_longitudinal_x_samples})
404424
wake_from_table = xw.WakeFromTable(table)
405-
wake_from_table.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50)
425+
wake_from_table.configure_for_tracking(zeta_range=(-1.01, 1.01), num_slices=50,
426+
_context=test_context)
406427

407428
assert len(wake_from_table.components) == 1
408429
assert wake_from_table.components[0].plane == 'z'

tests/test_xwakes_tune_shift.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@
77
import xpart as xp
88
import xwakes as xw
99
import xobjects as xo
10+
from xobjects.test_helpers import for_all_test_contexts
1011

12+
@for_all_test_contexts(excluding="ContextPyopencl")
1113
@pytest.mark.parametrize('wake_type', ['dipolar', 'quadrupolar'])
1214
@pytest.mark.parametrize('plane', ['x', 'y'])
13-
def test_tune_shift_transverse(wake_type, plane):
15+
def test_tune_shift_transverse(wake_type, plane, test_context):
1416

1517
table = pd.DataFrame({'time': [0, 10],
1618
f'{wake_type}_{plane}': [6e15, 6e15]})
1719

1820
wake = xw.WakeFromTable(table)
19-
wake.configure_for_tracking(zeta_range=(-20e-2, 20e-2), num_slices=100)
21+
wake.configure_for_tracking(zeta_range=(-20e-2, 20e-2), num_slices=100,
22+
_context=test_context)
2023

2124
assert len(wake.components) == 1
2225
assert wake.components[0].plane == plane
@@ -49,15 +52,15 @@ def test_tune_shift_transverse(wake_type, plane):
4952
line_no_wake = xt.Line(elements=[one_turn_map])
5053
line_with_wake = xt.Line(elements=[one_turn_map, wake])
5154

52-
line_no_wake.particle_ref = xt.Particles(p0c=2e9)
53-
line_with_wake.particle_ref = xt.Particles(p0c=2e9)
55+
line_no_wake.particle_ref = xt.Particles(p0c=2e9, _context=test_context)
56+
line_with_wake.particle_ref = xt.Particles(p0c=2e9, _context=test_context)
5457

55-
line_no_wake.build_tracker()
56-
line_with_wake.build_tracker()
58+
line_no_wake.build_tracker(_context=test_context)
59+
line_with_wake.build_tracker(_context=test_context)
5760

5861
p = xp.generate_matched_gaussian_bunch(line=line_no_wake,
5962
num_particles=1000, nemitt_x=1e-6, nemitt_y=1e-6, sigma_z=0.07,
60-
total_intensity_particles=1e11)
63+
total_intensity_particles=1e11, _context=test_context)
6164
p.x += 1e-3
6265
p.y += 1e-3
6366

@@ -67,8 +70,8 @@ def test_tune_shift_transverse(wake_type, plane):
6770
mean_y=lambda line, part: part.y.mean())
6871

6972
line_no_wake.track(p0, num_turns=100,
70-
log=mylog,
71-
with_progress=10)
73+
log=mylog,
74+
with_progress=10)
7275
log_no_wake = line_no_wake.log_last_track
7376

7477
line_with_wake.track(p, num_turns=100,
@@ -78,8 +81,15 @@ def test_tune_shift_transverse(wake_type, plane):
7881

7982
import nafflib as nl
8083

81-
tune_no_wake = nl.get_tune(log_no_wake[f'mean_{plane}'])
82-
tune_with_wake = nl.get_tune(log_with_wake[f'mean_{plane}'])
84+
mean_no_wake = test_context.nplike_lib.array(log_no_wake[f'mean_{plane}'])
85+
mean_with_wake = test_context.nplike_lib.array(log_with_wake[f'mean_{plane}'])
86+
87+
if hasattr(mean_no_wake, 'get'):
88+
tune_no_wake = nl.get_tune(mean_no_wake.get())
89+
tune_with_wake = nl.get_tune(mean_with_wake.get())
90+
else:
91+
tune_no_wake = nl.get_tune(mean_no_wake)
92+
tune_with_wake = nl.get_tune(mean_with_wake)
8393

8494
print(f'Tune without wake: {tune_no_wake}')
8595
print(f'Tune with wake: {tune_with_wake}')
@@ -89,13 +99,15 @@ def test_tune_shift_transverse(wake_type, plane):
8999
xo.assert_allclose(tune_no_wake, {'x': 0.28, 'y': 0.31}[plane], atol=1e-6, rtol=0)
90100
xo.assert_allclose(tune_with_wake, tune_no_wake - 2e-3, atol=0.3e-3, rtol=0)
91101

92-
def test_tune_shift_longitudinal():
102+
@for_all_test_contexts(excluding="ContextPyopencl")
103+
def test_tune_shift_longitudinal(test_context):
93104

94105
table = pd.DataFrame({'time': [0, 10],
95106
'longitudinal': [1e13, 1e13]})
96107

97108
wake = xw.WakeFromTable(table)
98-
wake.configure_for_tracking(zeta_range=(-20e-2, 20e-2), num_slices=100)
109+
wake.configure_for_tracking(zeta_range=(-20e-2, 20e-2), num_slices=100,
110+
_context=test_context)
99111

100112
assert len(wake.components) == 1
101113
assert wake.components[0].plane == 'z'
@@ -124,19 +136,20 @@ def test_tune_shift_longitudinal():
124136
line_no_wake = xt.Line(elements=[one_turn_map])
125137
line_with_wake = xt.Line(elements=[one_turn_map, wake])
126138

127-
line_no_wake.particle_ref = xt.Particles(p0c=2e9)
128-
line_with_wake.particle_ref = xt.Particles(p0c=2e9)
139+
line_no_wake.particle_ref = xt.Particles(p0c=2e9, _context=test_context)
140+
line_with_wake.particle_ref = xt.Particles(p0c=2e9, _context=test_context)
129141

130-
line_no_wake.build_tracker()
131-
line_with_wake.build_tracker()
142+
line_no_wake.build_tracker(_context=test_context)
143+
line_with_wake.build_tracker(_context=test_context)
132144

133145
p = xp.generate_matched_gaussian_bunch(line=line_no_wake,
134146
num_particles=1000, nemitt_x=1e-6, nemitt_y=1e-6, sigma_z=0.07,
135-
total_intensity_particles=1e11)
147+
total_intensity_particles=1e11,
148+
_context=test_context)
136149
p.zeta += 5e-3
137150

138151
p0 = p.copy()
139-
152+
140153
mylog = xt.Log(mean_zeta=lambda line, part: part.zeta.mean(),
141154
mean_delta=lambda line, part: part.delta.mean())
142155

@@ -146,14 +159,23 @@ def test_tune_shift_longitudinal():
146159
log_no_wake = line_no_wake.log_last_track
147160

148161
line_with_wake.track(p, num_turns=3000,
149-
log=mylog,
150-
with_progress=10)
162+
log=mylog,
163+
with_progress=10)
151164
log_with_wake = line_with_wake.log_last_track
152165

153166
import nafflib as nl
154167

155-
tune_no_wake = nl.get_tune(log_no_wake[f'mean_zeta'])
156-
tune_with_wake = nl.get_tune(log_with_wake[f'mean_zeta']-np.mean(log_with_wake[f'mean_zeta']))
168+
mean_zeta_no_wake = test_context.nplike_lib.array(log_no_wake[f'mean_zeta'])
169+
mean_zeta_with_wake = test_context.nplike_lib.array(log_with_wake[f'mean_zeta'])
170+
171+
if hasattr(mean_zeta_no_wake, 'get'):
172+
tune_no_wake = nl.get_tune(mean_zeta_no_wake.get())
173+
tune_with_wake = nl.get_tune(mean_zeta_with_wake.get() -
174+
np.mean(mean_zeta_with_wake.mean().get()))
175+
else:
176+
tune_no_wake = nl.get_tune(mean_zeta_no_wake)
177+
tune_with_wake = nl.get_tune(mean_zeta_with_wake -
178+
np.mean(mean_zeta_with_wake))
157179

158180
print(f'Tune without wake: {tune_no_wake}')
159181
print(f'Tune with wake: {tune_with_wake}')

0 commit comments

Comments
 (0)