Skip to content

Commit 6677cb9

Browse files
committed
Introduced test for random number generation
1 parent 0ca4a37 commit 6677cb9

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

examples/radiation/000_test_rng.py renamed to examples/radiation/000_random_generation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pathlib import Path
77

88
ctx = xo.ContextCpu()
9-
ctx = xo.ContextCupy()
10-
ctx = xo.ContextPyopencl()
9+
#ctx = xo.ContextCupy()
10+
#ctx = xo.ContextPyopencl()
1111

1212
part = xt.Particles(_context=ctx, p0c=6.5e12, x=[1,2,3])
1313
part._init_random_number_generator()
@@ -34,6 +34,7 @@ class TestElement(xt.BeamElement):
3434

3535
telem.track(part)
3636

37+
# Use turn-by-turn monitor to acquire some statistics
3738

3839
tracker = xt.Tracker(_buffer=telem._buffer,
3940
sequence=xl.Line(elements=[telem],

tests/test_random_gen.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
3+
import xobjects as xo
4+
import xtrack as xt
5+
import xline as xl
6+
from pathlib import Path
7+
def test_random_generation():
8+
for ctx in xo.context.get_test_contexts():
9+
print(f'{ctx}')
10+
11+
part = xt.Particles(_context=ctx, p0c=6.5e12, x=[1,2,3])
12+
part._init_random_number_generator()
13+
14+
class TestElement(xt.BeamElement):
15+
_xofields={
16+
'dummy': xo.Float64,
17+
}
18+
TestElement.XoStruct.extra_sources = [
19+
xt._pkg_root.joinpath(
20+
'random_number_generator/rng_src/base_rng.h'),
21+
xt._pkg_root.joinpath(
22+
'random_number_generator/rng_src/local_particle_rng.h'),
23+
]
24+
TestElement.XoStruct.extra_sources.append('''
25+
/*gpufun*/
26+
void TestElement_track_local_particle(
27+
TestElementData el, LocalParticle* part0){
28+
//start_per_particle_block (part0->part)
29+
double rr = LocalParticle_generate_random_double(part);
30+
LocalParticle_set_x(part, rr);
31+
//end_per_particle_block
32+
}
33+
''')
34+
35+
telem = TestElement(_context=ctx)
36+
37+
telem.track(part)
38+
39+
# Use turn-by turin monitor to acquire some statistics
40+
41+
tracker = xt.Tracker(_buffer=telem._buffer,
42+
sequence=xl.Line(elements=[telem],
43+
element_names='test_element'),
44+
save_source_as='source.c')
45+
46+
tracker.track(part, num_turns=1e6, turn_by_turn_monitor=True)
47+
48+
import matplotlib.pyplot as plt
49+
plt.close('all')
50+
for i_part in range(part._capacity):
51+
x = tracker.record_last_track.x[i_part, :]
52+
assert np.all(x>0)
53+
assert np.all(x<1)
54+
hstgm, bin_edges = np.histogram(x, bins=50, range=(0, 1), density=True)
55+
assert np.allclose(hstgm, 1, rtol=1e-10, atol=0.03)
56+

0 commit comments

Comments
 (0)