Skip to content

Commit 7db25ee

Browse files
committed
add type hints
1 parent a21c009 commit 7db25ee

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

metaflow/click_api.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
from collections import OrderedDict
55
from typeguard import check_type, TypeCheckError
66
import uuid, datetime
7-
from typing import Optional, List
7+
from typing import (
8+
Optional,
9+
List,
10+
OrderedDict as TOrderedDict,
11+
Any,
12+
Union,
13+
Dict,
14+
Callable,
15+
)
816
from metaflow import FlowSpec, Parameter
917
from metaflow.cli import start
1018
from metaflow._vendor import click
11-
from metaflow._vendor.click import Command, Group, Argument, Option
1219
from metaflow.parameters import JSONTypeClass
1320
from metaflow.includefile import FilePathClass
1421
from metaflow._vendor.click.types import (
@@ -41,8 +48,12 @@
4148

4249

4350
def _method_sanity_check(
44-
possible_arg_params, possible_opt_params, annotations, defaults, **kwargs
45-
):
51+
possible_arg_params: TOrderedDict[str, click.Argument],
52+
possible_opt_params: TOrderedDict[str, click.Option],
53+
annotations: TOrderedDict[str, Any],
54+
defaults: TOrderedDict[str, Any],
55+
**kwargs
56+
) -> Dict[str, Any]:
4657
method_params = {"args": {}, "options": {}}
4758

4859
possible_params = OrderedDict()
@@ -85,7 +96,7 @@ def _method_sanity_check(
8596
return method_params
8697

8798

88-
def get_annotation(param):
99+
def get_annotation(param: Union[click.Argument, click.Option]):
89100
py_type = click_to_python_types[type(param.type)]
90101
if not param.required:
91102
if param.multiple or param.nargs == -1:
@@ -99,7 +110,7 @@ def get_annotation(param):
99110
return py_type
100111

101112

102-
def get_inspect_param_obj(p, kind):
113+
def get_inspect_param_obj(p: Union[click.Argument, click.Option], kind: str):
103114
return inspect.Parameter(
104115
name=p.name,
105116
kind=kind,
@@ -108,7 +119,7 @@ def get_inspect_param_obj(p, kind):
108119
)
109120

110121

111-
def extract_flowspec_params(flow_file):
122+
def extract_flowspec_params(flow_file: str) -> List[Parameter]:
112123
spec = importlib.util.spec_from_file_location("module", flow_file)
113124
module = importlib.util.module_from_spec(spec)
114125
spec.loader.exec_module(module)
@@ -140,16 +151,16 @@ def chain(self):
140151
return self._chain
141152

142153
@classmethod
143-
def from_cli(cls, flow_file, cli_collection):
154+
def from_cli(cls, flow_file: str, cli_collection: Callable) -> Callable:
144155
flow_parameters = extract_flowspec_params(flow_file)
145156
class_dict = {"__module__": "metaflow", "_API_NAME": flow_file}
146157
command_groups = cli_collection.sources
147158
for each_group in command_groups:
148159
for _, cmd_obj in each_group.commands.items():
149-
if isinstance(cmd_obj, Group):
160+
if isinstance(cmd_obj, click.Group):
150161
# TODO: possibly check for fake groups with cmd_obj.name in ["cli", "main"]
151162
class_dict[cmd_obj.name] = extract_group(cmd_obj, flow_parameters)
152-
elif isinstance(cmd_obj, Command):
163+
elif isinstance(cmd_obj, click.Command):
153164
class_dict[cmd_obj.name] = extract_command(cmd_obj, flow_parameters)
154165
else:
155166
raise RuntimeError(
@@ -188,7 +199,7 @@ def _method(_self, **kwargs):
188199

189200
return m
190201

191-
def execute(self):
202+
def execute(self) -> List[str]:
192203
parents = []
193204
current = self
194205
while current.parent:
@@ -225,7 +236,7 @@ def execute(self):
225236
return components
226237

227238

228-
def extract_all_params(cmd_obj):
239+
def extract_all_params(cmd_obj: Union[click.Command, click.Group]):
229240
arg_params_sigs = OrderedDict()
230241
opt_params_sigs = OrderedDict()
231242
params_sigs = OrderedDict()
@@ -236,12 +247,12 @@ def extract_all_params(cmd_obj):
236247
defaults = OrderedDict()
237248

238249
for each_param in cmd_obj.params:
239-
if isinstance(each_param, Argument):
250+
if isinstance(each_param, click.Argument):
240251
arg_params_sigs[each_param.name] = get_inspect_param_obj(
241252
each_param, inspect.Parameter.POSITIONAL_ONLY
242253
)
243254
arg_parameters[each_param.name] = each_param
244-
elif isinstance(each_param, Option):
255+
elif isinstance(each_param, click.Option):
245256
opt_params_sigs[each_param.name] = get_inspect_param_obj(
246257
each_param, inspect.Parameter.KEYWORD_ONLY
247258
)
@@ -260,13 +271,13 @@ def extract_all_params(cmd_obj):
260271
return params_sigs, arg_parameters, opt_parameters, annotations, defaults
261272

262273

263-
def extract_group(cmd_obj, flow_parameters):
274+
def extract_group(cmd_obj: click.Group, flow_parameters: List[Parameter]) -> Callable:
264275
class_dict = {"__module__": "metaflow", "_API_NAME": cmd_obj.name}
265276
for _, sub_cmd_obj in cmd_obj.commands.items():
266-
if isinstance(sub_cmd_obj, Group):
277+
if isinstance(sub_cmd_obj, click.Group):
267278
# recursion
268279
class_dict[sub_cmd_obj.name] = extract_group(sub_cmd_obj, flow_parameters)
269-
elif isinstance(sub_cmd_obj, Command):
280+
elif isinstance(sub_cmd_obj, click.Command):
270281
class_dict[sub_cmd_obj.name] = extract_command(sub_cmd_obj, flow_parameters)
271282
else:
272283
raise RuntimeError(
@@ -302,7 +313,9 @@ def _method(_self, **kwargs):
302313
return m
303314

304315

305-
def extract_command(cmd_obj, flow_parameters):
316+
def extract_command(
317+
cmd_obj: click.Command, flow_parameters: List[Parameter]
318+
) -> Callable:
306319
if getattr(cmd_obj, "has_flow_params", False):
307320
for p in flow_parameters[::-1]:
308321
cmd_obj.params.insert(0, click.Option(("--" + p.name,), **p.kwargs))

0 commit comments

Comments
 (0)