Skip to content

feat: add mypy plugin options to handle missing paramters for a task #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/mypy_plugin.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,17 @@ Mypy plugin checks TaskOnKart generic types.
str_task=StrTask(), # mypy ok
int_task=StrTask(), # mypy error: Argument "int_task" to "StrTask" has incompatible type "StrTask"; expected "TaskOnKart[int]
)

Configurations (only pyproject.toml)
-----------------------------------

You can configure the Mypy plugin using the ``pyproject.toml`` file.
The following options are available:

.. code:: toml

[tool.gokart-mypy]
# If true, Mypy will raise an error if a task is missing required parameters.
# This configuration causes an error when the parameters set by `luigi.Config()`
# Default: false
disallow_missing_parameters = true
90 changes: 82 additions & 8 deletions gokart/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
from __future__ import annotations

import re
import sys
import warnings
from collections.abc import Iterator
from typing import Callable, Final, Literal
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Final, Literal

import luigi
from mypy.expandtype import expand_type
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
ArgKind,
Argument,
AssignmentStmt,
Block,
Expand All @@ -32,6 +38,7 @@
TypeInfo,
Var,
)
from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
add_method_to_class,
Expand All @@ -56,7 +63,54 @@
PARAMETER_TMP_MATCHER: Final = re.compile(r'^\w*Parameter$')


class PluginOptions(Enum):
DISALLOW_MISSING_PARAMETERS = 'disallow_missing_parameters'


@dataclass
class TaskOnKartPluginOptions:
# Whether to error on missing parameters in the constructor.
# Some projects use luigi.Config to set parameters, which does not require parameters to be explicitly passed to the constructor.
disallow_missing_parameters: bool = False

@classmethod
def _parse_toml(cls, config_file: str) -> dict[str, Any]:
if sys.version_info >= (3, 11):
import tomllib as toml_
else:
try:
import tomli as toml_
except ImportError: # pragma: no cover
warnings.warn('install tomli to parse pyproject.toml under Python 3.10', stacklevel=1)
return {}

with open(config_file, 'rb') as f:
return toml_.load(f)

@classmethod
def parse_config_file(cls, config_file: str) -> TaskOnKartPluginOptions:
# TODO: support other configuration file formats if necessary.
if not config_file.endswith('.toml'):
warnings.warn('gokart mypy plugin can be configured by pyproject.toml', stacklevel=1)
return cls()

config = cls._parse_toml(config_file)
gokart_plugin_config = config.get('tool', {}).get('gokart-mypy', {})

disallow_missing_parameters = gokart_plugin_config.get(PluginOptions.DISALLOW_MISSING_PARAMETERS.value, False)
if not isinstance(disallow_missing_parameters, bool):
raise ValueError(f'{PluginOptions.DISALLOW_MISSING_PARAMETERS.value} must be a boolean value')
return cls(disallow_missing_parameters=disallow_missing_parameters)


class TaskOnKartPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
if options.config_file is not None:
self._options = TaskOnKartPluginOptions.parse_config_file(options.config_file)
else:
self._options = TaskOnKartPluginOptions()

def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
# The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
# the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
Expand All @@ -78,7 +132,7 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return None

def _task_on_kart_class_maker_callback(self, ctx: ClassDefContext) -> None:
transformer = TaskOnKartTransformer(ctx.cls, ctx.reason, ctx.api)
transformer = TaskOnKartTransformer(ctx.cls, ctx.reason, ctx.api, self._options)
transformer.transform()

def _task_on_kart_parameter_field_callback(self, ctx: FunctionContext) -> Type:
Expand Down Expand Up @@ -125,6 +179,7 @@ def __init__(
type: Type | None,
info: TypeInfo,
api: SemanticAnalyzerPluginInterface,
options: TaskOnKartPluginOptions,
) -> None:
self.name = name
self.has_default = has_default
Expand All @@ -133,12 +188,12 @@ def __init__(
self.type = type # Type as __init__ argument
self.info = info
self._api = api
self._options = options

def to_argument(self, current_info: TypeInfo, *, of: Literal['__init__',]) -> Argument:
if of == '__init__':
# All arguments to __init__ are keyword-only and optional
# This is because gokart can set parameters by configuration'
arg_kind = ARG_NAMED_OPT
arg_kind = self._get_arg_kind_by_options()

return Argument(
variable=self.to_var(current_info),
type_annotation=self.expand_type(current_info),
Expand Down Expand Up @@ -170,10 +225,10 @@ def serialize(self) -> JsonDict:
}

@classmethod
def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface) -> TaskOnKartAttribute:
def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface, options: TaskOnKartPluginOptions) -> TaskOnKartAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data, api=api)
return cls(type=typ, info=info, **data, api=api, options=options)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
Expand All @@ -182,6 +237,22 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
with state.strict_optional_set(self._api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)

def _get_arg_kind_by_options(self) -> Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]:
"""Set the argument kind based on the options.

if `disallow_missing_parameters` is True, the argument kind is `ARG_NAMED` when the attribute has no default value.
This means the that all the parameters are passed to the constructor as keyword-only arguments.

Returns:
Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: The argument kind.
"""
if not self._options.disallow_missing_parameters:
return ARG_NAMED_OPT
if self.has_default:
return ARG_NAMED_OPT
# required parameter
return ARG_NAMED


class TaskOnKartTransformer:
"""Implement the behavior of gokart.TaskOnKart."""
Expand All @@ -191,10 +262,12 @@ def __init__(
cls: ClassDef,
reason: Expression | Statement,
api: SemanticAnalyzerPluginInterface,
options: TaskOnKartPluginOptions,
) -> None:
self._cls = cls
self._reason = reason
self._api = api
self._options = options

def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying gokart.TaskOnKart"""
Expand Down Expand Up @@ -267,7 +340,7 @@ def collect_attributes(self) -> list[TaskOnKartAttribute] | None:
for data in info.metadata[METADATA_TAG]['attributes']:
name: str = data['name']

attr = TaskOnKartAttribute.deserialize(info, data, self._api)
attr = TaskOnKartAttribute.deserialize(info, data, self._api, self._options)
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
Expand Down Expand Up @@ -337,6 +410,7 @@ def collect_attributes(self) -> list[TaskOnKartAttribute] | None:
type=init_type,
info=cls.info,
api=self._api,
options=self._options,
)

return list(found_attrs.values())
Expand Down
1 change: 1 addition & 0 deletions test/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@

CONFIG_DIR: Final[Path] = Path(__file__).parent.resolve()
PYPROJECT_TOML: Final[Path] = CONFIG_DIR / 'pyproject.toml'
PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS: Final[Path] = CONFIG_DIR / 'pyproject_disallow_missing_parameters.toml'
TEST_CONFIG_INI: Final[Path] = CONFIG_DIR / 'test_config.ini'
2 changes: 1 addition & 1 deletion test/config/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.mypy]
plugins = ["gokart.mypy:plugin"]
plugins = ["gokart.mypy"]

[[tool.mypy.overrides]]
ignore_missing_imports = true
Expand Down
9 changes: 9 additions & 0 deletions test/config/pyproject_disallow_missing_parameters.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[tool.mypy]
plugins = ["gokart.mypy"]

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"]

[tool.gokart-mypy]
disallow_missing_parameters = true
58 changes: 53 additions & 5 deletions test/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mypy import api

from test.config import PYPROJECT_TOML
from test.config import PYPROJECT_TOML, PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS


class TestMyMypyPlugin(unittest.TestCase):
Expand All @@ -16,7 +16,6 @@ def test_plugin_no_issue(self):


class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo: int = luigi.IntParameter() # type: ignore
bar: str = luigi.Parameter() # type: ignore
baz: bool = gokart.ExplicitBoolParameter()
Expand Down Expand Up @@ -44,7 +43,6 @@ def test_plugin_invalid_arg(self):


class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo: int = luigi.IntParameter() # type: ignore
bar: str = luigi.Parameter() # type: ignore
baz: bool = gokart.ExplicitBoolParameter()
Expand Down Expand Up @@ -79,7 +77,6 @@ class MyEnum(enum.Enum):
FOO = enum.auto()

class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
Expand Down Expand Up @@ -110,7 +107,6 @@ def test_parameter_has_default_type_no_issue_pattern(self):
import gokart

class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
Expand All @@ -122,3 +118,55 @@ class MyTask(gokart.TaskOnKart):
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('Success: no issues found', result[0])

def test_no_issue_found_when_missing_parameter_when_default_option(self):
"""
If `disallow_missing_parameters` is False (or default), mypy doesn't show any error when missing parameters.
"""
test_code = """
import luigi
import gokart

class MyTask(gokart.TaskOnKart):
foo = luigi.IntParameter()
bar = luigi.Parameter(default="bar")

MyTask()
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('Success: no issues found', result[0])

def test_issue_found_when_missing_parameter_when_disallow_missing_parameters_set_true(self):
"""
If `disallow_missing_parameters` is True, mypy shows an error when missing parameters.
"""
test_code = """
import luigi
import gokart

class MyTask(gokart.TaskOnKart):
# issue: foo is missing
foo = luigi.IntParameter()
# bar has default value, so it is not required to set it.
bar = luigi.Parameter(default="bar")

MyTask()
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(
[
'--show-traceback',
'--no-incremental',
'--cache-dir=/dev/null',
'--config-file',
str(PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS),
test_file.name,
]
)
self.assertIn('error: Missing named argument "foo" for "MyTask" [call-arg]', result[0])
self.assertIn('Found 1 error in 1 file (checked 1 source file)', result[0])