Skip to content

Commit 412e731

Browse files
Simplification of the code
1 parent b30778f commit 412e731

File tree

1 file changed

+30
-56
lines changed

1 file changed

+30
-56
lines changed

01_element/test_drift2.py

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,56 +4,37 @@
44
import uproot
55
import numpy as np
66
import scipy.constants as const
7-
from scipy import stats
87

9-
def getSamplerData(f,samplerName, variable, mask) :
8+
def getSamplerData(f, samplerName, variable, mask) :
109
d = f[samplerName+variable].array(library='np')
1110
elem = d[0]
1211
if np.ndim(elem) == 0:
1312
d_out=np.array(d)
1413
else:
1514
d_out = np.array([x[0] for x in d if x.size>0])
1615
d_out = d_out[mask]
16+
d_out_mean=np.mean(d_out)
1717

18-
return d_out
18+
return d_out_mean
1919

20-
def get_p_value(mu0,x, rtol=1e-6,atol=1e-12):
21-
mean = np.mean(x)
22-
sigma = np.std(x, ddof=1)
23-
test_value=0
24-
if abs(mean) > atol:
25-
if np.abs(sigma/mean) <rtol:
26-
test_value=1
27-
else:
28-
if sigma< atol:
29-
test_value=1
30-
31-
if test_value:
32-
if np.isclose(mean, mu0, rtol):
33-
return 0.0 # H0 perfect
34-
else:
35-
return 1.0 # H0 clearly wrong
36-
37-
t_stat, p_value = stats.ttest_1samp(x, mu0)
38-
39-
return p_value
4020

41-
def get_theoretical_beam_parameters(ek_in,p_type,sim_env):
21+
def get_theoretical_beam_parameters(Ek_in, p_type, sim_env):
4222
c=const.c
4323
e=const.e
4424
if p_type=="e-":
4525
m=const.electron_mass
4626
elif p_type=="proton":
4727
m=const.proton_mass
48-
gamma=ek_in*e*1e9/(m*c**2)+1
28+
gamma=Ek_in*e*1e9/(m*c**2)+1
4929
v=c*np.sqrt(1-1/gamma**2)
5030
time_out_th=sim_env["l"]/v*1e9 #s
51-
e_m=m*c**2
52-
E_tot=ek_in*e*1e9+e_m
53-
pz_rel=np.sqrt(E_tot**2-e_m**2)/c
54-
pz_rel=pz_rel*c/(e*1e9) # GeV/c
31+
E_m=m*c**2
32+
E_tot=Ek_in*e*1e9+E_m
33+
pz_in=np.sqrt(E_tot**2-E_m**2)/c
34+
pz_in=pz_in*c/(e*1e9) # GeV/c
35+
36+
return E_tot, pz_in, time_out_th
5537

56-
return E_tot, pz_rel, time_out_th
5738

5839
@pytest.fixture
5940
def sim_env():
@@ -70,12 +51,11 @@ def sim_env():
7051
"l": 2 #m
7152
}
7253

54+
7355
class TestClass:
7456

75-
alpha_test=0.05
76-
7757
@pytest.mark.parametrize(
78-
"ek_in, p_type",
58+
"Ek_in, p_type",
7959
[
8060
(1e-3, "e-"),
8161
(1e3, "e-"), #GeV
@@ -84,12 +64,12 @@ class TestClass:
8464
]
8565
)
8666

87-
def test_drift2(self, sim_env, ek_in, p_type) :
88-
89-
E_tot, pz_rel, time_out_th= get_theoretical_beam_parameters(ek_in,p_type,sim_env)
67+
def test_drift2(self, sim_env, Ek_in, p_type) :
68+
69+
E_tot, pz_in, time_out_th= get_theoretical_beam_parameters(Ek_in,p_type,sim_env)
9070
data = {
9171
'LENGTH': str(sim_env["l"]),
92-
'BEAM_ENERGY' : str(ek_in), #GeV
72+
'BEAM_ENERGY' : str(Ek_in), #GeV
9373
'P_TYPE': p_type
9474
}
9575

@@ -104,34 +84,28 @@ def test_drift2(self, sim_env, ek_in, p_type) :
10484
mask=parent_ID==0 #from initial beam only
10585
n_out=np.sum(mask)
10686

107-
theta_out=getSamplerData(f,samplerName,"theta", mask)
108-
p_theta=get_p_value(sim_env["theta_in"],theta_out)
87+
theta_out_mean=getSamplerData(f,samplerName,"theta", mask)
10988

11089
a=f[samplerName+'p'].array(library='np')
11190
p_out= np.array([x[0] for x in a if len(x)>0])
11291
a=f[samplerName+'zp'].array(library='np')
11392
pz_frac= np.array([x[0] for x in a if len(x)>0])
11493
pz_out=np.multiply(pz_frac,p_out)
11594
pz_out=pz_out[mask]
116-
p_pz=get_p_value(pz_rel,pz_out)
95+
pz_out_mean=np.mean(pz_out)
11796

118-
phi_out=getSamplerData(f,samplerName,"phi", mask)
119-
p_phi=get_p_value(sim_env["phi_in"],phi_out)
97+
phi_out_mean=getSamplerData(f,samplerName,"phi", mask)
12098

121-
ek_out=getSamplerData(f,samplerName,"kineticEnergy", mask)
122-
p_ek=get_p_value(ek_in,ek_out)
99+
Ek_out_mean=getSamplerData(f,samplerName,"kineticEnergy", mask)
123100

124-
t_out=getSamplerData(f,samplerName,"T", mask)
125-
p_t=get_p_value(time_out_th,t_out)
101+
t_out_mean=getSamplerData(f,samplerName,"T", mask)
126102

127-
s_sampler=getSamplerData(f,samplerName,"S", mask)
128-
p_l=get_p_value(sim_env["l"],s_sampler)
129-
130-
assert ( p_ek< self.alpha_test)
131-
assert (p_phi< self.alpha_test)
132-
assert (p_theta< self.alpha_test)
133-
assert (sim_env["n_in"]==n_out)
134-
assert (p_pz< self.alpha_test)
135-
assert ( p_t< self.alpha_test)
136-
assert ( p_l< self.alpha_test)
103+
s_sampler_mean=getSamplerData(f,samplerName,"S", mask)
137104

105+
assert Ek_out_mean==pytest.approx(Ek_in,rel=1e-9)
106+
assert phi_out_mean==pytest.approx(sim_env["phi_in"], rel=1e-9)
107+
assert theta_out_mean==pytest.approx(sim_env["theta_in"],rel=1e-9)
108+
assert sim_env["n_in"]==n_out
109+
assert pz_out_mean==pytest.approx(pz_in,1e-5)
110+
assert t_out_mean==pytest.approx(time_out_th,1e-6)
111+
assert s_sampler_mean==pytest.approx(sim_env["l"],1-9)

0 commit comments

Comments
 (0)