Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/

# C extensions
*.so
.ds_store

# Distribution / packaging
.Python
Expand Down
2 changes: 0 additions & 2 deletions libraries/dagster-delta/dagster_delta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
BaseDeltaLakeIOManager,
SchemaMode,
WriteMode,
WriterEngine,
)
from dagster_delta.resources import DeltaTableResource

Expand All @@ -35,7 +34,6 @@
"MergeConfig",
"MergeType",
"WriteMode",
"WriterEngine",
"SchemaMode",
"DeltaTableResource",
"BaseDeltaLakeIOManager",
Expand Down
169 changes: 68 additions & 101 deletions libraries/dagster-delta/dagster_delta/_handler/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from abc import abstractmethod
from typing import Any, Generic, Optional, TypeVar, Union, cast
from typing import Any, Generic, TypeVar, Union, cast

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
from arro3.core import RecordBatchReader, Table
from arro3.core.types import ArrowArrayExportable, ArrowStreamExportable
from dagster import (
InputContext,
MetadataValue,
Expand All @@ -13,17 +12,16 @@
TableSchema,
)
from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice
from deltalake import CommitProperties, DeltaTable, WriterProperties, write_deltalake
from deltalake import CommitProperties, DeltaTable, QueryBuilder, WriterProperties, write_deltalake
from deltalake.exceptions import TableNotFoundError
from deltalake.schema import Schema, _convert_pa_schema_to_delta
from deltalake.table import FilterLiteralType
from deltalake.schema import Schema
from deltalake.writer._conversion import _convert_arro3_schema_to_delta

from dagster_delta._handler.merge import merge_execute
from dagster_delta._handler.utils import (
create_predicate,
extract_date_format_from_partition_definition,
partition_dimensions_to_dnf,
read_table,
)
from dagster_delta.config import MergeConfig
from dagster_delta.io_manager.base import (
Expand All @@ -32,19 +30,30 @@
)

T = TypeVar("T")
ArrowTypes = Union[pa.Table, pa.RecordBatchReader, ds.Dataset]

ArrowTypes = Union[RecordBatchReader, Table]
try:
import pyarrow as pa

ArrowTypes = Union[RecordBatchReader, Table, pa.Table, pa.RecordBatchReader]
except ImportError:
pass


class DeltalakeBaseArrowTypeHandler(DbTypeHandler[T], Generic[T]):
"""Base TypeHandler implementation for arrow supported libraries used to handle deltalake IO."""

@abstractmethod
def from_arrow(self, obj: pa.RecordBatchReader, target_type: type) -> T:
def from_arrow(
self,
obj: Union[ArrowStreamExportable, ArrowArrayExportable],
target_type: type,
) -> T:
"""Abstract method to convert arrow to target type"""
pass

@abstractmethod
def to_arrow(self, obj: T) -> tuple[ArrowTypes, dict[str, Any]]:
def to_arrow(self, obj: T) -> RecordBatchReader: # type: ignore
"""Abstract method to convert type to arrow"""
pass

Expand Down Expand Up @@ -122,18 +131,18 @@ def handle_output(
table_config = additional_table_config
resource_config = context.resource_config or {}
object_stats = self.get_output_stats(obj)
data, delta_params = self.to_arrow(obj=obj)
delta_schema = Schema.from_pyarrow(_convert_pa_schema_to_delta(data.schema))

data = self.to_arrow(obj=obj)
delta_schema = Schema.from_arrow(_convert_arro3_schema_to_delta(data.schema))
resource_config = cast(_DeltaTableIOManagerResourceConfig, context.resource_config)
engine = resource_config.get("writer_engine")
save_mode = definition_metadata.get("mode")
main_save_mode = resource_config.get("mode")
custom_metadata = definition_metadata.get("custom_metadata") or resource_config.get(
"custom_metadata",
)
schema_mode = definition_metadata.get("schema_mode") or resource_config.get(
"schema_mode",
)
if schema_mode is not None:
schema_mode = str(schema_mode)

writer_properties = resource_config.get("writer_properties")
writer_properties = (
WriterProperties(**writer_properties) if writer_properties is not None else None # type: ignore
Expand All @@ -159,56 +168,44 @@ def handle_output(
logger.debug("Writing with mode: `%s`", main_save_mode)

merge_stats = None
partition_filters = None
partition_columns = None
predicate = None

if table_slice.partition_dimensions is not None:
partition_filters = partition_dimensions_to_dnf(
partition_dimensions=table_slice.partition_dimensions,
table_schema=delta_schema,
str_values=True,
date_format=date_format,
)
if partition_filters is not None and engine == "rust":
if partition_filters is not None:
## Convert partition_filter to predicate
predicate = create_predicate(partition_filters)
partition_filters = None
else:
predicate = None
# TODO(): make robust and move to function

partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions]

if main_save_mode not in ["merge", "create_or_replace"]:
if predicate is not None and engine == "rust":
if predicate is not None:
logger.debug("Using explicit partition predicate: \n%s", predicate)
elif partition_filters is not None and engine == "pyarrow":
logger.debug("Using explicit partition_filter: \n%s", partition_filters)
write_deltalake( # type: ignore
table_or_uri=connection.table_uri,
data=data,
storage_options=connection.storage_options,
mode=main_save_mode,
partition_filters=partition_filters,
mode=main_save_mode, # type: ignore
predicate=predicate,
partition_by=partition_columns,
engine=engine,
schema_mode=schema_mode,
schema_mode=schema_mode, # type: ignore
configuration=table_config,
custom_metadata=custom_metadata,
writer_properties=writer_properties,
writer_properties=writer_properties, # type: ignore
commit_properties=commit_properties,
**delta_params,
)
elif main_save_mode == "create_or_replace":
DeltaTable.create(
table_uri=connection.table_uri,
schema=_convert_pa_schema_to_delta(data.schema),
schema=delta_schema,
mode="overwrite",
partition_by=partition_columns,
configuration=table_config,
storage_options=connection.storage_options,
custom_metadata=custom_metadata,
)
else:
if merge_config is None:
Expand All @@ -221,31 +218,24 @@ def handle_output(
logger.debug("Creating a DeltaTable first before merging.")
dt = DeltaTable.create(
table_uri=connection.table_uri,
schema=_convert_pa_schema_to_delta(data.schema),
schema=delta_schema,
partition_by=partition_columns,
configuration=table_config,
storage_options=connection.storage_options,
custom_metadata=custom_metadata,
commit_properties=commit_properties,
)
merge_stats = merge_execute(
dt,
data,
MergeConfig.model_validate(merge_config),
writer_properties=writer_properties,
commit_properties=commit_properties,
custom_metadata=custom_metadata,
delta_params=delta_params,
merge_predicate_from_metadata=merge_predicate_from_metadata,
merge_operations_config=merge_operations_config_from_metadata,
partition_filters=partition_filters,
)

dt = DeltaTable(connection.table_uri, storage_options=connection.storage_options)
try:
stats = _get_partition_stats(dt=dt, partition_filters=partition_filters)
except Exception as e:
context.log.warning(f"error while computing table stats: {e}")
stats = {}

output_metadata = {
# "dagster/table_name": table_slice.table,
Expand All @@ -255,12 +245,15 @@ def handle_output(
TableSchema(
columns=[
TableColumn(name=name, type=str(dtype))
for name, dtype in zip(data.schema.names, data.schema.types)
for name, dtype in zip(
delta_schema.to_arrow().names,
delta_schema.to_arrow().types,
)
],
),
),
"table_version": MetadataValue.int(dt.version()),
**stats,
# **stats,
**object_stats,
}
if merge_stats is not None:
Expand All @@ -277,59 +270,33 @@ def load_input(
table_slice: TableSlice,
connection: TableConnection,
) -> T:
"""Loads the input as a pyarrow Table or RecordBatchReader."""
parquet_read_options = None
if context.resource_config is not None:
parquet_read_options = context.resource_config.get("parquet_read_options", None)
parquet_read_options = (
ds.ParquetReadOptions(**parquet_read_options)
if parquet_read_options is not None
else None
"""Loads the input as a arro3 Table or RecordBatchReader."""
table = DeltaTable(
table_uri=connection.table_uri,
storage_options=connection.storage_options,
)
logger = logging.getLogger()
logger.setLevel("DEBUG")
logger.debug("Connection timeout duration %s", connection.storage_options.get("timeout"))

predicate = None
if table_slice.partition_dimensions is not None:
partition_filters = partition_dimensions_to_dnf(
partition_dimensions=table_slice.partition_dimensions,
table_schema=table.schema(),
input_dnf=True,
)

dataset = read_table(table_slice, connection, parquet_read_options=parquet_read_options)

if context.dagster_type.typing_type == ds.Dataset:
if table_slice.columns is not None:
raise ValueError("Cannot select columns when loading as Dataset.")
return dataset

scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)


def _get_partition_stats(
dt: DeltaTable,
partition_filters: Optional[list[FilterLiteralType]] = None,
) -> dict[str, Any]:
"""Gets the stats for a partition

Args:
dt (DeltaTable): DeltaTable object
partition_filters (list[FilterLiteralType] | None, optional): filters to grabs stats with. Defaults to None.

Returns:
dict[str, MetadataValue]: Partition stats
"""
files = pa.array(dt.files(partition_filters=partition_filters))
files_table = pa.Table.from_arrays([files], names=["path"])
actions_table = pa.Table.from_batches([dt.get_add_actions(flatten=True)])
actions_table = actions_table.select(["path", "size_bytes", "num_records"])
table = files_table.join(actions_table, keys="path")

stats: dict[str, Any]

stats = {
"size_MB": MetadataValue.float(
pc.sum(table.column("size_bytes")).as_py() * 0.00000095367432, # type: ignore
),
}
row_count = MetadataValue.int(
pc.sum(table.column("num_records")).as_py(), # type: ignore
)
if partition_filters is not None:
stats["dagster/partition_row_count"] = row_count
else:
stats["dagster/row_count"] = row_count

return stats
if partition_filters is not None:
## Convert partition_filter to predicate
predicate = create_predicate(partition_filters)

logger.debug("Dataset input predicate %s", predicate)

col_select = table_slice.columns if table_slice.columns is not None else "*"
query = f"SELECT {col_select} FROM tbl"
if predicate is not None:
query = f"{query} WHERE {predicate}"
data = QueryBuilder().register("tbl", table).execute(query)

return self.from_arrow(data, context.dagster_type.typing_type)
10 changes: 2 additions & 8 deletions libraries/dagster-delta/dagster_delta/_handler/merge.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import logging
from typing import Any, Optional, TypeVar, Union

import pyarrow as pa
import pyarrow.dataset as ds
from arro3.core.types import ArrowArrayExportable, ArrowStreamExportable
from deltalake import CommitProperties, DeltaTable, WriterProperties
from deltalake.table import FilterLiteralType, TableMerger

from dagster_delta._handler.utils import create_predicate
from dagster_delta.config import MergeConfig, MergeOperationsConfig, MergeType

T = TypeVar("T")
ArrowTypes = Union[pa.Table, pa.RecordBatchReader, ds.Dataset]


def merge_execute(
dt: DeltaTable,
data: Union[pa.RecordBatchReader, pa.Table],
data: Union[ArrowStreamExportable, ArrowArrayExportable],
merge_config: MergeConfig,
writer_properties: Optional[WriterProperties],
commit_properties: Optional[CommitProperties],
custom_metadata: Optional[dict[str, str]],
delta_params: dict[str, Any],
merge_predicate_from_metadata: Optional[str],
merge_operations_config: Optional[MergeOperationsConfig],
partition_filters: Optional[list[FilterLiteralType]] = None,
Expand Down Expand Up @@ -53,8 +49,6 @@ def merge_execute(
error_on_type_mismatch=error_on_type_mismatch,
writer_properties=writer_properties,
commit_properties=commit_properties,
custom_metadata=custom_metadata,
**delta_params,
)

if merge_type == MergeType.update_only:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from dagster_delta._handler.utils.date_format import extract_date_format_from_partition_definition
from dagster_delta._handler.utils.dnf import partition_dimensions_to_dnf
from dagster_delta._handler.utils.predicates import create_predicate
from dagster_delta._handler.utils.reader import read_table

__all__ = [
"create_predicate",
"read_table",
"extract_date_format_from_partition_definition",
"partition_dimensions_to_dnf",
]
Loading