Skip to content

Commit 06e91e7

Browse files
committed
fix: draw kwarg issues #730 #731
1 parent 6223e7e commit 06e91e7

2 files changed

Lines changed: 105 additions & 25 deletions

File tree

tests/drawing/test_draw.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,58 @@ def test_issue_515(edgelist8):
768768
xgi.draw_multilayer(H, node_fc=["black"] * H.num_nodes)
769769

770770
plt.close("all")
771+
772+
773+
def test_draw_kwargs_validation(edgelist8):
774+
"""Regression tests for kwarg validation bugs fixed in issues #730 and #731."""
775+
H = xgi.Hypergraph(edgelist8)
776+
S = xgi.SimplicialComplex(edgelist8)
777+
778+
# issue #731: typo is caught even when labels are off (the default)
779+
with pytest.raises(TypeError):
780+
xgi.draw(H, font_siz_nodes=20)
781+
plt.close("all")
782+
783+
# issue #731: same typo is caught consistently when labels are on
784+
with pytest.raises(TypeError):
785+
xgi.draw(H, node_labels=True, font_siz_nodes=20)
786+
plt.close("all")
787+
788+
# issue #730 bug A: a valid edge-label kwarg must not be rejected when
789+
# node_labels=True (previously the node-label validator ran first and
790+
# rejected kwargs intended for draw_hyperedge_labels)
791+
xgi.draw(H, node_labels=True, hyperedge_labels=True, font_size_edges=14)
792+
plt.close("all")
793+
794+
# issue #730 bug A: same for SimplicialComplex
795+
xgi.draw(S, node_labels=True, hyperedge_labels=True, font_size_edges=14)
796+
plt.close("all")
797+
798+
# issue #730 bug B: a settings kwarg must not be rejected when node_labels=True
799+
# (previously min_node_size was not in draw_node_labels' signature and raised)
800+
xgi.draw(H, node_labels=True, min_node_size=3)
801+
plt.close("all")
802+
803+
# issue #730 bug C: unknown kwargs raise TypeError, not ValueError
804+
with pytest.raises(TypeError):
805+
xgi.draw(H, not_a_real_kwarg=99)
806+
plt.close("all")
807+
808+
# valid node-label kwargs are accepted and forwarded correctly
809+
xgi.draw(H, node_labels=True, font_size_nodes=14, font_color_nodes="red")
810+
plt.close("all")
811+
812+
# valid edge-label kwargs are accepted and forwarded correctly
813+
xgi.draw(H, hyperedge_labels=True, font_size_edges=14, font_color_edges="blue")
814+
plt.close("all")
815+
816+
# mixing all three kwarg buckets at once works without error
817+
xgi.draw(
818+
H,
819+
node_labels=True,
820+
hyperedge_labels=True,
821+
min_node_size=3,
822+
font_size_nodes=12,
823+
font_size_edges=10,
824+
)
825+
plt.close("all")

xgi/drawing/draw.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def draw(
218218
If those are single values, `interpolate_sizes` is ignored
219219
for it. By default, True.
220220
**kwargs : optional args
221-
Alternate default values. Values that can be overwritten are the following:
221+
Accepts three non-overlapping groups of keyword arguments:
222+
223+
Size-rescaling settings (see also ``rescale_sizes``):
222224
223225
* "min_node_size" (default: 5)
224226
* "max_node_size" (default: 30)
@@ -227,6 +229,15 @@ def draw(
227229
* "min_dyad_lw" (default: 1)
228230
* "max_dyad_lw" (default: 10)
229231
232+
Node-label styling (only used when ``node_labels`` is enabled);
233+
see :func:`draw_node_labels` for the full list.
234+
235+
Hyperedge-label styling (only used when ``hyperedge_labels`` is enabled);
236+
see :func:`draw_hyperedge_labels` for the full list.
237+
238+
Passing an argument that does not belong to any of these three groups
239+
raises a ``TypeError``.
240+
230241
Returns
231242
-------
232243
ax : matplotlib Axes
@@ -266,7 +277,29 @@ def draw(
266277
"max_node_lw": 5,
267278
}
268279

269-
settings.update(kwargs)
280+
# Split **kwargs into three non-overlapping buckets and validate upfront.
281+
# Doing this here — unconditionally and against all valid keys at once —
282+
# catches typos even when labels are off, and prevents cross-bucket confusion
283+
# (e.g. a valid edge-label kwarg being rejected by node-label validation).
284+
_settings_keys = set(settings)
285+
_node_label_keys = (
286+
signature(draw_node_labels).parameters.keys()
287+
- {"H", "pos", "ax_nodes", "node_labels"}
288+
)
289+
_edge_label_keys = (
290+
signature(draw_hyperedge_labels).parameters.keys()
291+
- {"H", "pos", "ax_edges", "hyperedge_labels"}
292+
)
293+
unknown = set(kwargs) - (_settings_keys | _node_label_keys | _edge_label_keys)
294+
if unknown:
295+
raise TypeError(
296+
f"draw() got unexpected keyword argument(s): {', '.join(sorted(unknown))}"
297+
)
298+
settings_kwargs = {k: v for k, v in kwargs.items() if k in _settings_keys}
299+
node_label_kwargs = {k: v for k, v in kwargs.items() if k in _node_label_keys}
300+
edge_label_kwargs = {k: v for k, v in kwargs.items() if k in _edge_label_keys}
301+
302+
settings.update(settings_kwargs)
270303

271304
ax, pos = _draw_init(H, ax, pos)
272305

@@ -293,7 +326,7 @@ def draw(
293326
max_order=max_order,
294327
hyperedge_labels=hyperedge_labels,
295328
rescale_sizes=rescale_sizes,
296-
**kwargs,
329+
**edge_label_kwargs,
297330
)
298331

299332
elif isinstance(H, Hypergraph):
@@ -319,7 +352,7 @@ def draw(
319352
hull=hull,
320353
radius=radius,
321354
rescale_sizes=rescale_sizes,
322-
**kwargs,
355+
**edge_label_kwargs,
323356
)
324357
else:
325358
raise XGIError("The input must be a SimplicialComplex or Hypergraph")
@@ -340,7 +373,7 @@ def draw(
340373
params=settings,
341374
node_labels=node_labels,
342375
rescale_sizes=rescale_sizes,
343-
**kwargs,
376+
**node_label_kwargs,
344377
)
345378

346379
# compute axis limits
@@ -467,7 +500,9 @@ def draw_nodes(
467500
}
468501

469502
settings.update(params)
470-
settings.update(kwargs)
503+
# Only absorb recognised settings keys from kwargs so that label kwargs
504+
# passed via a direct call to draw_nodes() don't pollute the settings dict.
505+
settings.update({k: v for k, v in kwargs.items() if k in settings})
471506

472507
ax, pos = _draw_init(H, ax, pos)
473508

@@ -523,15 +558,9 @@ def draw_nodes(
523558
)
524559

525560
if node_labels:
526-
# Get all valid keywords by inspecting the signatures of draw_node_labels
527-
valid_label_kwds = signature(draw_node_labels).parameters.keys()
528-
# Remove the arguments of this function (draw_networkx)
529-
valid_label_kwds = valid_label_kwds - {"H", "pos", "ax", "node_labels"}
530-
if any([k not in valid_label_kwds for k in kwargs]):
531-
invalid_args = ", ".join([k for k in kwargs if k not in valid_label_kwds])
532-
raise ValueError(f"Received invalid argument(s): {invalid_args}")
533-
label_kwds = {k: v for k, v in kwargs.items() if k in valid_label_kwds}
534-
draw_node_labels(H, pos, node_labels, ax_nodes=ax, **label_kwds)
561+
# kwargs here are already validated and filtered to node-label keys by
562+
# draw() when called from there; for direct calls they are passed as-is.
563+
draw_node_labels(H, pos, node_labels, ax_nodes=ax, **kwargs)
535564

536565
# compute axis limits
537566
_update_lims(pos, ax)
@@ -694,7 +723,9 @@ def draw_hyperedges(
694723
}
695724

696725
settings.update(params)
697-
settings.update(kwargs)
726+
# Only absorb recognised settings keys from kwargs so that label kwargs
727+
# passed via a direct call to draw_hyperedges() don't pollute the settings dict.
728+
settings.update({k: v for k, v in kwargs.items() if k in settings})
698729

699730
ax, pos = _draw_init(H, ax, pos)
700731

@@ -819,15 +850,9 @@ def draw_hyperedges(
819850
ax.add_collection(edge_collection)
820851

821852
if hyperedge_labels:
822-
# Get all valid keywords by inspecting the signatures of draw_node_labels
823-
valid_label_kwds = signature(draw_hyperedge_labels).parameters.keys()
824-
# Remove the arguments of this function (draw_networkx)
825-
valid_label_kwds = valid_label_kwds - {"H", "pos", "ax", "hyperedge_labels"}
826-
if any([k not in valid_label_kwds for k in kwargs]):
827-
invalid_args = ", ".join([k for k in kwargs if k not in valid_label_kwds])
828-
raise ValueError(f"Received invalid argument(s): {invalid_args}")
829-
label_kwds = {k: v for k, v in kwargs.items() if k in valid_label_kwds}
830-
draw_hyperedge_labels(H, pos, hyperedge_labels, ax_edges=ax, **label_kwds)
853+
# kwargs here are already validated and filtered to edge-label keys by
854+
# draw() when called from there; for direct calls they are passed as-is.
855+
draw_hyperedge_labels(H, pos, hyperedge_labels, ax_edges=ax, **kwargs)
831856

832857
# compute axis limits
833858
_update_lims(pos, ax)

0 commit comments

Comments
 (0)