Skip to content

Commit 0f01c99

Browse files
committed
update
1 parent 9a9ac29 commit 0f01c99

File tree

3 files changed

+72
-13
lines changed

3 files changed

+72
-13
lines changed

tests/_utils/test_load_table.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ def test_create_schema_with_invalid_column_name(spark: SparkSession):
231231
df = spark.createDataFrame([("Alice", 24), ("Bob", 25)], ["first-name", "age"])
232232
ds, schema = create_schema(df)
233233

234+
df2 = ds.to_dataframe()
235+
assert_df_equality(df, df2)
236+
234237

235238
def test_create_schema_with_invalid_column_name_in_a_structtype(spark: SparkSession):
236239
data = [
@@ -261,6 +264,9 @@ def test_create_schema_with_invalid_column_name_in_a_structtype(spark: SparkSess
261264
df = spark.createDataFrame(data)
262265
ds, schema = create_schema(df)
263266

267+
df2 = ds.to_dataframe()
268+
assert_df_equality(df, df2)
269+
264270

265271
def test_create_schema_with_invalid_column_name_in_a_nested_structtype(spark: SparkSession):
266272
data = [
@@ -298,3 +304,6 @@ def test_create_schema_with_invalid_column_name_in_a_nested_structtype(spark: Sp
298304

299305
df = spark.createDataFrame(data)
300306
ds, schema = create_schema(df)
307+
308+
df2 = ds.to_dataframe()
309+
assert_df_equality(df, df2)

typedspark/_core/dataset.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pyspark.sql import DataFrame
2323
from typing_extensions import Concatenate, ParamSpec
2424

25-
from typedspark._core.rename_columns import rename_columns
25+
from typedspark._core.rename_columns import rename_columns, rename_columns_2
2626
from typedspark._core.validate_schema import validate_schema
2727
from typedspark._schema.schema import Schema
2828
from typedspark._transforms.transform_to_schema import transform_to_schema
@@ -91,7 +91,7 @@ class Person(Schema):
9191

9292
schema = cls._schema_annotations # type: ignore
9393

94-
df = rename_columns(df, schema.get_structtype())
94+
df = rename_columns(df, schema)
9595
df = transform_to_schema(df, schema)
9696
if register_to_schema:
9797
schema = register_schema_to_dataset(df, schema)
@@ -114,14 +114,12 @@ class Person(Schema):
114114
df = cast(DataFrame, self)
115115
df.__class__ = DataFrame
116116

117-
for column in self._schema_annotations.get_structtype().fields:
118-
if column.metadata:
119-
df = df.withColumnRenamed(
120-
column.name, column.metadata.get("external_name", column.name)
121-
)
117+
df = rename_columns_2(df, self._schema_annotations)
122118

123119
return df
124120

121+
# return rename_columns(df, schema)
122+
125123
"""The following functions are equivalent to their parents in ``DataFrame``, but
126124
since they don't affect the ``Schema``, we can add type annotations here.
127125

typedspark/_core/rename_columns.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
"""Helper functions to rename columns from their external name (defined in
22
`ColumnMeta(external_name=...)`) to their internal name."""
33

4-
from typing import Optional
4+
from typing import Optional, Type
55

66
from pyspark.sql import Column, DataFrame
77
from pyspark.sql.functions import col, lit, struct, when
88
from pyspark.sql.types import StructField, StructType
99

10+
from typedspark._schema.schema import Schema
1011

11-
def rename_columns(df: DataFrame, schema: StructType) -> DataFrame:
12+
13+
def rename_columns(df: DataFrame, schema: Type[Schema]) -> DataFrame:
1214
"""Helper functions to rename columns from their external name (defined in
13-
`ColumnMeta(external_name=...)`) to their internal name."""
14-
for field in schema.fields:
15+
`ColumnMeta(external_name=...)`) to their internal name (as used in the Schema)."""
16+
for field in schema.get_structtype().fields:
1517
internal_name = field.name
1618

1719
if field.metadata and "external_name" in field.metadata:
@@ -25,6 +27,23 @@ def rename_columns(df: DataFrame, schema: StructType) -> DataFrame:
2527
return df
2628

2729

30+
def rename_columns_2(df: DataFrame, schema: Type[Schema]) -> DataFrame:
31+
"""Helper functions to rename columns from their internal name (as used in the
32+
Schema) to their external name (defined in `ColumnMeta(external_name=...)`)."""
33+
for field in schema.get_structtype().fields:
34+
internal_name = field.name
35+
36+
if field.metadata and "external_name" in field.metadata:
37+
external_name = field.metadata["external_name"]
38+
df = df.withColumnRenamed(internal_name, external_name) # swap
39+
40+
if isinstance(field.dataType, StructType):
41+
structtype = _create_renamed_structtype_2(field.dataType, internal_name)
42+
df = df.withColumn(external_name, structtype) # swap
43+
44+
return df
45+
46+
2847
def _create_renamed_structtype(
2948
schema: StructType,
3049
parent: str,
@@ -35,7 +54,7 @@ def _create_renamed_structtype(
3554

3655
mapping = []
3756
for field in schema.fields:
38-
external_name = _get_external_name(field, full_parent_path)
57+
external_name = _get_updated_parent_path(full_parent_path, field)
3958

4059
if isinstance(field.dataType, StructType):
4160
mapping += [
@@ -51,11 +70,44 @@ def _create_renamed_structtype(
5170
return _produce_nested_structtype(mapping, parent, full_parent_path)
5271

5372

54-
def _get_external_name(field: StructField, full_parent_path: str) -> str:
73+
def _create_renamed_structtype_2(
74+
schema: StructType,
75+
parent: str,
76+
full_parent_path: Optional[str] = None,
77+
) -> Column:
78+
if not full_parent_path:
79+
full_parent_path = f"`{parent}`"
80+
81+
mapping = []
82+
for field in schema.fields:
83+
internal_name = field.name
84+
external_name = field.metadata.get("external_name", internal_name)
85+
86+
updated_parent_path = _get_updated_parent_path_2(full_parent_path, internal_name) # swap
87+
88+
if isinstance(field.dataType, StructType):
89+
mapping += [
90+
_create_renamed_structtype(
91+
field.dataType,
92+
parent=external_name, # swap
93+
full_parent_path=updated_parent_path,
94+
)
95+
]
96+
else:
97+
mapping += [col(updated_parent_path).alias(external_name)] # swap
98+
99+
return _produce_nested_structtype(mapping, parent, full_parent_path)
100+
101+
102+
def _get_updated_parent_path(full_parent_path: str, field: StructField) -> str:
55103
external_name = field.metadata.get("external_name", field.name)
56104
return f"{full_parent_path}.`{external_name}`"
57105

58106

107+
def _get_updated_parent_path_2(full_parent_path: str, field: str) -> str:
108+
return f"{full_parent_path}.`{field}`"
109+
110+
59111
def _produce_nested_structtype(
60112
mapping: list[Column],
61113
parent: str,

0 commit comments

Comments
 (0)