11from collections .abc import Iterable
2- from itertools import zip_longest
2+ from itertools import groupby
33from math import inf
4- from typing import List , Optional , Tuple
4+ from typing import Dict , List , Optional , Tuple
55
66import matplotlib as mpl
77import matplotlib .gridspec as grid_spec
1010import numpy as np
1111from matplotlib .offsetbox import AnchoredOffsetbox , TextArea , VPacker
1212from matplotlib .path import Path
13- from matplotlib .ticker import AutoMinorLocator , ScalarFormatter
1413from matplotlib .widgets import Slider
1514
1615from .classes import (
@@ -37,9 +36,10 @@ class Color:
3736 CYAN = "#06B6D4"
3837 WHITE = "white"
3938 BLACK = "black"
39+ LIGHT_GRAY = "#E5E7EB"
4040
4141
42- ELEMENT_COLOR = {
42+ ELEMENT_COLOR : Dict [ type , str ] = {
4343 Dipole : Color .YELLOW ,
4444 Quadrupole : Color .RED ,
4545 Sextupole : Color .GREEN ,
@@ -78,17 +78,16 @@ def draw_elements(
7878 plt .hlines (y0 , x_min , x_max , color = "black" , linewidth = 1 )
7979 ax .set_ylim (y_min , y_max )
8080
81- sequence = lattice . sequence
82- position = start = end = 0
83- sign = 1
84- for element , next_element in zip_longest ( sequence , sequence [ 1 :]):
85- position += element .length
86- if element is next_element or position <= x_min :
81+ sign = - 1
82+ start = end = 0
83+ for element , group in groupby ( lattice . sequence ):
84+ start = end
85+ end += element .length * sum ( 1 for _ in group )
86+ if end <= x_min :
8787 continue
8888 elif start >= x_max :
8989 break
9090
91- start , end = end , position
9291 try :
9392 color = ELEMENT_COLOR [type (element )]
9493 except KeyError :
@@ -109,11 +108,10 @@ def draw_elements(
109108 )
110109 )
111110 if labels and type (element ) in {Dipole , Quadrupole }:
112- # sign = (isinstance(element, Quadrupole) << 1) - 1
113111 sign = - sign
114112 ax .annotate (
115113 element .name ,
116- xy = ((start + end ) / 2 , y0 - sign * rect_height ),
114+ xy = ((start + end ) / 2 , y0 + sign * rect_height ),
117115 fontsize = FONT_SIZE ,
118116 ha = "center" ,
119117 va = "center" ,
@@ -139,7 +137,7 @@ def draw_sub_lattices(
139137 # if len(ticks) < 5:
140138 # ax.xaxis.set_minor_locator(AutoMinorLocator())
141139 # ax.xaxis.set_minor_formatter(ScalarFormatter())
142- ax .grid (axis = "x" , linestyle = "--" )
140+ ax .grid (axis = "x" , color = Color . LIGHT_GRAY , linestyle = "--" , linewidth = 1 )
143141
144142 if labels :
145143 y_min , y_max = ax .get_ylim ()
@@ -406,10 +404,10 @@ def floor_plan(
406404 end = np .zeros (2 )
407405 x_min = y_min = 0
408406 x_max = y_max = 0
409- sequence = lattice .sequence
410407 sign = 1
411- for element , next_element in zip_longest (sequence , sequence [1 :]):
412- length = element .length
408+ for element , group in groupby (lattice .sequence ):
409+ start = end .copy ()
410+ length = element .length * sum (1 for _ in group )
413411 if isinstance (element , Drift ):
414412 color = Color .BLACK
415413 line_width = 1
@@ -420,7 +418,7 @@ def floor_plan(
420418 # TODO: refactor current angle
421419 angle = 0
422420 if isinstance (element , Dipole ):
423- angle = element .angle
421+ angle = element .k0 * length
424422 radius = length / angle
425423 vec = radius * np .array ([np .sin (angle ), 1 - np .cos (angle )])
426424 sin = np .sin (current_angle )
@@ -462,8 +460,6 @@ def floor_plan(
462460 y_max = max (y_max , end [1 ])
463461
464462 ax .add_patch (line ) # TODO: currently splitted elements get drawn twice
465- if element is next_element :
466- continue
467463
468464 if labels and isinstance (element , (Dipole , Quadrupole )):
469465 angle_center = (current_angle - angle / 2 ) + np .pi / 2
@@ -482,8 +478,6 @@ def floor_plan(
482478 zorder = 11 ,
483479 )
484480
485- start = end .copy ()
486-
487481 margin = 0.01 * max ((x_max - x_min ), (y_max - y_min ))
488482 ax .set_xlim (x_min - margin , x_max + margin )
489483 ax .set_ylim (y_min - margin , y_max + margin )
0 commit comments