77"""
88from __future__ import annotations
99
10- from dataclasses import dataclass
1110import re
1211import os
1312from pathlib import Path
3130 TWISS_ELEMENTS_DAT ,
3231)
3332from omc3 .utils import logging_tools
34- from omc3 .utils .iotools import PathOrStr
33+ from omc3 .utils .iotools import PathOrStr , find_file
3534from generic_parser .entry_datatypes import get_multi_class
3635
3736
3837LOG = logging_tools .get_logger (__name__ )
3938CURRENT_DIR = Path (__file__ ).parent
4039
4140
42- class AccExcitationMode :
41+ class AccExcitationMode : # TODO: use enum! (jdilly, 2025)
4342 # it is very important that FREE = 0
4443 FREE , ACD , ADT = range (3 )
4544
@@ -63,11 +62,10 @@ class Accelerator:
6362 LOCAL_REPO_NAME : str | None = None
6463 # RE_DICT needs to use MAD-X compatible regex patterns (jdilly, 2021)
6564 RE_DICT : dict [str , str ] = {
66- AccElementTypes .BPMS : r".*" ,
65+ AccElementTypes .BPMS : r"^B .*" ,
6766 AccElementTypes .MAGNETS : r".*" ,
68- AccElementTypes .ARC_BPMS : r".*" ,
67+ AccElementTypes .ARC_BPMS : r"^B .*" ,
6968 }
70- BPM_INITIAL : str = "B"
7169
7270 @staticmethod
7371 def get_parameters ():
@@ -184,10 +182,8 @@ def init_from_model_dir(self, model_dir: Path) -> None:
184182 try :
185183 self .model = tfs .read (model_dir / TWISS_DAT , index = NAME )
186184 except IOError :
187- bpm_index = [
188- idx for idx in self .elements .index .to_numpy () if idx .startswith (self .BPM_INITIAL )
189- ]
190- self .model = self .elements .loc [bpm_index , :]
185+ bpm_mask = self .elements .index .str .match (self .RE_DICT [AccElementTypes .BPMS ])
186+ self .model = self .elements .loc [bpm_mask , :]
191187 self .nat_tunes = [float (self .model .headers ["Q1" ]), float (self .model .headers ["Q2" ])]
192188 self .energy = float (self .model .headers ["ENERGY" ]) # always 450GeV because we do not set it anywhere properly...
193189
@@ -231,46 +227,19 @@ def init_from_model_dir(self, model_dir: Path) -> None:
231227 if errordefspath .is_file ():
232228 self .error_defs_file = errordefspath
233229
234- # Class methods ###########################################
235-
236- @ classmethod
237- def get_element_types_mask ( cls , list_of_elements : list [ str ], types ) -> numpy . ndarray :
230+ def find_modifier ( self , modifier : Path | str ):
231+ """ Try to find a modifier file, which might be given only by its name.
232+ By default this is looking for full-path, model-dir and in the acc-models-path,
233+ but should probably be overwritten by the accelerator sub-classes.
238234 """
239- Returns a boolean mask for elements in ``list_of_elements`` that belong to any of the
240- specified types.
241- Needs to handle: `bpm`, `magnet`, `arc_bpm` (see :class:`AccElementTypes` )
235+ dirs = []
236+ if self . model_dir is not None :
237+ dirs . append ( self . model_dir )
242238
243- Args:
244- list_of_elements: list of elements.
245- types: the kinds of elements to look for.
239+ if self .acc_model_path is not None :
240+ dirs .append (self .acc_model_path )
246241
247- Returns:
248- A boolean array of elements of specified kinds.
249- """
250- unknown_elements = [ty for ty in types if ty not in cls .RE_DICT ]
251- if len (unknown_elements ):
252- raise TypeError (f"Unknown element(s): '{ unknown_elements } '" )
253- series = pd .Series (list_of_elements )
254- mask = series .str .match (cls .RE_DICT [types [0 ]], case = False )
255- for ty in types [1 :]:
256- mask = mask | series .str .match (cls .RE_DICT [ty ], case = False )
257- return mask .to_numpy ()
258-
259- @classmethod
260- def get_variables (cls , frm = None , to = None , classes = None ):
261- """
262- Gets the variables with elements in the given range and the given classes. ``None`` means
263- everything.
264- """
265- raise NotImplementedError ("A function should have been overwritten, check stack trace." )
266-
267- @classmethod
268- def get_correctors_variables (cls , frm = None , to = None , classes = None ):
269- """
270- Returns the set of corrector variables between ``frm`` and ``to``, with classes in
271- classes. ``None`` means select all.
272- """
273- raise NotImplementedError ("A function should have been overwritten, check stack trace." )
242+ return find_file (modifier , dirs = dirs )
274243
275244 @property
276245 def beam_direction (self ) -> int :
@@ -322,6 +291,39 @@ def model_driven(self, value):
322291 raise AcceleratorDefinitionError ("Driven model cannot be set for accelerator with free excitation mode." )
323292 self ._model_driven = value
324293
294+ # Class methods ###########################################
295+
296+ @classmethod
297+ def get_element_types_mask (cls , list_of_elements : list [str ], types ) -> numpy .ndarray :
298+ """
299+ Returns a boolean mask for elements in ``list_of_elements`` that belong to any of the
300+ specified types.
301+ Needs to handle: `bpm`, `magnet`, `arc_bpm` (see :class:`AccElementTypes`)
302+
303+ Args:
304+ list_of_elements: list of elements.
305+ types: the kinds of elements to look for.
306+
307+ Returns:
308+ A boolean array of elements of specified kinds.
309+ """
310+ unknown_elements = [ty for ty in types if ty not in cls .RE_DICT ]
311+ if len (unknown_elements ):
312+ raise TypeError (f"Unknown element(s): '{ unknown_elements } '" )
313+ series = pd .Series (list_of_elements )
314+ mask = series .str .match (cls .RE_DICT [types [0 ]], case = False )
315+ for ty in types [1 :]:
316+ mask = mask | series .str .match (cls .RE_DICT [ty ], case = False )
317+ return mask .to_numpy ()
318+
319+ @classmethod
320+ def get_variables (cls , frm = None , to = None , classes = None ):
321+ """
322+ Gets the variables with elements in the given range and the given classes. ``None`` means
323+ everything.
324+ """
325+ raise NotImplementedError ("A function should have been overwritten, check stack trace." )
326+
325327 @classmethod
326328 def get_dir (cls ) -> Path :
327329 """Default directory for accelerator. Should be overwritten if more specific."""
@@ -340,7 +342,6 @@ def get_file(cls, filename: str) -> Path:
340342 f"File { file_path .name } not available for accelerator { cls .NAME } ."
341343 )
342344
343-
344345 ##########################################################################
345346
346347
@@ -359,24 +360,37 @@ class AcceleratorDefinitionError(Exception):
359360def _get_modifiers_from_modeldir (model_dir : Path ) -> list [Path ]:
360361 """Parse modifiers from job.create_model.madx or use modifiers.madx file."""
361362 job_file = model_dir / JOB_MODEL_MADX_NOMINAL
362- if job_file .exists ():
363- job_madx = job_file .read_text ()
364-
365- # find modifier tag in lines and return called file in these lines
366- # the modifier tag is used by the model creator to mark which line defines modifiers
367- # see e.g. `get_base_madx_script()` in `lhc.py`
368- # example for a match to the regex: `call, file = 'modifiers.madx'; MODIFIER_TAG`
369- modifiers = re .findall (
370- fr"\s*call,\s*file\s*=\s*[\"\']?([^;\'\"]+)[\"\']?\s*;\s*{ MODIFIER_TAG } " ,
371- job_madx ,
372- flags = re .IGNORECASE ,
373- )
374- modifiers = [Path (m ) for m in modifiers ]
375- return modifiers or None
363+ if job_file .is_file ():
364+ return find_called_files_with_tag (job_file , MODIFIER_TAG )
376365
377366 # Legacy
378367 modifiers_file = model_dir / MODIFIERS_MADX
379368 if modifiers_file .exists (): # legacy
380369 return [modifiers_file ]
381370
382371 return None
372+
373+
374+ def find_called_files_with_tag (madx_file : Path , tag : str ) -> list [Path ] | None :
375+ """ Parse lines that call a file and are tagged with the given tag and return
376+ a list of paths to these files.
377+
378+ This is mainly used to find the modifier tag in lines and return called file in these lines.
379+
380+ The modifier tag is used by the model creator to mark which line defines modifiers
381+ see e.g. `get_base_madx_script()` in `lhc.py`
382+
383+ An example for a match to the regex: `call, file = 'modifiers.madx'; !@modifiers`.
384+ """
385+ if not madx_file .is_file ():
386+ return None
387+
388+ job_madx = madx_file .read_text ()
389+
390+ called_files = re .findall (
391+ fr"\s*call,\s*file\s*=\s*[\"\']?([^;\'\"]+)[\"\']?\s*;\s*{ tag } " ,
392+ job_madx ,
393+ flags = re .IGNORECASE ,
394+ )
395+ called_files = [Path (m ) for m in called_files ]
396+ return called_files or None
0 commit comments