88import xtrack as xt
99
1010from .base_collimator import BaseCollimator , InvalidCollimator
11- from ..scattering_routines .everest import Material , CrystalMaterial , EverestEngine
11+ from ..scattering_routines .everest import GeneralMaterial , Material , CrystalMaterial , EverestEngine
1212from ..general import _pkg_root
1313
1414
2020# only activated around the track command. Furthermore, because of 'iscollective = False' we need to specify
2121# get_backtrack_element. We want it nicer..
2222
23+ # TODO: _per_particle_kernels should be a normal kernel (such that we don't need to pass a dummy Particles() )
24+
2325class EverestCollimator (BaseCollimator ):
2426 _xofields = { ** BaseCollimator ._xofields ,
2527 '_material' : Material ,
@@ -41,6 +43,13 @@ class EverestCollimator(BaseCollimator):
4143 _pkg_root .joinpath ('beam_elements' ,'collimators_src' ,'everest_collimator.h' )
4244 ]
4345
46+ _per_particle_kernels = {
47+ '_EverestCollimator_set_material' : xo .Kernel (
48+ c_name = 'EverestCollimator_set_material' ,
49+ args = []
50+ )
51+ }
52+
4453
4554 def __init__ (self , ** kwargs ):
4655 if '_xobject' not in kwargs :
@@ -54,8 +63,8 @@ def __init__(self, **kwargs):
5463 kwargs .setdefault ('rutherford_rng' , xt .RandomRutherford ())
5564 kwargs .setdefault ('_tracking' , True )
5665 super ().__init__ (** kwargs )
57- # if '_xobject' not in kwargs:
58- # self.random_generator.set_rutherford_by_xcoll_material(self.material )
66+ if '_xobject' not in kwargs :
67+ self ._EverestCollimator_set_material ( xp . Particles () )
5968
6069 @property
6170 def material (self ):
@@ -67,7 +76,7 @@ def material(self, material):
6776 if not isinstance ('material' , dict ) or material ['__class__' ] != "Material" :
6877 raise ValueError ("Invalid material!" )
6978 self ._material = material
70- # self.random_generator.set_rutherford_by_xcoll_material(material )
79+ self ._EverestCollimator_set_material ( xp . Particles () )
7180
7281 def get_backtrack_element (self , _context = None , _buffer = None , _offset = None ):
7382 # TODO: this should be an InvalidCollimator
@@ -103,6 +112,13 @@ class EverestCrystal(BaseCollimator):
103112 _pkg_root .joinpath ('beam_elements' ,'collimators_src' ,'everest_crystal.h' )
104113 ]
105114
115+ _per_particle_kernels = {
116+ '_EverestCrystal_set_material' : xo .Kernel (
117+ c_name = 'EverestCrystal_set_material' ,
118+ args = []
119+ )
120+ }
121+
106122
107123 def __init__ (self , ** kwargs ):
108124 if '_xobject' not in kwargs :
@@ -122,8 +138,8 @@ def __init__(self, **kwargs):
122138 kwargs .setdefault ('rutherford_rng' , xt .RandomRutherford ())
123139 kwargs .setdefault ('_tracking' , True )
124140 super ().__init__ (** kwargs )
125- # if '_xobject' not in kwargs:
126- # self.random_generator.set_rutherford_by_xcoll_material(self.material )
141+ if '_xobject' not in kwargs :
142+ self ._EverestCrystal_set_material ( xp . Particles () )
127143
128144 @property
129145 def lattice (self ):
@@ -148,7 +164,7 @@ def material(self, material):
148164 if not isinstance (material , dict ) or material ['__class__' ] != "CrystalMaterial" :
149165 raise ValueError ("Invalid material!" )
150166 self ._material = material
151- # self.random_generator.set_rutherford_by_xcoll_material(material )
167+ self ._EverestCrystal_set_material ( xp . Particles () )
152168
153169 def get_backtrack_element (self , _context = None , _buffer = None , _offset = None ):
154170 # TODO: this should be an InvalidCollimator
0 commit comments