Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/_core/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import functools
from typing import Annotated

import pandas as pd
import pytest
from chispa import assert_df_equality # type: ignore
from pyspark import StorageLevel
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import LongType, StringType

from typedspark import Column, DataSet, Schema
from typedspark._core.column_meta import ColumnMeta
from typedspark._core.dataset import DataSetImplements
from typedspark._utils.create_dataset import create_empty_dataset

Expand Down Expand Up @@ -138,3 +141,34 @@ def test_resetting_of_schema_annotations(spark: SparkSession):
# and then to None again
a = DataSet(df)
assert a._schema_annotations is None


def test_from_dataframe(spark: SparkSession):
df = spark.createDataFrame([(1, "a"), (2, "b")], ["a", "b"])
ds, _ = DataSet[A].from_dataframe(df)

assert isinstance(ds, DataSet)
assert_df_equality(ds, df)

df_2 = ds.to_dataframe()

assert isinstance(df_2, DataFrame)
assert_df_equality(df_2, df)


class Person(Schema):
name: Annotated[Column[StringType], ColumnMeta(external_name="first-name")]
age: Column[LongType]


def test_from_dataframe_with_external_name(spark: SparkSession):
df = spark.createDataFrame([("Alice", 1), ("Bob", 2)], ["first-name", "age"])
ds, _ = DataSet[Person].from_dataframe(df)

assert isinstance(ds, DataSet)
assert ds.columns == ["name", "age"]

df_2 = ds.to_dataframe()
assert isinstance(df_2, DataFrame)
assert df_2.columns == ["first-name", "age"]
assert_df_equality(df_2, df)
84 changes: 83 additions & 1 deletion tests/_utils/test_load_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from chispa.dataframe_comparer import assert_df_equality # type: ignore
from pyspark.sql import SparkSession
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import first
from pyspark.sql.types import IntegerType, StringType

Expand Down Expand Up @@ -225,3 +225,85 @@ def test_get_spark_session_without_spark_session():
if SparkSession.getActiveSession() is None:
with pytest.raises(ValueError):
_get_spark_session(None)


def test_create_schema_with_invalid_column_name(spark: SparkSession):
df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"])
ds, schema = create_schema(df)

df2 = ds.to_dataframe()
assert_df_equality(df, df2)


def test_create_schema_with_invalid_column_name_in_a_structtype(spark: SparkSession):
data = [
Row(
**{
"full-name": Row(
**{
"first-name": "Alice",
"last-name": "Smith",
},
),
"age": 24,
}
),
Row(
**{
"full-name": Row(
**{
"first-name": "Bob",
"last-name": "Brown",
},
),
"age": 25,
},
),
]

df = spark.createDataFrame(data)
ds, schema = create_schema(df)

df2 = ds.to_dataframe()
assert_df_equality(df, df2)


def test_create_schema_with_invalid_column_name_in_a_nested_structtype(spark: SparkSession):
data = [
Row(
**{
"details": Row(
**{
"full-name": Row(
**{
"first-name": "Alice",
"last-name": "Smith",
}
),
"age": 24,
}
)
}
),
Row(
**{
"details": Row(
**{
"full-name": Row(
**{
"first-name": "Bob",
"last-name": "Brown",
}
),
"age": 25,
}
)
}
),
]

df = spark.createDataFrame(data)
ds, schema = create_schema(df)

df2 = ds.to_dataframe()
assert_df_equality(df, df2)
1 change: 1 addition & 0 deletions typedspark/_core/column_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class A(Schema):
"""

comment: Optional[str] = None
external_name: Optional[str] = None

def get_metadata(self) -> Optional[Dict[str, str]]:
"""Returns the metadata of this column."""
Expand Down
76 changes: 72 additions & 4 deletions typedspark/_core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,30 @@
from __future__ import annotations

from copy import deepcopy
from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload
from typing import (
Callable,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)

from pyspark import StorageLevel
from pyspark.sql import Column as SparkColumn
from pyspark.sql import DataFrame
from typing_extensions import Concatenate, ParamSpec

from typedspark._core.rename_columns import rename_columns, rename_columns_2
from typedspark._core.validate_schema import validate_schema
from typedspark._schema.schema import Schema
from typedspark._transforms.transform_to_schema import transform_to_schema
from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset

_Schema = TypeVar("_Schema", bound=Schema)
_Protocol = TypeVar("_Protocol", bound=Schema, covariant=True)
Expand Down Expand Up @@ -54,9 +69,62 @@ def typedspark_schema(self) -> Type[_Implementation]:
"""Returns the ``Schema`` of the ``DataSet``."""
return self._schema_annotations

"""The following functions are equivalent to their parents in ``DataFrame``, but since they
don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings,
such that the docstring from the parent will appear."""
@classmethod
def from_dataframe(
cls, df: DataFrame, register_to_schema: bool = True
) -> Tuple[DataSet[_Implementation], Type[_Implementation]]:
"""Converts a DataFrame to a DataSet and registers the Schema to the DataSet.
Also renames the columns to their internal names, for example to deal with
characters that are not allowed in class attribute names.

.. code-block:: python

class Person(Schema):
name: Annotation[Column[StringType], ColumnMeta(external_name="first-name")]
age: Column[LongType]

df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"])
ds, schema = DataSet[Person].from_dataframe(df)
"""
if not hasattr(cls, "_schema_annotations"): # pragma: no cover
raise SyntaxError("Please define a schema, e.g. `DataSet[Person].from_dataset(df)`.")

schema = cls._schema_annotations # type: ignore

df = rename_columns(df, schema)
df = transform_to_schema(df, schema)
if register_to_schema:
schema = register_schema_to_dataset(df, schema)
return df, schema

def to_dataframe(self) -> DataFrame:
"""Converts a DataSet to a DataFrame. Also renames the columns to their external
names.

.. code-block:: python

class Person(Schema):
name: Annotated[Column[StringType], ColumnMeta(external_name="full-name")]
age: Column[LongType]

df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["name", "age"])
ds, schema = DataSet[Person].from_dataframe(df)
df = ds.to_dataframe()
"""
df = cast(DataFrame, self)
df.__class__ = DataFrame

df = rename_columns_2(df, self._schema_annotations)

return df

# return rename_columns(df, schema)

"""The following functions are equivalent to their parents in ``DataFrame``, but
since they don't affect the ``Schema``, we can add type annotations here.

We're omitting docstrings, such that the docstring from the parent will appear.
"""

def alias(self, alias: str) -> DataSet[_Implementation]:
return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore
Expand Down
123 changes: 123 additions & 0 deletions typedspark/_core/rename_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Helper functions to rename columns from their external name (defined in
`ColumnMeta(external_name=...)`) to their internal name."""

from typing import Optional, Type

from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col, lit, struct, when
from pyspark.sql.types import StructField, StructType

from typedspark._schema.schema import Schema


def rename_columns(df: DataFrame, schema: Type[Schema]) -> DataFrame:
"""Helper functions to rename columns from their external name (defined in
`ColumnMeta(external_name=...)`) to their internal name (as used in the Schema)."""
for field in schema.get_structtype().fields:
internal_name = field.name

if field.metadata and "external_name" in field.metadata:
external_name = field.metadata["external_name"]
df = df.withColumnRenamed(external_name, internal_name)

if isinstance(field.dataType, StructType):
structtype = _create_renamed_structtype(field.dataType, internal_name)
df = df.withColumn(internal_name, structtype)

return df


def rename_columns_2(df: DataFrame, schema: Type[Schema]) -> DataFrame:
"""Helper functions to rename columns from their internal name (as used in the
Schema) to their external name (defined in `ColumnMeta(external_name=...)`)."""
for field in schema.get_structtype().fields:
internal_name = field.name

if field.metadata and "external_name" in field.metadata:
external_name = field.metadata["external_name"]
df = df.withColumnRenamed(internal_name, external_name) # swap

if isinstance(field.dataType, StructType):
structtype = _create_renamed_structtype_2(field.dataType, internal_name)
df = df.withColumn(external_name, structtype) # swap

return df


def _create_renamed_structtype(
schema: StructType,
parent: str,
full_parent_path: Optional[str] = None,
) -> Column:
if not full_parent_path:
full_parent_path = f"`{parent}`"

mapping = []
for field in schema.fields:
external_name = _get_updated_parent_path(full_parent_path, field)

if isinstance(field.dataType, StructType):
mapping += [
_create_renamed_structtype(
field.dataType,
parent=field.name,
full_parent_path=external_name,
)
]
else:
mapping += [col(external_name).alias(field.name)]

return _produce_nested_structtype(mapping, parent, full_parent_path)


def _create_renamed_structtype_2(
schema: StructType,
parent: str,
full_parent_path: Optional[str] = None,
) -> Column:
if not full_parent_path:
full_parent_path = f"`{parent}`"

mapping = []
for field in schema.fields:
internal_name = field.name
external_name = field.metadata.get("external_name", internal_name)

updated_parent_path = _get_updated_parent_path_2(full_parent_path, internal_name) # swap

if isinstance(field.dataType, StructType):
mapping += [
_create_renamed_structtype(
field.dataType,
parent=external_name, # swap
full_parent_path=updated_parent_path,
)
]
else:
mapping += [col(updated_parent_path).alias(external_name)] # swap

return _produce_nested_structtype(mapping, parent, full_parent_path)


def _get_updated_parent_path(full_parent_path: str, field: StructField) -> str:
external_name = field.metadata.get("external_name", field.name)
return f"{full_parent_path}.`{external_name}`"


def _get_updated_parent_path_2(full_parent_path: str, field: str) -> str:
return f"{full_parent_path}.`{field}`"


def _produce_nested_structtype(
mapping: list[Column],
parent: str,
full_parent_path: str,
) -> Column:
return (
when(
col(full_parent_path).isNotNull(),
struct(*mapping),
)
.otherwise(lit(None))
.alias(parent)
)
Loading
Loading