Skip to content
1 change: 1 addition & 0 deletions pandera/api/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from pandera.api.pyspark.components import Column
from pandera.api.pyspark.container import DataFrameSchema
from pandera.api.pyspark.model import DataFrameModel

Check warning on line 5 in pandera/api/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

pandera/api/pyspark/__init__.py#L5

Added line #L5 was not covered by tests
82 changes: 80 additions & 2 deletions pandera/typing/pyspark.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
"""Pandera type annotations for Pyspark Pandas."""

from typing import TYPE_CHECKING, Generic, TypeVar
import functools
import json
from typing import TYPE_CHECKING, Generic, TypeVar, Any, get_args

Check warning on line 5 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L3-L5

Added lines #L3 - L5 were not covered by tests

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

Check warning on line 8 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L7-L8

Added lines #L7 - L8 were not covered by tests

from pandera.engines import PYDANTIC_V2
from pandera.errors import SchemaInitError

Check warning on line 11 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L10-L11

Added lines #L10 - L11 were not covered by tests
from pandera.typing.common import (
DataFrameBase,
GenericDtype,
IndexBase,
SeriesBase,
_GenericAlias,
)
from pandera.typing.pandas import DataFrameModel, _GenericAlias
from pandera.typing.pandas import DataFrameModel

Check warning on line 19 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L19

Added line #L19 was not covered by tests

try:
import pyspark.pandas as ps
Expand Down Expand Up @@ -39,6 +47,76 @@
"""Define this to override's pyspark.pandas generic type."""
return _GenericAlias(cls, item)

@classmethod
def pydantic_validate(cls, obj: Any, schema_model: T) -> ps.DataFrame:

Check warning on line 51 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L50-L51

Added lines #L50 - L51 were not covered by tests
"""
Verify that the input can be converted into a pandas dataframe that
meets all schema requirements.

This is for pydantic >= v2
"""
try:
schema = schema_model.to_schema() # type: ignore[attr-defined]
except SchemaInitError as exc:
error_message = (

Check warning on line 61 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L58-L61

Added lines #L58 - L61 were not covered by tests
f"Cannot use {cls} as a pydantic type as its "
"DataFrameModel cannot be converted to a DataFrameSchema.\n"
f"Please revisit the model to address the following errors:"
f"\n{exc}"
)
raise ValueError(error_message) from exc

Check warning on line 67 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L67

Added line #L67 was not covered by tests

validated_data = schema.validate(obj)

Check warning on line 69 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L69

Added line #L69 was not covered by tests

if validated_data.pandera.errors:
errors = json.dumps(

Check warning on line 72 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L71-L72

Added lines #L71 - L72 were not covered by tests
dict(validated_data.pandera.errors), indent=4
)
raise ValueError(errors)

Check warning on line 75 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L75

Added line #L75 was not covered by tests

return validated_data

Check warning on line 77 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L77

Added line #L77 was not covered by tests

if PYDANTIC_V2:

Check warning on line 79 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L79

Added line #L79 was not covered by tests

@classmethod
def __get_pydantic_core_schema__(

Check warning on line 82 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L81-L82

Added lines #L81 - L82 were not covered by tests
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
schema_model = get_args(_source_type)[0]
return core_schema.no_info_plain_validator_function(

Check warning on line 86 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L85-L86

Added lines #L85 - L86 were not covered by tests
functools.partial(
cls.pydantic_validate,
schema_model=schema_model,
),
)

else:

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate

Check warning on line 97 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L95-L97

Added lines #L95 - L97 were not covered by tests

@classmethod
def _get_schema_model(cls, field):
if not field.sub_fields:
raise TypeError(

Check warning on line 102 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L99-L102

Added lines #L99 - L102 were not covered by tests
"Expected a typed pandera.typing.DataFrame,"
" e.g. DataFrame[Schema]"
)
schema_model = field.sub_fields[0].type_
return schema_model

Check warning on line 107 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L106-L107

Added lines #L106 - L107 were not covered by tests

@classmethod
def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame:

Check warning on line 110 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L109-L110

Added lines #L109 - L110 were not covered by tests
"""
Verify that the input can be converted into a pandas dataframe that
meets all schema requirements.

This is for pydantic < v1
"""
schema_model = cls._get_schema_model(field)
return cls.pydantic_validate(obj, schema_model)

Check warning on line 118 in pandera/typing/pyspark.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark.py#L117-L118

Added lines #L117 - L118 were not covered by tests

# pylint:disable=too-few-public-methods,arguments-renamed
class Series(SeriesBase, ps.Series, Generic[GenericDtype]): # type: ignore [misc] # noqa
"""Representation of pandas.Series, only used for type annotation.
Expand Down
98 changes: 85 additions & 13 deletions pandera/typing/pyspark_sql.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Pandera type annotations for Pyspark."""
"""Pandera type annotations for Pyspark SQL."""

from typing import TypeVar, Union
import functools
import json
from typing import Union, TypeVar, Any, get_args, Generic

Check warning on line 5 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L3-L5

Added lines #L3 - L5 were not covered by tests

from pandera.typing.common import DataFrameBase
from pandera.typing.pandas import DataFrameModel, _GenericAlias
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

Check warning on line 8 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L7-L8

Added lines #L7 - L8 were not covered by tests

from pandera.engines import pyspark_engine, PYDANTIC_V2
from pandera.errors import SchemaInitError
from pandera.typing.common import DataFrameBase, _GenericAlias
from pandera.api.pyspark import DataFrameModel

Check warning on line 13 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L10-L13

Added lines #L10 - L13 were not covered by tests

try:
import pyspark.sql as ps
Expand All @@ -12,9 +19,9 @@
except ImportError: # pragma: no cover
PYSPARK_SQL_INSTALLED = False

if PYSPARK_SQL_INSTALLED:
from pandera.engines import pyspark_engine
T = TypeVar("T", bound=DataFrameModel)

Check warning on line 22 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L22

Added line #L22 was not covered by tests

if PYSPARK_SQL_INSTALLED:

Check warning on line 24 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L24

Added line #L24 was not covered by tests
PysparkString = pyspark_engine.String
PysparkInt = pyspark_engine.Int
PysparkLongInt = pyspark_engine.BigInt
Expand Down Expand Up @@ -43,13 +50,6 @@
PysparkBinary, # type: ignore
],
)
from typing import TYPE_CHECKING, Generic

# pylint:disable=invalid-name
if TYPE_CHECKING:
T = TypeVar("T") # pragma: no cover
else:
T = DataFrameModel

if PYSPARK_SQL_INSTALLED:
# pylint: disable=too-few-public-methods,arguments-renamed
Expand All @@ -64,3 +64,75 @@
def __class_getitem__(cls, item):
"""Define this to override's pyspark.pandas generic type."""
return _GenericAlias(cls, item) # pragma: no cover

@classmethod
def pydantic_validate(

Check warning on line 69 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L68-L69

Added lines #L68 - L69 were not covered by tests
cls, obj: ps.DataFrame, schema_model: T
) -> ps.DataFrame:
"""
Verify that the input can be converted into a pandas dataframe that
meets all schema requirements.

This is for pydantic V1 and V2.
"""
try:
schema = schema_model.to_schema()
except SchemaInitError as exc:
error_message = (

Check warning on line 81 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L78-L81

Added lines #L78 - L81 were not covered by tests
f"Cannot use {cls} as a pydantic type as its "
"DataFrameModel cannot be converted to a DataFrameSchema.\n"
f"Please revisit the model to address the following errors:"
f"\n{exc}"
)
raise ValueError(error_message) from exc

Check warning on line 87 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L87

Added line #L87 was not covered by tests

validated_data = schema.validate(obj)

Check warning on line 89 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L89

Added line #L89 was not covered by tests

if validated_data.pandera.errors:
errors = json.dumps(

Check warning on line 92 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L91-L92

Added lines #L91 - L92 were not covered by tests
dict(validated_data.pandera.errors), indent=4
)
raise ValueError(errors)

Check warning on line 95 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L95

Added line #L95 was not covered by tests

return validated_data

Check warning on line 97 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L97

Added line #L97 was not covered by tests

if PYDANTIC_V2:

Check warning on line 99 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L99

Added line #L99 was not covered by tests

@classmethod
def __get_pydantic_core_schema__(

Check warning on line 102 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L101-L102

Added lines #L101 - L102 were not covered by tests
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
schema_model = get_args(_source_type)[0]
return core_schema.no_info_plain_validator_function(

Check warning on line 106 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L105-L106

Added lines #L105 - L106 were not covered by tests
functools.partial(
cls.pydantic_validate,
schema_model=schema_model,
),
)

else:

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate

Check warning on line 117 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L115-L117

Added lines #L115 - L117 were not covered by tests

@classmethod
def _get_schema_model(cls, field):
if not field.sub_fields:
raise TypeError(

Check warning on line 122 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L119-L122

Added lines #L119 - L122 were not covered by tests
"Expected a typed pandera.typing.DataFrame,"
" e.g. DataFrame[Schema]"
)
schema_model = field.sub_fields[0].type_
return schema_model

Check warning on line 127 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L126-L127

Added lines #L126 - L127 were not covered by tests

@classmethod
def _pydantic_validate(cls, obj: Any, field) -> ps.DataFrame:

Check warning on line 130 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L129-L130

Added lines #L129 - L130 were not covered by tests
"""
Verify that the input can be converted into a pandas dataframe that
meets all schema requirements.

This is for pydantic v1
"""
schema_model = cls._get_schema_model(field)
return cls.pydantic_validate(obj, schema_model)

Check warning on line 138 in pandera/typing/pyspark_sql.py

View check run for this annotation

Codecov / codecov/patch

pandera/typing/pyspark_sql.py#L137-L138

Added lines #L137 - L138 were not covered by tests
2 changes: 2 additions & 0 deletions tests/pyspark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def spark() -> SparkSession:
creates spark session
"""
spark: SparkSession = SparkSession.builder.getOrCreate()
spark.conf.set("spark.sql.ansi.enabled", False)
yield spark
spark.stop()

Expand All @@ -29,6 +30,7 @@ def spark_connect() -> SparkSession:
# Set location of localhost Spark Connect server
os.environ["SPARK_LOCAL_REMOTE"] = "sc://localhost"
spark: SparkSession = SparkSession.builder.getOrCreate()
spark.conf.set("spark.sql.ansi.enabled", False)
yield spark
spark.stop()

Expand Down
1 change: 1 addition & 0 deletions tests/pyspark/test_pyspark_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pandera.pyspark import pyspark_sql_accessor

spark = SparkSession.builder.getOrCreate()
spark.conf.set("spark.sql.ansi.enabled", False)


@pytest.mark.parametrize(
Expand Down
81 changes: 81 additions & 0 deletions tests/pyspark/test_pyspark_pydantic_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Tests for the integration between PySpark and Pydantic."""

import pytest
from pydantic import BaseModel, ValidationError
from pyspark.testing.utils import assertDataFrameEqual
import pyspark.sql.types as T

import pandera.pyspark as pa
from pandera.typing.pyspark_sql import DataFrame as PySparkSQLDataFrame
from pandera.typing.pyspark import DataFrame as PySparkDataFrame
from pandera.pyspark import DataFrameModel


@pytest.fixture
def sample_schema_model():
class SampleSchema(DataFrameModel):
"""
Sample schema model with data checks.
"""

product: T.StringType() = pa.Field()
price: T.IntegerType() = pa.Field()

return SampleSchema


@pytest.fixture(
params=[PySparkDataFrame, PySparkSQLDataFrame],
ids=["pyspark", "pyspark_sql"],
)
def pydantic_container(request, sample_schema_model):
TypingClass = request.param

class PydanticContainer(BaseModel):
"""
Pydantic container with a DataFrameModel as a field.
"""

data: TypingClass[sample_schema_model]

return PydanticContainer


@pytest.fixture
def correct_data(spark, sample_data, sample_spark_schema):
"""
Correct data that should pass validation.
"""
return spark.createDataFrame(sample_data, sample_spark_schema)


@pytest.fixture
def incorrect_data(spark):
"""
Incorrect data that should fail validation.
"""
data = [
(1, "Apples"),
(2, "Bananas"),
]
return spark.createDataFrame(data, ["product", "price"])


def test_pydantic_model_instantiates_with_correct_data(
correct_data, pydantic_container
):
"""
Test that a Pydantic model can be instantiated with a DataFrameModel when data is valid.
"""
my_container = pydantic_container(data=correct_data)
assertDataFrameEqual(my_container.data, correct_data)


def test_pydantic_model_throws_validation_error_with_incorrect_data(
incorrect_data, pydantic_container
):
"""
Test that a Pydantic model throws a ValidationError when data is invalid.
"""
with pytest.raises(ValidationError):
pydantic_container(data=incorrect_data)
2 changes: 2 additions & 0 deletions tests/pyspark/test_schemas_on_pyspark_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,10 @@ def test_index_dtypes(
not in {
pandas_engine.Engine.dtype(pandas_engine.BOOL),
pandas_engine.DateTime(tz="UTC"), # type: ignore[call-arg]
pandas_engine.Engine.dtype(pa.dtypes.Timedelta), # type: ignore[call-arg]
}
],
ids=lambda x: str(x)
)
@hypothesis.given(st.data())
def test_nullable(
Expand Down
Loading