-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexpander.py
More file actions
132 lines (93 loc) · 3.2 KB
/
Copy pathexpander.py
File metadata and controls
132 lines (93 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# -*- coding: utf-8 -*-
"""The expander use the DAG to generate curve by topological order."""
import inspect
__nodes = {}
def reg_node(func):
name = func.__name__
assert name not in __nodes, f"Node {name} already exists."
sig = inspect.signature(func)
__nodes[name] = {
"deps": list(sig.parameters.keys()),
"processor": func,
"curve": None,
}
return func
def all_nodes():
return __nodes
def get_processor(name):
return __nodes[name]["processor"]
def get_curve(name):
return __nodes[name]["curve"]
def get_deps(name):
return __nodes[name]["deps"]
def get_node(name):
return __nodes[name]
class DAG:
"""Directed Acyclic Graph."""
def __init__(self):
self._vertices = []
self._edges = []
self._depend = {}
self._depended = {}
def _valid_vertex(self, *vertices):
for vtx in vertices:
if vtx not in self._vertices:
raise ValueError(f"vertex {vtx} does not belong to DAG.")
def _has_path_to(self, v_from, v_to):
if v_from == v_to:
return True
for vtx in self.get_depend(v_from):
if self._has_path_to(vtx, v_to):
return True
return False
def add_vertex(self, vertex):
self._vertices.append(vertex)
def vertices(self):
return self._vertices
def add_edge(self, v_from, *v_tos):
self._valid_vertex(v_from, *v_tos)
for v_to in v_tos:
if self._has_path_to(v_to, v_from):
raise RuntimeError('The edge will create a cycle.')
self._edges.append((v_from, v_to))
self._depend[v_from] = self._depend.setdefault(v_from, []) + [v_to]
self._depended[v_to] = self._depended.setdefault(v_to, []) + [v_from]
def get_depended(self, vertex):
return self._depended.get(vertex, [])
def get_depend(self, vertex):
return self._depend.get(vertex, [])
def indegree(self, vertex):
return len(self.get_depend(vertex))
def all_starts(self):
res = []
for vtx in self._vertices:
if self.indegree(vtx) == 0:
res.append(vtx)
return res
class Expander:
def __init__(self, dag: DAG):
self.dag = dag
def execute(self):
"""Execute in topological order.
TODO: async and parallel
"""
indegree_dict = {}
for vtx in self.dag.vertices():
indegree_dict[vtx] = self.dag.indegree(vtx)
zero_indgrees = self.dag.all_starts()
while zero_indgrees:
vtx = zero_indgrees.pop(0)
# use predecessors as arguments
deps = get_deps(vtx)
pres = self.dag.get_depend(vtx)
assert len(deps) == len(
pres
), f"deps: {deps} and pres: {pres} should be the same."
cs = [None] * len(pres)
for pre in pres:
cs[deps.index(pre)] = get_curve(pre)
get_node(vtx)["curve"] = get_processor(vtx)(*cs)
for vtx in self.dag.get_depended(vtx):
indegree_dict[vtx] -= 1
if indegree_dict[vtx] == 0:
zero_indgrees.append(vtx)