Skip to content

Commit ef648ea

Browse files
authored
Merge branch 'master' into 395-screen-image-orientation-is-incorrect
2 parents 44e6b34 + 9bcb5d7 commit ef648ea

File tree

6 files changed

+54
-28
lines changed

6 files changed

+54
-28
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
### 🚀 Features
1111

1212
- Implement `split` method for the `Solenoid` element (see #380) (@cr-xu)
13-
- Implement a more robust RPN parser, fixing a bug where short strings in an Elegant variable definition would cause parsing to fail. (see #387) (@amylizzle, @Hespe, @jank324)
13+
- Implement a more robust RPN parser, fixing a bug where short strings in an Elegant variable definition would cause parsing to fail. (see #387, #417) (@amylizzle, @Hespe, @jank324)
1414
- Add a `Sextupole` element (see #406) (@jank324, @Hespe)
1515

1616
### 🐛 Bug fixes
@@ -21,6 +21,7 @@
2121
- Fix issue where `Quadrupole.tracking_method` was not preserved on cloning (see #404) (@RemiLehe, @jank324)
2222
- The vertical screen misalignment is now correctly applied to `y` instead of `px` (see #405) (@RemiLehe)
2323
- Fix issues when generating screen images caused by the sign of particle charges (see #394) (@Hespe, @jank324)
24+
- Fix an issue where newer versions of `torch` only accept a `torch.Tensor` as input to `torch.rad2deg` (see #417) (@jank324)
2425
- Fix bug that caused correlations to be lost in the conversion from a `ParameterBeam` to a `ParticleBeam` (see #408) (@jank324, @Hespe)
2526

2627
### 🐆 Other

cheetah/accelerator/segment.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def from_ocelot(
304304
:param cell: Ocelot cell, i.e. a list of Ocelot elements to be converted.
305305
:param name: Unique identifier for the entire segment.
306306
:param warnings: Whether to print warnings when objects are not supported by
307-
Cheetah or converted with potentially unexpected behavior.
307+
Cheetah or converted with potentially unexpected behaviour.
308308
:param device: Device to place the lattice elements on.
309309
:param dtype: Data type to use for the lattice elements.
310310
:return: Cheetah segment closely resembling the Ocelot cell.
@@ -324,6 +324,7 @@ def from_bmad(
324324
cls,
325325
bmad_lattice_file_path: str,
326326
environment_variables: dict | None = None,
327+
warnings: bool = True,
327328
device: torch.device | None = None,
328329
dtype: torch.dtype | None = None,
329330
) -> "Segment":
@@ -338,13 +339,15 @@ def from_bmad(
338339
:param bmad_lattice_file_path: Path to the Bmad lattice file.
339340
:param environment_variables: Dictionary of environment variables to use when
340341
parsing the lattice file.
342+
:param warnings: Whether to print warnings when elements or expressions are not
343+
supported by Cheetah or converted with potentially unexpected behaviour.
341344
:param device: Device to place the lattice elements on.
342345
:param dtype: Data type to use for the lattice elements.
343346
:return: Cheetah `Segment` representing the Bmad lattice.
344347
"""
345348
bmad_lattice_file_path = Path(bmad_lattice_file_path)
346349
return bmad.convert_lattice_to_cheetah(
347-
bmad_lattice_file_path, environment_variables, device, dtype
350+
bmad_lattice_file_path, environment_variables, warnings, device, dtype
348351
)
349352

350353
@classmethod

cheetah/converters/bmad.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,10 @@ def convert_element(
185185
return cheetah.Cavity(
186186
length=torch.tensor(bmad_parsed["l"], **factory_kwargs),
187187
voltage=torch.tensor(bmad_parsed.get("voltage", 0.0), **factory_kwargs),
188-
phase=torch.tensor(
189-
-torch.rad2deg(bmad_parsed.get("phi0", 0.0) * 2 * torch.pi),
190-
**factory_kwargs,
188+
phase=-torch.rad2deg(
189+
torch.tensor(bmad_parsed.get("phi0", 0.0), **factory_kwargs)
190+
* 2
191+
* torch.pi
191192
),
192193
frequency=torch.tensor(bmad_parsed["rf_frequency"], **factory_kwargs),
193194
name=name,
@@ -285,6 +286,7 @@ def convert_element(
285286
def convert_lattice_to_cheetah(
286287
bmad_lattice_file_path: Path,
287288
environment_variables: dict | None = None,
289+
warnings: bool = True,
288290
device: torch.device | None = None,
289291
dtype: torch.dtype | None = None,
290292
) -> "cheetah.Element":
@@ -299,6 +301,8 @@ def convert_lattice_to_cheetah(
299301
:param bmad_lattice_file_path: Path to the Bmad lattice file.
300302
:param environment_variables: Dictionary of environment variables to use when
301303
parsing the lattice file.
304+
:param warnings: Whether to print warnings when elements or expressions are not
305+
supported by Cheetah or converted with potentially unexpected behaviour.
302306
:param device: Device to use for the lattice. If `None`, the current default device
303307
of PyTorch is used.
304308
:param dtype: Data type to use for the lattice. If `None`, the current default dtype
@@ -337,7 +341,7 @@ def convert_lattice_to_cheetah(
337341
), "Merging lines should never produce more lines than there were before."
338342

339343
# Parse the lattice file(s), i.e. basically execute them
340-
context = parse_lines(merged_lines)
344+
context = parse_lines(merged_lines, warnings)
341345

342346
# Convert the parsed lattice info to Cheetah elements
343347
return convert_element(context["__use__"], context, device, dtype)

cheetah/converters/utils/fortran_namelist.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import os
23
import re
34
from copy import deepcopy
@@ -94,13 +95,15 @@ def merge_delimiter_continued_lines(
9495
return merged_lines
9596

9697

97-
def evaluate_expression(expression: str, context: dict) -> Any:
98+
def evaluate_expression(expression: str, context: dict, warnings: bool = True) -> Any:
9899
"""
99100
Evaluate an expression in the context of a dictionary of variables.
100101
101102
:param expression: Expression to evaluate.
102103
:param context: Dictionary of variables to evaluate the expression in the context
103104
of.
105+
:param warnings: Whether to print warnings for unrecognised expressions that might
106+
lead to unexpected behaviour when parsed as strings.
104107
:return: Result of evaluating the expression.
105108
"""
106109

@@ -146,10 +149,11 @@ def evaluate_expression(expression: str, context: dict) -> Any:
146149

147150
return eval(expression, context)
148151
except SyntaxError:
149-
print(
150-
f"WARNING: Could not evaluate expression {expression}. It will now be "
151-
f"treated as a string. This may lead to unexpected behaviour."
152-
)
152+
if warnings:
153+
print(
154+
f"WARNING: Could not evaluate expression {expression}. It will now "
155+
"be treated as a string. This may lead to unexpected behaviour."
156+
)
153157
return expression
154158
except Exception as e:
155159
print(expression)
@@ -180,13 +184,15 @@ def resolve_object_name_wildcard(wildcard_pattern: str, context: dict) -> list:
180184
return type_matching_keys
181185

182186

183-
def assign_property(line: str, context: dict) -> dict:
187+
def assign_property(line: str, context: dict, warnings: bool = True) -> dict:
184188
"""
185189
Assign a property of an element to the context.
186190
187191
:param line: Line of a property assignment to be parsed.
188192
:param context: Dictionary of variables to assign the property to and from which to
189193
read variables.
194+
:param warnings: Whether to print warnings for unrecognised expressions that might
195+
lead to unexpected behaviour when parsed as strings.
190196
:return: Updated context.
191197
"""
192198
pattern = r"([a-z0-9_\*:]+)\[([a-z0-9_%]+)\]\s*=(.*)"
@@ -201,7 +207,7 @@ def assign_property(line: str, context: dict) -> dict:
201207
else:
202208
object_names = [object_name]
203209

204-
expression_result = evaluate_expression(property_expression, context)
210+
expression_result = evaluate_expression(property_expression, context, warnings)
205211

206212
for name in object_names:
207213
if name not in context:
@@ -211,13 +217,15 @@ def assign_property(line: str, context: dict) -> dict:
211217
return context
212218

213219

214-
def assign_variable(line: str, context: dict) -> dict:
220+
def assign_variable(line: str, context: dict, warnings: bool = True) -> dict:
215221
"""
216222
Assign a variable to the context.
217223
218224
:param line: Line of a variable assignment to be parsed.
219225
:param context: Dictionary of variables to assign the variable to and from which to
220226
read variables.
227+
:param warnings: Whether to print warnings for unrecognised expressions that might
228+
lead to unexpected behaviour when parsed as strings.
221229
:return: Updated context.
222230
"""
223231
pattern = r"([a-z0-9_]+)\s*=(.*)"
@@ -226,18 +234,20 @@ def assign_variable(line: str, context: dict) -> dict:
226234
variable_name = match.group(1).strip()
227235
variable_expression = match.group(2).strip()
228236

229-
context[variable_name] = evaluate_expression(variable_expression, context)
237+
context[variable_name] = evaluate_expression(variable_expression, context, warnings)
230238

231239
return context
232240

233241

234-
def define_element(line: str, context: dict) -> dict:
242+
def define_element(line: str, context: dict, warnings: bool = True) -> dict:
235243
"""
236244
Define an element in the context.
237245
238246
:param line: Line of an element definition to be parsed.
239247
:param context: Dictionary of variables to define the element in and from which to
240248
read variables.
249+
:param warnings: Whether to print warnings for unrecognised expressions that might
250+
lead to unexpected behaviour when parsed as strings.
241251
:return: Updated context.
242252
"""
243253
pattern = r"([a-z0-9_\.]+)\s*\:\s*([a-z0-9_]+)(\s*\,(.*))?"
@@ -267,7 +277,7 @@ def define_element(line: str, context: dict) -> dict:
267277
property_expression = property_expression.strip()
268278

269279
element_properties[property_name] = evaluate_expression(
270-
property_expression, context
280+
property_expression, context, warnings
271281
)
272282

273283
context[element_name] = element_properties
@@ -366,12 +376,14 @@ def parse_use_line(line: str, context: dict) -> dict:
366376
return context
367377

368378

369-
def parse_lines(lines: str) -> dict:
379+
def parse_lines(lines: str, warnings: bool = True) -> dict:
370380
"""
371381
Parse a list of lines from a Bmad or Elegant lattice file. They should be cleaned
372382
and merged before being passed to this function.
373383
374384
:param lines: List of lines to parse.
385+
:param warnings: Whether to print warnings for unrecognised expressions that might
386+
lead to unexpected behaviour when parsed as strings.
375387
:return: Dictionary of variables defined in the lattice file.
376388
"""
377389
property_assignment_pattern = r"[a-z0-9_\*:]+\[[a-z0-9_%]+\]\s*=.*"
@@ -389,21 +401,25 @@ def parse_lines(lines: str) -> dict:
389401
"m_electron": (
390402
physical_constants["electron mass energy equivalent in MeV"][0] * 1e6
391403
),
392-
"raddeg": scipy.constants.degree,
404+
"sqrt": math.sqrt,
405+
"asin": math.asin,
406+
"sin": math.sin,
407+
"cos": math.cos,
393408
"abs_func": abs,
409+
"raddeg": scipy.constants.degree,
394410
}
395411

396412
for line in lines:
397413
if re.fullmatch(property_assignment_pattern, line):
398-
context = assign_property(line, context)
414+
context = assign_property(line, context, warnings)
399415
elif re.fullmatch(variable_assignment_pattern, line):
400-
context = assign_variable(line, context)
416+
context = assign_variable(line, context, warnings)
401417
elif re.fullmatch(line_definition_pattern, line):
402418
context = define_line(line, context)
403419
elif re.fullmatch(overlay_definition_pattern, line):
404420
context = define_overlay(line, context)
405421
elif re.fullmatch(element_definition_pattern, line):
406-
context = define_element(line, context)
422+
context = define_element(line, context, warnings)
407423
elif re.fullmatch(use_line_pattern, line):
408424
context = parse_use_line(line, context)
409425

cheetah/converters/utils/rpn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def evaluate_expression(expression: str, context: dict | None = None) -> Any:
1212

1313
stack = []
1414
stripped = expression.strip()
15-
for token in stripped.split(" "):
15+
for token in stripped.split():
1616
match token:
1717
case "+":
1818
try:
@@ -99,12 +99,12 @@ def evaluate_expression(expression: str, context: dict | None = None) -> Any:
9999
raise SyntaxError(
100100
f"Invalid expression: {expression} - Need one value before asin"
101101
)
102-
case _: # all other tokens
103-
# commment, ignore this and all following tokens
102+
case _: # All other tokens
103+
# Commment ... ignore this and all following tokens
104104
if token[0] == "#":
105105
break
106106
try:
107-
# read as float since it's all torch in the back anyway
107+
# Read as float since it's all torch in the back anyway
108108
number = float(token)
109109
except ValueError:
110110
if token in context:

tests/resources/bmad_tutorial_lattice.bmad

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ parameter[particle] = electron ! Reference particle.
1010
abs = -0.6
1111

1212
d: drift, L = 0.5 * (0.3 + 0.7)
13-
b: sbend, L = 0.5, g = 1, e1 = 0.1, dg = 0.001 ! g = 1/design_radius
13+
! The two spaces in the expression for L are on purpose to test the parser's ability to
14+
! handle this
15+
b: sbend, L = 0.6 -0.1, g = 1, e1 = 0.1, dg = sqrt(0.000001) ! g = 1/design_radius
1416
q: quadrupole, L = abs(abs), k1 = 0.23
1517

1618
lat: line = (d, b, q) ! List of lattice elements

0 commit comments

Comments
 (0)