diff --git a/typedspark/_transforms/transform_to_schema.py b/typedspark/_transforms/transform_to_schema.py index e90d4786..e951f4b6 100644 --- a/typedspark/_transforms/transform_to_schema.py +++ b/typedspark/_transforms/transform_to_schema.py @@ -10,7 +10,8 @@ from typedspark._core.dataset import DataSet from typedspark._schema.schema import Schema from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns -from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings +from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings, \ + add_nulls_for_unspecified_nested_fields T = TypeVar("T", bound=Schema) @@ -44,6 +45,7 @@ def transform_to_schema( schema: Type[T], transformations: Optional[Dict[Column, SparkColumn]] = None, fill_unspecified_columns_with_nulls: bool = False, + fill_unspecified_inner_fields_with_nulls: bool = False, run_sequentially: bool = True, ) -> DataSet[T]: """On the provided DataFrame ``df``, it performs the ``transformations`` (if @@ -69,6 +71,9 @@ def transform_to_schema( if fill_unspecified_columns_with_nulls: transform = add_nulls_for_unspecified_columns(transform, schema, dataframe.columns) + if fill_unspecified_inner_fields_with_nulls: + transform = add_nulls_for_unspecified_nested_fields(transform, schema, dataframe.schema) + transform = RenameDuplicateColumns(transform, schema, dataframe.columns) return DataSet[schema]( # type: ignore diff --git a/typedspark/_transforms/utils.py b/typedspark/_transforms/utils.py index 2eb65332..082d12b9 100644 --- a/typedspark/_transforms/utils.py +++ b/typedspark/_transforms/utils.py @@ -4,7 +4,9 @@ from pyspark.sql import Column as SparkColumn from pyspark.sql.functions import lit +from pyspark.sql.types import StructType +from typedspark import structtype_column from typedspark._core.column import Column from typedspark._schema.schema import Schema @@ -44,3 +46,45 @@ def convert_keys_to_strings( ) return _transformations + +def _transform_struct_column_fill_nulls( + ts_schema: Schema, + data_schema : StructType, + struct_column: Column, +): + return structtype_column( + ts_schema, + { + getattr(ts_schema, field_name): ( + _transform_struct_column_fill_nulls( + getattr(ts_schema, field_name).dtype, + data_schema[field_name].dataType, + getattr(struct_column, field_name) + ) + if data_schema[field_name].dataType == StructType + else + getattr(struct_column, field_name) + ) + for field_name in data_schema.fieldNames() + }, + fill_unspecified_columns_with_nulls=True, + ) + +def add_nulls_for_unspecified_nested_fields( + transformations: Dict[str, SparkColumn], + schema: Type[Schema], + data_schema: StructType, +) -> Dict[str, SparkColumn]: + """Takes the columns from the target schema that are structs and not present in the + transformations and validates that all their fields are in the source data, if they are not + then fill them with nulls).""" + for field in schema.get_structtype().fields: + if ( + field.name not in transformations and + field.dataType == StructType and + field.name in data_schema.fieldNames): + transformations[field.name] = _transform_struct_column_fill_nulls( + schema, data_schema, getattr(schema, field.name) + ) + + return transformations