4
4
from collections import OrderedDict
5
5
from typeguard import check_type , TypeCheckError
6
6
import uuid , datetime
7
- from typing import Optional , List
7
+ from typing import (
8
+ Optional ,
9
+ List ,
10
+ OrderedDict as TOrderedDict ,
11
+ Any ,
12
+ Union ,
13
+ Dict ,
14
+ Callable ,
15
+ )
8
16
from metaflow import FlowSpec , Parameter
9
17
from metaflow .cli import start
10
18
from metaflow ._vendor import click
11
- from metaflow ._vendor .click import Command , Group , Argument , Option
12
19
from metaflow .parameters import JSONTypeClass
13
20
from metaflow .includefile import FilePathClass
14
21
from metaflow ._vendor .click .types import (
41
48
42
49
43
50
def _method_sanity_check (
44
- possible_arg_params , possible_opt_params , annotations , defaults , ** kwargs
45
- ):
51
+ possible_arg_params : TOrderedDict [str , click .Argument ],
52
+ possible_opt_params : TOrderedDict [str , click .Option ],
53
+ annotations : TOrderedDict [str , Any ],
54
+ defaults : TOrderedDict [str , Any ],
55
+ ** kwargs
56
+ ) -> Dict [str , Any ]:
46
57
method_params = {"args" : {}, "options" : {}}
47
58
48
59
possible_params = OrderedDict ()
@@ -85,7 +96,7 @@ def _method_sanity_check(
85
96
return method_params
86
97
87
98
88
- def get_annotation (param ):
99
+ def get_annotation (param : Union [ click . Argument , click . Option ] ):
89
100
py_type = click_to_python_types [type (param .type )]
90
101
if not param .required :
91
102
if param .multiple or param .nargs == - 1 :
@@ -99,7 +110,7 @@ def get_annotation(param):
99
110
return py_type
100
111
101
112
102
- def get_inspect_param_obj (p , kind ):
113
+ def get_inspect_param_obj (p : Union [ click . Argument , click . Option ], kind : str ):
103
114
return inspect .Parameter (
104
115
name = p .name ,
105
116
kind = kind ,
@@ -108,7 +119,7 @@ def get_inspect_param_obj(p, kind):
108
119
)
109
120
110
121
111
- def extract_flowspec_params (flow_file ) :
122
+ def extract_flowspec_params (flow_file : str ) -> List [ Parameter ] :
112
123
spec = importlib .util .spec_from_file_location ("module" , flow_file )
113
124
module = importlib .util .module_from_spec (spec )
114
125
spec .loader .exec_module (module )
@@ -140,16 +151,16 @@ def chain(self):
140
151
return self ._chain
141
152
142
153
@classmethod
143
- def from_cli (cls , flow_file , cli_collection ) :
154
+ def from_cli (cls , flow_file : str , cli_collection : Callable ) -> Callable :
144
155
flow_parameters = extract_flowspec_params (flow_file )
145
156
class_dict = {"__module__" : "metaflow" , "_API_NAME" : flow_file }
146
157
command_groups = cli_collection .sources
147
158
for each_group in command_groups :
148
159
for _ , cmd_obj in each_group .commands .items ():
149
- if isinstance (cmd_obj , Group ):
160
+ if isinstance (cmd_obj , click . Group ):
150
161
# TODO: possibly check for fake groups with cmd_obj.name in ["cli", "main"]
151
162
class_dict [cmd_obj .name ] = extract_group (cmd_obj , flow_parameters )
152
- elif isinstance (cmd_obj , Command ):
163
+ elif isinstance (cmd_obj , click . Command ):
153
164
class_dict [cmd_obj .name ] = extract_command (cmd_obj , flow_parameters )
154
165
else :
155
166
raise RuntimeError (
@@ -188,7 +199,7 @@ def _method(_self, **kwargs):
188
199
189
200
return m
190
201
191
- def execute (self ):
202
+ def execute (self ) -> List [ str ] :
192
203
parents = []
193
204
current = self
194
205
while current .parent :
@@ -225,7 +236,7 @@ def execute(self):
225
236
return components
226
237
227
238
228
- def extract_all_params (cmd_obj ):
239
+ def extract_all_params (cmd_obj : Union [ click . Command , click . Group ] ):
229
240
arg_params_sigs = OrderedDict ()
230
241
opt_params_sigs = OrderedDict ()
231
242
params_sigs = OrderedDict ()
@@ -236,12 +247,12 @@ def extract_all_params(cmd_obj):
236
247
defaults = OrderedDict ()
237
248
238
249
for each_param in cmd_obj .params :
239
- if isinstance (each_param , Argument ):
250
+ if isinstance (each_param , click . Argument ):
240
251
arg_params_sigs [each_param .name ] = get_inspect_param_obj (
241
252
each_param , inspect .Parameter .POSITIONAL_ONLY
242
253
)
243
254
arg_parameters [each_param .name ] = each_param
244
- elif isinstance (each_param , Option ):
255
+ elif isinstance (each_param , click . Option ):
245
256
opt_params_sigs [each_param .name ] = get_inspect_param_obj (
246
257
each_param , inspect .Parameter .KEYWORD_ONLY
247
258
)
@@ -260,13 +271,13 @@ def extract_all_params(cmd_obj):
260
271
return params_sigs , arg_parameters , opt_parameters , annotations , defaults
261
272
262
273
263
- def extract_group (cmd_obj , flow_parameters ) :
274
+ def extract_group (cmd_obj : click . Group , flow_parameters : List [ Parameter ]) -> Callable :
264
275
class_dict = {"__module__" : "metaflow" , "_API_NAME" : cmd_obj .name }
265
276
for _ , sub_cmd_obj in cmd_obj .commands .items ():
266
- if isinstance (sub_cmd_obj , Group ):
277
+ if isinstance (sub_cmd_obj , click . Group ):
267
278
# recursion
268
279
class_dict [sub_cmd_obj .name ] = extract_group (sub_cmd_obj , flow_parameters )
269
- elif isinstance (sub_cmd_obj , Command ):
280
+ elif isinstance (sub_cmd_obj , click . Command ):
270
281
class_dict [sub_cmd_obj .name ] = extract_command (sub_cmd_obj , flow_parameters )
271
282
else :
272
283
raise RuntimeError (
@@ -302,7 +313,9 @@ def _method(_self, **kwargs):
302
313
return m
303
314
304
315
305
- def extract_command (cmd_obj , flow_parameters ):
316
+ def extract_command (
317
+ cmd_obj : click .Command , flow_parameters : List [Parameter ]
318
+ ) -> Callable :
306
319
if getattr (cmd_obj , "has_flow_params" , False ):
307
320
for p in flow_parameters [::- 1 ]:
308
321
cmd_obj .params .insert (0 , click .Option (("--" + p .name ,), ** p .kwargs ))
0 commit comments