Skip to content

Commit 5940b91

Browse files
committed
Switching to test_contexts
1 parent 9a9f35b commit 5940b91

File tree

3 files changed

+28
-38
lines changed

3 files changed

+28
-38
lines changed

tests/test_aperture_turn_ele_and_monitor.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
import numpy as np
22

33
import xobjects as xo
4-
from xobjects.context import available
54
import xtrack as xt
65
import xline as xl
76

87

98
def test_aperture_turn_ele_and_monitor():
109

11-
for CTX in xo.ContextCpu, xo.ContextPyopencl, xo.ContextCupy:
12-
if CTX not in available:
13-
continue
14-
15-
print(f"Test {CTX}")
16-
context = CTX()
10+
for context in xo.context.get_test_contexts():
11+
print(f"Test {context.__class__}")
1712

1813
x_aper_min = -0.1
1914
x_aper_max = 0.2
@@ -117,12 +112,8 @@ def test_aperture_turn_ele_and_monitor():
117112

118113
def test_custom_monitor():
119114

120-
for CTX in xo.ContextCpu, xo.ContextPyopencl, xo.ContextCupy:
121-
if CTX not in available:
122-
continue
123-
124-
print(f"Test {CTX}")
125-
context = CTX()
115+
for context in xo.context.get_test_contexts():
116+
print(f"Test {context.__class__}")
126117

127118
x_aper_min = -0.1
128119
x_aper_max = 0.2

tests/test_collective_tracker.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,14 @@
33
import numpy as np
44

55
import xobjects as xo
6-
from xobjects.context import available
76
import xline as xl
87
import xtrack as xt
98
import xfields as xf
109

1110
def test_collective_tracker():
1211

13-
for CTX in xo.ContextCpu, xo.ContextPyopencl, xo.ContextCupy:
14-
if CTX not in available:
15-
continue
16-
17-
print(f"Test {CTX}")
18-
context = CTX()
12+
for context in xo.context.get_test_contexts():
13+
print(f"Test {context.__class__}")
1914

2015
test_data_folder = pathlib.Path(
2116
__file__).parent.joinpath('../test_data').absolute()

tests/test_dress.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,28 @@ class ElementData(xo.Struct):
1010

1111
class Element(dress(ElementData)):
1212

13-
def __init__(self, vv):
14-
self.xoinitialize(n=len(vv), b=np.sum(vv), vv=vv)
15-
16-
ele = Element([1,2,3])
17-
assert ele.n == ele._xobject.n == 3
18-
assert ele.b == ele._xobject.b == 6
19-
assert ele.vv[1] == ele._xobject.vv[1] == 2
20-
21-
ele.vv = [7,8,9]
22-
assert ele.n == ele._xobject.n == 3
23-
assert ele.b == ele._xobject.b == 6
24-
assert ele.vv[1] == ele._xobject.vv[1] == 8
25-
26-
ele.n = 5.
27-
assert ele.n == ele._xobject.n == 5
28-
29-
ele.b = 50
30-
assert ele.b == ele._xobject.b == 50.
13+
def __init__(self, vv, **kwargs):
14+
self.xoinitialize(n=len(vv), b=np.sum(vv), vv=vv,
15+
**kwargs)
16+
for context in xo.context.get_test_contexts():
17+
print(f"Test {context.__class__}")
18+
19+
ele = Element([1,2,3], _context=context)
20+
assert ele.n == ele._xobject.n == 3
21+
assert ele.b == ele._xobject.b == 6
22+
assert ele.vv[1] == ele._xobject.vv[1] == 2
23+
24+
new_vv = context.nparray_to_context_array(np.array([7,8,9]))
25+
ele.vv = new_vv
26+
assert ele.n == ele._xobject.n == 3
27+
assert ele.b == ele._xobject.b == 6
28+
assert ele.vv[1] == ele._xobject.vv[1] == 8
29+
30+
ele.n = 5.
31+
assert ele.n == ele._xobject.n == 5
32+
33+
ele.b = 50
34+
assert ele.b == ele._xobject.b == 50.
3135

3236

3337
def test_explicit_buffer():

0 commit comments

Comments
 (0)