forked from charlesxuuu/EnumMatching
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathenum_matching.py
337 lines (285 loc) · 10.8 KB
/
enum_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
from networkx.algorithms.bipartite import hopcroft_karp_matching
from networkx.algorithms.cycles import find_cycle
from networkx.algorithms.matching import is_perfect_matching
from networkx.exception import NetworkXNoCycle
from networkx import DiGraph, connected_components, weakly_connected_components, is_directed
from more_itertools import peekable
from networkx import get_node_attributes
import argparse
def enum_perfect_matching(g):
match = maximum_matching_all(g)
matches = [match]
# m, d = build_d(g, match)
if is_perfect_matching(g, match):
enum_perfect_matching_iter(matches, g, match)
else:
print("No perfect matching found!")
def enum_perfect_matching_iter(matches, g, m):
# Step 1
if not peekable(g.edges()):
return
# Step 2 Find a cycle in G
_, d = build_d(g, m)
cycle = find_cycle_in_dgm(d)
if cycle:
# Step 3 - Choose edge e from the cycle obtained
# Step 4 - Find a cycle containing e via DFS
# It is already done as we picked e from the cycle.
# Step 5 - Exchange edges to generate new M'
m_prime = m.copy()
e_start = cycle[0]
s = cycle[0]
e_end = 0
# to detect if we need to add or delete this edge
flip = 0
# to detect if it is the first time to visit the start
init = 0
# define the precursor
temp = s
# Step 5: Exchange edges along the cycle and output
# obtained maximum M'
for x in cycle:
if x == s and init == 0:
init = 1
continue
if flip == 0:
if init == 1:
e_end = x
init = 2
m_prime.remove_edge(temp, x)
flip = 1
else:
m_prime.add_edge(x, temp)
flip = 0
temp = x
# Pre-requisite for Step 6 and 7
g_plus = construct_g_plus(g, e_start, e_end)
g_minus = construct_g_minus(g, e_start, e_end)
# Step 6 and 7
enum_perfect_matching_iter(matches, g_plus, m)
enum_perfect_matching_iter(matches, g_minus, m_prime)
else:
return
def enum_maximum_matching(g):
match = maximum_matching_all(g)
m, d = build_d(g, match)
matches = [m]
if g.is_directed():
enum_maximum_matching_iter(matches, g, m, d)
else:
enum_maximum_matching_iter(matches, build_g(g), m, d)
# Convert di-graphs to maximum matchings
final_matches = []
for match in matches:
ma = maximum_matching_all(match)
final_matches.append(ma)
return final_matches
def enum_maximum_matching_iter(matches, g, m, d):
# If there are no edges in G
if not peekable(g.edges()) or not peekable(d.edges()):
print("D(G, M) or G has no edges!")
return
else:
# Step 2 Find a cycle in D(G, M)
cycle = find_cycle_in_dgm(d)
if cycle:
# Step 3 - Choose edge e from the cycle obtained
# Step 4 - Find a cycle containing e via DFS
# It is already done as we picked e from the cycle.
# Step 5 - Exchange edges to generate new M'
m_prime = m.copy()
e_start = cycle[0]
s = cycle[0]
e_end = 0
# to detect if we need to add or delete this edge
flip = 0
# to detect if it is the first time to visit the start
init = 0
# define the precursor
temp = s
# Step 5: Exchange edges along the cycle and output
# obtained maximum M'
for x in cycle:
if x == s and init == 0:
init = 1
continue
if flip == 0:
if init == 1:
e_end = x
init = 2
m_prime.remove_edge(temp, x)
flip = 1
else:
m_prime.add_edge(x, temp)
flip = 0
temp = x
# Pre-requisite for Step 6 and 7
g_plus = construct_g_plus(g, e_start, e_end)
g_minus = construct_g_minus(g, e_start, e_end)
m.remove_edge(e_start, e_end)
d_plus = construct_d_from_gm2(g_plus, m)
m.add_edge(e_start, e_end)
d_minus = construct_d_from_gm2(g_minus, m_prime)
# Step 6 and 7
enum_maximum_matching_iter(matches, g_plus, m, d_plus)
enum_maximum_matching_iter(matches, g_minus, m_prime, d_minus)
else:
# Step 8
nodes = list(g.nodes())
pair = {key: float("inf") for key in nodes}
for v in nodes:
for w in m.successors(v):
pair[v] = w
pair[w] = v
for v in nodes:
if pair[v] == float("inf"):
# if v is in the left side
for w in g.successors(v):
if pair[w] != float("inf"):
m_prime = m.copy()
m_prime.add_edge(v, w)
m_prime.remove_edge(pair[w], w)
matches.append(m_prime)
g_plus = construct_g_plus(g, v, w)
g_minus = construct_g_minus(g, v, w)
d_plus = construct_d_from_gm2(g_plus, m_prime)
d_minus = construct_d_from_gm2(g_minus, m)
enum_maximum_matching_iter(matches, g_plus, m_prime, d_plus)
enum_maximum_matching_iter(matches, g_minus, m, d_minus)
return
# if v is in the right side
for w in d.successors(v):
if pair[w] != float("inf"):
m_prime = m.copy()
m_prime.add_edge(w, v)
m_prime.remove_edge(w, pair[w])
matches.append(m_prime)
g_plus = construct_g_plus(g, w, v)
d_plus = construct_d_from_gm2(g_plus, m_prime)
g_minus = construct_g_minus(g, w, v)
d_minus = construct_d_from_gm2(g_minus, m)
enum_maximum_matching_iter(matches, g_plus, m_prime, d_plus)
enum_maximum_matching_iter(matches, g_minus, m, d_minus)
return
# -----------------------------Helper functions--------------------------
def maximum_matching_all(bipartite_graph):
matches = dict()
if is_directed(bipartite_graph):
parts = weakly_connected_components(bipartite_graph)
else:
parts = connected_components(bipartite_graph)
for conn in parts:
sub = bipartite_graph.subgraph(conn)
max_matching = hopcroft_karp_matching(sub)
matches.update(max_matching)
return matches
# input: undirected bipartite graph
# output: directed bipartite graph with only arrows 0 to 1
def build_g(undirected_graph):
g = DiGraph()
for n, d in undirected_graph.nodes(data=True):
if d['biparite'] == 0:
g.add_node(n, biparite=0)
else:
g.add_node(n, biparite=1)
top = get_node_attributes(undirected_graph, 'biparite')
# Get edges
for source, target in undirected_graph.edges():
if top[source] == 0:
g.add_edge(source, target)
return g
# d - same as g but points the other way
# m - stores matches
def build_d(g, match):
d = DiGraph()
m = DiGraph()
for node, data in g.nodes(data=True):
d.add_node(node, biparite=data['biparite'])
m.add_node(node, biparite=data['biparite'])
m_edges = []
for s, t in match.items():
m_edges.append((s, t))
data = get_node_attributes(g, 'biparite')
for source, target in g.edges():
if (target, source) in m_edges or (source, target) in m_edges:
if data[source] == 0:
d.add_edge(source, target)
m.add_edge(source, target)
else:
d.add_edge(target, source)
else:
if data[source] == 0:
d.add_edge(target, source)
else:
d.add_edge(source, target)
return m, d
def find_cycle_in_dgm(d):
path = list()
for node in d.nodes():
try:
cycle = find_cycle(d, source=node, orientation=None)
for source, target in cycle:
if source not in path:
path.append(source)
if target not in path:
path.append(target)
path.append(node)
return path
except NetworkXNoCycle:
continue
return None
def construct_g_minus(g, e_start, e_end):
g_minus = g.copy()
g_minus.remove_edge(e_start, e_end)
return g_minus
def construct_g_plus(g, e_start, e_end):
g_plus = g.copy()
# g_plus.remove_node(e_start)
# g_plus.remove_node(e_end)
for x in g.successors(e_start):
g_plus.remove_edge(e_start, x)
for x in g.reverse(copy=True).successors(e_end):
if x != e_start:
g_plus.remove_edge(x, e_end)
return g_plus
def construct_d_from_gm2(g_plus, m_prime):
d = g_plus.copy()
for v in g_plus.nodes():
for w in g_plus.successors(v):
if not m_prime.has_edge(v, w):
d.add_edge(w, v)
d.remove_edge(v, w)
return d
# Follows same format as the Java code
# The main advantage, Java requires nodes to be integers
# This version doesn't need the nodes to be integers!
def read_graph(graph_file):
input_graph = DiGraph()
print("Opening Graph: " + graph_file)
with open(graph_file, 'r') as fd:
# Skip Node and edge number, its not needed in Python NetworkX
next(fd)
next(fd)
# Process each edge
for line in fd:
source, target = line.strip().split(' ')
input_graph.add_node(source, biparite=0)
input_graph.add_node(target, biparite=1)
input_graph.add_edge(source, target)
return input_graph
# Contains pre-configured answers
def read_answers():
answers = dict()
with open('answers.csv', 'r') as fd:
for line in fd:
line.strip()
return answers
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='A python program that provided a bipartite graph can compute all '
'maximum matches')
parser.add_argument('--input', '-i', dest='graph', action='store',
help="Input bipartite graph to get all maximum matches", type=str)
args = parser.parse_args()
graph = read_graph(args.graph)
for max_match in enum_maximum_matching(graph):
print(max_match)