Skip to content

Commit 6151261

Browse files
add location parameter to draw-lattice (#85)
1 parent 5df8a2b commit 6151261

File tree

2 files changed

+77
-63
lines changed

2 files changed

+77
-63
lines changed

apace/cli.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,26 @@ def cli():
1515
@cli.command()
1616
@click.argument("location")
1717
@click.option("-o", "--output", required=False, type=click.Path(), help="Output path.")
18+
@click.option("-i", "--interactive", required=False, multiple=True, type=(str, str))
1819
@click.option("--ref-lattice", required=False, help="Path or URL to reference lattice.")
1920
@click.option("-s", "--sections", required=False, multiple=True, type=(float, float))
2021
@click.option("--y-min", required=False, type=float)
2122
@click.option("--y-max", required=False, type=float)
22-
def twiss(location, output, ref_lattice, sections, y_min, y_max):
23+
def twiss(location, output, interactive, ref_lattice, sections, y_min, y_max):
2324
"""Plot the Twiss parameter of the lattice at LOCATION (path or URL)."""
2425
lattice = Lattice.from_file(location)
25-
twiss = Twiss(lattice)
26-
ref_twiss = Twiss(ref_lattice) if ref_lattice is not None else None
27-
TwissPlot(twiss, ref_twiss=ref_twiss, sections=sections, y_min=y_min, y_max=y_max)
26+
options = dict(twiss=Twiss(lattice), sections=sections, y_min=y_min, y_max=y_max)
27+
if interactive: # TODO: interactive seems frozen
28+
options["pairs"] = (
29+
[(lattice[name], attr) for name, attr in interactive]
30+
if interactive
31+
else None
32+
)
33+
34+
if ref_lattice:
35+
options["ref_twiss"] = Twiss(ref_lattice)
36+
37+
TwissPlot(**options)
2838
if output is None:
2939
plt.show()
3040
else:

apace/plot.py

Lines changed: 63 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from enum import Enum
2-
from typing import Union, List, Tuple
31
import matplotlib as mpl
42
import matplotlib.pyplot as plt
53
import matplotlib.patches as patches
@@ -9,12 +7,14 @@
97
from matplotlib.path import Path
108
from matplotlib.ticker import AutoMinorLocator, ScalarFormatter
119
import numpy as np
12-
from enum import Enum
1310
from math import inf
1411
from collections.abc import Iterable
1512
from .classes import Base, Drift, Dipole, Quadrupole, Sextupole, Octupole, Lattice
1613

1714

15+
FONT_SIZE = 8
16+
17+
1818
class Color:
1919
RED = "crimson"
2020
GREEN = "mediumseagreen"
@@ -34,26 +34,24 @@ class Color:
3434
Octupole: Color.BLUE,
3535
}
3636

37-
FONT_SIZE = 8
3837

39-
[
40-
"beta_x",
41-
"beta_y",
42-
"eta_x",
43-
"psi_x",
44-
"psi_y",
45-
"alpha_x",
46-
"alpha_y",
47-
"gamma_x",
48-
"gamma_y",
49-
]
38+
OPTICAL_FUNCTIONS = {
39+
"beta_x": (r"$\beta_x$/m", Color.RED),
40+
"beta_y": (r"$\beta_y$/m", Color.BLUE),
41+
"eta_x": (r"$\eta_x$/m", Color.GREEN),
42+
"psi_x": (r"$\psi_x$", Color.YELLOW),
43+
"psi_y": (r"$\psi_y$", Color.ORANGE),
44+
"alpha_x": (r"$\alpha_x$", Color.MAGENTA),
45+
"alpha_y": (r"$\alpha_y$", Color.BLACK),
46+
}
5047

5148

5249
def draw_lattice(
5350
lattice,
5451
ax=None,
5552
x_min=-inf,
5653
x_max=inf,
54+
location="top",
5755
draw_elements=True,
5856
annotate_elements=True,
5957
draw_sub_lattices=True,
@@ -82,6 +80,12 @@ def draw_lattice(
8280
y_span = y_max - y_min
8381
rect_height = y_span / 32
8482

83+
y0 = -rect_height / 2
84+
if location == "top":
85+
y0 += y_max
86+
elif location == "bottom":
87+
y0 += y_min
88+
8589
if draw_elements:
8690
start = end = 0
8791
arrangement = lattice.arrangement
@@ -96,7 +100,7 @@ def draw_lattice(
96100

97101
rec_length = min(end, x_max) - max(start, x_min)
98102
rectangle = plt.Rectangle(
99-
(start if start > x_min else x_min, y_max - rect_height / 2),
103+
(start if start > x_min else x_min, y0),
100104
rec_length,
101105
rect_height,
102106
fc=ELEMENT_COLOR[type(element)],
@@ -112,7 +116,7 @@ def draw_lattice(
112116
)
113117
ax.annotate(
114118
element.name,
115-
xy=(center, y_max + sign * 0.75 * rect_height),
119+
xy=(center, y0 + sign * 0.75 * rect_height),
116120
fontsize=FONT_SIZE,
117121
ha="center",
118122
va=va,
@@ -135,7 +139,7 @@ def draw_lattice(
135139
ax.grid(linestyle="--")
136140

137141
if annotate_sub_lattices:
138-
y0 = y_max - 3 * rect_height
142+
y0_anno = y0 - 3 * rect_height
139143
end = 0
140144
for obj in lattice.tree:
141145
end += obj.length
@@ -145,7 +149,7 @@ def draw_lattice(
145149
x0 = end - obj.length / 2
146150
ax.annotate(
147151
obj.name,
148-
xy=(x0, y0),
152+
xy=(x0, y0_anno),
149153
fontsize=FONT_SIZE,
150154
fontstyle="oblique",
151155
alpha=0.5,
@@ -158,35 +162,37 @@ def draw_lattice(
158162

159163
def plot_twiss(
160164
twiss,
165+
twiss_functions=("beta_x", "beta_y", "eta_x"),
166+
*,
167+
scales={"eta_x": 10},
161168
ax=None,
162169
line_style="solid",
163170
line_width=1.3,
164171
alpha=1.0,
165-
eta_scale=10,
166172
show_ylabels=False,
167173
):
168174
if ax is None:
169175
ax = plt.gca()
170-
171-
text_areas = [None] * 3
172-
for value, label, color, order in (
173-
(twiss.beta_x, r"$\beta_x$/m", Color.RED, 2),
174-
(twiss.beta_y, r"$\beta_y$/m", Color.BLUE, 1),
175-
(twiss.eta_x * eta_scale, rf"{eta_scale}$\eta_x$/m", Color.GREEN, 0),
176-
# (twiss.curly_h, rf"{eta_scale}$\mathscr{{H}}_x$", Color.ORANGE, -1),
177-
):
176+
if scales is None:
177+
scales = {}
178+
179+
text_areas = []
180+
for i, function in enumerate(twiss_functions):
181+
value = getattr(twiss, function)
182+
scale = scales.get(function, "")
183+
label, color = OPTICAL_FUNCTIONS[function]
184+
label = str(scale) + label
178185
ax.plot(
179186
twiss.s,
180-
value,
187+
value if scale == "" else scale * value,
181188
color=color,
182189
linewidth=line_width,
183190
linestyle=line_style,
184191
alpha=alpha,
185-
zorder=order,
192+
zorder=10 - i,
186193
label=label,
187194
)
188-
189-
text_areas[order] = TextArea(label, textprops=dict(color=color, rotation=90))
195+
text_areas.append(TextArea(label, textprops=dict(color=color, rotation=90)))
190196

191197
ax.set_xlabel("Orbit Position $s$ / m")
192198
if show_ylabels:
@@ -218,22 +224,23 @@ def _twiss_plot_section(
218224
ref_twiss=None,
219225
ref_line_style="dashed",
220226
ref_line_width=2.5,
221-
eta_scale=10,
227+
scales={"eta_x": 10},
222228
overwrite=False,
223229
):
224230
if overwrite:
225231
ax.clear()
226232
if ref_twiss:
227233
plot_twiss(
228234
ref_twiss,
229-
ax,
230-
ref_line_style,
231-
ref_line_width,
235+
ax=ax,
236+
line_style=ref_line_style,
237+
line_width=ref_line_width,
232238
alpha=0.5,
233-
eta_scale=eta_scale,
234239
)
235240

236-
plot_twiss(twiss, ax, line_style, line_width, eta_scale=eta_scale)
241+
plot_twiss(
242+
twiss, ax=ax, line_style=line_style, line_width=line_width, scales=scales
243+
)
237244
if x_min is None:
238245
x_min = 0
239246
if x_max is None:
@@ -251,6 +258,7 @@ def _twiss_plot_section(
251258
# TODO:
252259
# * make sub_class of figure
253260
# * add attribute which defines which twiss parameters are plotted
261+
# * add twiss_functions argument similar to plot_twiss
254262
class TwissPlot:
255263
"""Convenience class to plot twiss parameters
256264
@@ -260,7 +268,7 @@ class TwissPlot:
260268
:param y_max float: Maximum y-limit
261269
:param y_min float: Minimum y-limit
262270
:param main bool: Wheter to plot whole ring or only given sections
263-
:param eta_scale int: Scaling factor of the dipsersion function
271+
:param scales Dict[str, int]: Optional scaling factors for optical functions
264272
:param Twiss ref_twiss: Reference twiss values. Will be plotted as dashed lines.
265273
:param pairs: List of (element, attribute)-pairs to create interactice sliders for.
266274
:type pairs: List[Tuple[Element, str]]
@@ -269,25 +277,31 @@ class TwissPlot:
269277
def __init__(
270278
self,
271279
twiss,
280+
twiss_functions=("beta_x", "beta_y", "eta_x"),
281+
*,
272282
sections=None,
273283
y_min=None,
274284
y_max=None,
275285
main=True,
276-
eta_scale=10,
286+
scales={"eta_x": 10},
277287
ref_twiss=None,
278288
pairs=None,
279289
):
280290
self.fig = plt.figure()
281291
self.twiss = twiss
282292
self.lattice = twiss.lattice
283-
self.eta_scale = eta_scale
293+
self.twiss_functions = twiss_functions
294+
self.scales = scales
284295
height_ratios = [4, 14] if (main and sections) else [1]
285296
main_grid = grid_spec.GridSpec(
286297
len(height_ratios), 1, self.fig, height_ratios=height_ratios
287298
)
299+
self.axs_sections = [] # TODO: needed for update function
288300

289301
if pairs:
290302
fig_sliders, axs = plt.subplots(nrows=len(pairs))
303+
if not isinstance(axs, Iterable):
304+
axs = (axs,)
291305
self.sliders = []
292306
for ax, (element, attribute) in zip(axs, pairs):
293307
initial_value = getattr(element, attribute)
@@ -310,7 +324,7 @@ def __init__(
310324
y_min=y_min,
311325
y_max=y_max,
312326
annotate_elements=False,
313-
eta_scale=eta_scale,
327+
scales=scales,
314328
)
315329

316330
if sections:
@@ -338,32 +352,22 @@ def __init__(
338352
y_min=y_min,
339353
y_max=y_max,
340354
annotate_elements=True,
341-
eta_scale=eta_scale,
355+
scales=scales,
342356
)
343357

344358
handles, labels = self.fig.axes[0].get_legend_handles_labels()
345-
self.fig.legend(
346-
handles,
347-
labels,
348-
loc="upper left",
349-
ncol=10,
350-
frameon=False,
351-
)
359+
self.fig.legend(handles, labels, loc="upper left", ncol=10, frameon=False)
352360
self.fig.suptitle(twiss.lattice.name, ha="right", x=0.98)
353361
self.fig.tight_layout()
354362

355363
def update(self):
356364
twiss = self.twiss
357365
for ax in [self.ax_main] + self.axs_sections:
358-
for line, data in zip(
359-
ax.lines,
360-
(
361-
twiss.beta_x,
362-
twiss.beta_y,
363-
twiss.eta_x * self.eta_scale,
364-
twiss.curly_h * self.eta_scale,
365-
),
366-
):
366+
for line, function in zip(ax.lines, self.twiss_functions):
367+
data = getattr(twiss, function)
368+
scale = self.scales.get(function)
369+
if scale is not None:
370+
data *= scale
367371
line.set_data(twiss.s, data)
368372
self.fig.canvas.draw_idle()
369373

0 commit comments

Comments
 (0)