Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pyspark
pyspark==3.5.3
# pyspark-connect
pyspark[connect]==3.5.3
# linters
flake8==7.1.1
pylint==3.3.2
Expand Down
28 changes: 28 additions & 0 deletions tests/_connect/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pyspark.sql import SparkSession
from pyspark.sql.types import LongType

from typedspark import Column, Schema, create_schema, transform_to_schema


class A(Schema):
a: Column[LongType]
b: Column[LongType]
c: Column[LongType]


def test_regular_spark_works(spark: SparkSession):
df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
assert df.count() == 3
assert type(df).__module__ == "pyspark.sql.dataframe"


def test_spark_connect_works(sparkConnect: SparkSession):
df = sparkConnect.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
assert df.count() == 3
assert type(df).__module__ == "pyspark.sql.connect.dataframe"


def test_transform_to_schema_works(sparkConnect: SparkSession):
df = sparkConnect.createDataFrame([(14, 23, 16)], ["a", "b", "c"])
typed_df = transform_to_schema(df, A)
assert typed_df.count() == 1
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,17 @@ def spark():
spark = SparkSession.Builder().getOrCreate()
yield spark
spark.stop()


@pytest.fixture(scope="session")
def sparkConnect():
"""Fixture for creating a spark session."""

spark = (
SparkSession.Builder()
.config("spark.jars.packages", "org.apache.spark:spark-connect_2.12:3.5.3")
.remote('local')
.getOrCreate()
)
yield spark
spark.stop()
327 changes: 327 additions & 0 deletions typedspark/_core/connect/dataset.py
Copy link
Author

@aurokk aurokk Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a copy of typedspark/_core/dataset.py, but DataFrame and Column types are replaced with ones from pyspark.sql.connect

Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
"""Module containing classes and functions related to TypedSpark DataSets."""

from __future__ import annotations

from copy import deepcopy
from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload

from pyspark import StorageLevel
from pyspark.sql.connect.column import Column as SparkColumn
from pyspark.sql.connect.dataframe import DataFrame
from typing_extensions import Concatenate, ParamSpec

from typedspark._core.validate_schema import validate_schema
from typedspark._schema.schema import Schema

_Schema = TypeVar("_Schema", bound=Schema)
_Protocol = TypeVar("_Protocol", bound=Schema, covariant=True)
_Implementation = TypeVar("_Implementation", bound=Schema, covariant=True)

P = ParamSpec("P")
_ReturnType = TypeVar("_ReturnType", bound=DataFrame) # pylint: disable=C0103


class DataSetImplements(DataFrame, Generic[_Protocol, _Implementation]):
"""DataSetImplements allows us to define functions such as:

.. code-block:: python
class Age(Schema, Protocol):
age: Column[LongType]

def birthday(df: DataSetImplements[Age, T]) -> DataSet[T]:
return transform_to_schema(
df,
df.typedspark_schema,
{Age.age: Age.age + 1},
)

Such a function:
1. Takes as an input ``DataSetImplements[Age, T]``: a ``DataSet`` that implements the protocol
``Age`` as ``T``.
2. Returns a ``DataSet[T]``: a ``DataSet`` of the same type as the one that was provided.

``DataSetImplements`` should solely be used as a type annotation, it is never initialized."""

_schema_annotations: Type[_Implementation]

def __init__(self):
raise NotImplementedError(
"DataSetImplements should solely be used as a type annotation, it is never initialized."
)

@property
def typedspark_schema(self) -> Type[_Implementation]:
"""Returns the ``Schema`` of the ``DataSet``."""
return self._schema_annotations

"""The following functions are equivalent to their parents in ``DataFrame``, but since they
don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings,
such that the docstring from the parent will appear."""

def alias(self, alias: str) -> DataSet[_Implementation]:
return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore

def cache(self) -> DataSet[_Implementation]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().cache()) # type: ignore

def persist(
self,
storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
) -> DataSet[_Implementation]:
return DataSet[self._schema_annotations](super().persist(storageLevel)) # type: ignore

def unpersist(self, blocking: bool = False) -> DataSet[_Implementation]:
return DataSet[self._schema_annotations](super().unpersist(blocking)) # type: ignore

def distinct(self) -> DataSet[_Implementation]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().distinct()) # type: ignore

def filter(self, condition) -> DataSet[_Implementation]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore

@overload
def join( # type: ignore
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: None = ...,
) -> DataFrame: ... # pragma: no cover

@overload
def join(
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: Literal["semi"] = ...,
) -> DataSet[_Implementation]: ... # pragma: no cover

@overload
def join(
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: Optional[str] = ...,
) -> DataFrame: ... # pragma: no cover

def join( # pylint: disable=C0116
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = None,
how: Optional[str] = None,
) -> DataFrame:
return super().join(other, on, how) # type: ignore

def orderBy(self, *args, **kwargs) -> DataSet[_Implementation]: # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103
return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore

def transform(
self,
func: Callable[Concatenate[DataSet[_Implementation], P], _ReturnType],
*args: P.args,
**kwargs: P.kwargs,
) -> _ReturnType:
return super().transform(func, *args, **kwargs) # type: ignore

@overload
def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataSet[_Implementation],
allowMissingColumns: Literal[False] = ..., # noqa: N803
) -> DataSet[_Implementation]: ... # pragma: no cover

@overload
def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataFrame,
allowMissingColumns: bool = ..., # noqa: N803
) -> DataFrame: ... # pragma: no cover

def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataFrame,
allowMissingColumns: bool = False, # noqa: N803
) -> DataFrame:
res = super().unionByName(other, allowMissingColumns)
if isinstance(other, DataSet) and other._schema_annotations == self._schema_annotations:
return DataSet[self._schema_annotations](res) # type: ignore
return res # pragma: no cover


class DataSet(DataSetImplements[_Schema, _Schema]):
"""``DataSet`` subclasses pyspark ``DataFrame`` and hence has all the same
functionality, with in addition the possibility to define a schema.

.. code-block:: python

class Person(Schema):
name: Column[StringType]
age: Column[LongType]

def foo(df: DataSet[Person]) -> DataSet[Person]:
# do stuff
return df
"""

def __new__(cls, dataframe: DataFrame) -> DataSet[_Schema]:
"""``__new__()`` instantiates the object (prior to ``__init__()``).

Here, we simply take the provided ``df`` and cast it to a
``DataSet``. This allows us to bypass the ``DataFrame``
constuctor in ``__init__()``, which requires parameters that may
be difficult to access. Subsequently, we perform schema validation, if
the schema annotations are provided.
"""
dataframe = cast(DataSet, dataframe)
dataframe.__class__ = DataSet

# first we reset the schema annotations to None, in case they are inherrited through the
# passed DataFrame
dataframe._schema_annotations = None # type: ignore

# then we use the class' schema annotations to validate the schema and add metadata
if hasattr(cls, "_schema_annotations"):
dataframe._schema_annotations = cls._schema_annotations # type: ignore
dataframe._validate_schema()
dataframe._add_schema_metadata()

return dataframe # type: ignore

def __init__(self, dataframe: DataFrame):
pass

def __class_getitem__(cls, item):
"""Allows us to define a schema for the ``DataSet``.

To make sure that the DataSet._schema_annotations variable isn't reused globally, we
generate a subclass of the ``DataSet`` with the schema annotations as a class variable.
"""
subclass_name = f"{cls.__name__}[{item.__name__}]"
subclass = type(subclass_name, (cls,), {"_schema_annotations": item})
return subclass

def _validate_schema(self) -> None:
"""Validates the schema of the ``DataSet`` against the schema annotations."""
validate_schema(
self._schema_annotations.get_structtype(),
deepcopy(self.schema),
self._schema_annotations.get_schema_name(),
)

def _add_schema_metadata(self) -> None:
"""Adds the ``ColumnMeta`` comments as metadata to the ``DataSet``.

Previously set metadata is deleted. Hence, if ``foo(dataframe: DataSet[A]) -> DataSet[B]``,
then ``DataSet[B]`` will not inherrit any metadata from ``DataSet[A]``.

Assumes validate_schema() in __setattr__() has been run.
"""
for field in self._schema_annotations.get_structtype().fields:
self.schema[field.name].metadata = field.metadata

"""The following functions are equivalent to their parents in ``DataSetImplements``. However,
to support functions like ``functools.reduce(DataSet.unionByName, datasets)``, we also add them
here. Unfortunately, this leads to some code redundancy, but we'll take that for granted."""

def alias(self, alias: str) -> DataSet[_Schema]:
return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore

def cache(self) -> DataSet[_Schema]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().cache()) # type: ignore

def persist(
self,
storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
) -> DataSet[_Schema]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().persist(storageLevel)) # type: ignore

def unpersist(self, blocking: bool = False) -> DataSet[_Schema]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().unpersist(blocking)) # type: ignore

def distinct(self) -> DataSet[_Schema]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().distinct()) # type: ignore

def filter(self, condition) -> DataSet[_Schema]: # pylint: disable=C0116
return DataSet[self._schema_annotations](super().filter(condition)) # type: ignore

@overload
def join( # type: ignore
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: None = ...,
) -> DataFrame: ... # pragma: no cover

@overload
def join(
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: Literal["semi"] = ...,
) -> DataSet[_Schema]: ... # pragma: no cover

@overload
def join(
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = ...,
how: Optional[str] = ...,
) -> DataFrame: ... # pragma: no cover

def join( # pylint: disable=C0116
self,
other: DataFrame,
on: Optional[ # pylint: disable=C0103
Union[str, List[str], SparkColumn, List[SparkColumn]]
] = None,
how: Optional[str] = None,
) -> DataFrame:
return super().join(other, on, how) # type: ignore

def orderBy(self, *args, **kwargs) -> DataSet[_Schema]: # type: ignore # noqa: N802, E501 # pylint: disable=C0116, C0103
return DataSet[self._schema_annotations](super().orderBy(*args, **kwargs)) # type: ignore

def transform(
self,
func: Callable[Concatenate[DataSet[_Schema], P], _ReturnType],
*args: P.args,
**kwargs: P.kwargs,
) -> _ReturnType:
return super().transform(func, *args, **kwargs) # type: ignore

@overload
def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataSet[_Schema],
allowMissingColumns: Literal[False] = ..., # noqa: N803
) -> DataSet[_Schema]: ... # pragma: no cover

@overload
def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataFrame,
allowMissingColumns: bool = ..., # noqa: N803
) -> DataFrame: ... # pragma: no cover

def unionByName( # noqa: N802 # pylint: disable=C0116, C0103
self,
other: DataFrame,
allowMissingColumns: bool = False, # noqa: N803
) -> DataFrame:
res = super().unionByName(other, allowMissingColumns)
if isinstance(other, DataSet) and other._schema_annotations == self._schema_annotations:
return DataSet[self._schema_annotations](res) # type: ignore
return res # pragma: no cover
Loading