From c42cc5bb06711fa342fc5bf72eb0999e60c8aed2 Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Tue, 25 Jun 2024 20:29:08 +0200 Subject: [PATCH 01/10] implement DataSet[Schema].from_dataframe(...) --- test.ipynb | 67 +++++++++++++++++++ tests/_core/test_dataset.py | 10 +++ typedspark/_core/dataset.py | 45 +++++++++++-- typedspark/_transforms/transform_to_schema.py | 11 ++- typedspark/_utils/load_table.py | 45 +------------ .../_utils/register_schema_to_dataset.py | 12 ++-- .../_utils/replace_illegal_column_names.py | 45 +++++++++++++ 7 files changed, 181 insertions(+), 54 deletions(-) create mode 100644 test.ipynb create mode 100644 typedspark/_utils/replace_illegal_column_names.py diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 00000000..45dfed27 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,67 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(DataFrame[a: bigint],\n", + " \n", + " from pyspark.sql.types import LongType\n", + " \n", + " from typedspark import Column, Schema\n", + " \n", + " \n", + " class A(Schema):\n", + " a: Column[LongType])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pyspark.sql import SparkSession\n", + "from pyspark.sql.types import LongType\n", + "from typedspark import Schema, Column\n", + "from typedspark._core.dataset import DataSet\n", + "\n", + "\n", + "class A(Schema):\n", + " a: Column[LongType]\n", + "\n", + "\n", + "spark = SparkSession.Builder().getOrCreate()\n", + "\n", + "df = spark.createDataFrame([(1,)], [\"a\"])\n", + "\n", + "DataSet[A].from_dataframe(df)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "typedspark-3.11.9", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index fc04c20e..9c2013fc 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -2,6 +2,7 @@ import pandas as pd import pytest +from chispa import assert_df_equality # type: ignore from pyspark import StorageLevel from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import LongType, StringType @@ -138,3 +139,12 @@ def test_resetting_of_schema_annotations(spark: SparkSession): # and then to None again a = DataSet(df) assert a._schema_annotations is None + + +def test_from_dataframe(spark: SparkSession): + df = spark.createDataFrame([(1, "a"), (2, "b")], ["a", "b"]) + ds, schema = DataSet[A].from_dataframe(df) + + assert isinstance(ds, DataSet) + assert issubclass(schema, A) + assert_df_equality(ds, df) diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 2b0f1210..a26f5cc9 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -3,7 +3,19 @@ from __future__ import annotations from copy import deepcopy -from typing import Callable, Generic, List, Literal, Optional, Type, TypeVar, Union, cast, overload +from typing import ( + Callable, + Generic, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) from pyspark import StorageLevel from pyspark.sql import Column as SparkColumn @@ -12,6 +24,9 @@ from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema +from typedspark._transforms.transform_to_schema import transform_to_schema +from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset +from typedspark._utils.replace_illegal_column_names import replace_illegal_column_names _Schema = TypeVar("_Schema", bound=Schema) _Protocol = TypeVar("_Protocol", bound=Schema, covariant=True) @@ -54,9 +69,31 @@ def typedspark_schema(self) -> Type[_Implementation]: """Returns the ``Schema`` of the ``DataSet``.""" return self._schema_annotations - """The following functions are equivalent to their parents in ``DataFrame``, but since they - don't affect the ``Schema``, we can add type annotations here. We're omitting docstrings, - such that the docstring from the parent will appear.""" + @classmethod + def from_dataframe( + cls, df: DataFrame + ) -> Tuple[DataSet[_Implementation], Type[_Implementation]]: + """Converts a DataFrame to a DataSet and registers the Schema to the DataSet. + + Also replaces "illegal" characters in the DataFrame's colnames (.e.g "test-result" + -> "test_result"), so they're compatible with the Schema (after all, Python doesn't allow + for characters such as dashes in attribute names). + """ + if not hasattr(cls, "_schema_annotations"): # pragma: no cover + raise SyntaxError("Please define a schema, e.g. `DataSet[Person].from_dataset(df)`.") + + schema = cls._schema_annotations # type: ignore + + df = replace_illegal_column_names(df) + ds = transform_to_schema(df, schema) + schema = register_schema_to_dataset(ds, schema) + return ds, schema + + """The following functions are equivalent to their parents in ``DataFrame``, but + since they don't affect the ``Schema``, we can add type annotations here. + + We're omitting docstrings, such that the docstring from the parent will appear. + """ def alias(self, alias: str) -> DataSet[_Implementation]: return DataSet[self._schema_annotations](super().alias(alias)) # type: ignore diff --git a/typedspark/_transforms/transform_to_schema.py b/typedspark/_transforms/transform_to_schema.py index 2cb23fc1..294740ca 100644 --- a/typedspark/_transforms/transform_to_schema.py +++ b/typedspark/_transforms/transform_to_schema.py @@ -1,17 +1,19 @@ """Module containing functions that are related to transformations to DataSets.""" from functools import reduce -from typing import Dict, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Dict, Optional, Type, TypeVar, Union from pyspark.sql import Column as SparkColumn from pyspark.sql import DataFrame from typedspark._core.column import Column -from typedspark._core.dataset import DataSet from typedspark._schema.schema import Schema from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings +if TYPE_CHECKING: # pragma: no cover + from typedspark._core.dataset import DataSet + T = TypeVar("T", bound=Schema) @@ -40,7 +42,7 @@ def transform_to_schema( schema: Type[T], transformations: Optional[Dict[Column, SparkColumn]] = None, fill_unspecified_columns_with_nulls: bool = False, -) -> DataSet[T]: +) -> "DataSet[T]": """On the provided DataFrame ``df``, it performs the ``transformations`` (if provided), and subsequently subsets the resulting DataFrame to the columns specified in ``schema``. @@ -58,6 +60,9 @@ def transform_to_schema( } ) """ + # importing within the function to avoid circular imports + from typedspark import DataSet + transform: Union[dict[str, SparkColumn], RenameDuplicateColumns] transform = convert_keys_to_strings(transformations) diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index ab578442..361937ec 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -1,7 +1,6 @@ """Functions for loading `DataSet` and `Schema` in notebooks.""" -import re -from typing import Dict, Optional, Tuple, Type +from typing import Optional, Tuple, Type from pyspark.sql import DataFrame, SparkSession @@ -9,45 +8,7 @@ from typedspark._schema.schema import Schema from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset - - -def _replace_illegal_column_names(dataframe: DataFrame) -> DataFrame: - """Replaces illegal column names with a legal version.""" - mapping = _create_mapping(dataframe) - - for column, column_renamed in mapping.items(): - if column != column_renamed: - dataframe = dataframe.withColumnRenamed(column, column_renamed) - - return dataframe - - -def _create_mapping(dataframe: DataFrame) -> Dict[str, str]: - """Checks if there are duplicate columns after replacing illegal characters.""" - mapping = {column: _replace_illegal_characters(column) for column in dataframe.columns} - renamed_columns = list(mapping.values()) - duplicates = { - column: column_renamed - for column, column_renamed in mapping.items() - if renamed_columns.count(column_renamed) > 1 - } - - if len(duplicates) > 0: - raise ValueError( - "You're trying to dynamically generate a Schema from a DataFrame. " - + "However, typedspark has detected that the DataFrame contains duplicate columns " - + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" - + "The folowing columns have lead to duplicates:\n" - + f"{duplicates}\n\n" - + "Please rename these columns in your DataFrame." - ) - - return mapping - - -def _replace_illegal_characters(column_name: str) -> str: - """Replaces illegal characters in a column name with an underscore.""" - return re.sub("[^A-Za-z0-9]", "_", column_name) +from typedspark._utils.replace_illegal_column_names import replace_illegal_column_names def create_schema( @@ -63,7 +24,7 @@ def create_schema( df, Person = create_schema(df) """ - dataframe = _replace_illegal_column_names(dataframe) + dataframe = replace_illegal_column_names(dataframe) schema = create_schema_from_structtype(dataframe.schema, schema_name) dataset = DataSet[schema](dataframe) # type: ignore schema = register_schema_to_dataset(dataset, schema) diff --git a/typedspark/_utils/register_schema_to_dataset.py b/typedspark/_utils/register_schema_to_dataset.py index 628108fd..ee8a4fcd 100644 --- a/typedspark/_utils/register_schema_to_dataset.py +++ b/typedspark/_utils/register_schema_to_dataset.py @@ -1,11 +1,13 @@ """Module containing functions that are related to registering schema's to DataSets.""" import itertools -from typing import Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Tuple, Type, TypeVar -from typedspark._core.dataset import DataSet from typedspark._schema.schema import Schema +if TYPE_CHECKING: # pragma: no cover + from typedspark._core.dataset import DataSet + T = TypeVar("T", bound=Schema) @@ -13,7 +15,7 @@ def _counter(count: itertools.count = itertools.count()): return next(count) -def register_schema_to_dataset(dataframe: DataSet[T], schema: Type[T]) -> Type[T]: +def register_schema_to_dataset(dataframe: "DataSet[T]", schema: Type[T]) -> Type[T]: """Helps combat column ambiguity. For example: .. code-block:: python @@ -65,8 +67,8 @@ class LinkedSchema(schema): # type: ignore # pylint: disable=missing-class-doc def register_schema_to_dataset_with_alias( - dataframe: DataSet[T], schema: Type[T], alias: str -) -> Tuple[DataSet[T], Type[T]]: + dataframe: "DataSet[T]", schema: Type[T], alias: str +) -> Tuple["DataSet[T]", Type[T]]: """When dealing with self-joins, running `register_dataset_to_schema()` is not enough. diff --git a/typedspark/_utils/replace_illegal_column_names.py b/typedspark/_utils/replace_illegal_column_names.py new file mode 100644 index 00000000..418d3818 --- /dev/null +++ b/typedspark/_utils/replace_illegal_column_names.py @@ -0,0 +1,45 @@ +"""Functions for loading `DataSet` and `Schema` in notebooks.""" + +import re +from typing import Dict + +from pyspark.sql import DataFrame + + +def replace_illegal_column_names(dataframe: DataFrame) -> DataFrame: + """Replaces illegal column names with a legal version.""" + mapping = _create_mapping(dataframe) + + for column, column_renamed in mapping.items(): + if column != column_renamed: + dataframe = dataframe.withColumnRenamed(column, column_renamed) + + return dataframe + + +def _create_mapping(dataframe: DataFrame) -> Dict[str, str]: + """Checks if there are duplicate columns after replacing illegal characters.""" + mapping = {column: _replace_illegal_characters(column) for column in dataframe.columns} + renamed_columns = list(mapping.values()) + duplicates = { + column: column_renamed + for column, column_renamed in mapping.items() + if renamed_columns.count(column_renamed) > 1 + } + + if len(duplicates) > 0: + raise ValueError( + "You're trying to dynamically generate a Schema from a DataFrame. " + + "However, typedspark has detected that the DataFrame contains duplicate columns " + + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" + + "The folowing columns have lead to duplicates:\n" + + f"{duplicates}\n\n" + + "Please rename these columns in your DataFrame." + ) + + return mapping + + +def _replace_illegal_characters(column_name: str) -> str: + """Replaces illegal characters in a column name with an underscore.""" + return re.sub("[^A-Za-z0-9]", "_", column_name) From f1940289083338aff3f6cd3d096f4490ffdff210 Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Sun, 30 Jun 2024 13:53:24 +0200 Subject: [PATCH 02/10] update --- tests/_core/test_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index 9c2013fc..7664f3b3 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -143,8 +143,7 @@ def test_resetting_of_schema_annotations(spark: SparkSession): def test_from_dataframe(spark: SparkSession): df = spark.createDataFrame([(1, "a"), (2, "b")], ["a", "b"]) - ds, schema = DataSet[A].from_dataframe(df) + ds, _ = DataSet[A].from_dataframe(df) assert isinstance(ds, DataSet) - assert issubclass(schema, A) assert_df_equality(ds, df) From f4425aed3865ea74fd9bae1af98467e1fb34b21c Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Sun, 21 Jul 2024 11:18:07 +0200 Subject: [PATCH 03/10] update --- tests/_core/test_dataset.py | 27 +++++++- typedspark/_core/column_meta.py | 9 ++- typedspark/_core/dataset.py | 61 +++++++++++++++---- typedspark/_utils/load_table.py | 45 +++++++++++++- .../_utils/replace_illegal_column_names.py | 45 -------------- 5 files changed, 125 insertions(+), 62 deletions(-) delete mode 100644 typedspark/_utils/replace_illegal_column_names.py diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index 7664f3b3..5e6eebab 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -1,4 +1,5 @@ import functools +from typing import Annotated import pandas as pd import pytest @@ -8,6 +9,7 @@ from pyspark.sql.types import LongType, StringType from typedspark import Column, DataSet, Schema +from typedspark._core.column_meta import ColumnMeta from typedspark._core.dataset import DataSetImplements from typedspark._utils.create_dataset import create_empty_dataset @@ -143,7 +145,30 @@ def test_resetting_of_schema_annotations(spark: SparkSession): def test_from_dataframe(spark: SparkSession): df = spark.createDataFrame([(1, "a"), (2, "b")], ["a", "b"]) - ds, _ = DataSet[A].from_dataframe(df) + ds = DataSet[A].from_dataframe(df) assert isinstance(ds, DataSet) assert_df_equality(ds, df) + + df_2 = ds.to_dataframe() + + assert isinstance(df_2, DataFrame) + assert_df_equality(df_2, df) + + +class Person(Schema): + name: Annotated[Column[StringType], ColumnMeta(external_name="first-name")] + age: Column[LongType] + + +def test_from_dataframe_with_external_name(spark: SparkSession): + df = spark.createDataFrame([("Alice", 1), ("Bob", 2)], ["first-name", "age"]) + ds = DataSet[Person].from_dataframe(df) + + assert isinstance(ds, DataSet) + assert ds.columns == ["name", "age"] + + df_2 = ds.to_dataframe() + assert isinstance(df_2, DataFrame) + assert df_2.columns == ["first-name", "age"] + assert_df_equality(df_2, df) diff --git a/typedspark/_core/column_meta.py b/typedspark/_core/column_meta.py index 4f3746b4..1dab1e7a 100644 --- a/typedspark/_core/column_meta.py +++ b/typedspark/_core/column_meta.py @@ -20,7 +20,14 @@ class A(Schema): """ comment: Optional[str] = None + external_name: Optional[str] = None def get_metadata(self) -> Optional[Dict[str, str]]: """Returns the metadata of this column.""" - return {"comment": self.comment} if self.comment else None + res = {} + if self.comment: + res["comment"] = self.comment + if self.external_name: + res["external_name"] = self.external_name + + return res if res else None diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index a26f5cc9..92aebe00 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -25,8 +25,6 @@ from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema from typedspark._transforms.transform_to_schema import transform_to_schema -from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset -from typedspark._utils.replace_illegal_column_names import replace_illegal_column_names _Schema = TypeVar("_Schema", bound=Schema) _Protocol = TypeVar("_Protocol", bound=Schema, covariant=True) @@ -70,24 +68,63 @@ def typedspark_schema(self) -> Type[_Implementation]: return self._schema_annotations @classmethod - def from_dataframe( - cls, df: DataFrame - ) -> Tuple[DataSet[_Implementation], Type[_Implementation]]: + def from_dataframe(cls, df: DataFrame) -> Union[ + DataSet[_Implementation], + Tuple[ + DataSet[_Implementation], + Type[_Implementation], + ], + ]: """Converts a DataFrame to a DataSet and registers the Schema to the DataSet. + Also renames the columns to their internal names, for example to deal with + characters that are not allowed in class attribute names. - Also replaces "illegal" characters in the DataFrame's colnames (.e.g "test-result" - -> "test_result"), so they're compatible with the Schema (after all, Python doesn't allow - for characters such as dashes in attribute names). + .. code-block:: python + + class Person(Schema): + name: Annotation[Column[StringType], ColumnMeta(external_name="full-name")] + age: Column[LongType] + + df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"]) + ds = DataSet[Person].from_dataframe(df) """ if not hasattr(cls, "_schema_annotations"): # pragma: no cover raise SyntaxError("Please define a schema, e.g. `DataSet[Person].from_dataset(df)`.") schema = cls._schema_annotations # type: ignore - df = replace_illegal_column_names(df) - ds = transform_to_schema(df, schema) - schema = register_schema_to_dataset(ds, schema) - return ds, schema + for column in schema.get_structtype().fields: + if column.metadata: + df = df.withColumnRenamed( + column.metadata.get("external_name", column.name), column.name + ) + + return transform_to_schema(df, schema) + + def to_dataframe(self) -> DataFrame: + """Converts a DataSet to a DataFrame. Also renames the columns to their external + names. + + .. code-block:: python + + class Person(Schema): + name: Annotated[Column[StringType], ColumnMeta(external_name="full-name")] + age: Column[LongType] + + df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["name", "age"]) + ds = DataSet[Person].from_dataframe(df) + df = ds.to_dataframe() + """ + df = cast(DataFrame, self) + df.__class__ = DataFrame + + for column in self._schema_annotations.get_structtype().fields: + if column.metadata: + df = df.withColumnRenamed( + column.name, column.metadata.get("external_name", column.name) + ) + + return df """The following functions are equivalent to their parents in ``DataFrame``, but since they don't affect the ``Schema``, we can add type annotations here. diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index 361937ec..ab578442 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -1,6 +1,7 @@ """Functions for loading `DataSet` and `Schema` in notebooks.""" -from typing import Optional, Tuple, Type +import re +from typing import Dict, Optional, Tuple, Type from pyspark.sql import DataFrame, SparkSession @@ -8,7 +9,45 @@ from typedspark._schema.schema import Schema from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset -from typedspark._utils.replace_illegal_column_names import replace_illegal_column_names + + +def _replace_illegal_column_names(dataframe: DataFrame) -> DataFrame: + """Replaces illegal column names with a legal version.""" + mapping = _create_mapping(dataframe) + + for column, column_renamed in mapping.items(): + if column != column_renamed: + dataframe = dataframe.withColumnRenamed(column, column_renamed) + + return dataframe + + +def _create_mapping(dataframe: DataFrame) -> Dict[str, str]: + """Checks if there are duplicate columns after replacing illegal characters.""" + mapping = {column: _replace_illegal_characters(column) for column in dataframe.columns} + renamed_columns = list(mapping.values()) + duplicates = { + column: column_renamed + for column, column_renamed in mapping.items() + if renamed_columns.count(column_renamed) > 1 + } + + if len(duplicates) > 0: + raise ValueError( + "You're trying to dynamically generate a Schema from a DataFrame. " + + "However, typedspark has detected that the DataFrame contains duplicate columns " + + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" + + "The folowing columns have lead to duplicates:\n" + + f"{duplicates}\n\n" + + "Please rename these columns in your DataFrame." + ) + + return mapping + + +def _replace_illegal_characters(column_name: str) -> str: + """Replaces illegal characters in a column name with an underscore.""" + return re.sub("[^A-Za-z0-9]", "_", column_name) def create_schema( @@ -24,7 +63,7 @@ def create_schema( df, Person = create_schema(df) """ - dataframe = replace_illegal_column_names(dataframe) + dataframe = _replace_illegal_column_names(dataframe) schema = create_schema_from_structtype(dataframe.schema, schema_name) dataset = DataSet[schema](dataframe) # type: ignore schema = register_schema_to_dataset(dataset, schema) diff --git a/typedspark/_utils/replace_illegal_column_names.py b/typedspark/_utils/replace_illegal_column_names.py deleted file mode 100644 index 418d3818..00000000 --- a/typedspark/_utils/replace_illegal_column_names.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Functions for loading `DataSet` and `Schema` in notebooks.""" - -import re -from typing import Dict - -from pyspark.sql import DataFrame - - -def replace_illegal_column_names(dataframe: DataFrame) -> DataFrame: - """Replaces illegal column names with a legal version.""" - mapping = _create_mapping(dataframe) - - for column, column_renamed in mapping.items(): - if column != column_renamed: - dataframe = dataframe.withColumnRenamed(column, column_renamed) - - return dataframe - - -def _create_mapping(dataframe: DataFrame) -> Dict[str, str]: - """Checks if there are duplicate columns after replacing illegal characters.""" - mapping = {column: _replace_illegal_characters(column) for column in dataframe.columns} - renamed_columns = list(mapping.values()) - duplicates = { - column: column_renamed - for column, column_renamed in mapping.items() - if renamed_columns.count(column_renamed) > 1 - } - - if len(duplicates) > 0: - raise ValueError( - "You're trying to dynamically generate a Schema from a DataFrame. " - + "However, typedspark has detected that the DataFrame contains duplicate columns " - + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" - + "The folowing columns have lead to duplicates:\n" - + f"{duplicates}\n\n" - + "Please rename these columns in your DataFrame." - ) - - return mapping - - -def _replace_illegal_characters(column_name: str) -> str: - """Replaces illegal characters in a column name with an underscore.""" - return re.sub("[^A-Za-z0-9]", "_", column_name) From a3e6c4757496f3bbf8391560d1798ee095d184b7 Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Sun, 21 Jul 2024 11:21:27 +0200 Subject: [PATCH 04/10] remove test notebook --- test.ipynb | 67 ------------------------------------------------------ 1 file changed, 67 deletions(-) delete mode 100644 test.ipynb diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 45dfed27..00000000 --- a/test.ipynb +++ /dev/null @@ -1,67 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(DataFrame[a: bigint],\n", - " \n", - " from pyspark.sql.types import LongType\n", - " \n", - " from typedspark import Column, Schema\n", - " \n", - " \n", - " class A(Schema):\n", - " a: Column[LongType])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from pyspark.sql import SparkSession\n", - "from pyspark.sql.types import LongType\n", - "from typedspark import Schema, Column\n", - "from typedspark._core.dataset import DataSet\n", - "\n", - "\n", - "class A(Schema):\n", - " a: Column[LongType]\n", - "\n", - "\n", - "spark = SparkSession.Builder().getOrCreate()\n", - "\n", - "df = spark.createDataFrame([(1,)], [\"a\"])\n", - "\n", - "DataSet[A].from_dataframe(df)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "typedspark-3.11.9", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 331678f5fb25b822eaf3ddb8f26298edfece8c45 Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Sun, 21 Jul 2024 11:51:13 +0200 Subject: [PATCH 05/10] change interface to include register_to_schema() --- tests/_core/test_dataset.py | 4 ++-- typedspark/_core/dataset.py | 20 +++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/_core/test_dataset.py b/tests/_core/test_dataset.py index 5e6eebab..d58c5ade 100644 --- a/tests/_core/test_dataset.py +++ b/tests/_core/test_dataset.py @@ -145,7 +145,7 @@ def test_resetting_of_schema_annotations(spark: SparkSession): def test_from_dataframe(spark: SparkSession): df = spark.createDataFrame([(1, "a"), (2, "b")], ["a", "b"]) - ds = DataSet[A].from_dataframe(df) + ds, _ = DataSet[A].from_dataframe(df) assert isinstance(ds, DataSet) assert_df_equality(ds, df) @@ -163,7 +163,7 @@ class Person(Schema): def test_from_dataframe_with_external_name(spark: SparkSession): df = spark.createDataFrame([("Alice", 1), ("Bob", 2)], ["first-name", "age"]) - ds = DataSet[Person].from_dataframe(df) + ds, _ = DataSet[Person].from_dataframe(df) assert isinstance(ds, DataSet) assert ds.columns == ["name", "age"] diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 92aebe00..671a252d 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -25,6 +25,7 @@ from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema from typedspark._transforms.transform_to_schema import transform_to_schema +from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset _Schema = TypeVar("_Schema", bound=Schema) _Protocol = TypeVar("_Protocol", bound=Schema, covariant=True) @@ -68,13 +69,9 @@ def typedspark_schema(self) -> Type[_Implementation]: return self._schema_annotations @classmethod - def from_dataframe(cls, df: DataFrame) -> Union[ - DataSet[_Implementation], - Tuple[ - DataSet[_Implementation], - Type[_Implementation], - ], - ]: + def from_dataframe( + cls, df: DataFrame + ) -> Tuple[DataSet[_Implementation], Type[_Implementation]]: """Converts a DataFrame to a DataSet and registers the Schema to the DataSet. Also renames the columns to their internal names, for example to deal with characters that are not allowed in class attribute names. @@ -86,7 +83,7 @@ class Person(Schema): age: Column[LongType] df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"]) - ds = DataSet[Person].from_dataframe(df) + ds, schema = DataSet[Person].from_dataframe(df) """ if not hasattr(cls, "_schema_annotations"): # pragma: no cover raise SyntaxError("Please define a schema, e.g. `DataSet[Person].from_dataset(df)`.") @@ -98,8 +95,9 @@ class Person(Schema): df = df.withColumnRenamed( column.metadata.get("external_name", column.name), column.name ) - - return transform_to_schema(df, schema) + df = transform_to_schema(df, schema) + schema = register_schema_to_dataset(df, schema) + return df, schema def to_dataframe(self) -> DataFrame: """Converts a DataSet to a DataFrame. Also renames the columns to their external @@ -112,7 +110,7 @@ class Person(Schema): age: Column[LongType] df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["name", "age"]) - ds = DataSet[Person].from_dataframe(df) + ds, schema = DataSet[Person].from_dataframe(df) df = ds.to_dataframe() """ df = cast(DataFrame, self) From e96ab1a6cffce1781c4be2624c550ab7b96682fb Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Sat, 10 Aug 2024 14:50:11 +0200 Subject: [PATCH 06/10] update --- .../_utils/create_dataset_from_structtype.py | 45 ++++++++++++++++++- typedspark/_utils/load_table.py | 43 +----------------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/typedspark/_utils/create_dataset_from_structtype.py b/typedspark/_utils/create_dataset_from_structtype.py index 30bc7b8c..99704108 100644 --- a/typedspark/_utils/create_dataset_from_structtype.py +++ b/typedspark/_utils/create_dataset_from_structtype.py @@ -1,6 +1,7 @@ """Utility functions for creating a ``Schema`` from a ``StructType``""" -from typing import Dict, Literal, Optional, Type +import re +from typing import Annotated, Dict, Literal, Optional, Type from pyspark.sql.types import ArrayType as SparkArrayType from pyspark.sql.types import DataType @@ -10,6 +11,7 @@ from pyspark.sql.types import StructType as SparkStructType from typedspark._core.column import Column +from typedspark._core.column_meta import ColumnMeta from typedspark._core.datatypes import ( ArrayType, DayTimeIntervalType, @@ -27,10 +29,20 @@ def create_schema_from_structtype( """Dynamically builds a ``Schema`` based on a ``DataFrame``'s ``StructType``""" type_annotations = {} attributes: Dict[str, None] = {} + column_name_mapping = _create_column_name_mapping(structtype) + for column in structtype: name = column.name data_type = _extract_data_type(column.dataType, name) - type_annotations[name] = Column[data_type] # type: ignore + + mapped_name = column_name_mapping[name] + if mapped_name == name: + type_annotations[name] = Column[data_type] # type: ignore + else: + type_annotations[mapped_name] = Annotated[ + Column[data_type], ColumnMeta(external_name=name) + ] + attributes[name] = None if not schema_name: @@ -42,6 +54,35 @@ def create_schema_from_structtype( return schema # type: ignore +def _create_column_name_mapping(structtype: SparkStructType) -> Dict[str, str]: + """Checks if there are duplicate columns after replacing illegal characters.""" + mapping = {column: _replace_illegal_characters(column) for column in structtype.names} + + renamed_columns = list(mapping.values()) + duplicates = { + column: column_renamed + for column, column_renamed in mapping.items() + if renamed_columns.count(column_renamed) > 1 + } + + if len(duplicates) > 0: + raise ValueError( + "You're trying to dynamically generate a Schema from a DataFrame. " + + "However, typedspark has detected that the DataFrame contains duplicate columns " + + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" + + "The folowing columns have lead to duplicates:\n" + + f"{duplicates}\n\n" + + "Please rename these columns in your DataFrame." + ) + + return mapping + + +def _replace_illegal_characters(column_name: str) -> str: + """Replaces illegal characters in a column name with an underscore.""" + return re.sub("[^A-Za-z0-9]", "_", column_name) + + def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]: """Given an instance of a ``DataType``, it extracts the corresponding ``DataType`` class, potentially including annotations (e.g. ``ArrayType[StringType]``).""" diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index ab578442..0089386a 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -1,7 +1,6 @@ """Functions for loading `DataSet` and `Schema` in notebooks.""" -import re -from typing import Dict, Optional, Tuple, Type +from typing import Optional, Tuple, Type from pyspark.sql import DataFrame, SparkSession @@ -11,45 +10,6 @@ from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset -def _replace_illegal_column_names(dataframe: DataFrame) -> DataFrame: - """Replaces illegal column names with a legal version.""" - mapping = _create_mapping(dataframe) - - for column, column_renamed in mapping.items(): - if column != column_renamed: - dataframe = dataframe.withColumnRenamed(column, column_renamed) - - return dataframe - - -def _create_mapping(dataframe: DataFrame) -> Dict[str, str]: - """Checks if there are duplicate columns after replacing illegal characters.""" - mapping = {column: _replace_illegal_characters(column) for column in dataframe.columns} - renamed_columns = list(mapping.values()) - duplicates = { - column: column_renamed - for column, column_renamed in mapping.items() - if renamed_columns.count(column_renamed) > 1 - } - - if len(duplicates) > 0: - raise ValueError( - "You're trying to dynamically generate a Schema from a DataFrame. " - + "However, typedspark has detected that the DataFrame contains duplicate columns " - + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n" - + "The folowing columns have lead to duplicates:\n" - + f"{duplicates}\n\n" - + "Please rename these columns in your DataFrame." - ) - - return mapping - - -def _replace_illegal_characters(column_name: str) -> str: - """Replaces illegal characters in a column name with an underscore.""" - return re.sub("[^A-Za-z0-9]", "_", column_name) - - def create_schema( dataframe: DataFrame, schema_name: Optional[str] = None ) -> Tuple[DataSet[Schema], Type[Schema]]: @@ -63,7 +23,6 @@ def create_schema( df, Person = create_schema(df) """ - dataframe = _replace_illegal_column_names(dataframe) schema = create_schema_from_structtype(dataframe.schema, schema_name) dataset = DataSet[schema](dataframe) # type: ignore schema = register_schema_to_dataset(dataset, schema) From 7c56566942c4e67dcccb2264514e994317d9b3ab Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:21:37 +0200 Subject: [PATCH 07/10] fix full-name typo --- typedspark/_core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index 671a252d..a3c8dac5 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -79,7 +79,7 @@ def from_dataframe( .. code-block:: python class Person(Schema): - name: Annotation[Column[StringType], ColumnMeta(external_name="full-name")] + name: Annotation[Column[StringType], ColumnMeta(external_name="first-name")] age: Column[LongType] df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"]) From aa5738173b72e9bdd6acb83792ea449781b38215 Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:58:24 +0200 Subject: [PATCH 08/10] add functionality to generate schemas with external_name fields in a notebook --- test.ipynb | 107 ++++++++++++++++++ .../_utils/create_dataset_from_structtype.py | 34 +++--- typedspark/_utils/load_table.py | 5 +- 3 files changed, 128 insertions(+), 18 deletions(-) create mode 100644 test.ipynb diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 00000000..8acbe29f --- /dev/null +++ b/test.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import SparkSession\n", + "from typedspark import create_schema\n", + "\n", + "\n", + "spark = SparkSession.Builder().getOrCreate()\n", + "\n", + "df = spark.createDataFrame([(\"Alice\", 24), (\"Bob\", 25)], [\"first-name\", \"age\"])\n", + "\n", + "ds, schema = create_schema(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+----------+---+\n", + "|first_name|age|\n", + "+----------+---+\n", + "| Alice| 24|\n", + "| Bob| 25|\n", + "+----------+---+\n", + "\n" + ] + } + ], + "source": [ + "ds.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\n", + "from pyspark.sql.types import LongType, StringType\n", + "\n", + "from typedspark import Column, Schema\n", + "\n", + "\n", + "class DynamicallyLoadedSchema(Schema):\n", + " first_name: Annotated[Column[StringType], ColumnMeta(comment=None, external_name='first-name')]\n", + " age: Column[LongType]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "schema" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "typedspark-3.12.2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/typedspark/_utils/create_dataset_from_structtype.py b/typedspark/_utils/create_dataset_from_structtype.py index 99704108..9a2dc1dc 100644 --- a/typedspark/_utils/create_dataset_from_structtype.py +++ b/typedspark/_utils/create_dataset_from_structtype.py @@ -1,13 +1,14 @@ """Utility functions for creating a ``Schema`` from a ``StructType``""" import re -from typing import Annotated, Dict, Literal, Optional, Type +from typing import Annotated, Any, Dict, Literal, Optional, Tuple, Type from pyspark.sql.types import ArrayType as SparkArrayType from pyspark.sql.types import DataType from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType from pyspark.sql.types import DecimalType as SparkDecimalType from pyspark.sql.types import MapType as SparkMapType +from pyspark.sql.types import StructField from pyspark.sql.types import StructType as SparkStructType from typedspark._core.column import Column @@ -32,18 +33,9 @@ def create_schema_from_structtype( column_name_mapping = _create_column_name_mapping(structtype) for column in structtype: - name = column.name - data_type = _extract_data_type(column.dataType, name) - - mapped_name = column_name_mapping[name] - if mapped_name == name: - type_annotations[name] = Column[data_type] # type: ignore - else: - type_annotations[mapped_name] = Annotated[ - Column[data_type], ColumnMeta(external_name=name) - ] - - attributes[name] = None + column_name, column_class = _create_column(column, column_name_mapping) + type_annotations[column_name] = column_class + attributes[column_name] = None if not schema_name: schema_name = "DynamicallyLoadedSchema" @@ -55,7 +47,7 @@ def create_schema_from_structtype( def _create_column_name_mapping(structtype: SparkStructType) -> Dict[str, str]: - """Checks if there are duplicate columns after replacing illegal characters.""" + """Creates a mapping from the original column names to the renamed column names.""" mapping = {column: _replace_illegal_characters(column) for column in structtype.names} renamed_columns = list(mapping.values()) @@ -83,6 +75,20 @@ def _replace_illegal_characters(column_name: str) -> str: return re.sub("[^A-Za-z0-9]", "_", column_name) +def _create_column(column: StructField, column_name_mapping: Dict[str, str]) -> Tuple[str, Any]: + """Creates a column object, optionally with an `external_name` if the mapped_name is + different from the original name (due to illegal characters, such as `-` in the + original name).""" + name = column.name + mapped_name = column_name_mapping[name] + data_type = _extract_data_type(column.dataType, name) + + if mapped_name == name: + return name, Column[data_type] # type: ignore + + return mapped_name, Annotated[Column[data_type], ColumnMeta(external_name=name)] # type: ignore + + def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]: """Given an instance of a ``DataType``, it extracts the corresponding ``DataType`` class, potentially including annotations (e.g. ``ArrayType[StringType]``).""" diff --git a/typedspark/_utils/load_table.py b/typedspark/_utils/load_table.py index 0089386a..c437c3a7 100644 --- a/typedspark/_utils/load_table.py +++ b/typedspark/_utils/load_table.py @@ -7,7 +7,6 @@ from typedspark._core.dataset import DataSet from typedspark._schema.schema import Schema from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype -from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset def create_schema( @@ -24,9 +23,7 @@ def create_schema( df, Person = create_schema(df) """ schema = create_schema_from_structtype(dataframe.schema, schema_name) - dataset = DataSet[schema](dataframe) # type: ignore - schema = register_schema_to_dataset(dataset, schema) - return dataset, schema + return DataSet[schema].from_dataframe(dataframe) # type: ignore def load_table( From 9a9ac2923c854ff7cd6e9bd3d216d224bf5dc15c Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Sat, 21 Sep 2024 09:59:13 +0200 Subject: [PATCH 09/10] nested structtypes --- test.ipynb | 107 ----------------------------- tests/_utils/test_load_table.py | 75 +++++++++++++++++++- typedspark/_core/dataset.py | 12 ++-- typedspark/_core/rename_columns.py | 71 +++++++++++++++++++ 4 files changed, 150 insertions(+), 115 deletions(-) delete mode 100644 test.ipynb create mode 100644 typedspark/_core/rename_columns.py diff --git a/test.ipynb b/test.ipynb deleted file mode 100644 index 8acbe29f..00000000 --- a/test.ipynb +++ /dev/null @@ -1,107 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from pyspark.sql import SparkSession\n", - "from typedspark import create_schema\n", - "\n", - "\n", - "spark = SparkSession.Builder().getOrCreate()\n", - "\n", - "df = spark.createDataFrame([(\"Alice\", 24), (\"Bob\", 25)], [\"first-name\", \"age\"])\n", - "\n", - "ds, schema = create_schema(df)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+----------+---+\n", - "|first_name|age|\n", - "+----------+---+\n", - "| Alice| 24|\n", - "| Bob| 25|\n", - "+----------+---+\n", - "\n" - ] - } - ], - "source": [ - "ds.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\n", - "from pyspark.sql.types import LongType, StringType\n", - "\n", - "from typedspark import Column, Schema\n", - "\n", - "\n", - "class DynamicallyLoadedSchema(Schema):\n", - " first_name: Annotated[Column[StringType], ColumnMeta(comment=None, external_name='first-name')]\n", - " age: Column[LongType]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "schema" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "typedspark-3.12.2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/_utils/test_load_table.py b/tests/_utils/test_load_table.py index 115778c1..ab04aa04 100644 --- a/tests/_utils/test_load_table.py +++ b/tests/_utils/test_load_table.py @@ -2,7 +2,7 @@ import pytest from chispa.dataframe_comparer import assert_df_equality # type: ignore -from pyspark.sql import SparkSession +from pyspark.sql import Row, SparkSession from pyspark.sql.functions import first from pyspark.sql.types import IntegerType, StringType @@ -225,3 +225,76 @@ def test_get_spark_session_without_spark_session(): if SparkSession.getActiveSession() is None: with pytest.raises(ValueError): _get_spark_session(None) + + +def test_create_schema_with_invalid_column_name(spark: SparkSession): + df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"]) + ds, schema = create_schema(df) + + +def test_create_schema_with_invalid_column_name_in_a_structtype(spark: SparkSession): + data = [ + Row( + **{ + "full-name": Row( + **{ + "first-name": "Alice", + "last-name": "Smith", + }, + ), + "age": 24, + } + ), + Row( + **{ + "full-name": Row( + **{ + "first-name": "Bob", + "last-name": "Brown", + }, + ), + "age": 25, + }, + ), + ] + + df = spark.createDataFrame(data) + ds, schema = create_schema(df) + + +def test_create_schema_with_invalid_column_name_in_a_nested_structtype(spark: SparkSession): + data = [ + Row( + **{ + "details": Row( + **{ + "full-name": Row( + **{ + "first-name": "Alice", + "last-name": "Smith", + } + ), + "age": 24, + } + ) + } + ), + Row( + **{ + "details": Row( + **{ + "full-name": Row( + **{ + "first-name": "Bob", + "last-name": "Brown", + } + ), + "age": 25, + } + ) + } + ), + ] + + df = spark.createDataFrame(data) + ds, schema = create_schema(df) diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index a3c8dac5..a28d3562 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -22,6 +22,7 @@ from pyspark.sql import DataFrame from typing_extensions import Concatenate, ParamSpec +from typedspark._core.rename_columns import rename_columns from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema from typedspark._transforms.transform_to_schema import transform_to_schema @@ -70,7 +71,7 @@ def typedspark_schema(self) -> Type[_Implementation]: @classmethod def from_dataframe( - cls, df: DataFrame + cls, df: DataFrame, register_to_schema: bool = True ) -> Tuple[DataSet[_Implementation], Type[_Implementation]]: """Converts a DataFrame to a DataSet and registers the Schema to the DataSet. Also renames the columns to their internal names, for example to deal with @@ -90,13 +91,10 @@ class Person(Schema): schema = cls._schema_annotations # type: ignore - for column in schema.get_structtype().fields: - if column.metadata: - df = df.withColumnRenamed( - column.metadata.get("external_name", column.name), column.name - ) + df = rename_columns(df, schema.get_structtype()) df = transform_to_schema(df, schema) - schema = register_schema_to_dataset(df, schema) + if register_to_schema: + schema = register_schema_to_dataset(df, schema) return df, schema def to_dataframe(self) -> DataFrame: diff --git a/typedspark/_core/rename_columns.py b/typedspark/_core/rename_columns.py new file mode 100644 index 00000000..d148459c --- /dev/null +++ b/typedspark/_core/rename_columns.py @@ -0,0 +1,71 @@ +"""Helper functions to rename columns from their external name (defined in +`ColumnMeta(external_name=...)`) to their internal name.""" + +from typing import Optional + +from pyspark.sql import Column, DataFrame +from pyspark.sql.functions import col, lit, struct, when +from pyspark.sql.types import StructField, StructType + + +def rename_columns(df: DataFrame, schema: StructType) -> DataFrame: + """Helper functions to rename columns from their external name (defined in + `ColumnMeta(external_name=...)`) to their internal name.""" + for field in schema.fields: + internal_name = field.name + + if field.metadata and "external_name" in field.metadata: + external_name = field.metadata["external_name"] + df = df.withColumnRenamed(external_name, internal_name) + + if isinstance(field.dataType, StructType): + structtype = _create_renamed_structtype(field.dataType, internal_name) + df = df.withColumn(internal_name, structtype) + + return df + + +def _create_renamed_structtype( + schema: StructType, + parent: str, + full_parent_path: Optional[str] = None, +) -> Column: + if not full_parent_path: + full_parent_path = f"`{parent}`" + + mapping = [] + for field in schema.fields: + external_name = _get_external_name(field, full_parent_path) + + if isinstance(field.dataType, StructType): + mapping += [ + _create_renamed_structtype( + field.dataType, + parent=field.name, + full_parent_path=external_name, + ) + ] + else: + mapping += [col(external_name).alias(field.name)] + + return _produce_nested_structtype(mapping, parent, full_parent_path) + + +def _get_external_name(field: StructField, full_parent_path: str) -> str: + external_name = field.metadata.get("external_name", field.name) + return f"{full_parent_path}.`{external_name}`" + + +def _produce_nested_structtype( + mapping: list[Column], + parent: str, + full_parent_path: str, +) -> Column: + return ( + when( + col(full_parent_path).isNotNull(), + struct(*mapping), + ) + .otherwise(lit(None)) + .alias(parent) + ) From 0f01c99b2d387deb35180cea30e497733baf879c Mon Sep 17 00:00:00 2001 From: nanne-aben <47976799+nanne-aben@users.noreply.github.com> Date: Tue, 17 Dec 2024 07:08:04 +0100 Subject: [PATCH 10/10] update --- tests/_utils/test_load_table.py | 9 +++++ typedspark/_core/dataset.py | 12 +++--- typedspark/_core/rename_columns.py | 64 +++++++++++++++++++++++++++--- 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/tests/_utils/test_load_table.py b/tests/_utils/test_load_table.py index ab04aa04..1a53fbd6 100644 --- a/tests/_utils/test_load_table.py +++ b/tests/_utils/test_load_table.py @@ -231,6 +231,9 @@ def test_create_schema_with_invalid_column_name(spark: SparkSession): df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"]) ds, schema = create_schema(df) + df2 = ds.to_dataframe() + assert_df_equality(df, df2) + def test_create_schema_with_invalid_column_name_in_a_structtype(spark: SparkSession): data = [ @@ -261,6 +264,9 @@ def test_create_schema_with_invalid_column_name_in_a_structtype(spark: SparkSess df = spark.createDataFrame(data) ds, schema = create_schema(df) + df2 = ds.to_dataframe() + assert_df_equality(df, df2) + def test_create_schema_with_invalid_column_name_in_a_nested_structtype(spark: SparkSession): data = [ @@ -298,3 +304,6 @@ def test_create_schema_with_invalid_column_name_in_a_nested_structtype(spark: Sp df = spark.createDataFrame(data) ds, schema = create_schema(df) + + df2 = ds.to_dataframe() + assert_df_equality(df, df2) diff --git a/typedspark/_core/dataset.py b/typedspark/_core/dataset.py index a28d3562..4bf5a95c 100644 --- a/typedspark/_core/dataset.py +++ b/typedspark/_core/dataset.py @@ -22,7 +22,7 @@ from pyspark.sql import DataFrame from typing_extensions import Concatenate, ParamSpec -from typedspark._core.rename_columns import rename_columns +from typedspark._core.rename_columns import rename_columns, rename_columns_2 from typedspark._core.validate_schema import validate_schema from typedspark._schema.schema import Schema from typedspark._transforms.transform_to_schema import transform_to_schema @@ -91,7 +91,7 @@ class Person(Schema): schema = cls._schema_annotations # type: ignore - df = rename_columns(df, schema.get_structtype()) + df = rename_columns(df, schema) df = transform_to_schema(df, schema) if register_to_schema: schema = register_schema_to_dataset(df, schema) @@ -114,14 +114,12 @@ class Person(Schema): df = cast(DataFrame, self) df.__class__ = DataFrame - for column in self._schema_annotations.get_structtype().fields: - if column.metadata: - df = df.withColumnRenamed( - column.name, column.metadata.get("external_name", column.name) - ) + df = rename_columns_2(df, self._schema_annotations) return df + # return rename_columns(df, schema) + """The following functions are equivalent to their parents in ``DataFrame``, but since they don't affect the ``Schema``, we can add type annotations here. diff --git a/typedspark/_core/rename_columns.py b/typedspark/_core/rename_columns.py index d148459c..533423bd 100644 --- a/typedspark/_core/rename_columns.py +++ b/typedspark/_core/rename_columns.py @@ -1,17 +1,19 @@ """Helper functions to rename columns from their external name (defined in `ColumnMeta(external_name=...)`) to their internal name.""" -from typing import Optional +from typing import Optional, Type from pyspark.sql import Column, DataFrame from pyspark.sql.functions import col, lit, struct, when from pyspark.sql.types import StructField, StructType +from typedspark._schema.schema import Schema -def rename_columns(df: DataFrame, schema: StructType) -> DataFrame: + +def rename_columns(df: DataFrame, schema: Type[Schema]) -> DataFrame: """Helper functions to rename columns from their external name (defined in - `ColumnMeta(external_name=...)`) to their internal name.""" - for field in schema.fields: + `ColumnMeta(external_name=...)`) to their internal name (as used in the Schema).""" + for field in schema.get_structtype().fields: internal_name = field.name if field.metadata and "external_name" in field.metadata: @@ -25,6 +27,23 @@ def rename_columns(df: DataFrame, schema: StructType) -> DataFrame: return df +def rename_columns_2(df: DataFrame, schema: Type[Schema]) -> DataFrame: + """Helper functions to rename columns from their internal name (as used in the + Schema) to their external name (defined in `ColumnMeta(external_name=...)`).""" + for field in schema.get_structtype().fields: + internal_name = field.name + + if field.metadata and "external_name" in field.metadata: + external_name = field.metadata["external_name"] + df = df.withColumnRenamed(internal_name, external_name) # swap + + if isinstance(field.dataType, StructType): + structtype = _create_renamed_structtype_2(field.dataType, internal_name) + df = df.withColumn(external_name, structtype) # swap + + return df + + def _create_renamed_structtype( schema: StructType, parent: str, @@ -35,7 +54,7 @@ def _create_renamed_structtype( mapping = [] for field in schema.fields: - external_name = _get_external_name(field, full_parent_path) + external_name = _get_updated_parent_path(full_parent_path, field) if isinstance(field.dataType, StructType): mapping += [ @@ -51,11 +70,44 @@ def _create_renamed_structtype( return _produce_nested_structtype(mapping, parent, full_parent_path) -def _get_external_name(field: StructField, full_parent_path: str) -> str: +def _create_renamed_structtype_2( + schema: StructType, + parent: str, + full_parent_path: Optional[str] = None, +) -> Column: + if not full_parent_path: + full_parent_path = f"`{parent}`" + + mapping = [] + for field in schema.fields: + internal_name = field.name + external_name = field.metadata.get("external_name", internal_name) + + updated_parent_path = _get_updated_parent_path_2(full_parent_path, internal_name) # swap + + if isinstance(field.dataType, StructType): + mapping += [ + _create_renamed_structtype( + field.dataType, + parent=external_name, # swap + full_parent_path=updated_parent_path, + ) + ] + else: + mapping += [col(updated_parent_path).alias(external_name)] # swap + + return _produce_nested_structtype(mapping, parent, full_parent_path) + + +def _get_updated_parent_path(full_parent_path: str, field: StructField) -> str: external_name = field.metadata.get("external_name", field.name) return f"{full_parent_path}.`{external_name}`" +def _get_updated_parent_path_2(full_parent_path: str, field: str) -> str: + return f"{full_parent_path}.`{field}`" + + def _produce_nested_structtype( mapping: list[Column], parent: str,