88from torch .distributions import MultivariateNormal , Distribution
99
1010from cheetah .particles import ParticleBeam
11- from cheetah .utils .bmadx import bmad_to_cheetah_coords
1211
1312
1413class BeamGenerator (torch .nn .Module , ABC ):
@@ -55,14 +54,15 @@ def __init__(
5554 energy : float ,
5655 base_dist : Distribution = None ,
5756 transformer : NNTransform = None ,
57+ output_scale : float = 1e-2 ,
5858 n_dim : int = 6 ,
5959 ):
6060 super (NNParticleBeamGenerator , self ).__init__ ()
6161 self .base_dist = base_dist or MultivariateNormal (
6262 torch .zeros (n_dim ), torch .eye (n_dim )
6363 )
6464 self .transformer = transformer or NNTransform (
65- 2 , 20 , output_scale = 1e-2 , phase_space_dim = n_dim
65+ 2 , 20 , output_scale = output_scale , phase_space_dim = n_dim
6666 )
6767 self .register_buffer ("beam_energy" , torch .tensor (energy ))
6868 self .register_buffer ("particle_charges" , torch .tensor (1.0 ))
@@ -80,14 +80,16 @@ def forward(self) -> ParticleBeam:
8080
8181 # create near zero coordinates into which we deposit the transformed beam
8282 # Note: these need to be near zero to maintain finite emittances
83- bmad_coords = torch .randn (len (transformed_beam ), 6 ).to (transformed_beam ) * 1e-7
84- bmad_coords [:, : transformed_beam .shape [1 ]] = transformed_beam
83+ coords = torch .randn (len (transformed_beam ), 6 ).to (transformed_beam ) * 1e-7
84+ coords [:, : transformed_beam .shape [1 ]] = transformed_beam
8585
86- transformed_beam = bmad_to_cheetah_coords (
87- bmad_coords , self . beam_energy , torch .tensor ( 0.511e6 )
86+ coords = torch . cat (
87+ ( coords , torch .ones_like ( coords [:, 0 ]. unsqueeze ( dim = - 1 ))), dim = - 1
8888 )
89+
8990 return ParticleBeam (
90- * transformed_beam ,
91+ particles = coords ,
92+ energy = self .beam_energy ,
9193 particle_charges = self .particle_charges ,
9294 survival_probabilities = self .survival_probabilities ,
9395 )
@@ -249,10 +251,14 @@ def forward(self) -> tuple[ParticleBeam, torch.Tensor]:
249251
250252 entropy = - torch .mean (log_p - log_q )
251253
252- particles , ref_energy = bmad_to_cheetah_coords (x , self .energy , self .mass )
253- particles [:, 4 ] *= - 1.0 # [TO DO] why is sign wrong?
254+ coords = torch .randn (len (x ), 6 ).to (x ) * 1e-7
255+ coords [:, : x .shape [1 ]] = x
256+ coords = torch .cat (
257+ (coords , torch .ones_like (coords [:, 0 ].unsqueeze (dim = - 1 ))), dim = - 1
258+ )
259+
254260 beam = ParticleBeam (
255- particles , energy = ref_energy , particle_charges = self .particle_charges
261+ particles = coords , energy = self . energy , particle_charges = self .particle_charges
256262 )
257263 return (beam , entropy )
258264
0 commit comments