7
7
from __future__ import annotations
8
8
9
9
import re
10
+ import sys
11
+ import warnings
10
12
from collections .abc import Iterator
11
- from typing import Callable , Final , Literal
13
+ from dataclasses import dataclass
14
+ from enum import Enum
15
+ from typing import Any , Callable , Final , Literal
12
16
13
17
import luigi
14
18
from mypy .expandtype import expand_type
15
19
from mypy .nodes import (
20
+ ARG_NAMED ,
16
21
ARG_NAMED_OPT ,
22
+ ArgKind ,
17
23
Argument ,
18
24
AssignmentStmt ,
19
25
Block ,
32
38
TypeInfo ,
33
39
Var ,
34
40
)
41
+ from mypy .options import Options
35
42
from mypy .plugin import ClassDefContext , FunctionContext , Plugin , SemanticAnalyzerPluginInterface
36
43
from mypy .plugins .common import (
37
44
add_method_to_class ,
56
63
PARAMETER_TMP_MATCHER : Final = re .compile (r'^\w*Parameter$' )
57
64
58
65
66
+ class PluginOptions (Enum ):
67
+ DISALLOW_MISSING_PARAMETERS = 'disallow_missing_parameters'
68
+
69
+
70
+ @dataclass
71
+ class TaskOnKartPluginOptions :
72
+ # Whether to error on missing parameters in the constructor.
73
+ # Some projects use luigi.Config to set parameters, which does not require parameters to be explicitly passed to the constructor.
74
+ disallow_missing_parameters : bool = False
75
+
76
+ @classmethod
77
+ def _parse_toml (cls , config_file : str ) -> dict [str , Any ]:
78
+ if sys .version_info >= (3 , 11 ):
79
+ import tomllib as toml_
80
+ else :
81
+ try :
82
+ import tomli as toml_
83
+ except ImportError : # pragma: no cover
84
+ warnings .warn ('install tomli to parse pyproject.toml under Python 3.10' , stacklevel = 1 )
85
+ return {}
86
+
87
+ with open (config_file , 'rb' ) as f :
88
+ return toml_ .load (f )
89
+
90
+ @classmethod
91
+ def parse_config_file (cls , config_file : str ) -> TaskOnKartPluginOptions :
92
+ # TODO: support other configuration file formats if necessary.
93
+ if not config_file .endswith ('.toml' ):
94
+ warnings .warn ('gokart mypy plugin can be configured by pyproject.toml' , stacklevel = 1 )
95
+ return cls ()
96
+
97
+ config = cls ._parse_toml (config_file )
98
+ gokart_plugin_config = config .get ('tool' , {}).get ('gokart-mypy' , {})
99
+
100
+ disallow_missing_parameters = gokart_plugin_config .get (PluginOptions .DISALLOW_MISSING_PARAMETERS .value , False )
101
+ if not isinstance (disallow_missing_parameters , bool ):
102
+ raise ValueError (f'{ PluginOptions .DISALLOW_MISSING_PARAMETERS .value } must be a boolean value' )
103
+ return cls (disallow_missing_parameters = disallow_missing_parameters )
104
+
105
+
59
106
class TaskOnKartPlugin (Plugin ):
107
+ def __init__ (self , options : Options ) -> None :
108
+ super ().__init__ (options )
109
+ if options .config_file is not None :
110
+ self ._options = TaskOnKartPluginOptions .parse_config_file (options .config_file )
111
+ else :
112
+ self ._options = TaskOnKartPluginOptions ()
113
+
60
114
def get_base_class_hook (self , fullname : str ) -> Callable [[ClassDefContext ], None ] | None :
61
115
# The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
62
116
# the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
@@ -78,7 +132,7 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
78
132
return None
79
133
80
134
def _task_on_kart_class_maker_callback (self , ctx : ClassDefContext ) -> None :
81
- transformer = TaskOnKartTransformer (ctx .cls , ctx .reason , ctx .api )
135
+ transformer = TaskOnKartTransformer (ctx .cls , ctx .reason , ctx .api , self . _options )
82
136
transformer .transform ()
83
137
84
138
def _task_on_kart_parameter_field_callback (self , ctx : FunctionContext ) -> Type :
@@ -125,6 +179,7 @@ def __init__(
125
179
type : Type | None ,
126
180
info : TypeInfo ,
127
181
api : SemanticAnalyzerPluginInterface ,
182
+ options : TaskOnKartPluginOptions ,
128
183
) -> None :
129
184
self .name = name
130
185
self .has_default = has_default
@@ -133,12 +188,12 @@ def __init__(
133
188
self .type = type # Type as __init__ argument
134
189
self .info = info
135
190
self ._api = api
191
+ self ._options = options
136
192
137
193
def to_argument (self , current_info : TypeInfo , * , of : Literal ['__init__' ,]) -> Argument :
138
194
if of == '__init__' :
139
- # All arguments to __init__ are keyword-only and optional
140
- # This is because gokart can set parameters by configuration'
141
- arg_kind = ARG_NAMED_OPT
195
+ arg_kind = self ._get_arg_kind_by_options ()
196
+
142
197
return Argument (
143
198
variable = self .to_var (current_info ),
144
199
type_annotation = self .expand_type (current_info ),
@@ -170,10 +225,10 @@ def serialize(self) -> JsonDict:
170
225
}
171
226
172
227
@classmethod
173
- def deserialize (cls , info : TypeInfo , data : JsonDict , api : SemanticAnalyzerPluginInterface ) -> TaskOnKartAttribute :
228
+ def deserialize (cls , info : TypeInfo , data : JsonDict , api : SemanticAnalyzerPluginInterface , options : TaskOnKartPluginOptions ) -> TaskOnKartAttribute :
174
229
data = data .copy ()
175
230
typ = deserialize_and_fixup_type (data .pop ('type' ), api )
176
- return cls (type = typ , info = info , ** data , api = api )
231
+ return cls (type = typ , info = info , ** data , api = api , options = options )
177
232
178
233
def expand_typevar_from_subtype (self , sub_type : TypeInfo ) -> None :
179
234
"""Expands type vars in the context of a subtype when an attribute is inherited
@@ -182,6 +237,22 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
182
237
with state .strict_optional_set (self ._api .options .strict_optional ):
183
238
self .type = map_type_from_supertype (self .type , sub_type , self .info )
184
239
240
+ def _get_arg_kind_by_options (self ) -> Literal [ArgKind .ARG_NAMED , ArgKind .ARG_NAMED_OPT ]:
241
+ """Set the argument kind based on the options.
242
+
243
+ if `disallow_missing_parameters` is True, the argument kind is `ARG_NAMED` when the attribute has no default value.
244
+ This means the that all the parameters are passed to the constructor as keyword-only arguments.
245
+
246
+ Returns:
247
+ Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: The argument kind.
248
+ """
249
+ if not self ._options .disallow_missing_parameters :
250
+ return ARG_NAMED_OPT
251
+ if self .has_default :
252
+ return ARG_NAMED_OPT
253
+ # required parameter
254
+ return ARG_NAMED
255
+
185
256
186
257
class TaskOnKartTransformer :
187
258
"""Implement the behavior of gokart.TaskOnKart."""
@@ -191,10 +262,12 @@ def __init__(
191
262
cls : ClassDef ,
192
263
reason : Expression | Statement ,
193
264
api : SemanticAnalyzerPluginInterface ,
265
+ options : TaskOnKartPluginOptions ,
194
266
) -> None :
195
267
self ._cls = cls
196
268
self ._reason = reason
197
269
self ._api = api
270
+ self ._options = options
198
271
199
272
def transform (self ) -> bool :
200
273
"""Apply all the necessary transformations to the underlying gokart.TaskOnKart"""
@@ -267,7 +340,7 @@ def collect_attributes(self) -> list[TaskOnKartAttribute] | None:
267
340
for data in info .metadata [METADATA_TAG ]['attributes' ]:
268
341
name : str = data ['name' ]
269
342
270
- attr = TaskOnKartAttribute .deserialize (info , data , self ._api )
343
+ attr = TaskOnKartAttribute .deserialize (info , data , self ._api , self . _options )
271
344
# TODO: We shouldn't be performing type operations during the main
272
345
# semantic analysis pass, since some TypeInfo attributes might
273
346
# still be in flux. This should be performed in a later phase.
@@ -337,6 +410,7 @@ def collect_attributes(self) -> list[TaskOnKartAttribute] | None:
337
410
type = init_type ,
338
411
info = cls .info ,
339
412
api = self ._api ,
413
+ options = self ._options ,
340
414
)
341
415
342
416
return list (found_attrs .values ())
0 commit comments