1
1
import ast
2
- from metaflow . graph import deindent_docstring , DAGNode
2
+ import re
3
3
4
4
# NOTE: This is a custom implementation of the FlowGraph class from the Metaflow client
5
5
# which can parse a graph out of a flow_name and a source code string, instead of relying on
6
6
# importing the source code as a module.
7
7
8
8
9
+ def deindent_docstring (doc ):
10
+ if doc :
11
+ # Find the indent to remove from the docstring. We consider the following possibilities:
12
+ # Option 1:
13
+ # """This is the first line
14
+ # This is the second line
15
+ # """
16
+ # Option 2:
17
+ # """
18
+ # This is the first line
19
+ # This is the second line
20
+ # """
21
+ # Option 3:
22
+ # """
23
+ # This is the first line
24
+ # This is the second line
25
+ # """
26
+ #
27
+ # In all cases, we can find the indent to remove by doing the following:
28
+ # - Check the first non-empty line, if it has an indent, use that as the base indent
29
+ # - If it does not have an indent and there is a second line, check the indent of the
30
+ # second line and use that
31
+ saw_first_line = False
32
+ matched_indent = None
33
+ for line in doc .splitlines ():
34
+ if line :
35
+ matched_indent = re .match ("[\t ]+" , line )
36
+ if matched_indent is not None or saw_first_line :
37
+ break
38
+ saw_first_line = True
39
+ if matched_indent :
40
+ return re .sub (r"\n" + matched_indent .group (), "\n " , doc ).strip ()
41
+ else :
42
+ return doc
43
+ else :
44
+ return ""
45
+
46
+
9
47
class StepVisitor (ast .NodeVisitor ):
10
48
11
49
def __init__ (self , nodes ):
@@ -20,6 +58,93 @@ def visit_FunctionDef(self, node):
20
58
self .nodes [node .name ] = DAGNode (node , decos , doc if doc else '' )
21
59
22
60
61
+ class DAGNode (object ):
62
+ def __init__ (self , func_ast , decos , doc ):
63
+ self .name = func_ast .name
64
+ self .func_lineno = func_ast .lineno
65
+ self .decorators = decos
66
+ self .doc = deindent_docstring (doc )
67
+ self .parallel_step = any (getattr (deco , "IS_PARALLEL" , False ) for deco in decos )
68
+
69
+ # these attributes are populated by _parse
70
+ self .tail_next_lineno = 0
71
+ self .type = None
72
+ self .out_funcs = []
73
+ self .has_tail_next = False
74
+ self .invalid_tail_next = False
75
+ self .num_args = 0
76
+ self .foreach_param = None
77
+ self .num_parallel = 0
78
+ self .parallel_foreach = False
79
+ self ._parse (func_ast )
80
+
81
+ # these attributes are populated by _traverse_graph
82
+ self .in_funcs = set ()
83
+ self .split_parents = []
84
+ self .matching_join = None
85
+ # these attributes are populated by _postprocess
86
+ self .is_inside_foreach = False
87
+
88
+ def _expr_str (self , expr ):
89
+ return "%s.%s" % (expr .value .id , expr .attr )
90
+
91
+ def _parse (self , func_ast ):
92
+ self .num_args = len (func_ast .args .args )
93
+ tail = func_ast .body [- 1 ]
94
+
95
+ # end doesn't need a transition
96
+ if self .name == "end" :
97
+ # TYPE: end
98
+ self .type = "end"
99
+
100
+ # ensure that the tail an expression
101
+ if not isinstance (tail , ast .Expr ):
102
+ return
103
+
104
+ # determine the type of self.next transition
105
+ try :
106
+ if not self ._expr_str (tail .value .func ) == "self.next" :
107
+ return
108
+
109
+ self .has_tail_next = True
110
+ self .invalid_tail_next = True
111
+ self .tail_next_lineno = tail .lineno
112
+ self .out_funcs = [e .attr for e in tail .value .args ]
113
+
114
+ keywords = dict (
115
+ (k .arg , getattr (k .value , "s" , None )) for k in tail .value .keywords
116
+ )
117
+ if len (keywords ) == 1 :
118
+ if "foreach" in keywords :
119
+ # TYPE: foreach
120
+ self .type = "foreach"
121
+ if len (self .out_funcs ) == 1 :
122
+ self .foreach_param = keywords ["foreach" ]
123
+ self .invalid_tail_next = False
124
+ elif "num_parallel" in keywords :
125
+ self .type = "foreach"
126
+ self .parallel_foreach = True
127
+ if len (self .out_funcs ) == 1 :
128
+ self .num_parallel = keywords ["num_parallel" ]
129
+ self .invalid_tail_next = False
130
+ elif len (keywords ) == 0 :
131
+ if len (self .out_funcs ) > 1 :
132
+ # TYPE: split
133
+ self .type = "split"
134
+ self .invalid_tail_next = False
135
+ elif len (self .out_funcs ) == 1 :
136
+ # TYPE: linear
137
+ if self .name == "start" :
138
+ self .type = "start"
139
+ elif self .num_args > 1 :
140
+ self .type = "join"
141
+ else :
142
+ self .type = "linear"
143
+ self .invalid_tail_next = False
144
+ except AttributeError :
145
+ return
146
+
147
+
23
148
class FlowGraph (object ):
24
149
# NOTE: This implementation relies on passing in the name of the FlowSpec class
25
150
# to be parsed from the sourcecode.
0 commit comments