Skip to content

Commit 9c02072

Browse files
authored
Merge pull request #704 from rdemaria/expose-masses
expose masses
2 parents 7832b16 + 58a3335 commit 9c02072

File tree

4 files changed

+75
-38
lines changed

4 files changed

+75
-38
lines changed

tests/test_particles_basics.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,14 @@ class KickAngExact(xt.BeamElement):
697697
particles.move(_context=xo.ContextCpu())
698698
xo.assert_allclose(particles.px, [23.99302e-3, 18.01805e-3], atol=1e-14, rtol=5e-7)
699699
xo.assert_allclose(particles.py, [-1.81976e-3, -4.73529e-3], atol=1e-14, rtol=5e-7)
700+
701+
702+
@for_all_test_contexts
703+
def test_update_rigidity0(test_context):
704+
p = xt.Particles("proton",p0c=7000e9)
705+
pb82 = xt.Particles("Pb208",q0=82,rigidity0=p.rigidity0)
706+
assert np.allclose(pb82.p0c/82,p.p0c)
707+
pb82.rigidity0=p.rigidity0
708+
assert np.allclose(pb82.p0c/82,p.p0c)
709+
710+

xtrack/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .general import _pkg_root, _print, START, END
77

88
from .particles import (Particles, PROTON_MASS_EV, ELECTRON_MASS_EV,
9-
enable_pyheadtail_interface, disable_pyheadtail_interface)
9+
enable_pyheadtail_interface, disable_pyheadtail_interface, masses)
1010

1111
from .base_element import BeamElement, Replica
1212
from .beam_elements import *

xtrack/environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,8 +1861,7 @@ def __setattr__(self, key, value):
18611861

18621862
def copy(self, **kwargs):
18631863
return self._resolved.copy(**kwargs)
1864-
1865-
1864+
18661865
class EnvVars:
18671866

18681867
def __init__(self, env):
@@ -2200,3 +2199,4 @@ def _disable_name_clash_checks(env):
22002199
yield
22012200
finally:
22022201
env._enable_name_clash_check = old_value
2202+

xtrack/particles/particles.py

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def __init__(
179179
Reference relativistic gamma
180180
beta0 : array_like of float, optional
181181
Reference relativistic beta
182+
rigidity0 : array_like of float, optional
183+
Reference magnetic rigidity [T.m]
184+
kinetic_energy0 : array_like of float, optional
185+
Reference kinetic energy [eV]
182186
mass_ratio : array_like of float, optional
183187
mass/mass0 (this is used to track particles of
184188
different species. Note that mass is the rest mass
@@ -224,7 +228,7 @@ def __init__(
224228

225229
accepted_args = set(self._xofields.keys()) | {
226230
'energy0', 'tau', 'pzeta', 'mass_ratio', 'mass', 'kinetic_energy0',
227-
'_context', '_buffer', '_offset', 'p0', 'name',
231+
'_context', '_buffer', '_offset', 'name', 'rigidity0',
228232
}
229233
if set(kwargs.keys()) - accepted_args:
230234
raise NameError(f'Invalid argument(s) provided: '
@@ -236,6 +240,8 @@ def __init__(
236240
per_part_input_vars = (
237241
self.per_particle_vars +
238242
((xo.Float64, 'energy0'),
243+
(xo.Float64, 'kinetic_energy0'),
244+
(xo.Float64, 'rigidity0'),
239245
(xo.Float64, 'tau'),
240246
(xo.Float64, 'pzeta'),
241247
(xo.Float64, 'mass_ratio'))
@@ -322,11 +328,6 @@ def __init__(
322328
self.start_tracking_at_element = kwargs.get(
323329
'start_tracking_at_element', -1)
324330

325-
# Init refs
326-
if 'kinetic_energy0' in kwargs.keys():
327-
assert kwargs.get('energy0') is None
328-
kwargs['energy0'] = kwargs.pop('kinetic_energy0') + self.mass0
329-
330331
# Ensure that all per particle inputs are numpy arrays of the same
331332
# length, and move them to the target context
332333
for xotype, field in per_part_input_vars:
@@ -351,20 +352,22 @@ def __init__(
351352
# Init independent per particle vars
352353
self.init_independent_per_part_vars(kwargs)
353354

355+
# Init chi and charge ratio
356+
self._update_chi_charge_ratio(
357+
chi=kwargs.get('chi'),
358+
charge_ratio=kwargs.get('charge_ratio'),
359+
mass_ratio=kwargs.get('mass_ratio'),
360+
mask=input_mask,
361+
)
354362

363+
# Init reference momentum and related vars
355364
self._update_refs(
356365
p0c=kwargs.get('p0c'),
357366
energy0=kwargs.get('energy0'),
358367
gamma0=kwargs.get('gamma0'),
359368
beta0=kwargs.get('beta0'),
360-
mask=input_mask,
361-
)
362-
363-
# Init chi and charge ratio
364-
self._update_chi_charge_ratio(
365-
chi=kwargs.get('chi'),
366-
charge_ratio=kwargs.get('charge_ratio'),
367-
mass_ratio=kwargs.get('mass_ratio'),
369+
kinetic_energy0=kwargs.get('kinetic_energy0'),
370+
rigidity0=kwargs.get('rigidity0'),
368371
mask=input_mask,
369372
)
370373

@@ -1235,13 +1238,25 @@ def p0c(self):
12351238
def p0c(self, value):
12361239
self.p0c[:] = value
12371240

1241+
def _rigidity0_setitem(self, indx, val):
1242+
ctx = self._buffer.context
1243+
temp_rigidity0 = ctx.zeros(shape=self._p0c.shape, dtype=np.float64)
1244+
temp_rigidity0[:] = np.nan
1245+
temp_rigidity0[indx] = val
1246+
self.update_p0c(temp_rigidity0 * clight * self.q0)
1247+
12381248
@property
12391249
def rigidity0(self):
12401250
rigidity0 = self.p0c / clight / self.q0
12411251
return self._buffer.context.linked_array_type.from_array(
12421252
rigidity0,
1243-
mode='readonly',
1244-
container=self)
1253+
mode='setitem_from_container',
1254+
container=self,
1255+
container_setitem_name='_rigidity0_setitem')
1256+
1257+
@rigidity0.setter
1258+
def rigidity0(self, value):
1259+
self.rigidity0[:] = value
12451260

12461261
def update_gamma0(self, new_gamma0):
12471262

@@ -1277,7 +1292,6 @@ def gamma0(self):
12771292
def gamma0(self, value):
12781293
self.gamma0[:] = value
12791294

1280-
12811295

12821296
def update_beta0(self, new_beta0):
12831297

@@ -1528,13 +1542,13 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
15281542
src_lines.append('if (set_scalar){')
15291543
for _, vv in cls.size_vars + cls.scalar_vars:
15301544
src_lines.append(
1531-
f' ParticlesData_set_' + vv + '(dest,'
1545+
' ParticlesData_set_' + vv + '(dest,'
15321546
f' LocalParticle_get_{vv}(source));')
15331547
src_lines.append('}')
15341548

15351549
for _, vv in cls.per_particle_vars:
15361550
src_lines.append(
1537-
f' ParticlesData_set_' + vv + '(dest, id, '
1551+
' ParticlesData_set_' + vv + '(dest, id, '
15381552
f' LocalParticle_get_{vv}(source));')
15391553
src_lines.append('}')
15401554
src_local_to_particles = '\n'.join(src_lines)
@@ -1584,15 +1598,15 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
15841598
for tt, vv in cls.size_vars + cls.scalar_vars:
15851599
src_lines.append('/*gpufun*/')
15861600
src_lines.append(f'{tt._c_type} LocalParticle_get_' + vv
1587-
+ f'(LocalParticle* part)'
1601+
+ '(LocalParticle* part)'
15881602
+ '{')
15891603
src_lines.append(f' return part->{vv};')
15901604
src_lines.append('}')
15911605

15921606
for tt, vv in cls.per_particle_vars:
15931607
src_lines.append('/*gpufun*/')
15941608
src_lines.append(f'{tt._c_type} LocalParticle_get_' + vv
1595-
+ f'(LocalParticle* part)'
1609+
+ '(LocalParticle* part)'
15961610
+ '{')
15971611
src_lines.append(f' return part->{vv}[part->ipart];')
15981612
src_lines.append('}')
@@ -1610,13 +1624,13 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16101624
src_angles_lines.append(f' double const p{xx} = LocalParticle_get_p{xx}(part);')
16111625
if exact == 'exact_':
16121626
src_angles_lines.append(f' double const p{yy} = LocalParticle_get_p{yy}(part);')
1613-
src_angles_lines.append(f' double const one_plus_delta = 1. + LocalParticle_get_delta(part);')
1627+
src_angles_lines.append(' double const one_plus_delta = 1. + LocalParticle_get_delta(part);')
16141628
src_angles_lines.append(
1615-
f' double const rpp = 1./sqrt(one_plus_delta*one_plus_delta - px*px - py*py);')
1629+
' double const rpp = 1./sqrt(one_plus_delta*one_plus_delta - px*px - py*py);')
16161630
else:
1617-
src_angles_lines.append(f' double const rpp = LocalParticle_get_rpp(part);')
1618-
src_angles_lines.append(f' // INFO: this is not the angle, but sin(angle)')
1619-
src_angles_lines.append(f' return p{xx}*rpp;')
1631+
src_angles_lines.append(' double const rpp = LocalParticle_get_rpp(part);')
1632+
src_angles_lines.append(' // INFO: this is not the angle, but sin(angle)')
1633+
src_angles_lines.append(' return p{xx}*rpp;')
16201634
src_angles_lines.append('}')
16211635
src_angles_lines.append('')
16221636

@@ -1625,15 +1639,15 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16251639
src_angles_lines.append('/*gpufun*/')
16261640
src_angles_lines.append(f'void LocalParticle_set_{exact}{xx}p(LocalParticle* part, double {xx}p){{')
16271641
src_angles_lines.append(f'#ifndef FREEZE_VAR_p{xx}')
1628-
src_angles_lines.append(f' double rpp = LocalParticle_get_rpp(part);')
1642+
src_angles_lines.append(' double rpp = LocalParticle_get_rpp(part);')
16291643
if exact == 'exact_':
16301644
src_angles_lines.append(
16311645
f' // Careful! If {yy}p also changes, use LocalParticle_set_{exact}xp_yp!')
16321646
src_angles_lines.append(f' double const {yy}p = LocalParticle_get_{exact}{yy}p(part);')
1633-
src_angles_lines.append(f' rpp *= sqrt(1 + xp*xp + yp*yp);')
1647+
src_angles_lines.append(' rpp *= sqrt(1 + xp*xp + yp*yp);')
16341648
src_angles_lines.append(f' // INFO: {xx}p is not the angle, but sin(angle)')
16351649
src_angles_lines.append(f' LocalParticle_set_p{xx}(part, {xx}p/rpp);')
1636-
src_angles_lines.append(f'#endif')
1650+
src_angles_lines.append('#endif')
16371651
src_angles_lines.append('}')
16381652
src_angles_lines.append('')
16391653

@@ -1644,7 +1658,7 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16441658
src_angles_lines.append(f'#ifndef FREEZE_VAR_p{xx}')
16451659
src_angles_lines.append(f' LocalParticle_set_{exact}{xx}p(part, '
16461660
+ f'LocalParticle_get_{exact}{xx}p(part) + {xx}p);')
1647-
src_angles_lines.append(f'#endif')
1661+
src_angles_lines.append('#endif')
16481662
src_angles_lines.append('}')
16491663
src_angles_lines.append('')
16501664
# Scaler
@@ -1653,19 +1667,19 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16531667
src_angles_lines.append(f'#ifndef FREEZE_VAR_p{xx}')
16541668
src_angles_lines.append(f' LocalParticle_set_{exact}{xx}p(part, '
16551669
+ f'LocalParticle_get_{exact}{xx}p(part) * value);')
1656-
src_angles_lines.append(f'#endif')
1670+
src_angles_lines.append('#endif')
16571671
src_angles_lines.append('}')
16581672
src_angles_lines.append('')
16591673
# Double setter, adder, scaler
16601674
src_angles_lines.append('/*gpufun*/')
16611675
src_angles_lines.append(f'void LocalParticle_set_{exact}xp_yp(LocalParticle* part, double xp, double yp){{')
1662-
src_angles_lines.append(f' double rpp = LocalParticle_get_rpp(part);')
1676+
src_angles_lines.append(' double rpp = LocalParticle_get_rpp(part);')
16631677
if exact == 'exact_':
1664-
src_angles_lines.append(f' rpp *= sqrt(1 + xp*xp + yp*yp);')
1678+
src_angles_lines.append(' rpp *= sqrt(1 + xp*xp + yp*yp);')
16651679
for xx in ['x', 'y']:
16661680
src_angles_lines.append(f'#ifndef FREEZE_VAR_p{xx}')
16671681
src_angles_lines.append(f' LocalParticle_set_p{xx}(part, {xx}p/rpp);')
1668-
src_angles_lines.append(f'#endif')
1682+
src_angles_lines.append('#endif')
16691683
src_angles_lines.append('}')
16701684
src_angles_lines.append('')
16711685
src_angles_lines.append('/*gpufun*/')
@@ -2099,8 +2113,10 @@ def _setattr_if_consistent(self, varname, given_value, computed_value,
20992113
getattr(self, varname)[mask] = target_val[mask]
21002114

21012115
def _update_refs(self, p0c=None, energy0=None, gamma0=None, beta0=None,
2116+
kinetic_energy0=None, rigidity0=None,
21022117
mask=None):
2103-
if not any(ff is not None for ff in (p0c, energy0, gamma0, beta0)):
2118+
if not any(ff is not None for ff in (p0c, energy0, gamma0, beta0,
2119+
kinetic_energy0, rigidity0)):
21042120
self._p0c = 1e9
21052121
p0c = self._p0c
21062122

@@ -2125,6 +2141,16 @@ def _update_refs(self, p0c=None, energy0=None, gamma0=None, beta0=None,
21252141
_energy0 = self.mass0 * _gamma0
21262142
_p0c = _energy0 * beta0
21272143
_beta0 = beta0
2144+
elif kinetic_energy0 is not None:
2145+
_energy0 = kinetic_energy0 + self.mass0
2146+
_p0c = _sqrt(_energy0 ** 2 - self.mass0 ** 2)
2147+
_beta0 = _p0c / _energy0
2148+
_gamma0 = _energy0 / self.mass0
2149+
elif rigidity0 is not None:
2150+
_p0c = rigidity0 * abs(self.q0) * clight
2151+
_energy0 = _sqrt(_p0c ** 2 + self.mass0 ** 2)
2152+
_beta0 = _p0c / _energy0
2153+
_gamma0 = _energy0 / self.mass0
21282154
else:
21292155
raise RuntimeError('This statement is unreachable.')
21302156

0 commit comments

Comments
 (0)