1- from enum import Enum
2- from typing import Union , List , Tuple
31import matplotlib as mpl
42import matplotlib .pyplot as plt
53import matplotlib .patches as patches
97from matplotlib .path import Path
108from matplotlib .ticker import AutoMinorLocator , ScalarFormatter
119import numpy as np
12- from enum import Enum
1310from math import inf
1411from collections .abc import Iterable
1512from .classes import Base , Drift , Dipole , Quadrupole , Sextupole , Octupole , Lattice
1613
1714
15+ FONT_SIZE = 8
16+
17+
1818class Color :
1919 RED = "crimson"
2020 GREEN = "mediumseagreen"
@@ -34,26 +34,24 @@ class Color:
3434 Octupole : Color .BLUE ,
3535}
3636
37- FONT_SIZE = 8
3837
39- [
40- "beta_x" ,
41- "beta_y" ,
42- "eta_x" ,
43- "psi_x" ,
44- "psi_y" ,
45- "alpha_x" ,
46- "alpha_y" ,
47- "gamma_x" ,
48- "gamma_y" ,
49- ]
38+ OPTICAL_FUNCTIONS = {
39+ "beta_x" : (r"$\beta_x$/m" , Color .RED ),
40+ "beta_y" : (r"$\beta_y$/m" , Color .BLUE ),
41+ "eta_x" : (r"$\eta_x$/m" , Color .GREEN ),
42+ "psi_x" : (r"$\psi_x$" , Color .YELLOW ),
43+ "psi_y" : (r"$\psi_y$" , Color .ORANGE ),
44+ "alpha_x" : (r"$\alpha_x$" , Color .MAGENTA ),
45+ "alpha_y" : (r"$\alpha_y$" , Color .BLACK ),
46+ }
5047
5148
5249def draw_lattice (
5350 lattice ,
5451 ax = None ,
5552 x_min = - inf ,
5653 x_max = inf ,
54+ location = "top" ,
5755 draw_elements = True ,
5856 annotate_elements = True ,
5957 draw_sub_lattices = True ,
@@ -82,6 +80,12 @@ def draw_lattice(
8280 y_span = y_max - y_min
8381 rect_height = y_span / 32
8482
83+ y0 = - rect_height / 2
84+ if location == "top" :
85+ y0 += y_max
86+ elif location == "bottom" :
87+ y0 += y_min
88+
8589 if draw_elements :
8690 start = end = 0
8791 arrangement = lattice .arrangement
@@ -96,7 +100,7 @@ def draw_lattice(
96100
97101 rec_length = min (end , x_max ) - max (start , x_min )
98102 rectangle = plt .Rectangle (
99- (start if start > x_min else x_min , y_max - rect_height / 2 ),
103+ (start if start > x_min else x_min , y0 ),
100104 rec_length ,
101105 rect_height ,
102106 fc = ELEMENT_COLOR [type (element )],
@@ -112,7 +116,7 @@ def draw_lattice(
112116 )
113117 ax .annotate (
114118 element .name ,
115- xy = (center , y_max + sign * 0.75 * rect_height ),
119+ xy = (center , y0 + sign * 0.75 * rect_height ),
116120 fontsize = FONT_SIZE ,
117121 ha = "center" ,
118122 va = va ,
@@ -135,7 +139,7 @@ def draw_lattice(
135139 ax .grid (linestyle = "--" )
136140
137141 if annotate_sub_lattices :
138- y0 = y_max - 3 * rect_height
142+ y0_anno = y0 - 3 * rect_height
139143 end = 0
140144 for obj in lattice .tree :
141145 end += obj .length
@@ -145,7 +149,7 @@ def draw_lattice(
145149 x0 = end - obj .length / 2
146150 ax .annotate (
147151 obj .name ,
148- xy = (x0 , y0 ),
152+ xy = (x0 , y0_anno ),
149153 fontsize = FONT_SIZE ,
150154 fontstyle = "oblique" ,
151155 alpha = 0.5 ,
@@ -158,35 +162,37 @@ def draw_lattice(
158162
159163def plot_twiss (
160164 twiss ,
165+ twiss_functions = ("beta_x" , "beta_y" , "eta_x" ),
166+ * ,
167+ scales = {"eta_x" : 10 },
161168 ax = None ,
162169 line_style = "solid" ,
163170 line_width = 1.3 ,
164171 alpha = 1.0 ,
165- eta_scale = 10 ,
166172 show_ylabels = False ,
167173):
168174 if ax is None :
169175 ax = plt .gca ()
170-
171- text_areas = [None ] * 3
172- for value , label , color , order in (
173- (twiss .beta_x , r"$\beta_x$/m" , Color .RED , 2 ),
174- (twiss .beta_y , r"$\beta_y$/m" , Color .BLUE , 1 ),
175- (twiss .eta_x * eta_scale , rf"{ eta_scale } $\eta_x$/m" , Color .GREEN , 0 ),
176- # (twiss.curly_h, rf"{eta_scale}$\mathscr{{H}}_x$", Color.ORANGE, -1),
177- ):
176+ if scales is None :
177+ scales = {}
178+
179+ text_areas = []
180+ for i , function in enumerate (twiss_functions ):
181+ value = getattr (twiss , function )
182+ scale = scales .get (function , "" )
183+ label , color = OPTICAL_FUNCTIONS [function ]
184+ label = str (scale ) + label
178185 ax .plot (
179186 twiss .s ,
180- value ,
187+ value if scale == "" else scale * value ,
181188 color = color ,
182189 linewidth = line_width ,
183190 linestyle = line_style ,
184191 alpha = alpha ,
185- zorder = order ,
192+ zorder = 10 - i ,
186193 label = label ,
187194 )
188-
189- text_areas [order ] = TextArea (label , textprops = dict (color = color , rotation = 90 ))
195+ text_areas .append (TextArea (label , textprops = dict (color = color , rotation = 90 )))
190196
191197 ax .set_xlabel ("Orbit Position $s$ / m" )
192198 if show_ylabels :
@@ -218,22 +224,23 @@ def _twiss_plot_section(
218224 ref_twiss = None ,
219225 ref_line_style = "dashed" ,
220226 ref_line_width = 2.5 ,
221- eta_scale = 10 ,
227+ scales = { "eta_x" : 10 } ,
222228 overwrite = False ,
223229):
224230 if overwrite :
225231 ax .clear ()
226232 if ref_twiss :
227233 plot_twiss (
228234 ref_twiss ,
229- ax ,
230- ref_line_style ,
231- ref_line_width ,
235+ ax = ax ,
236+ line_style = ref_line_style ,
237+ line_width = ref_line_width ,
232238 alpha = 0.5 ,
233- eta_scale = eta_scale ,
234239 )
235240
236- plot_twiss (twiss , ax , line_style , line_width , eta_scale = eta_scale )
241+ plot_twiss (
242+ twiss , ax = ax , line_style = line_style , line_width = line_width , scales = scales
243+ )
237244 if x_min is None :
238245 x_min = 0
239246 if x_max is None :
@@ -251,6 +258,7 @@ def _twiss_plot_section(
251258# TODO:
252259# * make sub_class of figure
253260# * add attribute which defines which twiss parameters are plotted
261+ # * add twiss_functions argument similar to plot_twiss
254262class TwissPlot :
255263 """Convenience class to plot twiss parameters
256264
@@ -260,7 +268,7 @@ class TwissPlot:
260268 :param y_max float: Maximum y-limit
261269 :param y_min float: Minimum y-limit
262270 :param main bool: Wheter to plot whole ring or only given sections
263- :param eta_scale int: Scaling factor of the dipsersion function
271+ :param scales Dict[str, int]: Optional scaling factors for optical functions
264272 :param Twiss ref_twiss: Reference twiss values. Will be plotted as dashed lines.
265273 :param pairs: List of (element, attribute)-pairs to create interactice sliders for.
266274 :type pairs: List[Tuple[Element, str]]
@@ -269,25 +277,31 @@ class TwissPlot:
269277 def __init__ (
270278 self ,
271279 twiss ,
280+ twiss_functions = ("beta_x" , "beta_y" , "eta_x" ),
281+ * ,
272282 sections = None ,
273283 y_min = None ,
274284 y_max = None ,
275285 main = True ,
276- eta_scale = 10 ,
286+ scales = { "eta_x" : 10 } ,
277287 ref_twiss = None ,
278288 pairs = None ,
279289 ):
280290 self .fig = plt .figure ()
281291 self .twiss = twiss
282292 self .lattice = twiss .lattice
283- self .eta_scale = eta_scale
293+ self .twiss_functions = twiss_functions
294+ self .scales = scales
284295 height_ratios = [4 , 14 ] if (main and sections ) else [1 ]
285296 main_grid = grid_spec .GridSpec (
286297 len (height_ratios ), 1 , self .fig , height_ratios = height_ratios
287298 )
299+ self .axs_sections = [] # TODO: needed for update function
288300
289301 if pairs :
290302 fig_sliders , axs = plt .subplots (nrows = len (pairs ))
303+ if not isinstance (axs , Iterable ):
304+ axs = (axs ,)
291305 self .sliders = []
292306 for ax , (element , attribute ) in zip (axs , pairs ):
293307 initial_value = getattr (element , attribute )
@@ -310,7 +324,7 @@ def __init__(
310324 y_min = y_min ,
311325 y_max = y_max ,
312326 annotate_elements = False ,
313- eta_scale = eta_scale ,
327+ scales = scales ,
314328 )
315329
316330 if sections :
@@ -338,32 +352,22 @@ def __init__(
338352 y_min = y_min ,
339353 y_max = y_max ,
340354 annotate_elements = True ,
341- eta_scale = eta_scale ,
355+ scales = scales ,
342356 )
343357
344358 handles , labels = self .fig .axes [0 ].get_legend_handles_labels ()
345- self .fig .legend (
346- handles ,
347- labels ,
348- loc = "upper left" ,
349- ncol = 10 ,
350- frameon = False ,
351- )
359+ self .fig .legend (handles , labels , loc = "upper left" , ncol = 10 , frameon = False )
352360 self .fig .suptitle (twiss .lattice .name , ha = "right" , x = 0.98 )
353361 self .fig .tight_layout ()
354362
355363 def update (self ):
356364 twiss = self .twiss
357365 for ax in [self .ax_main ] + self .axs_sections :
358- for line , data in zip (
359- ax .lines ,
360- (
361- twiss .beta_x ,
362- twiss .beta_y ,
363- twiss .eta_x * self .eta_scale ,
364- twiss .curly_h * self .eta_scale ,
365- ),
366- ):
366+ for line , function in zip (ax .lines , self .twiss_functions ):
367+ data = getattr (twiss , function )
368+ scale = self .scales .get (function )
369+ if scale is not None :
370+ data *= scale
367371 line .set_data (twiss .s , data )
368372 self .fig .canvas .draw_idle ()
369373
0 commit comments