Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 130 additions & 13 deletions pandera/api/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,9 @@ def less_than_or_equal_to(cls, max_value: Any, **kwargs) -> "Check":
@classmethod
def in_range(
cls,
min_value: T,
max_value: T,
*args,
min_value: T | None = None,
max_value: T | None = None,
include_min: bool = True,
include_max: bool = True,
**kwargs,
Expand All @@ -353,6 +354,12 @@ def in_range(
Both endpoints must be a type comparable to the dtype of the
data object to be validated.

:param args: Positional arguments. If a single value is provided, it
represents the exact value. If two values are provided, they
represent min_value and max_value respectively. If three values
are provided, they represent min_value, max_value, and include_min
respectively. If four values are provided, they represent min_value,
max_value, include_min, and include_max respectively.
:param min_value: Left / lower endpoint of the interval.
:param max_value: Right / upper endpoint of the interval. Must not be
smaller than min_value.
Expand All @@ -362,7 +369,63 @@ def in_range(
:param include_max: Defines whether min_value is also an allowed value
(the default) or whether all values must be strictly smaller than
max_value.

:example:

>>> import pandera as pa
>>>
>>> positional_check = pa.Check.in_range(0, 1)
>>> positional_include_min_check = pa.Check.in_range(0, 1, True)
>>> positional_include_min_max_check = pa.Check.in_range(0, 1, True, True)
>>> keyword_check = pa.Check.in_range(min_value=0, max_value=1)
>>> keyword_include_min_check = pa.Check.in_range(min_value=0, max_value=1, include_min=True)
>>> keyword_include_min_max_check = pa.Check.in_range(min_value=0, max_value=1, include_min=True, include_max=True)
"""
# Handle positional arguments for backward compatibility
# in_range(0, 1) or in_range(0, 1, True, False) should work
# Track whether values were provided (vs being default None)
min_value_provided = min_value is not None
max_value_provided = max_value is not None

if len(args) >= 2:
min_value = args[0]
max_value = args[1]
min_value_provided = True
max_value_provided = True
elif len(args) == 1:
# If only one positional arg is provided without keyword args,
# raise TypeError to match original behavior
if not min_value_provided and not max_value_provided:
raise TypeError(
"in_range() missing required argument: 'max_value'"
)
# One positional arg with one keyword arg
if not min_value_provided:
min_value = args[0]
min_value_provided = True
elif not max_value_provided:
max_value = args[0]
max_value_provided = True
if len(args) >= 3:
include_min = args[2]
if len(args) >= 4:
include_max = args[3]

# Check for missing required arguments
if not min_value_provided and not max_value_provided:
raise TypeError(
"in_range() missing required arguments: 'min_value' and 'max_value'"
)
if not min_value_provided:
raise TypeError(
"in_range() missing required argument: 'min_value'"
)
if not max_value_provided:
raise TypeError(
"in_range() missing required argument: 'max_value'"
)

# Check for invalid None values (explicitly passed)
if min_value is None:
raise ValueError("min_value must not be None")
if max_value is None:
Expand All @@ -386,7 +449,9 @@ def in_range(
)

@classmethod
def isin(cls, allowed_values: Iterable, **kwargs) -> "Check":
def isin(
cls, *args, allowed_values: Iterable | None = None, **kwargs
) -> "Check":
"""Ensure only allowed values occur within a series.

This checks whether all elements of a data object
Expand All @@ -396,51 +461,103 @@ def isin(cls, allowed_values: Iterable, **kwargs) -> "Check":
in allowed_values at least once can meet this condition. If you
want to check for substrings use :meth:`Check.str_contains`.

:param args: Positional arguments. If a single list/tuple is provided, it
represents the allowed values. If multiple values are provided, they
represent the allowed values.
:param allowed_values: The set of allowed values. May be any iterable.
:param kwargs: key-word arguments passed into the `Check` initializer.

:example:

>>> import pandera as pa
>>>
>>> positional_check = pa.Check.isin([1, 2, 3])
>>> positional_values_check = pa.Check.isin(1, 2, 3)
>>> keyword_check = pa.Check.isin(allowed_values=[1, 2, 3])
>>> keyword_values_check = pa.Check.isin(allowed_values=[1, 2, 3])
"""
values: Iterable
if allowed_values is not None:
values = allowed_values
elif len(args) == 1 and hasattr(args[0], "__iter__"):
# Single iterable passed as positional arg (including strings)
values = args[0]
elif args:
# Multiple values passed as positional args
values = args
else:
raise ValueError(
"Argument allowed_values must be provided. "
"Use Check.isin([1, 2, 3]) or Check.isin(allowed_values=[1, 2, 3])"
)
try:
allowed_values_mod = frozenset(allowed_values)
allowed_values_mod = frozenset(values)
except TypeError as exc:
raise ValueError(
f"Argument allowed_values must be iterable. Got {allowed_values}"
f"Argument allowed_values must be iterable. Got {values}"
) from exc
return cls.from_builtin_check_name(
"isin",
kwargs,
error=f"isin({allowed_values})",
error=f"isin({values})",
defaults={"determined_by_unique": True},
statistics={"allowed_values": allowed_values},
statistics={"allowed_values": values},
allowed_values=allowed_values_mod,
)

@classmethod
def notin(cls, forbidden_values: Iterable, **kwargs) -> "Check":
def notin(
cls, *args, forbidden_values: Iterable | None = None, **kwargs
) -> "Check":
"""Ensure some defined values don't occur within a series.

Like :meth:`Check.isin` this check operates on single characters if
it is applied on strings. If forbidden_values is a string, it is
understood as set of prohibited characters. Any string of length > 1
can't be in it by design.

:param args: Positional arguments. If a single list/tuple is provided, it
represents the forbidden values. If multiple values are provided, they
represent the forbidden values.
:param forbidden_values: The set of values which should not occur. May
be any iterable.
:param raise_warning: if True, check raises SchemaWarning instead of
SchemaError on validation.

:example:

>>> import pandera as pa
>>>
>>> positional_check = pa.Check.notin([1, 2, 3])
>>> positional_values_check = pa.Check.notin(1, 2, 3)
>>> keyword_check = pa.Check.notin(forbidden_values=[1, 2, 3])
"""
values: Iterable
if forbidden_values is not None:
values = forbidden_values
elif len(args) == 1 and hasattr(args[0], "__iter__"):
# Single iterable passed as positional arg (including strings)
values = args[0]
elif args:
# Multiple values passed as positional args
values = args
else:
raise ValueError(
"Argument forbidden_values must be provided. "
"Use Check.notin([1, 2, 3]) or Check.notin(forbidden_values=[1, 2, 3])"
)
try:
forbidden_values_mod = frozenset(forbidden_values)
forbidden_values_mod = frozenset(values)
except TypeError as exc:
raise ValueError(
"Argument forbidden_values must be iterable. "
f"Got {forbidden_values}"
f"Argument forbidden_values must be iterable. Got {values}"
) from exc
return cls.from_builtin_check_name(
"notin",
kwargs,
error=f"notin({forbidden_values})",
error=f"notin({values})",
defaults={"determined_by_unique": True},
statistics={"forbidden_values": forbidden_values},
statistics={"forbidden_values": values},
forbidden_values=forbidden_values_mod,
)

Expand Down
22 changes: 15 additions & 7 deletions pandera/api/dataframe/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Callable, Iterable
from typing import (
Any,
Optional,
Union,
cast,
)
Expand Down Expand Up @@ -119,16 +118,25 @@ def Field(
ge: Any | None = None,
lt: Any | None = None,
le: Any | None = None,
in_range: dict[str, Any] | None = None,
in_range: Union[
tuple[Any, Any],
tuple[Any, Any, bool, bool],
tuple[Any, Any, bool, bool, bool],
tuple[Any, Any, bool, bool, bool, bool],
dict[str, Any],
None,
] = None,
isin: Iterable[Any] | None = None,
notin: Iterable[Any] | None = None,
str_contains: str | None = None,
str_endswith: str | None = None,
str_length: int
| tuple[int]
| tuple[int, int]
| dict[str, int]
| None = None,
str_length: Union[
int,
tuple[int],
tuple[int, int],
dict[str, int],
None,
] = None,
str_matches: str | None = None,
str_startswith: str | None = None,
nullable: bool = False,
Expand Down
23 changes: 16 additions & 7 deletions pandera/api/pyspark/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from collections.abc import Callable, Iterable
from typing import (
Any,
Optional,
TypeVar,
Union,
)

from pandera.api.dataframe.components import ComponentSchema
Expand All @@ -23,16 +23,25 @@ def Field(
ge: Any = None,
lt: Any = None,
le: Any = None,
in_range: dict[str, Any] | None = None,
in_range: Union[
tuple[Any, Any],
tuple[Any, Any, bool, bool],
tuple[Any, Any, bool, bool, bool],
tuple[Any, Any, bool, bool, bool, bool],
dict[str, Any],
None,
] = None,
isin: Iterable[Any] | None = None,
notin: Iterable[Any] | None = None,
str_contains: str | None = None,
str_endswith: str | None = None,
str_length: int
| tuple[int]
| tuple[int, int]
| dict[str, int]
| None = None,
str_length: Union[
int,
tuple[int],
tuple[int, int],
dict[str, int],
None,
] = None,
str_matches: str | None = None,
str_startswith: str | None = None,
nullable: bool = False,
Expand Down
Loading
Loading