Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions libraries/dagster-delta/dagster_delta/_handler/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import abstractmethod
from typing import Any, Generic, Optional, TypeVar, Union, cast
from typing import Any, Generic, Optional, TypeVar, Union, cast, Dict

import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -53,6 +53,50 @@ def get_output_stats(self, obj: T) -> dict[str, MetadataValue]:
"""Abstract method to return output stats"""
pass

def _find_keys_in_metadata(
self,
context: OutputContext,
keys=["merge_predicate", "merge_operations_config"],
) -> Dict[str, Any]:
"""Finds the keys in the metadata in the following order:

It will find the merge_predicate or merge_operations_config in this order:
1. Runtime metadata
2. Definition metadata
3. IO Manager config

E.g., `merge_predicate` and `merge_operations_config`

Args:
context (OutputContext): The output context
keys (list[str], optional): The keys to find in the metadata. Defaults to ["merge_predicate", "merge_operations_config"].

Returns:
dict[str, Any]: The metadata with the keys found
"""
metadata_definition = context.definition_metadata or {}
metadata_output = context.output_metadata or {}

# Find each of the key in the definition or output metadata
result = {}

for key in keys:
if key in metadata_output or {}:
result[key] = metadata_output[key]
elif key in metadata_definition or {}:
result[key] = metadata_definition[key]
else:
result[key] = None

# If it's a TextMetadataValue, cast it to string
if isinstance(result[key], MetadataValue):
if result[key].value is not None:
result[key] = str(result[key].value)
else:
result[key] = None

return result

def handle_output(
self,
context: OutputContext,
Expand All @@ -63,26 +107,17 @@ def handle_output(
"""Stores pyarrow types in Delta table."""
logger = logging.getLogger()
logger.setLevel("DEBUG")
definition_metadata = context.definition_metadata or {}
output_metadata = context.output_metadata or {}
# Gets merge_predicate or merge_operations_config in this order: runtime metadata -> definition metadata -> IO Manager config
merge_predicate_from_metadata = output_metadata.get(
"merge_predicate",

keys_from_metadata = self._find_keys_in_metadata(
context, ["merge_predicate", "merge_operations_config"]
)
if merge_predicate_from_metadata is not None:
merge_predicate_from_metadata = merge_predicate_from_metadata.value
if merge_predicate_from_metadata is None:
merge_predicate_from_metadata = definition_metadata.get("merge_predicate")

merge_operations_config_from_metadata = output_metadata.get(
"merge_operations_config",
merge_predicate_from_metadata = keys_from_metadata.get("merge_predicate", None)
merge_operations_config_from_metadata = keys_from_metadata.get(
"merge_operations_config", None
)
if merge_operations_config_from_metadata is not None:
merge_operations_config_from_metadata = merge_operations_config_from_metadata.value
if merge_operations_config_from_metadata is None:
merge_operations_config_from_metadata = definition_metadata.get(
"merge_operations_config",
)

definition_metadata = context.definition_metadata or {}
additional_table_config = definition_metadata.get("table_configuration", {})
if connection.table_config is not None:
table_config = additional_table_config | connection.table_config
Expand Down
Loading