Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions libraries/dagster-delta/dagster_delta/_handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,25 @@ def handle_output(
logger = logging.getLogger()
logger.setLevel("DEBUG")
definition_metadata = context.definition_metadata or {}
output_metadata = context.output_metadata or {}
# Gets merge_predicate or merge_operations_config in this order: runtime metadata -> definition metadata -> IO Manager config
merge_predicate_from_metadata = output_metadata.get(
"merge_predicate",
)
if merge_predicate_from_metadata is not None:
merge_predicate_from_metadata = merge_predicate_from_metadata.value
if merge_predicate_from_metadata is None:
merge_predicate_from_metadata = definition_metadata.get("merge_predicate")

merge_predicate_from_metadata = definition_metadata.get("merge_predicate")
merge_operations_config_from_metadata = definition_metadata.get("merge_operations_config")

merge_operations_config_from_metadata = output_metadata.get(
"merge_operations_config",
)
if merge_operations_config_from_metadata is not None:
merge_operations_config_from_metadata = merge_operations_config_from_metadata.value
if merge_operations_config_from_metadata is None:
merge_operations_config_from_metadata = definition_metadata.get(
"merge_operations_config",
)
additional_table_config = definition_metadata.get("table_configuration", {})
if connection.table_config is not None:
table_config = additional_table_config | connection.table_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pyarrow as pa
import pytest
from dagster import (
OpExecutionContext,
Out,
graph,
op,
Expand All @@ -20,6 +21,41 @@
from dagster_delta.config import WhenMatchedUpdateAll, WhenNotMatchedInsertAll


@op(out=Out(metadata={"schema": "a_df2"}))
def a_df2(context: OpExecutionContext) -> pa.Table:
context.add_output_metadata(
{
"merge_predicate": "s.a = t.a",
"merge_operations_config": MergeOperationsConfig(
when_not_matched_insert_all=[WhenNotMatchedInsertAll()],
when_matched_update_all=[WhenMatchedUpdateAll()],
).model_dump(),
},
)

return pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})


@op(out=Out(metadata={"schema": "add_one2"}))
def add_one2(context: OpExecutionContext, df: pa.Table): # noqa: ANN201
context.add_output_metadata(
{
"merge_predicate": "s.a = t.a",
"merge_operations_config": MergeOperationsConfig(
when_not_matched_insert_all=[WhenNotMatchedInsertAll()],
when_matched_update_all=[WhenMatchedUpdateAll()],
).model_dump(),
},
)

return df.set_column(0, "a", pa.array([2, 3, 4]))


@graph
def add_one_to_dataframe_2():
add_one2(a_df2())


@op(out=Out(metadata={"schema": "a_df"}))
def a_df() -> pa.Table:
return pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand Down Expand Up @@ -102,3 +138,35 @@ def test_deltalake_io_manager_custom_merge(tmp_path):
dt = DeltaTable(os.path.join(tmp_path, "add_one/result"))
out_df = dt.to_pyarrow_table()
assert sorted(out_df["a"].to_pylist()) == [2, 3, 4]


def test_deltalake_io_manager_runtime_metadata_merge_configuration(tmp_path):
"""Whether runtime metadata 'merge_predicate' and 'merge_operations_config' gets picked up."""
resource_defs = {
"io_manager": DeltaLakePyarrowIOManager(
root_uri=str(tmp_path),
storage_options=LocalConfig(),
mode=WriteMode.merge,
merge_config=MergeConfig(
merge_type=MergeType.custom,
source_alias="s",
target_alias="t",
),
),
}

job = add_one_to_dataframe_2.to_job(resource_defs=resource_defs)

# run the job twice to ensure that tables get properly deleted
for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df2/result"))
out_df = dt.to_pyarrow_table()
assert sorted(out_df["a"].to_pylist()) == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "add_one2/result"))
out_df = dt.to_pyarrow_table()
assert sorted(out_df["a"].to_pylist()) == [2, 3, 4]
2 changes: 1 addition & 1 deletion libraries/dagster-delta/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dagster-delta"
version = "0.4.0"
version = "0.4.1"
description = "Deltalake IO Managers for Dagster with pyarrow and Polars support."
readme = "README.md"
requires-python = ">=3.9"
Expand Down
2 changes: 1 addition & 1 deletion libraries/dagster-delta/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.