From 54e3738bb0a2890516c2702d4968ab9682f4e0bb Mon Sep 17 00:00:00 2001 From: Benito Buchheim Date: Fri, 22 Aug 2025 10:09:31 +0200 Subject: [PATCH 1/3] Using TypeAdapter for validation Using TypeAdapter instead of the raw BaseModel to allow validating more complex types (e.g. discriminated union) --- src/pandantic/basemodel.py | 4 +-- src/pandantic/plugins/pandas.py | 42 +++++++++++++--------- src/pandantic/types.py | 3 -- src/pandantic/validators/pandas.py | 25 +++++++------ tests/test_discriminated_union.py | 56 ++++++++++++++++++++++++++++++ 5 files changed, 98 insertions(+), 32 deletions(-) create mode 100644 tests/test_discriminated_union.py diff --git a/src/pandantic/basemodel.py b/src/pandantic/basemodel.py index d2de2f3..421e2b7 100644 --- a/src/pandantic/basemodel.py +++ b/src/pandantic/basemodel.py @@ -7,7 +7,7 @@ import pandas as pd -from pandantic.types import SchemaTypes, TableTypes +from pandantic.types import TableTypes from pandantic.validators.base import BaseValidator from pandantic.validators.pandas import PandasValidator @@ -15,7 +15,7 @@ class CoreValidator: """An implementation of the Pydantic BaseValidator.""" - def __init__(self, schema: SchemaTypes): + def __init__(self, schema: Any): self.schema = schema def _get_implementation(self, dataframe: TableTypes) -> BaseValidator: diff --git a/src/pandantic/plugins/pandas.py b/src/pandantic/plugins/pandas.py index 2ad8e8e..fa6d309 100644 --- a/src/pandantic/plugins/pandas.py +++ b/src/pandantic/plugins/pandas.py @@ -15,7 +15,7 @@ from typing import Any, Optional import pandas as pd -from pydantic import BaseModel, ValidationError +from pydantic import TypeAdapter, ValidationError from pandantic.basemodel import CoreValidator @@ -25,6 +25,15 @@ @pd.api.extensions.register_dataframe_accessor("pandantic") class PydanticAccessor: + def _validate_schema_type(self, schema: Any) -> None: + """Raise TypeError if schema is not a valid pydantic model, type, or Union.""" + try: + TypeAdapter(schema) + except Exception as e: + raise TypeError( + "Arg `schema` must be a valid pydantic model, type, or Union!" + ) from e + def __init__(self, pandas_obj: pd.DataFrame): assert isinstance(pandas_obj, pd.DataFrame), "Only works with DataFrames!" if not any(isinstance(col, str) for col in pandas_obj.columns): @@ -37,13 +46,11 @@ def obj(self) -> pd.DataFrame: def validate( self, - schema: BaseModel, + schema: Any, n_jobs: Optional[int] = None, **kwargs: Optional[dict[str, Any]], ) -> bool: - if not isinstance(schema, type(BaseModel)): - raise TypeError("Arg `schema` must be a pydantic.BaseModel subclass!") - + self._validate_schema_type(schema) schema_validator = CoreValidator(schema) # type: ignore try: _ = schema_validator.validate( @@ -60,19 +67,14 @@ def validate( def filter( self, - schema: BaseModel, + schema: Any, n_jobs: Optional[int] = None, verbose: bool = True, **kwargs: Optional[dict[str, Any]], ) -> pd.DataFrame: - if not isinstance(schema, type(BaseModel)): - raise TypeError("Arg `schema` must be a pydantic.BaseModel subclass!") - + self._validate_schema_type(schema) schema_validator = CoreValidator(schema) # type: ignore - if verbose: - errors = "log" - else: - errors = "skip" + errors = "log" if verbose else "skip" filtered_df: pd.DataFrame = schema_validator.validate( dataframe=self.obj, errors=errors, @@ -84,13 +86,17 @@ def filter( def itertuples( self, - schema: BaseModel, + schema: Any, verbose: bool = True, ) -> Iterable[tuple[Any, ...]]: """Same as normal .itertuples(), except invalid rows are skipped.""" + self._validate_schema_type(schema) + adapter = TypeAdapter(schema) for row in self.obj.itertuples(name=None): try: - _ = schema(**dict(zip(self.obj.columns, row[1:]))) # type: ignore + _ = adapter.validate_python( + dict(zip(self.obj.columns, row[1:])) + ) except ValidationError as e: if verbose: logger.info(f"Invalid row {row} with error: {e}") @@ -98,16 +104,18 @@ def itertuples( yield row def iterrows( # type: ignore[no-untyped-def] - self, schema: BaseModel, verbose: bool = True, **kwargs + self, schema: Any, verbose: bool = True, **kwargs ) -> Iterable[tuple[Hashable, pd.Series]]: # type: ignore[type-arg] """Same as normal .iterrows(), except invalid rows are skipped.""" + self._validate_schema_type(schema) schema_validator = CoreValidator(schema) for i, _ in schema_validator.iterate(dataframe=self.obj, context=kwargs, verbose=verbose): yield i, self.obj.loc[i] # type: ignore[call-overload] def iterschemas( # type: ignore[no-untyped-def] - self, schema: BaseModel, verbose: bool = True, **kwargs + self, schema: Any, verbose: bool = True, **kwargs ) -> Iterable[tuple[Hashable, Any]]: """Iterate over DataFrame rows as validated schema models.""" + self._validate_schema_type(schema) schema_validator = CoreValidator(schema) return schema_validator.iterate(dataframe=self.obj, context=kwargs, verbose=verbose) diff --git a/src/pandantic/types.py b/src/pandantic/types.py index 8c7090c..969eece 100644 --- a/src/pandantic/types.py +++ b/src/pandantic/types.py @@ -1,8 +1,5 @@ from typing import Union, TypeAlias import pandas as pd -import pydantic - -SchemaTypes: TypeAlias = Union[type[pydantic.BaseModel]] TableTypes: TypeAlias = Union[pd.DataFrame] diff --git a/src/pandantic/validators/pandas.py b/src/pandantic/validators/pandas.py index 98ea205..6fba003 100644 --- a/src/pandantic/validators/pandas.py +++ b/src/pandantic/validators/pandas.py @@ -9,15 +9,16 @@ Process, Queue, ) -from pydantic import ValidationError -from pandantic.types import SchemaTypes +from pydantic import ValidationError, TypeAdapter + from pandantic.validators.base import BaseValidator class PandasValidator(BaseValidator): - def __init__(self, schema: SchemaTypes): + def __init__(self, schema: Any): self.schema = schema + self.adapter = TypeAdapter(schema) def validate( self, @@ -47,6 +48,10 @@ def validate( # check for extra columns and handle strict mode # NOTE: this will need to be abstracted to handle different types of schema objects if strict: + if not hasattr(self.schema, "model_fields"): + # TODO: Implement for complex schemas which rely on TypeAdapter for validation + raise ValueError("Strict mode is only supported for BaseModel schemas.") + extras = { col for col in dataframe.columns if col not in self.schema.model_fields.keys() } @@ -108,8 +113,8 @@ def validate( else: for index, row_dict in dataframe.to_dict("index").items(): try: - self.schema.model_validate( - obj=row_dict, + self.adapter.validate_python( + row_dict, context=context, ) except ValidationError as exc: # pylint: disable=broad-exception-caught @@ -147,8 +152,8 @@ def _validate_chunk( for index, row_dict in chunk.items(): try: - self.schema.model_validate( - obj=row_dict, + self.adapter.validate_python( + row_dict, context=context, ) except ValidationError as exc: # pylint: disable=broad-exception-caught @@ -169,14 +174,14 @@ def iterate( dict[str, Any] ] = None, # pylint: disable=consider-alternative-union-syntax,useless-suppression verbose: bool = True, - ) -> Iterable[tuple[Hashable, SchemaTypes]]: + ) -> Iterable[tuple[Hashable, Any]]: """Iterate over a DataFrame and yield validated schema models.""" for i, row in dataframe.iterrows(): try: yield ( i, - self.schema.model_validate( - obj=row.to_dict(), + self.adapter.validate_python( + row.to_dict(), context=context, ), ) diff --git a/tests/test_discriminated_union.py b/tests/test_discriminated_union.py new file mode 100644 index 0000000..83f8eb6 --- /dev/null +++ b/tests/test_discriminated_union.py @@ -0,0 +1,56 @@ +"""Test complex Union types.""" + +from typing import Annotated, Literal +import pandas as pd +import pytest +from pydantic import BaseModel, Field + +from pandantic import Pandantic + + +class Cat(BaseModel): + pet_type: Literal['cat'] + extra_lives_left: int + +class Dog(BaseModel): + pet_type: Literal['dog'] + extra_lives_left: Literal[0] + +Pet = Annotated[Cat | Dog, Field(discriminator="pet_type")] + +def test_valid_df_passes(): + """Test that a valid DataFrame with discriminated unions passes validation.""" + + # GIVEN + validator = Pandantic(schema=Pet) + + example_df_valid = pd.DataFrame( + data={ + "pet_type": ["cat", "dog"], + "extra_lives_left": [9, 0] + } + ) + + validator.validate( + dataframe=example_df_valid + ) + +def test_invalid_df_raises(): + """Test that an invalid DataFrame with discriminated unions raises a ValueError.""" + + # GIVEN + validator = Pandantic(schema=Pet) + + example_df_invalid = pd.DataFrame( + data={ + "pet_type": ["cat", "dog"], + "extra_lives_left": [9, 1] + } + ) + + # THEN + with pytest.raises(ValueError): + # WHEN + validator.validate( + dataframe=example_df_invalid + ) \ No newline at end of file From b7147580763783bc657fa6ffa7a63f177754972c Mon Sep 17 00:00:00 2001 From: Benito Buchheim Date: Wed, 3 Sep 2025 09:44:38 +0200 Subject: [PATCH 2/3] removing unused variable --- src/pandantic/plugins/pandas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pandantic/plugins/pandas.py b/src/pandantic/plugins/pandas.py index fa6d309..f40d4f4 100644 --- a/src/pandantic/plugins/pandas.py +++ b/src/pandantic/plugins/pandas.py @@ -53,7 +53,7 @@ def validate( self._validate_schema_type(schema) schema_validator = CoreValidator(schema) # type: ignore try: - _ = schema_validator.validate( + schema_validator.validate( dataframe=self.obj, errors="raise", context=kwargs, From f1691ac7174c101c89b83efc38bbaa18cc37a5ea Mon Sep 17 00:00:00 2001 From: Benito Buchheim Date: Wed, 3 Sep 2025 09:46:52 +0200 Subject: [PATCH 3/3] removing unused variable --- src/pandantic/plugins/pandas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pandantic/plugins/pandas.py b/src/pandantic/plugins/pandas.py index f40d4f4..c13a1e3 100644 --- a/src/pandantic/plugins/pandas.py +++ b/src/pandantic/plugins/pandas.py @@ -94,7 +94,7 @@ def itertuples( adapter = TypeAdapter(schema) for row in self.obj.itertuples(name=None): try: - _ = adapter.validate_python( + adapter.validate_python( dict(zip(self.obj.columns, row[1:])) ) except ValidationError as e: