9
9
import subprocess
10
10
from multiprocessing import Pool
11
11
12
+ from metaflow .cli import start , run
12
13
from metaflow ._vendor import click
14
+ from metaflow .click_api import MetaflowAPI , extract_all_params , click_to_python_types
13
15
14
16
from metaflow_test import MetaflowTest
15
17
from metaflow_test .formatter import FlowFormatter
@@ -61,14 +63,72 @@ def log(msg, formatter=None, context=None, real_bad=False, real_good=False):
61
63
click .echo ("[pid %s] %s" % (pid , line ))
62
64
63
65
64
- def run_test (formatter , context , debug , checks , env_base ):
66
+ def run_test (formatter , context , debug , checks , env_base , use_chaining_api = False ):
65
67
def run_cmd (mode ):
66
68
cmd = [context ["python" ], "-B" , "test_flow.py" ]
67
69
cmd .extend (context ["top_options" ])
68
70
cmd .extend ((mode , "--run-id-file" , "run-id" ))
69
71
cmd .extend (context ["run_options" ])
70
72
return cmd
71
73
74
+ def construct_arg_dict (params_opts , cli_options ):
75
+ result_dict = {}
76
+ has_value = False
77
+ secondary_supplied = False
78
+
79
+ for arg in cli_options :
80
+ if "=" in arg :
81
+ given_opt , val = arg .split ("=" )
82
+ has_value = True
83
+ else :
84
+ given_opt = arg
85
+
86
+ for key , each_param in params_opts .items ():
87
+ py_type = click_to_python_types [type (each_param .type )]
88
+ if given_opt in each_param .opts :
89
+ secondary_supplied = False
90
+ elif given_opt in each_param .secondary_opts :
91
+ secondary_supplied = True
92
+ else :
93
+ continue
94
+
95
+ if has_value :
96
+ value = val
97
+ else :
98
+ if secondary_supplied :
99
+ value = not each_param .default
100
+ else :
101
+ value = each_param .default
102
+
103
+ if each_param .multiple :
104
+ if key not in result_dict :
105
+ result_dict [key ] = [py_type (value )]
106
+ else :
107
+ result_dict [key ].append (py_type (value ))
108
+ else :
109
+ result_dict [key ] = py_type (value )
110
+
111
+ has_value = False
112
+ secondary_supplied = False
113
+
114
+ return result_dict
115
+
116
+ def construct_cmd_from_click_api (mode ):
117
+ api = MetaflowAPI .from_cli ("test_flow.py" , start )
118
+ _ , _ , param_opts , _ , _ = extract_all_params (start )
119
+ top_level_options = context ["top_options" ]
120
+ top_level_dict = construct_arg_dict (param_opts , top_level_options )
121
+
122
+ _ , _ , param_opts , _ , _ = extract_all_params (run )
123
+ run_level_options = context ["run_options" ]
124
+ run_level_dict = construct_arg_dict (param_opts , run_level_options )
125
+ run_level_dict ["run_id_file" ] = "run-id"
126
+
127
+ cmd = getattr (api (** top_level_dict ), mode )(** run_level_dict )
128
+ command = [context ["python" ], "-B" ]
129
+ command .extend (cmd )
130
+ return command
131
+
72
132
cwd = os .getcwd ()
73
133
tempdir = tempfile .mkdtemp ("_metaflow_test" )
74
134
package = os .path .dirname (os .path .abspath (__file__ ))
@@ -123,7 +183,10 @@ def run_cmd(mode):
123
183
return pre_ret , path
124
184
125
185
# run flow
126
- flow_ret = subprocess .call (run_cmd ("run" ), env = env )
186
+ if use_chaining_api :
187
+ flow_ret = subprocess .call (construct_cmd_from_click_api ("run" ), env = env )
188
+ else :
189
+ flow_ret = subprocess .call (run_cmd ("run" ), env = env )
127
190
if flow_ret :
128
191
if formatter .should_fail :
129
192
log ("Flow failed as expected." )
@@ -242,6 +305,33 @@ def run_test_cases(args):
242
305
return failed
243
306
else :
244
307
log ("success" , formatter , context , real_good = True )
308
+
309
+ log ("running [with chaining api]" , formatter , context )
310
+ ret , path = run_test (
311
+ formatter ,
312
+ context ,
313
+ debug ,
314
+ contexts ["checks" ],
315
+ base_env ,
316
+ use_chaining_api = True ,
317
+ )
318
+
319
+ if ret :
320
+ tstid = "%s in context %s [with chaining api]" % (
321
+ formatter ,
322
+ context ["name" ],
323
+ )
324
+ failed .append ((tstid , path ))
325
+ log ("failed [with chaining api]" , formatter , context , real_bad = True )
326
+ if debug :
327
+ return failed
328
+ else :
329
+ log (
330
+ "success [with chaining api]" ,
331
+ formatter ,
332
+ context ,
333
+ real_good = True ,
334
+ )
245
335
else :
246
336
log ("not a valid combination. Skipped." , formatter )
247
337
return failed
@@ -258,13 +348,13 @@ def run_test_cases(args):
258
348
"--tests" ,
259
349
default = "" ,
260
350
type = str ,
261
- help = "A comma-separate list of graphs to include (default: all)." ,
351
+ help = "A comma-separated list of tests to include (default: all)." ,
262
352
)
263
353
@click .option (
264
354
"--graphs" ,
265
355
default = "" ,
266
356
type = str ,
267
- help = "A comma-separate list of graphs to include (default: all)." ,
357
+ help = "A comma-separated list of graphs to include (default: all)." ,
268
358
)
269
359
@click .option (
270
360
"--debug" ,
0 commit comments