Skip to content

Commit abb9dcb

Browse files
committed
Update tests
1 parent d1d57ee commit abb9dcb

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

orbit_tools/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,9 @@
99
from . import sim
1010
from . import utils
1111
from .core import *
12+
from .bunch import *
13+
from .diag import Diagnostic
14+
from .diag import BunchHistogram
15+
from .lattice import *
16+
from .sim import *
17+
from .utils import *

tests/test_coupling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def make_bunch(mass: float = 0.938, energy: float = 1.000) -> Bunch:
15-
bunch = Bunch(mass=mass, energy=energy)
15+
bunch = Bunch()
1616
bunch.mass(mass)
1717
bunch.getSyncParticle().kinEnergy(energy)
1818
return bunch

tests/test_diag.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import numpy as np
3+
import pytest
4+
5+
from orbit.core.bunch import Bunch
6+
from orbit.lattice import AccLattice
7+
from orbit.lattice import AccNode
8+
from orbit.teapot import DriftTEAPOT
9+
from orbit.teapot import QuadTEAPOT
10+
from orbit.teapot import TEAPOT_Lattice
11+
12+
from orbit_tools.bunch import set_bunch_coords
13+
from orbit_tools.diag import BunchHistogram
14+
15+
16+
def test_hist():
17+
nbins = 100
18+
seed = 123
19+
20+
rng = np.random.default_rng(seed)
21+
x = rng.normal(size=(10_000, 6))
22+
23+
bunch = Bunch()
24+
bunch.mass(0.938)
25+
bunch.getSyncParticle().kinEnergy(1.000)
26+
bunch.macroSize(1.0)
27+
bunch = set_bunch_coords(bunch, x)
28+
29+
axis_list = []
30+
for i in range(6):
31+
axis_list.append((i,))
32+
33+
for i in range(6):
34+
for j in range(i):
35+
axis_list.append((i, j))
36+
37+
for axis in axis_list:
38+
ndim = len(axis)
39+
shape = tuple(ndim * [nbins])
40+
limits = ndim * [(-5.0, 5.0)]
41+
42+
# Compute histogram using BunchHistogram
43+
hist = BunchHistogram(axis=axis, shape=shape, limits=limits)
44+
values = hist.compute_histogram(bunch)
45+
values = values / np.max(values)
46+
47+
# Compute histogram using NumPy
48+
values_np, _ = np.histogramdd(x[:, axis], bins=hist.edges)
49+
values_np = values_np / np.max(values_np)
50+
51+
# Compare the histograms. There will be differences because Grid
52+
# classes use weighting.
53+
print(np.max(np.abs(values - values_np)))
54+
55+
56+

tests/test_ring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def make_bunch(mass: float = 0.938, energy: float = 1.000) -> Bunch:
19-
bunch = Bunch(mass=mass, energy=energy)
19+
bunch = Bunch()
2020
bunch.mass(mass)
2121
bunch.getSyncParticle().kinEnergy(energy)
2222
return bunch

tests/test_sim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def make_bunch(mass: float = 0.938, energy: float = 1.000) -> Bunch:
15-
bunch = Bunch(mass=mass, energy=energy)
15+
bunch = Bunch()
1616
bunch.mass(mass)
1717
bunch.getSyncParticle().kinEnergy(energy)
1818
return bunch

0 commit comments

Comments
 (0)