Skip to content

Commit b22f1f2

Browse files
authored
Merge pull request #724 from xgi-org/fix-layout-rng
fix: accept np.random.Generator as seed in spring layouts (#712)
2 parents bd8a39e + 9d581c7 commit b22f1f2

2 files changed

Lines changed: 68 additions & 4 deletions

File tree

tests/drawing/test_layout.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,58 @@ def test_edge_positions_from_barycenters(edgelist1):
253253
for idx, e in H.edges.members(dtype=dict).items():
254254
mean_pos = np.mean([node_pos[n] for n in e], axis=0)
255255
assert np.allclose(edge_pos[idx], mean_pos)
256+
257+
258+
def test_spring_layouts_accept_rng(edgelist1):
259+
"""Spring layouts should accept np.random.Generator as seed (issue #712)."""
260+
H = xgi.Hypergraph(edgelist1)
261+
rng = np.random.default_rng(42)
262+
263+
# All four spring layouts should accept a Generator without erroring
264+
xgi.pairwise_spring_layout(H, seed=rng)
265+
xgi.barycenter_spring_layout(H, seed=rng)
266+
xgi.weighted_barycenter_spring_layout(H, seed=rng)
267+
xgi.bipartite_spring_layout(H, seed=rng)
268+
269+
# int seed still works
270+
xgi.barycenter_spring_layout(H, seed=42)
271+
# None still works
272+
xgi.barycenter_spring_layout(H, seed=None)
273+
274+
275+
def test_spring_layouts_rng_semantics(edgelist1):
276+
"""Document expected reproducibility semantics for the seed argument.
277+
278+
The conventions mirror sklearn / scipy:
279+
280+
* Reusing the same `Generator` instance advances its state, so two
281+
consecutive calls produce different layouts.
282+
* Two fresh `Generator`s constructed from the same seed produce identical
283+
layouts.
284+
* Passing an int seed reproduces itself, but is not equivalent to passing a
285+
`Generator` constructed from that same int (the conversion path differs).
286+
"""
287+
H = xgi.Hypergraph(edgelist1)
288+
289+
# Reusing one rng across calls: state advances → different outputs
290+
rng = np.random.default_rng(42)
291+
pos_a = xgi.barycenter_spring_layout(H, seed=rng)
292+
pos_b = xgi.barycenter_spring_layout(H, seed=rng)
293+
assert any(not np.allclose(pos_a[n], pos_b[n]) for n in pos_a)
294+
295+
# Two fresh rngs from the same seed → identical outputs
296+
pos_c = xgi.barycenter_spring_layout(H, seed=np.random.default_rng(42))
297+
pos_d = xgi.barycenter_spring_layout(H, seed=np.random.default_rng(42))
298+
for n in pos_c:
299+
assert np.allclose(pos_c[n], pos_d[n])
300+
301+
# Same int seed reused → identical outputs
302+
pos_e = xgi.barycenter_spring_layout(H, seed=42)
303+
pos_f = xgi.barycenter_spring_layout(H, seed=42)
304+
for n in pos_e:
305+
assert np.allclose(pos_e[n], pos_f[n])
306+
307+
# int seed vs rng built from same seed → NOT equivalent (different paths)
308+
pos_g = xgi.barycenter_spring_layout(H, seed=42)
309+
pos_h = xgi.barycenter_spring_layout(H, seed=np.random.default_rng(42))
310+
assert any(not np.allclose(pos_g[n], pos_h[n]) for n in pos_g)

xgi/drawing/layout.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@
2222
]
2323

2424

25+
def _to_nx_seed(seed):
26+
# NetworkX's spring_layout doesn't accept np.random.Generator yet
27+
# (it calls seed.rand(...), which only exists on legacy RandomState).
28+
# Draw a fresh int from the rng so the user's Generator still advances.
29+
if isinstance(seed, np.random.Generator):
30+
return int(seed.integers(0, 2**32 - 1))
31+
return seed
32+
33+
2534
def random_layout(H, center=None, seed=None):
2635
"""Position nodes uniformly at random in the unit square.
2736
@@ -130,7 +139,7 @@ def pairwise_spring_layout(H, seed=None, k=None, **kwargs):
130139
if isinstance(H, SimplicialComplex):
131140
H = convert.from_max_simplices(H)
132141
G = convert.to_graph(H)
133-
pos = nx.spring_layout(G, seed=seed, k=k, **kwargs)
142+
pos = nx.spring_layout(G, seed=_to_nx_seed(seed), k=k, **kwargs)
134143
return pos
135144

136145

@@ -231,7 +240,7 @@ def bipartite_spring_layout(H, seed=None, k=None, **kwargs):
231240

232241
# Creating a dictionary for the position of the nodes with the standard spring
233242
# layout
234-
pos = nx.spring_layout(G, seed=seed, k=k, **kwargs)
243+
pos = nx.spring_layout(G, seed=_to_nx_seed(seed), k=k, **kwargs)
235244

236245
node_pos = {nodedict[i]: pos[i] for i in nodedict}
237246
edge_pos = {edgedict[i]: pos[i] for i in edgedict}
@@ -340,7 +349,7 @@ def barycenter_spring_layout(
340349

341350
# Creating a dictionary for the position of the nodes with the standard spring
342351
# layout
343-
pos_with_phantom_nodes = nx.spring_layout(G, seed=seed, k=k, **kwargs)
352+
pos_with_phantom_nodes = nx.spring_layout(G, seed=_to_nx_seed(seed), k=k, **kwargs)
344353

345354
# Retaining only the positions of the real nodes
346355
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
@@ -412,7 +421,7 @@ def weighted_barycenter_spring_layout(
412421

413422
# Creating a dictionary for node position with the standard spring layout
414423
pos_with_phantom_nodes = nx.spring_layout(
415-
G, weight="weight", seed=seed, k=k, **kwargs
424+
G, weight="weight", seed=_to_nx_seed(seed), k=k, **kwargs
416425
)
417426

418427
# Retaining only the positions of the real nodes

0 commit comments

Comments
 (0)