Skip to content

Commit 5ea53eb

Browse files
authored
Fix duplicate edges (#139)
* fix duplicate edge creation * more extensive testing * test for groups * nitpick
1 parent 547345e commit 5ea53eb

4 files changed

Lines changed: 668 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "znflow"
3-
version = "0.2.6"
3+
version = "0.2.7"
44
description = "A general purpose framework for building and running computational graphs."
55
authors = [
66
{ name = "Fabian Zills", email = "fzills@icp.uni-stuttgart.de" },
@@ -17,6 +17,7 @@ dev = [
1717
"attrs>=25.1.0",
1818
"dask>=2025.2.0",
1919
"distributed>=2025.2.0",
20+
"ipykernel>=6.30.1",
2021
"pydantic>=2.10.6",
2122
"pytest>=8.3.4",
2223
"pytest-cov>=6.0.0",

tests/test_edges.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import dataclasses
2+
3+
import znflow
4+
5+
6+
@dataclasses.dataclass
7+
class Node(znflow.Node):
8+
a: int
9+
b: int
10+
c: int | None = None
11+
12+
13+
def test_no_duplicate_edges():
14+
"""Test that connections don't create duplicate edges."""
15+
project = znflow.DiGraph()
16+
17+
with project:
18+
a = Node(a=1, b=2)
19+
b = Node(a=3, b=4)
20+
_ = Node(a=a.c, b=b.c)
21+
22+
# Should have 3 nodes
23+
assert len(project.nodes) == 3
24+
25+
# Should have 2 unique edges: a.c -> c.a and b.c -> c.b
26+
assert len(project.edges) == 2
27+
28+
# Verify the edges are what we expect
29+
edge_descriptions = {
30+
f"{d['u_attr']}->{d['v_attr']}" for _, _, d in project.edges(data=True)
31+
}
32+
assert edge_descriptions == {"c->a", "c->b"}
33+
34+
35+
def test_multiple_connections_same_nodes():
36+
"""Test that multiple different connections between same nodes work."""
37+
38+
@dataclasses.dataclass
39+
class MultiNode(znflow.Node):
40+
x: int | None = None
41+
y: int | None = None
42+
z: int | None = None
43+
44+
project = znflow.DiGraph()
45+
46+
with project:
47+
a = MultiNode()
48+
_ = MultiNode(x=a.x, y=a.y)
49+
50+
# Should have 2 edges: a.x -> b.x and a.y -> b.y
51+
assert len(project.edges) == 2
52+
53+
# Verify the edges are what we expect
54+
edge_descriptions = {
55+
f"{d['u_attr']}->{d['v_attr']}" for _, _, d in project.edges(data=True)
56+
}
57+
assert edge_descriptions == {"x->x", "y->y"}
58+
59+
60+
def test_no_duplicate_edges_group():
61+
project = znflow.DiGraph()
62+
63+
with project.group("group"):
64+
a = Node(a=1, b=2)
65+
b = Node(a=3, b=4)
66+
_ = Node(a=a.c, b=b.c)
67+
68+
with project:
69+
pass
70+
71+
with project.group("other-group"):
72+
pass
73+
74+
# Should have 3 nodes
75+
assert len(project.nodes) == 3
76+
77+
# Should have 2 unique edges: a.c -> c.a and b.c -> c.b
78+
assert len(project.edges) == 2
79+
80+
edge_descriptions = {
81+
f"{d['u_attr']}->{d['v_attr']}" for _, _, d in project.edges(data=True)
82+
}
83+
assert edge_descriptions == {"c->a", "c->b"}
84+
85+
86+
def test_no_duplicate_edges_iterable():
87+
project = znflow.DiGraph()
88+
89+
with project.group("subgraph"):
90+
a = Node(a=1, b=2)
91+
b = Node(a=3, b=4)
92+
c = Node(a=5, b=6)
93+
_ = Node(a=[a.a], b=[b.b, c.c])
94+
95+
with project:
96+
pass
97+
98+
with project.group("other-subgraph"):
99+
pass
100+
101+
assert len(project.nodes) == 4
102+
assert len(project.edges) == 3
103+
104+
edge_descriptions = {
105+
f"{d['u_attr']}->{d['v_attr']}" for _, _, d in project.edges(data=True)
106+
}
107+
assert edge_descriptions == {"a->a", "b->b", "c->b"}

0 commit comments

Comments
 (0)