Skip to content

Commit 58d07b9

Browse files
committed
fix(datafusion): add partition pruning, pushdown full predicates for DF integration
1 parent 64f6369 commit 58d07b9

File tree

2 files changed

+76
-20
lines changed

2 files changed

+76
-20
lines changed

crates/core/src/delta_datafusion/mod.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ use datafusion_expr::execution_props::ExecutionProps;
5858
use datafusion_expr::logical_plan::CreateExternalTable;
5959
use datafusion_expr::simplify::SimplifyContext;
6060
use datafusion_expr::utils::conjunction;
61-
use datafusion_expr::{col, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility};
61+
use datafusion_expr::{
62+
col, BinaryExpr, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility,
63+
};
6264
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
6365
use datafusion_physical_plan::filter::FilterExec;
6466
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
@@ -534,6 +536,10 @@ impl<'a> DeltaScanBuilder<'a> {
534536
Some(schema.clone()),
535537
)?;
536538

539+
// TODO temporarily using full schema to generate pruning predicates
540+
// should we optimize this by only including fields referenced from predicates?
541+
let filter_df_schema = logical_schema.clone().to_dfschema()?;
542+
537543
let logical_schema = if let Some(used_columns) = self.projection {
538544
let mut fields = vec![];
539545
for idx in used_columns {
@@ -545,18 +551,17 @@ impl<'a> DeltaScanBuilder<'a> {
545551
};
546552

547553
let context = SessionContext::new();
548-
let df_schema = logical_schema.clone().to_dfschema()?;
549554

550555
let logical_filter = self.filter.map(|expr| {
551556
// Simplify the expression first
552557
let props = ExecutionProps::new();
553558
let simplify_context =
554-
SimplifyContext::new(&props).with_schema(df_schema.clone().into());
559+
SimplifyContext::new(&props).with_schema(filter_df_schema.clone().into());
555560
let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10);
556561
let simplified = simplifier.simplify(expr).unwrap();
557562

558563
context
559-
.create_physical_expr(simplified, &df_schema)
564+
.create_physical_expr(simplified, &filter_df_schema)
560565
.unwrap()
561566
});
562567

@@ -757,17 +762,47 @@ impl TableProvider for DeltaTable {
757762
&self,
758763
filter: &[&Expr],
759764
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
760-
Ok(filter
761-
.iter()
762-
.map(|_| TableProviderFilterPushDown::Inexact)
763-
.collect())
765+
let partition_cols = self.snapshot()?.metadata().partition_columns.clone();
766+
Ok(get_pushdown_filters(filter, partition_cols))
764767
}
765768

766769
fn statistics(&self) -> Option<Statistics> {
767770
self.snapshot().ok()?.datafusion_table_statistics()
768771
}
769772
}
770773

774+
fn get_pushdown_filters(
775+
filter: &[&Expr],
776+
partition_cols: Vec<String>,
777+
) -> Vec<TableProviderFilterPushDown> {
778+
filter
779+
.iter()
780+
.map(|filter| {
781+
let columns = extract_columns(filter);
782+
if !columns.is_empty() && columns.iter().all(|col| partition_cols.contains(col)) {
783+
TableProviderFilterPushDown::Exact
784+
} else {
785+
TableProviderFilterPushDown::Inexact
786+
}
787+
})
788+
.collect()
789+
}
790+
791+
fn extract_columns(expr: &Expr) -> Vec<String> {
792+
let mut columns = Vec::new();
793+
match expr {
794+
Expr::Column(col) => columns.push(col.name.clone()),
795+
Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
796+
let left_columns = extract_columns(left);
797+
let right_columns = extract_columns(right);
798+
columns.extend(left_columns);
799+
columns.extend(right_columns);
800+
}
801+
_ => {}
802+
}
803+
columns
804+
}
805+
771806
/// A Delta table provider that enables additional metadata columns to be included during the scan
772807
#[derive(Debug)]
773808
pub struct DeltaTableProvider {
@@ -849,10 +884,8 @@ impl TableProvider for DeltaTableProvider {
849884
&self,
850885
filter: &[&Expr],
851886
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
852-
Ok(filter
853-
.iter()
854-
.map(|_| TableProviderFilterPushDown::Inexact)
855-
.collect())
887+
let partition_cols = self.snapshot.metadata().partition_columns.clone();
888+
Ok(get_pushdown_filters(filter, partition_cols))
856889
}
857890

858891
fn statistics(&self) -> Option<Statistics> {

crates/core/src/kernel/snapshot/log_data.rs

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -891,14 +891,37 @@ mod datafusion {
891891
arrow_cast::cast(batch.column_by_name("output")?, &ArrowDataType::UInt64).ok()
892892
}
893893

894-
// This function is required since DataFusion 35.0, but is implemented as a no-op
895-
// https://github.com/apache/arrow-datafusion/blob/ec6abece2dcfa68007b87c69eefa6b0d7333f628/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs#L550
896-
fn contained(
897-
&self,
898-
_column: &Column,
899-
_value: &HashSet<ScalarValue>,
900-
) -> Option<BooleanArray> {
901-
None
894+
// This function is required for partition column pruning to be executed correctly
895+
fn contained(&self, column: &Column, value: &HashSet<ScalarValue>) -> Option<BooleanArray> {
896+
// Check if the column is a partition column
897+
if !self.metadata.partition_columns.contains(&column.name) {
898+
return None;
899+
}
900+
901+
// Retrieve the partition values for the column
902+
let partition_values = self.pick_stats(column, "__dummy__")?;
903+
904+
let partition_values = partition_values
905+
.as_any()
906+
.downcast_ref::<StringArray>()
907+
.ok_or(DeltaTableError::generic(
908+
"failed to downcast string result to StringArray.",
909+
))
910+
.ok()?;
911+
912+
let mut contains = Vec::with_capacity(partition_values.len());
913+
914+
for i in 0..partition_values.len() {
915+
if partition_values.is_null(i) {
916+
contains.push(false);
917+
} else {
918+
let partition_value =
919+
ScalarValue::Utf8(Some(partition_values.value(i).to_string()));
920+
contains.push(value.contains(&partition_value));
921+
}
922+
}
923+
924+
Some(BooleanArray::from(contains))
902925
}
903926
}
904927
}

0 commit comments

Comments
 (0)