diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index fc04c20e..d58c5ade 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -1,12 +1,15 @@ import functools +from typing import Annotated import pandas as pd import pytest +from chispa import assert_df_equality # type: ignore from pyspark import StorageLevel from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import LongType, StringType from typedspark import Column, DataSet, Schema +from typedspark._core.column_meta import ColumnMeta from typedspark._core.dataset import DataSetImplements from typedspark._utils.create_dataset import create_empty_dataset @@ -138,3 +141,34 @@ def test_resetting_of_schema_annotations(spark: SparkSession): # and then to None again a = DataSet(df) assert a._schema_annotations is None + + +def test_from_dataframe(spark: SparkSession): + df = spark.createDataFrame([(1, "a"), (2, "b")], ["a", "b"]) + ds, _ = DataSet[A].from_dataframe(df) + + assert isinstance(ds, DataSet) + assert_df_equality(ds, df) + + df_2 = ds.to_dataframe() + + assert isinstance(df_2, DataFrame) + assert_df_equality(df_2, df) + + +class Person(Schema): + name: Annotated[Column[StringType], ColumnMeta(external_name="first-name")] + age: Column[LongType] + + +def test_from_dataframe_with_external_name(spark: SparkSession): + df = spark.createDataFrame([("Alice", 1), ("Bob", 2)], ["first-name", "age"]) + ds, _ = DataSet[Person].from_dataframe(df) + + assert isinstance(ds, DataSet) + assert ds.columns == ["name", "age"] + + df_2 = ds.to_dataframe() + assert isinstance(df_2, DataFrame) + assert df_2.columns == ["first-name", "age"] + assert_df_equality(df_2, df) diff --git a/tests/_utils/test_load_table.py b/tests/_utils/test_load_table.py index 115778c1..1a53fbd6 100644 --- a/tests/_utils/test_load_table.py +++ b/tests/_utils/test_load_table.py @@ -2,7 +2,7 @@ import pytest from chispa.dataframe_comparer import assert_df_equality # type: ignore -from pyspark.sql import SparkSession +from pyspark.sql import Row, SparkSession from pyspark.sql.functions import first from pyspark.sql.types import IntegerType, StringType @@ -225,3 +225,85 @@ def test_get_spark_session_without_spark_session(): if SparkSession.getActiveSession() is None: with pytest.raises(ValueError): _get_spark_session(None) + + +def test_create_schema_with_invalid_column_name(spark: SparkSession): + df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"]) + ds, schema = create_schema(df) + + df2 = ds.to_dataframe() + assert_df_equality(df, df2) + + +def test_create_schema_with_invalid_column_name_in_a_structtype(spark: SparkSession): + data = [ + Row( + **{ + "full-name": Row( + **{ + "first-name": "Alice", + "last-name": "Smith", + }, + ), + "age": 24, + } + ), + Row( + **{ + "full-name": Row( + **{ + "first-name": "Bob", + "last-name": "Brown", + }, + ), + "age": 25, + }, + ), + ] + + df = spark.createDataFrame(data) + ds, schema = create_schema(df) + + df2 = ds.to_dataframe() + assert_df_equality(df, df2) + + +def test_create_schema_with_invalid_column_name_in_a_nested_structtype(spark: SparkSession): + data = [ + Row( + **{ + "details": Row( + **{ + "full-name": Row( + **{ + "first-name": "Alice", + "last-name": "Smith", + } + ), + "age": 24, + } + ) + } + ), + Row( + **{ + "details": Row( + **{ + "full-name": Row( + **{ + "first-name": "Bob", + "last-name": "Brown", + } + ), + "age": 25, + } + ) + } + ), + ] + + df = spark.createDataFrame(data) + ds, schema = create_schema(df) + + df2 = ds.to_dataframe() + assert_df_equality(df, df2) diff --git a/typedspark/_core/column_meta.py b/typedspark/_core/column_meta.py index a6427940..6d97834e 100644 --- a/typedspark/_core/column_meta.py +++ b/typedspark/_core/column_meta.py @@ -20,6 +20,7 @@ class A(Schema): """ comment: Optional[str] = None + external_name: Optional[str] = None def get_metadata(self) -> Optional[Dict[str, str]]: """Returns the metadata of this column.""" diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 2b0f1210..4bf5a95c 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -3,15 +3,30 @@ from __future__ import annotations from copy import deepcopy -from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload +from typing import ( + Callable, + Generic, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) from pyspark import StorageLevel from pyspark.sql import Column as SparkColumn from pyspark.sql import DataFrame from typing_extensions import Concatenate, ParamSpec +from typedspark._core.rename_columns import rename_columns, rename_columns_2 from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema +from typedspark._transforms.transform_to_schema import transform_to_schema +from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset _Schema = TypeVar("_Schema", bound=Schema) _Protocol = TypeVar("_Protocol", bound=Schema, covariant=True) @@ -54,9 +69,62 @@ def typedspark_schema(self) -> Type[_Implementation]: """Returns the ``Schema`` of the ``DataSet``.""" return self._schema_annotations - """The following functions are equivalent to their parents in ``DataFrame``, but since they - don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings, - such that the docstring from the parent will appear.""" + @classmethod + def from_dataframe( + cls, df: DataFrame, register_to_schema: bool = True + ) -> Tuple[DataSet[_Implementation], Type[_Implementation]]: + """Converts a DataFrame to a DataSet and registers the Schema to the DataSet. + Also renames the columns to their internal names, for example to deal with + characters that are not allowed in class attribute names. + + .. code-block:: python + + class Person(Schema): + name: Annotation[Column[StringType], ColumnMeta(external_name="first-name")] + age: Column[LongType] + + df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"]) + ds, schema = DataSet[Person].from_dataframe(df) + """ + if not hasattr(cls, "_schema_annotations"): # pragma: no cover + raise SyntaxError("Please define a schema, e.g. `DataSet[Person].from_dataset(df)`.") + + schema = cls._schema_annotations # type: ignore + + df = rename_columns(df, schema) + df = transform_to_schema(df, schema) + if register_to_schema: + schema = register_schema_to_dataset(df, schema) + return df, schema + + def to_dataframe(self) -> DataFrame: + """Converts a DataSet to a DataFrame. Also renames the columns to their external + names. + + .. code-block:: python + + class Person(Schema): + name: Annotated[Column[StringType], ColumnMeta(external_name="full-name")] + age: Column[LongType] + + df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["name", "age"]) + ds, schema = DataSet[Person].from_dataframe(df) + df = ds.to_dataframe() + """ + df = cast(DataFrame, self) + df.__class__ = DataFrame + + df = rename_columns_2(df, self._schema_annotations) + + return df + + # return rename_columns(df, schema) + + """The following functions are equivalent to their parents in ``DataFrame``, but + since they don't affect the ``Schema``, we can add type annotations here. + + We're omitting docstrings, such that the docstring from the parent will appear. + """ def alias(self, alias: str) -> DataSet[_Implementation]: return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore diff --git a/typedspark/_core/rename_columns.py b/typedspark/_core/rename_columns.py new file mode 100644 index 00000000..533423bd --- /dev/null +++ b/typedspark/_core/rename_columns.py @@ -0,0 +1,123 @@ +"""Helper functions to rename columns from their external name (defined in +`ColumnMeta(external_name=...)`) to their internal name.""" + +from typing import Optional, Type + +from pyspark.sql import Column, DataFrame +from pyspark.sql.functions import col, lit, struct, when +from pyspark.sql.types import StructField, StructType + +from typedspark._schema.schema import Schema + + +def rename_columns(df: DataFrame, schema: Type[Schema]) -> DataFrame: + """Helper functions to rename columns from their external name (defined in + `ColumnMeta(external_name=...)`) to their internal name (as used in the Schema).""" + for field in schema.get_structtype().fields: + internal_name = field.name + + if field.metadata and "external_name" in field.metadata: + external_name = field.metadata["external_name"] + df = df.withColumnRenamed(external_name, internal_name) + + if isinstance(field.dataType, StructType): + structtype = _create_renamed_structtype(field.dataType, internal_name) + df = df.withColumn(internal_name, structtype) + + return df + + +def rename_columns_2(df: DataFrame, schema: Type[Schema]) -> DataFrame: + """Helper functions to rename columns from their internal name (as used in the + Schema) to their external name (defined in `ColumnMeta(external_name=...)`).""" + for field in schema.get_structtype().fields: + internal_name = field.name + + if field.metadata and "external_name" in field.metadata: + external_name = field.metadata["external_name"] + df = df.withColumnRenamed(internal_name, external_name) # swap + + if isinstance(field.dataType, StructType): + structtype = _create_renamed_structtype_2(field.dataType, internal_name) + df = df.withColumn(external_name, structtype) # swap + + return df + + +def _create_renamed_structtype( + schema: StructType, + parent: str, + full_parent_path: Optional[str] = None, +) -> Column: + if not full_parent_path: + full_parent_path = f"`{parent}`" + + mapping = [] + for field in schema.fields: + external_name = _get_updated_parent_path(full_parent_path, field) + + if isinstance(field.dataType, StructType): + mapping += [ + _create_renamed_structtype( + field.dataType, + parent=field.name, + full_parent_path=external_name, + ) + ] + else: + mapping += [col(external_name).alias(field.name)] + + return _produce_nested_structtype(mapping, parent, full_parent_path) + + +def _create_renamed_structtype_2( + schema: StructType, + parent: str, + full_parent_path: Optional[str] = None, +) -> Column: + if not full_parent_path: + full_parent_path = f"`{parent}`" + + mapping = [] + for field in schema.fields: + internal_name = field.name + external_name = field.metadata.get("external_name", internal_name) + + updated_parent_path = _get_updated_parent_path_2(full_parent_path, internal_name) # swap + + if isinstance(field.dataType, StructType): + mapping += [ + _create_renamed_structtype( + field.dataType, + parent=external_name, # swap + full_parent_path=updated_parent_path, + ) + ] + else: + mapping += [col(updated_parent_path).alias(external_name)] # swap + + return _produce_nested_structtype(mapping, parent, full_parent_path) + + +def _get_updated_parent_path(full_parent_path: str, field: StructField) -> str: + external_name = field.metadata.get("external_name", field.name) + return f"{full_parent_path}.`{external_name}`" + + +def _get_updated_parent_path_2(full_parent_path: str, field: str) -> str: + return f"{full_parent_path}.`{field}`" + + +def _produce_nested_structtype( + mapping: list[Column], + parent: str, + full_parent_path: str, +) -> Column: + return ( + when( + col(full_parent_path).isNotNull(), + struct(*mapping), + ) + .otherwise(lit(None)) + .alias(parent) + ) diff --git a/typedspark/_transforms/transform_to_schema.py b/typedspark/_transforms/transform_to_schema.py index 2cb23fc1..294740ca 100644 --- a/typedspark/_transforms/transform_to_schema.py +++ b/typedspark/_transforms/transform_to_schema.py @@ -1,17 +1,19 @@ """Module containing functions that are related to transformations to DataSets.""" from functools import reduce -from typing import Dict, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union from pyspark.sql import Column as SparkColumn from pyspark.sql import DataFrame from typedspark._core.column import Column -from typedspark._core.dataset import DataSet from typedspark._schema.schema import Schema from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings +if TYPE_CHECKING: # pragma: no cover + from typedspark._core.dataset import DataSet + T = TypeVar("T", bound=Schema) @@ -40,7 +42,7 @@ def transform_to_schema( schema: Type[T], transformations: Optional[Dict[Column, SparkColumn]] = None, fill_unspecified_columns_with_nulls: bool = False, -) -> DataSet[T]: +) -> "DataSet[T]": """On the provided DataFrame ``df``, it performs the ``transformations`` (if provided), and subsequently subsets the resulting DataFrame to the columns specified in ``schema``. @@ -58,6 +60,9 @@ def transform_to_schema( } ) """ + # importing within the function to avoid circular imports + from typedspark import DataSet + transform: Union[dict[str, SparkColumn], RenameDuplicateColumns] transform = convert_keys_to_strings(transformations) diff --git a/typedspark/_utils/create_dataset_from_structtype.py b/typedspark/_utils/create_dataset_from_structtype.py index 30bc7b8c..9a2dc1dc 100644 --- a/typedspark/_utils/create_dataset_from_structtype.py +++ b/typedspark/_utils/create_dataset_from_structtype.py @@ -1,15 +1,18 @@ """Utility functions for creating a ``Schema`` from a ``StructType``""" -from typing import Dict, Literal, Optional, Type +import re +from typing import Annotated, Any, Dict, Literal, Optional, Tuple, Type from pyspark.sql.types import ArrayType as SparkArrayType from pyspark.sql.types import DataType from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType from pyspark.sql.types import DecimalType as SparkDecimalType from pyspark.sql.types import MapType as SparkMapType +from pyspark.sql.types import StructField from pyspark.sql.types import StructType as SparkStructType from typedspark._core.column import Column +from typedspark._core.column_meta import ColumnMeta from typedspark._core.datatypes import ( ArrayType, DayTimeIntervalType, @@ -27,11 +30,12 @@ def create_schema_from_structtype( """Dynamically builds a ``Schema`` based on a ``DataFrame``'s ``StructType``""" type_annotations = {} attributes: Dict[str, None] = {} + column_name_mapping = _create_column_name_mapping(structtype) + for column in structtype: - name = column.name - data_type = _extract_data_type(column.dataType, name) - type_annotations[name] = Column[data_type] # type: ignore - attributes[name] = None + column_name, column_class = _create_column(column, column_name_mapping) + type_annotations[column_name] = column_class + attributes[column_name] = None if not schema_name: schema_name = "DynamicallyLoadedSchema" @@ -42,6 +46,49 @@ def create_schema_from_structtype( return schema # type: ignore +def _create_column_name_mapping(structtype: SparkStructType) -> Dict[str, str]: + """Creates a mapping from the original column names to the renamed column names.""" + mapping = {column: _replace_illegal_characters(column) for column in structtype.names} + + renamed_columns = list(mapping.values()) + duplicates = { + column: column_renamed + for column, column_renamed in mapping.items() + if renamed_columns.count(column_renamed) > 1 + } + + if len(duplicates) > 0: + raise ValueError( + "You're trying to dynamically generate a Schema from a DataFrame. " + + "However, typedspark has detected that the DataFrame contains duplicate columns " + + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" + + "The folowing columns have lead to duplicates:\n" + + f"{duplicates}\n\n" + + "Please rename these columns in your DataFrame." + ) + + return mapping + + +def _replace_illegal_characters(column_name: str) -> str: + """Replaces illegal characters in a column name with an underscore.""" + return re.sub("[^A-Za-z0-9]", "_", column_name) + + +def _create_column(column: StructField, column_name_mapping: Dict[str, str]) -> Tuple[str, Any]: + """Creates a column object, optionally with an `external_name` if the mapped_name is + different from the original name (due to illegal characters, such as `-` in the + original name).""" + name = column.name + mapped_name = column_name_mapping[name] + data_type = _extract_data_type(column.dataType, name) + + if mapped_name == name: + return name, Column[data_type] # type: ignore + + return mapped_name, Annotated[Column[data_type], ColumnMeta(external_name=name)] # type: ignore + + def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]: """Given an instance of a ``DataType``, it extracts the corresponding ``DataType`` class, potentially including annotations (e.g. ``ArrayType[StringType]``).""" diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index ab578442..c437c3a7 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -1,53 +1,12 @@ """Functions for loading `DataSet` and `Schema` in notebooks.""" -import re -from typing import Dict, Optional, Tuple, Type +from typing import Optional, Tuple, Type from pyspark.sql import DataFrame, SparkSession from typedspark._core.dataset import DataSet from typedspark._schema.schema import Schema from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype -from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset - - -def _replace_illegal_column_names(dataframe: DataFrame) -> DataFrame: - """Replaces illegal column names with a legal version.""" - mapping = _create_mapping(dataframe) - - for column, column_renamed in mapping.items(): - if column != column_renamed: - dataframe = dataframe.withColumnRenamed(column, column_renamed) - - return dataframe - - -def _create_mapping(dataframe: DataFrame) -> Dict[str, str]: - """Checks if there are duplicate columns after replacing illegal characters.""" - mapping = {column: _replace_illegal_characters(column) for column in dataframe.columns} - renamed_columns = list(mapping.values()) - duplicates = { - column: column_renamed - for column, column_renamed in mapping.items() - if renamed_columns.count(column_renamed) > 1 - } - - if len(duplicates) > 0: - raise ValueError( - "You're trying to dynamically generate a Schema from a DataFrame. " - + "However, typedspark has detected that the DataFrame contains duplicate columns " - + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" - + "The folowing columns have lead to duplicates:\n" - + f"{duplicates}\n\n" - + "Please rename these columns in your DataFrame." - ) - - return mapping - - -def _replace_illegal_characters(column_name: str) -> str: - """Replaces illegal characters in a column name with an underscore.""" - return re.sub("[^A-Za-z0-9]", "_", column_name) def create_schema( @@ -63,11 +22,8 @@ def create_schema( df, Person = create_schema(df) """ - dataframe = _replace_illegal_column_names(dataframe) schema = create_schema_from_structtype(dataframe.schema, schema_name) - dataset = DataSet[schema](dataframe) # type: ignore - schema = register_schema_to_dataset(dataset, schema) - return dataset, schema + return DataSet[schema].from_dataframe(dataframe) # type: ignore def load_table( diff --git a/typedspark/_utils/register_schema_to_dataset.py b/typedspark/_utils/register_schema_to_dataset.py index 628108fd..ee8a4fcd 100644 --- a/typedspark/_utils/register_schema_to_dataset.py +++ b/typedspark/_utils/register_schema_to_dataset.py @@ -1,11 +1,13 @@ """Module containing functions that are related to registering schema's to DataSets.""" import itertools -from typing import Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Tuple, Type, TypeVar -from typedspark._core.dataset import DataSet from typedspark._schema.schema import Schema +if TYPE_CHECKING: # pragma: no cover + from typedspark._core.dataset import DataSet + T = TypeVar("T", bound=Schema) @@ -13,7 +15,7 @@ def _counter(count: itertools.count = itertools.count()): return next(count) -def register_schema_to_dataset(dataframe: DataSet[T], schema: Type[T]) -> Type[T]: +def register_schema_to_dataset(dataframe: "DataSet[T]", schema: Type[T]) -> Type[T]: """Helps combat column ambiguity. For example: .. code-block:: python @@ -65,8 +67,8 @@ class LinkedSchema(schema): # type: ignore # pylint: disable=missing-class-doc def register_schema_to_dataset_with_alias( - dataframe: DataSet[T], schema: Type[T], alias: str -) -> Tuple[DataSet[T], Type[T]]: + dataframe: "DataSet[T]", schema: Type[T], alias: str +) -> Tuple["DataSet[T]", Type[T]]: """When dealing with self-joins, running `register_dataset_to_schema()` is not enough.