33import numpy as np
44import matplotlib as mpl
55from matplotlib .patches import PathPatch
6+ from matplotlib .transforms import Affine2D
67
78from .common import _compute_loops_per_angle
89from .undirected import UndirectedEdgeCollection
@@ -29,13 +30,17 @@ class DirectedEdgeCollection(mpl.artist.Artist):
2930 def __init__ (self , edges , arrows , ** kwargs ):
3031 super ().__init__ ()
3132
32- kwargs_arrows = {}
33- if "color" in kwargs .get ("style" , {}):
34- kwargs_arrows ["color" ] = kwargs ["style" ]["color" ]
35-
3633 # FIXME: do we need a separate _clear_state and _process like in the network
3734 self ._edges = UndirectedEdgeCollection (edges , ** kwargs )
38- self ._arrows = EdgeArrowCollection (arrows , ** kwargs_arrows )
35+
36+ # NOTE: offsets are a placeholder for later
37+ self ._arrows = EdgeArrowCollection (
38+ arrows ,
39+ offsets = np .zeros ((len (arrows ), 2 )),
40+ offset_transform = kwargs ["transform" ],
41+ transform = Affine2D (),
42+ match_original = True ,
43+ )
3944 self ._processed = False
4045
4146 def get_children (self ):
@@ -81,6 +86,29 @@ def _process(self):
8186
8287 self ._processed = True
8388
89+ def _set_edge_info_for_arrows (self , which = "end" , transform = None ):
90+ """Extract the start and/or end angles of the paths to compute arrows."""
91+ if transform is None :
92+ transform = self .get_transform ()
93+ trans = transform .transform
94+ trans_inv = transform .inverted ().transform
95+
96+ arrow_offsets = self ._arrows ._offsets
97+ for i , epath in enumerate (self ._edges ._paths ):
98+ # Offset the arrow to point to the end of the edge
99+ self ._arrows ._offsets [i ] = epath .vertices [- 1 ]
100+
101+ # Rotate the arrow to point in the direction of the edge
102+ apath = self ._arrows ._paths [i ]
103+ # NOTE: because the tip of the arrow is at (0, 0) in patch space,
104+ # in theory it will rotate around that point already
105+ v2 = trans (epath .vertices [- 1 ])
106+ v1 = trans (epath .vertices [- 2 ])
107+ dv = v2 - v1
108+ theta = atan2 (* (dv [::- 1 ]))
109+ mrot = np .array ([[cos (theta ), - sin (theta )], [sin (theta ), cos (theta )]])
110+ apath .vertices = apath .vertices @ mrot
111+
84112 @_stale_wrapper
85113 def draw (self , renderer , * args , ** kwds ):
86114 """Draw each of the children, with some buffering mechanism."""
@@ -90,12 +118,11 @@ def draw(self, renderer, *args, **kwds):
90118 if not self ._processed :
91119 self ._process ()
92120
93- # NOTE: looks like we have to manage the zorder ourselves
94- # this is kind of funny actually
95- children = list (self .get_children ())
96- children .sort (key = lambda x : x .zorder )
97- for art in children :
98- art .draw (renderer , * args , ** kwds )
121+ # We should manage zorder ourselves, but we need to compute
122+ # the new offsets and angles of arrows from the edges before drawing them
123+ self ._edges .draw (renderer , * args , ** kwds )
124+ self ._set_edge_info_for_arrows (which = "end" )
125+ self ._arrows .draw (renderer , * args , ** kwds )
99126
100127
101128class EdgeArrowCollection (mpl .collections .PatchCollection ):
@@ -115,9 +142,9 @@ def stale(self, val):
115142 self .stale_callback_post (self )
116143
117144
118- def make_arrow_patch (marker : str = "|>" , width : float = 3 , ** kwargs ):
145+ def make_arrow_patch (marker : str = "|>" , width : float = 8 , ** kwargs ):
119146 """Make a patch of the given marker shape and size."""
120- height = kwargs .pop ("height" , width * 1.5 )
147+ height = kwargs .pop ("height" , width * 1.3 )
121148
122149 if marker == "|>" :
123150 codes = ["MOVETO" , "LINETO" , "LINETO" ]
0 commit comments