Skip to content

Commit 7879a85

Browse files
Refactor draw_elements (#95)
1 parent ce7bfa7 commit 7879a85

File tree

4 files changed

+105
-81
lines changed

4 files changed

+105
-81
lines changed

apace/plot.py

Lines changed: 93 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Iterable
22
from itertools import zip_longest
33
from math import inf
4+
from typing import List, Optional, Tuple
45

56
import matplotlib as mpl
67
import matplotlib.gridspec as grid_spec
@@ -12,7 +13,16 @@
1213
from matplotlib.ticker import AutoMinorLocator, ScalarFormatter
1314
from matplotlib.widgets import Slider
1415

15-
from .classes import Base, Dipole, Drift, Lattice, Octupole, Quadrupole, Sextupole
16+
from .classes import (
17+
Base,
18+
Dipole,
19+
Drift,
20+
Element,
21+
Lattice,
22+
Octupole,
23+
Quadrupole,
24+
Sextupole,
25+
)
1626

1727
FONT_SIZE = 8
1828

@@ -30,7 +40,6 @@ class Color:
3040

3141

3242
ELEMENT_COLOR = {
33-
Drift: Color.BLACK,
3443
Dipole: Color.YELLOW,
3544
Quadrupole: Color.RED,
3645
Sextupole: Color.GREEN,
@@ -39,10 +48,10 @@ class Color:
3948

4049

4150
OPTICAL_FUNCTIONS = {
42-
"beta_x": (r"$\beta_x$/m", Color.RED),
43-
"beta_y": (r"$\beta_y$/m", Color.BLUE),
44-
"eta_x": (r"$\eta_x$/m", Color.GREEN),
45-
"eta_x_dds": (r"$\eta_x'$/m", Color.ORANGE),
51+
"beta_x": (r"$\beta_x$ / m", Color.RED),
52+
"beta_y": (r"$\beta_y$ / m", Color.BLUE),
53+
"eta_x": (r"$\eta_x$ / m", Color.GREEN),
54+
"eta_x_dds": (r"$\eta_x'$ / m", Color.ORANGE),
4655
"psi_x": (r"$\psi_x$", Color.YELLOW),
4756
"psi_y": (r"$\psi_y$", Color.ORANGE),
4857
"alpha_x": (r"$\alpha_x$", Color.PURPLE),
@@ -57,51 +66,54 @@ def draw_elements(
5766
labels: bool = True,
5867
location: str = "top",
5968
):
60-
"""Draw elements of a lattice to a matplotlib axes
61-
62-
:param ax: matplotlib axes, if not provided use current axes
63-
:type ax: matplotlib.axes
64-
:param lattice: lattice which gets drawn
65-
:type lattice: ap.Lattice
66-
:param labels: whether to display the names of elments, defaults to False
67-
:type labels: bool, optional
68-
:param draw_sub_lattices: Whether to show the start and end position of the sub lattices,
69-
defaults to True
70-
:type draw_sublattices: bool, optional
71-
"""
72-
69+
"""Draw the elements of a lattice onto a matplotlib axes."""
7370
x_min, x_max = ax.get_xlim()
7471
y_min, y_max = ax.get_ylim()
7572
rect_height = 0.05 * (y_max - y_min)
76-
y0 = y_max if location == "top" else y_min
73+
if location == "top":
74+
y0 = y_max = y_max + rect_height
75+
else:
76+
y0 = y_min - rect_height
77+
y_min -= 3 * rect_height
78+
plt.hlines(y0, x_min, x_max, color="black", linewidth=1)
79+
ax.set_ylim(y_min, y_max)
7780

78-
start = end = 0
7981
arrangement = lattice.arrangement
82+
position = start = end = 0
83+
sign = 1
8084
for element, next_element in zip_longest(arrangement, arrangement[1:]):
81-
end += element.length
82-
if element is next_element:
85+
position += element.length
86+
if element is next_element or position <= x_min:
8387
continue
88+
elif start >= x_max:
89+
break
8490

85-
if isinstance(element, Drift) or start >= x_max or end <= x_min:
86-
start = end
91+
start, end = end, position
92+
try:
93+
color = ELEMENT_COLOR[type(element)]
94+
except KeyError:
8795
continue
8896

89-
rec_length = min(end, x_max) - max(start, x_min)
90-
rectangle = plt.Rectangle(
91-
(max(start, x_min), y0 - rect_height / 2),
92-
rec_length,
93-
rect_height,
94-
fc=ELEMENT_COLOR[type(element)],
95-
clip_on=False,
96-
zorder=10,
97+
y0_local = y0
98+
if isinstance(element, Dipole) and element.angle < 0:
99+
y0_local += rect_height / 4
100+
101+
ax.add_patch(
102+
plt.Rectangle(
103+
(max(start, x_min), y0_local - rect_height / 2),
104+
min(end, x_max) - max(start, x_min),
105+
rect_height,
106+
facecolor=color,
107+
clip_on=False,
108+
zorder=10,
109+
)
97110
)
98-
ax.add_patch(rectangle)
99-
start = end
100-
if labels:
101-
sign = (isinstance(element, Quadrupole) << 1) - 1
111+
if labels and type(element) in {Dipole, Quadrupole}:
112+
# sign = (isinstance(element, Quadrupole) << 1) - 1
113+
sign = -sign
102114
ax.annotate(
103115
element.name,
104-
xy=((start + end) / 2, y0 + sign * rect_height),
116+
xy=((start + end) / 2, y0 - sign * rect_height),
105117
fontsize=FONT_SIZE,
106118
ha="center",
107119
va="center",
@@ -115,6 +127,7 @@ def draw_sub_lattices(
115127
lattice: Lattice,
116128
*,
117129
labels: bool = True,
130+
location: str = "bottom",
118131
):
119132
x_min, x_max = ax.get_xlim()
120133
length_gen = [0.0, *(obj.length for obj in lattice.tree)]
@@ -123,14 +136,20 @@ def draw_sub_lattices(
123136
i_max = np.searchsorted(position_list, x_max)
124137
ticks = position_list[i_min : i_max + 1]
125138
ax.set_xticks(ticks)
126-
if len(ticks) < 5:
127-
ax.xaxis.set_minor_locator(AutoMinorLocator())
128-
ax.xaxis.set_minor_formatter(ScalarFormatter())
129-
ax.grid(linestyle="--")
139+
# if len(ticks) < 5:
140+
# ax.xaxis.set_minor_locator(AutoMinorLocator())
141+
# ax.xaxis.set_minor_formatter(ScalarFormatter())
142+
ax.grid(axis="x", linestyle="--")
130143

131144
if labels:
132145
y_min, y_max = ax.get_ylim()
133-
y0 = y_max - 0.1 * (y_max - y_min)
146+
height = 0.08 * (y_max - y_min)
147+
if location == "top":
148+
y0, y_max = y_max, y_max + height
149+
else:
150+
y0, y_min = y_min - height / 3, y_min - height
151+
152+
ax.set_ylim(y_min, y_max)
134153
start = end = 0
135154
for obj in lattice.tree:
136155
end += obj.length
@@ -141,7 +160,7 @@ def draw_sub_lattices(
141160
ax.annotate(
142161
obj.name,
143162
xy=(x0, y0),
144-
fontsize=FONT_SIZE,
163+
fontsize=FONT_SIZE + 2,
145164
fontstyle="oblique",
146165
alpha=0.5,
147166
va="center",
@@ -169,32 +188,33 @@ def plot_twiss(
169188
text_areas = []
170189
for i, function in enumerate(twiss_functions):
171190
value = getattr(twiss, function)
172-
scale = scales.get(function, "")
173191
label, color = OPTICAL_FUNCTIONS[function]
174-
label = str(scale) + label
192+
scale = scales.get(function)
193+
if scale is not None:
194+
label = f"{scale} {label}"
195+
value = scale * value
196+
175197
ax.plot(
176198
twiss.s,
177-
value if scale == "" else scale * value,
199+
value,
178200
color=color,
179201
linewidth=line_width,
180202
linestyle=line_style,
181203
alpha=alpha,
182204
zorder=10 - i,
183205
label=label,
184206
)
185-
text_areas.append(TextArea(label, textprops=dict(color=color, rotation=90)))
207+
text_areas.insert(0, TextArea(label, textprops=dict(color=color, rotation=90)))
186208

187209
ax.set_xlabel("Orbit Position $s$ / m")
188210
if show_ylabels:
189211
ax.add_artist(
190212
AnchoredOffsetbox(
191-
loc=8,
192-
child=VPacker(children=text_areas, align="bottom", pad=0, sep=10),
193-
pad=0.0,
194-
frameon=False,
195-
bbox_to_anchor=(-0.08, 0.3),
213+
child=VPacker(children=text_areas, align="bottom", pad=0, sep=20),
214+
loc="center left",
215+
bbox_to_anchor=(-0.125, 0, 1.125, 1),
196216
bbox_transform=ax.transAxes,
197-
borderpad=0.0,
217+
frameon=False,
198218
)
199219
)
200220

@@ -270,7 +290,7 @@ def __init__(
270290
main=True,
271291
scales={"eta_x": 10},
272292
ref_twiss=None,
273-
pairs=None,
293+
pairs: Optional[List[Tuple[Element, str]]] = None,
274294
):
275295
self.fig = plt.figure()
276296
self.twiss = twiss
@@ -372,25 +392,30 @@ def find_optimal_grid(n):
372392

373393

374394
def floor_plan(
375-
lattice, ax=None, start_angle=0, annotate_elements=True, direction="clockwise"
395+
ax: mpl.axes.Axes,
396+
lattice: Lattice,
397+
*,
398+
start_angle: float = 0,
399+
labels: bool = True,
400+
direction: str = "clockwise",
376401
):
377-
if ax is None:
378-
ax = plt.gca()
379-
380402
ax.set_aspect("equal")
381-
codes = [Path.MOVETO, Path.LINETO]
403+
codes = Path.MOVETO, Path.LINETO
382404
current_angle = start_angle
383-
384405
start = np.zeros(2)
385406
end = np.zeros(2)
386407
x_min = y_min = 0
387408
x_max = y_max = 0
388409
arrangement = lattice.arrangement
389-
arrangement_shifted = arrangement[1:] + arrangement[0:1]
390-
for element, next_element in zip(arrangement, arrangement_shifted):
391-
color = ELEMENT_COLOR[type(element)]
410+
sign = 1
411+
for element, next_element in zip_longest(arrangement, arrangement[1:]):
392412
length = element.length
393-
line_width = 0.5 if isinstance(element, Drift) else 3
413+
if isinstance(element, Drift):
414+
color = Color.BLACK
415+
line_width = 1
416+
else:
417+
color = ELEMENT_COLOR[type(element)]
418+
line_width = 6
394419

395420
# TODO: refactor current angle
396421
angle = 0
@@ -440,9 +465,9 @@ def floor_plan(
440465
if element is next_element:
441466
continue
442467

443-
if annotate_elements and not isinstance(element, Drift):
468+
if labels and isinstance(element, (Dipole, Quadrupole)):
444469
angle_center = (current_angle - angle / 2) + np.pi / 2
445-
sign = -1 if isinstance(element, Quadrupole) else 1
470+
sign = -sign
446471
center = (start + end) / 2 + sign * 0.5 * np.array(
447472
[np.cos(angle_center), np.sin(angle_center)]
448473
)
@@ -459,8 +484,7 @@ def floor_plan(
459484

460485
start = end.copy()
461486

462-
margin = 0.05 * max((x_max - x_min), (y_max - y_min))
487+
margin = 0.01 * max((x_max - x_min), (y_max - y_min))
463488
ax.set_xlim(x_min - margin, x_max + margin)
464489
ax.set_ylim(y_min - margin, y_max + margin)
465-
466490
return ax

conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def pytest_addoption(parser):
1313
"--runslow", action="store_true", default=False, help="run slow tests"
1414
)
1515
parser.addoption(
16-
"--plots", action="store_true", default=False, help="Plot test results"
16+
"--plot", action="store_true", default=False, help="Plot test results"
1717
)
1818

1919

@@ -32,8 +32,8 @@ def pytest_collection_modifyitems(config, items):
3232

3333

3434
@pytest.fixture
35-
def plots(request):
36-
return request.config.getoption("--plots")
35+
def plot(request):
36+
return request.config.getoption("--plot")
3737

3838

3939
BASE_PATH = Path(__file__).parent

tests/test_plot.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,28 @@
99
import matplotlib.pyplot as plt
1010

1111

12-
def test_draw_elements(fodo_ring, test_output_dir, plots):
12+
def test_draw_elements(fodo_ring, test_output_dir, plot):
1313
twiss = ap.Twiss(fodo_ring)
1414
fig, ax = plt.subplots()
1515
plot_twiss(ax, twiss)
1616
draw_elements(ax, fodo_ring, location="top")
1717
draw_sub_lattices(ax, fodo_ring)
18-
if plots:
18+
if plot:
1919
fig.tight_layout()
2020
fig.savefig(test_output_dir / "test_draw_elements.svg")
2121

2222

23-
def test_TwissPlot(fodo_ring, test_output_dir, plots):
23+
def test_TwissPlot(fodo_ring, test_output_dir, plot):
2424
twiss = ap.Twiss(fodo_ring)
2525
fig = TwissPlot(twiss).fig
26-
if plots:
26+
if plot:
2727
fig.savefig(test_output_dir / "test_TwissPlot.svg")
2828

2929

30-
def test_floor_plan(fodo_ring, test_output_dir, plots):
30+
def test_floor_plan(fodo_ring, test_output_dir, plot):
3131
_, ax = plt.subplots()
32-
ax = floor_plan(fodo_ring, ax=ax)
32+
ax = floor_plan(ax, fodo_ring)
3333
ax.invert_yaxis()
34-
if plots:
34+
if plot:
3535
plt.tight_layout()
3636
plt.savefig(test_output_dir / "test_floor_plan.svg")

tests/test_tracking_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
# TODO: improve and add assert statments
8-
def test_quadrupole(test_output_dir, plots):
8+
def test_quadrupole(test_output_dir, plot):
99
d1 = ap.Drift("D1", length=5)
1010
d2 = ap.Drift("D2", length=0.2)
1111
q = ap.Quadrupole("Q1", length=0.2, k1=1.5)
@@ -18,7 +18,7 @@ def test_quadrupole(test_output_dir, plots):
1818
)
1919
tracking = Tracking(lattice)
2020
s, trajectory = tracking.track(distribution)
21-
if plots:
21+
if plot:
2222
_, ax = plt.subplots(figsize=(20, 5))
2323
ax.plot(s, trajectory[:, 0])
2424
plt.ylim(-0.0002, 0.0002)

0 commit comments

Comments
 (0)