diff --git a/luigi/parameter.py b/luigi/parameter.py index 9f1ae104d3..bc06127721 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -142,6 +142,7 @@ class UnconsumedParameterWarning(UserWarning): T = TypeVar("T", default=str) +_OptT = TypeVar("_OptT") class ConfigPath(TypedDict): @@ -426,21 +427,31 @@ def _parser_kwargs(cls, param_name, task_name=None): } -class OptionalParameterMixin: +class OptionalParameterMixin(Generic[_OptT]): """ Mixin to make a parameter class optional and treat empty string as None. """ - expected_type = type(None) + expected_type: type = type(None) + + def __init__( + self, + default: Union[_OptT, None, _NoValueType] = _no_value, + **kwargs: Unpack[_ParameterKwargs], + ): + super().__init__(default=default, **kwargs) # type: ignore[arg-type, call-arg, misc] @overload - def __get__(self: "Parameter[T]", instance: None, owner: Any) -> "Parameter[Optional[T]]": ... + def __get__(self, instance: None, owner: Any) -> "Parameter[Optional[_OptT]]": ... @overload - def __get__(self: "Parameter[T]", instance: Any, owner: Any) -> Optional[T]: ... + def __get__(self, instance: Any, owner: Any) -> Optional[_OptT]: ... def __get__(self, instance: Any, owner: Any) -> Any: - return super().__get__(instance, owner) + return super().__get__(instance, owner) # type: ignore[misc] + + def __set__(self, instance: Any, value: Optional[_OptT]): + super().__set__(instance, value) # type: ignore[misc] def serialize(self, x): """ @@ -485,13 +496,13 @@ def next_in_enumeration(self, value): return None -class OptionalParameter(OptionalParameterMixin, Parameter[Optional[str]]): +class OptionalParameter(OptionalParameterMixin[str], Parameter[Optional[str]]): """Class to parse optional parameters.""" expected_type = str -class OptionalStrParameter(OptionalParameterMixin, Parameter[Optional[str]]): +class OptionalStrParameter(OptionalParameterMixin[str], Parameter[Optional[str]]): """Class to parse optional str parameters.""" expected_type = str @@ -798,7 +809,7 @@ def next_in_enumeration(self, value): return value + 1 -class OptionalIntParameter(OptionalParameterMixin, IntParameter): +class OptionalIntParameter(OptionalParameterMixin[int], IntParameter): # type: ignore[misc] """Class to parse optional int parameters.""" expected_type = int @@ -816,7 +827,7 @@ def parse(self, x): return float(x) -class OptionalFloatParameter(OptionalParameterMixin, FloatParameter): +class OptionalFloatParameter(OptionalParameterMixin[float], FloatParameter): # type: ignore[misc] """Class to parse optional float parameters.""" expected_type = float @@ -897,7 +908,7 @@ def _parser_kwargs(self, *args, **kwargs): return parser_kwargs -class OptionalBoolParameter(OptionalParameterMixin, BoolParameter): +class OptionalBoolParameter(OptionalParameterMixin[bool], BoolParameter): # type: ignore[misc] """Class to parse optional bool parameters.""" expected_type = bool @@ -1299,7 +1310,7 @@ def serialize(self, x): return json.dumps(x, cls=_DictParamEncoder) -class OptionalDictParameter(OptionalParameterMixin, DictParameter): +class OptionalDictParameter(OptionalParameterMixin[FrozenOrderedDict], DictParameter): # type: ignore[misc] """Class to parse optional dict parameters.""" expected_type = FrozenOrderedDict @@ -1454,7 +1465,7 @@ def serialize(self, x): return json.dumps(x, cls=_DictParamEncoder) -class OptionalListParameter(OptionalParameterMixin, ListParameter): +class OptionalListParameter(OptionalParameterMixin[ListT], ListParameter): # type: ignore[misc] """Class to parse optional list parameters.""" expected_type = tuple @@ -1525,7 +1536,7 @@ def _convert_iterable_to_tuple(self, x): return tuple(x) -class OptionalTupleParameter(OptionalParameterMixin, TupleParameter): +class OptionalTupleParameter(OptionalParameterMixin[ListT], TupleParameter): # type: ignore[misc] """Class to parse optional tuple parameters.""" expected_type = tuple @@ -1588,13 +1599,13 @@ def __init__( """ if var_type is None: raise ParameterException("var_type must be specified") - self._var_type = var_type + self._var_type: Type[NumericalType] = var_type if min_value is None: raise ParameterException("min_value must be specified") - self._min_value = min_value + self._min_value: NumericalType = min_value if max_value is None: raise ParameterException("max_value must be specified") - self._max_value = max_value + self._max_value: NumericalType = max_value self._left_op = left_op self._right_op = right_op self._permitted_range = "{var_type} in {left_endpoint}{min_value}, {max_value}{right_endpoint}".format( @@ -1604,7 +1615,7 @@ def __init__( left_endpoint="[" if left_op == operator.le else "(", right_endpoint=")" if right_op == operator.lt else "]", ) - super().__init__(default=default, **kwargs) + super().__init__(default=default, **kwargs) # type: ignore[arg-type] if self.description: self.description += " " else: @@ -1619,7 +1630,7 @@ def parse(self, x): raise ValueError("{s} is not in the set of {permitted_range}".format(s=x, permitted_range=self._permitted_range)) -class OptionalNumericalParameter(OptionalParameterMixin, NumericalParameter): +class OptionalNumericalParameter(OptionalParameterMixin[NumericalType], NumericalParameter[NumericalType]): # type: ignore[misc] """Class to parse optional numerical parameters.""" def __init__( @@ -1627,7 +1638,7 @@ def __init__( default: Union[Optional[NumericalType], _NoValueType] = _no_value, **kwargs: Unpack[_ParameterKwargs], ): - super().__init__(default=default, **kwargs) + NumericalParameter.__init__(self, default=default, **kwargs) # type: ignore[arg-type, misc] self.expected_type = self._var_type @@ -1664,7 +1675,7 @@ def __init__( default: Union[ChoiceType, _NoValueType] = _no_value, *, choices: Optional[Sequence[ChoiceType]] = None, - var_type: Type[ChoiceType] = str, + var_type: Type[ChoiceType] = str, # type: ignore[assignment] **kwargs: Unpack[_ParameterKwargs], ): """ @@ -1726,7 +1737,7 @@ class MyTask(luigi.Task): _sep = "," - @overload + @overload # type: ignore[override] def __get__(self, instance: None, owner: Any) -> "Parameter[Tuple[ChoiceType, ...]]": ... @overload @@ -1738,7 +1749,7 @@ def __get__(self, instance: Any, owner: Any) -> Any: def __init__( self, default: Union[Tuple[ChoiceType, ...], _NoValueType] = _no_value, - var_type: Type[ChoiceType] = str, + var_type: Type[ChoiceType] = str, # type: ignore[assignment] choices: Optional[Sequence[ChoiceType]] = None, **kwargs: Unpack[_ParameterKwargs], ): @@ -1758,17 +1769,17 @@ def serialize(self, x): return self._sep.join(x) -class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter[ChoiceType]): +class OptionalChoiceParameter(OptionalParameterMixin[ChoiceType], ChoiceParameter[ChoiceType]): # type: ignore[misc] """Class to parse optional choice parameters.""" def __init__( self, default: Union[Optional[ChoiceType], _NoValueType] = _no_value, - var_type: Type[ChoiceType] = str, + var_type: Type[ChoiceType] = str, # type: ignore[assignment] choices: Optional[Sequence[ChoiceType]] = None, **kwargs: Unpack[_ParameterKwargs], ): - super().__init__(default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type] + ChoiceParameter.__init__(self, default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type, misc] self.expected_type = self._var_type @@ -1831,7 +1842,7 @@ def normalize(self, x): return path -class OptionalPathParameter(OptionalParameter, PathParameter): +class OptionalPathParameter(OptionalParameter, PathParameter): # type: ignore[misc] """Class to parse optional path parameters.""" - expected_type = (str, Path) # type: ignore + expected_type = (str, Path) # type: ignore[assignment] diff --git a/pyproject.toml b/pyproject.toml index 36770fd7eb..a5e2b31275 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,7 +191,6 @@ module = [ "luigi.contrib.sqla", "luigi.interface", "luigi.notifications", - "luigi.parameter", "luigi.tools.range", "luigi.worker", ]