Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 225 additions & 115 deletions pandera/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import sys
import types
from collections.abc import Callable, Iterable
from typing import (
from typing import ( # noqa
Any,
Dict,
List,
NoReturn,
Optional,
Tuple,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -589,139 +592,246 @@ def check_types(
inplace=inplace,
)

class _AnnotationInfoWithDataFrameModelTree:
def __init__(
self,
annotation_info: AnnotationInfo,
children: list["_AnnotationInfoWithDataFrameModelTree"]
| None = None,
dataframe_model: DataFrameModel | None = None,
) -> None:
if children and dataframe_model:
raise ValueError(
"At most one of children or dataframe_model should be set"
)
self._children = children
self._dataframe_model = dataframe_model
self._annotation_info = annotation_info

def __repr__(self) -> str:
return f"_AnnotationInfoWithDataFrameModelTree(annotation_info={self._annotation_info}, children={self._children}, dataframe_model={self._dataframe_model})"

@property
def annotation_info(self) -> AnnotationInfo:
return self._annotation_info

@property
def dataframe_model(self) -> DataFrameModel | None:
return self._dataframe_model

@property
def children(
self,
) -> list["_AnnotationInfoWithDataFrameModelTree"] | None:
return self._children

def child_at_index(
self, index: int
) -> Union["_AnnotationInfoWithDataFrameModelTree", None]:
"""
Returns the child at the given index, if it exists. Otherwise None.
"""
if self.children and len(self.children) > index:
return self.children[index]
else:
return None

@staticmethod
def from_annotation(
annotation: type,
) -> "_AnnotationInfoWithDataFrameModelTree":
annotation_info = AnnotationInfo(annotation)
if annotation_info.is_generic_df:
# Base condition
return _AnnotationInfoWithDataFrameModelTree(
annotation_info=annotation_info,
dataframe_model=cast(DataFrameModel, annotation_info.arg),
)
elif annotation_info.args and len(annotation_info.args) > 0:
# Recursive condition
return _AnnotationInfoWithDataFrameModelTree(
annotation_info=annotation_info,
children=[
_AnnotationInfoWithDataFrameModelTree.from_annotation(
arg
)
for arg in annotation_info.args
],
)
else:
# Base condition
return _AnnotationInfoWithDataFrameModelTree(
annotation_info=annotation_info, children=None
)

# Front-load annotation parsing
# @functools.lru_cache
def _get_annotated_schema_models(
wrapped: Callable,
) -> dict[
str,
Iterable[
tuple[Union[DataFrameModel, None], Union[AnnotationInfo, None]]
],
_AnnotationInfoWithDataFrameModelTree,
]:
annotated_schema_models: dict[
str,
Iterable[
tuple[Union[DataFrameModel, None], Union[AnnotationInfo, None]]
],
] = {}
for arg_name_, annotation in get_type_hints(
wrapped, include_extras=True
).items():
annotation_info = AnnotationInfo(annotation)
if not annotation_info.is_generic_df:
if annotation_info.origin == Union:
annotation_model_pairs = []
for annot in annotation_info.args: # type: ignore[union-attr]
sub_annotation_info = AnnotationInfo(annot)
if not sub_annotation_info.is_generic_df:
continue

schema_model = cast(
DataFrameModel, sub_annotation_info.arg
)
annotation_model_pairs.append(
(schema_model, sub_annotation_info)
)
else:
continue
else:
schema_model = cast(DataFrameModel, annotation_info.arg)
annotation_model_pairs = [(schema_model, annotation_info)]
return {
arg_name: _AnnotationInfoWithDataFrameModelTree.from_annotation(
annotation
)
for arg_name, annotation in get_type_hints(
wrapped, include_extras=True
).items()
}

def _check_arg_value_against_model(
arg_value: Any,
schema_model: DataFrameModel | None,
annotation_info: AnnotationInfo,
) -> Any:
if schema_model is None or (
annotation_info.optional and arg_value is None
):
# the pandera.schema attribute should only be available when
# schema.validate has been called in the DF. There's probably
# a better way of doing this
return arg_value

annotated_schema_models[arg_name_] = annotation_model_pairs
return annotated_schema_models
config = schema_model.__config__
data_container_type = annotation_info.origin
schema = schema_model.to_schema()

def _check_arg(arg_name: str, arg_value: Any) -> Any:
"""
Validate function's argument if annotated with a schema, else
pass-through.
"""
annotated_schema_models = _get_annotated_schema_models(wrapped)
annotation_model_pairs = annotated_schema_models.get(
arg_name, [(None, None)]
)
if data_container_type and config and config.from_format:
arg_value = data_container_type.from_format(arg_value, config)

if not annotation_model_pairs:
# Don't do checks if value is still a built-in type
if isinstance(
arg_value, (int, str, bool, float, dict, list, tuple, set)
):
return arg_value

error_handler = ErrorHandler(lazy=True)
for schema_model, annotation_info in annotation_model_pairs:
if schema_model is None:
return arg_value

if (
annotation_info
and not (annotation_info.optional and arg_value is None)
# the pandera.schema attribute should only be available when
# schema.validate has been called in the DF. There's probably
# a better way of doing this
):
config = schema_model.__config__
data_container_type = annotation_info.origin
schema = schema_model.to_schema()

if data_container_type and config and config.from_format:
arg_value = data_container_type.from_format(
arg_value, config
)

# Don't do checks if value is still a built-in type
if isinstance(
arg_value, (int, str, bool, float, dict, list, tuple, set)
):
return arg_value

if (
not hasattr(arg_value, "pandera")
or arg_value.pandera.schema is None
# don't re-validate a dataframe that contains the same
# exact schema
or arg_value.pandera.schema != schema
):
try:
arg_value = schema.validate(
arg_value,
head,
tail,
sample,
random_state,
lazy,
inplace,
)
except errors.SchemaError as e:
error_handler.collect_error(
get_error_category(
errors.SchemaErrorReason.INVALID_TYPE
),
errors.SchemaErrorReason.INVALID_TYPE,
_parse_schema_error(
"check_types",
wrapped,
schema,
arg_value,
e,
errors.SchemaErrorReason.INVALID_TYPE,
),
)
continue
if (
hasattr(arg_value, "pandera")
and arg_value.pandera.schema is not None
and arg_value.pandera.schema == schema
):
return arg_value

if data_container_type and config and config.to_format:
arg_value = data_container_type.to_format(
arg_value, config
)
arg_value = schema.validate(
arg_value,
head,
tail,
sample,
random_state,
lazy,
inplace,
)

return arg_value
if data_container_type and config and config.to_format:
arg_value = data_container_type.to_format(arg_value, config)

if error_handler.schema_errors:
if len(error_handler.schema_errors) == 1:
raise error_handler.schema_errors[0]
return arg_value

def _check_arg_value_against_union(
arg_value: Any,
union_child_nodes: list[_AnnotationInfoWithDataFrameModelTree],
) -> Any:
# Check if the arg value matches any of the children
schema_errors = []
for child in union_child_nodes:
try:
return _check_arg_value_against_model(
arg_value, child.dataframe_model, child.annotation_info
)
except errors.SchemaError as e:
schema_errors.append(e)
if schema_errors:
raise errors.SchemaErrors(
schema=schema,
schema_errors=error_handler.schema_errors,
schema=child.dataframe_model.to_schema()
if child.dataframe_model
else None,
schema_errors=schema_errors,
data=arg_value,
)
return arg_value

def _check_arg_value_against_tuple(
arg_value: Any,
tuple_child_nodes: list[_AnnotationInfoWithDataFrameModelTree],
) -> Any:
# Each of the children should match their respective schema
for child_arg_value, child_annotation_model_tree in zip(
arg_value, tuple_child_nodes
):
_check_arg_value(child_arg_value, child_annotation_model_tree)
return arg_value

def _check_arg_value_against_list(
arg_value: Any,
list_child_node: _AnnotationInfoWithDataFrameModelTree | None,
) -> Any:
if not list_child_node:
# List of no specific type
return arg_value

# Check all children conform to the schema
for x in arg_value:
_check_arg_value(x, list_child_node)
return arg_value

def _check_arg_value_against_dict(
arg_value: Any,
dict_child_node: _AnnotationInfoWithDataFrameModelTree | None,
) -> Any:
if not dict_child_node:
# Dict of no specific value type
return arg_value

# Check all children conform to the schema
for _, x in arg_value.items():
_check_arg_value(x, dict_child_node)
return arg_value

def _check_arg_value(
arg_value: Any,
annotation_model_tree: _AnnotationInfoWithDataFrameModelTree,
) -> Any:
if annotation_model_tree.annotation_info.origin == Union:
return _check_arg_value_against_union(
arg_value, annotation_model_tree.children or []
)
# NOTE: We use string literals for Tuple, List, and Dict here to prevent
# pyupgrade from (incorrectly) converting them to tuple, list, and dict.
# This is important because we want to match both list and List, for example.
elif annotation_model_tree.annotation_info.origin in [tuple, "Tuple"]:
return _check_arg_value_against_tuple(
arg_value, annotation_model_tree.children or []
)
elif annotation_model_tree.annotation_info.origin in [list, "List"]:
return _check_arg_value_against_list(
arg_value, annotation_model_tree.child_at_index(0)
)
elif annotation_model_tree.annotation_info.origin in [dict, "Dict"]:
return _check_arg_value_against_dict(
arg_value, annotation_model_tree.child_at_index(1)
)
else:
return _check_arg_value_against_model(
arg_value,
annotation_model_tree.dataframe_model,
annotation_model_tree.annotation_info,
)

def _check_arg(arg_name: str, arg_value: Any) -> Any:
"""
Validate function's argument if annotated with a schema, else
pass-through.
"""
annotated_schema_models = _get_annotated_schema_models(wrapped)

if arg_name not in annotated_schema_models:
return arg_value

annotation_model_tree = annotated_schema_models[arg_name]

return _check_arg_value(arg_value, annotation_model_tree)

sig = inspect.signature(wrapped)

Expand Down
Loading
Loading