Skip to content

Commit 44c66fc

Browse files
committed
change model handling
1 parent d896f6b commit 44c66fc

File tree

10 files changed

+44361
-90
lines changed

10 files changed

+44361
-90
lines changed

ir_amplitude_detuning/lhc_detuning_corrections.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@
3939
)
4040
from ir_amplitude_detuning.detuning.measurements import MeasureValue
4141
from ir_amplitude_detuning.detuning.targets import Target
42-
from ir_amplitude_detuning.simulation.lhc_simulation import FakeLHCBeam, LHCBeam, LHCCorrectors
42+
from ir_amplitude_detuning.simulation.lhc_simulation import (
43+
ACC_MODELS,
44+
FakeLHCBeam,
45+
LHCBeam,
46+
LHCCorrectors,
47+
pathstr,
48+
)
4349
from ir_amplitude_detuning.utilities.constants import (
4450
AMPDET_CALC_ERR_ID,
4551
AMPDET_CALC_ID,
@@ -68,12 +74,20 @@
6874
from ir_amplitude_detuning.detuning.targets import Target
6975

7076

71-
LOG = logging.getLogger(__name__)
72-
7377
LHCBeams: TypeAlias = dict[int, LHCBeam]
7478
LHCBeamsPerXing: TypeAlias = dict[str, LHCBeams]
7579

7680

81+
LOG = logging.getLogger(__name__)
82+
83+
84+
def get_optics(year: int) -> str:
85+
return {
86+
2018: pathstr("optics2018", "PROTON", "opticsfile.22_ctpps2"),
87+
2022: pathstr(ACC_MODELS, "strengths", "ATS_Nominal", "2022", "squeeze", "ats_30cm.madx")
88+
}[year]
89+
90+
7791
@dataclass(slots=True)
7892
class CorrectionResults:
7993
"""Class to store the results of a correction calculation.
@@ -97,7 +111,7 @@ def create_optics(
97111
outputdir: Path,
98112
output_id: str = '',
99113
xing: dict[str, dict] | None = None, # default set below
100-
optics: str = "round3030", # 30cm round optics
114+
optics: str | Path | None = None, # defaults to 30cm round optics
101115
year: int = 2018, # lhc year
102116
tune_x: float = 62.28, # horizontal tune
103117
tune_y: float = 60.31, # vertical tune
@@ -120,19 +134,15 @@ def create_optics(
120134
Returns:
121135
LHCBeams: The LHC beams, i.e. a dictionary of LHCBeam objects.
122136
"""
123-
# set mutable defaults ----
124-
if xing is None:
125-
xing = {'scheme': 'top'} # use top-energy crossing scheme
126-
127137
# Setup LHC for both beams -------------------------------------------------
128138
lhc_beams = {}
129139
for beam in beams:
130140
output_subdir = get_label_outputdir(outputdir, output_id, beam)
131141
lhc_beam = LHCBeam(
132142
beam=beam,
133143
outputdir=output_subdir,
134-
xing=xing,
135-
optics=optics,
144+
xing=xing or {'scheme': 'top'}, # use top-energy crossing scheme
145+
optics=optics or get_optics(year),
136146
year=year,
137147
tune_x=tune_x,
138148
tune_y=tune_y,
@@ -339,7 +349,7 @@ def check_corrections_ptc(
339349
# Below only needed if lhc_beams is None ---
340350
beams: Sequence[int] | None = None,
341351
xing: dict[str, dict] | None = None,
342-
optics: str = "round3030", # 30cm round optics
352+
optics: Path | None = None, # defaults to 30cm round optics
343353
year: int = 2018, # lhc year
344354
tune_x: float = 62.28, # horizontal tune
345355
tune_y: float = 60.31, # vertical tune
@@ -361,7 +371,7 @@ def check_corrections_ptc(
361371
lhc_beams (dict[int, LHCBeam]): Pre-run LHC beams.
362372
beams (Sequence[int]): Beams (if ``lhc_beams`` is None).
363373
xing (dict[str, dict]): Crossing scheme (if ``lhc_beams`` is `None`).
364-
optics (str): Optics (if ``lhc_beams`` is `None`).
374+
optics (Path): Path to the optics file (if ``lhc_beams`` is `None`).
365375
year (int): Year (if ``lhc_beams`` is `None`).
366376
tune_x (float): Horizontal tune (if ``lhc_beams`` is `None`).
367377
tune_y (float): Vertical tune (if ``lhc_beams`` is `None`).
@@ -371,15 +381,13 @@ def check_corrections_ptc(
371381
lhc_beams = {}
372382
if beams is None:
373383
raise ValueError("Either lhc_beams or beams must be given.")
374-
if xing is None:
375-
xing = {'scheme': 'top'}
376384

377385
for beam in beams:
378386
lhc_beam = LHCBeam(
379387
beam=beam,
380388
outputdir=get_label_outputdir(outputdir, 'tmp_ptc', beam),
381-
xing=xing,
382-
optics=optics,
389+
xing=xing or {'scheme': 'top'},
390+
optics=optics or get_optics(year),
383391
year=year,
384392
tune_x=tune_x,
385393
tune_y=tune_y,

ir_amplitude_detuning/simulation/lhc_simulation.py

Lines changed: 15 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@
3737
LOG = logging.getLogger(__name__) # setup in main()
3838
LOG_LEVEL = logging.DEBUG
3939

40-
ACC_MODELS = "acc-models-lhc"
40+
ACC_MODELS: str = "acc-models-lhc"
4141

42-
PATHS = {
43-
"db5": Path("/afs/cern.ch/eng/lhc/optics/V6.503"),
44-
"optics2016": Path("/afs/cern.ch/eng/lhc/optics/runII/2016"),
42+
PATHS: dict[str, Path] = {
4543
"optics2018": Path("/afs/cern.ch/eng/lhc/optics/runII/2018"),
4644
"optics_repo": Path("/afs/cern.ch/eng/acc-models/lhc"),
4745
ACC_MODELS: Path(ACC_MODELS),
@@ -62,45 +60,6 @@ def pathstr(key: str, *args: str) -> str:
6260
return str(PATHS[key].joinpath(*args))
6361

6462

65-
def get_optics_path(year: int, name: str | Path):
66-
"""Get optics by name, i.e. a collection of optics path-strings to the optics files.
67-
68-
Args:
69-
year (int): Year of the optics
70-
name (str, Path): Name for the optics or a path to the optics file.
71-
72-
Returns:
73-
str: Path to the optics file.
74-
"""
75-
if isinstance(name, Path):
76-
return str(name)
77-
78-
# Predefined optics paths ---
79-
optics_map = {
80-
2018: {
81-
'inj': pathstr("optics2018", "PROTON", "opticsfile.1"),
82-
'flat6015': pathstr("optics2018", 'MDflatoptics2018', 'opticsfile_flattele60cm.21'),
83-
'round3030': pathstr("optics2018", "PROTON", "opticsfile.22_ctpps2"),
84-
},
85-
2022: {
86-
'round3030': pathstr(ACC_MODELS, "strengths", "ATS_Nominal", "2022", "squeeze", "ats_30cm.madx")
87-
}
88-
}
89-
return optics_map[year][name]
90-
91-
92-
def get_wise_path(seed: int):
93-
"""Get the wise errordefinition file by seed-number.
94-
95-
Args:
96-
seed (int): Seed for the error realization.
97-
98-
Returns:
99-
str: Path to the wise errortable file.
100-
"""
101-
return pathstr('wise', f"WISE.errordef.{seed:04d}.tfs")
102-
103-
10463
def drop_allzero_columns(df: TfsDataFrame, keep: Sequence = ()) -> TfsDataFrame:
10564
"""Drop columns that contain only zeros, to save harddrive space.
10665
@@ -145,9 +104,8 @@ class LHCBeam:
145104
beam: int
146105
outputdir: Path
147106
xing: dict
148-
optics: str
107+
optics: str | Path | None
149108
year: int = 2018
150-
thin: bool = False
151109
tune_x: float = 62.28
152110
tune_y: float = 60.31
153111
chroma: float = 3
@@ -176,6 +134,15 @@ def __post_init__(self):
176134
# Define Sequence to use
177135
self.seq_name, self.seq_file, self.bv_flag = get_lhc_sequence_filename_and_bv(self.beam, accel="lhc" if self.year < 2020 else "hllhc") # `hllhc` just for naming of the sequence file, i.e. without _as_built
178136

137+
self.path_to_use = "optics2018"
138+
if self.year > 2019: # after 2019, use acc-models
139+
self.path_to_use = ACC_MODELS
140+
141+
acc_models_path = PATHS[ACC_MODELS]
142+
if acc_models_path.exists():
143+
acc_models_path.unlink()
144+
acc_models_path.symlink_to(pathstr("optics_repo", str(self.year)))
145+
179146
# Output Helper ---
180147
def output_path(self, type_: str, output_id: str, dir_: Path | None = None, suffix: str = ".tfs") -> Path:
181148
"""Returns the output path for standardized tfs names in the default output directory.
@@ -303,30 +270,13 @@ def setup_machine(self):
303270
Initialized the beam and applies optics, crossing."""
304271
self.reinstate_loggers()
305272
madx = self.madx # shorthand
306-
mvars = madx.globals # shorthand
307273

308274
# Load Macros
309-
madx.call(pathstr("optics2018", "toolkit", "macro.madx"))
275+
madx.call(pathstr(self.path_to_use, "toolkit", "macro.madx"))
310276

311277
# Lattice Setup ---------------------------------------
312278
# Load Sequence
313-
if self.year > 2019: # after 2019, use acc-models
314-
acc_models_path = PATHS[ACC_MODELS]
315-
if acc_models_path.exists():
316-
acc_models_path.unlink()
317-
acc_models_path.symlink_to(pathstr("optics_repo", str(self.year)))
318-
madx.call(pathstr(ACC_MODELS, self.seq_file))
319-
else:
320-
madx.call(pathstr("optics2018", self.seq_file))
321-
322-
# Slice Sequence
323-
if self.thin:
324-
mvars.slicefactor = 4
325-
madx.beam()
326-
madx.call(pathstr("optics2018", "toolkit", "myslice.madx"))
327-
madx.beam()
328-
madx.use(sequence=self.seq_name)
329-
madx.makethin(sequence=self.seq_name, style="teapot", makedipedge=True)
279+
madx.call(pathstr(self.path_to_use, self.seq_file))
330280

331281
# Cycling w.r.t. to IP3 (mandatory to find closed orbit in collision in the presence of errors)
332282
madx.seqedit(sequence=self.seq_name)
@@ -335,10 +285,7 @@ def setup_machine(self):
335285
madx.endedit()
336286

337287
# Define Optics and make beam
338-
madx.call(get_optics_path(self.year, self.optics))
339-
if self.optics == 'inj':
340-
mvars.NRJ = 450.000 # not defined in injection optics.1 but in the others
341-
288+
madx.call(str(self.optics))
342289
madx.beam(sequence=self.seq_name, bv=self.bv_flag,
343290
energy="NRJ", particle="proton", npart=self.n_particles,
344291
kbunch=1, ex=self.emittance, ey=self.emittance)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ test = [
6262
"pytest-cov >= 2.9",
6363
"pytest-timeout >= 1.4",
6464
"pytest-dependency >= 0.6.0",
65+
"gitpython >= 3.1", # imported as 'git', used for acc-models fixture
6566
]
6667
doc = [
6768
"sphinx >= 7.0",

tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,20 @@
99
import sys
1010
from pathlib import Path
1111

12+
import git
1213
import pytest
1314

15+
GITLAB_REPO_ACC_MODELS: str = "https://gitlab.cern.ch/acc-models/acc-models-{}.git"
1416

1517

1618
def assert_exists_and_not_empty(file_path: Path):
1719
"""Assert that a file exists and is not empty."""
1820
assert file_path.exists(), f"File {file_path} does not exist."
19-
assert file_path.stat().st_size > 0, f"File {file_path} is empty."
21+
assert file_path.stat().st_size > 0, f"File {file_path} is empty."
22+
23+
24+
def clone_acc_models(tmp_path_factory: pytest.TempPathFactory, accel: str, year: int) -> Path:
25+
""" Clone the acc-models directory for the specified accelerator from github into a temporary directory. """
26+
tmp_path = tmp_path_factory.mktemp(f"acc-models-{accel}-{year}")
27+
git.Repo.clone_from(GITLAB_REPO_ACC_MODELS.format(accel), tmp_path, branch=str(year))
28+
return tmp_path

0 commit comments

Comments
 (0)