diff --git a/requirements-dev.txt b/requirements-dev.txt index 65d0b372..fbb0c121 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,7 @@ # pyspark pyspark==3.5.3 +# pyspark-connect +pyspark[connect]==3.5.3 # linters flake8==7.1.1 pylint==3.3.2 diff --git a/tests/_connect/test_connect.py b/tests/_connect/test_connect.py new file mode 100644 index 00000000..bd2b523a --- /dev/null +++ b/tests/_connect/test_connect.py @@ -0,0 +1,22 @@ +from pyspark.sql import SparkSession +from pyspark.sql.types import LongType + +from typedspark import Column, Schema + + +class A(Schema): + a: Column[LongType] + b: Column[LongType] + c: Column[LongType] + + +def test_regular_spark_works(spark: SparkSession): + df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + assert df.count() == 3 + assert type(df).__module__ == "pyspark.sql.dataframe" + + +def test_spark_connect_works(sparkConnect: SparkSession): + df = sparkConnect.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + assert df.count() == 3 + assert type(df).__module__ == "pyspark.sql.connect.dataframe" diff --git a/tests/_connect/test_dataset.py b/tests/_connect/test_dataset.py new file mode 100644 index 00000000..150a4cbe --- /dev/null +++ b/tests/_connect/test_dataset.py @@ -0,0 +1,140 @@ +import functools + +import pandas as pd +import pytest +from pyspark import StorageLevel +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import LongType, StringType + +from typedspark import Column, DataSet, Schema +from typedspark._core.dataset import DataSetImplements +from typedspark._utils.create_dataset import create_empty_dataset + + +class A(Schema): + a: Column[LongType] + b: Column[StringType] + + +class B(Schema): + a: Column[LongType] + b: Column[StringType] + + +def create_dataframe(sparkConnect: SparkSession, d): + return sparkConnect.createDataFrame(pd.DataFrame(d)) + + +def test_dataset(sparkConnect: SparkSession): + d = dict( + a=[1, 2, 3], + b=["a", "b", "c"], + ) + df = create_dataframe(sparkConnect, d) + DataSet[A](df) + + +def test_dataset_allow_underscored_columns_not_in_schema(sparkConnect: SparkSession): + d = {"a": [1, 2, 3], "b": ["a", "b", "c"], "__c": [1, 2, 3]} + df = create_dataframe(sparkConnect, d) + DataSet[A](df) + + +def test_dataset_single_underscored_column_should_raise(sparkConnect: SparkSession): + d = {"a": [1, 2, 3], "b": ["a", "b", "c"], "_c": [1, 2, 3]} + df = create_dataframe(sparkConnect, d) + with pytest.raises(TypeError): + DataSet[A](df) + + +def test_dataset_missing_colnames(sparkConnect: SparkSession): + d = dict( + a=[1, 2, 3], + ) + df = create_dataframe(sparkConnect, d) + with pytest.raises(TypeError): + DataSet[A](df) + + +def test_dataset_too_many_colnames(sparkConnect: SparkSession): + d = dict( + a=[1, 2, 3], + b=["a", "b", "c"], + c=[1, 2, 3], + ) + df = create_dataframe(sparkConnect, d) + with pytest.raises(TypeError): + DataSet[A](df) + + +def test_wrong_type(sparkConnect: SparkSession): + d = dict( + a=[1, 2, 3], + b=[1, 2, 3], + ) + df = create_dataframe(sparkConnect, d) + with pytest.raises(TypeError): + DataSet[A](df) + + +def test_inherrited_functions(sparkConnect: SparkSession): + df = create_empty_dataset(sparkConnect, A) + + df.distinct() + cached1: DataSet[A] = df.cache() + cached2: DataSet[A] = df.persist(StorageLevel.MEMORY_AND_DISK) + df.filter(A.a == 1) + df.orderBy(A.a) + df.transform(lambda df: df) + + cached1.unpersist(True) + cached2.unpersist(True) + + +def test_inherrited_functions_with_other_dataset(sparkConnect: SparkSession): + df_a = create_empty_dataset(sparkConnect, A) + df_b = create_empty_dataset(sparkConnect, A) + + df_a.join(df_b, A.a.str) + df_a.unionByName(df_b) + + +def test_schema_property_of_dataset(sparkConnect: SparkSession): + df = create_empty_dataset(sparkConnect, A) + assert df.typedspark_schema == A + + +def test_initialize_dataset_implements(sparkConnect: SparkSession): + with pytest.raises(NotImplementedError): + DataSetImplements() + + +def test_reduce(sparkConnect: SparkSession): + functools.reduce( + DataSet.unionByName, + [create_empty_dataset(sparkConnect, A), create_empty_dataset(sparkConnect, A)], + ) + + +def test_resetting_of_schema_annotations(sparkConnect: SparkSession): + df = create_empty_dataset(sparkConnect, A) + + a: DataFrame + + # if no schema is specified, the annotation should be None + a = DataSet(df) + assert a._schema_annotations is None + + # when we specify a schema, the class variable will be set to A, but afterwards it should be + # reset to None again when we initialize a new object without specifying a schema + DataSet[A] + a = DataSet(df) + assert a._schema_annotations is None + + # and then to B + a = DataSet[B](df) + assert a._schema_annotations == B + + # and then to None again + a = DataSet(df) + assert a._schema_annotations is None diff --git a/tests/conftest.py b/tests/conftest.py index 22ba1940..08c8a1e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,3 +14,17 @@ def spark(): spark = SparkSession.Builder().getOrCreate() yield spark spark.stop() + + +@pytest.fixture(scope="session") +def sparkConnect(): + """Fixture for creating a spark session.""" + + spark = ( + SparkSession.Builder() + .config("spark.jars.packages", "org.apache.spark:spark-connect_2.12:3.5.3") + .remote('local') + .getOrCreate() + ) + yield spark + spark.stop() diff --git a/typedspark/_core/connect/dataset.py b/typedspark/_core/connect/dataset.py new file mode 100644 index 00000000..8d49f742 --- /dev/null +++ b/typedspark/_core/connect/dataset.py @@ -0,0 +1,327 @@ +"""Module containing classes and functions related to TypedSpark DataSets.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload + +from pyspark import StorageLevel +from pyspark.sql.connect.column import Column as SparkColumn +from pyspark.sql.connect.dataframe import DataFrame +from typing_extensions import Concatenate, ParamSpec + +from typedspark._core.validate_schema import validate_schema +from typedspark._schema.schema import Schema + +_Schema = TypeVar("_Schema", bound=Schema) +_Protocol = TypeVar("_Protocol", bound=Schema, covariant=True) +_Implementation = TypeVar("_Implementation", bound=Schema, covariant=True) + +P = ParamSpec("P") +_ReturnType = TypeVar("_ReturnType", bound=DataFrame) # pylint: disable=C0103 + + +class DataSetImplements(DataFrame, Generic[_Protocol, _Implementation]): + """DataSetImplements allows us to define functions such as: + + .. code-block:: python + class Age(Schema, Protocol): + age: Column[LongType] + + def birthday(df: DataSetImplements[Age, T]) -> DataSet[T]: + return transform_to_schema( + df, + df.typedspark_schema, + {Age.age: Age.age + 1}, + ) + + Such a function: + 1. Takes as an input ``DataSetImplements[Age, T]``: a ``DataSet`` that implements the protocol + ``Age`` as ``T``. + 2. Returns a ``DataSet[T]``: a ``DataSet`` of the same type as the one that was provided. + + ``DataSetImplements`` should solely be used as a type annotation, it is never initialized.""" + + _schema_annotations: Type[_Implementation] + + def __init__(self): + raise NotImplementedError( + "DataSetImplements should solely be used as a type annotation, it is never initialized." + ) + + @property + def typedspark_schema(self) -> Type[_Implementation]: + """Returns the ``Schema`` of the ``DataSet``.""" + return self._schema_annotations + + """The following functions are equivalent to their parents in ``DataFrame``, but since they + don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings, + such that the docstring from the parent will appear.""" + + def alias(self, alias: str) -> DataSet[_Implementation]: + return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore + + def cache(self) -> DataSet[_Implementation]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().cache()) # type: ignore + + def persist( + self, + storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER), + ) -> DataSet[_Implementation]: + return DataSet[self._schema_annotations](super().persist(storageLevel)) # type: ignore + + def unpersist(self, blocking: bool = False) -> DataSet[_Implementation]: + return DataSet[self._schema_annotations](super().unpersist(blocking)) # type: ignore + + def distinct(self) -> DataSet[_Implementation]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().distinct()) # type: ignore + + def filter(self, condition) -> DataSet[_Implementation]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore + + @overload + def join( # type: ignore + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: None = ..., + ) -> DataFrame: ... # pragma: no cover + + @overload + def join( + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: Literal["semi"] = ..., + ) -> DataSet[_Implementation]: ... # pragma: no cover + + @overload + def join( + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: Optional[str] = ..., + ) -> DataFrame: ... # pragma: no cover + + def join( # pylint: disable=C0116 + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = None, + how: Optional[str] = None, + ) -> DataFrame: + return super().join(other, on, how) # type: ignore + + def orderBy(self, *args, **kwargs) -> DataSet[_Implementation]: # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103 + return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore + + def transform( + self, + func: Callable[Concatenate[DataSet[_Implementation], P], _ReturnType], + *args: P.args, + **kwargs: P.kwargs, + ) -> _ReturnType: + return super().transform(func, *args, **kwargs) # type: ignore + + @overload + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataSet[_Implementation], + allowMissingColumns: Literal[False] = ..., # noqa: N803 + ) -> DataSet[_Implementation]: ... # pragma: no cover + + @overload + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataFrame, + allowMissingColumns: bool = ..., # noqa: N803 + ) -> DataFrame: ... # pragma: no cover + + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataFrame, + allowMissingColumns: bool = False, # noqa: N803 + ) -> DataFrame: + res = super().unionByName(other, allowMissingColumns) + if isinstance(other, DataSet) and other._schema_annotations == self._schema_annotations: + return DataSet[self._schema_annotations](res) # type: ignore + return res # pragma: no cover + + +class DataSet(DataSetImplements[_Schema, _Schema]): + """``DataSet`` subclasses pyspark ``DataFrame`` and hence has all the same + functionality, with in addition the possibility to define a schema. + + .. code-block:: python + + class Person(Schema): + name: Column[StringType] + age: Column[LongType] + + def foo(df: DataSet[Person]) -> DataSet[Person]: + # do stuff + return df + """ + + def __new__(cls, dataframe: DataFrame) -> DataSet[_Schema]: + """``__new__()`` instantiates the object (prior to ``__init__()``). + + Here, we simply take the provided ``df`` and cast it to a + ``DataSet``. This allows us to bypass the ``DataFrame`` + constuctor in ``__init__()``, which requires parameters that may + be difficult to access. Subsequently, we perform schema validation, if + the schema annotations are provided. + """ + dataframe = cast(DataSet, dataframe) + dataframe.__class__ = DataSet + + # first we reset the schema annotations to None, in case they are inherrited through the + # passed DataFrame + dataframe._schema_annotations = None # type: ignore + + # then we use the class' schema annotations to validate the schema and add metadata + if hasattr(cls, "_schema_annotations"): + dataframe._schema_annotations = cls._schema_annotations # type: ignore + dataframe._validate_schema() + dataframe._add_schema_metadata() + + return dataframe # type: ignore + + def __init__(self, dataframe: DataFrame): + pass + + def __class_getitem__(cls, item): + """Allows us to define a schema for the ``DataSet``. + + To make sure that the DataSet._schema_annotations variable isn't reused globally, we + generate a subclass of the ``DataSet`` with the schema annotations as a class variable. + """ + subclass_name = f"{cls.__name__}[{item.__name__}]" + subclass = type(subclass_name, (cls,), {"_schema_annotations": item}) + return subclass + + def _validate_schema(self) -> None: + """Validates the schema of the ``DataSet`` against the schema annotations.""" + validate_schema( + self._schema_annotations.get_structtype(), + deepcopy(self.schema), + self._schema_annotations.get_schema_name(), + ) + + def _add_schema_metadata(self) -> None: + """Adds the ``ColumnMeta`` comments as metadata to the ``DataSet``. + + Previously set metadata is deleted. Hence, if ``foo(dataframe: DataSet[A]) -> DataSet[B]``, + then ``DataSet[B]`` will not inherrit any metadata from ``DataSet[A]``. + + Assumes validate_schema() in __setattr__() has been run. + """ + for field in self._schema_annotations.get_structtype().fields: + self.schema[field.name].metadata = field.metadata + + """The following functions are equivalent to their parents in ``DataSetImplements``. However, + to support functions like ``functools.reduce(DataSet.unionByName, datasets)``, we also add them + here. Unfortunately, this leads to some code redundancy, but we'll take that for granted.""" + + def alias(self, alias: str) -> DataSet[_Schema]: + return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore + + def cache(self) -> DataSet[_Schema]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().cache()) # type: ignore + + def persist( + self, + storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER), + ) -> DataSet[_Schema]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().persist(storageLevel)) # type: ignore + + def unpersist(self, blocking: bool = False) -> DataSet[_Schema]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().unpersist(blocking)) # type: ignore + + def distinct(self) -> DataSet[_Schema]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().distinct()) # type: ignore + + def filter(self, condition) -> DataSet[_Schema]: # pylint: disable=C0116 + return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore + + @overload + def join( # type: ignore + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: None = ..., + ) -> DataFrame: ... # pragma: no cover + + @overload + def join( + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: Literal["semi"] = ..., + ) -> DataSet[_Schema]: ... # pragma: no cover + + @overload + def join( + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = ..., + how: Optional[str] = ..., + ) -> DataFrame: ... # pragma: no cover + + def join( # pylint: disable=C0116 + self, + other: DataFrame, + on: Optional[ # pylint: disable=C0103 + Union[str, List[str], SparkColumn, List[SparkColumn]] + ] = None, + how: Optional[str] = None, + ) -> DataFrame: + return super().join(other, on, how) # type: ignore + + def orderBy(self, *args, **kwargs) -> DataSet[_Schema]: # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103 + return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore + + def transform( + self, + func: Callable[Concatenate[DataSet[_Schema], P], _ReturnType], + *args: P.args, + **kwargs: P.kwargs, + ) -> _ReturnType: + return super().transform(func, *args, **kwargs) # type: ignore + + @overload + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataSet[_Schema], + allowMissingColumns: Literal[False] = ..., # noqa: N803 + ) -> DataSet[_Schema]: ... # pragma: no cover + + @overload + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataFrame, + allowMissingColumns: bool = ..., # noqa: N803 + ) -> DataFrame: ... # pragma: no cover + + def unionByName( # noqa: N802 # pylint: disable=C0116, C0103 + self, + other: DataFrame, + allowMissingColumns: bool = False, # noqa: N803 + ) -> DataFrame: + res = super().unionByName(other, allowMissingColumns) + if isinstance(other, DataSet) and other._schema_annotations == self._schema_annotations: + return DataSet[self._schema_annotations](res) # type: ignore + return res # pragma: no cover diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 2b0f1210..77cb163b 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -8,8 +8,10 @@ from pyspark import StorageLevel from pyspark.sql import Column as SparkColumn from pyspark.sql import DataFrame +from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame from typing_extensions import Concatenate, ParamSpec +from typedspark._core.connect.dataset import DataSet as SparkConnectDataSet from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema @@ -180,7 +182,10 @@ def __new__(cls, dataframe: DataFrame) -> DataSet[_Schema]: the schema annotations are provided. """ dataframe = cast(DataSet, dataframe) - dataframe.__class__ = DataSet + if isinstance(dataframe, SparkConnectDataFrame): + dataframe.__class__ = SparkConnectDataSet + else: + dataframe.__class__ = DataSet # first we reset the schema annotations to None, in case they are inherrited through the # passed DataFrame