Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/_core/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import functools
from typing import Annotated

import pandas as pd
import pytest
from chispa import assert_df_equality # type: ignore
from pyspark import StorageLevel
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import LongType, StringType

from typedspark import Column, DataSet, Schema
from typedspark._core.column_meta import ColumnMeta
from typedspark._core.dataset import DataSetImplements
from typedspark._utils.create_dataset import create_empty_dataset

Expand Down Expand Up @@ -138,3 +141,34 @@ def test_resetting_of_schema_annotations(spark: SparkSession):
# and then to None again
a = DataSet(df)
assert a._schema_annotations is None


def test_from_dataframe(spark: SparkSession):
df = spark.createDataFrame([(1, "a"), (2, "b")], ["a", "b"])
ds, _ = DataSet[A].from_dataframe(df)

assert isinstance(ds, DataSet)
assert_df_equality(ds, df)

df_2 = ds.to_dataframe()

assert isinstance(df_2, DataFrame)
assert_df_equality(df_2, df)


class Person(Schema):
name: Annotated[Column[StringType], ColumnMeta(external_name="first-name")]
age: Column[LongType]


def test_from_dataframe_with_external_name(spark: SparkSession):
df = spark.createDataFrame([("Alice", 1), ("Bob", 2)], ["first-name", "age"])
ds, _ = DataSet[Person].from_dataframe(df)

assert isinstance(ds, DataSet)
assert ds.columns == ["name", "age"]

df_2 = ds.to_dataframe()
assert isinstance(df_2, DataFrame)
assert df_2.columns == ["first-name", "age"]
assert_df_equality(df_2, df)
1 change: 1 addition & 0 deletions typedspark/_core/column_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class A(Schema):
"""

comment: Optional[str] = None
external_name: Optional[str] = None

def get_metadata(self) -> Optional[Dict[str, str]]:
"""Returns the metadata of this column."""
Expand Down
80 changes: 76 additions & 4 deletions typedspark/_core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
from __future__ import annotations

from copy import deepcopy
from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload
from typing import (
Callable,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)

from pyspark import StorageLevel
from pyspark.sql import Column as SparkColumn
Expand All @@ -12,6 +24,8 @@

from typedspark._core.validate_schema import validate_schema
from typedspark._schema.schema import Schema
from typedspark._transforms.transform_to_schema import transform_to_schema
from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset

_Schema = TypeVar("_Schema", bound=Schema)
_Protocol = TypeVar("_Protocol", bound=Schema, covariant=True)
Expand Down Expand Up @@ -54,9 +68,67 @@ def typedspark_schema(self) -> Type[_Implementation]:
"""Returns the ``Schema`` of the ``DataSet``."""
return self._schema_annotations

"""The following functions are equivalent to their parents in ``DataFrame``, but since they
don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings,
such that the docstring from the parent will appear."""
@classmethod
def from_dataframe(
cls, df: DataFrame
) -> Tuple[DataSet[_Implementation], Type[_Implementation]]:
"""Converts a DataFrame to a DataSet and registers the Schema to the DataSet.
Also renames the columns to their internal names, for example to deal with
characters that are not allowed in class attribute names.

.. code-block:: python

class Person(Schema):
name: Annotation[Column[StringType], ColumnMeta(external_name="full-name")]
age: Column[LongType]

df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"])
ds, schema = DataSet[Person].from_dataframe(df)
"""
if not hasattr(cls, "_schema_annotations"): # pragma: no cover
raise SyntaxError("Please define a schema, e.g. `DataSet[Person].from_dataset(df)`.")

schema = cls._schema_annotations # type: ignore

for column in schema.get_structtype().fields:
if column.metadata:
df = df.withColumnRenamed(
column.metadata.get("external_name", column.name), column.name
)
df = transform_to_schema(df, schema)
schema = register_schema_to_dataset(df, schema)
return df, schema

def to_dataframe(self) -> DataFrame:
"""Converts a DataSet to a DataFrame. Also renames the columns to their external
names.

.. code-block:: python

class Person(Schema):
name: Annotated[Column[StringType], ColumnMeta(external_name="full-name")]
age: Column[LongType]

df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["name", "age"])
ds, schema = DataSet[Person].from_dataframe(df)
df = ds.to_dataframe()
"""
df = cast(DataFrame, self)
df.__class__ = DataFrame

for column in self._schema_annotations.get_structtype().fields:
if column.metadata:
df = df.withColumnRenamed(
column.name, column.metadata.get("external_name", column.name)
)

return df

"""The following functions are equivalent to their parents in ``DataFrame``, but
since they don't affect the ``Schema``, we can add type annotations here.

We're omitting docstrings, such that the docstring from the parent will appear.
"""

def alias(self, alias: str) -> DataSet[_Implementation]:
return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore
Expand Down
11 changes: 8 additions & 3 deletions typedspark/_transforms/transform_to_schema.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Module containing functions that are related to transformations to DataSets."""

from functools import reduce
from typing import Dict, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union

from pyspark.sql import Column as SparkColumn
from pyspark.sql import DataFrame

from typedspark._core.column import Column
from typedspark._core.dataset import DataSet
from typedspark._schema.schema import Schema
from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns
from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings

if TYPE_CHECKING: # pragma: no cover
from typedspark._core.dataset import DataSet

T = TypeVar("T", bound=Schema)


Expand Down Expand Up @@ -40,7 +42,7 @@ def transform_to_schema(
schema: Type[T],
transformations: Optional[Dict[Column, SparkColumn]] = None,
fill_unspecified_columns_with_nulls: bool = False,
) -> DataSet[T]:
) -> "DataSet[T]":
"""On the provided DataFrame ``df``, it performs the ``transformations`` (if
provided), and subsequently subsets the resulting DataFrame to the columns specified
in ``schema``.
Expand All @@ -58,6 +60,9 @@ def transform_to_schema(
}
)
"""
# importing within the function to avoid circular imports
from typedspark import DataSet

transform: Union[dict[str, SparkColumn], RenameDuplicateColumns]
transform = convert_keys_to_strings(transformations)

Expand Down
12 changes: 7 additions & 5 deletions typedspark/_utils/register_schema_to_dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
"""Module containing functions that are related to registering schema's to DataSets."""

import itertools
from typing import Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Tuple, Type, TypeVar

from typedspark._core.dataset import DataSet
from typedspark._schema.schema import Schema

if TYPE_CHECKING: # pragma: no cover
from typedspark._core.dataset import DataSet

T = TypeVar("T", bound=Schema)


def _counter(count: itertools.count = itertools.count()):
return next(count)


def register_schema_to_dataset(dataframe: DataSet[T], schema: Type[T]) -> Type[T]:
def register_schema_to_dataset(dataframe: "DataSet[T]", schema: Type[T]) -> Type[T]:
"""Helps combat column ambiguity. For example:

.. code-block:: python
Expand Down Expand Up @@ -65,8 +67,8 @@ class LinkedSchema(schema): # type: ignore # pylint: disable=missing-class-doc


def register_schema_to_dataset_with_alias(
dataframe: DataSet[T], schema: Type[T], alias: str
) -> Tuple[DataSet[T], Type[T]]:
dataframe: "DataSet[T]", schema: Type[T], alias: str
) -> Tuple["DataSet[T]", Type[T]]:
"""When dealing with self-joins, running `register_dataset_to_schema()` is not
enough.

Expand Down