Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class UnconsumedParameterWarning(UserWarning):


T = TypeVar("T", default=str)
_OptT = TypeVar("_OptT")


class ConfigPath(TypedDict):
Expand Down Expand Up @@ -426,21 +427,31 @@ def _parser_kwargs(cls, param_name, task_name=None):
}


class OptionalParameterMixin:
class OptionalParameterMixin(Generic[_OptT]):
"""
Mixin to make a parameter class optional and treat empty string as None.
"""

expected_type = type(None)
expected_type: type = type(None)

def __init__(
self,
default: Union[_OptT, None, _NoValueType] = _no_value,
**kwargs: Unpack[_ParameterKwargs],
):
super().__init__(default=default, **kwargs) # type: ignore[arg-type, call-arg, misc]

@overload
def __get__(self: "Parameter[T]", instance: None, owner: Any) -> "Parameter[Optional[T]]": ...
def __get__(self, instance: None, owner: Any) -> "Parameter[Optional[_OptT]]": ...

@overload
def __get__(self: "Parameter[T]", instance: Any, owner: Any) -> Optional[T]: ...
def __get__(self, instance: Any, owner: Any) -> Optional[_OptT]: ...

def __get__(self, instance: Any, owner: Any) -> Any:
return super().__get__(instance, owner)
return super().__get__(instance, owner) # type: ignore[misc]

def __set__(self, instance: Any, value: Optional[_OptT]):
super().__set__(instance, value) # type: ignore[misc]

def serialize(self, x):
"""
Expand Down Expand Up @@ -485,13 +496,13 @@ def next_in_enumeration(self, value):
return None


class OptionalParameter(OptionalParameterMixin, Parameter[Optional[str]]):
class OptionalParameter(OptionalParameterMixin[str], Parameter[Optional[str]]):
"""Class to parse optional parameters."""

expected_type = str


class OptionalStrParameter(OptionalParameterMixin, Parameter[Optional[str]]):
class OptionalStrParameter(OptionalParameterMixin[str], Parameter[Optional[str]]):
"""Class to parse optional str parameters."""

expected_type = str
Expand Down Expand Up @@ -798,7 +809,7 @@ def next_in_enumeration(self, value):
return value + 1


class OptionalIntParameter(OptionalParameterMixin, IntParameter):
class OptionalIntParameter(OptionalParameterMixin[int], IntParameter): # type: ignore[misc]
"""Class to parse optional int parameters."""

expected_type = int
Expand All @@ -816,7 +827,7 @@ def parse(self, x):
return float(x)


class OptionalFloatParameter(OptionalParameterMixin, FloatParameter):
class OptionalFloatParameter(OptionalParameterMixin[float], FloatParameter): # type: ignore[misc]
"""Class to parse optional float parameters."""

expected_type = float
Expand Down Expand Up @@ -897,7 +908,7 @@ def _parser_kwargs(self, *args, **kwargs):
return parser_kwargs


class OptionalBoolParameter(OptionalParameterMixin, BoolParameter):
class OptionalBoolParameter(OptionalParameterMixin[bool], BoolParameter): # type: ignore[misc]
"""Class to parse optional bool parameters."""

expected_type = bool
Expand Down Expand Up @@ -1299,7 +1310,7 @@ def serialize(self, x):
return json.dumps(x, cls=_DictParamEncoder)


class OptionalDictParameter(OptionalParameterMixin, DictParameter):
class OptionalDictParameter(OptionalParameterMixin[FrozenOrderedDict], DictParameter): # type: ignore[misc]
"""Class to parse optional dict parameters."""

expected_type = FrozenOrderedDict
Expand Down Expand Up @@ -1454,7 +1465,7 @@ def serialize(self, x):
return json.dumps(x, cls=_DictParamEncoder)


class OptionalListParameter(OptionalParameterMixin, ListParameter):
class OptionalListParameter(OptionalParameterMixin[ListT], ListParameter): # type: ignore[misc]
"""Class to parse optional list parameters."""

expected_type = tuple
Expand Down Expand Up @@ -1525,7 +1536,7 @@ def _convert_iterable_to_tuple(self, x):
return tuple(x)


class OptionalTupleParameter(OptionalParameterMixin, TupleParameter):
class OptionalTupleParameter(OptionalParameterMixin[ListT], TupleParameter): # type: ignore[misc]
"""Class to parse optional tuple parameters."""

expected_type = tuple
Expand Down Expand Up @@ -1588,13 +1599,13 @@ def __init__(
"""
if var_type is None:
raise ParameterException("var_type must be specified")
self._var_type = var_type
self._var_type: Type[NumericalType] = var_type
if min_value is None:
raise ParameterException("min_value must be specified")
self._min_value = min_value
self._min_value: NumericalType = min_value
if max_value is None:
raise ParameterException("max_value must be specified")
self._max_value = max_value
self._max_value: NumericalType = max_value
self._left_op = left_op
self._right_op = right_op
self._permitted_range = "{var_type} in {left_endpoint}{min_value}, {max_value}{right_endpoint}".format(
Expand All @@ -1604,7 +1615,7 @@ def __init__(
left_endpoint="[" if left_op == operator.le else "(",
right_endpoint=")" if right_op == operator.lt else "]",
)
super().__init__(default=default, **kwargs)
super().__init__(default=default, **kwargs) # type: ignore[arg-type]
if self.description:
self.description += " "
else:
Expand All @@ -1619,15 +1630,15 @@ def parse(self, x):
raise ValueError("{s} is not in the set of {permitted_range}".format(s=x, permitted_range=self._permitted_range))


class OptionalNumericalParameter(OptionalParameterMixin, NumericalParameter):
class OptionalNumericalParameter(OptionalParameterMixin[NumericalType], NumericalParameter[NumericalType]): # type: ignore[misc]
"""Class to parse optional numerical parameters."""

def __init__(
self,
default: Union[Optional[NumericalType], _NoValueType] = _no_value,
**kwargs: Unpack[_ParameterKwargs],
):
super().__init__(default=default, **kwargs)
NumericalParameter.__init__(self, default=default, **kwargs) # type: ignore[arg-type, misc]
self.expected_type = self._var_type


Expand Down Expand Up @@ -1664,7 +1675,7 @@ def __init__(
default: Union[ChoiceType, _NoValueType] = _no_value,
*,
choices: Optional[Sequence[ChoiceType]] = None,
var_type: Type[ChoiceType] = str,
var_type: Type[ChoiceType] = str, # type: ignore[assignment]
**kwargs: Unpack[_ParameterKwargs],
):
"""
Expand Down Expand Up @@ -1726,7 +1737,7 @@ class MyTask(luigi.Task):

_sep = ","

@overload
@overload # type: ignore[override]
def __get__(self, instance: None, owner: Any) -> "Parameter[Tuple[ChoiceType, ...]]": ...

@overload
Expand All @@ -1738,7 +1749,7 @@ def __get__(self, instance: Any, owner: Any) -> Any:
def __init__(
self,
default: Union[Tuple[ChoiceType, ...], _NoValueType] = _no_value,
var_type: Type[ChoiceType] = str,
var_type: Type[ChoiceType] = str, # type: ignore[assignment]
choices: Optional[Sequence[ChoiceType]] = None,
**kwargs: Unpack[_ParameterKwargs],
):
Expand All @@ -1758,17 +1769,17 @@ def serialize(self, x):
return self._sep.join(x)


class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter[ChoiceType]):
class OptionalChoiceParameter(OptionalParameterMixin[ChoiceType], ChoiceParameter[ChoiceType]): # type: ignore[misc]
"""Class to parse optional choice parameters."""

def __init__(
self,
default: Union[Optional[ChoiceType], _NoValueType] = _no_value,
var_type: Type[ChoiceType] = str,
var_type: Type[ChoiceType] = str, # type: ignore[assignment]
choices: Optional[Sequence[ChoiceType]] = None,
**kwargs: Unpack[_ParameterKwargs],
):
super().__init__(default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type]
ChoiceParameter.__init__(self, default=default, var_type=var_type, choices=choices, **kwargs) # type: ignore[arg-type, misc]
self.expected_type = self._var_type


Expand Down Expand Up @@ -1831,7 +1842,7 @@ def normalize(self, x):
return path


class OptionalPathParameter(OptionalParameter, PathParameter):
class OptionalPathParameter(OptionalParameter, PathParameter): # type: ignore[misc]
"""Class to parse optional path parameters."""

expected_type = (str, Path) # type: ignore
expected_type = (str, Path) # type: ignore[assignment]
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ module = [
"luigi.contrib.sqla",
"luigi.interface",
"luigi.notifications",
"luigi.parameter",
"luigi.tools.range",
"luigi.worker",
]
Expand Down
Loading