Skip to content

Commit b601f84

Browse files
authored
Merge pull request #68 from roussel-ryan/67-small-changes
small changes in `beams.py` and `train.py`
2 parents 5a63040 + 3468b94 commit b601f84

File tree

2 files changed

+32
-23
lines changed

2 files changed

+32
-23
lines changed

gpsr/beams.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch.distributions import MultivariateNormal, Distribution
99

1010
from cheetah.particles import ParticleBeam
11-
from cheetah.utils.bmadx import bmad_to_cheetah_coords
1211

1312

1413
class BeamGenerator(torch.nn.Module, ABC):
@@ -55,14 +54,15 @@ def __init__(
5554
energy: float,
5655
base_dist: Distribution = None,
5756
transformer: NNTransform = None,
57+
output_scale: float = 1e-2,
5858
n_dim: int = 6,
5959
):
6060
super(NNParticleBeamGenerator, self).__init__()
6161
self.base_dist = base_dist or MultivariateNormal(
6262
torch.zeros(n_dim), torch.eye(n_dim)
6363
)
6464
self.transformer = transformer or NNTransform(
65-
2, 20, output_scale=1e-2, phase_space_dim=n_dim
65+
2, 20, output_scale=output_scale, phase_space_dim=n_dim
6666
)
6767
self.register_buffer("beam_energy", torch.tensor(energy))
6868
self.register_buffer("particle_charges", torch.tensor(1.0))
@@ -80,14 +80,16 @@ def forward(self) -> ParticleBeam:
8080

8181
# create near zero coordinates into which we deposit the transformed beam
8282
# Note: these need to be near zero to maintain finite emittances
83-
bmad_coords = torch.randn(len(transformed_beam), 6).to(transformed_beam) * 1e-7
84-
bmad_coords[:, : transformed_beam.shape[1]] = transformed_beam
83+
coords = torch.randn(len(transformed_beam), 6).to(transformed_beam) * 1e-7
84+
coords[:, : transformed_beam.shape[1]] = transformed_beam
8585

86-
transformed_beam = bmad_to_cheetah_coords(
87-
bmad_coords, self.beam_energy, torch.tensor(0.511e6)
86+
coords = torch.cat(
87+
(coords, torch.ones_like(coords[:, 0].unsqueeze(dim=-1))), dim=-1
8888
)
89+
8990
return ParticleBeam(
90-
*transformed_beam,
91+
particles=coords,
92+
energy=self.beam_energy,
9193
particle_charges=self.particle_charges,
9294
survival_probabilities=self.survival_probabilities,
9395
)
@@ -249,10 +251,14 @@ def forward(self) -> tuple[ParticleBeam, torch.Tensor]:
249251

250252
entropy = -torch.mean(log_p - log_q)
251253

252-
particles, ref_energy = bmad_to_cheetah_coords(x, self.energy, self.mass)
253-
particles[:, 4] *= -1.0 # [TO DO] why is sign wrong?
254+
coords = torch.randn(len(x), 6).to(x) * 1e-7
255+
coords[:, : x.shape[1]] = x
256+
coords = torch.cat(
257+
(coords, torch.ones_like(coords[:, 0].unsqueeze(dim=-1))), dim=-1
258+
)
259+
254260
beam = ParticleBeam(
255-
particles, energy=ref_energy, particle_charges=self.particle_charges
261+
particles=coords, energy=self.energy, particle_charges=self.particle_charges
256262
)
257263
return (beam, entropy)
258264

gpsr/train.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@ def training_step(self, batch, batch_idx):
3737
f"prediction {i} shape {pred[i].shape} does not match target shape {y[i].shape}"
3838
)
3939

40-
# normalize images
41-
y_normalized = [normalize_images(y_ele) for y_ele in y]
42-
pred_normalized = [normalize_images(pred_ele) for pred_ele in pred]
43-
4440
# add up the loss functions from each prediction
4541
loss = 0.0
46-
for y_ele, pred_ele in zip(y_normalized, pred_normalized):
47-
loss += torch.sum(self.loss_func(y_ele, pred_ele))
42+
for y_ele, pred_ele in zip(y, pred):
43+
# normalize images
44+
y_ele_norm = normalize_images(y_ele)
45+
pred_ele_norm = normalize_images(pred_ele)
46+
# add loss
47+
loss += self.loss_func(y_ele_norm, pred_ele_norm)
48+
49+
# normalize loss by number of outputs
50+
loss /= len(y)
4851

4952
# log the loss function at the end of each epoch
5053
self.log("loss", loss, on_epoch=True)
@@ -57,7 +60,7 @@ def configure_optimizers(self):
5760

5861

5962
def train_gpsr(
60-
model,
63+
gpsr_model: GPSR,
6164
train_dataloader,
6265
n_epochs: int = 100,
6366
lr: float = 1e-3,
@@ -72,7 +75,7 @@ def train_gpsr(
7275
7376
Arguments
7477
---------
75-
model: GPSRModel
78+
gpsr_model: GPSR
7679
GPSR model to be trained.
7780
train_dataloader: DataLoader
7881
DataLoader for the training data.
@@ -88,8 +91,8 @@ def train_gpsr(
8891
8992
Returns
9093
-------
91-
model: GPSRModel
92-
Trained GPSR model.
94+
lit_gpsr_model: LitGPSR
95+
Trained LitGPSR model.
9396
9497
"""
9598

@@ -112,13 +115,13 @@ def train_gpsr(
112115
**kwargs,
113116
)
114117

115-
gpsr_model = LitGPSR(model, lr, loss_func=loss_func)
118+
lit_gpsr_model = LitGPSR(gpsr_model, lr, loss_func=loss_func)
116119
trainer.fit(
117-
gpsr_model,
120+
lit_gpsr_model,
118121
train_dataloader,
119122
)
120123

121-
return model
124+
return lit_gpsr_model
122125

123126

124127
class EntropyLitGPSR(L.LightningModule, ABC):

0 commit comments

Comments
 (0)