Skip to content

Commit 2b88525

Browse files
committed
Edge labels
1 parent b970c49 commit 2b88525

File tree

8 files changed

+222
-39
lines changed

8 files changed

+222
-39
lines changed

iplotx/edge/label.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,26 @@
88
)
99

1010

11-
@_forwarder(
12-
(
13-
"set_clip_path",
14-
"set_clip_box",
15-
"set_transform",
16-
"set_snap",
17-
"set_sketch_params",
18-
"set_figure",
19-
"set_animated",
20-
"set_picker",
21-
)
22-
)
2311
class LabelCollection(mpl.artist.Artist):
24-
def __init__(self, labels, offsets=None, style=None):
25-
self._create_labels(labels, offsets, style)
12+
def __init__(self, labels, style=None):
13+
self._labels = labels
14+
self._style = style
15+
super().__init__()
2616

27-
def _create_labels(self, labels, offsets, style):
28-
if offsets is None:
29-
offsets = np.zeros((len(labels), 2))
30-
if style is None:
31-
style = {}
17+
def _create_labels(self):
18+
style = self._style if self._style is not None else {}
3219

3320
arts = []
34-
for label, offset in zip(labels, offsets):
21+
for label in self._labels:
3522
art = mpl.text.Text(
36-
offset[0],
37-
offset[1],
23+
0,
24+
0,
3825
label,
3926
transform=self.axes.transData,
4027
**style,
4128
)
29+
art.set_figure(self.figure)
30+
art.axes = self.axes
4231
arts.append(art)
4332
self._labels = arts
4433

@@ -49,10 +38,6 @@ def set_offsets(self, offsets):
4938
for art, offset in zip(self._labels, offsets):
5039
art.set_position((offset[0], offset[1]))
5140

52-
@property
53-
def stale(self):
54-
return super().stale
55-
5641
@_stale_wrapper
5742
def draw(self, renderer, *args, **kwds):
5843
"""Draw each of the children, with some buffering mechanism."""
@@ -62,4 +47,4 @@ def draw(self, renderer, *args, **kwds):
6247
# We should manage zorder ourselves, but we need to compute
6348
# the new offsets and angles of arrows from the edges before drawing them
6449
for art in self.get_children():
65-
art.draw(renderer, *args, **kwargs)
50+
art.draw(renderer, *args, **kwds)

iplotx/edge/undirected.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import matplotlib as mpl
55

66
from .common import _compute_loops_per_angle
7+
from .label import LabelCollection
8+
from ..tools.matplotlib import (
9+
_compute_mid_coord,
10+
_stale_wrapper,
11+
)
712

813

914
class UndirectedEdgeCollection(mpl.collections.PatchCollection):
@@ -325,28 +330,50 @@ def _shorten_path_undirected_curved(
325330
return path
326331

327332
def _compute_labels(self):
333+
style = self._style.get("label", None) if self._style is not None else None
328334
offsets = []
329335
for path in self._paths:
330336
offset = _compute_mid_coord(path)
331337
offsets.append(offset)
332338

333339
if not hasattr(self, "_label_collection"):
334-
self._label_collection = LabelCollection(self._labels, offsets=offsets)
335-
else:
336-
self._label_collection.set_offsets(offsets)
340+
self._label_collection = LabelCollection(
341+
self._labels,
342+
style=style,
343+
)
344+
345+
# Forward a bunch of mpl settings that are needed
346+
self._label_collection.set_figure(self.figure)
347+
self._label_collection.axes = self.axes
348+
# forward the clippath/box to the children need this logic
349+
# because mpl exposes some fast-path logic
350+
clip_path = self.get_clip_path()
351+
if clip_path is None:
352+
clip_box = self.get_clip_box()
353+
self._label_collection.set_clip_box(clip_box)
354+
else:
355+
self._label_collection.set_clip_path(clip_path)
356+
357+
# Finally make the patches
358+
self._label_collection._create_labels()
359+
self._label_collection.set_offsets(offsets)
337360

338361
def get_children(self):
339362
children = []
340363
if hasattr(self, "_label_collection"):
341364
children.append(self._label_collection)
342365
return children
343366

344-
def draw(self, renderer):
367+
@_stale_wrapper
368+
def draw(self, renderer, *args, **kwds):
345369
if self._vertex_paths is not None:
346370
self._paths = self._compute_paths()
347371
if self._labels is not None:
348372
self._compute_labels()
349-
return super().draw(renderer)
373+
super().draw(renderer)
374+
375+
for child in self.get_children():
376+
child.draw(renderer, *args, **kwds)
350377

351378
@property
352379
def stale(self):

iplotx/matplotlib.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from functools import wraps, partial
2+
import matplotlib as mpl
3+
4+
5+
# NOTE: https://github.com/networkx/grave/blob/main/grave/grave.py
6+
def _stale_wrapper(func):
7+
"""Decorator to manage artist state."""
8+
9+
@wraps(func)
10+
def inner(self, *args, **kwargs):
11+
try:
12+
func(self, *args, **kwargs)
13+
finally:
14+
self.stale = False
15+
16+
return inner
17+
18+
19+
def _forwarder(forwards, cls=None):
20+
"""Decorator to forward specific methods to Artist children."""
21+
if cls is None:
22+
return partial(_forwarder, forwards)
23+
24+
def make_forward(name):
25+
def method(self, *args, **kwargs):
26+
ret = getattr(cls.mro()[1], name)(self, *args, **kwargs)
27+
for c in self.get_children():
28+
getattr(c, name)(*args, **kwargs)
29+
return ret
30+
31+
return method
32+
33+
for f in forwards:
34+
method = make_forward(f)
35+
method.__name__ = f
36+
method.__doc__ = "broadcasts {} to children".format(f)
37+
setattr(cls, f, method)
38+
39+
return cls
40+
41+
42+
def _additional_set_methods(attributes, cls=None):
43+
"""Decorator to add specific set methods for children properties."""
44+
if cls is None:
45+
return partial(_additional_set_methods, attributes)
46+
47+
def make_setter(name):
48+
def method(self, value):
49+
self.set(**{name: value})
50+
51+
return method
52+
53+
for attr in attributes:
54+
desc = attr.replace("_", " ")
55+
method = make_setter(attr)
56+
method.__name__ = f"set_{attr}"
57+
method.__doc__ = f"Set {desc}."
58+
setattr(cls, f"set_{attr}", method)
59+
60+
return cls
61+
62+
63+
# FIXME: this method appears quite inconsistent, would be better to improve.
64+
# The issue is that to really know the size of a label on screen, we need to
65+
# render it first. Therefore, we should render the labels, then render the
66+
# vertices. Leaving for now, since this can be styled manually which covers
67+
# many use cases.
68+
def _get_label_width_height(text, hpadding=18, vpadding=12, **kwargs):
69+
"""Get the bounding box size for a text with certain properties."""
70+
forbidden_props = ["horizontalalignment", "verticalalignment", "ha", "va"]
71+
for prop in forbidden_props:
72+
if prop in kwargs:
73+
del kwargs[prop]
74+
75+
path = mpl.textpath.TextPath((0, 0), text, **kwargs)
76+
boundingbox = path.get_extents()
77+
width = boundingbox.width + hpadding
78+
height = boundingbox.height + vpadding
79+
return (width, height)
80+
81+
82+
def _compute_mid_coord(path):
83+
if (path.codes[-1] in (mpl.path.CURVE4, mpl.CURVE3)):
84+

iplotx/network.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Union, Sequence
22
import numpy as np
33
import pandas as pd
44
import matplotlib as mpl
@@ -55,6 +55,7 @@ def __init__(
5555
network: GraphType,
5656
layout: LayoutType = None,
5757
vertex_labels: Union[None, list, dict, pd.Series] = None,
58+
edge_labels: Union[None, Sequence] = None,
5859
):
5960
"""Network container artist that groups all plotting elements.
6061
@@ -66,12 +67,16 @@ def __init__(
6667
vertex_labels (list, dict, or pandas.Series): The labels for the vertices. If None, no vertex labels
6768
will be drawn. If a list, the labels are taken from the list. If a dict, the keys
6869
should be the vertex IDs and the values should be the labels.
70+
elge_labels (sequence): The labels for the edges. If None, no edge labels will be drawn.
6971
"""
7072
super().__init__()
7173

7274
self.network = network
7375
self._ipx_internal_data = _create_internal_data(
74-
network, layout, vertex_labels=vertex_labels
76+
network,
77+
layout,
78+
vertex_labels=vertex_labels,
79+
edge_labels=edge_labels,
7580
)
7681
self._clear_state()
7782

@@ -410,7 +415,12 @@ def draw(self, renderer, *args, **kwds):
410415

411416

412417
# INTERNAL ROUTINES
413-
def _create_internal_data(network, layout=None, vertex_labels=None):
418+
def _create_internal_data(
419+
network,
420+
layout=None,
421+
vertex_labels=None,
422+
edge_labels=None,
423+
):
414424
"""Create internal data for the network."""
415425
nl = network_library(network)
416426
directed = detect_directedness(network)
@@ -427,6 +437,10 @@ def _create_internal_data(network, layout=None, vertex_labels=None):
427437

428438
# Vertex labels
429439
if vertex_labels is not None:
440+
if len(vertex_labels) != len(vertex_df):
441+
raise ValueError(
442+
"Vertex labels must be the same length as the number of vertices."
443+
)
430444
vertex_df["label"] = vertex_labels
431445

432446
# Edges are a list of tuples, because of multiedges
@@ -438,6 +452,14 @@ def _create_internal_data(network, layout=None, vertex_labels=None):
438452
edge_df = pd.DataFrame(tmp)
439453
del tmp
440454

455+
# Edge labels
456+
if edge_labels is not None:
457+
if len(edge_labels) != len(edge_df):
458+
raise ValueError(
459+
"Edge labels must be the same length as the number of edges."
460+
)
461+
edge_df["labels"] = edge_labels
462+
441463
else:
442464
# Vertices are ordered integers, no gaps
443465
layout = normalise_layout(layout)
@@ -450,6 +472,10 @@ def _create_internal_data(network, layout=None, vertex_labels=None):
450472

451473
# Vertex labels
452474
if vertex_labels is not None:
475+
if len(vertex_labels) != len(vertex_df):
476+
raise ValueError(
477+
"Vertex labels must be the same length as the number of vertices."
478+
)
453479
vertex_df["label"] = vertex_labels
454480

455481
# Edges are a list of tuples, because of multiedges
@@ -461,6 +487,14 @@ def _create_internal_data(network, layout=None, vertex_labels=None):
461487
edge_df = pd.DataFrame(tmp)
462488
del tmp
463489

490+
# Edge labels
491+
if edge_labels is not None:
492+
if len(edge_labels) != len(edge_df):
493+
raise ValueError(
494+
"Edge labels must be the same length as the number of edges."
495+
)
496+
edge_df["labels"] = edge_labels
497+
464498
internal_data = {
465499
"vertex_df": vertex_df,
466500
"edge_df": edge_df,

iplotx/plotting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Union, Sequence
22
import pandas as pd
33
import matplotlib.pyplot as plt
44

@@ -13,6 +13,7 @@ def plot(
1313
network: GraphType,
1414
layout: Union[LayoutType, None] = None,
1515
vertex_labels: Union[None, list, dict, pd.Series] = None,
16+
edge_labels: Union[None, Sequence] = None,
1617
ax: Union[None, object] = None,
1718
):
1819
"""Plot this network using the specified layout.
@@ -31,7 +32,12 @@ def plot(
3132
if ax is None:
3233
fig, ax = plt.subplots()
3334

34-
nwkart = NetworkArtist(network, layout, vertex_labels=vertex_labels)
35+
nwkart = NetworkArtist(
36+
network,
37+
layout,
38+
vertex_labels=vertex_labels,
39+
edge_labels=edge_labels,
40+
)
3541

3642
ax.add_artist(nwkart)
3743

iplotx/styles.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
"color": "black",
2020
"curved": True,
2121
"tension": +1.5,
22+
"label": {
23+
"horizontalalignment": "center",
24+
"verticalalignment": "center",
25+
},
2226
},
2327
"arrow": {
2428
"marker": "|>",
@@ -83,6 +87,7 @@ def use(style: Union[str, dict, Sequence]):
8387
"default" resets the style to the default one. If this is a sequence,
8488
each style is applied in order.
8589
"""
90+
global current
8691

8792
def _update(style: dict, current: dict):
8893
for key, value in style.items():
@@ -107,8 +112,9 @@ def _update(style: dict, current: dict):
107112
reset()
108113
else:
109114
if isinstance(style, str):
110-
style = get_style(style)
111-
_update(style, current)
115+
current = get_style(style)
116+
else:
117+
_update(style, current)
112118

113119

114120
def reset():

iplotx/tools/geometry.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# See also this link for the general answer (using scipy to compute coefficients):
2+
# https://stackoverflow.com/questions/12643079/b%C3%A9zier-curve-fitting-with-scipy
3+
def _evaluate_squared_bezier(points, t):
4+
"""Evaluate a squared Bezier curve at t."""
5+
p0, p1, p2 = points
6+
return (1 - t) ** 2 * p0 + 2 * (1 - t) * t * p1 + t**2 * p2
7+
8+
9+
def _evaluate_cubic_bezier(points, t):
10+
"""Evaluate a cubic Bezier curve at t."""
11+
p0, p1, p2, p3 = points
12+
return (
13+
(1 - t) ** 3 * p0
14+
+ 3 * (1 - t) ** 2 * t * p1
15+
+ 3 * (1 - t) * t**2 * p2
16+
+ t**3 * p3
17+
)

0 commit comments

Comments
 (0)