Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pandera/api/dataframe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,12 @@ def empty(
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls.pydantic_validate,
if issubclass(_source_type, cls):
return core_schema.no_info_plain_validator_function(
cls.pydantic_validate,
)
return core_schema.no_info_after_validator_function(
cls.validate, _handler(_source_type)
)

@classmethod
Expand Down
46 changes: 45 additions & 1 deletion tests/core/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Unit tests for pydantic compatibility."""

# pylint:disable=too-few-public-methods,missing-class-docstring
from typing import Optional
from typing import Annotated, Optional

import pandas as pd
import pytest

import pandera as pa
from pandera.engines import pydantic_version
from pandera.errors import SchemaError
from pandera.typing import DataFrame, Series

try:
Expand Down Expand Up @@ -54,6 +55,29 @@ class SeriesSchemaPydantic(BaseModel):
pa_index: Optional[pa.Index]


if PYDANTIC_V2:
from pydantic import ConfigDict

class AnnotatedDfPydantic(BaseModel): # type: ignore[no-redef]
"""Test pydantic model with annotated dataframe model."""

# arbitrary_types_allowed=True required for pandas.DataFrame
model_config = ConfigDict(arbitrary_types_allowed=True)

df: Annotated[pd.DataFrame, SimpleSchema]

else:

class AnnotatedDfPydantic(BaseModel): # type: ignore[no-redef]
"""Test pydantic model with annotated dataframe model."""

# arbitrary_types_allowed=True required for pandas.DataFrame
class Config:
arbitrary_types_allowed = True

df: Annotated[pd.DataFrame, SimpleSchema]


def test_typed_dataframe():
"""Test that typed DataFrame is compatible with pydantic."""
valid_df = pd.DataFrame({"str_col": ["hello", "world"]})
Expand All @@ -64,6 +88,16 @@ def test_typed_dataframe():
TypedDfPydantic(df=invalid_df)


def test_annotated_dataframe_model():
"""Test that annotated DataFrame is compatible with pydantic."""
valid_df = pd.DataFrame({"str_col": ["hello", "world"]})
assert isinstance(AnnotatedDfPydantic(df=valid_df), AnnotatedDfPydantic)

invalid_df = pd.DataFrame({"str_col": ["hello", "hello"]})
with pytest.raises(SchemaError):
AnnotatedDfPydantic(df=invalid_df) # right type, wrong schema


@pytest.mark.skipif(
not PYDANTIC_V2,
reason="Pydantic <2 cannot catch the invalid dataframe model error",
Expand All @@ -88,6 +122,16 @@ class PydanticModel(BaseModel):
PydanticModel(pa_schema=InvalidSchema)


@pytest.mark.skipif(
not PYDANTIC_V2,
reason="Pydantic <2 cannot catch the invalid dataframe model error",
)
def test_invalid_annotated_dataframe():
"""Test that an invalid annotated DataFrame is recognized by pandera."""
with pytest.raises(ValidationError):
AnnotatedDfPydantic(df=1)


def test_dataframemodel():
"""Test that DataFrameModel is compatible with pydantic."""
assert isinstance(
Expand Down
Loading