@@ -179,6 +179,10 @@ def __init__(
179179 Reference relativistic gamma
180180 beta0 : array_like of float, optional
181181 Reference relativistic beta
182+ rigidity0 : array_like of float, optional
183+ Reference magnetic rigidity [T.m]
184+ kinetic_energy0 : array_like of float, optional
185+ Reference kinetic energy [eV]
182186 mass_ratio : array_like of float, optional
183187 mass/mass0 (this is used to track particles of
184188 different species. Note that mass is the rest mass
@@ -224,7 +228,7 @@ def __init__(
224228
225229 accepted_args = set (self ._xofields .keys ()) | {
226230 'energy0' , 'tau' , 'pzeta' , 'mass_ratio' , 'mass' , 'kinetic_energy0' ,
227- '_context' , '_buffer' , '_offset' , 'p0 ' , 'name ' ,
231+ '_context' , '_buffer' , '_offset' , 'name ' , 'rigidity0 ' ,
228232 }
229233 if set (kwargs .keys ()) - accepted_args :
230234 raise NameError (f'Invalid argument(s) provided: '
@@ -236,6 +240,8 @@ def __init__(
236240 per_part_input_vars = (
237241 self .per_particle_vars +
238242 ((xo .Float64 , 'energy0' ),
243+ (xo .Float64 , 'kinetic_energy0' ),
244+ (xo .Float64 , 'rigidity0' ),
239245 (xo .Float64 , 'tau' ),
240246 (xo .Float64 , 'pzeta' ),
241247 (xo .Float64 , 'mass_ratio' ))
@@ -322,11 +328,6 @@ def __init__(
322328 self .start_tracking_at_element = kwargs .get (
323329 'start_tracking_at_element' , - 1 )
324330
325- # Init refs
326- if 'kinetic_energy0' in kwargs .keys ():
327- assert kwargs .get ('energy0' ) is None
328- kwargs ['energy0' ] = kwargs .pop ('kinetic_energy0' ) + self .mass0
329-
330331 # Ensure that all per particle inputs are numpy arrays of the same
331332 # length, and move them to the target context
332333 for xotype , field in per_part_input_vars :
@@ -351,20 +352,22 @@ def __init__(
351352 # Init independent per particle vars
352353 self .init_independent_per_part_vars (kwargs )
353354
355+ # Init chi and charge ratio
356+ self ._update_chi_charge_ratio (
357+ chi = kwargs .get ('chi' ),
358+ charge_ratio = kwargs .get ('charge_ratio' ),
359+ mass_ratio = kwargs .get ('mass_ratio' ),
360+ mask = input_mask ,
361+ )
354362
363+ # Init reference momentum and related vars
355364 self ._update_refs (
356365 p0c = kwargs .get ('p0c' ),
357366 energy0 = kwargs .get ('energy0' ),
358367 gamma0 = kwargs .get ('gamma0' ),
359368 beta0 = kwargs .get ('beta0' ),
360- mask = input_mask ,
361- )
362-
363- # Init chi and charge ratio
364- self ._update_chi_charge_ratio (
365- chi = kwargs .get ('chi' ),
366- charge_ratio = kwargs .get ('charge_ratio' ),
367- mass_ratio = kwargs .get ('mass_ratio' ),
369+ kinetic_energy0 = kwargs .get ('kinetic_energy0' ),
370+ rigidity0 = kwargs .get ('rigidity0' ),
368371 mask = input_mask ,
369372 )
370373
@@ -1235,13 +1238,25 @@ def p0c(self):
12351238 def p0c (self , value ):
12361239 self .p0c [:] = value
12371240
1241+ def _rigidity0_setitem (self , indx , val ):
1242+ ctx = self ._buffer .context
1243+ temp_rigidity0 = ctx .zeros (shape = self ._p0c .shape , dtype = np .float64 )
1244+ temp_rigidity0 [:] = np .nan
1245+ temp_rigidity0 [indx ] = val
1246+ self .update_p0c (temp_rigidity0 * clight * self .q0 )
1247+
12381248 @property
12391249 def rigidity0 (self ):
12401250 rigidity0 = self .p0c / clight / self .q0
12411251 return self ._buffer .context .linked_array_type .from_array (
12421252 rigidity0 ,
1243- mode = 'readonly' ,
1244- container = self )
1253+ mode = 'setitem_from_container' ,
1254+ container = self ,
1255+ container_setitem_name = '_rigidity0_setitem' )
1256+
1257+ @rigidity0 .setter
1258+ def rigidity0 (self , value ):
1259+ self .rigidity0 [:] = value
12451260
12461261 def update_gamma0 (self , new_gamma0 ):
12471262
@@ -1277,7 +1292,6 @@ def gamma0(self):
12771292 def gamma0 (self , value ):
12781293 self .gamma0 [:] = value
12791294
1280-
12811295
12821296 def update_beta0 (self , new_beta0 ):
12831297
@@ -1528,13 +1542,13 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
15281542 src_lines .append ('if (set_scalar){' )
15291543 for _ , vv in cls .size_vars + cls .scalar_vars :
15301544 src_lines .append (
1531- f ' ParticlesData_set_' + vv + '(dest,'
1545+ ' ParticlesData_set_' + vv + '(dest,'
15321546 f' LocalParticle_get_{ vv } (source));' )
15331547 src_lines .append ('}' )
15341548
15351549 for _ , vv in cls .per_particle_vars :
15361550 src_lines .append (
1537- f ' ParticlesData_set_' + vv + '(dest, id, '
1551+ ' ParticlesData_set_' + vv + '(dest, id, '
15381552 f' LocalParticle_get_{ vv } (source));' )
15391553 src_lines .append ('}' )
15401554 src_local_to_particles = '\n ' .join (src_lines )
@@ -1584,15 +1598,15 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
15841598 for tt , vv in cls .size_vars + cls .scalar_vars :
15851599 src_lines .append ('/*gpufun*/' )
15861600 src_lines .append (f'{ tt ._c_type } LocalParticle_get_' + vv
1587- + f '(LocalParticle* part)'
1601+ + '(LocalParticle* part)'
15881602 + '{' )
15891603 src_lines .append (f' return part->{ vv } ;' )
15901604 src_lines .append ('}' )
15911605
15921606 for tt , vv in cls .per_particle_vars :
15931607 src_lines .append ('/*gpufun*/' )
15941608 src_lines .append (f'{ tt ._c_type } LocalParticle_get_' + vv
1595- + f '(LocalParticle* part)'
1609+ + '(LocalParticle* part)'
15961610 + '{' )
15971611 src_lines .append (f' return part->{ vv } [part->ipart];' )
15981612 src_lines .append ('}' )
@@ -1610,13 +1624,13 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16101624 src_angles_lines .append (f' double const p{ xx } = LocalParticle_get_p{ xx } (part);' )
16111625 if exact == 'exact_' :
16121626 src_angles_lines .append (f' double const p{ yy } = LocalParticle_get_p{ yy } (part);' )
1613- src_angles_lines .append (f ' double const one_plus_delta = 1. + LocalParticle_get_delta(part);' )
1627+ src_angles_lines .append (' double const one_plus_delta = 1. + LocalParticle_get_delta(part);' )
16141628 src_angles_lines .append (
1615- f ' double const rpp = 1./sqrt(one_plus_delta*one_plus_delta - px*px - py*py);' )
1629+ ' double const rpp = 1./sqrt(one_plus_delta*one_plus_delta - px*px - py*py);' )
16161630 else :
1617- src_angles_lines .append (f ' double const rpp = LocalParticle_get_rpp(part);' )
1618- src_angles_lines .append (f ' // INFO: this is not the angle, but sin(angle)' )
1619- src_angles_lines .append (f ' return p{ xx } *rpp;' )
1631+ src_angles_lines .append (' double const rpp = LocalParticle_get_rpp(part);' )
1632+ src_angles_lines .append (' // INFO: this is not the angle, but sin(angle)' )
1633+ src_angles_lines .append (' return p{xx}*rpp;' )
16201634 src_angles_lines .append ('}' )
16211635 src_angles_lines .append ('' )
16221636
@@ -1625,15 +1639,15 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16251639 src_angles_lines .append ('/*gpufun*/' )
16261640 src_angles_lines .append (f'void LocalParticle_set_{ exact } { xx } p(LocalParticle* part, double { xx } p){{' )
16271641 src_angles_lines .append (f'#ifndef FREEZE_VAR_p{ xx } ' )
1628- src_angles_lines .append (f ' double rpp = LocalParticle_get_rpp(part);' )
1642+ src_angles_lines .append (' double rpp = LocalParticle_get_rpp(part);' )
16291643 if exact == 'exact_' :
16301644 src_angles_lines .append (
16311645 f' // Careful! If { yy } p also changes, use LocalParticle_set_{ exact } xp_yp!' )
16321646 src_angles_lines .append (f' double const { yy } p = LocalParticle_get_{ exact } { yy } p(part);' )
1633- src_angles_lines .append (f ' rpp *= sqrt(1 + xp*xp + yp*yp);' )
1647+ src_angles_lines .append (' rpp *= sqrt(1 + xp*xp + yp*yp);' )
16341648 src_angles_lines .append (f' // INFO: { xx } p is not the angle, but sin(angle)' )
16351649 src_angles_lines .append (f' LocalParticle_set_p{ xx } (part, { xx } p/rpp);' )
1636- src_angles_lines .append (f '#endif' )
1650+ src_angles_lines .append ('#endif' )
16371651 src_angles_lines .append ('}' )
16381652 src_angles_lines .append ('' )
16391653
@@ -1644,7 +1658,7 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16441658 src_angles_lines .append (f'#ifndef FREEZE_VAR_p{ xx } ' )
16451659 src_angles_lines .append (f' LocalParticle_set_{ exact } { xx } p(part, '
16461660 + f'LocalParticle_get_{ exact } { xx } p(part) + { xx } p);' )
1647- src_angles_lines .append (f '#endif' )
1661+ src_angles_lines .append ('#endif' )
16481662 src_angles_lines .append ('}' )
16491663 src_angles_lines .append ('' )
16501664 # Scaler
@@ -1653,19 +1667,19 @@ def gen_local_particle_api(cls, mode='no_local_copy'):
16531667 src_angles_lines .append (f'#ifndef FREEZE_VAR_p{ xx } ' )
16541668 src_angles_lines .append (f' LocalParticle_set_{ exact } { xx } p(part, '
16551669 + f'LocalParticle_get_{ exact } { xx } p(part) * value);' )
1656- src_angles_lines .append (f '#endif' )
1670+ src_angles_lines .append ('#endif' )
16571671 src_angles_lines .append ('}' )
16581672 src_angles_lines .append ('' )
16591673 # Double setter, adder, scaler
16601674 src_angles_lines .append ('/*gpufun*/' )
16611675 src_angles_lines .append (f'void LocalParticle_set_{ exact } xp_yp(LocalParticle* part, double xp, double yp){{' )
1662- src_angles_lines .append (f ' double rpp = LocalParticle_get_rpp(part);' )
1676+ src_angles_lines .append (' double rpp = LocalParticle_get_rpp(part);' )
16631677 if exact == 'exact_' :
1664- src_angles_lines .append (f ' rpp *= sqrt(1 + xp*xp + yp*yp);' )
1678+ src_angles_lines .append (' rpp *= sqrt(1 + xp*xp + yp*yp);' )
16651679 for xx in ['x' , 'y' ]:
16661680 src_angles_lines .append (f'#ifndef FREEZE_VAR_p{ xx } ' )
16671681 src_angles_lines .append (f' LocalParticle_set_p{ xx } (part, { xx } p/rpp);' )
1668- src_angles_lines .append (f '#endif' )
1682+ src_angles_lines .append ('#endif' )
16691683 src_angles_lines .append ('}' )
16701684 src_angles_lines .append ('' )
16711685 src_angles_lines .append ('/*gpufun*/' )
@@ -2099,8 +2113,10 @@ def _setattr_if_consistent(self, varname, given_value, computed_value,
20992113 getattr (self , varname )[mask ] = target_val [mask ]
21002114
21012115 def _update_refs (self , p0c = None , energy0 = None , gamma0 = None , beta0 = None ,
2116+ kinetic_energy0 = None , rigidity0 = None ,
21022117 mask = None ):
2103- if not any (ff is not None for ff in (p0c , energy0 , gamma0 , beta0 )):
2118+ if not any (ff is not None for ff in (p0c , energy0 , gamma0 , beta0 ,
2119+ kinetic_energy0 , rigidity0 )):
21042120 self ._p0c = 1e9
21052121 p0c = self ._p0c
21062122
@@ -2125,6 +2141,16 @@ def _update_refs(self, p0c=None, energy0=None, gamma0=None, beta0=None,
21252141 _energy0 = self .mass0 * _gamma0
21262142 _p0c = _energy0 * beta0
21272143 _beta0 = beta0
2144+ elif kinetic_energy0 is not None :
2145+ _energy0 = kinetic_energy0 + self .mass0
2146+ _p0c = _sqrt (_energy0 ** 2 - self .mass0 ** 2 )
2147+ _beta0 = _p0c / _energy0
2148+ _gamma0 = _energy0 / self .mass0
2149+ elif rigidity0 is not None :
2150+ _p0c = rigidity0 * abs (self .q0 ) * clight
2151+ _energy0 = _sqrt (_p0c ** 2 + self .mass0 ** 2 )
2152+ _beta0 = _p0c / _energy0
2153+ _gamma0 = _energy0 / self .mass0
21282154 else :
21292155 raise RuntimeError ('This statement is unreachable.' )
21302156
0 commit comments