77import xwakes as xw
88import xtrack as xt
99import xobjects as xo
10+ from xobjects .test_helpers import for_all_test_contexts
1011
1112test_data_folder = pathlib .Path (__file__ ).parent .joinpath (
1213 '../test_data' ).absolute ()
1314
14- def test_xwakes_kick_vs_pyheadtail_table_dipolar ():
15+ @for_all_test_contexts (excluding = "ContextPyopencl" )
16+ def test_xwakes_kick_vs_pyheadtail_table_dipolar (test_context ):
1517
1618 from xpart .pyheadtail_interface .pyhtxtparticles import PyHtXtParticles
1719
18- p = xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1 , 1 , 100000 ))
20+ p = xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1 , 1 , 100000 ),
21+ _context = test_context )
1922 p .x [p .zeta > 0 ] += 1e-3
2023 p .y [p .zeta > 0 ] += 1e-3
2124 p_table = p .copy ()
@@ -31,7 +34,8 @@ def test_xwakes_kick_vs_pyheadtail_table_dipolar():
3134 'quadrupolar_xy' , 'dipolar_yx' , 'quadrupolar_yx' ,
3235 'constant_x' , 'constant_y' ])
3336 wake_from_table = xw .WakeFromTable (table , columns = ['dipolar_x' , 'dipolar_y' ])
34- wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 100 )
37+ wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 100 ,
38+ _context = test_context )
3539
3640 # Zotter convention
3741 assert table ['dipolar_x' ].values [1 ] > 0
@@ -72,11 +76,13 @@ def test_xwakes_kick_vs_pyheadtail_table_dipolar():
7276 xo .assert_allclose (p_table .px , p_ref .px , atol = 1e-30 , rtol = 2e-3 )
7377 xo .assert_allclose (p_table .py , p_ref .py , atol = 1e-30 , rtol = 2e-3 )
7478
75- def test_xwakes_kick_vs_pyheadtail_table_quadrupolar ():
79+ @for_all_test_contexts (excluding = "ContextPyopencl" )
80+ def test_xwakes_kick_vs_pyheadtail_table_quadrupolar (test_context ):
7681
7782 from xpart .pyheadtail_interface .pyhtxtparticles import PyHtXtParticles
7883
79- p = xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1 , 1 , 100000 ))
84+ p = xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1 , 1 , 100000 ),
85+ _context = test_context )
8086 p .x [p .zeta > 0 ] += 1e-3
8187 p .y [p .zeta > 0 ] += 1e-3
8288 p_table = p .copy ()
@@ -93,7 +99,8 @@ def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
9399 'constant_x' , 'constant_y' ])
94100
95101 wake_from_table = xw .WakeFromTable (table , columns = ['quadrupolar_x' , 'quadrupolar_y' ])
96- wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 100 )
102+ wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 100 ,
103+ _context = test_context )
97104
98105 # This is specific of this table
99106 assert table ['quadrupolar_x' ].values [1 ] < 0
@@ -136,13 +143,16 @@ def test_xwakes_kick_vs_pyheadtail_table_quadrupolar():
136143 xo .assert_allclose (p_table .py , p_ref .py , atol = 1e-30 , rtol = 2e-3 )
137144
138145
139- def test_xwakes_kick_vs_pyheadtail_table_longitudinal ():
146+ @for_all_test_contexts (excluding = "ContextPyopencl" )
147+ def test_xwakes_kick_vs_pyheadtail_table_longitudinal (test_context ):
140148
141149 from xpart .pyheadtail_interface .pyhtxtparticles import PyHtXtParticles
142150
143151 p = xt .Particles .merge ([
144- xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1e-3 , 1e-3 , 100000 )),
145- xt .Particles (p0c = 7e12 , zeta = 1e-6 + np .zeros (100000 ))
152+ xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1e-3 , 1e-3 , 100000 ),
153+ _context = test_context ),
154+ xt .Particles (p0c = 7e12 , zeta = 1e-6 + np .zeros (100000 ),
155+ _context = test_context ),
146156 ])
147157
148158 p_table = p .copy ()
@@ -158,7 +168,9 @@ def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
158168 'quadrupolar_xy' , 'dipolar_yx' , 'quadrupolar_yx' ,
159169 'constant_x' , 'constant_y' ])
160170 wake_from_table = xw .WakeFromTable (table , columns = ['time' , 'longitudinal' ])
161- wake_from_table .configure_for_tracking (zeta_range = (- 2e-3 , 2e-3 ), num_slices = 1000 )
171+ wake_from_table .configure_for_tracking (zeta_range = (- 2e-3 , 2e-3 ),
172+ num_slices = 1000 ,
173+ _context = test_context )
162174
163175 assert len (wake_from_table .components ) == 1
164176 assert wake_from_table .components [0 ].plane == 'z'
@@ -191,12 +203,13 @@ def test_xwakes_kick_vs_pyheadtail_table_longitudinal():
191203 assert np .max (p_ref .delta ) > 1e-12
192204 xo .assert_allclose (p_table .delta , p_ref .delta , atol = 1e-14 , rtol = 0 )
193205
194- def test_xwakes_kick_vs_pyheadtail_resonator_dipolar ():
206+ @for_all_test_contexts (excluding = "ContextPyopencl" )
207+ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar (test_context ):
195208
196209 from xpart .pyheadtail_interface .pyhtxtparticles import PyHtXtParticles
197210
198211 p = xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1 , 1 , 100000 ),
199- weight = 1e14 )
212+ weight = 1e14 , _context = test_context )
200213 p .x [p .zeta > 0 ] += 1e-3
201214 p .y [p .zeta > 0 ] += 1e-3
202215 p_table = p .copy ()
@@ -207,7 +220,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
207220 r = 1e8 , q = 1e7 , f_r = 1e9 ,
208221 kind = xw .Yokoya ('circular' ), # equivalent to: kind=['dipolar_x', 'dipolar_y'],
209222 )
210- wake .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 )
223+ wake .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 ,
224+ _context = test_context )
211225
212226 assert len (wake .components ) == 2
213227 assert wake .components [0 ].plane == 'x'
@@ -237,7 +251,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
237251 table = pd .DataFrame ({'time' : t_samples , 'dipolar_x' : w_dipole_x_samples ,
238252 'dipolar_y' : w_dipole_y_samples })
239253 wake_from_table = xw .WakeFromTable (table )
240- wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 )
254+ wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 ,
255+ _context = test_context )
241256
242257 assert len (wake_from_table .components ) == 2
243258 assert wake_from_table .components [0 ].plane == 'x'
@@ -277,12 +292,13 @@ def test_xwakes_kick_vs_pyheadtail_resonator_dipolar():
277292 xo .assert_allclose (p_table .px , p_ref .px , rtol = 0 , atol = 2e-3 * np .max (np .abs (p_ref .px )))
278293 xo .assert_allclose (p_table .py , p_ref .py , rtol = 0 , atol = 2e-3 * np .max (np .abs (p_ref .py )))
279294
280- def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar ():
295+ @for_all_test_contexts (excluding = "ContextPyopencl" )
296+ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar (test_context ):
281297
282298 from xpart .pyheadtail_interface .pyhtxtparticles import PyHtXtParticles
283299
284300 p = xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1 , 1 , 100000 ),
285- weight = 1e14 )
301+ weight = 1e14 , _context = test_context )
286302 p .x [p .zeta > 0 ] += 1e-3
287303 p .y [p .zeta > 0 ] += 1e-3
288304 p_table = p .copy ()
@@ -293,7 +309,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
293309 r = 1e8 , q = 1e7 , f_r = 1e9 ,
294310 kind = ['quadrupolar_x' , 'quadrupolar_y' ],
295311 )
296- wake .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 )
312+ wake .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 ,
313+ _context = test_context )
297314
298315 assert len (wake .components ) == 2
299316 assert wake .components [0 ].plane == 'x'
@@ -322,7 +339,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
322339 table = pd .DataFrame ({'time' : t_samples , 'quadrupolar_x' : w_quadrupole_x_samples ,
323340 'quadrupolar_y' : w_quadrupole_y_samples })
324341 wake_from_table = xw .WakeFromTable (table )
325- wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 )
342+ wake_from_table .configure_for_tracking (zeta_range = (- 1 , 1 ), num_slices = 50 ,
343+ _context = test_context )
326344
327345 assert len (wake_from_table .components ) == 2
328346 assert wake_from_table .components [0 ].plane == 'x'
@@ -365,12 +383,13 @@ def test_xwakes_kick_vs_pyheadtail_resonator_quadrupolar():
365383 xo .assert_allclose (p_table .px , p_ref .px , rtol = 0 , atol = 2e-3 * np .max (np .abs (p_ref .px )))
366384 xo .assert_allclose (p_table .py , p_ref .py , rtol = 0 , atol = 2e-3 * np .max (np .abs (p_ref .py )))
367385
368- def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal ():
386+ @for_all_test_contexts (excluding = "ContextPyopencl" )
387+ def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal (test_context ):
369388
370389 from xpart .pyheadtail_interface .pyhtxtparticles import PyHtXtParticles
371390
372391 p = xt .Particles (p0c = 7e12 , zeta = np .linspace (- 1 , 1 , 100000 ),
373- weight = 1e14 )
392+ weight = 1e14 , _context = test_context )
374393 p .x [p .zeta > 0 ] += 1e-3
375394 p .y [p .zeta > 0 ] += 1e-3
376395 p_table = p .copy ()
@@ -381,7 +400,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
381400 r = 1e8 , q = 1e7 , f_r = 1e9 ,
382401 kind = 'longitudinal'
383402 )
384- wake .configure_for_tracking (zeta_range = (- 1.01 , 1.01 ), num_slices = 50 )
403+ wake .configure_for_tracking (zeta_range = (- 1.01 , 1.01 ), num_slices = 50 ,
404+ _context = test_context )
385405
386406 assert len (wake .components ) == 1
387407 assert wake .components [0 ].plane == 'z'
@@ -402,7 +422,8 @@ def test_xwakes_kick_vs_pyheadtail_resonator_longitudinal():
402422 w_longitudinal_x_samples [0 ] *= 2 # Undo sampling weight
403423 table = pd .DataFrame ({'time' : t_samples , 'longitudinal' : w_longitudinal_x_samples })
404424 wake_from_table = xw .WakeFromTable (table )
405- wake_from_table .configure_for_tracking (zeta_range = (- 1.01 , 1.01 ), num_slices = 50 )
425+ wake_from_table .configure_for_tracking (zeta_range = (- 1.01 , 1.01 ), num_slices = 50 ,
426+ _context = test_context )
406427
407428 assert len (wake_from_table .components ) == 1
408429 assert wake_from_table .components [0 ].plane == 'z'
0 commit comments