11from collections .abc import Iterable
22from itertools import zip_longest
33from math import inf
4+ from typing import List , Optional , Tuple
45
56import matplotlib as mpl
67import matplotlib .gridspec as grid_spec
1213from matplotlib .ticker import AutoMinorLocator , ScalarFormatter
1314from matplotlib .widgets import Slider
1415
15- from .classes import Base , Dipole , Drift , Lattice , Octupole , Quadrupole , Sextupole
16+ from .classes import (
17+ Base ,
18+ Dipole ,
19+ Drift ,
20+ Element ,
21+ Lattice ,
22+ Octupole ,
23+ Quadrupole ,
24+ Sextupole ,
25+ )
1626
1727FONT_SIZE = 8
1828
@@ -30,7 +40,6 @@ class Color:
3040
3141
3242ELEMENT_COLOR = {
33- Drift : Color .BLACK ,
3443 Dipole : Color .YELLOW ,
3544 Quadrupole : Color .RED ,
3645 Sextupole : Color .GREEN ,
@@ -39,10 +48,10 @@ class Color:
3948
4049
4150OPTICAL_FUNCTIONS = {
42- "beta_x" : (r"$\beta_x$/ m" , Color .RED ),
43- "beta_y" : (r"$\beta_y$/ m" , Color .BLUE ),
44- "eta_x" : (r"$\eta_x$/ m" , Color .GREEN ),
45- "eta_x_dds" : (r"$\eta_x'$/ m" , Color .ORANGE ),
51+ "beta_x" : (r"$\beta_x$ / m" , Color .RED ),
52+ "beta_y" : (r"$\beta_y$ / m" , Color .BLUE ),
53+ "eta_x" : (r"$\eta_x$ / m" , Color .GREEN ),
54+ "eta_x_dds" : (r"$\eta_x'$ / m" , Color .ORANGE ),
4655 "psi_x" : (r"$\psi_x$" , Color .YELLOW ),
4756 "psi_y" : (r"$\psi_y$" , Color .ORANGE ),
4857 "alpha_x" : (r"$\alpha_x$" , Color .PURPLE ),
@@ -57,51 +66,54 @@ def draw_elements(
5766 labels : bool = True ,
5867 location : str = "top" ,
5968):
60- """Draw elements of a lattice to a matplotlib axes
61-
62- :param ax: matplotlib axes, if not provided use current axes
63- :type ax: matplotlib.axes
64- :param lattice: lattice which gets drawn
65- :type lattice: ap.Lattice
66- :param labels: whether to display the names of elments, defaults to False
67- :type labels: bool, optional
68- :param draw_sub_lattices: Whether to show the start and end position of the sub lattices,
69- defaults to True
70- :type draw_sublattices: bool, optional
71- """
72-
69+ """Draw the elements of a lattice onto a matplotlib axes."""
7370 x_min , x_max = ax .get_xlim ()
7471 y_min , y_max = ax .get_ylim ()
7572 rect_height = 0.05 * (y_max - y_min )
76- y0 = y_max if location == "top" else y_min
73+ if location == "top" :
74+ y0 = y_max = y_max + rect_height
75+ else :
76+ y0 = y_min - rect_height
77+ y_min -= 3 * rect_height
78+ plt .hlines (y0 , x_min , x_max , color = "black" , linewidth = 1 )
79+ ax .set_ylim (y_min , y_max )
7780
78- start = end = 0
7981 arrangement = lattice .arrangement
82+ position = start = end = 0
83+ sign = 1
8084 for element , next_element in zip_longest (arrangement , arrangement [1 :]):
81- end += element .length
82- if element is next_element :
85+ position += element .length
86+ if element is next_element or position <= x_min :
8387 continue
88+ elif start >= x_max :
89+ break
8490
85- if isinstance (element , Drift ) or start >= x_max or end <= x_min :
86- start = end
91+ start , end = end , position
92+ try :
93+ color = ELEMENT_COLOR [type (element )]
94+ except KeyError :
8795 continue
8896
89- rec_length = min (end , x_max ) - max (start , x_min )
90- rectangle = plt .Rectangle (
91- (max (start , x_min ), y0 - rect_height / 2 ),
92- rec_length ,
93- rect_height ,
94- fc = ELEMENT_COLOR [type (element )],
95- clip_on = False ,
96- zorder = 10 ,
97+ y0_local = y0
98+ if isinstance (element , Dipole ) and element .angle < 0 :
99+ y0_local += rect_height / 4
100+
101+ ax .add_patch (
102+ plt .Rectangle (
103+ (max (start , x_min ), y0_local - rect_height / 2 ),
104+ min (end , x_max ) - max (start , x_min ),
105+ rect_height ,
106+ facecolor = color ,
107+ clip_on = False ,
108+ zorder = 10 ,
109+ )
97110 )
98- ax .add_patch (rectangle )
99- start = end
100- if labels :
101- sign = (isinstance (element , Quadrupole ) << 1 ) - 1
111+ if labels and type (element ) in {Dipole , Quadrupole }:
112+ # sign = (isinstance(element, Quadrupole) << 1) - 1
113+ sign = - sign
102114 ax .annotate (
103115 element .name ,
104- xy = ((start + end ) / 2 , y0 + sign * rect_height ),
116+ xy = ((start + end ) / 2 , y0 - sign * rect_height ),
105117 fontsize = FONT_SIZE ,
106118 ha = "center" ,
107119 va = "center" ,
@@ -115,6 +127,7 @@ def draw_sub_lattices(
115127 lattice : Lattice ,
116128 * ,
117129 labels : bool = True ,
130+ location : str = "bottom" ,
118131):
119132 x_min , x_max = ax .get_xlim ()
120133 length_gen = [0.0 , * (obj .length for obj in lattice .tree )]
@@ -123,14 +136,20 @@ def draw_sub_lattices(
123136 i_max = np .searchsorted (position_list , x_max )
124137 ticks = position_list [i_min : i_max + 1 ]
125138 ax .set_xticks (ticks )
126- if len (ticks ) < 5 :
127- ax .xaxis .set_minor_locator (AutoMinorLocator ())
128- ax .xaxis .set_minor_formatter (ScalarFormatter ())
129- ax .grid (linestyle = "--" )
139+ # if len(ticks) < 5:
140+ # ax.xaxis.set_minor_locator(AutoMinorLocator())
141+ # ax.xaxis.set_minor_formatter(ScalarFormatter())
142+ ax .grid (axis = "x" , linestyle = "--" )
130143
131144 if labels :
132145 y_min , y_max = ax .get_ylim ()
133- y0 = y_max - 0.1 * (y_max - y_min )
146+ height = 0.08 * (y_max - y_min )
147+ if location == "top" :
148+ y0 , y_max = y_max , y_max + height
149+ else :
150+ y0 , y_min = y_min - height / 3 , y_min - height
151+
152+ ax .set_ylim (y_min , y_max )
134153 start = end = 0
135154 for obj in lattice .tree :
136155 end += obj .length
@@ -141,7 +160,7 @@ def draw_sub_lattices(
141160 ax .annotate (
142161 obj .name ,
143162 xy = (x0 , y0 ),
144- fontsize = FONT_SIZE ,
163+ fontsize = FONT_SIZE + 2 ,
145164 fontstyle = "oblique" ,
146165 alpha = 0.5 ,
147166 va = "center" ,
@@ -169,32 +188,33 @@ def plot_twiss(
169188 text_areas = []
170189 for i , function in enumerate (twiss_functions ):
171190 value = getattr (twiss , function )
172- scale = scales .get (function , "" )
173191 label , color = OPTICAL_FUNCTIONS [function ]
174- label = str (scale ) + label
192+ scale = scales .get (function )
193+ if scale is not None :
194+ label = f"{ scale } { label } "
195+ value = scale * value
196+
175197 ax .plot (
176198 twiss .s ,
177- value if scale == "" else scale * value ,
199+ value ,
178200 color = color ,
179201 linewidth = line_width ,
180202 linestyle = line_style ,
181203 alpha = alpha ,
182204 zorder = 10 - i ,
183205 label = label ,
184206 )
185- text_areas .append ( TextArea (label , textprops = dict (color = color , rotation = 90 )))
207+ text_areas .insert ( 0 , TextArea (label , textprops = dict (color = color , rotation = 90 )))
186208
187209 ax .set_xlabel ("Orbit Position $s$ / m" )
188210 if show_ylabels :
189211 ax .add_artist (
190212 AnchoredOffsetbox (
191- loc = 8 ,
192- child = VPacker (children = text_areas , align = "bottom" , pad = 0 , sep = 10 ),
193- pad = 0.0 ,
194- frameon = False ,
195- bbox_to_anchor = (- 0.08 , 0.3 ),
213+ child = VPacker (children = text_areas , align = "bottom" , pad = 0 , sep = 20 ),
214+ loc = "center left" ,
215+ bbox_to_anchor = (- 0.125 , 0 , 1.125 , 1 ),
196216 bbox_transform = ax .transAxes ,
197- borderpad = 0.0 ,
217+ frameon = False ,
198218 )
199219 )
200220
@@ -270,7 +290,7 @@ def __init__(
270290 main = True ,
271291 scales = {"eta_x" : 10 },
272292 ref_twiss = None ,
273- pairs = None ,
293+ pairs : Optional [ List [ Tuple [ Element , str ]]] = None ,
274294 ):
275295 self .fig = plt .figure ()
276296 self .twiss = twiss
@@ -372,25 +392,30 @@ def find_optimal_grid(n):
372392
373393
374394def floor_plan (
375- lattice , ax = None , start_angle = 0 , annotate_elements = True , direction = "clockwise"
395+ ax : mpl .axes .Axes ,
396+ lattice : Lattice ,
397+ * ,
398+ start_angle : float = 0 ,
399+ labels : bool = True ,
400+ direction : str = "clockwise" ,
376401):
377- if ax is None :
378- ax = plt .gca ()
379-
380402 ax .set_aspect ("equal" )
381- codes = [ Path .MOVETO , Path .LINETO ]
403+ codes = Path .MOVETO , Path .LINETO
382404 current_angle = start_angle
383-
384405 start = np .zeros (2 )
385406 end = np .zeros (2 )
386407 x_min = y_min = 0
387408 x_max = y_max = 0
388409 arrangement = lattice .arrangement
389- arrangement_shifted = arrangement [1 :] + arrangement [0 :1 ]
390- for element , next_element in zip (arrangement , arrangement_shifted ):
391- color = ELEMENT_COLOR [type (element )]
410+ sign = 1
411+ for element , next_element in zip_longest (arrangement , arrangement [1 :]):
392412 length = element .length
393- line_width = 0.5 if isinstance (element , Drift ) else 3
413+ if isinstance (element , Drift ):
414+ color = Color .BLACK
415+ line_width = 1
416+ else :
417+ color = ELEMENT_COLOR [type (element )]
418+ line_width = 6
394419
395420 # TODO: refactor current angle
396421 angle = 0
@@ -440,9 +465,9 @@ def floor_plan(
440465 if element is next_element :
441466 continue
442467
443- if annotate_elements and not isinstance (element , Drift ):
468+ if labels and isinstance (element , ( Dipole , Quadrupole ) ):
444469 angle_center = (current_angle - angle / 2 ) + np .pi / 2
445- sign = - 1 if isinstance ( element , Quadrupole ) else 1
470+ sign = - sign
446471 center = (start + end ) / 2 + sign * 0.5 * np .array (
447472 [np .cos (angle_center ), np .sin (angle_center )]
448473 )
@@ -459,8 +484,7 @@ def floor_plan(
459484
460485 start = end .copy ()
461486
462- margin = 0.05 * max ((x_max - x_min ), (y_max - y_min ))
487+ margin = 0.01 * max ((x_max - x_min ), (y_max - y_min ))
463488 ax .set_xlim (x_min - margin , x_max + margin )
464489 ax .set_ylim (y_min - margin , y_max + margin )
465-
466490 return ax
0 commit comments