Skip to content

Commit 6200bf8

Browse files
authored
Fixes kwarg validation bugs in draw() reported in #730 and #731. (#732)
* fix: draw kwarg issues #730 #731 * tests: improved * docs: added note * fix: bug that was silently ignored in tuto
1 parent 6223e7e commit 6200bf8

3 files changed

Lines changed: 196 additions & 122 deletions

File tree

tests/drawing/test_draw.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,105 @@ def test_issue_515(edgelist8):
768768
xgi.draw_multilayer(H, node_fc=["black"] * H.num_nodes)
769769

770770
plt.close("all")
771+
772+
773+
def test_draw_kwargs_validation(edgelist8):
774+
"""Regression tests for kwarg validation bugs fixed in issues #730 and #731."""
775+
H = xgi.Hypergraph(edgelist8)
776+
S = xgi.SimplicialComplex(edgelist8)
777+
778+
# issue #731: typo is caught even when labels are off (the default)
779+
with pytest.raises(TypeError):
780+
xgi.draw(H, font_siz_nodes=20)
781+
plt.close("all")
782+
783+
# issue #731: same typo is caught consistently when labels are on
784+
with pytest.raises(TypeError):
785+
xgi.draw(H, node_labels=True, font_siz_nodes=20)
786+
plt.close("all")
787+
788+
# issue #730 bug A: a valid edge-label kwarg must not be rejected when
789+
# node_labels=True (previously the node-label validator ran first and
790+
# rejected kwargs intended for draw_hyperedge_labels)
791+
xgi.draw(H, node_labels=True, hyperedge_labels=True, font_size_edges=14)
792+
plt.close("all")
793+
794+
# issue #730 bug A: same for SimplicialComplex
795+
xgi.draw(S, node_labels=True, hyperedge_labels=True, font_size_edges=14)
796+
plt.close("all")
797+
798+
# issue #730 bug B: a settings kwarg must not be rejected when node_labels=True
799+
# (previously min_node_size was not in draw_node_labels' signature and raised)
800+
xgi.draw(H, node_labels=True, min_node_size=3)
801+
plt.close("all")
802+
803+
# issue #730 bug C: unknown kwargs raise TypeError, not ValueError
804+
with pytest.raises(TypeError):
805+
xgi.draw(H, not_a_real_kwarg=99)
806+
plt.close("all")
807+
808+
# valid label kwargs are accepted even when labels are off — draw() validates
809+
# against all known keys unconditionally, so users can pass label kwargs
810+
# without having to enable labels first
811+
xgi.draw(H, node_labels=False, font_size_nodes=14)
812+
plt.close("all")
813+
xgi.draw(H, hyperedge_labels=False, font_size_edges=14)
814+
plt.close("all")
815+
816+
# mixing all three kwarg buckets at once works without error
817+
xgi.draw(
818+
H,
819+
node_labels=True,
820+
hyperedge_labels=True,
821+
min_node_size=3,
822+
font_size_nodes=12,
823+
font_size_edges=10,
824+
)
825+
plt.close("all")
826+
827+
828+
def test_draw_kwargs_forwarding(edgelist8):
829+
"""Check that label kwargs are actually applied to the rendered text objects."""
830+
H = xgi.Hypergraph(edgelist8)
831+
832+
# node-label font size and color are forwarded to the Text objects on the axes
833+
fig, ax = plt.subplots()
834+
xgi.draw(H, ax=ax, node_labels=True, font_size_nodes=18, font_color_nodes="red")
835+
for text in ax.texts:
836+
assert text.get_fontsize() == 18
837+
assert text.get_color() == "red"
838+
plt.close("all")
839+
840+
# edge-label font size and color are forwarded to the Text objects on the axes
841+
fig, ax = plt.subplots()
842+
xgi.draw(H, ax=ax, hyperedge_labels=True, font_size_edges=14, font_color_edges="blue")
843+
for text in ax.texts:
844+
assert text.get_fontsize() == 14
845+
assert text.get_color() == "blue"
846+
plt.close("all")
847+
848+
# when both label types are drawn, each uses its own kwargs
849+
fig, ax = plt.subplots()
850+
xgi.draw(
851+
H,
852+
ax=ax,
853+
node_labels=True,
854+
hyperedge_labels=True,
855+
font_size_nodes=16,
856+
font_size_edges=10,
857+
)
858+
font_sizes = {text.get_fontsize() for text in ax.texts}
859+
assert 16 in font_sizes
860+
assert 10 in font_sizes
861+
plt.close("all")
862+
863+
# when labels are off, no text is rendered regardless of label kwargs passed
864+
fig, ax = plt.subplots()
865+
xgi.draw(H, ax=ax, node_labels=False, font_size_nodes=18, font_color_nodes="red")
866+
assert len(ax.texts) == 0
867+
plt.close("all")
868+
869+
fig, ax = plt.subplots()
870+
xgi.draw(H, ax=ax, hyperedge_labels=False, font_size_edges=14, font_color_edges="blue")
871+
assert len(ax.texts) == 0
872+
plt.close("all")

tutorials/getting_started/quickstart.ipynb

Lines changed: 38 additions & 97 deletions
Large diffs are not rendered by default.

xgi/drawing/draw.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def draw(
218218
If those are single values, `interpolate_sizes` is ignored
219219
for it. By default, True.
220220
**kwargs : optional args
221-
Alternate default values. Values that can be overwritten are the following:
221+
Accepts three non-overlapping groups of keyword arguments:
222+
223+
Size-rescaling settings (see also ``rescale_sizes``):
222224
223225
* "min_node_size" (default: 5)
224226
* "max_node_size" (default: 30)
@@ -227,6 +229,15 @@ def draw(
227229
* "min_dyad_lw" (default: 1)
228230
* "max_dyad_lw" (default: 10)
229231
232+
Node-label styling (only used when ``node_labels`` is enabled);
233+
see :func:`draw_node_labels` for the full list.
234+
235+
Hyperedge-label styling (only used when ``hyperedge_labels`` is enabled);
236+
see :func:`draw_hyperedge_labels` for the full list.
237+
238+
Passing an argument that does not belong to any of these three groups
239+
raises a ``TypeError``.
240+
230241
Returns
231242
-------
232243
ax : matplotlib Axes
@@ -247,6 +258,12 @@ def draw(
247258
>>> H.add_edges_from([[1,2,3],[3,4],[4,5,6,7],[7,8,9,10,11]])
248259
>>> ax = xgi.draw(H, pos=xgi.barycenter_spring_layout(H))
249260
261+
Notes
262+
-----
263+
For finer control over node or hyperedge label appearance, call
264+
:func:`draw_node_labels` or :func:`draw_hyperedge_labels` directly after
265+
``draw()``.
266+
250267
See Also
251268
--------
252269
draw_nodes
@@ -266,7 +283,29 @@ def draw(
266283
"max_node_lw": 5,
267284
}
268285

269-
settings.update(kwargs)
286+
# Split **kwargs into three non-overlapping buckets and validate upfront.
287+
# Doing this here — unconditionally and against all valid keys at once —
288+
# catches typos even when labels are off, and prevents cross-bucket confusion
289+
# (e.g. a valid edge-label kwarg being rejected by node-label validation).
290+
_settings_keys = set(settings)
291+
_node_label_keys = (
292+
signature(draw_node_labels).parameters.keys()
293+
- {"H", "pos", "ax_nodes", "node_labels"}
294+
)
295+
_edge_label_keys = (
296+
signature(draw_hyperedge_labels).parameters.keys()
297+
- {"H", "pos", "ax_edges", "hyperedge_labels"}
298+
)
299+
unknown = set(kwargs) - (_settings_keys | _node_label_keys | _edge_label_keys)
300+
if unknown:
301+
raise TypeError(
302+
f"draw() got unexpected keyword argument(s): {', '.join(sorted(unknown))}"
303+
)
304+
settings_kwargs = {k: v for k, v in kwargs.items() if k in _settings_keys}
305+
node_label_kwargs = {k: v for k, v in kwargs.items() if k in _node_label_keys}
306+
edge_label_kwargs = {k: v for k, v in kwargs.items() if k in _edge_label_keys}
307+
308+
settings.update(settings_kwargs)
270309

271310
ax, pos = _draw_init(H, ax, pos)
272311

@@ -293,7 +332,7 @@ def draw(
293332
max_order=max_order,
294333
hyperedge_labels=hyperedge_labels,
295334
rescale_sizes=rescale_sizes,
296-
**kwargs,
335+
**edge_label_kwargs,
297336
)
298337

299338
elif isinstance(H, Hypergraph):
@@ -319,7 +358,7 @@ def draw(
319358
hull=hull,
320359
radius=radius,
321360
rescale_sizes=rescale_sizes,
322-
**kwargs,
361+
**edge_label_kwargs,
323362
)
324363
else:
325364
raise XGIError("The input must be a SimplicialComplex or Hypergraph")
@@ -340,7 +379,7 @@ def draw(
340379
params=settings,
341380
node_labels=node_labels,
342381
rescale_sizes=rescale_sizes,
343-
**kwargs,
382+
**node_label_kwargs,
344383
)
345384

346385
# compute axis limits
@@ -467,7 +506,9 @@ def draw_nodes(
467506
}
468507

469508
settings.update(params)
470-
settings.update(kwargs)
509+
# Only absorb recognised settings keys from kwargs so that label kwargs
510+
# passed via a direct call to draw_nodes() don't pollute the settings dict.
511+
settings.update({k: v for k, v in kwargs.items() if k in settings})
471512

472513
ax, pos = _draw_init(H, ax, pos)
473514

@@ -523,15 +564,9 @@ def draw_nodes(
523564
)
524565

525566
if node_labels:
526-
# Get all valid keywords by inspecting the signatures of draw_node_labels
527-
valid_label_kwds = signature(draw_node_labels).parameters.keys()
528-
# Remove the arguments of this function (draw_networkx)
529-
valid_label_kwds = valid_label_kwds - {"H", "pos", "ax", "node_labels"}
530-
if any([k not in valid_label_kwds for k in kwargs]):
531-
invalid_args = ", ".join([k for k in kwargs if k not in valid_label_kwds])
532-
raise ValueError(f"Received invalid argument(s): {invalid_args}")
533-
label_kwds = {k: v for k, v in kwargs.items() if k in valid_label_kwds}
534-
draw_node_labels(H, pos, node_labels, ax_nodes=ax, **label_kwds)
567+
# kwargs here are already validated and filtered to node-label keys by
568+
# draw() when called from there; for direct calls they are passed as-is.
569+
draw_node_labels(H, pos, node_labels, ax_nodes=ax, **kwargs)
535570

536571
# compute axis limits
537572
_update_lims(pos, ax)
@@ -694,7 +729,9 @@ def draw_hyperedges(
694729
}
695730

696731
settings.update(params)
697-
settings.update(kwargs)
732+
# Only absorb recognised settings keys from kwargs so that label kwargs
733+
# passed via a direct call to draw_hyperedges() don't pollute the settings dict.
734+
settings.update({k: v for k, v in kwargs.items() if k in settings})
698735

699736
ax, pos = _draw_init(H, ax, pos)
700737

@@ -819,15 +856,9 @@ def draw_hyperedges(
819856
ax.add_collection(edge_collection)
820857

821858
if hyperedge_labels:
822-
# Get all valid keywords by inspecting the signatures of draw_node_labels
823-
valid_label_kwds = signature(draw_hyperedge_labels).parameters.keys()
824-
# Remove the arguments of this function (draw_networkx)
825-
valid_label_kwds = valid_label_kwds - {"H", "pos", "ax", "hyperedge_labels"}
826-
if any([k not in valid_label_kwds for k in kwargs]):
827-
invalid_args = ", ".join([k for k in kwargs if k not in valid_label_kwds])
828-
raise ValueError(f"Received invalid argument(s): {invalid_args}")
829-
label_kwds = {k: v for k, v in kwargs.items() if k in valid_label_kwds}
830-
draw_hyperedge_labels(H, pos, hyperedge_labels, ax_edges=ax, **label_kwds)
859+
# kwargs here are already validated and filtered to edge-label keys by
860+
# draw() when called from there; for direct calls they are passed as-is.
861+
draw_hyperedge_labels(H, pos, hyperedge_labels, ax_edges=ax, **kwargs)
831862

832863
# compute axis limits
833864
_update_lims(pos, ax)

0 commit comments

Comments
 (0)