Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion typedspark/_transforms/transform_to_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from typedspark._core.dataset import DataSet
from typedspark._schema.schema import Schema
from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns
from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings
from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings, \
add_nulls_for_unspecified_nested_fields

T = TypeVar("T", bound=Schema)

Expand Down Expand Up @@ -44,6 +45,7 @@ def transform_to_schema(
schema: Type[T],
transformations: Optional[Dict[Column, SparkColumn]] = None,
fill_unspecified_columns_with_nulls: bool = False,
fill_unspecified_inner_fields_with_nulls: bool = False,
run_sequentially: bool = True,
) -> DataSet[T]:
"""On the provided DataFrame ``df``, it performs the ``transformations`` (if
Expand All @@ -69,6 +71,9 @@ def transform_to_schema(
if fill_unspecified_columns_with_nulls:
transform = add_nulls_for_unspecified_columns(transform, schema, dataframe.columns)

if fill_unspecified_inner_fields_with_nulls:
transform = add_nulls_for_unspecified_nested_fields(transform, schema, dataframe.schema)

transform = RenameDuplicateColumns(transform, schema, dataframe.columns)

return DataSet[schema]( # type: ignore
Expand Down
44 changes: 44 additions & 0 deletions typedspark/_transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from pyspark.sql import Column as SparkColumn
from pyspark.sql.functions import lit
from pyspark.sql.types import StructType

from typedspark import structtype_column
from typedspark._core.column import Column
from typedspark._schema.schema import Schema

Expand Down Expand Up @@ -44,3 +46,45 @@ def convert_keys_to_strings(
)

return _transformations

def _transform_struct_column_fill_nulls(
ts_schema: Schema,
data_schema : StructType,
struct_column: Column,
):
return structtype_column(
ts_schema,
{
getattr(ts_schema, field_name): (
_transform_struct_column_fill_nulls(
getattr(ts_schema, field_name).dtype,
data_schema[field_name].dataType,
getattr(struct_column, field_name)
)
if data_schema[field_name].dataType == StructType
else
getattr(struct_column, field_name)
)
for field_name in data_schema.fieldNames()
},
fill_unspecified_columns_with_nulls=True,
)

def add_nulls_for_unspecified_nested_fields(
transformations: Dict[str, SparkColumn],
schema: Type[Schema],
data_schema: StructType,
) -> Dict[str, SparkColumn]:
"""Takes the columns from the target schema that are structs and not present in the
transformations and validates that all their fields are in the source data, if they are not
then fill them with nulls)."""
for field in schema.get_structtype().fields:
if (
field.name not in transformations and
field.dataType == StructType and
field.name in data_schema.fieldNames):
transformations[field.name] = _transform_struct_column_fill_nulls(
schema, data_schema, getattr(schema, field.name)
)

return transformations
Loading