diff --git a/crates/core/examples/df_scan.rs b/crates/core/examples/df_scan.rs new file mode 100644 index 0000000000..71789fccb1 --- /dev/null +++ b/crates/core/examples/df_scan.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use arrow_cast::pretty::print_batches; +use datafusion::datasource::TableProvider; +use datafusion::physical_plan::collect_partitioned; +use datafusion::prelude::SessionContext; +use deltalake_core::delta_datafusion::engine::DataFusionEngine; +use deltalake_core::kernel::Snapshot; +use deltalake_core::DeltaTableError; +use url::Url; + +static CASES: &[&str] = &[ + "./dat/v0.0.3/reader_tests/generated/all_primitive_types/delta/", // 0 + "./dat/v0.0.3/reader_tests/generated/basic_append/delta/", // 1 + "./dat/v0.0.3/reader_tests/generated/basic_partitioned/delta/", // 2 + "./dat/v0.0.3/reader_tests/generated/cdf/delta/", // 3 + "./dat/v0.0.3/reader_tests/generated/check_constraints/delta/", // 4 + "./dat/v0.0.3/reader_tests/generated/column_mapping/delta/", // 5 + "./dat/v0.0.3/reader_tests/generated/deletion_vectors/delta/", // 6 + "./dat/v0.0.3/reader_tests/generated/generated_columns/delta/", // 7 + "./dat/v0.0.3/reader_tests/generated/iceberg_compat_v1/delta/", // 8 + "./dat/v0.0.3/reader_tests/generated/multi_partitioned/delta/", // 9 + "./dat/v0.0.3/reader_tests/generated/multi_partitioned_2/delta/", // 10 + "./dat/v0.0.3/reader_tests/generated/nested_types/delta/", // 11 + "./dat/v0.0.3/reader_tests/generated/no_replay/delta/", // 12 + "./dat/v0.0.3/reader_tests/generated/no_stats/delta/", // 13 + "./dat/v0.0.3/reader_tests/generated/partitioned_with_null/delta/", // 14 + "./dat/v0.0.3/reader_tests/generated/stats_as_struct/delta/", // 15 + "./dat/v0.0.3/reader_tests/generated/timestamp_ntz/delta/", // 16 + "./dat/v0.0.3/reader_tests/generated/with_checkpoint/delta/", // 17 + "./dat/v0.0.3/reader_tests/generated/with_schema_change/delta/", // 18 +]; + +#[tokio::main(flavor = "multi_thread", worker_threads = 4)] +async fn main() -> Result<(), DeltaTableError> { + let session = Arc::new(SessionContext::new()); + let engine = DataFusionEngine::new_from_session(&session.state()); + + let path = std::fs::canonicalize(CASES[5]).unwrap(); + let table_url = Url::from_directory_path(path).unwrap(); + let snapshot = + Snapshot::try_new_with_engine(engine.clone(), table_url, Default::default(), None).await?; + + let state = session.state_ref().read().clone(); + + let plan = snapshot.scan(&state, None, &[], None).await?; + + let batches: Vec<_> = collect_partitioned(plan, session.task_ctx()) + .await? + .into_iter() + .flatten() + .collect(); + + print_batches(&batches).unwrap(); + + Ok(()) +} diff --git a/crates/core/src/data_catalog/storage/mod.rs b/crates/core/src/data_catalog/storage/mod.rs index 5cadaff798..848e781d5c 100644 --- a/crates/core/src/data_catalog/storage/mod.rs +++ b/crates/core/src/data_catalog/storage/mod.rs @@ -117,12 +117,12 @@ impl SchemaProvider for ListingSchemaProvider { let Some(location) = self.tables.get(name).map(|t| t.clone()) else { return Ok(None); }; - let provider = open_table_with_storage_options( + let table = open_table_with_storage_options( ensure_table_uri(location)?, self.storage_options.raw.clone(), ) .await?; - Ok(Some(Arc::new(provider) as Arc)) + Ok(Some(Arc::new(table) as Arc)) } fn register_table( diff --git a/crates/core/src/delta_datafusion/engine/expressions/mod.rs b/crates/core/src/delta_datafusion/engine/expressions/mod.rs new file mode 100644 index 0000000000..625fc693bf --- /dev/null +++ b/crates/core/src/delta_datafusion/engine/expressions/mod.rs @@ -0,0 +1,5 @@ +pub(crate) use self::to_df::*; +pub(crate) use self::to_kernel::*; + +mod to_df; +mod to_kernel; diff --git a/crates/core/src/delta_datafusion/engine/expressions/to_df.rs b/crates/core/src/delta_datafusion/engine/expressions/to_df.rs new file mode 100644 index 0000000000..fb46683840 --- /dev/null +++ b/crates/core/src/delta_datafusion/engine/expressions/to_df.rs @@ -0,0 +1,874 @@ +use std::sync::Arc; + +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::common::{not_impl_err, DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::functions::core::expr_ext::FieldAccessor; +use datafusion::functions::expr_fn::named_struct; +use datafusion::logical_expr::{col, lit, Expr}; +use delta_kernel::arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField}; +use delta_kernel::engine::arrow_conversion::TryIntoArrow; +use delta_kernel::expressions::{ + BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp, Expression, + JunctionPredicate, JunctionPredicateOp, Scalar, UnaryExpression, UnaryExpressionOp, + UnaryPredicate, UnaryPredicateOp, +}; +use delta_kernel::schema::DataType; +use delta_kernel::Predicate; +use itertools::Itertools; + +pub(crate) fn to_datafusion_expr(expr: &Expression, output_type: &DataType) -> DFResult { + match expr { + Expression::Literal(scalar) => scalar_to_df(scalar).map(lit), + Expression::Column(name) => { + let mut name_iter = name.iter(); + let base_name = name_iter.next().ok_or_else(|| { + DataFusionError::Internal("Expected at least one column name".into()) + })?; + Ok(name_iter.fold(col(base_name), |acc, n| acc.field(n))) + } + Expression::Predicate(expr) => predicate_to_df(expr, output_type), + Expression::Struct(fields) => struct_to_df(fields, output_type), + Expression::Binary(expr) => binary_to_df(expr, output_type), + Expression::Opaque(_) => not_impl_err!("Opaque expressions are not yet supported"), + Expression::Unknown(_) => not_impl_err!("Unknown expressions are not yet supported"), + Expression::Transform(_) => not_impl_err!("Transform expressions are not yet supported"), + Expression::Unary(_) => not_impl_err!("Unary expressions are not yet supported"), + Expression::Variadic(_) => not_impl_err!("Variadic expressions are not yet supported"), + } +} + +pub(crate) fn scalar_to_df(scalar: &Scalar) -> DFResult { + Ok(match scalar { + Scalar::Boolean(value) => ScalarValue::Boolean(Some(*value)), + Scalar::String(value) => ScalarValue::Utf8(Some(value.clone())), + Scalar::Byte(value) => ScalarValue::Int8(Some(*value)), + Scalar::Short(value) => ScalarValue::Int16(Some(*value)), + Scalar::Integer(value) => ScalarValue::Int32(Some(*value)), + Scalar::Long(value) => ScalarValue::Int64(Some(*value)), + Scalar::Float(value) => ScalarValue::Float32(Some(*value)), + Scalar::Double(value) => ScalarValue::Float64(Some(*value)), + Scalar::Timestamp(value) => { + ScalarValue::TimestampMicrosecond(Some(*value), Some("UTC".into())) + } + Scalar::TimestampNtz(value) => ScalarValue::TimestampMicrosecond(Some(*value), None), + Scalar::Date(value) => ScalarValue::Date32(Some(*value)), + Scalar::Binary(value) => ScalarValue::Binary(Some(value.clone())), + Scalar::Decimal(data) => { + ScalarValue::Decimal128(Some(data.bits()), data.precision(), data.scale() as i8) + } + Scalar::Struct(data) => { + let fields: Vec = data + .fields() + .iter() + .map(|f| f.try_into_arrow()) + .try_collect()?; + let values: Vec<_> = data.values().iter().map(scalar_to_df).try_collect()?; + fields + .into_iter() + .zip(values.into_iter()) + .fold(ScalarStructBuilder::new(), |builder, (field, value)| { + builder.with_scalar(field, value) + }) + .build()? + } + Scalar::Array(_) => { + return Err(DataFusionError::NotImplemented( + "Array scalar values not implemented".into(), + )); + } + Scalar::Map(_) => { + return Err(DataFusionError::NotImplemented( + "Map scalar values not implemented".into(), + )); + } + Scalar::Null(data_type) => { + let data_type: ArrowDataType = data_type + .try_into_arrow() + .map_err(|e| DataFusionError::External(e.into()))?; + ScalarValue::try_from(&data_type)? + } + }) +} + +fn binary_to_df(bin: &BinaryExpression, output_type: &DataType) -> DFResult { + let BinaryExpression { left, op, right } = bin; + let left_expr = to_datafusion_expr(left, output_type)?; + let right_expr = to_datafusion_expr(right, output_type)?; + Ok(match op { + BinaryExpressionOp::Plus => left_expr + right_expr, + BinaryExpressionOp::Minus => left_expr - right_expr, + BinaryExpressionOp::Multiply => left_expr * right_expr, + BinaryExpressionOp::Divide => left_expr / right_expr, + }) +} + +fn unary_to_df(un: &UnaryExpression, output_type: &DataType) -> DFResult { + let UnaryExpression { op, expr } = un; + let expr = to_datafusion_expr(expr, output_type)?; + Ok(match op { + UnaryExpressionOp::ToJson => todo!(), + }) +} + +fn binary_pred_to_df(bin: &BinaryPredicate, output_type: &DataType) -> DFResult { + let BinaryPredicate { left, op, right } = bin; + let left_expr = to_datafusion_expr(left, output_type)?; + let right_expr = to_datafusion_expr(right, output_type)?; + Ok(match op { + BinaryPredicateOp::Equal => left_expr.eq(right_expr), + BinaryPredicateOp::LessThan => left_expr.lt(right_expr), + BinaryPredicateOp::GreaterThan => left_expr.gt(right_expr), + BinaryPredicateOp::Distinct => Err(DataFusionError::NotImplemented( + "DISTINCT operator not supported".into(), + ))?, + BinaryPredicateOp::In => Err(DataFusionError::NotImplemented( + "IN operator not supported".into(), + ))?, + }) +} + +fn predicate_to_df(predicate: &Predicate, output_type: &DataType) -> DFResult { + match predicate { + Predicate::BooleanExpression(expr) => to_datafusion_expr(expr, output_type), + Predicate::Not(expr) => Ok(!(predicate_to_df(expr, output_type)?)), + Predicate::Unary(expr) => unary_pred_to_df(expr, output_type), + Predicate::Binary(expr) => binary_pred_to_df(expr, output_type), + Predicate::Junction(expr) => junction_to_df(expr, output_type), + Predicate::Opaque(_) => not_impl_err!("Opaque predicates are not yet supported"), + Predicate::Unknown(_) => not_impl_err!("Unknown predicates are not yet supported"), + } +} + +fn unary_pred_to_df(unary: &UnaryPredicate, output_type: &DataType) -> DFResult { + let UnaryPredicate { op, expr } = unary; + let df_expr = to_datafusion_expr(expr, output_type)?; + Ok(match op { + UnaryPredicateOp::IsNull => df_expr.is_null(), + }) +} + +fn junction_to_df(junction: &JunctionPredicate, output_type: &DataType) -> DFResult { + let JunctionPredicate { op, preds } = junction; + let df_exprs: Vec<_> = preds + .iter() + .map(|e| predicate_to_df(e, output_type)) + .try_collect()?; + match op { + JunctionPredicateOp::And => Ok(df_exprs + .into_iter() + .reduce(|a, b| a.and(b)) + .unwrap_or(lit(true))), + JunctionPredicateOp::Or => Ok(df_exprs + .into_iter() + .reduce(|a, b| a.or(b)) + .unwrap_or(lit(false))), + } +} + +fn struct_to_df(fields: &[Arc], output_type: &DataType) -> DFResult { + let DataType::Struct(struct_type) = output_type else { + return Err(DataFusionError::Execution( + "expected struct output type".into(), + )); + }; + let df_exprs: Vec<_> = fields + .iter() + .zip(struct_type.fields()) + .map(|(expr, field)| { + Ok(vec![ + lit(field.name().to_string()), + to_datafusion_expr(expr, field.data_type())?, + ]) + }) + .flatten_ok() + .try_collect::<_, _, DataFusionError>()?; + Ok(named_struct(df_exprs)) +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + + use datafusion::logical_expr::{col, lit}; + use delta_kernel::expressions::ColumnName; + use delta_kernel::expressions::{ArrayData, BinaryExpression, MapData, Scalar, StructData}; + use delta_kernel::schema::{ArrayType, DataType, MapType, StructField, StructType}; + + use super::*; + + /// Test conversion of primitive scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_primitives() { + let test_cases = vec![ + (Scalar::Boolean(true), ScalarValue::Boolean(Some(true))), + ( + Scalar::String("test".to_string()), + ScalarValue::Utf8(Some("test".to_string())), + ), + (Scalar::Integer(42), ScalarValue::Int32(Some(42))), + (Scalar::Long(42), ScalarValue::Int64(Some(42))), + (Scalar::Float(42.0), ScalarValue::Float32(Some(42.0))), + (Scalar::Double(42.0), ScalarValue::Float64(Some(42.0))), + (Scalar::Byte(42), ScalarValue::Int8(Some(42))), + (Scalar::Short(42), ScalarValue::Int16(Some(42))), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test conversion of temporal scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_temporal() { + let test_cases = vec![ + ( + Scalar::Timestamp(1234567890), + ScalarValue::TimestampMicrosecond(Some(1234567890), Some("UTC".into())), + ), + ( + Scalar::TimestampNtz(1234567890), + ScalarValue::TimestampMicrosecond(Some(1234567890), None), + ), + (Scalar::Date(18262), ScalarValue::Date32(Some(18262))), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test conversion of binary and decimal scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_binary_decimal() { + let binary_data = vec![1, 2, 3]; + let decimal_data = Scalar::decimal(123456789, 10, 2).unwrap(); + + let test_cases = vec![ + ( + Scalar::Binary(binary_data.clone()), + ScalarValue::Binary(Some(binary_data)), + ), + ( + decimal_data, + ScalarValue::Decimal128(Some(123456789), 10, 2), + ), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test conversion of struct scalar type to DataFusion scalar value + #[test] + fn test_scalar_to_df_struct() { + let result = scalar_to_df(&Scalar::Struct( + StructData::try_new( + vec![ + StructField::nullable("a", DataType::INTEGER), + StructField::nullable("b", DataType::STRING), + ], + vec![Scalar::Integer(42), Scalar::String("test".to_string())], + ) + .unwrap(), + )) + .unwrap(); + + // Create the expected struct value + let expected = ScalarStructBuilder::new() + .with_scalar( + ArrowField::new("a", ArrowDataType::Int32, true), + ScalarValue::Int32(Some(42)), + ) + .with_scalar( + ArrowField::new("b", ArrowDataType::Utf8, true), + ScalarValue::Utf8(Some("test".to_string())), + ) + .build() + .unwrap(); + + assert_eq!(result, expected); + } + + /// Test conversion of null scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_null() { + let test_cases = vec![ + (Scalar::Null(DataType::INTEGER), ScalarValue::Int32(None)), + (Scalar::Null(DataType::STRING), ScalarValue::Utf8(None)), + (Scalar::Null(DataType::BOOLEAN), ScalarValue::Boolean(None)), + (Scalar::Null(DataType::DOUBLE), ScalarValue::Float64(None)), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test error cases for unsupported scalar types (Array and Map) + #[test] + fn test_scalar_to_df_errors() { + let array_data = ArrayData::try_new( + ArrayType::new(DataType::INTEGER, true), + vec![Scalar::Integer(1), Scalar::Integer(2)], + ) + .unwrap(); + + let map_data = MapData::try_new( + MapType::new(DataType::STRING, DataType::INTEGER, true), + vec![ + (Scalar::String("key1".to_string()), Scalar::Integer(1)), + (Scalar::String("key2".to_string()), Scalar::Integer(2)), + ], + ) + .unwrap(); + + let test_cases = vec![ + ( + Scalar::Array(array_data), + "Array scalar values not implemented", + ), + (Scalar::Map(map_data), "Map scalar values not implemented"), + ]; + + for (input, expected_error) in test_cases { + let result = scalar_to_df(&input); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains(expected_error)); + } + } + + /// Test basic column reference: `test_col` + #[test] + fn test_column_expression() { + let expr = Expression::Column(ColumnName::new(["test_col"])); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("test_col")); + + let expr = Expression::Column(ColumnName::new(["test_col", "field_1", "field_2"])); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("test_col").field("field_1").field("field_2")); + } + + /// Test various literal values: + /// - `true` (boolean) + /// - `"test"` (string) + /// - `42` (integer) + /// - `42L` (long) + /// - `42.0f` (float) + /// - `42.0` (double) + /// - `NULL` (null boolean) + #[test] + fn test_literal_expressions() { + // Test various scalar types + let test_cases = vec![ + (Expression::Literal(Scalar::Boolean(true)), lit(true)), + ( + Expression::Literal(Scalar::String("test".to_string())), + lit("test"), + ), + (Expression::Literal(Scalar::Integer(42)), lit(42)), + (Expression::Literal(Scalar::Long(42)), lit(42i64)), + (Expression::Literal(Scalar::Float(42.0)), lit(42.0f32)), + (Expression::Literal(Scalar::Double(42.0)), lit(42.0)), + ( + Expression::Literal(Scalar::Null(DataType::BOOLEAN)), + lit(ScalarValue::Boolean(None)), + ), + ]; + + for (input, expected) in test_cases { + let result = to_datafusion_expr(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test binary operations: + /// - `a = 1` (equality) + /// - `a + b` (addition) + /// - `a * 2` (multiplication) + #[test] + fn test_binary_expressions() { + let test_cases = vec![ + ( + Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }), + col("a") + col("b"), + ), + ( + Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Multiply, + right: Box::new(Expression::Literal(Scalar::Integer(2))), + }), + col("a") * lit(2), + ), + ]; + + for (input, expected) in test_cases { + let result = to_datafusion_expr(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test binary operations: + /// - `a = 1` (equality) + /// - `a + b` (addition) + /// - `a * 2` (multiplication) + #[test] + fn test_binary_predicate() { + let test_cases = vec![( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Literal(Scalar::Integer(1))), + }, + col("a").eq(lit(1)), + )]; + + for (input, expected) in test_cases { + let result = binary_pred_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test unary operations: + /// - `a IS NULL` (null check) + /// - `NOT a` (logical negation) + #[test] + fn test_unary_expressions() { + let test_cases = vec![( + UnaryPredicate { + op: UnaryPredicateOp::IsNull, + expr: Box::new(Expression::Column(ColumnName::new(["a"]))), + }, + col("a").is_null(), + )]; + + for (input, expected) in test_cases { + let result = unary_pred_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test junction operations: + /// - `a AND b` (logical AND) + /// - `a OR b` (logical OR) + #[test] + fn test_junction_expressions() { + let test_cases = vec![ + ( + JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }, + col("a").and(col("b")), + ), + ( + JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }, + col("a").or(col("b")), + ), + ]; + + for (input, expected) in test_cases { + let result = junction_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test complex nested expression: + /// `(a > 1 AND b < 2) OR (c = 3)` + #[test] + fn test_complex_nested_expressions() { + // Test a complex expression: (a > 1 AND b < 2) OR (c = 3) + let expr = Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Literal(Scalar::Integer(1))), + }), + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["b"]))), + op: BinaryPredicateOp::LessThan, + right: Box::new(Expression::Literal(Scalar::Integer(2))), + }), + ], + }), + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["c"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Literal(Scalar::Integer(3))), + }), + ], + }); + + let result = predicate_to_df(&expr, &DataType::BOOLEAN).unwrap(); + let expected = (col("a").gt(lit(1)).and(col("b").lt(lit(2)))).or(col("c").eq(lit(3))); + assert_eq!(result, expected); + } + + #[test] + fn test_struct_expression() { + let expr = Expression::Struct(vec![ + Expression::Column(ColumnName::new(["a"])).into(), + Expression::Column(ColumnName::new(["b"])).into(), + ]); + let result = to_datafusion_expr( + &expr, + &DataType::Struct(Box::new( + StructType::try_new(vec![ + StructField::nullable("a", DataType::INTEGER), + StructField::nullable("b", DataType::INTEGER), + ]) + .unwrap(), + )), + ) + .unwrap(); + assert_eq!( + result, + named_struct(vec![lit("a"), col("a"), lit("b"), col("b")]) + ); + } + + /// Test binary expression conversions: + /// - Addition: a + b + /// - Subtraction: a - b + /// - Multiplication: a * b + /// - Division: a / b + #[test] + fn test_binary_to_df() { + let test_cases = vec![ + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") + col("b"), + ), + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Minus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") - col("b"), + ), + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Multiply, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") * col("b"), + ), + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Divide, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") / col("b"), + ), + ]; + + for (input, expected) in test_cases { + let result = binary_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test binary expression conversions: + /// - Equality: a = b + /// - Inequality: a != b + /// - Less than: a < b + /// - Less than or equal: a <= b + /// - Greater than: a > b + /// - Greater than or equal: a >= b + #[test] + fn test_binary_pred_to_df() { + let test_cases = vec![ + ( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a").eq(col("b")), + ), + ( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::LessThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a").lt(col("b")), + ), + ( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a").gt(col("b")), + ), + ]; + + for (input, expected) in test_cases { + let result = binary_pred_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + + let test_cases = vec![ + ( + Predicate::Not(Box::new(Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }))), + col("a").eq(col("b")).not(), + ), + ( + Predicate::Not(Box::new(Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }))), + col("a").gt(col("b")).not(), + ), + ( + Predicate::Not(Box::new(Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::LessThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }))), + col("a").lt(col("b")).not(), + ), + ]; + + for (input, expected) in test_cases { + let result = predicate_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test junction expression conversions: + /// - Simple AND: a AND b + /// - Simple OR: a OR b + /// - Multiple AND: a AND b AND c + /// - Multiple OR: a OR b OR c + /// - Empty AND (should return true) + /// - Empty OR (should return false) + #[test] + fn test_junction_to_df() { + let test_cases = vec![ + // Simple AND + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }), + col("a").and(col("b")), + ), + // Simple OR + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }), + col("a").or(col("b")), + ), + // Multiple AND + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + ], + }), + col("a").and(col("b")).and(col("c")), + ), + // Multiple OR + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + ], + }), + col("a").or(col("b")).or(col("c")), + ), + // Empty AND (should return true) + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![], + }), + lit(true), + ), + // Empty OR (should return false) + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![], + }), + lit(false), + ), + ]; + + for (input, expected) in test_cases { + let result = predicate_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test to_datafusion_expr with various expression types and combinations: + /// - Column expressions with nested fields + /// - Complex unary expressions + /// - Nested binary expressions + /// - Mixed junction expressions + /// - Struct expressions with nested fields + /// - Complex combinations of all expression types + #[test] + fn test_to_datafusion_expr_comprehensive() { + // Test column expressions with nested fields + let expr = Expression::Column(ColumnName::new(["struct", "field", "nested"])); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("struct").field("field").field("nested")); + + // Test complex unary expressions + let expr = Expression::Predicate(Box::new(Predicate::Not(Box::new(Predicate::Unary( + UnaryPredicate { + op: UnaryPredicateOp::IsNull, + expr: Box::new(Expression::Column(ColumnName::new(["a"]))), + }, + ))))); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, !col("a").is_null()); + + // Test nested binary expressions + let expr = Expression::Binary(BinaryExpression { + left: Box::new(Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + })), + op: BinaryExpressionOp::Multiply, + right: Box::new(Expression::Column(ColumnName::new(["c"]))), + }); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, (col("a") + col("b")) * col("c")); + + // Test mixed junction expressions + let expr = Expression::Predicate(Box::new(Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Literal(Scalar::Integer(0))), + }), + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + ], + }), + ], + }))); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("a").gt(lit(0)).and(col("b").or(col("c")))); + + // Test struct expressions with nested fields + let expr = Expression::Struct(vec![ + Expression::Column(ColumnName::new(["a"])).into(), + Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["b"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["c"]))), + }) + .into(), + ]); + let result = to_datafusion_expr( + &expr, + &DataType::Struct(Box::new( + StructType::try_new(vec![ + StructField::nullable("a", DataType::INTEGER), + StructField::nullable("sum", DataType::INTEGER), + ]) + .unwrap(), + )), + ) + .unwrap(); + assert_eq!( + result, + named_struct(vec![lit("a"), col("a"), lit("sum"), col("b") + col("c")]) + ); + + // Test complex combination of all expression types + let expr = Expression::Predicate(Box::new(Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::Not(Box::new(Predicate::BooleanExpression(Expression::Column( + ColumnName::new(["a"]), + )))), + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["b"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Literal(Scalar::Integer(42))), + }), + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["d"]))), + ], + }), + ], + }))); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!( + result, + (!col("a")) + .and(col("b").eq(lit(42))) + .and(col("c").or(col("d"))) + ); + + // Test error case: empty column name + let expr = Expression::Column(ColumnName::new::<&str>([])); + assert!(to_datafusion_expr(&expr, &DataType::BOOLEAN).is_err()); + } +} diff --git a/crates/core/src/delta_datafusion/engine/expressions/to_kernel.rs b/crates/core/src/delta_datafusion/engine/expressions/to_kernel.rs new file mode 100644 index 0000000000..437a28acb2 --- /dev/null +++ b/crates/core/src/delta_datafusion/engine/expressions/to_kernel.rs @@ -0,0 +1,530 @@ +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::{BinaryExpr, Expr, Operator}; +use delta_kernel::expressions::{ + BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp, DecimalData, + Expression, JunctionPredicate, JunctionPredicateOp, Predicate, Scalar, UnaryPredicate, + UnaryPredicateOp, +}; +use delta_kernel::schema::{DataType, DecimalType, PrimitiveType}; +use delta_kernel::Error as DeltaKernelError; +use itertools::Itertools; + +pub(crate) fn to_df_err(e: DeltaKernelError) -> DataFusionError { + DataFusionError::External(Box::new(e)) +} + +pub(crate) fn to_delta_predicate(filters: &[Expr]) -> DFResult { + if filters.is_empty() { + return Ok(Predicate::BooleanExpression(Expression::Literal( + Scalar::Boolean(true), + ))); + }; + Ok(Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: filters.iter().map(to_predicate).try_collect()?, + })) +} + +pub(crate) fn to_predicate(expr: &Expr) -> DFResult { + match to_delta_expression(expr)? { + Expression::Predicate(pred) => Ok(pred.as_ref().clone()), + expr => Ok(Predicate::BooleanExpression(expr)), + } +} + +/// Convert a DataFusion expression to a Delta expression. +pub(crate) fn to_delta_expression(expr: &Expr) -> DFResult { + match expr { + Expr::Column(column) => Ok(Expression::Column( + column + .name + .parse() + .map_err(|e| DataFusionError::External(Box::new(e)))?, + )), + Expr::Literal(scalar, _meta) => { + Ok(Expression::Literal(datafusion_scalar_to_scalar(scalar)?)) + } + Expr::BinaryExpr(BinaryExpr { + op: op @ (Operator::And | Operator::Or), + .. + }) => { + let preds = flatten_junction_expr(expr, *op)?; + Ok(Expression::Predicate(Box::new(Predicate::Junction( + JunctionPredicate { + op: to_junction_op(*op), + preds, + }, + )))) + } + Expr::BinaryExpr(BinaryExpr { + op: op @ (Operator::Eq | Operator::Lt | Operator::Gt), + left, + right, + }) => Ok(Expression::Predicate(Box::new(Predicate::Binary( + BinaryPredicate { + left: Box::new(to_delta_expression(left.as_ref())?), + op: to_binary_predicate_op(*op)?, + right: Box::new(to_delta_expression(right.as_ref())?), + }, + )))), + Expr::BinaryExpr(BinaryExpr { + op: op @ (Operator::NotEq | Operator::LtEq | Operator::GtEq), + left, + right, + }) => { + let inverted = match op { + Operator::NotEq => Operator::Eq, + Operator::LtEq => Operator::Gt, + Operator::GtEq => Operator::Lt, + _ => unreachable!(), + }; + Ok(Expression::Predicate(Box::new(Predicate::Not(Box::new( + Predicate::Binary(BinaryPredicate { + left: Box::new(to_delta_expression(left.as_ref())?), + op: to_binary_predicate_op(inverted)?, + right: Box::new(to_delta_expression(right.as_ref())?), + }), + ))))) + } + Expr::BinaryExpr(BinaryExpr { op, left, right }) => { + Ok(Expression::Binary(BinaryExpression { + left: Box::new(to_delta_expression(left.as_ref())?), + op: to_binary_op(*op)?, + right: Box::new(to_delta_expression(right.as_ref())?), + })) + } + Expr::IsNull(expr) => Ok(Expression::Predicate(Box::new(Predicate::Unary( + UnaryPredicate { + op: UnaryPredicateOp::IsNull, + expr: Box::new(to_delta_expression(expr.as_ref())?), + }, + )))), + Expr::Not(expr) => Ok(Expression::Predicate(Box::new(Predicate::Not(Box::new( + Predicate::BooleanExpression(to_delta_expression(expr.as_ref())?), + ))))), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported expression: {:?}", + expr + ))), + } +} + +fn datafusion_scalar_to_scalar(scalar: &ScalarValue) -> DFResult { + match scalar { + ScalarValue::Boolean(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Boolean(*value)), + None => Ok(Scalar::Null(DataType::BOOLEAN)), + }, + ScalarValue::Utf8(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::String(value.clone())), + None => Ok(Scalar::Null(DataType::STRING)), + }, + ScalarValue::Int8(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Byte(*value)), + None => Ok(Scalar::Null(DataType::BYTE)), + }, + ScalarValue::Int16(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Short(*value)), + None => Ok(Scalar::Null(DataType::SHORT)), + }, + ScalarValue::Int32(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Integer(*value)), + None => Ok(Scalar::Null(DataType::INTEGER)), + }, + ScalarValue::Int64(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Long(*value)), + None => Ok(Scalar::Null(DataType::LONG)), + }, + ScalarValue::Float32(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Float(*value)), + None => Ok(Scalar::Null(DataType::FLOAT)), + }, + ScalarValue::Float64(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Double(*value)), + None => Ok(Scalar::Null(DataType::DOUBLE)), + }, + ScalarValue::TimestampMicrosecond(maybe_value, Some(_)) => match maybe_value { + Some(value) => Ok(Scalar::Timestamp(*value)), + None => Ok(Scalar::Null(DataType::TIMESTAMP)), + }, + ScalarValue::TimestampMicrosecond(maybe_value, None) => match maybe_value { + Some(value) => Ok(Scalar::TimestampNtz(*value)), + None => Ok(Scalar::Null(DataType::TIMESTAMP_NTZ)), + }, + ScalarValue::Date32(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Date(*value)), + None => Ok(Scalar::Null(DataType::DATE)), + }, + ScalarValue::Binary(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Binary(value.clone())), + None => Ok(Scalar::Null(DataType::BINARY)), + }, + ScalarValue::Decimal128(maybe_value, precision, scale) => match maybe_value { + Some(value) => Ok(Scalar::Decimal( + DecimalData::try_new( + *value, + DecimalType::try_new(*precision, *scale as u8).map_err(to_df_err)?, + ) + .map_err(to_df_err)?, + )), + None => Ok(Scalar::Null(DataType::Primitive(PrimitiveType::Decimal( + DecimalType::try_new(*precision, *scale as u8).map_err(to_df_err)?, + )))), + }, + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported scalar value: {:?}", + scalar + ))), + } +} + +fn to_binary_predicate_op(op: Operator) -> DFResult { + match op { + Operator::Eq => Ok(BinaryPredicateOp::Equal), + Operator::Lt => Ok(BinaryPredicateOp::LessThan), + Operator::Gt => Ok(BinaryPredicateOp::GreaterThan), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported operator: {:?}", + op + ))), + } +} + +fn to_binary_op(op: Operator) -> DFResult { + match op { + Operator::Plus => Ok(BinaryExpressionOp::Plus), + Operator::Minus => Ok(BinaryExpressionOp::Minus), + Operator::Multiply => Ok(BinaryExpressionOp::Multiply), + Operator::Divide => Ok(BinaryExpressionOp::Divide), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported operator: {:?}", + op + ))), + } +} + +/// Helper function to flatten nested AND/OR expressions into a single junction expression +fn flatten_junction_expr(expr: &Expr, target_op: Operator) -> DFResult> { + match expr { + Expr::BinaryExpr(BinaryExpr { op, left, right }) if *op == target_op => { + let mut left_exprs = flatten_junction_expr(left.as_ref(), target_op)?; + let mut right_exprs = flatten_junction_expr(right.as_ref(), target_op)?; + left_exprs.append(&mut right_exprs); + Ok(left_exprs) + } + _ => { + let delta_expr = to_predicate(expr)?; + Ok(vec![delta_expr]) + } + } +} + +fn to_junction_op(op: Operator) -> JunctionPredicateOp { + match op { + Operator::And => JunctionPredicateOp::And, + Operator::Or => JunctionPredicateOp::Or, + _ => unimplemented!("Unsupported operator: {:?}", op), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::logical_expr::{col, lit}; + use delta_kernel::expressions::{BinaryExpressionOp, JunctionPredicateOp, Scalar}; + + fn assert_junction_expr( + expr: &Expr, + expected_op: JunctionPredicateOp, + expected_children: usize, + ) { + let delta_expr = to_delta_expression(expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Junction(junction) => { + assert_eq!(junction.op, expected_op); + assert_eq!(junction.preds.len(), expected_children); + } + _ => panic!("Expected Junction predicate, got {:?}", predicate), + }, + _ => panic!("Expected Junction expression, got {:?}", delta_expr), + } + } + + #[test] + fn test_simple_and() { + let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); + assert_junction_expr(&expr, JunctionPredicateOp::And, 2); + } + + #[test] + fn test_simple_or() { + let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2))); + assert_junction_expr(&expr, JunctionPredicateOp::Or, 2); + } + + #[test] + fn test_nested_and() { + let expr = col("a") + .eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("c").eq(lit(3))) + .and(col("d").eq(lit(4))); + assert_junction_expr(&expr, JunctionPredicateOp::And, 4); + } + + #[test] + fn test_nested_or() { + let expr = col("a") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(col("c").eq(lit(3))) + .or(col("d").eq(lit(4))); + assert_junction_expr(&expr, JunctionPredicateOp::Or, 4); + } + + #[test] + fn test_mixed_nested_and_or() { + // (a AND b) OR (c AND d) + let left = col("a").eq(lit(1)).and(col("b").eq(lit(2))); + let right = col("c").eq(lit(3)).and(col("d").eq(lit(4))); + let expr = left.or(right); + + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Junction(junction) => { + assert_eq!(junction.op, JunctionPredicateOp::Or); + assert_eq!(junction.preds.len(), 2); + + // Check that both children are AND junctions + for child in &junction.preds { + match child { + Predicate::Junction(binary) => { + assert_eq!(binary.op, JunctionPredicateOp::And); + } + _ => panic!("Expected Binary expression in child: {:?}", child), + } + } + } + _ => panic!("Expected Junction predicate, got {:?}", predicate), + }, + _ => panic!("Expected Junction expression"), + } + } + + #[test] + fn test_deeply_nested_and() { + // (((a AND b) AND c) AND d) + let expr = col("a") + .eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("c").eq(lit(3))) + .and(col("d").eq(lit(4))); + assert_junction_expr(&expr, JunctionPredicateOp::And, 4); + } + + #[test] + fn test_complex_expression() { + // (a AND b) OR ((c AND d) AND e) + let left = col("a").eq(lit(1)).and(col("b").eq(lit(2))); + let right = col("c") + .eq(lit(3)) + .and(col("d").eq(lit(4))) + .and(col("e").eq(lit(5))); + let expr = left.or(right); + + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Junction(junction) => { + assert_eq!(junction.op, JunctionPredicateOp::Or); + assert_eq!(junction.preds.len(), 2); + + // First child should be an AND with 2 expressions + match &junction.preds[0] { + Predicate::Junction(child_junction) => { + assert_eq!(child_junction.op, JunctionPredicateOp::And); + assert_eq!(child_junction.preds.len(), 2); + } + _ => panic!("Expected Junction expression in first child"), + } + + // Second child should be an AND with 3 expressions + match &junction.preds[1] { + Predicate::Junction(child_junction) => { + assert_eq!(child_junction.op, JunctionPredicateOp::And); + assert_eq!(child_junction.preds.len(), 3); + } + _ => panic!("Expected Junction expression in second child"), + } + } + _ => panic!("Expected Junction predicate, got {:?}", predicate), + }, + _ => panic!("Expected Junction expression"), + } + } + + #[test] + fn test_column_expression() { + let expr = col("test_column"); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Column(name) => assert_eq!(&name.to_string(), "test_column"), + _ => panic!("Expected Column expression, got {:?}", delta_expr), + } + } + + #[test] + fn test_literal_expressions() { + // Test boolean literal + let expr = lit(true); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Boolean(value)) => assert!(value), + _ => panic!("Expected Boolean literal, got {:?}", delta_expr), + } + + // Test string literal + let expr = lit("test"); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::String(value)) => assert_eq!(value, "test"), + _ => panic!("Expected String literal, got {:?}", delta_expr), + } + + // Test integer literal + let expr = lit(42i32); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Integer(value)) => assert_eq!(value, 42), + _ => panic!("Expected Integer literal, got {:?}", delta_expr), + } + + // Test decimal literal + let expr = lit(ScalarValue::Decimal128(Some(12345), 10, 2)); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Decimal(data)) => { + assert_eq!(data.bits(), 12345); + assert_eq!(data.precision(), 10); + assert_eq!(data.scale(), 2); + } + _ => panic!("Expected Decimal literal, got {:?}", delta_expr), + } + } + + #[test] + fn test_binary_expressions() { + // Test comparison operators + let test_cases = vec![ + (col("a").eq(lit(1)), BinaryPredicateOp::Equal), + (col("a").lt(lit(1)), BinaryPredicateOp::LessThan), + (col("a").gt(lit(1)), BinaryPredicateOp::GreaterThan), + ]; + + for (expr, expected_op) in test_cases { + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Binary(binary) => { + assert_eq!(binary.op, expected_op); + match binary.left.as_ref() { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in left operand"), + } + match *binary.right.as_ref() { + Expression::Literal(Scalar::Integer(value)) => assert_eq!(value, 1), + _ => panic!("Expected Integer literal in right operand"), + } + } + _ => panic!("Expected Binary predicate, got {:?}", predicate), + }, + _ => panic!("Expected Binary expression, got {:?}", delta_expr), + } + } + + // Test arithmetic operators + let test_cases = vec![ + (col("a") + lit(1), BinaryExpressionOp::Plus), + (col("a") - lit(1), BinaryExpressionOp::Minus), + (col("a") * lit(1), BinaryExpressionOp::Multiply), + (col("a") / lit(1), BinaryExpressionOp::Divide), + ]; + + for (expr, expected_op) in test_cases { + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Binary(binary) => { + assert_eq!(binary.op, expected_op); + match binary.left.as_ref() { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in left operand"), + } + match *binary.right.as_ref() { + Expression::Literal(Scalar::Integer(value)) => assert_eq!(value, 1), + _ => panic!("Expected Integer literal in right operand"), + } + } + _ => panic!("Expected Binary expression, got {:?}", delta_expr), + } + } + } + + #[test] + fn test_unary_expressions() { + // Test IS NULL + let expr = col("a").is_null(); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Unary(unary) => { + assert_eq!(unary.op, UnaryPredicateOp::IsNull); + match unary.expr.as_ref() { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in operand"), + } + } + _ => panic!("Expected Unary predicate, got {:?}", predicate), + }, + _ => panic!("Expected Unary expression, got {:?}", delta_expr), + } + + // Test NOT + let expr = !col("a"); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Not(unary) => match unary.as_ref() { + Predicate::BooleanExpression(expr) => match expr { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in operand"), + }, + _ => panic!("Expected Boolean expression in operand"), + }, + _ => panic!("Expected Unary predicate, got {:?}", predicate), + }, + _ => panic!("Expected Unary expression, got {:?}", delta_expr), + } + } + + #[test] + fn test_null_literals() { + let test_cases = vec![ + (lit(ScalarValue::Boolean(None)), DataType::BOOLEAN), + (lit(ScalarValue::Utf8(None)), DataType::STRING), + (lit(ScalarValue::Int32(None)), DataType::INTEGER), + (lit(ScalarValue::Float64(None)), DataType::DOUBLE), + ]; + + for (expr, expected_type) in test_cases { + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Null(data_type)) => { + assert_eq!(data_type, expected_type); + } + _ => panic!("Expected Null literal, got {:?}", delta_expr), + } + } + } +} diff --git a/crates/core/src/delta_datafusion/engine/mod.rs b/crates/core/src/delta_datafusion/engine/mod.rs index 478541df5c..fc92ea7c4a 100644 --- a/crates/core/src/delta_datafusion/engine/mod.rs +++ b/crates/core/src/delta_datafusion/engine/mod.rs @@ -6,9 +6,12 @@ use delta_kernel::{Engine, EvaluationHandler, JsonHandler, ParquetHandler, Stora use tokio::runtime::Handle; use self::file_formats::DataFusionFileFormatHandler; -use self::storage::DataFusionStorageHandler; use crate::kernel::ARROW_HANDLER; +pub(crate) use expressions::*; +pub(crate) use storage::*; + +mod expressions; mod file_formats; mod storage; diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 7bb7157360..fd91de412b 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -782,14 +782,14 @@ impl TableProviderFactory for DeltaTableFactory { _ctx: &dyn Session, cmd: &CreateExternalTable, ) -> datafusion::error::Result> { - let provider = if cmd.options.is_empty() { + let table = if cmd.options.is_empty() { let table_url = ensure_table_uri(&cmd.to_owned().location)?; open_table(table_url).await? } else { let table_url = ensure_table_uri(&cmd.to_owned().location)?; open_table_with_storage_options(table_url, cmd.to_owned().options).await? }; - Ok(Arc::new(provider)) + Ok(Arc::new(table)) } } @@ -1173,14 +1173,14 @@ mod tests { let df = ctx.sql("select * from test").await.unwrap(); let actual = df.collect().await.unwrap(); let expected = vec! [ - "+----+----+----+-------------------------------------------------------------------------------+", - "| c3 | c1 | c2 | file_source |", - "+----+----+----+-------------------------------------------------------------------------------+", - "| 4 | 6 | a | c1=6/c2=a/part-00011-10619b10-b691-4fd0-acc4-2a9608499d7c.c000.snappy.parquet |", - "| 5 | 4 | c | c1=4/c2=c/part-00003-f525f459-34f9-46f5-82d6-d42121d883fd.c000.snappy.parquet |", - "| 6 | 5 | b | c1=5/c2=b/part-00007-4e73fa3b-2c88-424a-8051-f8b54328ffdb.c000.snappy.parquet |", - "+----+----+----+-------------------------------------------------------------------------------+", - ]; + "+----+----+----+-------------------------------------------------------------------------------+", + "| c3 | c1 | c2 | file_source |", + "+----+----+----+-------------------------------------------------------------------------------+", + "| 4 | 6 | a | c1=6/c2=a/part-00011-10619b10-b691-4fd0-acc4-2a9608499d7c.c000.snappy.parquet |", + "| 5 | 4 | c | c1=4/c2=c/part-00003-f525f459-34f9-46f5-82d6-d42121d883fd.c000.snappy.parquet |", + "| 6 | 5 | b | c1=5/c2=b/part-00007-4e73fa3b-2c88-424a-8051-f8b54328ffdb.c000.snappy.parquet |", + "+----+----+----+-------------------------------------------------------------------------------+", + ]; assert_batches_sorted_eq!(&expected, &actual); } @@ -1598,9 +1598,10 @@ mod tests { .unwrap(); let datafusion = SessionContext::new(); - let table = Arc::new(table); - datafusion.register_table("snapshot", table).unwrap(); + datafusion + .register_table("snapshot", Arc::new(table)) + .unwrap(); let df = datafusion .sql("select * from snapshot where id > 10000 and id < 20000") diff --git a/crates/core/src/delta_datafusion/session.rs b/crates/core/src/delta_datafusion/session.rs index 0b6f08453e..6635f3fddb 100644 --- a/crates/core/src/delta_datafusion/session.rs +++ b/crates/core/src/delta_datafusion/session.rs @@ -12,6 +12,20 @@ pub fn create_session() -> DeltaSessionContext { DeltaSessionContext::default() } +#[cfg(test)] +pub fn create_test_session() -> DeltaSessionContext { + use std::sync::Arc; + + use object_store::memory::InMemory; + + let session = DeltaSessionContext::default(); + session.inner.runtime_env().register_object_store( + &url::Url::parse("memory:///").unwrap(), + Arc::new(InMemory::new()), + ); + session +} + // Given a `Session` reference, get the concrete `SessionState` reference // Note: this may stop working in future versions, #[deprecated( diff --git a/crates/core/src/delta_datafusion/table_provider.rs b/crates/core/src/delta_datafusion/table_provider.rs index 7a1ac88781..9aa5081a0a 100644 --- a/crates/core/src/delta_datafusion/table_provider.rs +++ b/crates/core/src/delta_datafusion/table_provider.rs @@ -44,6 +44,7 @@ use datafusion::{ scalar::ScalarValue, }; use delta_kernel::table_properties::DataSkippingNumIndexedCols; +use futures::future::BoxFuture; use futures::StreamExt as _; use itertools::Itertools; use object_store::ObjectMeta; @@ -56,13 +57,16 @@ use crate::delta_datafusion::{ }; use crate::kernel::schema::cast::cast_record_batch; use crate::kernel::transaction::{CommitBuilder, PROTOCOL}; -use crate::kernel::{Action, Add, EagerSnapshot, Remove}; +use crate::kernel::{resolve_snapshot, Action, Add, EagerSnapshot, Remove}; +use crate::logstore::LogStore; use crate::operations::write::writer::{DeltaWriter, WriterConfig}; use crate::operations::write::WriterStatsConfig; use crate::protocol::{DeltaOperation, SaveMode}; use crate::{ensure_table_uri, DeltaTable}; use crate::{logstore::LogStoreRef, DeltaResult, DeltaTableError}; +pub(crate) mod next; + const PATH_COLUMN: &str = "__delta_rs_path"; /// DataSink implementation for delta lake diff --git a/crates/core/src/delta_datafusion/table_provider/next/mod.rs b/crates/core/src/delta_datafusion/table_provider/next/mod.rs new file mode 100644 index 0000000000..2c53d10ff1 --- /dev/null +++ b/crates/core/src/delta_datafusion/table_provider/next/mod.rs @@ -0,0 +1,486 @@ +use std::any::Any; +use std::pin::Pin; +use std::{borrow::Cow, sync::Arc}; + +use arrow::datatypes::{DataType, Field, SchemaRef}; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::common::{DataFusionError, HashMap, Result}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::DefaultParquetFileReaderFactory; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::datasource::TableType; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::TableProviderFilterPushDown; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion::physical_plan::union::UnionExec; +use datafusion::prelude::Expr; +use datafusion::scalar::ScalarValue; +use datafusion::{ + catalog::{Session, TableProvider}, + logical_expr::{dml::InsertOp, LogicalPlan}, + physical_plan::ExecutionPlan, +}; +use delta_kernel::engine::arrow_conversion::{TryIntoArrow, TryIntoKernel}; +use delta_kernel::scan::ScanMetadata; +use delta_kernel::schema::SchemaRef as KernelSchemaRef; +use delta_kernel::Engine; +use futures::future::ready; +use futures::{Stream, TryStreamExt as _}; +use itertools::Itertools; +use object_store::path::Path; + +use crate::delta_datafusion::engine::{ + to_delta_predicate, AsObjectStoreUrl as _, DataFusionEngine, +}; +use crate::delta_datafusion::table_provider::get_pushdown_filters; +use crate::delta_datafusion::table_provider::next::replay::{ScanFileContext, ScanFileStream}; +use crate::delta_datafusion::table_provider::next::scan::DeltaScanExec; +use crate::delta_datafusion::DataFusionMixins as _; +use crate::kernel::{EagerSnapshot, Scan, Snapshot}; +use crate::DeltaTableError; + +mod replay; +mod scan; + +impl Snapshot { + fn kernel_scan(&self, projection: Option<&Vec>, filters: &[Expr]) -> Result> { + let (_, projected_kernel_schema) = project_schema(self.read_schema(), projection)?; + Ok(Arc::new( + self.scan_builder() + .with_schema(projected_kernel_schema.clone()) + .with_predicate(Arc::new(to_delta_predicate(filters)?)) + .build()?, + )) + } + + async fn execution_plan( + &self, + session: &dyn Session, + scan: Arc, + stream: Pin> + Send>>, + engine: Arc, + projection: Option<&Vec>, + limit: Option, + ) -> Result> { + // let (_, projected_kernel_schema) = project_schema(self.read_schema(), projection)?; + + let mut stream = ScanFileStream::new(engine, &scan, stream); + let mut files = Vec::new(); + while let Some(file) = stream.try_next().await? { + files.push(file); + } + + let transforms: HashMap<_, _> = files + .iter() + .flatten() + .flat_map(|file| { + file.transform + .as_ref() + .map(|t| (file.file_url.to_string(), t.clone())) + }) + .collect(); + let dv_stream = stream.dv_stream.build(); + let dvs: HashMap<_, _> = dv_stream + .try_filter_map(|(url, dv)| ready(Ok(dv.map(|dv| (url.to_string(), dv))))) + .try_collect() + .await?; + + let metrics = ExecutionPlanMetricsSet::new(); + MetricBuilder::new(&metrics) + .global_counter("count_files_skipped") + .add(stream.metrics.num_skipped); + MetricBuilder::new(&metrics) + .global_counter("count_files_scanned") + .add(stream.metrics.num_scanned); + + let file_id_column = "__delta_rs_file_id".to_string(); + + // Convert the files into datafusions `PartitionedFile`s grouped by the object store they are stored in + // this is used to create a DataSourceExec plan for each store + // To correlate the data with the original file, we add the file url as a partition value + // This is required to apply the correct transform to the data in downstream processing. + let to_partitioned_file = |f: ScanFileContext| { + let file_path = Path::from_url_path(f.file_url.path())?; + let mut partitioned_file = PartitionedFile::new(file_path.to_string(), f.size) + .with_statistics(Arc::new(f.stats)); + partitioned_file.partition_values = + vec![ScalarValue::Utf8(Some(f.file_url.to_string()))]; + // NB: we need to reassign the location since the 'new' method does incompatible path encoding internally. + partitioned_file.object_meta.location = file_path; + Ok::<_, DataFusionError>(( + f.file_url.as_object_store_url(), + (partitioned_file, None::>), + )) + }; + + let files_by_store = files + .into_iter() + .flat_map(|fs| fs.into_iter().map(to_partitioned_file)) + .try_collect::<_, Vec<_>, _>() + .map_err(|e| DataFusionError::External(Box::new(e)))? + .into_iter() + .into_group_map(); + + let physical_schema = Arc::new(scan.physical_schema().as_ref().try_into_arrow()?); + let pq_plan = get_read_plan( + files_by_store, + &physical_schema, + session, + limit, + Field::new( + file_id_column.clone(), + DataType::Dictionary(DataType::UInt16.into(), DataType::Utf8.into()), + false, + ), + &metrics, + ) + .await?; + + let (projected_arrow_schema, _) = project_schema(self.read_schema(), projection)?; + let exec = DeltaScanExec::new( + projected_arrow_schema, + scan.logical_schema().clone(), + pq_plan, + Arc::new(transforms), + Arc::new(dvs), + file_id_column, + metrics, + ); + + Ok(Arc::new(exec)) + } +} + +#[async_trait::async_trait] +impl TableProvider for Snapshot { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.read_schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn get_table_definition(&self) -> Option<&str> { + None + } + + fn get_logical_plan(&self) -> Option> { + None + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let scan = self.kernel_scan(projection, filters)?; + let engine = DataFusionEngine::new_from_session(session); + let stream = scan.scan_metadata(engine.clone()); + self.execution_plan(session, scan, stream, engine, projection, limit) + .await + } + + fn supports_filters_pushdown( + &self, + filter: &[&Expr], + ) -> Result> { + Ok(get_pushdown_filters( + filter, + self.metadata().partition_columns(), + )) + } + + /// Insert the data into the delta table + /// Insert operation is only supported for Append and Overwrite + /// Return the execution plan + async fn insert_into( + &self, + state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> Result> { + todo!("Implement insert_into method") + } +} + +#[async_trait::async_trait] +impl TableProvider for EagerSnapshot { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + TableProvider::schema(self.snapshot()) + } + + fn table_type(&self) -> TableType { + self.snapshot().table_type() + } + + fn get_table_definition(&self) -> Option<&str> { + self.snapshot().get_table_definition() + } + + fn get_logical_plan(&self) -> Option> { + self.snapshot().get_logical_plan() + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let scan = self.snapshot().kernel_scan(projection, filters)?; + let engine = DataFusionEngine::new_from_session(session); + let stream = if let Ok(files) = self.files() { + scan.scan_metadata_from( + engine.clone(), + self.snapshot().version() as u64, + Box::new(files.to_vec().into_iter()), + None, + ) + } else { + scan.scan_metadata(engine.clone()) + }; + self.snapshot() + .execution_plan(session, scan, stream, engine, projection, limit) + .await + } + + fn supports_filters_pushdown( + &self, + filter: &[&Expr], + ) -> Result> { + self.snapshot().supports_filters_pushdown(filter) + } + + /// Insert the data into the delta table + /// Insert operation is only supported for Append and Overwrite + /// Return the execution plan + async fn insert_into( + &self, + state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> Result> { + self.snapshot().insert_into(state, input, insert_op).await + } +} + +async fn get_read_plan( + files_by_store: impl IntoIterator< + Item = (ObjectStoreUrl, Vec<(PartitionedFile, Option>)>), + >, + physical_schema: &SchemaRef, + state: &dyn Session, + limit: Option, + file_id_field: Field, + _metrics: &ExecutionPlanMetricsSet, +) -> Result> { + // TODO: update parquet source. + let source = ParquetSource::default(); + + let mut plans = Vec::new(); + + for (store_url, files) in files_by_store.into_iter() { + // state.ensure_object_store(store_url.as_ref()).await?; + + let store = state.runtime_env().object_store(&store_url)?; + let _reader_factory = source + .parquet_file_reader_factory() + .cloned() + .unwrap_or_else(|| Arc::new(DefaultParquetFileReaderFactory::new(store))); + + // let file_group = compute_parquet_access_plans(&reader_factory, files, &metrics).await?; + let file_group = files.into_iter().map(|file| file.0); + + // TODO: convert passed predicate to an expression in terms of physical columns + // and add it to the FileScanConfig + // let file_source = + // source.with_schema_adapter_factory(Arc::new(NestedSchemaAdapterFactory))?; + let file_source = Arc::new(source.clone()); + let config = FileScanConfigBuilder::new(store_url, physical_schema.clone(), file_source) + .with_file_group(file_group.into_iter().collect()) + .with_table_partition_cols(vec![file_id_field.clone()]) + .with_limit(limit) + .build(); + let plan: Arc = DataSourceExec::from_data_source(config); + plans.push(plan); + } + + let plan = match plans.len() { + 1 => plans.remove(0), + _ => Arc::new(UnionExec::new(plans)), + }; + Ok(match plan.with_fetch(limit) { + Some(limit) => limit, + None => plan, + }) +} + +fn project_schema( + schema: SchemaRef, + projection: Option<&Vec>, +) -> Result<(SchemaRef, KernelSchemaRef)> { + let projected_arrow_schema = match projection { + Some(p) => Arc::new(schema.project(p)?), + None => schema, + }; + let projected_kernel_schema: KernelSchemaRef = Arc::new( + projected_arrow_schema + .as_ref() + .try_into_kernel() + .map_err(DeltaTableError::from)?, + ); + Ok((projected_arrow_schema, projected_kernel_schema)) +} + +#[cfg(test)] +mod tests { + use datafusion::{ + datasource::{physical_plan::FileScanConfig, source::DataSource}, + physical_plan::{collect_partitioned, visit_execution_plan, ExecutionPlanVisitor}, + }; + + use crate::{ + assert_batches_sorted_eq, + delta_datafusion::create_test_session, + kernel::Snapshot, + test_utils::{TestResult, TestTables}, + }; + + use super::*; + + /// Extracts fields from the parquet scan + #[derive(Default)] + struct DeltaScanVisitor { + num_skipped: Option, + num_scanned: Option, + total_bytes_scanned: Option, + } + + impl DeltaScanVisitor { + fn pre_visit_delta_scan( + &mut self, + delta_scan_exec: &DeltaScanExec, + ) -> Result { + let Some(metrics) = delta_scan_exec.metrics() else { + return Ok(true); + }; + + self.num_skipped = metrics + .sum_by_name("count_files_skipped") + .map(|v| v.as_usize()); + self.num_scanned = metrics + .sum_by_name("count_files_scanned") + .map(|v| v.as_usize()); + + Ok(true) + } + + fn pre_visit_data_source( + &mut self, + datasource_exec: &DataSourceExec, + ) -> Result { + let Some(scan_config) = datasource_exec + .data_source() + .as_any() + .downcast_ref::() + else { + return Ok(true); + }; + + let pq_metrics = scan_config + .metrics() + .clone_inner() + .sum_by_name("bytes_scanned"); + self.total_bytes_scanned = pq_metrics.map(|v| v.as_usize()); + + // if let Some(parquet_source) = scan_config + // .file_source + // .as_any() + // .downcast_ref::() + // { + // parquet_source + // } + + Ok(true) + } + } + + impl ExecutionPlanVisitor for DeltaScanVisitor { + type Error = DataFusionError; + + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + if let Some(delta_scan_exec) = plan.as_any().downcast_ref::() { + return self.pre_visit_delta_scan(delta_scan_exec); + }; + + if let Some(datasource_exec) = plan.as_any().downcast_ref::() { + return self.pre_visit_data_source(datasource_exec); + } + + Ok(true) + } + } + + #[tokio::test] + async fn test_query_simple_table() -> TestResult { + let log_store = TestTables::Simple.table_builder()?.build_storage()?; + let snapshot = Arc::new(Snapshot::try_new(&log_store, Default::default(), None).await?); + + let session = Arc::new(create_test_session().into_inner()); + session.register_table("delta_table", snapshot).unwrap(); + + let df = session.sql("SELECT * FROM delta_table").await.unwrap(); + let batches = df.collect().await?; + + let expected = vec![ + "+----+", "| id |", "+----+", "| 5 |", "| 7 |", "| 9 |", "+----+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn test_scan_simple_table() -> TestResult { + let log_store = TestTables::Simple.table_builder()?.build_storage()?; + let snapshot = Snapshot::try_new(&log_store, Default::default(), None).await?; + + let session = Arc::new(create_test_session().into_inner()); + let state = session.state_ref().read().clone(); + + let plan = snapshot.scan(&state, None, &[], None).await?; + + let batches: Vec<_> = collect_partitioned(plan.clone(), session.task_ctx()) + .await? + .into_iter() + .flatten() + .collect(); + + let mut visitor = DeltaScanVisitor::default(); + visit_execution_plan(plan.as_ref(), &mut visitor).unwrap(); + + assert_eq!(visitor.num_scanned, Some(5)); + assert_eq!(visitor.num_skipped, Some(28)); + assert_eq!(visitor.total_bytes_scanned, Some(231)); + + let expected = vec![ + "+----+", "| id |", "+----+", "| 5 |", "| 7 |", "| 9 |", "+----+", + ]; + + assert_batches_sorted_eq!(&expected, &batches); + + Ok(()) + } +} diff --git a/crates/core/src/delta_datafusion/table_provider/next/replay.rs b/crates/core/src/delta_datafusion/table_provider/next/replay.rs new file mode 100644 index 0000000000..c66872aa4f --- /dev/null +++ b/crates/core/src/delta_datafusion/table_provider/next/replay.rs @@ -0,0 +1,354 @@ +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::array::BooleanArray; +use arrow::compute::filter_record_batch; +use datafusion::{ + common::{ + error::DataFusionErrorBuilder, stats::Precision, ColumnStatistics, HashMap, Statistics, + }, + error::DataFusionError, + scalar::ScalarValue, +}; +use delta_kernel::{ + engine::{arrow_conversion::TryIntoArrow, arrow_data::ArrowEngineData}, + expressions::{Scalar, StructData}, + scan::{ + state::{DvInfo, Stats}, + Scan as KernelScan, ScanMetadata, + }, + Engine, ExpressionRef, +}; +use futures::Stream; +use itertools::Itertools; +use pin_project_lite::pin_project; +use url::Url; + +use crate::{ + delta_datafusion::engine::scalar_to_df, + kernel::{ + arrow::engine_ext::stats_schema, parse_stats_column_with_schema, LogicalFileView, + ReceiverStreamBuilder, Scan, StructDataExt, + }, + DeltaResult, +}; + +#[derive(Debug)] +pub(crate) struct ReplayStats { + pub(crate) num_skipped: usize, + pub(crate) num_scanned: usize, +} + +impl ReplayStats { + fn new() -> Self { + Self { + num_skipped: 0, + num_scanned: 0, + } + } +} + +pin_project! { + pub(crate) struct ScanFileStream { + pub(crate) metrics: ReplayStats, + + engine: Arc, + + table_root: Url, + + kernel_scan: Arc, + + pub(crate) dv_stream: ReceiverStreamBuilder<(Url, Option>)>, + + #[pin] + stream: S, + } +} + +impl ScanFileStream { + pub(crate) fn new(engine: Arc, scan: &Arc, stream: S) -> Self { + Self { + metrics: ReplayStats::new(), + dv_stream: ReceiverStreamBuilder::<(Url, Option>)>::new(100), + engine, + table_root: scan.table_root().clone(), + kernel_scan: scan.inner.clone(), + stream, + } + } +} + +impl Stream for ScanFileStream +where + S: Stream>, +{ + type Item = DeltaResult>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let physical_arrow = this + .kernel_scan + .physical_schema() + .as_ref() + .try_into_arrow() + .unwrap(); + match this.stream.poll_next(cx) { + Poll::Ready(Some(Ok(scan_data))) => { + let mut ctx = ScanContext::new(this.table_root.clone()); + ctx = match scan_data.visit_scan_files(ctx, visit_scan_file) { + Ok(ctx) => ctx, + Err(err) => return Poll::Ready(Some(Err(err.into()))), + }; + + // Spawn tasks to read the deletion vectors from disk. + for file in &ctx.files { + let engine = this.engine.clone(); + let dv_info = file.dv_info.clone(); + let file_url = file.file_url.clone(); + let table_root = this.table_root.clone(); + let tx = this.dv_stream.tx(); + if dv_info.has_vector() { + let load_dv = move || { + let dv = dv_info.get_selection_vector(engine.as_ref(), &table_root)?; + let _ = tx.blocking_send(Ok((file_url, dv))); + Ok(()) + }; + this.dv_stream.spawn_blocking(load_dv); + } + } + + this.metrics.num_scanned += ctx.count; + this.metrics.num_skipped += scan_data + .scan_files + .selection_vector + .len() + .saturating_sub(ctx.count); + + let batch = + ArrowEngineData::try_from_engine_data(scan_data.scan_files.data)?.into(); + let scan_files = filter_record_batch( + &batch, + &BooleanArray::from(scan_data.scan_files.selection_vector), + )?; + + let stats_schema = Arc::new(stats_schema( + this.kernel_scan.physical_schema(), + this.kernel_scan.snapshot().table_properties(), + )); + let parsed_stats = parse_stats_column_with_schema( + this.kernel_scan.snapshot().as_ref(), + &scan_files, + stats_schema, + )?; + + let mut file_statistics = extract_file_statistics(&this.kernel_scan, parsed_stats); + + Poll::Ready(Some(Ok(ctx + .files + .into_iter() + .map(|ctx| { + let stats = file_statistics + .remove(&ctx.file_url) + .unwrap_or_else(|| Statistics::new_unknown(&physical_arrow)); + ScanFileContext::new(ctx, stats) + }) + .collect_vec()))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +fn extract_file_statistics( + scan: &KernelScan, + parsed_stats: arrow_array::RecordBatch, +) -> HashMap { + (0..parsed_stats.num_rows()) + .map(move |idx| LogicalFileView::new(parsed_stats.clone(), idx)) + .filter_map(|view| { + let num_rows = view + .num_records() + .map(|num| Precision::Exact(num)) + .unwrap_or(Precision::Absent); + let total_byte_size = Precision::Exact(view.size() as usize); + + let null_counts = extract_struct(view.null_counts()); + let max_values = extract_struct(view.max_values()); + let min_values = extract_struct(view.min_values()); + + let column_statistics = scan + .physical_schema() + .fields() + .map(|f| { + let null_count = if let Some(field_index) = + null_counts.as_ref().and_then(|v| v.index_of(f.name())) + { + null_counts + .as_ref() + .map(|v| match v.values()[field_index] { + Scalar::Integer(int_val) => Precision::Exact(int_val as usize), + Scalar::Long(long_val) => Precision::Exact(long_val as usize), + _ => Precision::Absent, + }) + .unwrap_or_default() + } else { + Precision::Absent + }; + + let max_value = extract_precision(&max_values, f.name()); + let min_value = extract_precision(&min_values, f.name()); + + ColumnStatistics { + null_count, + max_value, + min_value, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + } + }) + .collect_vec(); + + Some(( + parse_path(&scan.snapshot().table_root(), view.path().as_ref()).ok()?, + Statistics { + num_rows, + total_byte_size, + column_statistics, + }, + )) + }) + .collect() +} + +fn extract_precision(data: &Option, name: impl AsRef) -> Precision { + if let Some(field_index) = data.as_ref().and_then(|v| v.index_of(name.as_ref())) { + data.as_ref() + .map(|v| match scalar_to_df(&v.values()[field_index]) { + Ok(df) => Precision::Exact(df), + _ => Precision::Absent, + }) + .unwrap_or_default() + } else { + Precision::Absent + } +} + +fn extract_struct(scalar: Option) -> Option { + match scalar { + Some(Scalar::Struct(data)) => Some(data), + _ => None, + } +} + +#[derive(Debug)] +pub struct ScanFileContext { + /// Fully qualified URL of the file. + pub file_url: Url, + /// Size of the file on disk. + pub size: u64, + /// Selection vector to filter the data in the file. + // pub selection_vector: Option>, + /// Transformations to apply to the data in the file. + pub transform: Option, + /// Statistics about the data in the file. + /// + /// The query engine may choose to use these statistics to further optimize the scan. + pub stats: Statistics, +} + +impl ScanFileContext { + /// Create a new `ScanFileContext` with the given file URL, size, and statistics. + fn new(inner: ScanFileContextInner, stats: Statistics) -> Self { + Self { + file_url: inner.file_url, + size: inner.size, + transform: inner.transform, + stats, + } + } +} + +/// Metadata to read a data file from object storage. +struct ScanFileContextInner { + /// Fully qualified URL of the file. + pub file_url: Url, + /// Size of the file on disk. + pub size: u64, + /// Selection vector to filter the data in the file. + // pub selection_vector: Option>, + /// Transformations to apply to the data in the file. + pub transform: Option, + + pub dv_info: DvInfo, +} + +struct ScanContext { + /// Table root URL + table_root: Url, + /// Files to be scanned. + files: Vec, + /// Errors encountered during the scan. + errs: DataFusionErrorBuilder, + count: usize, +} + +impl ScanContext { + fn new(table_root: Url) -> Self { + Self { + table_root, + files: Vec::new(), + errs: DataFusionErrorBuilder::new(), + count: 0, + } + } + + fn parse_path(&self, path: &str) -> DeltaResult { + parse_path(&self.table_root, path) + } +} + +fn parse_path(url: &Url, path: &str) -> DeltaResult { + Ok(match Url::parse(path) { + Ok(url) => url, + Err(_) => url + .join(path) + .map_err(|e| DataFusionError::External(Box::new(e)))?, + }) +} + +fn visit_scan_file( + ctx: &mut ScanContext, + path: &str, + size: i64, + _stats: Option, + dv_info: DvInfo, + transform: Option, + // NB: partition values are passed for backwards compatibility + // all required transformations are now part of the transform field + _: std::collections::HashMap, +) { + let file_url = match ctx.parse_path(path) { + Ok(v) => v, + Err(e) => { + ctx.errs.add_error(e); + return; + } + }; + + ctx.files.push(ScanFileContextInner { + dv_info, + transform, + file_url, + size: size as u64, + }); + ctx.count += 1; +} diff --git a/crates/core/src/delta_datafusion/table_provider/next/scan.rs b/crates/core/src/delta_datafusion/table_provider/next/scan.rs new file mode 100644 index 0000000000..3d5ae7fa6d --- /dev/null +++ b/crates/core/src/delta_datafusion/table_provider/next/scan.rs @@ -0,0 +1,295 @@ +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{ArrayAccessor, AsArray, RecordBatch, StringArray}; +use arrow::datatypes::{SchemaRef, UInt16Type}; +use datafusion::common::config::ConfigOptions; +use datafusion::common::error::{DataFusionError, Result}; +use datafusion::common::HashMap; +use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::execution_plan::{CardinalityEffect, PlanProperties}; +use datafusion::physical_plan::filter_pushdown::{FilterDescription, FilterPushdownPhase}; +use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PhysicalExpr, Statistics, +}; +use delta_kernel::engine::arrow_conversion::TryIntoKernel; +use delta_kernel::schema::SchemaRef as KernelSchemaRef; +use delta_kernel::{EvaluationHandler, ExpressionRef}; +use futures::stream::{Stream, StreamExt}; + +use crate::kernel::arrow::engine_ext::ExpressionEvaluatorExt; +use crate::kernel::ARROW_HANDLER; + +#[derive(Clone, Debug)] +pub struct DeltaScanExec { + /// Output schema for processed data. + logical_schema: SchemaRef, + kernel_logical_schema: KernelSchemaRef, + /// Execution plan yielding the raw data read from data files. + input: Arc, + /// Transforms to be applied to data eminating from individual files + transforms: Arc>, + /// Deletion vectors for the table + selection_vectors: Arc>>, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + file_id_column: String, +} + +impl DisplayAs for DeltaScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // TODO: actually implement formatting according to the type + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + write!(f, "DeltaScanExec: ") + } + } + } +} + +impl DeltaScanExec { + pub(crate) fn new( + logical_schema: SchemaRef, + kernel_logical_schema: KernelSchemaRef, + input: Arc, + transforms: Arc>, + selection_vectors: Arc>>, + file_id_column: String, + metrics: ExecutionPlanMetricsSet, + ) -> Self { + Self { + logical_schema, + kernel_logical_schema, + input, + transforms, + selection_vectors, + metrics, + file_id_column, + } + } +} + +impl ExecutionPlan for DeltaScanExec { + fn name(&self) -> &'static str { + "DeltaScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + // fn maintains_input_order(&self) -> Vec { + // // Tell optimizer this operator doesn't reorder its input + // vec![true] + // } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return Err(DataFusionError::Plan(format!( + "DeltaScan: wrong number of children {}", + children.len() + ))); + } + Ok(Arc::new(Self { + logical_schema: self.logical_schema.clone(), + kernel_logical_schema: self.kernel_logical_schema.clone(), + input: children[0].clone(), + transforms: self.transforms.clone(), + selection_vectors: self.selection_vectors.clone(), + metrics: self.metrics.clone(), + file_id_column: self.file_id_column.clone(), + })) + } + + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> Result>> { + if let Some(new_pq) = self.input.repartitioned(target_partitions, config)? { + let mut new_plan = self.clone(); + new_plan.input = new_pq; + Ok(Some(Arc::new(new_plan))) + } else { + Ok(None) + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + Ok(Box::pin(DeltaScanStream { + schema: Arc::clone(&self.logical_schema), + kernel_schema: Arc::clone(&self.kernel_logical_schema), + input: self.input.execute(partition, context)?, + baseline_metrics: BaselineMetrics::new(&self.metrics, partition), + transforms: Arc::clone(&self.transforms), + selection_vectors: Arc::clone(&self.selection_vectors), + file_id_column: self.file_id_column.clone(), + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + self.input.partition_statistics(None) + } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } + + fn fetch(&self) -> Option { + self.input.fetch() + } + + fn with_fetch(&self, limit: Option) -> Option> { + if let Some(new_input) = self.input.with_fetch(limit) { + let mut new_plan = self.clone(); + new_plan.input = new_input; + Some(Arc::new(new_plan)) + } else { + None + } + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + // TODO(roeap): this will likely not do much for column mapping enabled tables + // since the default methods determines this b ased on existence of columns in child + // schemas. In the case of column mapping all columns will have a different name. + FilterDescription::from_children(parent_filters, &self.children()) + } +} + +/// Stream of RecordBatches produced read from delta table. +/// +/// The data returned by this stream represents the logical data caontained inn the table. +/// This means all transformations according to the delta protocol are applied. +struct DeltaScanStream { + schema: SchemaRef, + kernel_schema: KernelSchemaRef, + input: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + /// Transforms to be applied to data read from individual files + transforms: Arc>, + /// Selection vectors to be applied to data read from individual files + selection_vectors: Arc>>, + /// Column name for the file id + file_id_column: String, +} + +impl DeltaScanStream { + fn batch_project(&self, mut batch: RecordBatch) -> Result { + // Records time on drop + let _timer = self.baseline_metrics.elapsed_compute().timer(); + + let (file_id, file_id_idx) = extract_file_id(&batch, &self.file_id_column)?; + batch.remove_column(file_id_idx); + + let Some(transform) = self.transforms.get(&file_id) else { + let batch = RecordBatch::try_new(self.schema.clone(), batch.columns().to_vec())?; + return Ok(batch); + }; + + let input_schema = Arc::new( + batch + .schema() + .try_into_kernel() + .map_err(|e| DataFusionError::External(Box::new(e)))?, + ); + let evaluator = ARROW_HANDLER.new_expression_evaluator( + input_schema, + transform.clone(), + self.kernel_schema.clone().into(), + ); + + let result = evaluator + .evaluate_arrow(batch) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + Ok(result) + } +} + +impl Stream for DeltaScanStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => Some(self.batch_project(batch)), + other => other, + }); + self.baseline_metrics.record_poll(poll) + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} + +impl RecordBatchStream for DeltaScanStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +fn extract_file_id(batch: &RecordBatch, file_id_column: &str) -> Result<(String, usize)> { + let file_id_idx = batch + .schema_ref() + .fields() + .iter() + .position(|f| f.name() == file_id_column) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected column '{}' to be present in the input", + file_id_column + )) + })?; + + let file_id = batch + .column(file_id_idx) + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected file id column to be a dictionary of strings" + )) + })? + .value(0) + .to_string(); + + Ok((file_id, file_id_idx)) +} diff --git a/crates/core/src/kernel/snapshot/iterators.rs b/crates/core/src/kernel/snapshot/iterators.rs index cc1cce66bd..3d291a850e 100644 --- a/crates/core/src/kernel/snapshot/iterators.rs +++ b/crates/core/src/kernel/snapshot/iterators.rs @@ -20,6 +20,8 @@ use crate::kernel::scalars::ScalarExt; use crate::kernel::{Add, DeletionVectorDescriptor, Remove}; use crate::{DeltaResult, DeltaTableError}; +#[cfg(feature = "datafusion")] +pub(crate) use self::scan_row::parse_stats_column_with_schema; pub(crate) use self::scan_row::{scan_row_in_eval, ScanRowOutStream}; pub use self::tombstones::TombstoneView; diff --git a/crates/core/src/kernel/snapshot/iterators/scan_row.rs b/crates/core/src/kernel/snapshot/iterators/scan_row.rs index d897d49aa5..c21ff1c4b0 100644 --- a/crates/core/src/kernel/snapshot/iterators/scan_row.rs +++ b/crates/core/src/kernel/snapshot/iterators/scan_row.rs @@ -13,8 +13,8 @@ use delta_kernel::engine::parse_json; use delta_kernel::expressions::Scalar; use delta_kernel::expressions::UnaryExpressionOp; use delta_kernel::scan::scan_row_schema; -use delta_kernel::schema::DataType; use delta_kernel::schema::PrimitiveType; +use delta_kernel::schema::{DataType, SchemaRef as KernelSchemaRef}; use delta_kernel::snapshot::Snapshot as KernelSnapshot; use delta_kernel::{EvaluationHandler, Expression, ExpressionEvaluator}; use futures::Stream; @@ -89,7 +89,19 @@ pub(crate) fn scan_row_in_eval( Ok(ARROW_HANDLER.new_expression_evaluator(input_schema, EXPRESSION.clone(), OUT_TYPE.clone())) } -fn parse_stats_column(sn: &KernelSnapshot, batch: &RecordBatch) -> DeltaResult { +pub(crate) fn parse_stats_column( + sn: &KernelSnapshot, + batch: &RecordBatch, +) -> DeltaResult { + let stats_schema = sn.stats_schema()?; + parse_stats_column_with_schema(sn, batch, stats_schema) +} + +pub(crate) fn parse_stats_column_with_schema( + sn: &KernelSnapshot, + batch: &RecordBatch, + stats_schema: KernelSchemaRef, +) -> DeltaResult { let Some((stats_idx, _)) = batch.schema_ref().column_with_name("stats") else { return Err(DeltaTableError::SchemaMismatch { msg: "stats column not found".to_string(), @@ -99,7 +111,6 @@ fn parse_stats_column(sn: &KernelSnapshot, batch: &RecordBatch) -> DeltaResult, + pub(crate) inner: Arc, } impl From for Scan { diff --git a/crates/core/src/kernel/snapshot/serde.rs b/crates/core/src/kernel/snapshot/serde.rs index a644a7f17d..b8aa513af0 100644 --- a/crates/core/src/kernel/snapshot/serde.rs +++ b/crates/core/src/kernel/snapshot/serde.rs @@ -51,6 +51,12 @@ impl Serialize for Snapshot { .latest_crc_file .as_ref() .map(|f| FileMetaSerde::from(&f.location)); + // let latest_commit_file = self + // .inner + // .log_segment() + // .latest_commit_file + // .as_ref() + // .map(|f| FileMetaSerde::from(&f.location)); let mut seq = serializer.serialize_seq(None)?; @@ -62,6 +68,7 @@ impl Serialize for Snapshot { seq.serialize_element(&ascending_compaction_files)?; seq.serialize_element(&checkpoint_parts)?; seq.serialize_element(&latest_crc_file)?; + // seq.serialize_element(&latest_commit_file)?; seq.serialize_element(&self.config)?; @@ -172,11 +179,17 @@ impl<'de> Visitor<'de> for SnapshotVisitor { .transpose()? .flatten(); + // let latest_commit_file = latest_crc_file + // .map(|meta| ParsedLogPath::try_from(meta.into_kernel()).map_err(de::Error::custom)) + // .transpose()? + // .flatten(); + let listed_log_files = ListedLogFiles::try_new( ascending_commit_files, ascending_compaction_files, checkpoint_parts, latest_crc_file, + // latest_commit_file, ) .map_err(de::Error::custom)?; diff --git a/crates/core/src/operations/load.rs b/crates/core/src/operations/load.rs index 097ddd3e90..e5af1a9839 100644 --- a/crates/core/src/operations/load.rs +++ b/crates/core/src/operations/load.rs @@ -115,11 +115,13 @@ impl std::future::IntoFuture for LoadBuilder { #[cfg(test)] mod tests { + use crate::delta_datafusion::create_session; use crate::operations::{collect_sendable_stream, DeltaOps}; use crate::writer::test_utils::{get_record_batch, TestResult}; use crate::DeltaTableBuilder; use datafusion::assert_batches_sorted_eq; use std::path::Path; + use std::sync::Arc; use url::Url; #[tokio::test] @@ -155,7 +157,16 @@ mod tests { let batch = get_record_batch(None, false); let table = DeltaOps::new_in_memory().write(vec![batch.clone()]).await?; - let (_table, stream) = DeltaOps(table).load().await?; + let session = create_session().into_inner(); + session.runtime_env().register_object_store( + &url::Url::parse("memory:///")?, + table.log_store().object_store(None), + ); + + let (_table, stream) = DeltaOps(table) + .load() + .with_session_state(Arc::new(session.state())) + .await?; let data = collect_sendable_stream(stream).await?; let expected = vec![ @@ -186,7 +197,17 @@ mod tests { let batch = get_record_batch(None, false); let table = DeltaOps::new_in_memory().write(vec![batch.clone()]).await?; - let (_table, stream) = DeltaOps(table).load().with_columns(["id", "value"]).await?; + let session = create_session().into_inner(); + session.runtime_env().register_object_store( + &url::Url::parse("memory:///")?, + table.log_store().object_store(None), + ); + + let (_table, stream) = DeltaOps(table) + .load() + .with_columns(["id", "value"]) + .with_session_state(Arc::new(session.state())) + .await?; let data = collect_sendable_stream(stream).await?; let expected = vec![ diff --git a/crates/core/src/operations/vacuum.rs b/crates/core/src/operations/vacuum.rs index a78542d291..89ba1366a6 100644 --- a/crates/core/src/operations/vacuum.rs +++ b/crates/core/src/operations/vacuum.rs @@ -542,13 +542,14 @@ async fn get_stale_files( #[cfg(test)] mod tests { + use std::path::Path; + use std::{io::Read, time::SystemTime}; + use object_store::{local::LocalFileSystem, memory::InMemory, PutPayload}; + use url::Url; use super::*; use crate::{checkpoints::create_checkpoint, ensure_table_uri, open_table}; - use std::path::Path; - use std::{io::Read, time::SystemTime}; - use url::Url; #[tokio::test] async fn test_vacuum_full() -> DeltaResult<()> { @@ -693,7 +694,7 @@ mod tests { async fn test_vacuum_keep_version_validity() { use datafusion::prelude::SessionContext; use object_store::GetResultPayload; - let store = InMemory::new(); + let store = Arc::new(InMemory::new()); let source = LocalFileSystem::new_with_prefix("../test/tests/data/simple_table").unwrap(); let mut stream = source.list(None); @@ -715,7 +716,7 @@ mod tests { let table_url = url::Url::parse("memory:///").unwrap(); let mut table = crate::DeltaTableBuilder::from_uri(table_url.clone()) .unwrap() - .with_storage_backend(Arc::new(store), table_url) + .with_storage_backend(store.clone(), table_url) .build() .unwrap(); table.load().await.unwrap(); @@ -739,6 +740,8 @@ mod tests { assert_eq!(Some(6), table.version()); let ctx = SessionContext::new(); + ctx.runtime_env() + .register_object_store(&url::Url::parse("memory:///").unwrap(), store); ctx.register_table("test", Arc::new(table)).unwrap(); let _batches = ctx .sql("SELECT * FROM test") @@ -838,7 +841,7 @@ mod tests { /// This tests the fix for the race condition where concurrent writer's files could be deleted #[tokio::test] async fn test_vacuum_full_protects_recent_uncommitted_files() -> DeltaResult<()> { - use chrono::{DateTime, Utc}; + use chrono::DateTime; use object_store::GetResultPayload; let store = InMemory::new(); diff --git a/crates/core/tests/read_delta_partitions_test.rs b/crates/core/tests/read_delta_partitions_test.rs index cdad7ef645..3036a783a5 100644 --- a/crates/core/tests/read_delta_partitions_test.rs +++ b/crates/core/tests/read_delta_partitions_test.rs @@ -56,26 +56,18 @@ async fn read_null_partitions_from_checkpoint() { #[cfg(feature = "datafusion")] #[tokio::test] async fn load_from_delta_8_0_table_with_special_partition() { - use datafusion::physical_plan::SendableRecordBatchStream; - use deltalake_core::{DeltaOps, DeltaTable}; - use futures::{future, StreamExt}; + use deltalake_core::DeltaOps; + use futures::TryStreamExt; - let path = "../test/tests/data/delta-0.8.0-special-partition"; - let table = deltalake_core::open_table( - Url::from_directory_path(std::fs::canonicalize(&path).unwrap()).unwrap(), - ) - .await - .unwrap(); + let path = std::fs::canonicalize("../test/tests/data/delta-0.8.0-special-partition").unwrap(); + let table = deltalake_core::open_table(Url::from_directory_path(path).unwrap()) + .await + .unwrap(); - let (_, stream): (DeltaTable, SendableRecordBatchStream) = DeltaOps(table) + let (_, stream) = DeltaOps(table) .load() .with_columns(vec!["x", "y"]) .await .unwrap(); - stream - .for_each(|batch| { - assert!(batch.is_ok()); - future::ready(()) - }) - .await; + let _res: Vec<_> = stream.try_collect().await.unwrap(); }