Skip to content

Commit 55ecf68

Browse files
Support DiHypergraph in from_bipartite_graph (#633)
* Support DiHypergraph in from_bipartite_graph * PR feedback * Add the ability to check bipartiteness of directed and undirected * format with black . * PR feedback * fix additional PR comments * fixed small bugs --------- Co-authored-by: Nicholas Landry <nicholas.landry.91@gmail.com>
1 parent b4f9f38 commit 55ecf68

3 files changed

Lines changed: 131 additions & 26 deletions

File tree

tests/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,37 @@ def bipartite_graph4():
135135
return G
136136

137137

138+
@pytest.fixture
139+
def bipartite_digraph1():
140+
G = nx.DiGraph()
141+
G.add_nodes_from([1, 2, 3, 4], bipartite=0)
142+
G.add_nodes_from(["a", "b", "c"], bipartite=1)
143+
G.add_edges_from(
144+
[
145+
("a", 1),
146+
(1, "b"),
147+
("b", 2),
148+
(2, "c"),
149+
("c", 3),
150+
(4, "a"),
151+
]
152+
)
153+
return G
154+
155+
156+
@pytest.fixture
157+
def bipartite_digraph2():
158+
G = nx.DiGraph()
159+
G.add_nodes_from([1], bipartite=0)
160+
G.add_nodes_from(["a"], bipartite=1)
161+
G.add_edges_from(
162+
[
163+
("a", 1, {"direction": "invalid"}),
164+
]
165+
)
166+
return G
167+
168+
138169
@pytest.fixture
139170
def attr0():
140171
return {"color": "brown", "name": "camel"}

tests/convert/test_bipartite_graph.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import networkx as nx
12
import pytest
23

34
import xgi
45
from xgi.exception import XGIError
56

67

7-
def test_to_bipartite_graph(edgelist1, edgelist3, edgelist4):
8+
def test_to_bipartite_graph(edgelist1, edgelist3, edgelist4, diedgelist1):
89
H1 = xgi.Hypergraph(edgelist1)
910
H2 = xgi.Hypergraph(edgelist3)
1011
H3 = xgi.Hypergraph(edgelist4)
@@ -49,9 +50,26 @@ def test_to_bipartite_graph(edgelist1, edgelist3, edgelist4):
4950
assert sorted(bi_el3) == sorted(true_bi_el3)
5051
assert G3.edges() == xgi.to_bipartite_graph(H3, index=False).edges()
5152

53+
## Directed
54+
H = xgi.DiHypergraph(diedgelist1)
55+
G = xgi.to_bipartite_graph(H)
56+
57+
assert isinstance(G, nx.DiGraph)
58+
assert set(G.nodes) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
59+
assert (0, 8) in G.edges
60+
assert (8, 0) not in G.edges
61+
assert (9, 6) in G.edges
62+
assert (6, 9) not in G.edges
63+
assert (9, 5) in G.edges
64+
assert (9, 5) in G.edges
65+
5266

5367
def test_from_bipartite_graph(
54-
bipartite_graph1, bipartite_graph2, bipartite_graph3, bipartite_graph4
68+
bipartite_graph1,
69+
bipartite_graph2,
70+
bipartite_graph3,
71+
bipartite_graph4,
72+
bipartite_digraph1,
5573
):
5674
H = xgi.from_bipartite_graph(bipartite_graph1)
5775

@@ -81,3 +99,15 @@ def test_from_bipartite_graph(
8199
# not bipartite
82100
with pytest.raises(XGIError):
83101
H = xgi.from_bipartite_graph(bipartite_graph4, dual=True)
102+
103+
### Directed
104+
H = xgi.from_bipartite_graph(bipartite_digraph1)
105+
106+
assert set(H.nodes) == {1, 2, 3, 4}
107+
assert set(H.edges) == {"a", "b", "c"}
108+
assert H.edges.head("a") == {1}
109+
assert H.edges.tail("a") == {4}
110+
assert H.edges.head("b") == {2}
111+
assert H.edges.tail("b") == {1}
112+
assert H.edges.head("c") == {3}
113+
assert H.edges.tail("c") == {2}

xgi/convert/bipartite_graph.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
"""Methods for converting to and from bipartite graphs."""
22

33
import networkx as nx
4-
from networkx import bipartite
54

5+
from ..core import DiHypergraph, Hypergraph
66
from ..exception import XGIError
7-
from ..generators import empty_hypergraph
87

98
__all__ = ["from_bipartite_graph", "to_bipartite_graph"]
109

1110

12-
def from_bipartite_graph(G, create_using=None, dual=False):
11+
def from_bipartite_graph(G, dual=False):
1312
"""
1413
Create a Hypergraph from a NetworkX bipartite graph.
1514
@@ -30,16 +29,13 @@ def from_bipartite_graph(G, create_using=None, dual=False):
3029
A networkx bipartite graph. Each node in the graph has a property
3130
'bipartite' taking the value of 0 or 1 indicating the type of node.
3231
33-
create_using : Hypergraph constructor, optional
34-
The hypergraph object to add the data to, by default None
35-
3632
dual : bool, default : False
3733
If True, get edges from bipartite=0 and nodes from bipartite=1
3834
3935
Returns
4036
-------
41-
Hypergraph
42-
The equivalent hypergraph
37+
Hypergraph or DiHypergraph
38+
The equivalent hypergraph or directed hypergraph
4339
4440
References
4541
----------
@@ -58,8 +54,14 @@ def from_bipartite_graph(G, create_using=None, dual=False):
5854
>>> H = xgi.from_bipartite_graph(G)
5955
6056
"""
57+
if isinstance(G, nx.DiGraph):
58+
directed = True
59+
else:
60+
directed = False
61+
6162
edges = []
6263
nodes = []
64+
6365
for n, d in G.nodes(data=True):
6466
try:
6567
node_type = d["bipartite"]
@@ -73,34 +75,59 @@ def from_bipartite_graph(G, create_using=None, dual=False):
7375
else:
7476
raise XGIError("Invalid type specifier")
7577

76-
if not bipartite.is_bipartite_node_set(G, nodes):
78+
if not _is_bipartite(G, nodes, edges):
7779
raise XGIError("The network is not bipartite")
7880

79-
H = empty_hypergraph(create_using)
81+
if directed:
82+
H = DiHypergraph()
83+
else:
84+
H = Hypergraph()
85+
8086
H.add_nodes_from(nodes)
81-
for edge in edges:
82-
nodes_in_edge = list(G.neighbors(edge))
83-
H.add_edge(nodes_in_edge, idx=edge)
87+
88+
for u, v in G.edges:
89+
if directed:
90+
if v in edges:
91+
H.add_node_to_edge(v, u, direction="in")
92+
else:
93+
H.add_node_to_edge(u, v, direction="out")
94+
else:
95+
H.add_node_to_edge(v, u)
96+
8497
return H.dual() if dual else H
8598

8699

100+
def _is_bipartite(G, nodes1, nodes2):
101+
"""Assumption is that nodes1.union(nodes2) == G.nodes"""
102+
for i, j in G.edges:
103+
cond1 = i in nodes1
104+
cond2 = j in nodes2
105+
if not cond1 == cond2: # if not both true or both false
106+
return False
107+
return True
108+
109+
87110
def to_bipartite_graph(H, index=False):
88111
"""Create a NetworkX bipartite network from a hypergraph.
89112
90113
Parameters
91114
----------
92-
H: xgi.Hypergraph
115+
H: xgi.Hypergraph or xgi.DiHypergraph
93116
The XGI hypergraph object of interest
94117
index: bool (default False)
95118
If False (default), return only the graph. If True, additionally return the
96119
index-to-node and index-to-edge mappings.
97120
98121
Returns
99122
-------
100-
nx.Graph[, dict, dict]
101-
The resulting equivalent bipartite graph, and optionally the index-to-unit
102-
mappings.
103-
123+
if xgi.Hypergraph
124+
nx.Graph[, dict, dict]
125+
The resulting equivalent bipartite graph, and optionally the index-to-unit
126+
mappings.
127+
if xgi.Hypergraph
128+
nx.DiGraph[, dict, dict]
129+
The resulting equivalent directed bipartite graph, and optionally the index-to-unit
130+
mappings.
104131
References
105132
----------
106133
The Why, How, and When of Representations for Complex Systems,
@@ -116,24 +143,41 @@ def to_bipartite_graph(H, index=False):
116143
>>> G, itn, ite = xgi.to_bipartite_graph(H, index=True)
117144
118145
"""
119-
G = nx.Graph()
146+
if isinstance(H, DiHypergraph):
147+
directed = True
148+
else:
149+
directed = False
120150

121151
n = H.num_nodes
122152
m = H.num_edges
123153

124154
node_dict = dict(zip(H.nodes, range(n)))
125155
edge_dict = dict(zip(H.edges, range(n, n + m)))
156+
157+
if directed:
158+
G = nx.DiGraph()
159+
else:
160+
G = nx.Graph()
161+
126162
G.add_nodes_from(node_dict.values(), bipartite=0)
127163
G.add_nodes_from(edge_dict.values(), bipartite=1)
128-
for node in H.nodes:
129-
for edge in H.nodes.memberships(node):
130-
G.add_edge(node_dict[node], edge_dict[edge])
164+
165+
if directed:
166+
for e in H.edges:
167+
for v in H.edges.tail(e):
168+
G.add_edge(node_dict[v], edge_dict[e])
169+
for v in H.edges.head(e):
170+
G.add_edge(edge_dict[e], node_dict[v])
171+
else:
172+
for e in H.edges:
173+
for v in H.edges.members(e):
174+
G.add_edge(node_dict[v], edge_dict[e])
131175

132176
if index:
133177
return (
134178
G,
135-
dict(zip(range(n), H.nodes)),
136-
dict(zip(range(n, n + m), H.edges)),
179+
{v: k for k, v in node_dict.items()},
180+
{v: k for k, v in edge_dict.items()},
137181
)
138182
else:
139183
return G

0 commit comments

Comments
 (0)