Skip to content

Commit 87b6677

Browse files
Cleanup twiss_plot (#59)
1 parent eef84e5 commit 87b6677

File tree

2 files changed

+46
-130
lines changed

2 files changed

+46
-130
lines changed

apace/plot.py

Lines changed: 42 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ class Color:
3535
def draw_lattice(
3636
lattice,
3737
ax=None,
38+
x_min=-inf,
39+
x_max=inf,
3840
draw_elements=True,
3941
annotate_elements=True,
4042
draw_sub_lattices=True,
4143
annotate_sub_lattices=True,
42-
x_min=-inf,
43-
x_max=inf,
4444
):
4545
"""Draw elements of a lattice to a matplotlib axes
4646
@@ -73,7 +73,7 @@ def draw_lattice(
7373
if element is next_element:
7474
continue
7575

76-
if isinstance(element, Drift) or start > x_max or end < x_min:
76+
if isinstance(element, Drift) or start >= x_max or end <= x_min:
7777
start = end
7878
continue
7979

@@ -105,11 +105,12 @@ def draw_lattice(
105105
start = end
106106

107107
if draw_sub_lattices:
108-
length_list = [0]
109-
length_list.extend(obj.length for obj in lattice.tree)
110-
position_list = np.add.accumulate(length_list)
111-
ax.set_xticks(position_list)
112-
ax.grid(axis="x", linestyle="--")
108+
length_gen = [0] + [obj.length for obj in lattice.tree]
109+
position_list = np.add.accumulate(length_gen)
110+
i_min = np.searchsorted(position_list, x_min)
111+
i_max = np.searchsorted(position_list, x_max)
112+
ax.set_xticks(position_list[i_min:i_max])
113+
ax.grid(linestyle="--")
113114

114115
if annotate_sub_lattices:
115116
y0 = y_max - 3 * rect_height
@@ -134,13 +135,7 @@ def draw_lattice(
134135

135136

136137
def plot_twiss(
137-
twiss,
138-
ax=None,
139-
line_style="solid",
140-
line_width=1.3,
141-
alpha=1.0,
142-
eta_x_scale=10,
143-
show_legend=True,
138+
twiss, ax=None, line_style="solid", line_width=1.3, alpha=1.0, eta_x_scale=10,
144139
):
145140
if ax is None:
146141
_, ax = plt.subplots()
@@ -176,101 +171,13 @@ def plot_twiss(
176171
label=f"{eta_x_scale}$\\eta_x$/m",
177172
)
178173

179-
if show_legend:
180-
ax.legend(
181-
loc="lower left",
182-
bbox_to_anchor=(0.0, 1.05),
183-
ncol=10,
184-
borderaxespad=0,
185-
frameon=False,
186-
)
187-
188174
ax.set_xlabel("Orbit Position $s$ / m")
189-
190175
return ax
191176

192177

193-
def set_limits(ax, lattice, x_min=None, x_max=None, y_min=None, y_max=None):
194-
x_min = x_min if x_min else 0
195-
x_max = x_max if x_max else lattice.length
196-
y_lim = ax.get_ylim()
197-
y_min = y_min if y_min else -0.5
198-
y_max = y_max if y_max else 1.1 * y_lim[1]
199-
ax.set_xlim((x_min, x_max))
200-
ax.set_ylim((y_min, y_max))
201-
return x_min, x_max, y_min, y_max
202-
203-
204-
def set_grid(ax, lattice, x_min, x_max, y_min, y_max, n_x_ticks, n_y_ticks):
205-
ax.xaxis.grid(which="minor", linestyle="dotted")
206-
ax.yaxis.grid(alpha=0.5, zorder=0, linestyle="dotted")
207-
if n_x_ticks:
208-
lin = np.linspace(x_min, x_max, n_x_ticks)
209-
lattice_length_list = [0]
210-
lattice_length_list.extend([lattice.length for lattice in lattice.tree])
211-
pos_tick = (
212-
np.add.accumulate(lattice_length_list) if not (x_min and x_max) else lin
213-
)
214-
ax.set_xticks(pos_tick, minor=True)
215-
ax.set_xticks(lin)
216-
if n_y_ticks:
217-
ax.set_yticks(np.arange(int(y_min), int(y_max), n_y_ticks))
218-
ax.set_xlabel("orbit position $s$/m")
219-
220-
221-
def annotate_info(lattice, twiss, ax=None):
222-
# fig = plt.gcf()
223-
if ax is None:
224-
ax = plt.gca()
225-
226-
margin = 0.02
227-
fs = 15
228-
229-
ax.annotate(
230-
lattice.name,
231-
xy=(1 - margin, 1 - margin),
232-
xycoords="figure fraction",
233-
va="top",
234-
ha="right",
235-
fontsize=fs,
236-
)
237-
238-
# annolist_string1 = f"$Q_x$: {twiss.Qx:.2f} ({twiss.Qx_freq:.0f} kHz) $Q_y$: {twiss.Qy:.2f} " \
239-
# f"({twiss.Qy_freq:.0f} kHz) $\\alpha_C$: {twiss.alphac:.2e}"
240-
# fig.annotate(annolist_string1, xy=(start, height_1),
241-
# xycoords='figure fraction', va='center', ha='left', fontsize=fs)
242-
# r = fig.canvas.get_renderer()
243-
244-
x = margin
245-
y = 1 - margin
246-
for line in ax.get_lines():
247-
label = line.get_label()
248-
plt.annotate(
249-
label,
250-
xy=(x, y),
251-
xycoords="figure fraction",
252-
color=line.get_color(),
253-
fontsize=fs,
254-
va="top",
255-
ha="left",
256-
)
257-
x += len(label) * 0.005
258-
259-
# space = 0
260-
# x_min, x_max = ax.get_xlim()
261-
# for i, s in enumerate(string):
262-
# t = ax.annotate(s, xy=(w, 1 - margin), xycoords='figure fraction',
263-
# color=mpl.cm.Set1(i / 9), fontsize=fs, va="top", ha="left")
264-
# transform = ax.transData.inverted()
265-
# bb = t.get_window_extent(renderer=r)
266-
# bb = bb.transformed(transform)
267-
# w = w + (bb.x_max - bb.x_min) / (x_max - x_min) + space
268-
269-
270178
def _twiss_plot_section(
271179
twiss,
272180
ax=None,
273-
lattice=None,
274181
x_min=None,
275182
x_max=None,
276183
y_min=None,
@@ -289,22 +196,26 @@ def _twiss_plot_section(
289196
):
290197
if overwrite:
291198
ax.clear()
292-
293199
if ref_twiss:
294-
plot_twiss(
295-
ref_twiss, ax, ref_line_style, ref_line_width, alpha=0.5,
296-
)
200+
plot_twiss(ref_twiss, ax, ref_line_style, ref_line_width, alpha=0.5)
297201

298202
plot_twiss(twiss, ax, line_style, line_width, eta_x_scale)
299-
if lattice:
300-
x_min, x_max, y_min, y_max = set_limits(ax, lattice, x_min, x_max, y_min, y_max)
301-
set_grid(ax, lattice, x_min, x_max, y_min, y_max, n_x_ticks, n_y_ticks)
302-
draw_lattice(lattice, ax, annotate_elements, x_min, x_max)
203+
if x_min is None:
204+
x_min = 0
205+
if x_max is None:
206+
x_max = twiss.lattice.length
207+
if y_min is None:
208+
y_min = -0.5
209+
if y_max is None:
210+
y_max = ax.get_ylim()[1]
211+
212+
draw_lattice(twiss.lattice, ax, x_min, x_max, annotate_elements=annotate_elements)
213+
ax.set_xlim((x_min, x_max))
214+
ax.set_ylim((y_min, y_max))
303215

304216

305217
def twiss_plot(
306218
twiss,
307-
lattice=None,
308219
main=True,
309220
fig_size=(16, 9),
310221
sections=None,
@@ -314,46 +225,50 @@ def twiss_plot(
314225
ref_twiss=None,
315226
path=None,
316227
):
317-
fig = plt.figure(figsize=fig_size) # , constrained_layout=True)
228+
fig = plt.figure(figsize=fig_size)
318229
height_ratios = [2, 7] if (main and sections) else [1]
319230
main_grid = grid_spec.GridSpec(
320-
len(height_ratios), 1, figure=fig, height_ratios=height_ratios
231+
len(height_ratios), 1, fig, height_ratios=height_ratios
321232
)
322233

323234
if main:
324235
ax = fig.add_subplot(main_grid[0])
325236
_twiss_plot_section(
326237
twiss,
327238
ax,
328-
lattice,
329239
ref_twiss=ref_twiss,
330240
y_min=y_min,
331241
y_max=y_max,
332-
annotate_elements=True,
242+
annotate_elements=False,
333243
eta_x_scale=eta_x_scale,
334244
)
335245

246+
ax.legend(
247+
loc="lower left",
248+
bbox_to_anchor=(0.0, 1.05),
249+
ncol=10,
250+
borderaxespad=0,
251+
frameon=False,
252+
)
253+
336254
if sections:
337255
if isinstance(sections, str) or not isinstance(sections[0], Iterable):
338256
sections = [sections]
339257

340-
N_sections = len(sections)
341-
rows, cols = find_optimal_grid(N_sections)
342-
sub_grid = grid_spec.GridSpecFromSubplotSpec(
343-
rows, cols, subplot_spec=main_grid[-1]
344-
)
258+
n_sections = len(sections)
259+
rows, cols = find_optimal_grid(n_sections)
260+
sub_grid = grid_spec.GridSpecFromSubplotSpec(rows, cols, main_grid[-1])
345261
for i, section in enumerate(sections):
346262
ax = fig.add_subplot(sub_grid[i])
347263

348264
if isinstance(section, str):
349-
pass # TODO: implement cell_start + cell_end
265+
raise NotImplementedError # TODO: implement cell_start + cell_end
350266
else:
351-
x_min, x_max = section[0], section[1]
267+
x_min, x_max = section
352268

353269
_twiss_plot_section(
354-
ax,
355270
twiss,
356-
lattice,
271+
ax,
357272
ref_twiss=ref_twiss,
358273
x_min=x_min,
359274
x_max=x_max,
@@ -363,10 +278,9 @@ def twiss_plot(
363278
n_x_ticks=None,
364279
)
365280

366-
fig.suptitle(lattice.name)
281+
fig.suptitle(twiss.lattice.name)
367282
fig.tight_layout()
368-
fig.subplots_adjust(top=0.93)
369-
# annotate_info(lattice, twiss)
283+
# fig.subplots_adjust(top=0.93)
370284
if path:
371285
fig.savefig(path)
372286

tests/test_plot.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111

1212

1313
def test_draw_lattice():
14-
plot_twiss(twiss)
14+
_, ax = plt.subplots()
15+
plot_twiss(twiss, ax=ax)
1516
draw_lattice(lattice)
1617
plt.tight_layout()
1718
plt.savefig("/tmp/apace_test_draw_lattice.pdf")
1819

1920

2021
def test_floor_plan():
21-
ax = floor_plan(lattice)
22+
_, ax = plt.subplots()
23+
ax = floor_plan(lattice, ax=ax)
2224
ax.invert_yaxis()
2325
plt.tight_layout()
2426
plt.savefig("/tmp/apace_test_floor_plan.pdf")

0 commit comments

Comments
 (0)