diff --git a/Cargo.lock b/Cargo.lock index d1742cd998..fe6ca6b9c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -744,6 +744,7 @@ dependencies = [ "arrow-cast", "arrow-json", "arrow-schema", + "criterion 0.7.0", "datafusion", "insta", "serde_json", diff --git a/crates/arrow_tools/Cargo.toml b/crates/arrow_tools/Cargo.toml index c92cf02dda..1e59078f45 100644 --- a/crates/arrow_tools/Cargo.toml +++ b/crates/arrow_tools/Cargo.toml @@ -20,3 +20,8 @@ tracing.workspace = true insta = {workspace = true, features = ["json"]} arrow-json.workspace = true serde_json.workspace = true +criterion = { version = "0.7", features = ["html_reports"] } + +[[bench]] +name = "truncate" +harness = false diff --git a/crates/arrow_tools/benches/truncate.rs b/crates/arrow_tools/benches/truncate.rs new file mode 100644 index 0000000000..7b54f889ac --- /dev/null +++ b/crates/arrow_tools/benches/truncate.rs @@ -0,0 +1,462 @@ +/* +Copyright 2024-2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Criterion benchmarks for the truncation fast-paths and formatting helpers +// in arrow_tools. +// +// Run with: cargo bench -p arrow_tools --bench truncate +// +// The benchmarks deliberately separate "fast path" (no actual work needed, +// exercises the zero-copy `clone()` paths added in the audit) from +// "actual truncation" (exercises the slice+concat / collect paths). + +use arrow::array::{ + FixedSizeListArray, Int32Array, LargeListViewArray, ListArray, ListViewArray, StringArray, + StringViewBuilder, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_tools::record_batch::{truncate_numeric_column_length, truncate_string_columns}; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use std::hint::black_box; +use std::sync::Arc; + +// Creates a batch where *no* string needs truncation at the given limit. +// This exercises the new UTF8 fast-path (cheap any() + early Arc::clone). +fn make_all_short_string_batch(n: usize, _max_chars: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, true)])); + // All strings are ASCII and well under the limit + let short = "short".repeat(3); // ~15 chars + let strings: Vec> = (0..n).map(|_| Some(short.clone())).collect(); + let arr = StringArray::from(strings); + RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch") +} + +// Creates a batch where a significant fraction of strings *do* need +// character truncation. Exercises the collect + truncation path. +fn make_many_long_string_batch(n: usize, max_chars: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, true)])); + let long = "x".repeat(max_chars + 20); // will need truncation + let short = "s"; + let strings: Vec> = (0..n) + .map(|i| { + if i % 3 == 0 { + Some(long.clone()) + } else { + Some(short.to_string()) + } + }) + .collect(); + let arr = StringArray::from(strings); + RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch") +} + +// StringViewArray versions of the above (exercises the specific fast-path +// arm for DataType::Utf8View that was added during the audit). +fn make_all_short_string_view_batch(n: usize, _max_chars: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "text", + DataType::Utf8View, + true, + )])); + let short = "short".repeat(3); + let mut builder = StringViewBuilder::new(); + for _ in 0..n { + builder.append_value(&short); + } + let arr = builder.finish(); + RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch") +} + +fn make_many_long_string_view_batch(n: usize, max_chars: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "text", + DataType::Utf8View, + true, + )])); + let long = "x".repeat(max_chars + 20); + let short = "s"; + let mut builder = StringViewBuilder::new(); + for i in 0..n { + if i % 3 == 0 { + builder.append_value(&long); + } else { + builder.append_value(short); + } + } + let arr = builder.finish(); + RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch") +} + +// Creates a List batch where no list exceeds the element limit. +// Exercises the list fast-path (try_fold + clone). +fn make_all_short_list_batch(n: usize, max_elems: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )])); + let per_list = max_elems.min(5); + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let offsets: Vec = (0..=n) + .map(|i| i32::try_from(i * per_list).expect("offset fits in i32")) + .collect(); + let list = ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::::new(offsets.into()), + Arc::new(values), + None, + ); + RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch") +} + +// Creates a List batch where truncation is required for every list. +fn make_long_list_batch(n: usize, truncate_to: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )])); + let per_list = truncate_to + 15; // will need truncation + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let offsets: Vec = (0..=n) + .map(|i| i32::try_from(i * per_list).expect("offset fits in i32")) + .collect(); + let list = ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::::new(offsets.into()), + Arc::new(values), + None, + ); + RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch") +} + +// ListView versions (exercises the ListView truncation path, which has +// more complex offset + size handling and a different fast-path decision +// based on the explicit sizes buffer). +fn make_all_short_list_view_batch(n: usize, max_elems: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )])); + let per_list = max_elems.min(5); + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let offsets: Vec = (0..n) + .map(|i| i32::try_from(i * per_list).expect("offset fits in i32")) + .collect(); + let per_list_i32 = i32::try_from(per_list).expect("per_list fits in i32"); + let sizes: Vec = (0..n).map(|_| per_list_i32).collect(); + let list_view = ListViewArray::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::::from(offsets), + ScalarBuffer::::from(sizes), + Arc::new(values), + None, + ) + .expect("ListViewArray construction for benchmark"); + RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch") +} + +fn make_long_list_view_batch(n: usize, truncate_to: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )])); + let per_list = truncate_to + 15; + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let offsets: Vec = (0..n) + .map(|i| i32::try_from(i * per_list).expect("offset fits in i32")) + .collect(); + let per_list_i32 = i32::try_from(per_list).expect("per_list fits in i32"); + let sizes: Vec = (0..n).map(|_| per_list_i32).collect(); + let list_view = ListViewArray::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::::from(offsets), + ScalarBuffer::::from(sizes), + Arc::new(values), + None, + ) + .expect("ListViewArray construction for benchmark"); + RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch") +} + +// FixedSizeList versions (exercises the FixedSizeList fast-path, which is +// the cheapest of all — just a uniform size comparison — plus the stride-based +// slicing + concat work path). +fn make_all_short_fixed_size_list_batch(n: usize, _max_elems: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 5), + true, + )])); + let per_list = 5usize; + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let list = FixedSizeListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 5, + Arc::new(values), + None, + ); + RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch") +} + +fn make_long_fixed_size_list_batch(n: usize, _truncate_to: usize) -> RecordBatch { + // For FixedSizeList the "long" case is when the fixed size > truncate_to. + // We create lists of size 20 and truncate to 5. + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 20), + true, + )])); + let per_list = 20usize; + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let list = FixedSizeListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 20, + Arc::new(values), + None, + ); + RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch") +} + +// LargeListView versions (completes benchmark coverage for all five list +// variants that received full support during the type audit; i64 offsets/sizes +// + more complex offset rebuild in the work path). +fn make_all_short_large_list_view_batch(n: usize, max_elems: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )])); + let per_list = max_elems.min(5); + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let offsets: Vec = (0..n) + .map(|i| { + let offset = i.checked_mul(per_list).expect("offset fits in usize"); + i64::try_from(offset).expect("offset fits in i64") + }) + .collect(); + let per_list_i64 = i64::try_from(per_list).expect("list size fits in i64"); + let sizes: Vec = std::iter::repeat_n(per_list_i64, n).collect(); + let list_view = LargeListViewArray::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::::from(offsets), + ScalarBuffer::::from(sizes), + Arc::new(values), + None, + ) + .expect("LargeListViewArray construction for benchmark"); + RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch") +} + +fn make_long_large_list_view_batch(n: usize, truncate_to: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "nums", + DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )])); + let per_list = truncate_to + 15; + let total_values = n * per_list; + let values = Int32Array::from( + (0..total_values) + .map(|v| i32::try_from(v).expect("value fits in i32")) + .collect::>(), + ); + let offsets: Vec = (0..n) + .map(|i| { + let offset = i.checked_mul(per_list).expect("offset fits in usize"); + i64::try_from(offset).expect("offset fits in i64") + }) + .collect(); + let per_list_i64 = i64::try_from(per_list).expect("list size fits in i64"); + let sizes: Vec = std::iter::repeat_n(per_list_i64, n).collect(); + let list_view = LargeListViewArray::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::::from(offsets), + ScalarBuffer::::from(sizes), + Arc::new(values), + None, + ) + .expect("LargeListViewArray construction for benchmark"); + RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch") +} + +fn bench_truncate(c: &mut Criterion) { + let mut group = c.benchmark_group("arrow_tools_truncate"); + + // String (Utf8) fast path + let short_strings = make_all_short_string_batch(2000, 50); + group.bench_function("string_fast_path_2000_rows_all_short", |b| { + b.iter(|| { + truncate_string_columns(black_box(&short_strings), 50).expect("truncate short strings"); + }); + }); + + // String (Utf8) actual work + let long_strings = make_many_long_string_batch(2000, 50); + group.bench_function("string_with_truncation_2000_rows_mixed", |b| { + b.iter(|| { + truncate_string_columns(black_box(&long_strings), 50).expect("truncate long strings"); + }); + }); + + // StringView fast path (exercises the specific Utf8View arm + is_some_and decision) + let short_views = make_all_short_string_view_batch(2000, 50); + group.bench_function("stringview_fast_path_2000_rows_all_short", |b| { + b.iter(|| { + truncate_string_columns(black_box(&short_views), 50) + .expect("truncate short string views"); + }); + }); + + // StringView actual truncation + let long_views = make_many_long_string_view_batch(2000, 50); + group.bench_function("stringview_with_truncation_2000_rows_mixed", |b| { + b.iter(|| { + truncate_string_columns(black_box(&long_views), 50) + .expect("truncate long string views"); + }); + }); + + // List (regular List) fast path + let short_lists = make_all_short_list_batch(800, 5); + group.bench_function("list_fast_path_800_rows_all_short", |b| { + b.iter(|| { + truncate_numeric_column_length(black_box(&short_lists), 5) + .expect("truncate short lists"); + }); + }); + + // List (regular List) actual truncation work + let long_lists = make_long_list_batch(800, 5); + group.bench_function("list_with_truncation_800_rows", |b| { + b.iter(|| { + truncate_numeric_column_length(black_box(&long_lists), 5).expect("truncate long lists"); + }); + }); + + // ListView fast path (exercises the sizes-based decision + clone) + // Use iter_batched to isolate setup cost (creating the ListView batch + // with non-contiguous offsets/sizes is non-trivial). + group.bench_function("listview_fast_path_800_rows_all_short", |b| { + b.iter_batched( + || make_all_short_list_view_batch(800, 5), + |batch| { + truncate_numeric_column_length(black_box(&batch), 5) + .expect("truncate short list views"); + }, + BatchSize::SmallInput, + ); + }); + + // ListView actual truncation work (exercises the more complex offset/size rebuild) + let long_list_views = make_long_list_view_batch(800, 5); + group.bench_function("listview_with_truncation_800_rows", |b| { + b.iter(|| { + truncate_numeric_column_length(black_box(&long_list_views), 5) + .expect("truncate long list views"); + }); + }); + + // FixedSizeList fast path (cheapest decision: uniform size comparison). + // Use iter_batched for consistency with the view variants (even though + // setup is simpler, it keeps the benchmark structure uniform). + group.bench_function("fixed_size_list_fast_path_800_rows_all_short", |b| { + b.iter_batched( + || make_all_short_fixed_size_list_batch(800, 5), + |batch| { + truncate_numeric_column_length(black_box(&batch), 5) + .expect("truncate short fixed-size lists"); + }, + BatchSize::SmallInput, + ); + }); + + // FixedSizeList actual truncation work (stride-based slicing + concat) + let long_fsl = make_long_fixed_size_list_batch(800, 5); + group.bench_function("fixed_size_list_with_truncation_800_rows", |b| { + b.iter(|| { + truncate_numeric_column_length(black_box(&long_fsl), 5) + .expect("truncate long fixed-size lists"); + }); + }); + + // LargeListView fast path (completes the five-variant benchmark coverage; + // i64 sizes scan + zero-copy clone). + // Use iter_batched to isolate setup cost (i64 ScalarBuffer creation). + group.bench_function("large_listview_fast_path_800_rows_all_short", |b| { + b.iter_batched( + || make_all_short_large_list_view_batch(800, 5), + |batch| { + truncate_numeric_column_length(black_box(&batch), 5) + .expect("truncate short large list views"); + }, + BatchSize::SmallInput, + ); + }); + + // LargeListView actual truncation work (i64 offset/size rebuild path) + let long_large_list_views = make_long_large_list_view_batch(800, 5); + group.bench_function("large_listview_with_truncation_800_rows", |b| { + b.iter(|| { + truncate_numeric_column_length(black_box(&long_large_list_views), 5) + .expect("truncate long large list views"); + }); + }); + + group.finish(); +} + +criterion_group!(benches, bench_truncate); +criterion_main!(benches); diff --git a/crates/arrow_tools/src/format.rs b/crates/arrow_tools/src/format.rs index 52c1dbc68f..6444a3d3c1 100644 --- a/crates/arrow_tools/src/format.rs +++ b/crates/arrow_tools/src/format.rs @@ -15,8 +15,11 @@ limitations under the License. */ use crate::schema::to_source_native_type_name; -use arrow::array::{Array, ArrayRef, FixedSizeListArray, ListArray, RecordBatch, StructArray}; -use arrow::buffer::OffsetBuffer; +use arrow::array::{ + Array, ArrayRef, FixedSizeListArray, LargeListArray, LargeListViewArray, ListArray, + ListViewArray, RecordBatch, StructArray, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::compute::concat; use arrow_cast::display::{ArrayFormatter, FormatOptions}; use arrow_schema::{ArrowError, DataType, Field, Schema}; @@ -52,6 +55,19 @@ pub(crate) fn format_column_data( "Failed to downcast to StringViewArray".into(), ))?; + // Fast path: zero-copy when no string requires character truncation. + // Cheap byte-length filter first (char count only on candidates). + if !string_array.iter().any(|opt| { + opt.is_some_and(|s| { + // Short-circuit at the (max_characters)th boundary so we + // stop walking pathologically long strings as soon as we + // know they'll need truncation. + s.len() > max_characters && s.chars().nth(max_characters).is_some() + }) + }) { + return Ok(Arc::clone(&column)); + } + let truncated = string_array .iter() .map(|x| truncate_str(x, max_characters)) @@ -64,9 +80,22 @@ pub(crate) fn format_column_data( .as_any() .downcast_ref::() .ok_or(ArrowError::CastError( - "Failed to downcast to ListArray".into(), + "Failed to downcast to StringArray".into(), ))?; + // Fast path: zero-copy when no string requires character truncation. + // Cheap byte-length filter first (char count only on candidates). + if !string_array.iter().any(|opt| { + opt.is_some_and(|s| { + // Short-circuit at the (max_characters)th boundary so we + // stop walking pathologically long strings as soon as we + // know they'll need truncation. + s.len() > max_characters && s.chars().nth(max_characters).is_some() + }) + }) { + return Ok(Arc::clone(&column)); + } + let truncated = string_array .iter() .map(|x| truncate_str(x, max_characters)) @@ -80,36 +109,66 @@ pub(crate) fn format_column_data( DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) - | DataType::ListView(_), + | DataType::ListView(_) + | DataType::LargeListView(_), Some(_), ), ) => { - let array_ref = if let DataType::FixedSizeList(_, _) = column.data_type() { - let fixed_list_array = column - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::CastError("Failed to downcast to FixedSizeListArray".into()) - })?; - Arc::new(truncate_fixed_size_list_array( - fixed_list_array, - num_elements, - )?) as ArrayRef - } else { - let list_array = column - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::CastError("Failed to downcast to ListArray".into()) - })?; - Arc::new(truncate_list_array(list_array, num_elements)?) as ArrayRef + let array_ref = match column.data_type() { + DataType::FixedSizeList(_, _) => { + let fixed_list_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to FixedSizeListArray".into()) + })?; + Arc::new(truncate_fixed_size_list_array( + fixed_list_array, + num_elements, + )?) as ArrayRef + } + DataType::LargeList(_) => { + let list_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to LargeListArray".into()) + })?; + Arc::new(truncate_large_list_array(list_array, num_elements)?) as ArrayRef + } + DataType::ListView(_) => { + let list_array = + column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to ListViewArray".into()) + })?; + Arc::new(truncate_list_view_array(list_array, num_elements)?) as ArrayRef + } + DataType::LargeListView(_) => { + let list_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to LargeListViewArray".into()) + })?; + Arc::new(truncate_large_list_view_array(list_array, num_elements)?) as ArrayRef + } + _ => { + let list_array = + column.as_any().downcast_ref::().ok_or_else(|| { + ArrowError::CastError("Failed to downcast to ListArray".into()) + })?; + Arc::new(truncate_list_array(list_array, num_elements)?) as ArrayRef + } }; Ok(array_ref) } (FormatOperation::TruncateUtf8Length(max_characters), (DataType::List(field), _)) => { let list_array = column .as_any() - .downcast_ref::() + .downcast_ref::() .ok_or_else(|| ArrowError::CastError("Failed to downcast to ListArray".into()))?; let truncated_values = format_column_data( @@ -118,24 +177,150 @@ pub(crate) fn format_column_data( FormatOperation::TruncateUtf8Length(max_characters), )?; + // Zero-copy fast path for the outer list when the inner values were + // not modified (common when text in lists is short). + if Arc::ptr_eq(&truncated_values, list_array.values()) { + return Ok(Arc::clone(&column)); + } + let list = ListArray::new( Arc::clone(&field), - arrow::buffer::OffsetBuffer::new( - arrow::buffer::Buffer::from_slice_ref(list_array.value_offsets()).into(), - ), + list_array.offsets().clone(), truncated_values, - list_array.logical_nulls(), + list_array.nulls().cloned(), ); Ok(Arc::new(list) as ArrayRef) } + (FormatOperation::TruncateUtf8Length(max_characters), (DataType::LargeList(field), _)) => { + let list_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to LargeListArray".into()) + })?; + + let truncated_values = format_column_data( + Arc::clone(list_array.values()), + &field, + FormatOperation::TruncateUtf8Length(max_characters), + )?; + + // Zero-copy fast path for the outer list when the inner values were + // not modified (common when text in lists is short). + if Arc::ptr_eq(&truncated_values, list_array.values()) { + return Ok(Arc::clone(&column)); + } + + let list = LargeListArray::new( + Arc::clone(&field), + list_array.offsets().clone(), + truncated_values, + list_array.nulls().cloned(), + ); + + Ok(Arc::new(list) as ArrayRef) + } + ( + FormatOperation::TruncateUtf8Length(max_characters), + (DataType::FixedSizeList(field, size), _), + ) => { + let list_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to FixedSizeListArray".into()) + })?; + + let truncated_values = format_column_data( + Arc::clone(list_array.values()), + &field, + FormatOperation::TruncateUtf8Length(max_characters), + )?; + + // Zero-copy fast path for the outer list when the inner values were + // not modified (common when text in lists is short). + if Arc::ptr_eq(&truncated_values, list_array.values()) { + return Ok(Arc::clone(&column)); + } + + let list = FixedSizeListArray::new( + Arc::clone(&field), + size, + truncated_values, + list_array.nulls().cloned(), + ); + + Ok(Arc::new(list) as ArrayRef) + } + (FormatOperation::TruncateUtf8Length(max_characters), (DataType::ListView(field), _)) => { + let list_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to ListViewArray".into()) + })?; + + let truncated_values = format_column_data( + Arc::clone(list_array.values()), + &field, + FormatOperation::TruncateUtf8Length(max_characters), + )?; + + // Zero-copy fast path for the outer list when the inner values were + // not modified (common when text in lists is short). + if Arc::ptr_eq(&truncated_values, list_array.values()) { + return Ok(Arc::clone(&column)); + } + + ListViewArray::try_new( + Arc::clone(&field), + list_array.offsets().clone(), + list_array.sizes().clone(), + truncated_values, + list_array.nulls().cloned(), + ) + .map(|list| Arc::new(list) as ArrayRef) + } + ( + FormatOperation::TruncateUtf8Length(max_characters), + (DataType::LargeListView(field), _), + ) => { + let list_array = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to downcast to LargeListViewArray".into()) + })?; + + let truncated_values = format_column_data( + Arc::clone(list_array.values()), + &field, + FormatOperation::TruncateUtf8Length(max_characters), + )?; + + // Zero-copy fast path for the outer list when the inner values were + // not modified (common when text in lists is short). + if Arc::ptr_eq(&truncated_values, list_array.values()) { + return Ok(Arc::clone(&column)); + } + + LargeListViewArray::try_new( + Arc::clone(&field), + list_array.offsets().clone(), + list_array.sizes().clone(), + truncated_values, + list_array.nulls().cloned(), + ) + .map(|list| Arc::new(list) as ArrayRef) + } (FormatOperation::TruncateUtf8Length(max_characters), (DataType::Struct(fields), _)) => { let struct_array = column .as_any() .downcast_ref::() .ok_or_else(|| ArrowError::CastError("Failed to downcast to StructArray".into()))?; - let columns = fields + let columns: Vec<_> = fields .iter() .enumerate() .map(|(i, field)| { @@ -148,6 +333,15 @@ pub(crate) fn format_column_data( }) .collect::, _>>()?; + // Zero-copy fast path for the whole struct when no field was modified. + let all_unchanged = columns + .iter() + .enumerate() + .all(|(i, c)| Arc::ptr_eq(c, struct_array.column(i))); + if all_unchanged { + return Ok(Arc::clone(&column)); + } + let truncated_struct = StructArray::from(fields.iter().cloned().zip(columns).collect::>()); Ok(Arc::new(truncated_struct) as ArrayRef) @@ -161,9 +355,11 @@ fn get_possible_nested_list_datatype(f: &Arc) -> (DataType, Option { - Some(f.data_type().clone()) - } + DataType::List(f) + | DataType::FixedSizeList(f, _) + | DataType::LargeList(f) + | DataType::ListView(f) + | DataType::LargeListView(f) => Some(f.data_type().clone()), _ => None, }, ) @@ -184,18 +380,78 @@ fn truncate_str(str: Option<&str>, max_characters: usize) -> Option<&str> { }) } -#[expect( - clippy::cast_sign_loss, - clippy::cast_possible_truncation, - clippy::cast_possible_wrap -)] +fn value_to_usize(value: T, value_name: &str) -> Result +where + T: TryInto + std::fmt::Display + Copy, +{ + value.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "{value_name} {value} cannot be represented as usize" + )) + }) +} + +fn value_to_i32(value: usize, value_name: &str) -> Result { + i32::try_from(value).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "{value_name} {value} cannot be represented as i32" + )) + }) +} + +fn list_slice_range( + start_offset: T, + end_offset: T, + array_name: &str, + start_offset_name: &str, + end_offset_name: &str, +) -> Result<(usize, usize), ArrowError> +where + T: TryInto + std::fmt::Display + Copy, +{ + let start = value_to_usize(start_offset, start_offset_name)?; + let end = value_to_usize(end_offset, end_offset_name)?; + let len = end.checked_sub(start).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "{array_name} end offset {end} is before start offset {start}" + )) + })?; + + Ok((start, len)) +} + +fn list_element_field(data_type: &DataType, array_name: &str) -> Result, ArrowError> { + match data_type { + DataType::List(field) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) + | DataType::FixedSizeList(field, _) => Ok(Arc::clone(field)), + _ => Err(ArrowError::InvalidArgumentError(format!( + "{array_name} data type is not a list type: {data_type}" + ))), + } +} + fn truncate_fixed_size_list_array( list_array: &FixedSizeListArray, max_len: usize, ) -> Result { + if list_array.is_empty() { + return Ok(list_array.clone()); + } let child_array = list_array.values(); - let original_size = list_array.value_length() as usize; - let truncated_size = max_len.min(original_size); + let original_size = + value_to_usize(list_array.value_length(), "FixedSizeListArray value length")?; + // Fast path: zero-copy clone when truncation would be a no-op. This is the + // common case for display formatting (generous max_len, small actual lists) + // and avoids per-row slice + concat allocation/copy entirely. + if max_len >= original_size { + return Ok(list_array.clone()); + } + let truncated_size = max_len; // known to be < original_size + let truncated_size_i32 = value_to_i32(truncated_size, "FixedSizeListArray truncated size")?; + let element_field = list_element_field(list_array.data_type(), "FixedSizeListArray")?; let sliced_arrays: Vec> = (0..list_array.len()) .map(|i| child_array.slice(i * original_size, truncated_size)) @@ -204,57 +460,272 @@ fn truncate_fixed_size_list_array( let new_child_array = Arc::new(concat( &sliced_arrays.iter().map(AsRef::as_ref).collect::>(), )?); - let nulls = new_child_array.nulls().cloned(); - - FixedSizeListArray::try_new( - Arc::new(Field::new( - "item", - child_array.data_type().clone(), - child_array.is_nullable(), - )), - truncated_size as i32, + // Parent list-level nulls (which slots are NULL lists) are preserved as-is; + // they live separately from the child array's element-level null bitmap. + let nulls = list_array.nulls().cloned(); + + FixedSizeListArray::try_new(element_field, truncated_size_i32, new_child_array, nulls) +} + +fn truncate_list_array(list_array: &ListArray, max_len: usize) -> Result { + if list_array.is_empty() { + return Ok(list_array.clone()); + } + let child_array = list_array.values(); + let offsets = list_array.value_offsets(); + // Fast path: zero-copy clone when no list element exceeds max_len. + let mut needs_trunc = false; + for i in 0..list_array.len() { + let (_, len) = list_slice_range( + offsets[i], + offsets[i + 1], + "ListArray", + "ListArray start offset", + "ListArray end offset", + )?; + if len > max_len { + needs_trunc = true; + break; + } + } + if !needs_trunc { + return Ok(list_array.clone()); + } + let element_field = list_element_field(list_array.data_type(), "ListArray")?; + + let slice_ranges: Vec<(usize, usize)> = (0..list_array.len()) + .map(|i| { + list_slice_range( + offsets[i], + offsets[i + 1], + "ListArray", + "ListArray start offset", + "ListArray end offset", + ) + .map(|(start, len)| (start, max_len.min(len))) + }) + .collect::>()?; + let new_lengths: Vec = slice_ranges.iter().map(|&(_, len)| len).collect(); + + let sliced_arrays: Vec> = new_lengths + .iter() + .zip(slice_ranges.iter()) + .map(|(&len, &(start, _))| child_array.slice(start, len)) + .collect(); + + let new_child_array = Arc::new(concat( + &sliced_arrays.iter().map(AsRef::as_ref).collect::>(), + )?); + + let nulls = list_array.nulls().cloned(); + + ListArray::try_new( + element_field, + OffsetBuffer::from_lengths(new_lengths), new_child_array, nulls, ) } -#[expect(clippy::cast_sign_loss)] -fn truncate_list_array(list_array: &ListArray, max_len: usize) -> Result { +fn truncate_large_list_array( + list_array: &LargeListArray, + max_len: usize, +) -> Result { + if list_array.is_empty() { + return Ok(list_array.clone()); + } let child_array = list_array.values(); let offsets = list_array.value_offsets(); + // Fast path: zero-copy clone when truncation is unnecessary (common for + // result formatting). LargeList uses i64 offsets, so validate conversion + // before deciding whether the slow path is needed. + let mut needs_trunc = false; + for i in 0..list_array.len() { + let (_, len) = list_slice_range( + offsets[i], + offsets[i + 1], + "LargeListArray", + "LargeListArray start offset", + "LargeListArray end offset", + )?; + if len > max_len { + needs_trunc = true; + break; + } + } + if !needs_trunc { + return Ok(list_array.clone()); + } + let element_field = list_element_field(list_array.data_type(), "LargeListArray")?; - let new_lengths: Vec = (0..list_array.len()) + let slice_ranges: Vec<(usize, usize)> = (0..list_array.len()) .map(|i| { - let start = offsets[i] as usize; - let end = offsets[i + 1] as usize; - max_len.min(end - start) + list_slice_range( + offsets[i], + offsets[i + 1], + "LargeListArray", + "LargeListArray start offset", + "LargeListArray end offset", + ) + .map(|(start, len)| (start, max_len.min(len))) }) - .collect(); + .collect::>()?; + let new_lengths: Vec = slice_ranges.iter().map(|&(_, len)| len).collect(); let sliced_arrays: Vec> = new_lengths .iter() - .enumerate() - .map(|(i, &len)| child_array.slice(offsets[i] as usize, len)) + .zip(slice_ranges.iter()) + .map(|(&len, &(start, _))| child_array.slice(start, len)) .collect(); let new_child_array = Arc::new(concat( &sliced_arrays.iter().map(AsRef::as_ref).collect::>(), )?); - let nulls = new_child_array.nulls().cloned(); + let nulls = list_array.nulls().cloned(); - ListArray::try_new( - Arc::new(Field::new( - "item", - child_array.data_type().clone(), - child_array.is_nullable(), - )), + LargeListArray::try_new( + element_field, OffsetBuffer::from_lengths(new_lengths), new_child_array, nulls, ) } +fn truncate_list_view_array( + list_array: &ListViewArray, + max_len: usize, +) -> Result { + if list_array.is_empty() { + return Ok(list_array.clone()); + } + let child_array = list_array.values(); + let sizes = list_array.value_sizes(); + // Fast path for ListView: sizes are stored explicitly, cheap scan. + // When no element exceeds the limit we return the original (preserving + // whatever non-contiguous view layout the caller had). + let mut needs_trunc = false; + for &size in sizes { + if value_to_usize(size, "ListViewArray size")? > max_len { + needs_trunc = true; + break; + } + } + if !needs_trunc { + return Ok(list_array.clone()); + } + let offsets = list_array.value_offsets(); + let element_field = list_element_field(list_array.data_type(), "ListViewArray")?; + + let slice_ranges: Vec<(usize, usize, i32)> = (0..list_array.len()) + .map(|i| { + let start = value_to_usize(offsets[i], "ListViewArray offset")?; + let original_size = value_to_usize(sizes[i], "ListViewArray size")?; + let truncated_size = max_len.min(original_size); + let truncated_size_i32 = value_to_i32(truncated_size, "ListViewArray truncated size")?; + + Ok((start, truncated_size, truncated_size_i32)) + }) + .collect::>()?; + let new_sizes: Vec = slice_ranges + .iter() + .map(|&(_, _, truncated_size_i32)| truncated_size_i32) + .collect(); + + let sliced_arrays: Vec> = (0..list_array.len()) + .map(|i| child_array.slice(slice_ranges[i].0, slice_ranges[i].1)) + .collect(); + + let new_child_array = Arc::new(concat( + &sliced_arrays.iter().map(AsRef::as_ref).collect::>(), + )?); + + let nulls = list_array.nulls().cloned(); + + // After concat, the new values buffer is laid out contiguously, so we + // re-derive offsets from the truncated sizes. Bail on overflow rather + // than silently producing misaligned offsets. + let mut new_offsets: Vec = Vec::with_capacity(new_sizes.len()); + let mut running: i32 = 0; + for &s in &new_sizes { + new_offsets.push(running); + running = running.checked_add(s).ok_or_else(|| { + ArrowError::InvalidArgumentError("ListViewArray cumulative size overflowed i32".into()) + })?; + } + + ListViewArray::try_new( + element_field, + ScalarBuffer::from(new_offsets), + ScalarBuffer::from(new_sizes), + new_child_array, + nulls, + ) +} + +fn truncate_large_list_view_array( + list_array: &LargeListViewArray, + max_len: usize, +) -> Result { + if list_array.is_empty() { + return Ok(list_array.clone()); + } + let child_array = list_array.values(); + let sizes = list_array.value_sizes(); + // Fast path for LargeListView (i64 sizes): cheap scan over the sizes buffer. + // When nothing needs truncation we avoid the expensive concat + offset rebuild. + let max_len_i64 = i64::try_from(max_len).map_err(|_| { + ArrowError::InvalidArgumentError(format!("max_len {max_len} cannot be represented as i64")) + })?; + if !sizes.iter().any(|&s| s > max_len_i64) { + return Ok(list_array.clone()); + } + let offsets = list_array.value_offsets(); + let element_field = list_element_field(list_array.data_type(), "LargeListViewArray")?; + + let new_sizes: Vec = (0..list_array.len()) + .map(|i| sizes[i].min(max_len_i64)) + .collect(); + + let slice_ranges: Vec<(usize, usize)> = (0..list_array.len()) + .map(|i| { + let start = value_to_usize(offsets[i], "LargeListViewArray offset")?; + let len = value_to_usize(new_sizes[i], "LargeListViewArray size")?; + + Ok((start, len)) + }) + .collect::>()?; + + let sliced_arrays: Vec> = (0..list_array.len()) + .map(|i| child_array.slice(slice_ranges[i].0, slice_ranges[i].1)) + .collect(); + + let new_child_array = Arc::new(concat( + &sliced_arrays.iter().map(AsRef::as_ref).collect::>(), + )?); + + let nulls = list_array.nulls().cloned(); + + let mut new_offsets: Vec = Vec::with_capacity(new_sizes.len()); + let mut running: i64 = 0; + for &s in &new_sizes { + new_offsets.push(running); + running = running.checked_add(s).ok_or_else(|| { + ArrowError::InvalidArgumentError( + "LargeListViewArray cumulative size overflowed i64".into(), + ) + })?; + } + + LargeListViewArray::try_new( + element_field, + ScalarBuffer::from(new_offsets), + ScalarBuffer::from(new_sizes), + new_child_array, + nulls, + ) +} + /// Creates a visual representation of record batches using markdown document format with additional header fields. /// /// # Errors @@ -413,6 +884,26 @@ pub fn pretty_print_schema( write_data_type(field.data_type(), w)?; w.write_char('>') } + DataType::FixedSizeList(field, size) => { + w.write_str("fixed_size_list[{size}]") + } + DataType::ListView(field) => { + w.write_str("list_view') + } + DataType::LargeListView(field) => { + w.write_str("large_list_view') + } + DataType::Map(field, _) => { + w.write_str("map<")?; + write_data_type(field.data_type(), w)?; + w.write_char('>') + } DataType::Struct(fields) => { w.write_str("struct<")?; for (i, f) in fields.iter().enumerate() { @@ -634,6 +1125,519 @@ Cras venenatis euismod malesuada.", } } + #[test] + fn test_truncate_large_list_array() { + use arrow::array::{Int32Array, LargeListArray as LargeListArrayAlias}; + use arrow::buffer::OffsetBuffer; + + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); + let offsets = OffsetBuffer::::new(vec![0_i64, 3, 6, 8].into()); + let input = LargeListArrayAlias::new( + Arc::new(Field::new("item", DataType::Int32, true)), + offsets, + Arc::new(values), + None, + ); + + let output = + truncate_large_list_array(&input, 2).expect("truncate_large_list_array failed"); + + assert_eq!(output.len(), 3); + // Each sublist should now have at most 2 elements. + let observed_lengths: Vec = (0..output.len()) + .map(|i| { + let start = + usize::try_from(output.value_offsets()[i]).expect("offset should fit in usize"); + let end = usize::try_from(output.value_offsets()[i + 1]) + .expect("offset should fit in usize"); + end - start + }) + .collect(); + assert_eq!(observed_lengths, vec![2, 2, 2]); + } + + /// Parent list-level nulls (which slots are NULL lists) must survive + /// truncation. The bug we're guarding against: pulling the child array's + /// null bitmap and stuffing it into the parent's `nulls` slot, which can + /// flip valid slots to NULL or vice versa. + #[test] + fn test_truncate_list_array_preserves_parent_nulls() { + let input = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + let parent_nulls_before: Vec = (0..input.len()).map(|i| input.is_null(i)).collect(); + + let output = truncate_list_array(&input, 2).expect("truncate_list_array failed"); + let parent_nulls_after: Vec = (0..output.len()).map(|i| output.is_null(i)).collect(); + + assert_eq!(parent_nulls_before, parent_nulls_after); + assert_eq!(parent_nulls_after, vec![false, true, false]); + } + + #[test] + fn test_truncate_large_list_array_preserves_parent_nulls() { + use arrow::array::{Int32Array, LargeListArray as LargeListArrayAlias}; + use arrow::buffer::{NullBuffer, OffsetBuffer}; + + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5]); + let offsets = OffsetBuffer::::new(vec![0_i64, 3, 3, 6].into()); + // Mark the middle slot as a NULL list. + let nulls = NullBuffer::from(vec![true, false, true]); + let input = LargeListArrayAlias::new( + Arc::new(Field::new("item", DataType::Int32, true)), + offsets, + Arc::new(values), + Some(nulls), + ); + + let output = + truncate_large_list_array(&input, 2).expect("truncate_large_list_array failed"); + + assert!(!output.is_null(0)); + assert!(output.is_null(1)); + assert!(!output.is_null(2)); + } + + /// Truncation must carry the original element `Field` through unchanged. + /// Building a fresh field with a hardcoded name ("item") and + /// `child_array.is_nullable()` (which reflects the *data*, not the + /// declared schema) would silently rewrite the resulting `DataType` and + /// drop any field-level metadata. + #[test] + fn test_truncate_helpers_preserve_element_field() { + use arrow::array::{ + FixedSizeListArray as FixedSizeListArrayAlias, Int32Array, + LargeListArray as LargeListArrayAlias, LargeListViewArray as LargeListViewArrayAlias, + ListViewArray as ListViewArrayAlias, + }; + use arrow::buffer::{OffsetBuffer, ScalarBuffer}; + + let metadata: HashMap = [("origin".to_string(), "audit_test".to_string())] + .into_iter() + .collect(); + // Original field has a custom name, declared-nullable=false, and metadata; + // none of which the truncation helper should be free to discard. + let element_field = + Arc::new(Field::new("value", DataType::Int32, false).with_metadata(metadata.clone())); + + let values = || Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); + + let assert_field_preserved = |actual: &Arc| { + assert_eq!(actual.name(), "value"); + assert!(!actual.is_nullable()); + assert_eq!(actual.metadata(), &metadata); + }; + + // List + let list = ListArray::new( + Arc::clone(&element_field), + OffsetBuffer::::new(vec![0_i32, 3, 6, 8].into()), + Arc::new(values()), + None, + ); + let output = truncate_list_array(&list, 2).expect("truncate List"); + match output.data_type() { + DataType::List(field) => assert_field_preserved(field), + other => panic!("unexpected {other:?}"), + } + + // LargeList + let large_list = LargeListArrayAlias::new( + Arc::clone(&element_field), + OffsetBuffer::::new(vec![0_i64, 3, 6, 8].into()), + Arc::new(values()), + None, + ); + let output = truncate_large_list_array(&large_list, 2).expect("truncate LargeList"); + match output.data_type() { + DataType::LargeList(field) => assert_field_preserved(field), + other => panic!("unexpected {other:?}"), + } + + // FixedSizeList + let fixed_size_list = FixedSizeListArrayAlias::new( + Arc::clone(&element_field), + 4, + Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7])), + None, + ); + let output = + truncate_fixed_size_list_array(&fixed_size_list, 2).expect("truncate FixedSizeList"); + match output.data_type() { + DataType::FixedSizeList(field, _) => assert_field_preserved(field), + other => panic!("unexpected {other:?}"), + } + + // ListView + let list_view = ListViewArrayAlias::try_new( + Arc::clone(&element_field), + ScalarBuffer::::from(vec![0_i32, 3, 6]), + ScalarBuffer::::from(vec![3_i32, 3, 2]), + Arc::new(values()), + None, + ) + .expect("ListViewArray construction"); + let output = truncate_list_view_array(&list_view, 2).expect("truncate ListView"); + match output.data_type() { + DataType::ListView(field) => assert_field_preserved(field), + other => panic!("unexpected {other:?}"), + } + + // LargeListView + let large_list_view = LargeListViewArrayAlias::try_new( + Arc::clone(&element_field), + ScalarBuffer::::from(vec![0_i64, 3, 6]), + ScalarBuffer::::from(vec![3_i64, 3, 2]), + Arc::new(values()), + None, + ) + .expect("LargeListViewArray construction"); + let output = + truncate_large_list_view_array(&large_list_view, 2).expect("truncate LargeListView"); + match output.data_type() { + DataType::LargeListView(field) => assert_field_preserved(field), + other => panic!("unexpected {other:?}"), + } + } + + /// `TruncateUtf8Length` must recurse into every list-like variant, not + /// just `List`. Otherwise strings nested under `LargeList` / + /// `FixedSizeList` / `ListView` / `LargeListView` silently skip + /// truncation when callers run `truncate_string_columns` over a record + /// batch. + #[test] + fn test_truncate_utf8_recurses_into_all_list_variants() { + use arrow::array::{ + FixedSizeListArray as FixedSizeListArrayAlias, LargeListArray as LargeListArrayAlias, + LargeListViewArray as LargeListViewArrayAlias, ListViewArray as ListViewArrayAlias, + StringArray, + }; + use arrow::buffer::{OffsetBuffer, ScalarBuffer}; + + fn assert_first_string_truncated(arr: &ArrayRef, expected_first_char: &str) { + // Each variant stores its UTF8 elements contiguously in the child + // array; inspecting element 0 is enough to prove truncation ran. + let dt = arr.data_type().clone(); + let child: ArrayRef = match &dt { + DataType::List(_) => Arc::clone( + arr.as_any() + .downcast_ref::() + .expect("ListArray") + .values(), + ), + DataType::LargeList(_) => Arc::clone( + arr.as_any() + .downcast_ref::() + .expect("LargeListArray") + .values(), + ), + DataType::FixedSizeList(_, _) => Arc::clone( + arr.as_any() + .downcast_ref::() + .expect("FixedSizeListArray") + .values(), + ), + DataType::ListView(_) => Arc::clone( + arr.as_any() + .downcast_ref::() + .expect("ListViewArray") + .values(), + ), + DataType::LargeListView(_) => Arc::clone( + arr.as_any() + .downcast_ref::() + .expect("LargeListViewArray") + .values(), + ), + other => panic!("unexpected outer type {other:?}"), + }; + let strings = child + .as_any() + .downcast_ref::() + .expect("StringArray child"); + assert_eq!(strings.value(0), expected_first_char); + } + + let inner_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let strings = || StringArray::from(vec!["abcdef", "ghijkl"]); + + // List + let list = ListArray::new( + Arc::clone(&inner_field), + OffsetBuffer::::new(vec![0_i32, 1, 2].into()), + Arc::new(strings()), + None, + ); + let truncated = format_column_data( + Arc::new(list) as ArrayRef, + &Arc::new(Field::new( + "col", + DataType::List(Arc::clone(&inner_field)), + true, + )), + FormatOperation::TruncateUtf8Length(1), + ) + .expect("List Utf8 truncate"); + assert_first_string_truncated(&truncated, "a"); + + // LargeList + let large_list = LargeListArrayAlias::new( + Arc::clone(&inner_field), + OffsetBuffer::::new(vec![0_i64, 1, 2].into()), + Arc::new(strings()), + None, + ); + let truncated = format_column_data( + Arc::new(large_list) as ArrayRef, + &Arc::new(Field::new( + "col", + DataType::LargeList(Arc::clone(&inner_field)), + true, + )), + FormatOperation::TruncateUtf8Length(1), + ) + .expect("LargeList Utf8 truncate"); + assert_first_string_truncated(&truncated, "a"); + + // FixedSizeList[1] + let fsl = + FixedSizeListArrayAlias::new(Arc::clone(&inner_field), 1, Arc::new(strings()), None); + let truncated = format_column_data( + Arc::new(fsl) as ArrayRef, + &Arc::new(Field::new( + "col", + DataType::FixedSizeList(Arc::clone(&inner_field), 1), + true, + )), + FormatOperation::TruncateUtf8Length(1), + ) + .expect("FixedSizeList Utf8 truncate"); + assert_first_string_truncated(&truncated, "a"); + + // ListView + let list_view = ListViewArrayAlias::try_new( + Arc::clone(&inner_field), + ScalarBuffer::::from(vec![0_i32, 1]), + ScalarBuffer::::from(vec![1_i32, 1]), + Arc::new(strings()), + None, + ) + .expect("ListViewArray construction"); + let truncated = format_column_data( + Arc::new(list_view) as ArrayRef, + &Arc::new(Field::new( + "col", + DataType::ListView(Arc::clone(&inner_field)), + true, + )), + FormatOperation::TruncateUtf8Length(1), + ) + .expect("ListView Utf8 truncate"); + assert_first_string_truncated(&truncated, "a"); + + // LargeListView + let large_list_view = LargeListViewArrayAlias::try_new( + Arc::clone(&inner_field), + ScalarBuffer::::from(vec![0_i64, 1]), + ScalarBuffer::::from(vec![1_i64, 1]), + Arc::new(strings()), + None, + ) + .expect("LargeListViewArray construction"); + let truncated = format_column_data( + Arc::new(large_list_view) as ArrayRef, + &Arc::new(Field::new( + "col", + DataType::LargeListView(Arc::clone(&inner_field)), + true, + )), + FormatOperation::TruncateUtf8Length(1), + ) + .expect("LargeListView Utf8 truncate"); + assert_first_string_truncated(&truncated, "a"); + } + + /// `arrow::compute::concat` errors when given an empty slice of arrays. + /// 0-row record batches are a realistic input (e.g. a sample query that + /// returns no rows), so each list truncation helper must short-circuit + /// before reaching `concat`. + #[test] + fn test_truncate_helpers_handle_empty_arrays() { + use arrow::array::{ + FixedSizeListArray as FixedSizeListArrayAlias, Int32Array, + LargeListArray as LargeListArrayAlias, LargeListViewArray as LargeListViewArrayAlias, + ListViewArray as ListViewArrayAlias, + }; + use arrow::buffer::{OffsetBuffer, ScalarBuffer}; + + // Empty ListArray + let empty_list = ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::::new(vec![0_i32].into()), + Arc::new(Int32Array::from(Vec::::new())), + None, + ); + assert_eq!( + truncate_list_array(&empty_list, 4) + .expect("empty ListArray must round-trip") + .len(), + 0 + ); + + // Empty LargeListArray + let empty_large_list = LargeListArrayAlias::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::::new(vec![0_i64].into()), + Arc::new(Int32Array::from(Vec::::new())), + None, + ); + assert_eq!( + truncate_large_list_array(&empty_large_list, 4) + .expect("empty LargeListArray must round-trip") + .len(), + 0 + ); + + // Empty FixedSizeListArray + let empty_fsl = FixedSizeListArrayAlias::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 3, + Arc::new(Int32Array::from(Vec::::new())), + None, + ); + assert_eq!( + truncate_fixed_size_list_array(&empty_fsl, 2) + .expect("empty FixedSizeListArray must round-trip") + .len(), + 0 + ); + + // Empty ListViewArray + let empty_list_view = ListViewArrayAlias::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::::from(Vec::::new()), + ScalarBuffer::::from(Vec::::new()), + Arc::new(Int32Array::from(Vec::::new())), + None, + ) + .expect("empty ListViewArray construction"); + assert_eq!( + truncate_list_view_array(&empty_list_view, 4) + .expect("empty ListViewArray must round-trip") + .len(), + 0 + ); + + // Empty LargeListViewArray + let empty_large_list_view = LargeListViewArrayAlias::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::::from(Vec::::new()), + ScalarBuffer::::from(Vec::::new()), + Arc::new(Int32Array::from(Vec::::new())), + None, + ) + .expect("empty LargeListViewArray construction"); + assert_eq!( + truncate_large_list_view_array(&empty_large_list_view, 4) + .expect("empty LargeListViewArray must round-trip") + .len(), + 0 + ); + } + + #[test] + fn test_truncate_fixed_size_list_array_preserves_parent_nulls() { + use arrow::array::{FixedSizeListArray as FixedSizeListArrayAlias, Int32Array}; + use arrow::buffer::NullBuffer; + + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]); + let nulls = NullBuffer::from(vec![true, false, true]); + let input = FixedSizeListArrayAlias::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 3, + Arc::new(values), + Some(nulls), + ); + + let output = truncate_fixed_size_list_array(&input, 2) + .expect("truncate_fixed_size_list_array failed"); + + assert!(!output.is_null(0)); + assert!(output.is_null(1)); + assert!(!output.is_null(2)); + } + + #[test] + fn test_truncate_list_view_array() { + use arrow::array::{Int32Array, ListViewArray as ListViewArrayAlias}; + use arrow::buffer::ScalarBuffer; + + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); + let offsets = ScalarBuffer::::from(vec![0_i32, 3, 6]); + let sizes = ScalarBuffer::::from(vec![3_i32, 3, 2]); + let input = ListViewArrayAlias::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + offsets, + sizes, + Arc::new(values), + None, + ) + .expect("ListViewArray::try_new"); + + let output = truncate_list_view_array(&input, 2).expect("truncate_list_view_array failed"); + + assert_eq!(output.len(), 3); + // After truncation each entry has at most 2 elements. + for size in output.value_sizes() { + assert!(*size <= 2, "size {size} exceeded max_len"); + } + } + + #[test] + fn test_truncate_large_list_view_array() { + use arrow::array::{Int32Array, LargeListViewArray as LargeListViewArrayAlias}; + use arrow::buffer::ScalarBuffer; + + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); + let offsets = ScalarBuffer::::from(vec![0_i64, 3, 6]); + let sizes = ScalarBuffer::::from(vec![3_i64, 3, 2]); + let input = LargeListViewArrayAlias::try_new( + Arc::new(Field::new("item", DataType::Int32, true)), + offsets, + sizes, + Arc::new(values), + None, + ) + .expect("LargeListViewArray::try_new"); + + let output = truncate_large_list_view_array(&input, 2) + .expect("truncate_large_list_view_array failed"); + + assert_eq!(output.len(), 3); + for size in output.value_sizes() { + assert!(*size <= 2, "size {size} exceeded max_len"); + } + } + + #[test] + fn test_get_possible_nested_list_datatype_view_variants() { + let inner = DataType::Int32; + let cases = vec![ + DataType::List(Arc::new(Field::new("item", inner.clone(), true))), + DataType::LargeList(Arc::new(Field::new("item", inner.clone(), true))), + DataType::FixedSizeList(Arc::new(Field::new("item", inner.clone(), true)), 4), + DataType::ListView(Arc::new(Field::new("item", inner.clone(), true))), + DataType::LargeListView(Arc::new(Field::new("item", inner.clone(), true))), + ]; + for dt in cases { + let f = Arc::new(Field::new("col", dt, true)); + let (_, inner_dt) = get_possible_nested_list_datatype(&f); + assert_eq!(inner_dt.as_ref(), Some(&inner), "field {f:?}"); + } + } + #[test] fn test_truncate_fixed_size_list_array() { let test_cases: Vec<(&str, usize, FixedSizeListArray)> = vec![ @@ -751,4 +1755,196 @@ Cras venenatis euismod malesuada.", metadata: struct (nullable) "); } + + /// `max_len = 0` is a valid (if unusual) truncation request: every list + /// must become a 0-element list while parent nulls and the element Field + /// (including metadata/nullable) are preserved exactly. + #[test] + fn test_truncate_list_to_zero_elements() { + use arrow::array::{Int32Array, ListArray as ListArrayAlias}; + use arrow::buffer::{NullBuffer, OffsetBuffer}; + + let element_field = Arc::new( + Field::new("value", DataType::Int32, false) + .with_metadata([("audit".to_string(), "zero_len".to_string())].into()), + ); + + // List with mixed lengths + one parent NULL + // lists: [0,1,2], [], [3,4,5], [6] (values 0..6) + let list = ListArrayAlias::new( + Arc::clone(&element_field), + OffsetBuffer::::new(vec![0_i32, 3, 3, 6, 7].into()), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + let out = truncate_list_array(&list, 0).expect("truncate to 0"); + assert_eq!(out.len(), 4); + assert!(out.is_null(1)); + assert!(!out.is_null(0)); + // All non-null lists must report length 0 + for i in [0, 2, 3] { + assert!(!out.is_null(i)); + assert_eq!(out.value(i).len(), 0); + } + // Element field preserved + if let DataType::List(f) = out.data_type() { + assert_eq!(f.name(), "value"); + assert!(!f.is_nullable()); + assert_eq!(f.metadata().get("audit"), Some(&"zero_len".to_string())); + } else { + panic!("expected List"); + } + } + + /// Fast-path + `max_len=0` for `ListView` (non-contiguous layout). + #[test] + fn test_truncate_list_view_to_zero_and_fast_path() { + use arrow::array::{Int32Array, ListViewArray as ListViewArrayAlias}; + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + let element_field = Arc::new(Field::new("item", DataType::Int32, true)); + let values = Int32Array::from(vec![10, 20, 30, 40]); + // Two lists: [10,20,30] at offset 0 size 3, and a NULL list, and [40] + let offsets = ScalarBuffer::::from(vec![0_i32, 0, 3]); + let sizes = ScalarBuffer::::from(vec![3_i32, 0, 1]); + let nulls = NullBuffer::from(vec![true, false, true]); + let input = ListViewArrayAlias::try_new( + Arc::clone(&element_field), + offsets, + sizes, + Arc::new(values), + Some(nulls), + ) + .expect("ListView construction"); + + // Fast path: max_len larger than any size (including the 0-size one) + let out_fast = truncate_list_view_array(&input, 100).expect("fast path"); + assert_eq!(out_fast.len(), 3); + assert!(out_fast.is_null(1)); + assert_eq!(out_fast.value_sizes(), &[3, 0, 1]); // unchanged layout + + // max_len=0 path + let out_zero = truncate_list_view_array(&input, 0).expect("zero len view"); + assert_eq!(out_zero.len(), 3); + assert!(out_zero.is_null(1)); + for i in [0, 2] { + assert!(!out_zero.is_null(i)); + assert_eq!(out_zero.value(i).len(), 0); + } + } + + /// Using the public `truncate_numeric_column_length` API with a `RecordBatch` + /// that contains a mix of list columns that do and do not need truncation. + /// When the fast path triggers for some columns, the original column Arc + /// should be reused (identity) for those that didn't change. + #[test] + fn test_truncate_numeric_column_length_fast_path_and_mixed() { + use arrow::array::{Int32Array, ListArray as ListArrayAlias}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::Field as ArrowField; + + use crate::record_batch::truncate_numeric_column_length; + + let schema = Arc::new(Schema::new(vec![ + ArrowField::new( + "short_lists", + DataType::List(Arc::new(Field::new("v", DataType::Int32, true))), + true, + ), + ArrowField::new( + "long_lists", + DataType::List(Arc::new(Field::new("v", DataType::Int32, true))), + true, + ), + ])); + + let short = ListArrayAlias::new( + Arc::new(Field::new("v", DataType::Int32, true)), + OffsetBuffer::::new(vec![0_i32, 1, 2].into()), + Arc::new(Int32Array::from(vec![1, 2])), + None, + ); + let long = ListArrayAlias::new( + Arc::new(Field::new("v", DataType::Int32, true)), + OffsetBuffer::::new(vec![0_i32, 5, 10].into()), + Arc::new(Int32Array::from((0..10).collect::>())), + None, + ); + + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(short), Arc::new(long)]) + .expect("batch"); + + // max_elements=3 : short_lists hits fast path (no change), long_lists truncated + let processed = + truncate_numeric_column_length(&batch, 3).expect("truncate_numeric_column_length"); + + assert_eq!(processed.num_columns(), 2); + // First column should be the exact same Arc (fast path clone in practice + // returns the input, but RecordBatch construction may not dedup Arcs; + // we at least verify logical identity of data). + let short_in = batch.column(0); + let short_out = processed.column(0); + assert_eq!(short_in.data_type(), short_out.data_type()); + assert_eq!(short_in.len(), short_out.len()); + // Second column must have been truncated (lengths now <=3) + let long_out = processed + .column(1) + .as_any() + .downcast_ref::() + .expect("ListArray"); + for i in 0..long_out.len() { + assert!(long_out.value(i).len() <= 3); + } + } + + /// `max_len = 0` must work correctly for `ListView` (non-contiguous layout). + /// This exercises the `ListView` truncation path with the most aggressive + /// truncation limit, ensuring parent nulls and element Field are preserved + /// even when every list is reduced to zero elements. + #[test] + fn test_truncate_list_view_to_zero_elements() { + use arrow::array::{Int32Array, ListViewArray as ListViewArrayAlias}; + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + let element_field = Arc::new( + Field::new("value", DataType::Int32, false) + .with_metadata([("audit".to_string(), "zero_len_view".to_string())].into()), + ); + + // ListView with mixed lengths + one parent NULL (3 lists) + let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + let offsets = ScalarBuffer::::from(vec![0_i32, 3, 3]); // 3 lists + let sizes = ScalarBuffer::::from(vec![3_i32, 0, 3]); + let nulls = NullBuffer::from(vec![true, false, true]); // 3 entries + let input = ListViewArrayAlias::try_new( + Arc::clone(&element_field), + offsets, + sizes, + Arc::new(values), + Some(nulls), + ) + .expect("ListView construction"); + + let out = truncate_list_view_array(&input, 0).expect("truncate ListView to 0"); + + assert_eq!(out.len(), 3); + assert!(out.is_null(1)); + assert!(!out.is_null(0)); + assert!(!out.is_null(2)); + assert_eq!(out.value(0).len(), 0); + assert_eq!(out.value(2).len(), 0); + + // Element field preserved + if let DataType::ListView(f) = out.data_type() { + assert_eq!(f.name(), "value"); + assert!(!f.is_nullable()); + assert_eq!( + f.metadata().get("audit"), + Some(&"zero_len_view".to_string()) + ); + } else { + panic!("expected ListView"); + } + } } diff --git a/crates/cayenne/benches/mutation_writer.rs b/crates/cayenne/benches/mutation_writer.rs index 8c84e77dfe..5b12a3d8e2 100644 --- a/crates/cayenne/benches/mutation_writer.rs +++ b/crates/cayenne/benches/mutation_writer.rs @@ -261,5 +261,82 @@ fn bench_inline_mutation_paths(c: &mut Criterion) { pressure.finish(); } -criterion_group!(benches, bench_append_roundtrip, bench_inline_mutation_paths); +/// Benchmarks the directory durability primitives added for ACID correctness +/// on local FS (parent-directory `sync_all` after `create_dir_all` for +/// snapshot directories, _partitioned_wal/, and deletions/ subdirs). +/// +/// These one-time-per-snapshot or per-table costs are the direct result of +/// the durability hardening. The benchmark quantifies the "tax" for Q21 +/// workloads that trigger frequent compactions or cross-partition operations. +fn bench_directory_durability_primitives(c: &mut Criterion) { + let rt = Runtime::new().expect("runtime"); + + let mut group = c.benchmark_group("directory_durability_sync_all"); + // These are one-time operations; a smaller sample size is sufficient + // to get stable numbers without making the bench too slow. + group.sample_size(30); + + group.bench_function("create_dir_all_plus_parent_sync", |b| { + b.iter_batched( + || { + let temp = tempfile::tempdir().expect("tempdir for bench"); + let parent = temp.path().to_path_buf(); + let child = parent.join("new_snapshot_or_wal_or_deletions_dir"); + (temp, parent, child) + }, + |(_keep_alive, parent, child)| { + rt.block_on(async { + // Replicate the exact hardened pattern used in + // ensure_snapshot_dir_exists, ensure_partitioned_wal_dir_and_sync_parent, + // and the deletions/ subdir creation in DeletionVectorWriter. + if !child.exists() { + tokio::fs::create_dir_all(&child) + .await + .expect("create_dir_all"); + let p = parent.clone(); + let _ = tokio::task::spawn_blocking(move || { + std::fs::File::open(&p).and_then(|f| f.sync_all()) + }) + .await; + } + black_box(child); + }); + }, + BatchSize::SmallInput, + ); + }); + + // For comparison: the cost of create_dir_all *without* the parent sync. + // This shows the incremental cost of the durability guarantee. + group.bench_function("create_dir_all_without_sync", |b| { + b.iter_batched( + || { + let temp = tempfile::tempdir().expect("tempdir for bench"); + let parent = temp.path().to_path_buf(); + let child = parent.join("new_snapshot_without_sync"); + (temp, parent, child) + }, + |(_keep_alive, _parent, child)| { + rt.block_on(async { + if !child.exists() { + tokio::fs::create_dir_all(&child) + .await + .expect("create_dir_all"); + } + black_box(child); + }); + }, + BatchSize::SmallInput, + ); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_append_roundtrip, + bench_inline_mutation_paths, + bench_directory_durability_primitives +); criterion_main!(benches); diff --git a/crates/cayenne/src/cayenne_catalog.rs b/crates/cayenne/src/cayenne_catalog.rs index b6a5d9088d..3f24ecb235 100644 --- a/crates/cayenne/src/cayenne_catalog.rs +++ b/crates/cayenne/src/cayenne_catalog.rs @@ -379,6 +379,39 @@ impl MetadataCatalog for CayenneCatalog { if !db_dir.exists() { tokio::fs::create_dir_all(db_dir).await?; + + // Best-effort sync of the parent directory so the db_dir entry + // itself is durable on local FS before we proceed to create the + // catalog DB file and initialize its schema. + // + // We keep this best-effort (with warning on failure) rather than + // fatal because: + // - Catalog DB directory creation is a one-time initialization + // event (not a hot write path). + // - It is immediately followed by DB file creation and schema + // initialization, which provide strong content durability. + // - The parent directory is frequently a stable, operator- + // managed volume root (e.g., K8s PersistentVolume) where + // directory entry durability is already handled at a higher + // level. + // + // This is still the right thing to do for consistency with the + // uniform durability contract used for all per-table mutable + // data paths, and it gives operators a clear warning if + // something unusual happens on a fresh deployment. + if let Some(parent) = db_dir.parent() { + let parent = parent.to_path_buf(); + if let Err(e) = tokio::task::spawn_blocking(move || { + std::fs::File::open(&parent).and_then(|f| f.sync_all()) + }) + .await + { + tracing::warn!( + "Failed to sync parent of catalog DB directory {} (subsequent DB writes will still be durable; directory entry may not survive crash): {e}", + db_dir.display() + ); + } + } } // Initialize schema using the appropriate metastore backend @@ -476,6 +509,35 @@ impl MetadataCatalog for CayenneCatalog { // Generate initial snapshot UUID let initial_snapshot_id = uuid::Uuid::now_v7().to_string(); + // Create the initial snapshot directory *before* inserting the table + // row into the metastore. This ensures the directory entry is durable + // (with parent sync of the table root) before the catalog "commits" + // the existence of a table pointing at this snapshot_id. This is the + // final piece of the uniform local-FS durability contract (snapshot + // dirs, _partitioned_wal/, deletions/, and now initial table creation). + // Matches the contract we enforce everywhere else in the write path. + if !base_path.starts_with("s3://") { + let table_root = std::path::PathBuf::from(&base_path).join(&table_id); + let snapshot_dir = table_root.join(&initial_snapshot_id); + + if !snapshot_dir.exists() { + tokio::fs::create_dir_all(&snapshot_dir) + .await + .map_err(|e| CatalogError::Io { source: e })?; + + // Sync the table root (parent of the new snapshot dir) so the + // subdir entry is durable on local FS. Best-effort on the sync + // itself (creation failure is already fatal above); this is + // the same pattern used for the first _partitioned_wal/ and + // first deletions/ subdirs. + let table_root_for_sync = table_root.clone(); + let _ = tokio::task::spawn_blocking(move || { + let _ = std::fs::File::open(&table_root_for_sync).and_then(|f| f.sync_all()); + }) + .await; + } + } + // Serialize Vortex config to JSON let vortex_config_json = serde_json::to_string(&options.vortex_config).map_err(|e| { CatalogError::InvalidOperation { @@ -534,18 +596,9 @@ impl MetadataCatalog for CayenneCatalog { Err(e) => return Err(e), } - // Create the initial snapshot directory (only for local paths) - // Directory structure: [base_path]/[table_id]/[snapshot_id]/ - // For S3 paths, directories are virtual and created when files are written - if !base_path.starts_with("s3://") { - let snapshot_dir = std::path::PathBuf::from(&base_path) - .join(&table_id) - .join(&initial_snapshot_id); - - tokio::fs::create_dir_all(&snapshot_dir) - .await - .map_err(|e| CatalogError::Io { source: e })?; - } + // The initial snapshot directory was already created (with parent + // sync) before the metastore INSERT, so the catalog row now points + // at a durable directory. Nothing more to do here for local FS. Ok(table_id) } @@ -1966,13 +2019,39 @@ async fn ensure_snapshot_directory_exists(table: &TableMetadata) -> CatalogResul return Ok(()); } - let snapshot_dir = std::path::PathBuf::from(&table.path) - .join(&table.table_id) - .join(&table.current_snapshot_id); + let table_root = std::path::PathBuf::from(&table.path).join(&table.table_id); + let snapshot_dir = table_root.join(&table.current_snapshot_id); + + match tokio::fs::metadata(&snapshot_dir).await { + Ok(metadata) if metadata.is_dir() => return Ok(()), + Ok(_) => { + return Err(CatalogError::Io { + source: std::io::Error::new( + std::io::ErrorKind::AlreadyExists, + format!( + "snapshot path '{}' exists but is not a directory", + snapshot_dir.display() + ), + ), + }); + } + Err(source) if source.kind() == std::io::ErrorKind::NotFound => {} + Err(source) => return Err(CatalogError::Io { source }), + } tokio::fs::create_dir_all(&snapshot_dir) .await - .map_err(|e| CatalogError::Io { source: e }) + .map_err(|source| CatalogError::Io { source })?; + + // Sync parent (table root) for the same durability reason as the + // initial creation path above and all other new subdir creations. + let table_root_for_sync = table_root; + let _ = tokio::task::spawn_blocking(move || { + let _ = std::fs::File::open(&table_root_for_sync).and_then(|f| f.sync_all()); + }) + .await; + + Ok(()) } /// Checks if the existing stored configuration matches the new [`CreateTableOptions`]. diff --git a/crates/cayenne/src/lib.rs b/crates/cayenne/src/lib.rs index 5280c70708..bad547db01 100644 --- a/crates/cayenne/src/lib.rs +++ b/crates/cayenne/src/lib.rs @@ -60,6 +60,7 @@ pub mod cayenne_catalog; pub mod ddl; #[cfg(feature = "partition-table-provider")] pub use ddl::CayenneDdlHandler; +pub mod logical_optimizer; pub mod metadata; pub mod metastore; pub mod optimizer_rules; diff --git a/crates/cayenne/src/logical_optimizer.rs b/crates/cayenne/src/logical_optimizer.rs new file mode 100644 index 0000000000..c8d522dd8d --- /dev/null +++ b/crates/cayenne/src/logical_optimizer.rs @@ -0,0 +1,1899 @@ +/* +Copyright 2026 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Logical optimizer rules for Cayenne. +//! +//! The flagship rule here is [`CayennePropagateFilterAcrossEquiJoinKeys`], the +//! plan-time predicate transitive closure used to unblock chbench q21 (see +//! `crates/cayenne/src/optimizer_rules.rs` module docs for the broader +//! no-spill strategy this fits into). +//! +//! `DataFusion`'s stock `infer_join_predicates` (in `push_down_filter`) already +//! propagates predicates that *directly* reference a join-key column: +//! `WHERE nation.n_nationkey = 5 AND nation.n_nationkey = supplier.s_nationkey` +//! is transformed into `WHERE supplier.s_nationkey = 5 AND ...`. That covers +//! the `n_nationkey = $const` shape but misses the q21 shape, where the +//! selective filter is on a *non-key* column (`n_name = 'CHINA'`). The +//! cardinality bound the dim-table filter implies for the equi-joined key +//! column never reaches the fact-table scans, so by the time the planner +//! orders joins from the SQL `FROM` clause, `(supplier, order_line, …)` +//! has already been chosen with no nation filter pushed through. +//! +//! ## What the rule does +//! +//! For every `LogicalPlan::Join` with `JoinType::Inner`, `JoinType::LeftSemi`, +//! `JoinType::RightSemi`, `JoinType::Left`, or `JoinType::Right`, default SQL +//! NULL equality (`NULL != NULL`), and one or more equi-key pairs whose data +//! types match, the rule inspects each side for a non-trivial `Filter` that +//! references at least one column other than each candidate join key. If one +//! side is dim-like and has a projectable column key, it wraps the *opposite* +//! side with +//! +//! ```text +//! Filter(other_side.key IN (SELECT this_side.key FROM this_side_subtree)) +//! ``` +//! +//! The inserted subquery re-projects the join key through whatever filters +//! already exist on the original side, so `DataFusion`'s +//! `decorrelate_predicate_subquery` and `push_down_filter` can then plant a +//! `LeftSemi` join (or, after pushdown, a partition-pruning predicate) on +//! the fact-table scan. For q21 this turns +//! `nation ⋈ supplier ⋈ order_line` into a shape where `supplier.s_nationkey +//! IN (SELECT n_nationkey FROM nation WHERE n_name = 'CHINA')` is visible +//! while the join graph is being costed. +//! +//! Semi-join coverage is what makes chained propagation work: after +//! `decorrelate_predicate_subquery` rewrites a propagated `InSubquery` into a +//! `LeftSemi` join, the next optimizer pass can keep propagating across +//! adjacent inner joins (e.g. `region → nation → supplier → fact`) instead of +//! halting at the semi-join boundary. Propagation correctness on +//! `LeftSemi`/`RightSemi` follows from the join's existing key-domain +//! semantics: wrapping either input with `IN (SELECT key FROM other_side)` +//! produces a subset of rows that the semi-join would already retain. +//! +//! For outer joins (`Left`, `Right`) the rule fires *only* in the +//! preserved-side → lookup-side direction. Filtering the lookup side narrows +//! matches the outer join would already drop (and substitute `NULL` for); +//! filtering the preserved side would silently delete rows the outer join is +//! supposed to emit as `NULL`-padded, which would change the output. +//! `FullOuter` is excluded — both sides are preserved, so neither direction is +//! safe. +//! +//! ## Termination +//! +//! Each introduced subquery is wrapped in a `SubqueryAlias` whose name +//! starts with [`PROPAGATED_FILTER_ALIAS_PREFIX`]. Before firing, the rule +//! walks the candidate side's filter chain and refuses to re-introduce a +//! propagated filter for the same target key. This prevents the rule from +//! oscillating with itself when the optimizer iterates to fixed point, while +//! still allowing composite joins to receive one derived filter per key. +//! +//! ## Conservatism +//! +//! The rule only fires when the side providing the filter is dim-like: a small +//! subtree with at most [`MAX_DIM_LIKE_TABLE_SCANS`] table scans behind +//! identity-preserving operators and inner joins. Joining a non-trivial subtree +//! would risk duplicate-executing a large plan inside the subquery, since +//! `DataFusion` does not currently de-duplicate plan-level common subexpressions +//! across the outer plan and an `InSubquery`. The dim-table-filter shape +//! (`Filter(n_name='CHINA') → TableScan(nation)`) and small dimension snowflakes +//! are cheap to re-execute. +//! +//! Two cardinality gates further suppress propagations that wouldn't pay off +//! at runtime, when the underlying [`TableSource`]s expose row counts via +//! `TableProvider::statistics`: +//! +//! * [`MIN_DIM_ROWS_FOR_PROPAGATION`] — skip when the dim subtree's known +//! upper-bound row count is below the threshold. Very small dims (≪ 1k +//! rows) already participate in fast hash builds; the extra `InSubquery → +//! LeftSemi` shape we'd introduce doesn't recover its own decorrelation / +//! planning cost. +//! * [`MIN_FACT_ROWS_FOR_PROPAGATION`] — skip when the receiving fact +//! subtree's known upper-bound row count is below the threshold. Below it +//! there isn't enough probe-side cardinality for the filter to save +//! meaningful work, and the plain hash join wins. +//! +//! Both gates only fire when stats are present (`Precision::Exact` or +//! `Precision::Inexact`); missing stats fall back to the structural behavior. + +use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion::common::{Column, DataFusionError, NullEquality, Result, Spans, TableReference}; +use datafusion::logical_expr::{ + Filter, Join, JoinType, LogicalPlan, Projection, Subquery, SubqueryAlias, +}; +use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use datafusion_expr::Expr; +use datafusion_expr::ExprSchemable; +use datafusion_expr::expr::InSubquery; +use std::{collections::BTreeSet, sync::Arc}; + +/// Prefix for [`SubqueryAlias`] names introduced by +/// [`CayennePropagateFilterAcrossEquiJoinKeys`]. +/// +/// Used both as a sentinel for key-scoped cycle detection (the rule refuses to +/// add another propagated filter for a target key that already has one) and as +/// a marker in explain output so the rewrite is recognizable when reading plans. +pub const PROPAGATED_FILTER_ALIAS_PREFIX: &str = "__cayenne_xclos__"; + +/// Logical optimizer rule that, for each `Inner`, `LeftSemi`, or `RightSemi` +/// join with default SQL NULL equality and a simple equi-key +/// `(left.a = right.b)`, introduces +/// `Filter(other_side.key IN (SELECT this_side.key FROM this_side_subtree))` +/// on the side opposite a non-key filter. +/// +/// See the module-level docs for the full design and the q21 motivation. +#[derive(Default)] +pub struct CayennePropagateFilterAcrossEquiJoinKeys; + +impl CayennePropagateFilterAcrossEquiJoinKeys { + /// Create a new instance of the rule. + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl std::fmt::Debug for CayennePropagateFilterAcrossEquiJoinKeys { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CayennePropagateFilterAcrossEquiJoinKeys") + .finish() + } +} + +impl OptimizerRule for CayennePropagateFilterAcrossEquiJoinKeys { + fn name(&self) -> &'static str { + "cayenne_propagate_filter_across_equi_join_keys" + } + + fn apply_order(&self) -> Option { + // TopDown: process outer joins first so the propagation seeds reach + // inner joins on the next pass. + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + let join = match plan { + LogicalPlan::Join(j) => j, + other => return Ok(Transformed::no(other)), + }; + if !matches!( + join.join_type, + JoinType::Inner + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::Left + | JoinType::Right, + ) { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + if join.null_equality != NullEquality::NullEqualsNothing { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + if matches!(join.join_type, JoinType::LeftSemi) + && right_side_carries_propagation_marker(&join.right) + { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + // For outer joins, propagation is only safe in the *preserved-side → + // lookup-side* direction. Filtering the lookup side can only narrow + // matches that the join would already drop; filtering the preserved + // side would drop output rows that the outer join would have emitted + // as `NULL`-padded. Inner and semi joins are unrestricted. + let allow_left_to_right = matches!( + join.join_type, + JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Left, + ); + let allow_right_to_left = matches!( + join.join_type, + JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right, + ); + + let equijoin_keys = matching_equijoin_keys(&join); + if equijoin_keys.is_empty() { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + let mut left_analysis = analyze_logical_side(&join.left); + let mut right_analysis = analyze_logical_side(&join.right); + + let mut new_left: Arc = Arc::clone(&join.left); + let mut new_right: Arc = Arc::clone(&join.right); + let mut changed = false; + + for key in &equijoin_keys { + match key { + EquiKey::BothColumns { left, right } => { + // Propagate the LEFT-side filtered key domain → the RIGHT side. + if allow_left_to_right + && left_analysis.is_dim_like + && left_analysis.has_non_key_filter(&left.name) + && key_preserved_through_summaries(&join.left, left) + && !skip_propagation_by_cardinality(&join.left, &join.right) + && !right_analysis.has_propagated_filter_target(&column_expr(right)) + { + let subquery_plan = build_key_projection_subquery( + Arc::clone(&join.left), + left, + config.alias_generator(), + )?; + let target = column_expr(right); + let wrapped = wrap_with_in_subquery_filter_expr( + Arc::clone(&new_right), + &target, + subquery_plan, + )?; + new_right = Arc::new(wrapped); + right_analysis.add_propagated_filter_target(&target); + changed = true; + } + + // Propagate the RIGHT-side filtered key domain → the LEFT side. + if allow_right_to_left + && right_analysis.is_dim_like + && right_analysis.has_non_key_filter(&right.name) + && key_preserved_through_summaries(&join.right, right) + && !skip_propagation_by_cardinality(&join.right, &join.left) + && !left_analysis.has_propagated_filter_target(&column_expr(left)) + { + let subquery_plan = build_key_projection_subquery( + Arc::clone(&join.right), + right, + config.alias_generator(), + )?; + let target = column_expr(left); + let wrapped = wrap_with_in_subquery_filter_expr( + Arc::clone(&new_left), + &target, + subquery_plan, + )?; + new_left = Arc::new(wrapped); + left_analysis.add_propagated_filter_target(&target); + changed = true; + } + } + EquiKey::LeftColumnRightExpr { + left_col, + right_expr, + } => { + // Only LEFT-dim → RIGHT-expr direction can fire: the right + // side has an expression key, so the fact-side filter + // target must be that expression. Propagation in the other + // direction would require projecting an expression + // (potentially referencing fact-side rows) inside the dim + // subquery, which would no longer be a cheap re-execution. + if allow_left_to_right + && left_analysis.is_dim_like + && left_analysis.has_non_key_filter(&left_col.name) + && key_preserved_through_summaries(&join.left, left_col) + && !skip_propagation_by_cardinality(&join.left, &join.right) + && !right_analysis.has_propagated_filter_target(right_expr) + { + let subquery_plan = build_key_projection_subquery( + Arc::clone(&join.left), + left_col, + config.alias_generator(), + )?; + let wrapped = wrap_with_in_subquery_filter_expr( + Arc::clone(&new_right), + right_expr, + subquery_plan, + )?; + new_right = Arc::new(wrapped); + right_analysis.add_propagated_filter_target(right_expr); + changed = true; + } + } + EquiKey::LeftExprRightColumn { + left_expr, + right_col, + } => { + // Symmetric: only RIGHT-dim → LEFT-expr direction. + if allow_right_to_left + && right_analysis.is_dim_like + && right_analysis.has_non_key_filter(&right_col.name) + && key_preserved_through_summaries(&join.right, right_col) + && !skip_propagation_by_cardinality(&join.right, &join.left) + && !left_analysis.has_propagated_filter_target(left_expr) + { + let subquery_plan = build_key_projection_subquery( + Arc::clone(&join.right), + right_col, + config.alias_generator(), + )?; + let wrapped = wrap_with_in_subquery_filter_expr( + Arc::clone(&new_left), + left_expr, + subquery_plan, + )?; + new_left = Arc::new(wrapped); + left_analysis.add_propagated_filter_target(left_expr); + changed = true; + } + } + } + } + + if !changed { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + let new_join = Join::try_new( + new_left, + new_right, + join.on, + join.filter, + join.join_type, + join.join_constraint, + join.null_equality, + )?; + + Ok(Transformed::yes(LogicalPlan::Join(new_join))) + } +} + +#[derive(Default)] +struct SideAnalysis { + is_dim_like: bool, + filter_columns: BTreeSet, + /// Targets of already-propagated `InSubquery` filters on this side, keyed + /// by the `Display` form of the target expression. Used for cycle + /// prevention — the same target should not be wrapped twice. Tracks both + /// pure-column and expression targets uniformly, so the chbench + /// `ascii(substr(c_state,1,1)) - 65` shape is also cycle-guarded. + propagated_filter_targets: BTreeSet, +} + +impl SideAnalysis { + fn has_non_key_filter(&self, key_name: &str) -> bool { + self.filter_columns.iter().any(|column| column != key_name) + } + + fn has_propagated_filter_target(&self, target: &Expr) -> bool { + self.propagated_filter_targets.contains(&target.to_string()) + } + + fn add_propagated_filter_target(&mut self, target: &Expr) { + self.propagated_filter_targets.insert(target.to_string()); + } +} + +fn column_expr(column: &Column) -> Expr { + Expr::Column(column.clone()) +} + +fn analyze_logical_side(plan: &LogicalPlan) -> SideAnalysis { + let mut analysis = SideAnalysis { + is_dim_like: is_dim_like_subtree(plan), + ..SideAnalysis::default() + }; + + let _ = plan.apply(|node| { + if let LogicalPlan::Filter(filter) = node { + collect_filter_column_names(&filter.predicate, &mut analysis.filter_columns); + collect_propagated_filter_targets( + &filter.predicate, + &mut analysis.propagated_filter_targets, + ); + } + // Post-decorrelation cycle detection: `decorrelate_predicate_subquery` + // rewrites our propagated `InSubquery` into a `LeftSemi` join with the + // marker `SubqueryAlias` as its right child. Without this branch the + // rule's cycle guard misses the marker (it only walked Filter + // predicates), and the optimizer would re-propagate on every iteration + // until hitting `max_passes`, stacking redundant `LeftSemi` layers. + if let LogicalPlan::Join(join) = node + && matches!(join.join_type, JoinType::LeftSemi) + && right_side_carries_propagation_marker(&join.right) + { + for (left_expr, _) in &join.on { + analysis + .propagated_filter_targets + .insert(left_expr.to_string()); + } + } + + Ok(TreeNodeRecursion::Continue) + }); + + analysis +} + +/// Returns `true` if `plan` is — possibly behind a chain of `Projection` or +/// `SubqueryAlias` wrappers added by later optimizer rules — a `SubqueryAlias` +/// whose name starts with [`PROPAGATED_FILTER_ALIAS_PREFIX`]. +fn right_side_carries_propagation_marker(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::SubqueryAlias(alias) => { + if alias + .alias + .table() + .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX) + { + return true; + } + right_side_carries_propagation_marker(&alias.input) + } + LogicalPlan::Projection(p) => right_side_carries_propagation_marker(&p.input), + _ => false, + } +} + +/// An equi-join key from `Join::on`, classified by which sides are pure +/// columns. Propagation requires the *dim* side to be a `Column` so the IN +/// subquery has a cheap, projectable key; the *fact* side may be an arbitrary +/// expression (e.g. the chbench `ascii(substr(c_state,1,1)) - 65` pattern). +enum EquiKey { + /// Both join keys are columns. The rule may fire in either direction + /// depending on which side is dim-like. + BothColumns { left: Column, right: Column }, + /// Left key is a column, right key is an expression. Only the + /// `LEFT → RIGHT` propagation direction is supported. + LeftColumnRightExpr { left_col: Column, right_expr: Expr }, + /// Right key is a column, left key is an expression. Only the + /// `RIGHT → LEFT` propagation direction is supported. + LeftExprRightColumn { left_expr: Expr, right_col: Column }, +} + +/// Return the equi-join keys from `join.on` whose data types match. Drops +/// pairs where both sides are expressions (no dim-like column to project) and +/// pairs whose types differ (the `IN` subquery would need an implicit cast we +/// don't insert here). +fn matching_equijoin_keys(join: &Join) -> Vec { + join.on + .iter() + .filter_map(|(left, right)| { + if !join_key_types_match(left, right, &join.left, &join.right) { + return None; + } + + match (left, right) { + (Expr::Column(l), Expr::Column(r)) => Some(EquiKey::BothColumns { + left: l.clone(), + right: r.clone(), + }), + (Expr::Column(l), other) => Some(EquiKey::LeftColumnRightExpr { + left_col: l.clone(), + right_expr: other.clone(), + }), + (other, Expr::Column(r)) => Some(EquiKey::LeftExprRightColumn { + left_expr: other.clone(), + right_col: r.clone(), + }), + // Both sides are non-trivial expressions — no cheap projection + // target on either side, skip. + _ => None, + } + }) + .collect() +} + +fn join_key_types_match( + left: &Expr, + right: &Expr, + left_plan: &LogicalPlan, + right_plan: &LogicalPlan, +) -> bool { + let Ok(left_type) = left.get_type(left_plan.schema()) else { + return false; + }; + let Ok(right_type) = right.get_type(right_plan.schema()) else { + return false; + }; + + left_type == right_type +} + +/// Maximum number of `TableScan` leaves allowed inside a dim-like subtree. +/// +/// Chosen to cover the canonical chbench / TPC-H dimension snowflake +/// (`region ⋈ nation ⋈ supplier`, three leaves) without admitting arbitrarily +/// large dim joins whose re-execution under an `InSubquery` would be expensive. +const MAX_DIM_LIKE_TABLE_SCANS: usize = 3; + +/// Skip propagation when the dim subtree's known upper-bound row count is +/// below this threshold. Below it the dim is already small enough that the +/// stock hash build is fast, and the `InSubquery → LeftSemi` decorrelation + +/// planning cost outweighs the saved probe work. +const MIN_DIM_ROWS_FOR_PROPAGATION: usize = 1_000; + +/// Skip propagation when the receiving fact subtree's known upper-bound row +/// count is below this threshold. Below it there isn't enough probe +/// cardinality for the filter to recoup the propagation overhead. +const MIN_FACT_ROWS_FOR_PROPAGATION: usize = 100_000; + +/// Returns `true` if `plan` is a "dim-like" subtree — a small snowflake of +/// dimensions composed of at most [`MAX_DIM_LIKE_TABLE_SCANS`] `TableScan`s +/// connected through identity-preserving operators (`Projection`, +/// `SubqueryAlias`, `Filter`, `Limit`), inner equi-joins with default SQL +/// NULL equality, `Aggregate`, or `Distinct`. +/// +/// The conservatism here keeps the duplicated subquery cheap: `DataFusion` will +/// execute it independently of the outer join, so we only fire on subtrees +/// where re-running the scan(s) + filter(s) is cheap. Unions, windows, sorts, +/// and any non-inner / null-equal join terminate the walk. +/// +/// `Aggregate` and `Distinct` are *structurally* allowed here, but the rule's +/// caller must additionally verify the join key is preserved through any +/// aggregations via [`key_preserved_through_summaries`] — an aggregate that +/// does not group by the key does not preserve its domain and cannot be the +/// source of a propagated subquery on that key. +fn is_dim_like_subtree(plan: &LogicalPlan) -> bool { + count_dim_like_table_scans(plan).is_some_and(|n| n <= MAX_DIM_LIKE_TABLE_SCANS) +} + +fn count_dim_like_table_scans(plan: &LogicalPlan) -> Option { + match plan { + LogicalPlan::TableScan(_) => Some(1), + LogicalPlan::Projection(p) => count_dim_like_table_scans(&p.input), + LogicalPlan::SubqueryAlias(a) => count_dim_like_table_scans(&a.input), + LogicalPlan::Filter(f) => count_dim_like_table_scans(&f.input), + LogicalPlan::Limit(l) => count_dim_like_table_scans(&l.input), + LogicalPlan::Aggregate(a) => count_dim_like_table_scans(&a.input), + LogicalPlan::Distinct(d) => count_dim_like_table_scans(distinct_input(d)), + LogicalPlan::Join(j) + if j.join_type == JoinType::Inner + && j.null_equality == NullEquality::NullEqualsNothing => + { + let l = count_dim_like_table_scans(&j.left)?; + let r = count_dim_like_table_scans(&j.right)?; + Some(l + r) + } + _ => None, + } +} + +/// Returns the single input plan of a `Distinct` regardless of variant. +fn distinct_input(distinct: &datafusion::logical_expr::Distinct) -> &LogicalPlan { + use datafusion::logical_expr::Distinct; + match distinct { + Distinct::All(input) => input, + Distinct::On(on) => &on.input, + } +} + +/// Sum of known upper-bound row counts of all `TableScan`s reachable from +/// `plan`. Returns `None` if any reachable `TableScan` is missing stats — the +/// caller falls back to the structural gates in that case. +/// +/// The walk follows every `LogicalPlan` child (not just dim-like wrappers) so +/// fact-side subtrees with joins, aggregates, etc. are summed too. The result +/// is a loose *upper bound* — filter selectivity isn't accounted for — which +/// is the right direction for the "skip if known small" gate (a true upper +/// bound below the threshold guarantees the subtree is actually small). +fn subtree_upper_bound_rows(plan: &LogicalPlan) -> Option { + use datafusion::common::stats::Precision; + use datafusion::datasource::DefaultTableSource; + + let mut total: usize = 0; + let mut any_unknown = false; + let _ = plan.apply(|node| { + if let LogicalPlan::TableScan(scan) = node { + let rows = scan + .source + .as_any() + .downcast_ref::() + .and_then(|default| default.table_provider.statistics()) + .and_then(|stats| match stats.num_rows { + Precision::Exact(n) | Precision::Inexact(n) => Some(n), + Precision::Absent => None, + }); + if let Some(n) = rows { + total = total.saturating_add(n); + } else { + any_unknown = true; + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + }); + if any_unknown { None } else { Some(total) } +} + +/// `true` when stats prove the dim side is below +/// [`MIN_DIM_ROWS_FOR_PROPAGATION`] *or* the fact side is below +/// [`MIN_FACT_ROWS_FOR_PROPAGATION`]. Missing stats on either side fall back +/// to the structural gates: this function returns `false` (allow propagation), +/// matching the rule's behavior before the cardinality gates were added. +fn skip_propagation_by_cardinality(dim_side: &LogicalPlan, fact_side: &LogicalPlan) -> bool { + if matches!( + subtree_upper_bound_rows(dim_side), + Some(n) if n < MIN_DIM_ROWS_FOR_PROPAGATION + ) { + return true; + } + if matches!( + subtree_upper_bound_rows(fact_side), + Some(n) if n < MIN_FACT_ROWS_FOR_PROPAGATION + ) { + return true; + } + false +} + +/// Returns `true` if `key` retains its scan-level domain through every +/// `Aggregate` / `Distinct` reachable in `plan`. +/// +/// * `Aggregate` preserves a column's domain only when it appears in +/// `group_expr` as a plain `Expr::Column` reference. +/// * `Distinct::All` preserves every projected column (deduplication keeps +/// value identity). +/// * `Distinct::On(distinct_on)` preserves only the columns named in its `on` +/// list; for safety we conservatively require the key to appear there. +/// +/// The walk follows only identity-preserving operators plus inner equi-joins — +/// the same shape `is_dim_like_subtree` accepts. Anything outside that vocab +/// (`Sort`, `Window`, etc.) is conservatively rejected by returning `false`. +fn key_preserved_through_summaries(plan: &LogicalPlan, key: &Column) -> bool { + fn key_for_input_schema(input: &LogicalPlan, key: &Column) -> Option { + input + .schema() + .qualified_field_with_unqualified_name(&key.name) + .ok() + .map(|(qualifier, field)| Column::new(qualifier.cloned(), field.name().clone())) + } + + fn walk(plan: &LogicalPlan, key: &Column) -> bool { + match plan { + LogicalPlan::TableScan(_) => plan.schema().has_column(key), + LogicalPlan::Projection(p) => plan.schema().has_column(key) && walk(&p.input, key), + LogicalPlan::SubqueryAlias(a) => { + let relation_matches_alias = match key.relation.as_ref() { + Some(relation) => relation == &a.alias, + None => true, + }; + relation_matches_alias + && key_for_input_schema(&a.input, key) + .is_some_and(|input_key| walk(&a.input, &input_key)) + } + LogicalPlan::Filter(f) => plan.schema().has_column(key) && walk(&f.input, key), + LogicalPlan::Limit(l) => plan.schema().has_column(key) && walk(&l.input, key), + LogicalPlan::Aggregate(a) => { + let key_in_group = a + .group_expr + .iter() + .any(|expr| matches!(expr, Expr::Column(column) if column == key)); + key_in_group && plan.schema().has_column(key) && walk(&a.input, key) + } + LogicalPlan::Distinct(distinct) => { + use datafusion::logical_expr::Distinct; + let key_kept = match distinct { + Distinct::All(_) => true, + Distinct::On(on) => on + .on_expr + .iter() + .any(|expr| matches!(expr, Expr::Column(column) if column == key)), + }; + key_kept && plan.schema().has_column(key) && walk(distinct_input(distinct), key) + } + LogicalPlan::Join(j) + if j.join_type == JoinType::Inner + && j.null_equality == NullEquality::NullEqualsNothing => + { + plan.schema().has_column(key) && (walk(&j.left, key) || walk(&j.right, key)) + } + _ => false, + } + } + + walk(plan, key) +} + +fn collect_filter_column_names(expr: &Expr, columns: &mut BTreeSet) { + let _ = expr.apply(|e| { + if let Expr::Column(column) = e { + columns.insert(column.name.clone()); + } + + Ok(TreeNodeRecursion::Continue) + }); +} + +fn collect_propagated_filter_targets(expr: &Expr, targets: &mut BTreeSet) { + let _ = expr.apply(|e| { + if let Expr::InSubquery(InSubquery { + expr: target_expr, + subquery, + .. + }) = e + && let LogicalPlan::SubqueryAlias(alias) = subquery.subquery.as_ref() + && alias + .alias + .table() + .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX) + { + targets.insert(target_expr.to_string()); + return Ok(TreeNodeRecursion::Jump); + } + + Ok(TreeNodeRecursion::Continue) + }); +} + +/// Returns `true` if `plan` already contains a propagated-filter marker. +#[cfg(test)] +fn subtree_has_propagated_filter(plan: &LogicalPlan) -> bool { + let mut found = false; + let _ = plan.apply(|node| { + if let LogicalPlan::SubqueryAlias(alias) = node + && alias + .alias + .table() + .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX) + { + found = true; + return Ok(TreeNodeRecursion::Stop); + } + if let LogicalPlan::Filter(f) = node + && expr_has_propagated_filter(&f.predicate) + { + found = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + }); + found +} + +/// Returns `true` if `expr` contains an [`InSubquery`] whose inner plan +/// starts with a [`SubqueryAlias`] named with +/// [`PROPAGATED_FILTER_ALIAS_PREFIX`]. +#[must_use] +#[cfg(test)] +fn expr_has_propagated_filter(expr: &Expr) -> bool { + let mut found = false; + let _ = expr.apply(|e| { + if let Expr::InSubquery(InSubquery { subquery, .. }) = e + && let LogicalPlan::SubqueryAlias(alias) = subquery.subquery.as_ref() + && alias + .alias + .table() + .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX) + { + found = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + }); + found +} + +/// Build a `SubqueryAlias(__cayenne_xclos__N, Projection([key_col], subtree))` +/// suitable for use as the inner plan of a [`Subquery`] referenced by an +/// [`InSubquery`] expression. +/// +/// The alias name uses [`PROPAGATED_FILTER_ALIAS_PREFIX`] plus a unique id +/// from [`OptimizerConfig::alias_generator`], so each invocation produces a +/// distinct marker. The marker doubles as the cycle-detection sentinel +/// scanned by [`analyze_logical_side`]. +fn build_key_projection_subquery( + subtree: Arc, + key_col: &Column, + alias_gen: &Arc, +) -> Result { + let key_expr = Expr::Column(key_col.clone()); + let projection = LogicalPlan::Projection(Projection::try_new(vec![key_expr], subtree)?); + let alias_name = alias_gen.next(PROPAGATED_FILTER_ALIAS_PREFIX); + let aliased = SubqueryAlias::try_new(Arc::new(projection), TableReference::bare(alias_name))?; + Ok(LogicalPlan::SubqueryAlias(aliased)) +} + +/// Wrap `input` with `Filter(target IN (subquery))` using the `subquery_plan` +/// (which must already be a `SubqueryAlias` named with +/// [`PROPAGATED_FILTER_ALIAS_PREFIX`]) as the right-hand side. `target` may be +/// a column or any expression whose columns all resolve in `input`'s schema — +/// the chbench `ascii(substr(c_state,1,1)) - 65` shape is supported through +/// this entry point. +fn wrap_with_in_subquery_filter_expr( + input: Arc, + target: &Expr, + subquery_plan: LogicalPlan, +) -> Result { + let predicate = Expr::InSubquery(InSubquery::new( + Box::new(target.clone()), + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::default(), + }, + false, + )); + let filter = Filter::try_new(predicate, input)?; + Ok(LogicalPlan::Filter(filter)) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::catalog::MemTable; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + fn rule() -> CayennePropagateFilterAcrossEquiJoinKeys { + CayennePropagateFilterAcrossEquiJoinKeys::new() + } + + fn make_ctx() -> Result { + let ctx = SessionContext::new(); + // dim-like nation table — gains an `n_regionkey` so the multi-hop + // `region ⋈ nation` propagation tests can join through it. + let nation_schema = Arc::new(Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, true), + Field::new("n_regionkey", DataType::Int64, false), + ])); + // dim-like region table for multi-hop tests. + let region_schema = Arc::new(Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, true), + ])); + // fact-like supplier table + let supplier_schema = Arc::new(Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_nationkey", DataType::Int64, false), + ])); + // fact-like customer table for expression-equi-key tests + // (chbench `ascii(substr(c_state, 1, 1)) - 65` nation mapping). + let customer_schema = Arc::new(Schema::new(vec![ + Field::new("c_id", DataType::Int64, false), + Field::new("c_state", DataType::Utf8, true), + ])); + ctx.register_table( + "nation", + Arc::new(MemTable::try_new(Arc::clone(&nation_schema), vec![vec![]])?), + )?; + ctx.register_table( + "region", + Arc::new(MemTable::try_new(Arc::clone(®ion_schema), vec![vec![]])?), + )?; + ctx.register_table( + "supplier", + Arc::new(MemTable::try_new( + Arc::clone(&supplier_schema), + vec![vec![]], + )?), + )?; + ctx.register_table( + "customer", + Arc::new(MemTable::try_new( + Arc::clone(&customer_schema), + vec![vec![]], + )?), + )?; + Ok(ctx) + } + + /// Walk a `LogicalPlan` to find the first `Join` and return whichever + /// side's plan tree contains a `SubqueryAlias` whose name starts with + /// [`PROPAGATED_FILTER_ALIAS_PREFIX`]. + fn find_propagated_side(plan: &LogicalPlan) -> Option<&'static str> { + let mut result: Option<&'static str> = None; + let _ = plan.apply(|node| { + if let LogicalPlan::Join(j) = node { + if subtree_has_propagated_filter(j.left.as_ref()) { + result = Some("left"); + return Ok(TreeNodeRecursion::Stop); + } + if subtree_has_propagated_filter(j.right.as_ref()) { + result = Some("right"); + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + }); + result + } + + fn count_propagated_filter_exprs(plan: &LogicalPlan) -> usize { + let mut count = 0; + let _ = plan.apply(|node| { + if let LogicalPlan::Filter(f) = node { + let _ = f.predicate.apply(|expr| { + if let Expr::InSubquery(InSubquery { subquery, .. }) = expr + && let LogicalPlan::SubqueryAlias(alias) = subquery.subquery.as_ref() + && alias + .alias + .table() + .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX) + { + count += 1; + } + Ok(TreeNodeRecursion::Continue) + }); + } + Ok(TreeNodeRecursion::Continue) + }); + count + } + + #[test] + fn rule_metadata() { + assert_eq!( + rule().name(), + "cayenne_propagate_filter_across_equi_join_keys" + ); + assert_eq!(rule().apply_order(), Some(ApplyOrder::TopDown)); + } + + #[tokio::test] + async fn non_inner_join_is_unchanged() -> Result<()> { + // Use `IS NULL` on the right side so `eliminate_outer_join` doesn't + // promote the LEFT JOIN to an INNER JOIN, otherwise we'd be testing + // the wrong thing. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier LEFT JOIN nation \ + ON s_nationkey = n_nationkey WHERE n_name IS NULL", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + assert!( + !changed, + "LEFT JOIN must be skipped by the rule; plan was:\n{plan}" + ); + Ok(()) + } + + /// Run the rule against every `LogicalPlan::Join` reachable from `plan`, + /// returning the transformed plan and a flag indicating whether at least + /// one invocation made a change. + /// + /// Mirrors what `DataFusion`'s optimizer driver does for an + /// `ApplyOrder::TopDown` rule, but without spinning up the rest of the + /// rule pipeline — keeps the tests focused on this rule's behavior in + /// isolation. + fn apply_rule_to_all_joins( + rule: &CayennePropagateFilterAcrossEquiJoinKeys, + plan: LogicalPlan, + cfg: &datafusion::optimizer::OptimizerContext, + ) -> Result<(LogicalPlan, bool)> { + let mut any_changed = false; + let transformed = plan.transform_down(|node| { + if matches!(node, LogicalPlan::Join(_)) { + let r = rule.rewrite(node, cfg)?; + if r.transformed { + any_changed = true; + } + Ok(r) + } else { + Ok(Transformed::no(node)) + } + })?; + Ok((transformed.data, any_changed)) + } + + #[tokio::test] + async fn inner_join_with_dim_filter_propagates_via_subquery() -> Result<()> { + // The canonical q21 shape (reduced): + // FROM supplier, nation + // WHERE s_nationkey = n_nationkey AND n_name = 'CHINA' + // + // After PushDownFilter, `n_name = 'CHINA'` lives in a Filter directly + // above the nation TableScan. The rule should then wrap supplier with + // `Filter(s_nationkey IN (SELECT n_nationkey FROM nation + // WHERE n_name = 'CHINA'))`. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier, nation \ + WHERE s_nationkey = n_nationkey AND n_name = 'CHINA'", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + + // Depending on `DataFusion`'s planner the join's `left`/`right` may be + // either order. We don't care which side gets the InSubquery, only + // that exactly one of them does, and that it carries the marker. + let propagated = find_propagated_side(&transformed_plan); + assert!( + changed, + "rule should fire on inner join with dim-side non-key filter; plan was:\n{plan}" + ); + assert!( + propagated.is_some(), + "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}" + ); + + // Cycle prevention: running the rule a second time on the + // already-transformed plan must be a no-op. + let (second_plan, changed2) = apply_rule_to_all_joins(&r, transformed_plan.clone(), &cfg)?; + assert!( + !changed2, + "second pass must not re-propagate (cycle guard); plan was:\n{second_plan}" + ); + + Ok(()) + } + + #[tokio::test] + async fn left_semi_join_with_dim_filter_propagates_via_subquery() -> Result<()> { + // The `IN (subquery)` shape that `decorrelate_predicate_subquery` + // rewrites into a `LeftSemi` join. The propagation rule must still + // fire on the resulting semi-join so the dim filter reaches the fact + // side across chained joins. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier \ + WHERE s_nationkey IN \ + (SELECT n_nationkey FROM nation WHERE n_name = 'CHINA')", + ) + .await? + .into_optimized_plan()?; + + // Sanity-check that decorrelation produced a semi-join shape; if it + // didn't, this test is testing the wrong thing. + let mut semi_seen = false; + let _ = plan.apply(|node| { + if let LogicalPlan::Join(j) = node + && matches!(j.join_type, JoinType::LeftSemi | JoinType::RightSemi) + { + semi_seen = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + }); + assert!( + semi_seen, + "expected decorrelation to produce a semi-join; plan was:\n{plan}" + ); + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + assert!( + changed, + "rule should fire on semi-join with dim-side non-key filter; plan was:\n{plan}" + ); + assert!( + find_propagated_side(&transformed_plan).is_some(), + "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}" + ); + Ok(()) + } + + #[tokio::test] + async fn left_outer_join_propagates_only_left_to_right() -> Result<()> { + // `supplier LEFT JOIN nation ON s_nationkey = n_nationkey WHERE + // s_name = 'X'`. The LEFT side (supplier) has a non-key filter; it is + // the preserved side. Propagating to the lookup side (nation) is safe. + // + // Note: `eliminate_outer_join` will rewrite the LEFT JOIN to an INNER + // JOIN only if the WHERE clause forces the right side to be non-null + // — using a filter on the LEFT side instead preserves the outer + // semantics, which is what we want for this test. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier LEFT JOIN nation \ + ON s_nationkey = n_nationkey WHERE s_suppkey > 5", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + // The supplier-side filter (`s_suppkey > 5`) is a non-key filter on + // the LEFT/preserved side. Direction LEFT→RIGHT is allowed; the rule + // should propagate `n_nationkey IN (SELECT s_nationkey FROM filtered_supplier)` + // onto nation. + assert!( + changed, + "rule should fire LEFT→RIGHT for LEFT OUTER; plan was:\n{plan}" + ); + assert!( + find_propagated_side(&transformed_plan).is_some(), + "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}" + ); + Ok(()) + } + + #[tokio::test] + async fn left_outer_join_blocks_right_to_left_propagation() -> Result<()> { + // Filter on the RIGHT (lookup) side of a LEFT OUTER must NOT cause + // propagation onto the LEFT (preserved) side: doing so would drop + // left rows the outer join should emit as `(left, NULL...)`. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier LEFT JOIN nation \ + ON s_nationkey = n_nationkey \ + WHERE n_name = 'CHINA' OR n_name IS NULL", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + // The filter is on the RIGHT side. RIGHT→LEFT propagation is forbidden + // for LEFT OUTER. LEFT→RIGHT is allowed but there's no LEFT-side filter + // to propagate. So the rule must be a no-op here. + assert!( + !changed, + "RIGHT→LEFT propagation must not fire on LEFT OUTER; plan was:\n{plan}" + ); + Ok(()) + } + + #[tokio::test] + async fn rule_does_not_re_fire_on_post_decorrelation_left_semi() -> Result<()> { + // Regression test for the cycle-detection bug across optimizer + // iterations: after Pass 1 wraps the receiving side with an + // `InSubquery`, `decorrelate_predicate_subquery` rewrites that into a + // `LeftSemi` join with the marker `SubqueryAlias` as its right child. + // If the rule's cycle detection only sees `InSubquery` markers (and + // not the structural `LeftSemi`-with-marker shape), Pass 2 sees no + // marker on the receiving side and re-propagates, producing nested + // LeftSemi joins on every subsequent optimizer pass. + // + // The fix detects the post-decorrelation shape and records the + // already-propagated target so the rule's cycle guard short-circuits + // on subsequent passes. + use datafusion::common::NullEquality; + use datafusion::logical_expr::JoinConstraint; + use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit}; + + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, true), + ])); + let fact_schema = Arc::new(Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_nationkey", DataType::Int64, false), + ])); + + // Build the dim subquery: `Filter(n_name='CHINA') → TableScan(nation)` + // wrapped in the propagated-filter alias the rule would have produced. + let nation_scan = table_scan(Some("nation"), &dim_schema, None)?.build()?; + let nation_filter = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("nation"), "n_name")).eq(lit("CHINA")), + Arc::new(nation_scan), + )?); + let nation_projection = LogicalPlan::Projection(Projection::try_new( + vec![Expr::Column(Column::new(Some("nation"), "n_nationkey"))], + Arc::new(nation_filter), + )?); + let dim_subquery_alias = format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1"); + let dim_subquery = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(nation_projection), + TableReference::bare(dim_subquery_alias), + )?); + + // Build supplier scan (the receiving fact side). + let supplier_scan = table_scan(Some("supplier"), &fact_schema, None)?.build()?; + + // Compose the post-decorrelation shape: `LeftSemi(supplier, dim_subquery)` + // on `s_nationkey = n_nationkey`. + let semi_join_input = LogicalPlanBuilder::from(supplier_scan) + .join_with_expr_keys( + dim_subquery, + JoinType::LeftSemi, + ( + vec![Expr::Column(Column::new(Some("supplier"), "s_nationkey"))], + vec![Expr::Column(Column::new( + Some(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")), + "n_nationkey", + ))], + ), + None, + )? + .build()?; + + // Now build an outer `Inner Join` between the *original* nation_filtered + // and this `LeftSemi` subtree on the same equi-key — the exact shape an + // optimizer pass would see after the rule already fired + decorrelated. + let dim_filter_again_scan = table_scan(Some("nation_outer"), &dim_schema, None)?.build()?; + let dim_filter_again = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("nation_outer"), "n_name")).eq(lit("CHINA")), + Arc::new(dim_filter_again_scan), + )?); + + let outer_join = LogicalPlan::Join(Join::try_new( + Arc::new(dim_filter_again), + Arc::new(semi_join_input), + vec![( + Expr::Column(Column::new(Some("nation_outer"), "n_nationkey")), + Expr::Column(Column::new(Some("supplier"), "s_nationkey")), + )], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?); + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, outer_join.clone(), &cfg)?; + assert!( + !changed, + "rule must not re-fire when the receiving side already contains a \ + post-decorrelation LeftSemi propagation marker; plan was:\n{outer_join}" + ); + Ok(()) + } + + #[tokio::test] + async fn rule_re_fires_when_receiving_side_has_non_marker_subquery_alias() -> Result<()> { + // Devil's-advocate edge case: a `LeftSemi` whose right side is a + // `SubqueryAlias` with a *non-marker* name should NOT block + // propagation (the marker prefix is the unique signal that this rule + // already fired). Guards against the cycle guard being too aggressive. + use datafusion::common::NullEquality; + use datafusion::logical_expr::JoinConstraint; + use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit}; + + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, true), + ])); + let fact_schema = Arc::new(Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_nationkey", DataType::Int64, false), + ])); + + let nation_scan = table_scan(Some("nation"), &dim_schema, None)?.build()?; + let nation_filter = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("nation"), "n_name")).eq(lit("CHINA")), + Arc::new(nation_scan), + )?); + let nation_projection = LogicalPlan::Projection(Projection::try_new( + vec![Expr::Column(Column::new(Some("nation"), "n_nationkey"))], + Arc::new(nation_filter), + )?); + let user_alias = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(nation_projection), + TableReference::bare("some_user_alias"), + )?); + + let supplier_scan = table_scan(Some("supplier"), &fact_schema, None)?.build()?; + let semi_join_input = LogicalPlanBuilder::from(supplier_scan) + .join_with_expr_keys( + user_alias, + JoinType::LeftSemi, + ( + vec![Expr::Column(Column::new(Some("supplier"), "s_nationkey"))], + vec![Expr::Column(Column::new( + Some("some_user_alias".to_string()), + "n_nationkey", + ))], + ), + None, + )? + .build()?; + + let outer_dim_scan = table_scan(Some("nation_outer"), &dim_schema, None)?.build()?; + let outer_dim_filter = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("nation_outer"), "n_name")).eq(lit("CHINA")), + Arc::new(outer_dim_scan), + )?); + + let outer_join = LogicalPlan::Join(Join::try_new( + Arc::new(outer_dim_filter), + Arc::new(semi_join_input), + vec![( + Expr::Column(Column::new(Some("nation_outer"), "n_nationkey")), + Expr::Column(Column::new(Some("supplier"), "s_nationkey")), + )], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?); + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, outer_join.clone(), &cfg)?; + assert!( + changed, + "rule should still fire when the receiving LeftSemi's alias is \ + user-supplied (not the propagation marker); plan was:\n{outer_join}" + ); + Ok(()) + } + + #[tokio::test] + async fn rule_detects_marker_through_projection_wrapper() -> Result<()> { + // Subsequent optimizer rules (`MergeProjection`, etc.) may wrap the + // marker `SubqueryAlias` in a `Projection`. The cycle guard must still + // detect the marker through this wrapping. + use datafusion::common::NullEquality; + use datafusion::logical_expr::JoinConstraint; + use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit}; + + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, true), + ])); + let fact_schema = Arc::new(Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_nationkey", DataType::Int64, false), + ])); + + let nation_scan = table_scan(Some("nation"), &dim_schema, None)?.build()?; + let nation_filter = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("nation"), "n_name")).eq(lit("CHINA")), + Arc::new(nation_scan), + )?); + let inner_projection = LogicalPlan::Projection(Projection::try_new( + vec![Expr::Column(Column::new(Some("nation"), "n_nationkey"))], + Arc::new(nation_filter), + )?); + let marker_alias = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(inner_projection), + TableReference::bare(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")), + )?); + let wrapped_marker = LogicalPlan::Projection(Projection::try_new( + vec![Expr::Column(Column::new( + Some(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")), + "n_nationkey", + ))], + Arc::new(marker_alias), + )?); + + let supplier_scan = table_scan(Some("supplier"), &fact_schema, None)?.build()?; + let semi_join_input = LogicalPlanBuilder::from(supplier_scan) + .join_with_expr_keys( + wrapped_marker, + JoinType::LeftSemi, + ( + vec![Expr::Column(Column::new(Some("supplier"), "s_nationkey"))], + vec![Expr::Column(Column::new( + Some(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")), + "n_nationkey", + ))], + ), + None, + )? + .build()?; + + let outer_dim_scan = table_scan(Some("nation_outer"), &dim_schema, None)?.build()?; + let outer_dim_filter = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("nation_outer"), "n_name")).eq(lit("CHINA")), + Arc::new(outer_dim_scan), + )?); + + let outer_join = LogicalPlan::Join(Join::try_new( + Arc::new(outer_dim_filter), + Arc::new(semi_join_input), + vec![( + Expr::Column(Column::new(Some("nation_outer"), "n_nationkey")), + Expr::Column(Column::new(Some("supplier"), "s_nationkey")), + )], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?); + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, outer_join.clone(), &cfg)?; + assert!( + !changed, + "cycle guard must detect a marker wrapped in an outer Projection; \ + plan was:\n{outer_join}" + ); + Ok(()) + } + + #[tokio::test] + async fn inner_join_without_filter_is_noop() -> Result<()> { + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier, nation \ + WHERE s_nationkey = n_nationkey", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + assert!( + !changed, + "rule must not fire when neither side has a non-key filter; plan was:\n{plan}" + ); + Ok(()) + } + + #[tokio::test] + async fn inner_join_with_expression_fact_key_propagates_dim_filter() -> Result<()> { + // The canonical chbench Q5/Q7/Q10 shape: a non-trivial expression on + // the fact side and a pure column on the dim side, with the dim side + // carrying the selective non-key filter. + // + // The rule must fire on `(Column, Expr)` (or `(Expr, Column)`) equi-key + // pairs even though neither side is a pure column-column join. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT c_id FROM customer, nation \ + WHERE ascii(substr(c_state, 1, 1)) - 65 = n_nationkey \ + AND n_name = 'CHINA'", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + assert!( + changed, + "rule should fire on expression-vs-column equi-key; plan was:\n{plan}" + ); + assert!( + find_propagated_side(&transformed_plan).is_some(), + "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}" + ); + + // Cycle prevention: running the rule a second time must be a no-op + // (the unified Display-keyed cycle guard tracks the InSubquery target + // expression, not just column targets). + let (_, changed2) = apply_rule_to_all_joins(&r, transformed_plan, &cfg)?; + assert!( + !changed2, + "second pass must not re-propagate (cycle guard) on expression target" + ); + Ok(()) + } + + #[tokio::test] + async fn multi_hop_dim_subtree_propagates_through_region_nation() -> Result<()> { + // The canonical Q5 shape: `region ⋈ nation ⋈ supplier` with a + // selective filter on `region.r_name`. With the multi-hop dim + // detector the `region ⋈ nation` subtree counts as dim-like, so the + // rule can propagate the filtered `n_nationkey` domain to `supplier` + // in a single pass instead of waiting for the optimizer's fixed + // point. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier, nation, region \ + WHERE s_nationkey = n_nationkey \ + AND n_regionkey = r_regionkey \ + AND r_name = 'ASIA'", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + assert!( + changed, + "rule should propagate r_name filter through the multi-hop dim subtree; \ + plan was:\n{plan}" + ); + assert!( + find_propagated_side(&transformed_plan).is_some(), + "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}" + ); + Ok(()) + } + + #[test] + fn key_preserved_through_summaries_accepts_distinct_all() -> Result<()> { + // `Distinct::All` deduplicates whole rows but preserves every column's + // values (it can only remove duplicate rows), so any join key survives. + use datafusion::logical_expr::Distinct; + use datafusion_expr::builder::table_scan; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + ])); + let scan = table_scan(Some("t"), &schema, None)?.build()?; + let distinct = LogicalPlan::Distinct(Distinct::All(Arc::new(scan))); + + let key_a = Column::new(Some("t"), "a"); + let key_b = Column::new(Some("t"), "b"); + + assert!(key_preserved_through_summaries(&distinct, &key_a)); + assert!(key_preserved_through_summaries(&distinct, &key_b)); + Ok(()) + } + + #[tokio::test] + async fn aggregate_dim_propagates_when_key_is_in_group_by() -> Result<()> { + // Pre-aggregated dim: `SELECT n_nationkey, count(*) FROM nation + // WHERE n_name = 'CHINA' GROUP BY n_nationkey` joined against + // supplier. The aggregate's GROUP BY includes `n_nationkey`, so the + // key's domain is preserved through the aggregation and the rule + // should still propagate to supplier. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier, \ + (SELECT n_nationkey FROM nation WHERE n_name = 'CHINA' \ + GROUP BY n_nationkey) AS n_agg \ + WHERE s_nationkey = n_nationkey", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + assert!( + changed, + "rule should fire when dim has Aggregate(GROUP BY key); plan was:\n{plan}" + ); + assert!( + find_propagated_side(&transformed_plan).is_some(), + "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}" + ); + Ok(()) + } + + #[test] + fn key_preserved_through_summaries_rejects_aggregate_without_key_in_group() -> Result<()> { + // Sanity-check the helper: an aggregate that does NOT group by `a` + // must report the key as not preserved. + use datafusion::logical_expr::Aggregate; + use datafusion_expr::builder::table_scan; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + ])); + let scan = table_scan(Some("t"), &schema, None)?.build()?; + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(scan), + vec![Expr::Column(Column::new(Some("t"), "b"))], + vec![], + )?); + + let key_a = Column::new(Some("t"), "a"); + let key_b = Column::new(Some("t"), "b"); + + assert!( + !key_preserved_through_summaries(&agg, &key_a), + "`a` aggregated away, must not be preserved" + ); + assert!( + key_preserved_through_summaries(&agg, &key_b), + "`b` is in GROUP BY, must be preserved" + ); + Ok(()) + } + + #[test] + fn subtree_upper_bound_rows_sums_stats_across_dim_subtree() -> Result<()> { + use datafusion::catalog::{Session, TableProvider}; + use datafusion::common::stats::Precision; + use datafusion::datasource::DefaultTableSource; + use datafusion::logical_expr::{TableType, dml::InsertOp}; + use datafusion::physical_plan::ExecutionPlan; + use datafusion_common::Statistics; + use datafusion_expr::Expr as ExprAlias; + use datafusion_expr::LogicalPlanBuilder; + + /// `TableProvider` that returns a constant row count from `statistics()`. + #[derive(Debug)] + struct FixedStatsProvider { + schema: arrow::datatypes::SchemaRef, + num_rows: usize, + } + + #[async_trait::async_trait] + impl TableProvider for FixedStatsProvider { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn schema(&self) -> arrow::datatypes::SchemaRef { + Arc::clone(&self.schema) + } + fn table_type(&self) -> TableType { + TableType::Base + } + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[ExprAlias], + _limit: Option, + ) -> Result> { + Err(datafusion::common::DataFusionError::NotImplemented( + "FixedStatsProvider scan not used in this test".to_string(), + )) + } + fn statistics(&self) -> Option { + let mut stats = Statistics::new_unknown(self.schema.as_ref()); + stats.num_rows = Precision::Exact(self.num_rows); + Some(stats) + } + async fn insert_into( + &self, + _state: &dyn Session, + _input: Arc, + _insert_op: InsertOp, + ) -> Result> { + Err(datafusion::common::DataFusionError::NotImplemented( + "FixedStatsProvider insert not used".to_string(), + )) + } + } + + fn fixed_table_scan(rows: usize) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("k", DataType::Int64, false)])); + let provider = Arc::new(FixedStatsProvider { + schema: Arc::clone(&schema), + num_rows: rows, + }); + let source = Arc::new(DefaultTableSource::new(provider)); + LogicalPlanBuilder::scan("t", source, None)?.build() + } + + // Single scan: row count is reported directly. + let small = fixed_table_scan(500)?; + assert_eq!(subtree_upper_bound_rows(&small), Some(500)); + + // Below the dim threshold → gate fires (skip propagation). + let fact = fixed_table_scan(1_000_000)?; + assert!(skip_propagation_by_cardinality(&small, &fact)); + + // Above the dim threshold + above the fact threshold → gate is silent. + let big_dim = fixed_table_scan(5_000)?; + assert!(!skip_propagation_by_cardinality(&big_dim, &fact)); + + // Below the fact threshold → gate fires from the fact side. + let tiny_fact = fixed_table_scan(50_000)?; + assert!(skip_propagation_by_cardinality(&big_dim, &tiny_fact)); + + Ok(()) + } + + #[test] + fn skip_propagation_by_cardinality_silent_when_stats_absent() -> Result<()> { + // MemTable doesn't expose row counts via `TableProvider::statistics()`, + // so the gate must fall back to the structural behavior (no skip). + use datafusion::catalog::MemTable; + use datafusion::datasource::DefaultTableSource; + use datafusion_expr::LogicalPlanBuilder; + + let schema = Arc::new(Schema::new(vec![Field::new("k", DataType::Int64, false)])); + let provider = Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![]])?); + let source = Arc::new(DefaultTableSource::new(provider)); + let scan = LogicalPlanBuilder::scan("t", source, None)?.build()?; + + assert_eq!(subtree_upper_bound_rows(&scan), None); + assert!(!skip_propagation_by_cardinality(&scan, &scan)); + Ok(()) + } + + #[test] + fn key_preserved_through_summaries_rejects_same_name_different_relation() -> Result<()> { + use datafusion::logical_expr::{Aggregate, Distinct, DistinctOn}; + use datafusion_expr::builder::table_scan; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let scan = table_scan(Some("t2"), &schema, None)?.build()?; + let t1_key = Column::new(Some("t1"), "a"); + let t2_key = Column::new(Some("t2"), "a"); + + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(scan.clone()), + vec![Expr::Column(t2_key.clone())], + vec![], + )?); + assert!( + !key_preserved_through_summaries(&aggregate, &t1_key), + "same-name GROUP BY columns from a different relation must not preserve the key" + ); + + let distinct_on = LogicalPlan::Distinct(Distinct::On(DistinctOn::try_new( + vec![Expr::Column(t2_key.clone())], + vec![Expr::Column(t2_key)], + None, + Arc::new(scan), + )?)); + assert!( + !key_preserved_through_summaries(&distinct_on, &t1_key), + "same-name DISTINCT ON columns from a different relation must not preserve the key" + ); + + Ok(()) + } + + #[tokio::test] + async fn inner_join_with_key_only_filter_is_noop() -> Result<()> { + // `n_nationkey = 22` references only the join key — `DataFusion`'s + // stock `infer_join_predicates` already handles this case, so our + // rule must NOT fire and create a redundant subquery. + let ctx = make_ctx()?; + let plan = ctx + .sql( + "SELECT s_suppkey FROM supplier, nation \ + WHERE s_nationkey = n_nationkey AND n_nationkey = 22", + ) + .await? + .into_optimized_plan()?; + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?; + assert!( + !changed, + "rule must not fire when filter references only the join key; plan was:\n{plan}" + ); + Ok(()) + } + + #[test] + fn null_equal_inner_join_is_noop() -> Result<()> { + use datafusion::logical_expr::JoinConstraint; + use datafusion_expr::{builder::table_scan, lit}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("c", DataType::Utf8, true), + ])); + let right_schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)])); + + let left_scan = table_scan(Some("l"), &left_schema, None)?.build()?; + let left = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("l"), "c")).eq(lit("v")), + Arc::new(left_scan), + )?); + let right = table_scan(Some("r"), &right_schema, None)?.build()?; + + let join = LogicalPlan::Join(Join::try_new( + Arc::new(left), + Arc::new(right), + vec![( + Expr::Column(Column::new(Some("l"), "a")), + Expr::Column(Column::new(Some("r"), "x")), + )], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNull, + )?); + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (_, changed) = apply_rule_to_all_joins(&r, join, &cfg)?; + + assert!( + !changed, + "rule must not introduce SQL IN filters for null-equal joins" + ); + Ok(()) + } + + #[test] + fn composite_join_receives_one_filter_per_non_key_constrained_key() -> Result<()> { + use datafusion::common::NullEquality; + use datafusion::logical_expr::JoinConstraint; + use datafusion_expr::{builder::table_scan, lit}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + Field::new("c", DataType::Utf8, true), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Int64, false), + ])); + + let left_scan = table_scan(Some("l"), &left_schema, None)?.build()?; + let left = LogicalPlan::Filter(Filter::try_new( + Expr::Column(Column::new(Some("l"), "c")).eq(lit("v")), + Arc::new(left_scan), + )?); + let right = table_scan(Some("r"), &right_schema, None)?.build()?; + + let join = LogicalPlan::Join(Join::try_new( + Arc::new(left), + Arc::new(right), + vec![ + ( + Expr::Column(Column::new(Some("l"), "a")), + Expr::Column(Column::new(Some("r"), "x")), + ), + ( + Expr::Column(Column::new(Some("l"), "b")), + Expr::Column(Column::new(Some("r"), "y")), + ), + ], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?); + + let r = rule(); + let cfg = datafusion::optimizer::OptimizerContext::new(); + let (transformed_plan, changed) = apply_rule_to_all_joins(&r, join, &cfg)?; + + assert!( + changed, + "rule should fire on composite inner join with side-local non-key filter" + ); + assert_eq!( + count_propagated_filter_exprs(&transformed_plan), + 2, + "each matching composite key should get one propagated filter; plan was:\n{transformed_plan}" + ); + Ok(()) + } + + #[test] + fn expr_has_propagated_filter_detects_marker_alias() -> Result<()> { + use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit}; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let scan = table_scan(Some("t"), &schema, None)?.build()?; + let projection = LogicalPlanBuilder::from(scan) + .project(vec![Expr::Column(Column::new(Some("t"), "a"))])? + .build()?; + + let alias_name = format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1"); + let aliased = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(projection), + TableReference::bare(alias_name), + )?); + + let in_subquery = Expr::InSubquery(InSubquery::new( + Box::new(lit(1i64)), + Subquery { + subquery: Arc::new(aliased), + outer_ref_columns: vec![], + spans: Spans::default(), + }, + false, + )); + + assert!(expr_has_propagated_filter(&in_subquery)); + assert!(!expr_has_propagated_filter(&lit(5i64))); + Ok(()) + } + + #[test] + fn is_dim_like_subtree_handles_simple_scan() -> Result<()> { + use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("x", DataType::Utf8, true), + ])); + let scan = table_scan(Some("t"), &schema, None)?.build()?; + assert!(is_dim_like_subtree(&scan)); + + let filtered = LogicalPlanBuilder::from(scan) + .filter(Expr::Column(Column::new(Some("t"), "x")).eq(lit("v")))? + .build()?; + assert!(is_dim_like_subtree(&filtered)); + Ok(()) + } +} diff --git a/crates/cayenne/src/metastore/sqlite.rs b/crates/cayenne/src/metastore/sqlite.rs index 42e9c2c8f3..b594331778 100644 --- a/crates/cayenne/src/metastore/sqlite.rs +++ b/crates/cayenne/src/metastore/sqlite.rs @@ -120,6 +120,35 @@ impl SqliteMetastore { if !db_dir.exists() { tokio::fs::create_dir_all(db_dir).await?; + + // Best-effort parent directory sync (defense-in-depth with + // the sync already performed in CayenneCatalog::init). + // Ensures the db_dir entry is durable before opening the + // SQLite connection and initializing the schema. + // + // We keep this best-effort (with warning on failure) for + // the same reasons as in CayenneCatalog::init: one-time + // initialization, followed by DB file + schema creation, + // and the parent is often a stable operator-managed + // volume root. + if let Some(parent) = db_dir.parent() { + let parent_for_sync = parent.to_path_buf(); + let parent_display = parent_for_sync.display().to_string(); + let db_dir_display = db_dir.display().to_string(); + match tokio::task::spawn_blocking(move || { + std::fs::File::open(&parent_for_sync).and_then(|f| f.sync_all()) + }) + .await + { + Ok(Ok(())) => {} + Ok(Err(error)) => tracing::warn!( + "Failed to sync parent directory {parent_display} after creating SQLite catalog DB directory {db_dir_display} (subsequent DB writes will still be durable): {error}" + ), + Err(error) => tracing::warn!( + "Failed to join SQLite catalog DB parent directory sync task for {parent_display}: {error}" + ), + } + } } // Open connection with tokio-rusqlite diff --git a/crates/cayenne/src/metastore/turso.rs b/crates/cayenne/src/metastore/turso.rs index c3fc9ea771..e7e24ffdbc 100644 --- a/crates/cayenne/src/metastore/turso.rs +++ b/crates/cayenne/src/metastore/turso.rs @@ -82,6 +82,35 @@ impl TursoMetastore { if !db_dir.exists() { tokio::fs::create_dir_all(db_dir).await?; + + // Best-effort parent directory sync (defense-in-depth with + // the sync already performed in CayenneCatalog::init). + // Ensures the db_dir entry is durable before opening the + // Turso connection and initializing the schema. + // + // We keep this best-effort (with warning on failure) for + // the same reasons as in CayenneCatalog::init: one-time + // initialization, followed by DB file + schema creation, + // and the parent is often a stable operator-managed + // volume root. + if let Some(parent) = db_dir.parent() { + let parent_for_sync = parent.to_path_buf(); + let parent_display = parent_for_sync.display().to_string(); + let db_dir_display = db_dir.display().to_string(); + match tokio::task::spawn_blocking(move || { + std::fs::File::open(&parent_for_sync).and_then(|f| f.sync_all()) + }) + .await + { + Ok(Ok(())) => {} + Ok(Err(error)) => tracing::warn!( + "Failed to sync parent directory {parent_display} after creating Turso catalog DB directory {db_dir_display} (subsequent DB writes will still be durable): {error}" + ), + Err(error) => tracing::warn!( + "Failed to join Turso catalog DB parent directory sync task for {parent_display}: {error}" + ), + } + } } let db = Builder::new_local(db_path).build().await.map_err(|e| { diff --git a/crates/cayenne/src/optimizer_rules.rs b/crates/cayenne/src/optimizer_rules.rs index 91b223fe72..8f3e515741 100644 --- a/crates/cayenne/src/optimizer_rules.rs +++ b/crates/cayenne/src/optimizer_rules.rs @@ -15,24 +15,110 @@ limitations under the License. */ //! Physical optimizer rules for Cayenne execution plans. +//! +//! # No-spill build-side memory strategy (q21 / chbench multi-way joins) +//! +//! `DataFusion`'s `HashJoinExec` build side is non-spillable. Under the runtime +//! memory pool (`GreedyMemoryPool` wrapped in `TrackConsumersPool`), wide chbench +//! shapes such as q21 (a 5-way join feeding a correlated `NOT EXISTS` self-join +//! over `order_line`) exhaust the `HashJoinInput[N]` reservations because each +//! build-side hash table independently materializes its full keyspace. +//! +//! The q21 fix is layered so each optimizer rule handles the part `DataFusion` +//! cannot currently spill or infer on its own: +//! +//! 1. **Logical predicate propagation.** +//! [`crate::logical_optimizer::CayennePropagateFilterAcrossEquiJoinKeys`] +//! introduces explicit `InSubquery` filters for equi-join keys when the +//! selective predicate is on a non-key column. `DataFusion`'s stock +//! `infer_join_predicates` only fires when the predicate already references +//! a join key (`WHERE n_nationkey = 5` → `WHERE s_nationkey = 5`). For q21 +//! the filter is `n_name = 'CHINA'`, so the Cayenne rule exposes the +//! `nation → supplier → stock/order_line` cardinality bound before +//! `push_down_filter` plants it into scans. +//! +//! 2. **Cross-scan dynamic filter sharing.** When a join's +//! `Arc` is pushed into one +//! `CayenneAccelerationExec`, [`CayenneDynamicFilterSharing`] installs the +//! same `Arc` on sibling `CayenneAccelerationExec`s backed by the same +//! underlying table and equi-joined column set. The shared `Arc` carries the +//! same `Arc>` state, so all sibling scans observe the exact +//! filter values as soon as the producing join accumulates them. Applies to +//! `Inner`, `LeftSemi`, and `RightSemi` parent joins (anti joins are +//! excluded — their semantics require the *absence* of a match, so sharing +//! the filter would drop rows the anti-join is supposed to preserve). +//! +//! 3. **Same-source anti / semi-join sort-merge rewrite.** `DataFusion` does not +//! create dynamic filters for anti joins, and q21's `NOT EXISTS` self-join +//! can leave large `HashJoinInput[N]` reservations behind. +//! [`CayenneAntiJoinSortMergeRewriter`] rewrites same-source Cayenne +//! `LeftAnti` / `RightAnti` / `LeftSemi` / `RightSemi` `HashJoinExec` nodes +//! to `SortMergeJoinExec` with explicit spillable `SortExec` inputs above a +//! 10M-row build-side threshold. Sort-merge preserves the join semantics for +//! each of these types without materializing a full non-spillable hash table +//! on the LEFT input (`HashJoinExec`'s build side, regardless of join type). +//! +//! [`CayenneJoinRewriter`] still handles the ordinary inner-join probe side by +//! swapping the default in-list accumulator for [`ExactLeftAccumulator`], which +//! produces a precise dynamic filter (or falls back to `RangeBounds` + +//! `BloomFilter`) that `DataFusion`'s filter-pushdown phase plants into the +//! right-side `CayenneAccelerationExec`'s `FileSource`. +//! +//! ## Audit notes (verified 2026-05-14 against the q21 explain snapshot at +//! `crates/test-framework/src/snapshot/snapshots/explain/test_framework__snapshot__file[parquet]-cayenne[file]-indexes_tpch_q21_explain.snap`) +//! +//! * **Cayenne table statistics are `Exact` at the physical-plan boundary.** +//! The chain `CayenneTableProvider::statistics` +//! → [`crate::stats::file_statistics_to_df`] returns +//! `Precision::Exact(num_rows)` whenever the persisted `i64` row count is +//! non-negative. Per-file `Statistics` are also `Exact` because +//! `VortexFormat::infer_stats` reads `row_count` from the file footer, and +//! `SessionConfig::default().collect_statistics()` is `true`, so +//! `ListingTable::do_collect_statistics` is exercised for every scan. +//! `CayenneAccelerationExec::partition_statistics` simply delegates to the +//! inner `DataSourceExec`, so the value reaches `JoinSelection`. The q21 +//! explain plan confirms `should_swap_join_order` picks the smaller side as +//! build at every level (nation/supplier on the LEFT, lineitem on the +//! RIGHT), so the residual OOM is *not* attributable to fuzzy stats — it is +//! the **logical** join order locking in the SQL `FROM` order and applying +//! the nation filter last. +//! +//! * **Build-side projections are minimal.** Every `CayenneAccelerationExec` +//! in the snapshot terminates in a `DataSourceExec` whose `projection=[...]` +//! lists only the join keys and the columns referenced above the join. +//! `DataFusion`'s stock projection pushdown already prunes wider scans down to +//! `[s_suppkey, s_name, s_nationkey]`, `[o_orderkey, o_orderstatus]`, +//! `[l_orderkey, l_suppkey]`, etc. No additional `ProjectionExec` insertion +//! above the build side is required. +//! +//! With these layers active, q21 is included in +//! `test_framework::queries::get_chbench_test_queries`. -use datafusion::common::NullEquality; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::{JoinType, NullEquality}; use datafusion::config::ConfigOptions; use datafusion::error::DataFusionError; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::joins::HashJoinExec; +use datafusion::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; +use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::{error::Result, physical_plan::projection::ProjectionExec}; +use datafusion_common::stats::Precision; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::repartition::RepartitionExec; use runtime_datafusion::execution_plan::schema_cast::SchemaCastScanExec; use runtime_datafusion::extension::bytes_processed::BytesProcessedExec; use runtime_datafusion::join_accumulator::ExactLeftAccumulator; +use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; use crate::provider::CayenneAccelerationExec; -use crate::provider::scan::IsCayenneAccelerationExec; +use crate::provider::scan::{IsCayenneAccelerationExec, ScanDynamicFilter, ScanIdentity}; /// Optimizer rule that rewrites `HashJoinExec` nodes to use `ExactLeftAccumulator` /// when the probe side is a `CayenneAccelerationExec`. @@ -47,6 +133,530 @@ impl CayenneJoinRewriter { } } +/// Shares already-pushed hash-join dynamic filters between same-source Cayenne +/// scans when the current hash join proves the relevant columns are equi-joined. +#[derive(Default)] +pub struct CayenneDynamicFilterSharing; + +impl CayenneDynamicFilterSharing { + /// Create a new `CayenneDynamicFilterSharing` optimizer rule. + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl std::fmt::Debug for CayenneDynamicFilterSharing { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CayenneDynamicFilterSharing").finish() + } +} + +/// Rewrites same-source Cayenne anti and semi joins from hash join to +/// sort-merge join when the build side is large enough to risk OOM. +/// +/// `DataFusion`'s `HashJoinExec` always materializes its left input as the +/// non-spillable build side regardless of join type. For q21's correlated +/// `NOT EXISTS` self-join (a `LeftAnti`) that build side can be a large +/// multi-way `order_line` result; the same shape arises in `EXISTS` / +/// `IN (subquery)` constructs that decorrelate into `LeftSemi`. Sort-merge +/// preserves the join semantics for each of these types while keeping the +/// build side spillable. +#[derive(Default)] +pub struct CayenneAntiJoinSortMergeRewriter; + +/// Only rewrite same-source anti or semi joins whose LEFT (build) input has +/// `Precision::Exact` row count exceeding this threshold. Below it the +/// in-memory hash table is faster than two explicit sort buffers. +const ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS: usize = 10_000_000; + +impl CayenneAntiJoinSortMergeRewriter { + /// Create a new `CayenneAntiJoinSortMergeRewriter` optimizer rule. + #[must_use] + pub fn new() -> Self { + Self + } +} + +impl std::fmt::Debug for CayenneAntiJoinSortMergeRewriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CayenneAntiJoinSortMergeRewriter").finish() + } +} + +impl PhysicalOptimizerRule for CayenneAntiJoinSortMergeRewriter { + fn name(&self) -> &'static str { + "CayenneAntiJoinSortMergeRewriter" + } + + fn schema_check(&self) -> bool { + false + } + + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result, DataFusionError> { + plan.transform_down(|node| { + let Some(hash_join) = node.as_any().downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + + let Some(sort_merge_join) = try_rewrite_same_source_anti_join(hash_join)? else { + return Ok(Transformed::no(node)); + }; + + Ok(Transformed::yes(sort_merge_join)) + }) + .data() + } +} + +impl PhysicalOptimizerRule for CayenneDynamicFilterSharing { + fn name(&self) -> &'static str { + "CayenneDynamicFilterSharing" + } + + fn schema_check(&self) -> bool { + false + } + + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result, DataFusionError> { + plan.transform_down(|node| { + let Some(hash_join) = node.as_any().downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + + let (left_additions, right_additions) = filter_additions_for_join(hash_join); + if left_additions.is_empty() && right_additions.is_empty() { + return Ok(Transformed::no(node)); + } + + let (left, left_changed) = + apply_filter_additions(Arc::clone(hash_join.left()), &left_additions, config)?; + let (right, right_changed) = + apply_filter_additions(Arc::clone(hash_join.right()), &right_additions, config)?; + + if !left_changed && !right_changed { + return Ok(Transformed::no(node)); + } + + let new_node = node.with_new_children(vec![left, right])?; + Ok(Transformed::yes(new_node)) + }) + .data() + } +} + +#[derive(Clone)] +struct CayenneScanSummary { + identity: Arc, + columns: BTreeSet, + schema_fields: Vec<(String, DataType)>, + dynamic_filters: Vec, +} + +#[derive(Clone)] +struct FilterAddition { + identity: Arc, + schema_fields: Vec<(String, DataType)>, + filter: Arc, +} + +fn filter_additions_for_join( + hash_join: &HashJoinExec, +) -> (Vec, Vec) { + // `Inner`, `LeftSemi`, and `RightSemi` all preserve the equi-key domain: + // a dynamic filter built from one side is also a valid filter for an + // equi-joined same-source scan on the other side. `LeftAnti`/`RightAnti` + // do not — their output requires the absence of a match, so propagating + // the filter would drop rows that should be retained. + if !matches!( + *hash_join.join_type(), + JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi, + ) { + return (Vec::new(), Vec::new()); + } + + let left_scans = collect_cayenne_scans(hash_join.left()); + let right_scans = collect_cayenne_scans(hash_join.right()); + if left_scans.is_empty() || right_scans.is_empty() { + return (Vec::new(), Vec::new()); + } + let right_scans_by_identity = scans_by_identity(&right_scans); + + let mut pair_columns: HashMap<(usize, usize), BTreeSet> = HashMap::new(); + for (left_key, right_key) in hash_join.on() { + let Some(left_column) = physical_column_name(left_key) else { + continue; + }; + let Some(right_column) = physical_column_name(right_key) else { + continue; + }; + + if left_column != right_column { + continue; + } + + let matching_pairs = same_source_pairs_for_column( + &left_scans, + &right_scans, + &right_scans_by_identity, + left_column, + right_column, + ); + let [(left_index, right_index)] = matching_pairs.as_slice() else { + continue; + }; + if left_scans[*left_index].schema_fields != right_scans[*right_index].schema_fields { + continue; + } + + pair_columns + .entry((*left_index, *right_index)) + .or_default() + .insert(left_column.to_string()); + } + + let mut left_additions = Vec::new(); + let mut right_additions = Vec::new(); + + for ((left_index, right_index), shared_columns) in pair_columns { + let left_scan = &left_scans[left_index]; + let right_scan = &right_scans[right_index]; + + for filter in &left_scan.dynamic_filters { + if filter.columns().is_subset(&shared_columns) { + push_filter_addition( + &mut right_additions, + Arc::clone(&right_scan.identity), + right_scan.schema_fields.clone(), + Arc::clone(filter.filter()), + ); + } + } + + for filter in &right_scan.dynamic_filters { + if filter.columns().is_subset(&shared_columns) { + push_filter_addition( + &mut left_additions, + Arc::clone(&left_scan.identity), + left_scan.schema_fields.clone(), + Arc::clone(filter.filter()), + ); + } + } + } + + (left_additions, right_additions) +} + +fn try_rewrite_same_source_anti_join( + hash_join: &HashJoinExec, +) -> Result>, DataFusionError> { + // Same-source `LeftAnti`/`RightAnti`/`LeftSemi`/`RightSemi` joins all + // share the relevant property: `HashJoinExec` builds the LEFT input into + // a non-spillable hash table, so a large build side risks OOM. Sort-merge + // is spillable and preserves the join semantics for each of these types. + if !matches!( + hash_join.join_type(), + JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi | JoinType::RightSemi, + ) { + return Ok(None); + } + + if hash_join.contains_projection() || hash_join.on().is_empty() { + return Ok(None); + } + + if !has_single_same_source_pair_for_all_join_keys(hash_join) { + return Ok(None); + } + + let Some(build_row_count) = spillable_rewrite_build_input_exact_rows(hash_join) else { + return Ok(None); + }; + if build_row_count <= ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS { + return Ok(None); + } + + let sort_options = vec![SortOptions::default(); hash_join.on().len()]; + let Some(left_ordering) = join_key_ordering( + hash_join + .on() + .iter() + .map(|(left_key, _)| Arc::clone(left_key)), + &sort_options, + ) else { + return Ok(None); + }; + let Some(right_ordering) = join_key_ordering( + hash_join + .on() + .iter() + .map(|(_, right_key)| Arc::clone(right_key)), + &sort_options, + ) else { + return Ok(None); + }; + + let left: Arc = + Arc::new(SortExec::new(left_ordering, Arc::clone(hash_join.left()))); + let right: Arc = + Arc::new(SortExec::new(right_ordering, Arc::clone(hash_join.right()))); + + let join = SortMergeJoinExec::try_new( + left, + right, + hash_join.on().to_vec(), + hash_join.filter().cloned(), + *hash_join.join_type(), + sort_options, + hash_join.null_equality(), + )?; + + tracing::debug!( + join_type = ?hash_join.join_type(), + build_row_count, + threshold = ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS, + "Replacing same-source Cayenne anti/semi HashJoinExec with SortMergeJoinExec" + ); + + Ok(Some(Arc::new(join))) +} + +fn spillable_rewrite_build_input_exact_rows(hash_join: &HashJoinExec) -> Option { + // `HashJoinExec` materializes the LEFT input as the (non-spillable) build + // hash table regardless of join type — including `*Anti` and `*Semi`. + let build_input = match hash_join.join_type() { + JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi | JoinType::RightSemi => { + hash_join.left() + } + _ => return None, + }; + + match build_input.partition_statistics(None).ok()?.num_rows { + Precision::Exact(row_count) => Some(row_count), + Precision::Inexact(_) | Precision::Absent => None, + } +} + +fn join_key_ordering( + keys: impl Iterator>, + sort_options: &[SortOptions], +) -> Option { + let sort_exprs = keys + .zip(sort_options.iter().copied()) + .map(|(expr, options)| PhysicalSortExpr { expr, options }) + .collect::>(); + + LexOrdering::new(sort_exprs) +} + +fn has_single_same_source_pair_for_all_join_keys(hash_join: &HashJoinExec) -> bool { + let left_scans = collect_cayenne_scans(hash_join.left()); + let right_scans = collect_cayenne_scans(hash_join.right()); + if left_scans.is_empty() || right_scans.is_empty() { + return false; + } + let right_scans_by_identity = scans_by_identity(&right_scans); + + let mut matched_pair = None; + for (left_key, right_key) in hash_join.on() { + let Some(left_column) = physical_column_name(left_key) else { + return false; + }; + let Some(right_column) = physical_column_name(right_key) else { + return false; + }; + + if left_column != right_column { + return false; + } + + let pairs = same_source_pairs_for_column( + &left_scans, + &right_scans, + &right_scans_by_identity, + left_column, + right_column, + ); + let [(left_index, right_index)] = pairs.as_slice() else { + return false; + }; + let pair = (*left_index, *right_index); + + if matched_pair.is_some_and(|previous_pair| previous_pair != pair) { + return false; + } + matched_pair = Some(pair); + } + + matched_pair.is_some() +} + +fn collect_cayenne_scans(plan: &Arc) -> Vec { + let mut scans = Vec::new(); + collect_cayenne_scans_inner(plan, &mut scans); + scans +} + +fn collect_cayenne_scans_inner(plan: &Arc, scans: &mut Vec) { + if let Some(cayenne) = plan.as_any().downcast_ref::() + && let Some(identity) = cayenne.scan_identity() + { + let schema_fields = plan_schema_fields(&cayenne.schema()); + let columns = schema_fields.iter().map(|(name, _)| name.clone()).collect(); + scans.push(CayenneScanSummary { + identity, + columns, + schema_fields, + dynamic_filters: cayenne.dynamic_filters(), + }); + return; + } + + for child in plan.children() { + collect_cayenne_scans_inner(child, scans); + } +} + +fn physical_column_name(expr: &Arc) -> Option<&str> { + expr.as_any().downcast_ref::().map(Column::name) +} + +fn scans_by_identity(scans: &[CayenneScanSummary]) -> HashMap, Vec> { + let mut by_identity: HashMap, Vec> = HashMap::new(); + for (index, scan) in scans.iter().enumerate() { + by_identity + .entry(Arc::clone(&scan.identity)) + .or_default() + .push(index); + } + by_identity +} + +fn same_source_pairs_for_column( + left_scans: &[CayenneScanSummary], + right_scans: &[CayenneScanSummary], + right_scans_by_identity: &HashMap, Vec>, + left_column: &str, + right_column: &str, +) -> Vec<(usize, usize)> { + let mut pairs = Vec::new(); + + for (left_index, left_scan) in left_scans.iter().enumerate() { + if !left_scan.columns.contains(left_column) { + continue; + } + + let Some(right_indices) = right_scans_by_identity.get(&left_scan.identity) else { + continue; + }; + + for &right_index in right_indices { + if right_scans[right_index].columns.contains(right_column) { + pairs.push((left_index, right_index)); + } + } + } + + pairs +} + +fn push_filter_addition( + additions: &mut Vec, + identity: Arc, + schema_fields: Vec<(String, DataType)>, + filter: Arc, +) { + if additions.iter().any(|addition| { + addition.identity == identity + && addition.schema_fields == schema_fields + && Arc::ptr_eq(&addition.filter, &filter) + }) { + return; + } + + additions.push(FilterAddition { + identity, + schema_fields, + filter, + }); +} + +fn plan_schema_fields(schema: &SchemaRef) -> Vec<(String, DataType)> { + schema + .fields() + .iter() + .map(|field| (field.name().clone(), field.data_type().clone())) + .collect() +} + +fn apply_filter_additions( + plan: Arc, + additions: &[FilterAddition], + config: &ConfigOptions, +) -> Result<(Arc, bool), DataFusionError> { + if additions.is_empty() { + return Ok((plan, false)); + } + + if let Some(cayenne) = plan.as_any().downcast_ref::() { + let Some(identity) = cayenne.scan_identity() else { + return Ok((plan, false)); + }; + let schema_fields = plan_schema_fields(&cayenne.schema()); + let existing = cayenne.dynamic_filters(); + let filters = additions + .iter() + .filter(|addition| addition.identity == identity) + .filter(|addition| addition.schema_fields == schema_fields) + .filter(|addition| { + !existing + .iter() + .any(|filter| Arc::ptr_eq(filter.filter(), &addition.filter)) + }) + .map(|addition| Arc::clone(&addition.filter)) + .collect::>(); + + let Some(new_plan) = cayenne.with_additional_dynamic_filters(&filters, config)? else { + return Ok((plan, false)); + }; + + return Ok((new_plan, true)); + } + + let children = plan + .children() + .into_iter() + .map(Arc::clone) + .collect::>(); + if children.is_empty() { + return Ok((plan, false)); + } + + let mut changed = false; + let mut new_children = Vec::with_capacity(children.len()); + for child in children { + let (new_child, child_changed) = apply_filter_additions(child, additions, config)?; + changed |= child_changed; + new_children.push(new_child); + } + + if !changed { + return Ok((plan, false)); + } + + plan.with_new_children(new_children) + .map(|plan| (plan, true)) +} + impl std::fmt::Debug for CayenneJoinRewriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("CayenneJoinRewriter").finish() @@ -195,20 +805,135 @@ impl PhysicalOptimizerRule for CayenneJoinRewriter { #[cfg(test)] mod tests { - use super::CayenneJoinRewriter; + use super::{ + ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS, CayenneAntiJoinSortMergeRewriter, + CayenneDynamicFilterSharing, CayenneJoinRewriter, FilterAddition, apply_filter_additions, + plan_schema_fields, + }; use crate::provider::CayenneAccelerationExec; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::{JoinType, NullEquality}; use datafusion::config::ConfigOptions; use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::physical_optimizer::PhysicalOptimizerRule; - use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; + use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; use datafusion::physical_plan::projection::ProjectionExec; + use datafusion::physical_plan::sorts::sort::SortExec; + use datafusion::physical_plan::union::UnionExec; use datafusion::physical_plan::{ExecutionPlan, displayable}; - use datafusion_physical_expr::expressions::col; + use datafusion_common::stats::Precision; + use datafusion_common::{DataFusionError, Result as DFResult, Statistics}; + use datafusion_datasource::file::FileSource; + use datafusion_datasource::file_groups::FileGroup; + use datafusion_datasource::file_scan_config::FileScanConfigBuilder; + use datafusion_datasource::file_stream::FileOpener; + use datafusion_datasource::source::DataSourceExec; + use datafusion_datasource::{PartitionedFile, TableSchema}; + use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, col, lit}; + use datafusion_physical_expr::projection::ProjectionExprs; + use datafusion_physical_expr::{PhysicalExpr, conjunction}; + use datafusion_physical_plan::DisplayFormatType; + use datafusion_physical_plan::filter_pushdown::{FilterPushdownPropagation, PushedDown}; + use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use object_store::ObjectMeta; + use object_store::ObjectStore; + use object_store::path::Path; use runtime_datafusion::join_accumulator::ExactLeftAccumulator; + use std::any::Any; use std::sync::Arc; + #[derive(Clone)] + struct TestFileSource { + table_schema: TableSchema, + filter: Option>, + metrics: ExecutionPlanMetricsSet, + } + + impl TestFileSource { + fn new(table_schema: TableSchema, filter: Option>) -> Self { + Self { + table_schema, + filter, + metrics: ExecutionPlanMetricsSet::new(), + } + } + } + + impl FileSource for TestFileSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &datafusion_datasource::file_scan_config::FileScanConfig, + _partition: usize, + ) -> DFResult> { + Err(DataFusionError::NotImplemented( + "test source cannot open files".to_string(), + )) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_schema(&self) -> &TableSchema { + &self.table_schema + } + + fn with_batch_size(&self, _batch_size: usize) -> Arc { + Arc::new(self.clone()) + } + + fn filter(&self) -> Option> { + self.filter.clone() + } + + fn projection(&self) -> Option<&ProjectionExprs> { + None + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn file_type(&self) -> &'static str { + "test" + } + + fn fmt_extra( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + Ok(()) + } + + fn try_pushdown_filters( + &self, + filters: Vec>, + _config: &ConfigOptions, + ) -> DFResult>> { + let filter_count = filters.len(); + let filter = match &self.filter { + Some(existing) => Some(conjunction( + std::iter::once(Arc::clone(existing)).chain(filters), + )), + None => Some(conjunction(filters)), + }; + let source = Self { + table_schema: self.table_schema.clone(), + filter, + metrics: ExecutionPlanMetricsSet::new(), + }; + + Ok(FilterPushdownPropagation::with_parent_pushdown_result(vec![ + PushedDown::Yes; + filter_count + ]) + .with_updated_node(Arc::new(source))) + } + } + fn memory_exec(column_name: &str) -> Arc { let schema = Arc::new(Schema::new(vec![Field::new( column_name, @@ -219,6 +944,118 @@ mod tests { .expect("memory exec should be valid") } + fn file_exec( + schema: &Arc, + path: &str, + filter: Option>, + ) -> Arc { + file_exec_with_statistics(schema, path, filter, Statistics::new_unknown(schema)) + } + + fn file_exec_with_statistics( + schema: &Arc, + path: &str, + filter: Option>, + statistics: Statistics, + ) -> Arc { + let table_schema = TableSchema::new(Arc::clone(schema), Vec::new()); + let source = Arc::new(TestFileSource::new(table_schema, filter)); + let file = PartitionedFile::from(ObjectMeta { + location: Path::from(path), + last_modified: chrono::DateTime::UNIX_EPOCH, + size: 1_024, + e_tag: None, + version: None, + }); + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("file:///").expect("object store url should parse"), + source, + ) + .with_file_group(FileGroup::new(vec![file])) + .with_statistics(statistics) + .build(); + + DataSourceExec::from_data_source(config) as Arc + } + + fn cayenne_file_exec( + schema: &Arc, + path: &str, + filter: Option>, + ) -> Arc { + Arc::new(CayenneAccelerationExec::new(file_exec( + schema, path, filter, + ))) + } + + fn inlined_exec(schema: &Arc) -> Arc { + MemorySourceConfig::try_new_exec(&[vec![]], Arc::clone(schema), None) + .expect("inlined memory exec should be valid") + } + + fn cayenne_file_with_inlined_exec( + schema: &Arc, + path: &str, + filter: Option>, + ) -> Arc { + Arc::new(CayenneAccelerationExec::new( + UnionExec::try_new(vec![file_exec(schema, path, filter), inlined_exec(schema)]) + .expect("mixed file and inlined union should be valid"), + )) + } + + fn cayenne_file_exec_with_num_rows( + schema: &Arc, + path: &str, + row_count: Precision, + ) -> Arc { + Arc::new(CayenneAccelerationExec::new(file_exec_with_statistics( + schema, + path, + None, + Statistics::new_unknown(schema).with_num_rows(row_count), + ))) + } + + fn large_exact_cayenne_file_exec(schema: &Arc, path: &str) -> Arc { + cayenne_file_exec_with_num_rows( + schema, + path, + Precision::Exact(ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS + 1), + ) + } + + fn order_line_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("order_id", DataType::Int64, false), + Field::new("warehouse_id", DataType::Int64, false), + Field::new("line_number", DataType::Int64, false), + ])) + } + + fn reordered_order_line_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("warehouse_id", DataType::Int64, false), + Field::new("order_id", DataType::Int64, false), + Field::new("line_number", DataType::Int64, false), + ])) + } + + fn order_line_schema_with_different_non_key_type() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("order_id", DataType::Int64, false), + Field::new("warehouse_id", DataType::Int64, false), + Field::new("line_number", DataType::UInt64, false), + ])) + } + + fn dynamic_filter_for(column_name: &str, schema: &Schema) -> Arc { + Arc::new(DynamicFilterPhysicalExpr::new( + vec![col(column_name, schema).expect("filter column should exist")], + lit(true), + )) + } + fn hash_join( left: Arc, right: Arc, @@ -241,15 +1078,57 @@ mod tests { right_column: &str, null_equality: NullEquality, ) -> HashJoinExec { - let left_key = col(left_column, &left.schema()).expect("left join key should exist"); - let right_key = col(right_column, &right.schema()).expect("right join key should exist"); + hash_join_with_join_type( + left, + right, + left_column, + right_column, + JoinType::Inner, + null_equality, + ) + } + + fn hash_join_with_join_type( + left: Arc, + right: Arc, + left_column: &str, + right_column: &str, + join_type: JoinType, + null_equality: NullEquality, + ) -> HashJoinExec { + hash_join_with_join_type_on( + left, + right, + &[(left_column, right_column)], + join_type, + null_equality, + ) + } + + fn hash_join_with_join_type_on( + left: Arc, + right: Arc, + columns: &[(&str, &str)], + join_type: JoinType, + null_equality: NullEquality, + ) -> HashJoinExec { + let on = columns + .iter() + .map(|(left_column, right_column)| { + let left_key = + col(left_column, &left.schema()).expect("left join key should exist"); + let right_key = + col(right_column, &right.schema()).expect("right join key should exist"); + (left_key, right_key) + }) + .collect(); HashJoinExec::try_new( left, right, - vec![(left_key, right_key)], + on, None, - &JoinType::Inner, + &join_type, None, PartitionMode::Partitioned, null_equality, @@ -267,6 +1146,18 @@ mod tests { .expect("optimizer should succeed") } + fn optimize_filter_sharing(plan: Arc) -> Arc { + CayenneDynamicFilterSharing::new() + .optimize(plan, &ConfigOptions::default()) + .expect("filter sharing optimizer should succeed") + } + + fn optimize_anti_join_sort_merge(plan: Arc) -> Arc { + CayenneAntiJoinSortMergeRewriter::new() + .optimize(plan, &ConfigOptions::default()) + .expect("anti join sort-merge optimizer should succeed") + } + fn plan_snapshot(plan: &Arc) -> String { displayable(plan.as_ref()).indent(true).to_string() } @@ -374,6 +1265,488 @@ mod tests { ); } + #[test] + fn shares_dynamic_filter_across_same_source_equi_joined_cayenne_scans() { + let schema = order_line_schema(); + let source_filter = dynamic_filter_for("order_id", &schema); + let left = cayenne_file_exec( + &schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + ); + let right = cayenne_file_exec(&schema, "order_line.vortex", None); + let join = Arc::new(hash_join(left, right, "order_id", "order_id")); + + let optimized = optimize_filter_sharing(join); + let join = optimized + .as_any() + .downcast_ref::() + .expect("optimized plan should remain a hash join"); + let right = join + .right() + .as_any() + .downcast_ref::() + .expect("right side should remain Cayenne"); + let filters = right.dynamic_filters(); + + assert_eq!(1, filters.len()); + assert!(Arc::ptr_eq(filters[0].filter(), &source_filter)); + } + + #[test] + fn shares_dynamic_filter_with_vortex_branch_of_mixed_inlined_scan() { + let schema = order_line_schema(); + let source_filter = dynamic_filter_for("order_id", &schema); + let left = cayenne_file_with_inlined_exec( + &schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + ); + let right = cayenne_file_with_inlined_exec(&schema, "order_line.vortex", None); + let join = Arc::new(hash_join(left, right, "order_id", "order_id")); + + let optimized = optimize_filter_sharing(join); + let join = optimized + .as_any() + .downcast_ref::() + .expect("optimized plan should remain a hash join"); + let right = join + .right() + .as_any() + .downcast_ref::() + .expect("right side should remain Cayenne"); + let filters = right.dynamic_filters(); + + assert_eq!(1, filters.len()); + assert!(Arc::ptr_eq(filters[0].filter(), &source_filter)); + } + + #[test] + fn does_not_share_dynamic_filter_when_join_does_not_cover_filter_columns() { + let schema = order_line_schema(); + let source_filter = dynamic_filter_for("line_number", &schema); + let left = cayenne_file_exec( + &schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + ); + let right = cayenne_file_exec(&schema, "order_line.vortex", None); + let join = Arc::new(hash_join(left, right, "order_id", "order_id")); + + let optimized = optimize_filter_sharing(join); + let join = optimized + .as_any() + .downcast_ref::() + .expect("optimized plan should remain a hash join"); + let right = join + .right() + .as_any() + .downcast_ref::() + .expect("right side should remain Cayenne"); + + assert!(right.dynamic_filters().is_empty()); + } + + #[test] + fn does_not_share_dynamic_filter_across_different_projection_order() { + let left_schema = order_line_schema(); + let right_schema = reordered_order_line_schema(); + let source_filter = dynamic_filter_for("order_id", &left_schema); + let left = cayenne_file_exec( + &left_schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + ); + let right = cayenne_file_exec(&right_schema, "order_line.vortex", None); + let join = Arc::new(hash_join(left, right, "order_id", "order_id")); + + let optimized = optimize_filter_sharing(join); + let join = optimized + .as_any() + .downcast_ref::() + .expect("optimized plan should remain a hash join"); + let right = join + .right() + .as_any() + .downcast_ref::() + .expect("right side should remain Cayenne"); + + assert!(right.dynamic_filters().is_empty()); + } + + #[test] + fn does_not_share_dynamic_filter_across_different_schema_types() { + let left_schema = order_line_schema(); + let right_schema = order_line_schema_with_different_non_key_type(); + let source_filter = dynamic_filter_for("order_id", &left_schema); + let left = cayenne_file_exec( + &left_schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + ); + let right = cayenne_file_exec(&right_schema, "order_line.vortex", None); + let join = Arc::new(hash_join(left, right, "order_id", "order_id")); + + let optimized = optimize_filter_sharing(join); + let join = optimized + .as_any() + .downcast_ref::() + .expect("optimized plan should remain a hash join"); + let right = join + .right() + .as_any() + .downcast_ref::() + .expect("right side should remain Cayenne"); + + assert!(right.dynamic_filters().is_empty()); + } + + #[test] + fn does_not_apply_filter_addition_to_same_identity_different_projection_order() { + // `apply_filter_additions` must not push a filter into a scan whose + // schema fields don't match the source scan exactly (different column + // ordering / types means the filter's column-by-position indices + // would refer to wrong columns). + let source_schema = order_line_schema(); + let target_schema = reordered_order_line_schema(); + let source_filter = dynamic_filter_for("order_id", &source_schema); + let source = CayenneAccelerationExec::new(file_exec( + &source_schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + )); + let addition = FilterAddition { + identity: source + .scan_identity() + .expect("source scan should have file identity"), + schema_fields: plan_schema_fields(&source.schema()), + filter: Arc::clone(&source_filter), + }; + let target = cayenne_file_exec(&target_schema, "order_line.vortex", None); + + let (optimized, changed) = + apply_filter_additions(Arc::clone(&target), &[addition], &ConfigOptions::default()) + .expect("filter addition should be evaluated"); + + assert!(!changed); + let target = optimized + .as_any() + .downcast_ref::() + .expect("target should remain Cayenne"); + assert!(target.dynamic_filters().is_empty()); + } + + #[test] + fn does_not_share_dynamic_filter_for_anti_join() { + // `*Anti` joins must not receive a shared dynamic filter: their + // output requires the *absence* of a match, so filtering the probe + // side would drop rows that the anti-join should preserve. + let schema = order_line_schema(); + let source_filter = dynamic_filter_for("order_id", &schema); + let left = cayenne_file_exec( + &schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + ); + let right = cayenne_file_exec(&schema, "order_line.vortex", None); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_filter_sharing(join); + let join = optimized + .as_any() + .downcast_ref::() + .expect("optimized plan should remain a hash join"); + let right = join + .right() + .as_any() + .downcast_ref::() + .expect("right side should remain Cayenne"); + + assert!(right.dynamic_filters().is_empty()); + } + + #[test] + fn shares_dynamic_filter_for_left_semi_join() { + // `LeftSemi` preserves the equi-key domain: a dynamic filter built + // from the left side is also valid on a same-source equi-joined + // right scan, since the semi join's output is a subset of the left. + let schema = order_line_schema(); + let source_filter = dynamic_filter_for("order_id", &schema); + let left = cayenne_file_exec( + &schema, + "order_line.vortex", + Some(Arc::clone(&source_filter)), + ); + let right = cayenne_file_exec(&schema, "order_line.vortex", None); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_filter_sharing(join); + let join = optimized + .as_any() + .downcast_ref::() + .expect("optimized plan should remain a hash join"); + let right = join + .right() + .as_any() + .downcast_ref::() + .expect("right side should remain Cayenne"); + let filters = right.dynamic_filters(); + + assert_eq!( + 1, + filters.len(), + "semi join should propagate same-source filter" + ); + assert!(Arc::ptr_eq(filters[0].filter(), &source_filter)); + } + + #[test] + fn rewrites_same_source_left_semi_hash_join_to_sort_merge() { + // Same memory concern as `LeftAnti`: `HashJoinExec` materializes the + // LEFT input as a non-spillable hash table, and a large same-source + // semi-join build side risks OOM. + let schema = order_line_schema(); + let left = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let right = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + assert!( + optimized + .as_any() + .downcast_ref::() + .is_some(), + "same-source Cayenne LeftSemi join should use sort-merge join" + ); + } + + #[test] + fn rewrites_same_source_left_anti_hash_join_to_sort_merge() { + let schema = order_line_schema(); + let left = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let right = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + let sort_merge = optimized + .as_any() + .downcast_ref::() + .expect("same-source Cayenne anti join should use sort-merge join"); + + assert_eq!(JoinType::LeftAnti, sort_merge.join_type()); + assert!( + sort_merge + .left() + .as_any() + .downcast_ref::() + .is_some(), + "left anti-join input should be explicitly sorted" + ); + assert!( + sort_merge + .right() + .as_any() + .downcast_ref::() + .is_some(), + "right anti-join input should be explicitly sorted" + ); + } + + #[test] + fn rewrites_same_source_multi_key_left_anti_hash_join_to_sort_merge() { + let schema = order_line_schema(); + let left = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let right = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let join = Arc::new(hash_join_with_join_type_on( + left, + right, + &[("order_id", "order_id"), ("warehouse_id", "warehouse_id")], + JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + + assert!( + optimized + .as_any() + .downcast_ref::() + .is_some(), + "multi-key same-source Cayenne anti join should use sort-merge join" + ); + } + + #[test] + fn leaves_unrelated_left_anti_hash_join_unchanged() { + let schema = order_line_schema(); + let left = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let right = large_exact_cayenne_file_exec(&schema, "other_order_line.vortex"); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + + assert!( + optimized.as_any().downcast_ref::().is_some(), + "anti joins over unrelated sources should stay as hash joins" + ); + } + + #[test] + fn leaves_exact_small_same_source_left_anti_hash_join_unchanged() { + let schema = order_line_schema(); + let left = cayenne_file_exec_with_num_rows( + &schema, + "order_line.vortex", + Precision::Exact(ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS), + ); + let right = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + + assert!( + optimized.as_any().downcast_ref::().is_some(), + "same-source anti joins at or below the large-input threshold should stay as hash joins" + ); + } + + #[test] + fn leaves_inexact_same_source_left_anti_hash_join_unchanged() { + let schema = order_line_schema(); + let left = cayenne_file_exec_with_num_rows( + &schema, + "order_line.vortex", + Precision::Inexact(ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS + 1), + ); + let right = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + + assert!( + optimized.as_any().downcast_ref::().is_some(), + "same-source anti joins with inexact preserved-side stats should stay as hash joins" + ); + } + + #[test] + fn leaves_unknown_same_source_left_anti_hash_join_unchanged() { + let schema = order_line_schema(); + let left = cayenne_file_exec(&schema, "order_line.vortex", None); + let right = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + + assert!( + optimized.as_any().downcast_ref::().is_some(), + "same-source anti joins with unknown preserved-side stats should stay as hash joins" + ); + } + + #[test] + fn rewrites_right_anti_hash_join_when_build_side_stats_are_exact_large() { + let schema = order_line_schema(); + let left = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let right = cayenne_file_exec(&schema, "order_line.vortex", None); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + + assert!( + optimized + .as_any() + .downcast_ref::() + .is_some(), + "RightAnti should gate on the left build side, not the right preserved side" + ); + } + + #[test] + fn leaves_right_anti_hash_join_when_build_side_stats_are_unknown() { + let schema = order_line_schema(); + let left = cayenne_file_exec(&schema, "order_line.vortex", None); + let right = large_exact_cayenne_file_exec(&schema, "order_line.vortex"); + let join = Arc::new(hash_join_with_join_type( + left, + right, + "order_id", + "order_id", + JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )); + + let optimized = optimize_anti_join_sort_merge(join); + + assert!( + optimized.as_any().downcast_ref::().is_some(), + "RightAnti should stay hash join when the left build side has unknown stats" + ); + } + #[test] fn snapshots_cayenne_probe_join_explain_plan() { let right = Arc::new(CayenneAccelerationExec::new(memory_exec("right_id"))); diff --git a/crates/cayenne/src/partition_creator.rs b/crates/cayenne/src/partition_creator.rs index 213174bd50..336b8e8ea9 100644 --- a/crates/cayenne/src/partition_creator.rs +++ b/crates/cayenne/src/partition_creator.rs @@ -210,10 +210,36 @@ impl PartitionCreator for CayennePartitionCreator { } tracing::debug!("creating Cayenne partition at {partition_path}"); - std::fs::create_dir_all(&partition_dir) + tokio::fs::create_dir_all(&partition_dir) + .await .boxed() .context(creator::CreatePartitionSnafu)?; + // For local FS, sync the parent (table base_path) after creating a new + // partition sub-directory so its directory entry is durable before we + // record the partition in the catalog via add_partition. This follows + // the same uniform contract as snapshot directories, _partitioned_wal/, + // deletions/ subdirs, and initial table creation. + if self.object_store_config.is_none() + && let Some(parent) = partition_dir.parent() + { + let parent = parent.to_path_buf(); + let parent_display = parent.display().to_string(); + match tokio::task::spawn_blocking(move || { + std::fs::File::open(&parent).and_then(|f| f.sync_all()) + }) + .await + { + Ok(Ok(())) => {} + Ok(Err(error)) => tracing::warn!( + "Failed to sync Cayenne partition parent directory {parent_display}: {error}" + ), + Err(error) => tracing::warn!( + "Failed to join Cayenne partition parent directory sync task for {parent_display}: {error}" + ), + } + } + let partition_column_names = self.partition_column_labels(); let partition_key = partition_value_strings.join("/"); diff --git a/crates/cayenne/src/provider/delete/vector_io.rs b/crates/cayenne/src/provider/delete/vector_io.rs index 284a281661..0a1d1a3f4a 100644 --- a/crates/cayenne/src/provider/delete/vector_io.rs +++ b/crates/cayenne/src/provider/delete/vector_io.rs @@ -174,7 +174,47 @@ impl<'a> DeletionVectorWriter<'a> { } let deletion_dir = self.table_snapshot_deletion_dir(); - tokio::fs::create_dir_all(&deletion_dir).await?; + let snapshot_dir = deletion_dir + .parent() + .map(Path::to_path_buf) + .ok_or_else(|| Error::Internal { + table: self.table.path.clone(), + message: format!( + "Deletion vector directory '{}' has no snapshot parent", + deletion_dir.display() + ), + })?; + + // Ensure the deletions/ subdirectory exists. + // If we just created it, sync its parent (the snapshot directory) + // so the subdir entry is durable on local FS. + // + // This is required for the same contract we now enforce for + // snapshot directories themselves (ensure_snapshot_dir_exists) + // and for the _partitioned_wal/ coordination directory: + // on POSIX, mkdir in a directory updates the parent's metadata. + // A crash immediately after this create_dir_all but before the + // subsequent file write + file fsync + catalog record could + // otherwise leave a catalog entry pointing at a deletions/ + // directory whose creation was lost. + // + // The sync is one-time per snapshot (first deletion vector + // written to it). Subsequent deletions reuse the directory. + let sync_snapshot_parent = match tokio::fs::create_dir(&deletion_dir).await { + Ok(()) => true, + Err(source) if source.kind() == std::io::ErrorKind::AlreadyExists => false, + Err(source) if source.kind() == std::io::ErrorKind::NotFound => { + tokio::fs::create_dir_all(&deletion_dir).await?; + true + } + Err(source) => return Err(Error::IoError { source }), + }; + if sync_snapshot_parent { + let table = self.table.path.clone(); + tokio::task::spawn_blocking(move || std::fs::File::open(&snapshot_dir)?.sync_all()) + .await + .map_err(|source| Error::TaskPanicked { table, source })??; + } let file_path = Self::deletion_file_path(&deletion_dir); @@ -474,6 +514,17 @@ async fn write_deletion_file( writer.write(&batch)?; writer.finish()?; + // Ensure the deletion vector file content is durable before we record + // a pointer to it in the catalog. A crash without this sync could leave + // a zero-length or partial .arrow file while the catalog transaction + // that references it has committed (or is about to). On recovery, + // readers would then hit a missing/corrupt deletion vector for a + // "committed" delete — either erroring or (worse) returning deleted rows. + // This is the exact durability requirement we enforce for data files + // and WAL markers in the append path. + let f = std::fs::OpenOptions::new().write(true).open(&output_path)?; + f.sync_all()?; + let metadata = std::fs::metadata(&output_path)?; Ok(metadata.len()) diff --git a/crates/cayenne/src/provider/partitioned_wal.rs b/crates/cayenne/src/provider/partitioned_wal.rs index 761c33d82c..37956ef2f5 100644 --- a/crates/cayenne/src/provider/partitioned_wal.rs +++ b/crates/cayenne/src/provider/partitioned_wal.rs @@ -106,6 +106,48 @@ impl PartitionedWal { } } + /// Ensure the `_partitioned_wal/` subdirectory exists. + /// If we just created it, sync its parent (the table root) so the + /// subdirectory entry itself is durable on local FS. + /// + /// This is required for the same reason as the parent-directory sync + /// in `ensure_snapshot_dir_exists`: on POSIX, creating a subdirectory + /// updates the parent's directory metadata. Without the parent sync, + /// a crash can make the `_partitioned_wal/` directory "disappear" even + /// though we are about to write a coordination record inside it. + /// + /// This is the last piece of the local-FS durability puzzle for the + /// cross-partition coordination infrastructure (the write side of + /// `PartitionedWal` now has the same treatment as the removal side + /// and as all snapshot directory creation paths). + /// + /// Note for S3 tables: the `_partitioned_wal/` directory and the + /// `PartitionedWal` JSON file are still local files on the writer's + /// machine (coordination is local to the writer process). The per- + /// partition "staging WAL" on S3 is an object in the staging prefix, + /// and its removal is a best-effort object delete. The local FS + /// durability fixes apply to the coordination records on the writer. + async fn ensure_partitioned_wal_dir_and_sync_parent( + table_root: &Path, + wal_dir: &Path, + ) -> Result<()> { + match tokio::fs::create_dir(wal_dir).await { + Ok(()) => { + let parent = table_root.to_path_buf(); // the table root + let table = table_root.display().to_string(); + + // Sync the table root so the _partitioned_wal/ subdir entry is durable. + tokio::task::spawn_blocking(move || std::fs::File::open(&parent)?.sync_all()) + .await + .map_err(|source| Error::TaskPanicked { table, source })??; + } + Err(source) if source.kind() == std::io::ErrorKind::AlreadyExists => {} + Err(source) => return Err(Error::IoError { source }), + } + + Ok(()) + } + /// Return the on-disk path for this WAL under the given table root. #[must_use] pub fn path_under(&self, table_root: &Path) -> PathBuf { @@ -133,7 +175,12 @@ impl PartitionedWal { /// be written, serialization fails, or the atomic rename / fsync fails. pub async fn write_to(&self, table_root: &Path) -> Result { let wal_dir = table_root.join(PARTITIONED_WAL_DIR); - tokio::fs::create_dir_all(&wal_dir).await?; + + // Ensure the _partitioned_wal/ subdirectory exists and, if we just + // created it, sync its parent (the table root) so the subdirectory + // entry itself is durable. This is the same durability requirement we + // now enforce for new snapshot directories in ensure_snapshot_dir_exists. + Self::ensure_partitioned_wal_dir_and_sync_parent(table_root, &wal_dir).await?; let wal_path = wal_dir.join(format!("{}.json", self.commit_id)); let tmp_path = wal_dir.join(format!("{}.json.tmp", self.commit_id)); @@ -194,6 +241,41 @@ impl PartitionedWal { "Removed partitioned WAL at {} (commit {commit_id})", path.display(), ); + + // Best-effort directory sync so that the absence of the + // cross-partition coordination marker is durable. This aligns + // the removal of the top-level `PartitionedWal` with the + // per-partition staging WAL removal (which now also syncs its + // directory). A crash without this sync could leave the marker + // visible, causing a conservative "incomplete cross-partition + // commit" detection on the next open (safe, but noisy). + // The actual data durability is already guaranteed by the + // per-partition move + WAL removal steps that ran under the + // held barrier before this call. + let wal_dir = table_root.join(PARTITIONED_WAL_DIR); + let wal_dir_display = wal_dir.display().to_string(); + match tokio::task::spawn_blocking(move || { + std::fs::File::open(&wal_dir).and_then(|f| f.sync_all()) + }) + .await + { + Ok(Ok(())) => {} + Ok(Err(e)) => { + tracing::warn!( + "Failed to sync {} after removing PartitionedWal {} (data is safe; may see stale marker on restart): {e}", + wal_dir_display, + commit_id, + ); + } + Err(join_err) => { + tracing::warn!( + "Join error while syncing {} after removing PartitionedWal {} (data is safe): {join_err}", + wal_dir_display, + commit_id, + ); + } + } + Ok(()) } Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()), diff --git a/crates/cayenne/src/provider/scan.rs b/crates/cayenne/src/provider/scan.rs index ede464e384..a02f46a0fb 100644 --- a/crates/cayenne/src/provider/scan.rs +++ b/crates/cayenne/src/provider/scan.rs @@ -14,14 +14,21 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::{any::Any, sync::Arc}; +use std::{ + any::Any, + collections::BTreeSet, + sync::{Arc, OnceLock}, +}; use arrow_schema::SchemaRef; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion_common::{DataFusionError, Statistics}; +use datafusion_datasource::file_scan_config::FileScanConfig; +use datafusion_datasource::source::DataSourceExec; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{Distribution, OrderingRequirements, PhysicalExpr}; use futures::TryStreamExt; @@ -36,6 +43,7 @@ use datafusion_physical_plan::{ metrics::MetricsSet, projection::ProjectionExec, repartition::RepartitionExec, + union::UnionExec, }; /// Wrapper for Cayenne acceleration execution plans. @@ -43,16 +51,349 @@ use datafusion_physical_plan::{ #[derive(Debug)] pub struct CayenneAccelerationExec { inner: Arc, + scan_identity: OnceLock>>, } impl CayenneAccelerationExec { /// Creates a new `CayenneAccelerationExec` wrapping the given execution plan. #[must_use] pub fn new(inner: Arc) -> Self { - Self { inner } + Self { + inner, + scan_identity: OnceLock::new(), + } + } + + /// Returns a stable identity for the underlying scan source, derived from + /// the `FileScanConfig`'s `object_store_url` plus the sorted set of file + /// paths backing the inner `DataSourceExec`. + /// + /// Two `CayenneAccelerationExec` nodes that scan the same set of physical + /// files return the same identity, which is the precondition for sharing a + /// runtime dynamic filter across them (see the cross-scan filter sharing + /// workstream documented in `crates/cayenne/src/optimizer_rules.rs`). + /// + /// Returns `None` if the inner plan does not contain a `DataSourceExec` + /// whose `DataSource` is a `FileScanConfig` with at least one file. Mixed + /// inlined-data scans use a `UnionExec`; their in-memory branch is ignored + /// and the identity is derived from the file-backed branch. The identity + /// intentionally ignores ordering of files within partitions and projection + /// differences — it is purely a per-table fingerprint. + /// + /// The `object_store_url` is required to disambiguate two stores that + /// happen to contain the same relative paths (e.g. two different S3 + /// buckets both with `part-000.vortex`). Without it the identity would + /// silently collide when paths are stored as relative locations. + #[must_use] + pub(crate) fn scan_identity(&self) -> Option> { + self.scan_identity + .get_or_init(|| compute_scan_identity(&self.inner)) + .as_ref() + .map(Arc::clone) + } + + /// Returns the dynamic filters currently pushed into this Cayenne scan. + /// + /// These filters originate from `DataFusion`'s hash-join dynamic-filter + /// pass. They are safe to share only when an optimizer has proven the target + /// scan is equi-joined on every referenced column. + #[must_use] + pub(crate) fn dynamic_filters(&self) -> Vec { + let mut filters = Vec::new(); + for file_scan_config in file_scan_configs(&self.inner) { + if let Some(filter) = file_scan_config.file_source().filter() { + collect_dynamic_filters(&filter, &mut filters); + } + } + filters + } + + /// Push additional dynamic filters into the underlying file source. + /// + /// Returns `Ok(None)` when the scan source declined all filters or the inner + /// plan is not a simple file scan. + /// + /// # Errors + /// + /// Returns an error when rebuilding the underlying `DataSourceExec` with the + /// additional filters fails. + pub(crate) fn with_additional_dynamic_filters( + &self, + filters: &[Arc], + config: &ConfigOptions, + ) -> Result>> { + let Some(inner) = + push_dynamic_filters_to_data_source(Arc::clone(&self.inner), filters, config)? + else { + return Ok(None); + }; + + Ok(Some(Arc::new(Self::new(inner)))) + } +} + +fn compute_scan_identity(plan: &Arc) -> Option> { + let file_scan_configs = file_scan_configs(plan); + let first_file_scan_config = file_scan_configs.first()?; + let object_store_url = first_file_scan_config.object_store_url.as_str(); + if file_scan_configs + .iter() + .any(|file_scan_config| file_scan_config.object_store_url.as_str() != object_store_url) + { + return None; + } + + let mut paths: Vec = file_scan_configs + .iter() + .flat_map(|file_scan_config| file_scan_config.file_groups.iter()) + .flat_map(datafusion_datasource::file_groups::FileGroup::iter) + .map(|pf| pf.object_meta.location.to_string()) + .collect(); + + if paths.is_empty() { + return None; + } + + paths.sort(); + paths.dedup(); + Some(Arc::new(ScanIdentity { + object_store_url: Arc::from(object_store_url), + paths: Arc::from(paths), + })) +} + +/// Stable identifier for a Cayenne scan source, derived from the +/// `FileScanConfig`'s `object_store_url` plus the sorted set of file paths +/// backing the underlying `DataSourceExec`. +/// +/// Equality and hashing are content-based on both the `object_store_url` and +/// the path set, so two `CayenneAccelerationExec` instances over the same +/// logical table compare equal regardless of projection, partitioning, or +/// wrapper-plan differences — and two scans over different stores that happen +/// to share a relative path (e.g. two S3 buckets each with `part-000.vortex`) +/// do *not* collide. The path set is reference-counted so copying a scan +/// identity during optimizer rewrites does not clone every file path. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct ScanIdentity { + object_store_url: Arc, + paths: Arc<[String]>, +} + +/// A dynamic filter currently attached to a Cayenne scan, plus the scan-local +/// column names the filter references. +#[derive(Clone)] +pub(crate) struct ScanDynamicFilter { + filter: Arc, + columns: BTreeSet, +} + +impl ScanDynamicFilter { + /// Returns the shared dynamic filter expression. + #[must_use] + pub(crate) fn filter(&self) -> &Arc { + &self.filter + } + + /// Returns the scan-local column names referenced by this filter. + #[must_use] + pub(crate) fn columns(&self) -> &BTreeSet { + &self.columns + } +} + +fn file_scan_configs(plan: &Arc) -> Vec<&FileScanConfig> { + let mut configs = Vec::new(); + collect_file_scan_configs(plan, &mut configs); + configs +} + +/// Walks `plan` looking for underlying file-backed `DataSourceExec` nodes, +/// descending only through a whitelist of operators that are known to preserve +/// scan identity, plus `UnionExec` for mixed file + inlined-memory scans. +/// +/// Cayenne plans typically wrap the data source in transparent or +/// near-transparent operators: `ProjectionExec`, `RepartitionExec`, +/// `CoalesceBatchesExec`, `CoalescePartitionsExec`, plus the runtime's +/// `BytesProcessedExec` / `SchemaCastScanExec` and the cayenne-internal +/// `InexactStatsExec`. Any one of those may sit between +/// `CayenneAccelerationExec` and the `DataSourceExec`. +/// +/// Cayenne tables with inlined rows add a `UnionExec` whose file-backed branch +/// should still participate in dynamic-filter sharing. Non-file children such +/// as `MemoryExec` are ignored; they stay unfiltered because inline batches are +/// intentionally small. +/// +/// Anything else with a single child (e.g. `FilterExec`, `SortExec`, +/// `LimitExec`, an unfamiliar custom node) is *not* identity-preserving for +/// our purposes — it may change cardinality, ordering, or the file-set +/// semantics the identity relies on. Collecting no file scans is safer than +/// misattributing identity: the worst that happens is dynamic-filter sharing is +/// conservatively disabled. +fn collect_file_scan_configs<'a>( + plan: &'a Arc, + configs: &mut Vec<&'a FileScanConfig>, +) { + if let Some(data_source_exec) = plan.as_any().downcast_ref::() { + if let Some(file_scan_config) = data_source_exec + .data_source() + .as_any() + .downcast_ref::() + { + configs.push(file_scan_config); + } + return; + } + + if plan.as_any().downcast_ref::().is_some() { + for child in plan.children() { + collect_file_scan_configs(child, configs); + } + return; + } + + if !is_identity_preserving_wrapper(plan) { + return; + } + + let children = plan.children(); + if children.len() != 1 { + return; + } + + collect_file_scan_configs(children[0], configs); +} + +/// Returns `true` if `plan` is a known transparent / near-transparent wrapper +/// that preserves the underlying scan's identity (same file set, same logical +/// rows, just resharded / renamed / instrumented). +/// +/// The check is by-type for the wrappers we have in-scope, and by `name()` for +/// the ones that live in other crates or are crate-private. Adding a new +/// wrapper requires touching this function explicitly — that's intentional; +/// it stops a future operator from silently being treated as transparent. +fn is_identity_preserving_wrapper(plan: &Arc) -> bool { + let any = plan.as_any(); + if any.downcast_ref::().is_some() + || any.downcast_ref::().is_some() + || any + .downcast_ref::() + .is_some() + || any + .downcast_ref::() + .is_some() + || any.downcast_ref::().is_some() + { + return true; + } + + // Cross-crate / crate-private wrappers we can't downcast to without + // pulling in their concrete types: match by the stable `name()` string. + matches!( + plan.name(), + "BytesProcessedExec" | "SchemaCastScanExec" | "InexactStatsExec" + ) +} + +fn collect_dynamic_filters(expr: &Arc, filters: &mut Vec) { + if let Some(dynamic_filter) = expr.as_any().downcast_ref::() { + if let Some(columns) = dynamic_filter_column_names(dynamic_filter) { + filters.push(ScanDynamicFilter { + filter: Arc::clone(expr), + columns, + }); + } + return; + } + + for child in expr.children() { + collect_dynamic_filters(child, filters); + } +} + +fn dynamic_filter_column_names( + dynamic_filter: &DynamicFilterPhysicalExpr, +) -> Option> { + let mut columns = BTreeSet::new(); + for child in dynamic_filter.children() { + let column = child.as_any().downcast_ref::()?; + columns.insert(column.name().to_string()); + } + + if columns.is_empty() { + None + } else { + Some(columns) } } +fn push_dynamic_filters_to_data_source( + plan: Arc, + filters: &[Arc], + optimizer_config: &ConfigOptions, +) -> Result>> { + if filters.is_empty() { + return Ok(None); + } + + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && let Some(file_scan_config) = data_source_exec + .data_source() + .as_any() + .downcast_ref::() + { + let filters = filters.iter().map(Arc::clone).collect(); + let propagation = file_scan_config + .file_source() + .try_pushdown_filters(filters, optimizer_config)?; + + let Some(updated_source) = propagation.updated_node else { + return Ok(None); + }; + + let mut updated_config = file_scan_config.clone(); + updated_config.file_source = updated_source; + let updated_exec = data_source_exec + .clone() + .with_data_source(Arc::new(updated_config)); + return Ok(Some(Arc::new(updated_exec))); + } + + let children = plan + .children() + .into_iter() + .map(Arc::clone) + .collect::>(); + if children.is_empty() { + return Ok(None); + } + + let is_union = plan.as_any().downcast_ref::().is_some(); + if !is_union && !is_identity_preserving_wrapper(&plan) { + return Ok(None); + } + if !is_union && children.len() != 1 { + return Ok(None); + } + + let mut changed = false; + let mut new_children = Vec::with_capacity(children.len()); + for child in children { + match push_dynamic_filters_to_data_source(Arc::clone(&child), filters, optimizer_config)? { + Some(updated_child) => { + changed = true; + new_children.push(updated_child); + } + None => new_children.push(child), + } + } + + if !changed { + return Ok(None); + } + + plan.with_new_children(new_children).map(Some) +} + pub(crate) fn round_robin_repartition_if_needed( plan: Arc, target_partitions: usize, @@ -357,4 +698,83 @@ mod tests { "projection-swapped Cayenne plan should stay wrapped for optimizer identification" ); } + + #[test] + fn scan_identity_returns_none_for_non_file_data_source() { + // MemorySourceConfig is not a FileScanConfig, so scan_identity must + // return None rather than misattributing identity. + let exec = CayenneAccelerationExec::new(one_partition_plan()); + assert!(exec.scan_identity().is_none()); + } + + #[test] + fn scan_identity_returns_none_when_inner_wraps_unknown_multi_child_plan() { + // A plan with multiple children (e.g. a join) cannot have a single + // unambiguous scan identity; find_data_source_exec must bail. + let left = one_partition_plan(); + let right = one_partition_plan(); + let schema = left.schema(); + let projection_expr = col("id", &schema).expect("id column should exist"); + + // Construct a 2-child wrapper via UnionExec to exercise the + // `children.len() != 1` early return without depending on join wiring. + let union = datafusion::physical_plan::union::UnionExec::try_new(vec![left, right]) + .expect("union exec should be created"); + + // Wrap in a projection so the top isn't a DataSourceExec. + let projection = ProjectionExec::try_new(vec![(projection_expr, "id".to_string())], union) + .expect("projection exec should be created"); + let exec = CayenneAccelerationExec::new(Arc::new(projection)); + assert!(exec.scan_identity().is_none()); + } + + #[test] + fn scan_identity_equality_and_hashing_are_path_based() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let a = ScanIdentity { + object_store_url: Arc::from("s3://bucket/"), + paths: Arc::from(vec!["a.parquet".to_string(), "b.parquet".to_string()]), + }; + let b = ScanIdentity { + object_store_url: Arc::from("s3://bucket/"), + paths: Arc::from(vec!["a.parquet".to_string(), "b.parquet".to_string()]), + }; + let c = ScanIdentity { + object_store_url: Arc::from("s3://bucket/"), + paths: Arc::from(vec!["a.parquet".to_string()]), + }; + + assert_eq!(a, b, "same path set must compare equal"); + assert_ne!(a, c, "different path sets must not compare equal"); + + let mut ha = DefaultHasher::new(); + a.hash(&mut ha); + let mut hb = DefaultHasher::new(); + b.hash(&mut hb); + // Verify Hash compiles and is content-based (we don't assert exact + // equality of finish() between distinct hashers, but both use the + // same content; the trait must be derivable from the inner fields). + let _ = (ha.finish(), hb.finish()); + + assert_eq!(a.object_store_url.as_ref(), "s3://bucket/"); + assert_eq!(a.paths.as_ref(), &["a.parquet", "b.parquet"]); + } + + #[test] + fn scan_identity_does_not_collide_across_object_stores() { + // Same relative paths across two different stores must produce + // distinct identities — otherwise cross-scan dynamic filters could + // mistakenly share state across unrelated tables. + let bucket_a = ScanIdentity { + object_store_url: Arc::from("s3://bucket-a/"), + paths: Arc::from(vec!["part-000.vortex".to_string()]), + }; + let bucket_b = ScanIdentity { + object_store_url: Arc::from("s3://bucket-b/"), + paths: Arc::from(vec!["part-000.vortex".to_string()]), + }; + assert_ne!(bucket_a, bucket_b); + } } diff --git a/crates/cayenne/src/provider/staging_wal.rs b/crates/cayenne/src/provider/staging_wal.rs index e1e9827718..6bde0d010f 100644 --- a/crates/cayenne/src/provider/staging_wal.rs +++ b/crates/cayenne/src/provider/staging_wal.rs @@ -54,6 +54,7 @@ limitations under the License. //! The legacy one-shot [`CayenneStagedAppend::commit`] is reimplemented in terms //! of this lifecycle and remains observably identical to the previous behavior. +use super::PartitionedWal; use super::Result; use super::constants::{STAGING_DIR_NAME, STAGING_WAL_FILENAME}; use super::table::CayenneTableProvider; @@ -539,10 +540,19 @@ impl CayenneTableProvider { })?; tokio::fs::write(&wal_path, content.as_bytes()).await?; - // fsync the WAL file to ensure it is durable before we begin moving files. + // fsync the WAL file content. let file = tokio::fs::File::open(&wal_path).await?; file.sync_all().await?; + // fsync the staging directory so that the directory entry for the newly + // written WAL file (and any data files previously written to this staging + // dir by `write_to_snapshot`) are durably persisted. This completes the + // "prepare" phase durability: the staging WAL record that lists the files + // to be moved is only considered durably written after its own directory + // entry is safe. Matches the full tmp+rename+dir-fsync pattern used for + // `PartitionedWal` and the syncs we perform after move and after WAL removal. + Self::sync_snapshot_dir(&staging_dir).await?; + tracing::debug!( "Wrote staging WAL for table {} with {} file(s) targeting snapshot {target_snapshot}", self.table_name(), @@ -649,14 +659,32 @@ impl CayenneTableProvider { let staging_dir = Self::snapshot_dir_path(self.table_path(), self.table_id(), STAGING_DIR_NAME); let wal_path = staging_dir.join(STAGING_WAL_FILENAME); - match tokio::fs::remove_file(&wal_path).await { - Ok(()) => {} - Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} + let removed = match tokio::fs::remove_file(&wal_path).await { + Ok(()) => true, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => true, // already gone = success state Err(e) => { tracing::warn!( "Failed to remove staging WAL for table {}: {e}", self.table_name(), ); + false + } + }; + + if removed { + // Durability: after removing the WAL marker (the "commit success" signal), + // fsync the staging directory so the unlink is persisted. A crash without + // this sync could make the removal non-durable, causing a false-positive + // "incomplete write" detection on the next open even though the data move + // succeeded and was synced. This completes the "WAL absent = durably + // committed" contract for local FS staged appends (symmetric to the + // sync after data file moves). + if let Err(e) = Self::sync_snapshot_dir(&staging_dir).await { + tracing::warn!( + "Failed to sync staging dir after WAL removal for table {}: {e} (data is safe; may see stale WAL on restart)", + self.table_name(), + ); + // Non-fatal: data files are already durable. A lingering WAL is conservative. } } } @@ -682,13 +710,32 @@ impl CayenneTableProvider { if let Some((wal, wal_location)) = wal { // Automated recovery attempt will be implemented in the future — for now we just error with details to help the operator resolve the issue. + // Best-effort enrichment: if this per-partition incomplete write was part + // of a cross-partition commit (i.e. a `PartitionedWal` record references + // this partition's table_id), include the commit_id in the error message. + // This helps operators correlate "incomplete write" errors across multiple + // partitions of the same logical table and points them at the + // `_partitioned_wal/` directory for manual resolution. + let mut extra = String::new(); + if let Ok(all_pw) = + PartitionedWal::read_all_in(std::path::Path::new(self.table_path())).await + { + for (pw, _) in all_pw { + if pw.partitions.iter().any(|e| e.table_id == self.table_id()) { + extra = format!(" (part of cross-partition commit {})", pw.commit_id); + break; + } + } + } + return Err(Error::IncompleteWrite { table: self.table_name().to_string(), message: format!( - "A previous write was interrupted while moving {} file(s) to '{}' (started at {}). Some files may have been partially written and require manual resolution. The WAL file is located at '{wal_location}'.", + "A previous write was interrupted while moving {} file(s) to '{}' (started at {}). Some files may have been partially written and require manual resolution. The WAL file is located at '{wal_location}'.{}", wal.staged_files.len(), wal.target_snapshot, wal.created_at, + extra, ), }); } diff --git a/crates/cayenne/src/provider/table.rs b/crates/cayenne/src/provider/table.rs index 9bc7ddfd80..1226eb5020 100644 --- a/crates/cayenne/src/provider/table.rs +++ b/crates/cayenne/src/provider/table.rs @@ -35,10 +35,24 @@ use crate::metadata::{ use crate::provider::scan::{CayenneAccelerationExec, round_robin_repartition_if_needed}; use crate::provider::sink::CayenneDataSink; use crate::provider::{Error, Result}; -use arrow::array::Array; +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, + FixedSizeBinaryArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, + Int64Array, LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, +}; +use arrow::compute::kernels::aggregate; +use arrow::datatypes::{ + Date32Type, Date64Type, Decimal128Type, Int8Type, Int16Type, Int32Type, Int64Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; use arrow::record_batch::RecordBatch; use arrow_row::{OwnedRow, RowConverter, SortField}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, SchemaRef, TimeUnit}; use async_trait::async_trait; use data_components::delete::{DeletionExec, DeletionSink}; use datafusion::datasource::file_format::FileFormat; @@ -52,7 +66,7 @@ use datafusion::optimizer::analyzer::type_coercion::TypeCoercionRewriter; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion_catalog::{Session, TableProvider}; use datafusion_common::tree_node::TreeNode; -use datafusion_common::{ColumnStatistics, Constraints, DFSchema, Statistics}; +use datafusion_common::{ColumnStatistics, Constraints, DFSchema, ScalarValue, Statistics}; use datafusion_execution::cache::TableScopedPath; use datafusion_execution::config::SessionConfig; use datafusion_expr::dml::InsertOp; @@ -206,6 +220,22 @@ impl ColumnStatsAccumulator { }; } + let (batch_min, batch_max) = + Self::fast_column_min_max(col).unwrap_or_else(|| Self::scalar_column_min_max(col)); + + datafusion_common::ColumnStatistics { + null_count, + min_value: batch_min.map_or(Precision::Absent, Precision::Exact), + max_value: batch_max.map_or(Precision::Absent, Precision::Exact), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + } + } + + fn scalar_column_min_max( + col: &dyn arrow::array::Array, + ) -> (Option, Option) { // O(n) linear scan to find min/max using `ScalarValue` comparison. // NaN values are skipped entirely so stats remain deterministic. let mut batch_min: Option = None; @@ -246,16 +276,207 @@ impl ColumnStatsAccumulator { }); } - datafusion_common::ColumnStatistics { - null_count, - min_value: batch_min.map_or(Precision::Absent, Precision::Exact), - max_value: batch_max.map_or(Precision::Absent, Precision::Exact), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - byte_size: Precision::Absent, + (batch_min, batch_max) + } + + fn fast_column_min_max( + col: &dyn arrow::array::Array, + ) -> Option<(Option, Option)> { + macro_rules! primitive_min_max { + ($array_ty:ty, $arrow_ty:ty, |$value:ident| $scalar:expr) => {{ + let array = col.as_any().downcast_ref::<$array_ty>()?; + let min_value = aggregate::min::<$arrow_ty>(array).map(|$value| $scalar); + let max_value = aggregate::max::<$arrow_ty>(array).map(|$value| $scalar); + Some((min_value, max_value)) + }}; + } + + macro_rules! byte_min_max { + ($array_ty:ty, $min_fn:ident, $max_fn:ident, |$value:ident| $scalar:expr) => {{ + let array = col.as_any().downcast_ref::<$array_ty>()?; + let min_value = aggregate::$min_fn(array).map(|$value| $scalar); + let max_value = aggregate::$max_fn(array).map(|$value| $scalar); + Some((min_value, max_value)) + }}; + } + + match col.data_type() { + DataType::Boolean => { + let array = col.as_any().downcast_ref::()?; + Some(( + aggregate::min_boolean(array).map(|value| ScalarValue::Boolean(Some(value))), + aggregate::max_boolean(array).map(|value| ScalarValue::Boolean(Some(value))), + )) + } + DataType::Int8 => primitive_min_max!(Int8Array, Int8Type, |value| { + ScalarValue::Int8(Some(value)) + }), + DataType::Int16 => primitive_min_max!(Int16Array, Int16Type, |value| { + ScalarValue::Int16(Some(value)) + }), + DataType::Int32 => primitive_min_max!(Int32Array, Int32Type, |value| { + ScalarValue::Int32(Some(value)) + }), + DataType::Int64 => primitive_min_max!(Int64Array, Int64Type, |value| { + ScalarValue::Int64(Some(value)) + }), + DataType::UInt8 => primitive_min_max!(UInt8Array, UInt8Type, |value| { + ScalarValue::UInt8(Some(value)) + }), + DataType::UInt16 => primitive_min_max!(UInt16Array, UInt16Type, |value| { + ScalarValue::UInt16(Some(value)) + }), + DataType::UInt32 => primitive_min_max!(UInt32Array, UInt32Type, |value| { + ScalarValue::UInt32(Some(value)) + }), + DataType::UInt64 => primitive_min_max!(UInt64Array, UInt64Type, |value| { + ScalarValue::UInt64(Some(value)) + }), + DataType::Float32 => { + let array = col.as_any().downcast_ref::()?; + let (min_value, max_value) = Self::float32_min_max(array); + Some(( + min_value.map(|value| ScalarValue::Float32(Some(value))), + max_value.map(|value| ScalarValue::Float32(Some(value))), + )) + } + DataType::Float64 => { + let array = col.as_any().downcast_ref::()?; + let (min_value, max_value) = Self::float64_min_max(array); + Some(( + min_value.map(|value| ScalarValue::Float64(Some(value))), + max_value.map(|value| ScalarValue::Float64(Some(value))), + )) + } + DataType::Decimal128(precision, scale) => { + primitive_min_max!(Decimal128Array, Decimal128Type, |value| { + ScalarValue::Decimal128(Some(value), *precision, *scale) + }) + } + DataType::Utf8 => byte_min_max!(StringArray, min_string, max_string, |value| { + ScalarValue::Utf8(Some(value.to_string())) + }), + DataType::LargeUtf8 => { + byte_min_max!(LargeStringArray, min_string, max_string, |value| { + ScalarValue::LargeUtf8(Some(value.to_string())) + }) + } + DataType::Utf8View => { + byte_min_max!(StringViewArray, min_string_view, max_string_view, |value| { + ScalarValue::Utf8View(Some(value.to_string())) + }) + } + DataType::Binary => byte_min_max!(BinaryArray, min_binary, max_binary, |value| { + ScalarValue::Binary(Some(value.to_vec())) + }), + DataType::LargeBinary => { + byte_min_max!(LargeBinaryArray, min_binary, max_binary, |value| { + ScalarValue::LargeBinary(Some(value.to_vec())) + }) + } + DataType::BinaryView => { + byte_min_max!(BinaryViewArray, min_binary_view, max_binary_view, |value| { + ScalarValue::BinaryView(Some(value.to_vec())) + }) + } + DataType::FixedSizeBinary(size) => byte_min_max!( + FixedSizeBinaryArray, + min_fixed_size_binary, + max_fixed_size_binary, + |value| { ScalarValue::FixedSizeBinary(*size, Some(value.to_vec())) } + ), + DataType::Date32 => primitive_min_max!(Date32Array, Date32Type, |value| { + ScalarValue::Date32(Some(value)) + }), + DataType::Date64 => primitive_min_max!(Date64Array, Date64Type, |value| { + ScalarValue::Date64(Some(value)) + }), + DataType::Time32(TimeUnit::Second) => { + primitive_min_max!(Time32SecondArray, Time32SecondType, |value| { + ScalarValue::Time32Second(Some(value)) + }) + } + DataType::Time32(TimeUnit::Millisecond) => { + primitive_min_max!(Time32MillisecondArray, Time32MillisecondType, |value| { + ScalarValue::Time32Millisecond(Some(value)) + }) + } + DataType::Time64(TimeUnit::Microsecond) => { + primitive_min_max!(Time64MicrosecondArray, Time64MicrosecondType, |value| { + ScalarValue::Time64Microsecond(Some(value)) + }) + } + DataType::Time64(TimeUnit::Nanosecond) => { + primitive_min_max!(Time64NanosecondArray, Time64NanosecondType, |value| { + ScalarValue::Time64Nanosecond(Some(value)) + }) + } + DataType::Timestamp(TimeUnit::Second, tz) => { + primitive_min_max!(TimestampSecondArray, TimestampSecondType, |value| { + ScalarValue::TimestampSecond(Some(value), tz.clone()) + }) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => primitive_min_max!( + TimestampMillisecondArray, + TimestampMillisecondType, + |value| { ScalarValue::TimestampMillisecond(Some(value), tz.clone()) } + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => primitive_min_max!( + TimestampMicrosecondArray, + TimestampMicrosecondType, + |value| { ScalarValue::TimestampMicrosecond(Some(value), tz.clone()) } + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + primitive_min_max!(TimestampNanosecondArray, TimestampNanosecondType, |value| { + ScalarValue::TimestampNanosecond(Some(value), tz.clone()) + }) + } + _ => None, } } + fn float32_min_max(array: &Float32Array) -> (Option, Option) { + let mut min_value: Option = None; + let mut max_value: Option = None; + + for value in array.iter().flatten() { + if value.is_nan() { + continue; + } + min_value = Some(match min_value { + Some(current) if current <= value => current, + _ => value, + }); + max_value = Some(match max_value { + Some(current) if current >= value => current, + _ => value, + }); + } + + (min_value, max_value) + } + + fn float64_min_max(array: &Float64Array) -> (Option, Option) { + let mut min_value: Option = None; + let mut max_value: Option = None; + + for value in array.iter().flatten() { + if value.is_nan() { + continue; + } + min_value = Some(match min_value { + Some(current) if current <= value => current, + _ => value, + }); + max_value = Some(match max_value { + Some(current) if current >= value => current, + _ => value, + }); + } + + (min_value, max_value) + } + /// Get the total accumulated row count. pub(crate) fn row_count(&self) -> i64 { self.row_count.load(std::sync::atomic::Ordering::Relaxed) @@ -1200,9 +1421,11 @@ impl CayenneTableProvider { // Create listing options for Vortex format. /// - /// Only wraps the `VortexFormat` with `DeletionFilteringVortexFormat` for - /// `PositionBased` strategy. PK-based strategies (`Int64Pk`, `RowConverterBased`) - /// filter at the `ExecutionPlan` level, not during file reading. + /// Always wraps the `VortexFormat` so Cayenne-specific Vortex predicate + /// pushdown guards apply to every scan. `PositionBased` additionally + /// attaches deletion vectors during file reading; PK-based strategies + /// (`Int64Pk`, `RowConverterBased`) still filter at the `ExecutionPlan` + /// level. fn create_listing_options( vortex_format: &Arc, strategy: &PkDeletionStrategyWithCache, @@ -1216,9 +1439,9 @@ impl CayenneTableProvider { Arc::clone(cached_deleted_row_ids), )), PkDeletionStrategyWithCache::Int64Pk { .. } - | PkDeletionStrategyWithCache::RowConverterBased { .. } => { - Arc::clone(vortex_format) as Arc - } + | PkDeletionStrategyWithCache::RowConverterBased { .. } => Arc::new( + DeletionFilteringVortexFormat::without_deletion_vectors(Arc::clone(vortex_format)), + ), }; ListingOptions::new(file_format).with_session_config_options(session_config) } @@ -1254,7 +1477,25 @@ impl CayenneTableProvider { snapshot_dir: &std::path::Path, ) -> std::io::Result<()> { if !snapshot_dir.exists() { + // Capture the parent before creation so we can sync it afterwards. + let parent = snapshot_dir.parent().map(std::path::Path::to_path_buf); tokio::fs::create_dir_all(snapshot_dir).await?; + + // Make the *creation of the new snapshot directory itself* durable. + // On POSIX, creating a subdirectory updates the parent's directory + // metadata. Without syncing the parent, a crash can make the new + // snapshot directory "disappear" from the filesystem even though + // we later write files into it and commit the catalog to point at it. + // This is the same durability requirement we enforce for file + // creation, renames, and WAL marker removal elsewhere in the code. + if let Some(parent) = parent { + tokio::task::spawn_blocking(move || { + let f = std::fs::File::open(&parent)?; + f.sync_all() + }) + .await + .map_err(std::io::Error::other)??; + } } Ok(()) } @@ -1357,6 +1598,13 @@ impl CayenneTableProvider { table_name = self.table_metadata.table_name, ); + // Durability: fsync the target snapshot directory so that the rename operations + // are persisted before the caller removes the staging WAL. This ensures that + // "WAL absent" truly means the data files are durable on disk (ACID Durability + // for the staged append path on local filesystems). Matches the sync performed + // in the sort-rewrite / compaction path before metadata commit. + Self::sync_snapshot_dir(&target_dir).await?; + Ok(()) } @@ -1847,6 +2095,17 @@ impl CayenneTableProvider { ) .await?; + // Sync the new snapshot directory for durability before recording the + // sequence number in the catalog. This is required for the same reason + // as in the sort-rewrite and normal append paths: the Vortex files must + // be durably present before the catalog metadata that makes them + // visible (via sequence number / protected snapshot) is committed. + let is_s3 = self.table_metadata.path.starts_with("s3://"); + if !is_s3 { + let snapshot_dir = self.snapshot_dir_path_for(&new_snapshot_id); + Self::sync_snapshot_dir(&snapshot_dir).await?; + } + tracing::debug!( "Insert to new snapshot {} completed, wrote {} rows to Vortex in {} chunk(s)", new_snapshot_id, @@ -6185,6 +6444,70 @@ mod tests { ); } + #[test] + fn compute_column_stats_uses_typed_min_max_for_int64() { + let array = Int64Array::from(vec![Some(10), None, Some(-4), Some(7)]); + + let stats = ColumnStatsAccumulator::compute_column_stats(&array); + + assert_eq!( + stats.null_count, + datafusion_common::stats::Precision::Exact(1) + ); + assert_eq!( + stats.min_value, + datafusion_common::stats::Precision::Exact(ScalarValue::Int64(Some(-4))) + ); + assert_eq!( + stats.max_value, + datafusion_common::stats::Precision::Exact(ScalarValue::Int64(Some(10))) + ); + } + + #[test] + fn compute_column_stats_skips_float_nan_values() { + let array = Float64Array::from(vec![Some(f64::NAN), Some(5.0), None, Some(-2.0)]); + + let stats = ColumnStatsAccumulator::compute_column_stats(&array); + + assert_eq!( + stats.null_count, + datafusion_common::stats::Precision::Exact(1) + ); + assert_eq!( + stats.min_value, + datafusion_common::stats::Precision::Exact(ScalarValue::Float64(Some(-2.0))) + ); + assert_eq!( + stats.max_value, + datafusion_common::stats::Precision::Exact(ScalarValue::Float64(Some(5.0))) + ); + } + + #[test] + fn compute_column_stats_uses_typed_min_max_for_utf8_view() { + let array = StringViewArray::from(vec![Some("beta"), Some("alpha"), None]); + + let stats = ColumnStatsAccumulator::compute_column_stats(&array); + + assert_eq!( + stats.null_count, + datafusion_common::stats::Precision::Exact(1) + ); + assert_eq!( + stats.min_value, + datafusion_common::stats::Precision::Exact(ScalarValue::Utf8View(Some( + "alpha".to_string() + ))) + ); + assert_eq!( + stats.max_value, + datafusion_common::stats::Precision::Exact(ScalarValue::Utf8View(Some( + "beta".to_string() + ))) + ); + } + #[test] fn statistics_to_inexact_downgrades_exact_values_for_mutable_overlays() { let stats = Statistics { diff --git a/crates/cayenne/src/provider/vortex_format.rs b/crates/cayenne/src/provider/vortex_format.rs index 637a8dc070..dabab2107d 100644 --- a/crates/cayenne/src/provider/vortex_format.rs +++ b/crates/cayenne/src/provider/vortex_format.rs @@ -29,19 +29,30 @@ limitations under the License. use std::any::Any; use std::collections::HashMap; +use std::fmt::Formatter; use std::sync::Arc; use arc_swap::ArcSwap; +use arrow_schema::{DataType, Schema}; use async_trait::async_trait; use datafusion::datasource::file_format::FileFormat; use datafusion_catalog::Session; use datafusion_common::Result as DFResult; use datafusion_common::Statistics; +use datafusion_common::config::ConfigOptions; use datafusion_datasource::PartitionedFile; use datafusion_datasource::TableSchema; +use datafusion_datasource::file::FileSource; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfig; +use datafusion_datasource::file_stream::FileOpener; +use datafusion_datasource::source::DataSourceExec; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions as df_expr; +use datafusion_physical_expr::projection::ProjectionExprs; +use datafusion_physical_plan::filter_pushdown::{FilterPushdownPropagation, PushedDown}; +use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use object_store::{ObjectMeta, ObjectStore}; use roaring::{RoaringBitmap, RoaringTreemap}; @@ -187,6 +198,13 @@ impl DeletionFilteringVortexFormat { } } + /// Create a wrapper that installs Cayenne Vortex predicate-pushdown guards + /// without applying any deletion vectors. + #[must_use] + pub fn without_deletion_vectors(inner: Arc) -> Self { + Self::new(inner, Arc::new(ArcSwap::from_pointee(HashMap::new()))) + } + /// Attach `VortexAccessPlan` extensions to files with deletion vectors. /// /// This is a convenience method that delegates to [`attach_deletion_vectors_to_config`]. @@ -198,7 +216,7 @@ impl DeletionFilteringVortexFormat { #[async_trait] impl FileFormat for DeletionFilteringVortexFormat { fn as_any(&self) -> &dyn Any { - self + self.inner.as_any() } fn compression_type(&self) -> Option { @@ -274,6 +292,7 @@ impl FileFormat for DeletionFilteringVortexFormat { .inner .create_physical_plan(state, modified_config) .await?; + let plan = wrap_vortex_file_sources(plan)?; // If there are deletions, wrap the plan to force inexact statistics. // This prevents AggregateStatistics optimizer from short-circuiting @@ -306,6 +325,229 @@ impl FileFormat for DeletionFilteringVortexFormat { } } +fn wrap_vortex_file_sources(plan: Arc) -> DFResult> { + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && let Some(file_scan_config) = data_source_exec + .data_source() + .as_any() + .downcast_ref::() + { + let mut wrapped_config = file_scan_config.clone(); + wrapped_config.file_source = Arc::new(CayenneVortexFileSource::new(Arc::clone( + file_scan_config.file_source(), + ))); + + let new_exec = data_source_exec + .clone() + .with_data_source(Arc::new(wrapped_config)); + return Ok(Arc::new(new_exec)); + } + + let children = plan.children(); + if children.is_empty() { + return Ok(plan); + } + + let new_children = children + .into_iter() + .map(|child| wrap_vortex_file_sources(Arc::clone(child))) + .collect::>>()?; + + plan.with_new_children(new_children) +} + +#[derive(Clone)] +struct CayenneVortexFileSource { + inner: Arc, +} + +impl CayenneVortexFileSource { + fn new(inner: Arc) -> Self { + Self { inner } + } +} + +impl std::fmt::Debug for CayenneVortexFileSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CayenneVortexFileSource") + .field("file_type", &self.inner.file_type()) + .finish() + } +} + +impl FileSource for CayenneVortexFileSource { + fn create_file_opener( + &self, + object_store: Arc, + base_config: &FileScanConfig, + partition: usize, + ) -> DFResult> { + self.inner + .create_file_opener(object_store, base_config, partition) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_schema(&self) -> &TableSchema { + self.inner.table_schema() + } + + fn with_batch_size(&self, batch_size: usize) -> Arc { + Arc::new(Self::new(self.inner.with_batch_size(batch_size))) + } + + fn filter(&self) -> Option> { + self.inner.filter() + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.inner.projection() + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + self.inner.metrics() + } + + fn file_type(&self) -> &str { + self.inner.file_type() + } + + fn fmt_extra(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result { + self.inner.fmt_extra(t, f) + } + + fn supports_repartitioning(&self) -> bool { + self.inner.supports_repartitioning() + } + + fn repartitioned( + &self, + target_partitions: usize, + repartition_file_min_size: usize, + output_ordering: Option, + config: &FileScanConfig, + ) -> DFResult> { + self.inner.repartitioned( + target_partitions, + repartition_file_min_size, + output_ordering, + config, + ) + } + + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> DFResult>> { + let schema = self.inner.table_schema().file_schema(); + let mut safe_filters = Vec::new(); + let mut safe_filter_indexes = Vec::new(); + let mut pushdown_results = vec![PushedDown::No; filters.len()]; + + for (index, filter) in filters.into_iter().enumerate() { + if contains_decimal_to_floating_cast(filter.as_ref(), schema) { + tracing::debug!( + %filter, + "Skipping Vortex predicate pushdown for decimal-to-floating cast" + ); + continue; + } + + safe_filter_indexes.push(index); + safe_filters.push(filter); + } + + if safe_filters.is_empty() { + return Ok(FilterPushdownPropagation::with_parent_pushdown_result( + pushdown_results, + )); + } + + let inner_propagation = self.inner.try_pushdown_filters(safe_filters, config)?; + + for (safe_index, result) in safe_filter_indexes + .into_iter() + .zip(inner_propagation.filters.into_iter()) + { + pushdown_results[safe_index] = result; + } + + let mut propagation = + FilterPushdownPropagation::with_parent_pushdown_result(pushdown_results); + if let Some(updated_node) = inner_propagation.updated_node { + propagation = propagation.with_updated_node(Arc::new(Self::new(updated_node)) as _); + } + + Ok(propagation) + } + + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> DFResult>> { + let schema = self.inner.table_schema().file_schema(); + if contains_decimal_to_floating_projection(projection, schema) { + tracing::debug!( + %projection, + "Skipping Vortex projection pushdown for decimal-to-floating cast" + ); + return Ok(None); + } + + self.inner + .try_pushdown_projection(projection) + .map(|source| source.map(|source| Arc::new(Self::new(source)) as _)) + } +} + +fn contains_decimal_to_floating_projection(projection: &ProjectionExprs, schema: &Schema) -> bool { + projection + .iter() + .any(|expr| contains_decimal_to_floating_cast(expr.expr.as_ref(), schema)) +} + +fn contains_decimal_to_floating_cast(expr: &dyn PhysicalExpr, schema: &Schema) -> bool { + if let Some(cast) = expr.as_any().downcast_ref::() { + let casts_to_floating = matches!(cast.cast_type(), DataType::Float32 | DataType::Float64); + // Resolve the input type of the cast using the provided schema. + // If resolution fails (e.g. Column index mismatch because the schema passed + // is the raw file_schema while the filter expr was built against a projected + // or wrapper-adjusted schema), conservatively treat it as a bad cast. + // This prevents accidentally pushing a decimal→float cast filter to Vortex, + // which can produce wrong comparison/NULL results due to precision differences. + let casts_from_decimal = match cast.expr().data_type(schema) { + Ok(data_type) => matches!( + data_type, + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ), + Err(_) => true, // cannot prove safe → skip pushdown (correctness first) + }; + + if casts_to_floating && casts_from_decimal { + return true; + } + } + + if let Some(dynamic_filter) = expr + .as_any() + .downcast_ref::() + && let Ok(current) = dynamic_filter.current() + && contains_decimal_to_floating_cast(current.as_ref(), schema) + { + return true; + } + + expr.children() + .into_iter() + .any(|child| contains_decimal_to_floating_cast(child.as_ref(), schema)) +} + /// A wrapper execution plan that forces inexact row count statistics. /// /// This is used to wrap scan plans when there are deletions, preventing @@ -314,13 +556,11 @@ impl FileFormat for DeletionFilteringVortexFormat { #[derive(Debug)] struct InexactStatsExec { inner: Arc, - properties: PlanProperties, } impl InexactStatsExec { fn new(inner: Arc) -> Self { - let properties = inner.properties().clone(); - Self { inner, properties } + Self { inner } } } @@ -346,7 +586,7 @@ impl ExecutionPlan for InexactStatsExec { } fn properties(&self) -> &PlanProperties { - &self.properties + self.inner.properties() } fn children(&self) -> Vec<&Arc> { @@ -382,4 +622,164 @@ impl ExecutionPlan for InexactStatsExec { column_statistics: stats.column_statistics, }) } + + #[expect(deprecated)] + fn statistics(&self) -> DFResult { + // Delegate and then force inexact row count for safety, in case any code path + // still uses the deprecated statistics() method. + let stats = self.inner.statistics()?; + Ok(Statistics { + num_rows: stats.num_rows.to_inexact(), + total_byte_size: stats.total_byte_size, + column_statistics: stats.column_statistics, + }) + } + + fn metrics(&self) -> Option { + self.inner.metrics() + } + + fn supports_limit_pushdown(&self) -> bool { + self.inner.supports_limit_pushdown() + } + + fn fetch(&self) -> Option { + self.inner.fetch() + } + + fn with_fetch(&self, limit: Option) -> Option> { + self.inner + .with_fetch(limit) + .map(|plan| Arc::new(Self::new(plan)) as Arc) + } + + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> DFResult>> { + self.inner + .repartitioned(target_partitions, config) + .map(|plan| plan.map(|plan| Arc::new(Self::new(plan)) as Arc)) + } + + fn try_swapping_with_projection( + &self, + projection: &datafusion_physical_plan::projection::ProjectionExec, + ) -> DFResult>> { + self.inner + .try_swapping_with_projection(projection) + .map(|plan| plan.map(|plan| Arc::new(Self::new(plan)) as _)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{ + BinaryExpr, CastExpr, Column, DynamicFilterPhysicalExpr, Literal, + }; + use datafusion_physical_expr::expressions::{col, lit}; + use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; + + #[test] + fn detects_decimal_to_floating_cast_predicate() { + let schema = Schema::new(vec![Field::new( + "amount", + DataType::Decimal128(15, 2), + true, + )]); + let amount = Arc::new(Column::new("amount", 0)) as Arc; + let cast = + Arc::new(CastExpr::new(amount, DataType::Float64, None)) as Arc; + let literal = + Arc::new(Literal::new(ScalarValue::Float64(Some(1.0)))) as Arc; + let predicate = BinaryExpr::new(cast, Operator::Lt, literal); + + assert!(contains_decimal_to_floating_cast(&predicate, &schema)); + } + + #[test] + fn allows_decimal_to_decimal_predicate() { + let schema = Schema::new(vec![Field::new( + "amount", + DataType::Decimal128(15, 2), + true, + )]); + let amount = col("amount", &schema).expect("amount column should exist"); + let literal = lit(ScalarValue::Decimal128(Some(100), 15, 2)); + let predicate = BinaryExpr::new(amount, Operator::Lt, literal); + + assert!(!contains_decimal_to_floating_cast(&predicate, &schema)); + } + + #[test] + fn allows_dynamic_filter_without_decimal_to_floating_current_predicate() { + let schema = Schema::new(vec![Field::new("amount", DataType::Int64, true)]); + let amount = col("amount", &schema).expect("amount column should exist"); + let dynamic_filter = DynamicFilterPhysicalExpr::new(vec![amount], lit(true)); + + assert!(!contains_decimal_to_floating_cast(&dynamic_filter, &schema)); + } + + #[test] + fn detects_dynamic_filter_decimal_to_floating_current_predicate() { + let schema = Schema::new(vec![Field::new( + "amount", + DataType::Decimal128(15, 2), + true, + )]); + let amount = Arc::new(Column::new("amount", 0)) as Arc; + let cast = Arc::new(CastExpr::new(amount, DataType::Float64, None)); + let literal = Arc::new(Literal::new(ScalarValue::Float64(Some(1.0)))); + let predicate = Arc::new(BinaryExpr::new(cast, Operator::Lt, literal)); + let dynamic_filter = DynamicFilterPhysicalExpr::new( + vec![col("amount", &schema).expect("amount column should exist")], + predicate, + ); + + assert!(contains_decimal_to_floating_cast(&dynamic_filter, &schema)); + } + + #[test] + fn detects_decimal_to_floating_cast_projection() { + let schema = Schema::new(vec![Field::new( + "amount", + DataType::Decimal128(15, 2), + true, + )]); + let amount = Arc::new(Column::new("amount", 0)) as Arc; + let cast = Arc::new(CastExpr::new(amount, DataType::Float64, None)); + let projection = ProjectionExprs::new([ProjectionExpr { + expr: cast, + alias: "amount_f64".to_string(), + }]); + + assert!(contains_decimal_to_floating_projection( + &projection, + &schema + )); + } + + #[test] + fn allows_plain_decimal_projection() { + let schema = Schema::new(vec![Field::new( + "amount", + DataType::Decimal128(15, 2), + true, + )]); + let amount = Arc::new(Column::new("amount", 0)) as Arc; + let projection = ProjectionExprs::new([ProjectionExpr { + expr: amount, + alias: "amount".to_string(), + }]); + + assert!(!contains_decimal_to_floating_projection( + &projection, + &schema + )); + } } diff --git a/crates/cayenne/tests/cross_partition_overwrite_test.rs b/crates/cayenne/tests/cross_partition_overwrite_test.rs index 7d5904ba5f..021936edcb 100644 --- a/crates/cayenne/tests/cross_partition_overwrite_test.rs +++ b/crates/cayenne/tests/cross_partition_overwrite_test.rs @@ -23,6 +23,38 @@ You may obtain a copy of the License at //! atomically — either every partition advances or none do. //! - Rolling back the shared transaction (or surfacing an error from //! `apply_in_txn`) leaves every partition at its prior snapshot pointer. +//! +//! Durability note (as of the fixes in this branch): +//! All local-FS directory creation points that are part of the write + +//! crash-recovery infrastructure now perform the required parent-directory +//! sync after `create_dir_all` (snapshot directories via +//! `ensure_snapshot_dir_exists` (including initial table creation before +//! metastore INSERT), the `_partitioned_wal/` coordination directory via the +//! helper in `PartitionedWal::write_to`, `deletions/` subdirectories under +//! snapshots via `DeletionVectorWriter`, and partition value subdirectories +//! via `CayennePartitionCreator` before `add_partition`). +//! The catalog DB directory creation in `CayenneCatalog::init` also +//! receives a best-effort parent sync for completeness of the system +//! initialization path. +//! Combined with the per-partition staging WAL, deletion vector file +//! `sync_all`, and directory syncs in the delete sinks, a successful +//! cross-partition operation (append or overwrite, including any +//! concurrent or pending deletions or new partitions) leaves a fully +//! durable set of coordination records and data files on local FS. +//! The existing fault-injection and restart tests in this file, together +//! with the per-partition durability tests (deletion vector restart, +//! staged-append restart, `acid_compliance`, `data_inlining`, catalog +//! concurrency with partitions, and `shared_metastore_concurrency_test` +//! which exercises fresh catalog DB directory creation in both +//! `CayenneCatalog::init` and the SQLite/Turso metastore backends with +//! best-effort parent sync + warning), +//! provide comprehensive regression coverage for this property, +//! including the edge cases of the very first cross-partition write on a +//! brand-new table (first creation of the `_partitioned_wal/` directory), +//! the first deletion vector written to a snapshot, the first discovery +//! of a new partition value, and first-time catalog initialization on a +//! brand-new data directory (including defense-in-depth in the +//! connection setup paths). #![expect( clippy::expect_used, diff --git a/crates/data_components/src/elasticsearch/query_table.rs b/crates/data_components/src/elasticsearch/query_table.rs index 2f37a86dda..81c938bc11 100644 --- a/crates/data_components/src/elasticsearch/query_table.rs +++ b/crates/data_components/src/elasticsearch/query_table.rs @@ -27,7 +27,7 @@ use chrono::DateTime; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, LargeStringBuilder, ListBuilder, RecordBatch, StringArray, StringBuilder, - TimestampMicrosecondArray, + TimestampMicrosecondArray, UInt64Array, }; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use async_trait::async_trait; @@ -430,6 +430,24 @@ fn build_array_from_hits( .collect(); Ok(Arc::new(Int64Array::from(values)) as ArrayRef) } + // ES `unsigned_long` maps to Arrow `UInt64` in schema.rs. Decode using + // `as_u64` so values up to u64::MAX round-trip without being clipped + // through i64. JS clients commonly serialize values > 2^53-1 as digit + // strings (since JSON `number` can't represent them safely), and ES + // preserves that representation in `_source`, so also accept a numeric + // string. Values outside u64 range (incl. negative numerics) yield NULL. + DataType::UInt64 => { + let values: Vec> = hits + .iter() + .map(|h| { + extract_field(&h.source, field_name).and_then(|v| { + v.as_u64() + .or_else(|| v.as_str().and_then(|s| s.parse::().ok())) + }) + }) + .collect(); + Ok(Arc::new(UInt64Array::from(values)) as ArrayRef) + } DataType::Int32 => { let values: Vec> = hits .iter() @@ -954,6 +972,44 @@ mod tests { assert_eq!(closed.as_slice(), ["pit-1"]); } + // ── UInt64 (ES `unsigned_long`) ──────────────────────────────────────────── + + /// schema.rs maps ES `unsigned_long` to Arrow `UInt64`. Without a dedicated + /// decoder arm, the schema would say `UInt64` while the decoder fell into + /// the JSON-string fallback, blowing up at `RecordBatch` construction with a + /// schema/data type mismatch. + #[test] + fn test_unsigned_long_decodes_to_uint64() { + use arrow::array::UInt64Array; + + let schema = Arc::new(Schema::new(vec![Field::new("big", DataType::UInt64, true)])); + // u64::MAX would silently lose the high bit if we routed through i64; + // include it explicitly to lock in the as_u64 decoding path. + let max = u64::MAX; + let hits = vec![ + make_hit(json!({"big": 0_u64})), + make_hit(json!({"big": max})), + make_hit(json!({})), // missing → null + make_hit(json!({"big": -1_i64})), // negative → null (out of u64 range) + // JS-style stringified large values land in _source as strings. + make_hit(json!({"big": "18446744073709551614"})), + make_hit(json!({"big": "not a number"})), // unparseable → null + ]; + let batch = hits_to_record_batch(&hits, &schema).expect("hits_to_record_batch failed"); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("column 0 should be UInt64Array"); + + assert_eq!(col.value(0), 0); + assert_eq!(col.value(1), max); + assert!(col.is_null(2)); + assert!(col.is_null(3)); + assert_eq!(col.value(4), 18_446_744_073_709_551_614_u64); + assert!(col.is_null(5)); + } + // ── Timestamp ────────────────────────────────────────────────────────────── #[test] diff --git a/crates/data_components/src/elasticsearch/schema.rs b/crates/data_components/src/elasticsearch/schema.rs index c4a1b40ebe..46f0ae1f6c 100644 --- a/crates/data_components/src/elasticsearch/schema.rs +++ b/crates/data_components/src/elasticsearch/schema.rs @@ -73,6 +73,8 @@ fn es_type_to_arrow(mapping: &FieldMapping) -> DataType { DataType::Utf8 } Some("long") => DataType::Int64, + // unsigned_long covers the full u64 range; Int64 would silently overflow values > i64::MAX. + Some("unsigned_long") => DataType::UInt64, Some("integer") => DataType::Int32, Some("short") => DataType::Int16, Some("byte") => DataType::Int8, @@ -152,4 +154,22 @@ mod tests { DataType::FixedSizeList(_, 384) )); } + + #[test] + fn test_unsigned_long_mapping() { + let mut properties = HashMap::new(); + properties.insert( + "big".to_string(), + FieldMapping { + field_type: Some("unsigned_long".to_string()), + properties: None, + dims: None, + similarity: None, + }, + ); + + let schema = mapping_to_schema(&properties); + let big = schema.field_with_name("big").expect("big field"); + assert_eq!(big.data_type(), &DataType::UInt64); + } } diff --git a/crates/libnfs/src/lib.rs b/crates/libnfs/src/lib.rs index 987ce45f4f..e7d223c206 100644 --- a/crates/libnfs/src/lib.rs +++ b/crates/libnfs/src/lib.rs @@ -48,6 +48,8 @@ compile_error!("libnfs bindings require a 64-bit target pointer width"); mod sys { //! Raw FFI bindings generated by bindgen. + // Bindgen emits declarations for libnfs symbols that the safe wrapper may not call on every platform/libnfs version. + #![allow(dead_code)] // Allow deprecated items from generated bindings #![allow(deprecated)] #![allow(clippy::upper_case_acronyms)] diff --git a/crates/runtime/src/datafusion/builder.rs b/crates/runtime/src/datafusion/builder.rs index 8ae439533b..a74f0032fc 100644 --- a/crates/runtime/src/datafusion/builder.rs +++ b/crates/runtime/src/datafusion/builder.rs @@ -30,7 +30,13 @@ use crate::{config::ClusterRole, metrics::telemetry::track_bytes_processed, stat use crate::{dataaccelerator::AcceleratorEngineRegistry, datafusion::SPICE_SCP_SCHEMA}; use cache::Caching; #[cfg(not(windows))] -use cayenne::optimizer_rules::CayenneJoinRewriter; +use cayenne::logical_optimizer::CayennePropagateFilterAcrossEquiJoinKeys; +#[cfg(not(windows))] +use cayenne::optimizer_rules::{ + CayenneAntiJoinSortMergeRewriter, CayenneDynamicFilterSharing, CayenneJoinRewriter, +}; +#[cfg(not(windows))] +use datafusion::optimizer::{Optimizer, OptimizerRule}; use datafusion::{ catalog::{CatalogProvider, MemoryCatalogProvider}, execution::{ @@ -383,7 +389,11 @@ impl DataFusionBuilder { // and accumulator budget are only configured for supported targets. // Windows keeps DataFusion's standard hash-join dynamic filters. clamp_maximum_shared_inlist_memory_bytes(exact_join_filter_memory_limit); - state = state.with_physical_optimizer_rule(Arc::new(CayenneJoinRewriter::new())); + state = with_cayenne_logical_optimizer(state); + state = state + .with_physical_optimizer_rule(Arc::new(CayenneDynamicFilterSharing::new())) + .with_physical_optimizer_rule(Arc::new(CayenneAntiJoinSortMergeRewriter::new())) + .with_physical_optimizer_rule(Arc::new(CayenneJoinRewriter::new())); } #[cfg(windows)] { @@ -599,6 +609,43 @@ impl DataFusionBuilder { } } +#[cfg(not(windows))] +fn with_cayenne_logical_optimizer(mut state: SessionStateBuilder) -> SessionStateBuilder { + let trailing_rules = state.optimizer_rules().take().unwrap_or_default(); + let mut optimizer_rules = state + .optimizer() + .take() + .map_or_else(|| Optimizer::new().rules, |optimizer| optimizer.rules); + + insert_cayenne_logical_optimizer_rule(&mut optimizer_rules); + optimizer_rules.extend(trailing_rules); + state.with_optimizer_rules(optimizer_rules) +} + +#[cfg(not(windows))] +fn insert_cayenne_logical_optimizer_rule(rules: &mut Vec>) { + if rules + .iter() + .any(|rule| rule.name() == "cayenne_propagate_filter_across_equi_join_keys") + { + return; + } + + let insert_at = rules + .iter() + .position(|rule| rule.name() == "decorrelate_predicate_subquery") + .unwrap_or_else(|| { + rules + .iter() + .position(|rule| rule.name() == "push_down_filter") + .unwrap_or(rules.len()) + }); + rules.insert( + insert_at, + Arc::new(CayennePropagateFilterAcrossEquiJoinKeys::new()), + ); +} + pub struct AnalyzerRulesBuilder { include_federation: bool, extra_rules: Vec>, @@ -777,6 +824,8 @@ mod tests { }; #[cfg(not(windows))] use cayenne::provider::CayenneAccelerationExec; + #[cfg(not(windows))] + use datafusion::catalog::MemTable; use datafusion::optimizer::Analyzer; #[cfg(not(windows))] use datafusion::{ @@ -914,6 +963,206 @@ mod tests { ); } + #[test] + #[cfg(not(windows))] + fn test_built_datafusion_registers_cayenne_logical_rule_before_subquery_decorrelation() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("tokio runtime"); + let handle = rt.handle().clone(); + + let df = DataFusionBuilder::new( + status::RuntimeStatus::new(), + Arc::new(AcceleratorEngineRegistry::default()), + handle, + ) + .build(); + + let state = df.ctx.state(); + let rule_names: Vec<&str> = state.optimizers().iter().map(|r| r.name()).collect(); + let cayenne_position = rule_names + .iter() + .position(|name| *name == "cayenne_propagate_filter_across_equi_join_keys") + .expect("Cayenne logical filter propagation rule should be registered"); + let decorrelate_position = rule_names + .iter() + .position(|name| *name == "decorrelate_predicate_subquery") + .expect("DataFusion decorrelate_predicate_subquery rule should be registered"); + let push_down_position = rule_names + .iter() + .position(|name| *name == "push_down_filter") + .expect("DataFusion push_down_filter rule should be registered"); + + assert!( + cayenne_position < decorrelate_position, + "Cayenne logical filter propagation must run before decorrelate_predicate_subquery so generated InSubquery predicates cannot reach physical planning" + ); + assert!( + decorrelate_position < push_down_position, + "DataFusion decorrelate_predicate_subquery must run before push_down_filter" + ); + assert_eq!( + rule_names + .iter() + .filter(|name| **name == "cayenne_propagate_filter_across_equi_join_keys") + .count(), + 1, + "Cayenne logical filter propagation rule should be registered exactly once" + ); + } + + #[test] + #[cfg(not(windows))] + fn test_built_datafusion_decorrelates_cayenne_propagated_subquery() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("tokio runtime"); + let handle = rt.handle().clone(); + + let df = DataFusionBuilder::new( + status::RuntimeStatus::new(), + Arc::new(AcceleratorEngineRegistry::default()), + handle, + ) + .build(); + + rt.block_on(async { + let nation_schema = Arc::new(Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, true), + ])); + let supplier_schema = Arc::new(Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_nationkey", DataType::Int64, false), + ])); + + df.ctx + .register_table( + "nation", + Arc::new( + MemTable::try_new(Arc::clone(&nation_schema), vec![vec![]]) + .expect("nation mem table should be valid"), + ), + ) + .expect("nation table should register"); + df.ctx + .register_table( + "supplier", + Arc::new( + MemTable::try_new(Arc::clone(&supplier_schema), vec![vec![]]) + .expect("supplier mem table should be valid"), + ), + ) + .expect("supplier table should register"); + + let dataframe = df + .ctx + .sql( + "SELECT s_suppkey FROM supplier, nation \ + WHERE s_nationkey = n_nationkey AND n_name = 'CHINA'", + ) + .await + .expect("q21-shaped query should create a dataframe"); + let optimized_plan = dataframe + .clone() + .into_optimized_plan() + .expect("q21-shaped query should optimize"); + let optimized_plan = optimized_plan.to_string(); + + assert!( + !optimized_plan.contains("InSubquery"), + "Cayenne propagated subqueries must be decorrelated before physical planning: {optimized_plan}" + ); + + dataframe + .create_physical_plan() + .await + .expect("q21-shaped query should create a physical plan"); + }); + } + + /// Regression test for the post-decorrelation re-propagation bug + /// (`cayenne::logical_optimizer`): after the rule wraps a Filter with + /// `InSubquery` and `DataFusion` decorrelates it to `LeftSemi`, the + /// optimizer iterates the rule pipeline to fixed point. Without the + /// cycle-detection fix in `analyze_logical_side`, the rule would re-fire + /// each pass and stack one redundant `LeftSemi` per iteration up to + /// `max_passes`. This integration test runs the full optimizer pipeline + /// and asserts the final plan has at most one `LeftSemi` for the q21 + /// shape — proving the cycle guard holds across decorrelation. + #[test] + #[cfg(not(windows))] + fn test_built_datafusion_does_not_stack_redundant_left_semi_after_decorrelation() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("tokio runtime"); + let handle = rt.handle().clone(); + + let df = DataFusionBuilder::new( + status::RuntimeStatus::new(), + Arc::new(AcceleratorEngineRegistry::default()), + handle, + ) + .build(); + + rt.block_on(async { + let nation_schema = Arc::new(Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, true), + ])); + let supplier_schema = Arc::new(Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_nationkey", DataType::Int64, false), + ])); + + df.ctx + .register_table( + "nation", + Arc::new( + MemTable::try_new(Arc::clone(&nation_schema), vec![vec![]]) + .expect("nation mem table should be valid"), + ), + ) + .expect("nation table should register"); + df.ctx + .register_table( + "supplier", + Arc::new( + MemTable::try_new(Arc::clone(&supplier_schema), vec![vec![]]) + .expect("supplier mem table should be valid"), + ), + ) + .expect("supplier table should register"); + + let dataframe = df + .ctx + .sql( + "SELECT s_suppkey FROM supplier, nation \ + WHERE s_nationkey = n_nationkey AND n_name = 'CHINA'", + ) + .await + .expect("q21-shaped query should create a dataframe"); + let optimized_plan = dataframe + .into_optimized_plan() + .expect("q21-shaped query should optimize"); + let plan_text = optimized_plan.to_string(); + + // The optimizer iterates rules to fixed point. Before the cycle + // guard, every iteration would add another `LeftSemi Join` on the + // fact side. With the guard in place we expect exactly one (the + // single decorrelated propagation). + let left_semi_count = plan_text.matches("LeftSemi Join").count(); + assert!( + left_semi_count <= 1, + "post-decorrelation re-propagation is stacking redundant LeftSemi joins \ + (count={left_semi_count}); plan was:\n{plan_text}" + ); + }); + } + /// Cayenne rewrites `HashJoinExec` to use a custom accumulator type, so it /// must run after `DataFusion`'s built-in physical optimizer rules that /// downcast to the default `HashJoinExec` type. @@ -947,11 +1196,35 @@ mod tests { .iter() .position(|name| *name == "CayenneJoinRewriter") .expect("Cayenne join rewriter should be registered"); + let cayenne_filter_sharing_position = rule_names + .iter() + .position(|name| *name == "CayenneDynamicFilterSharing") + .expect("Cayenne dynamic filter sharing rule should be registered"); + let cayenne_anti_sort_merge_position = rule_names + .iter() + .position(|name| *name == "CayenneAntiJoinSortMergeRewriter") + .expect("Cayenne anti join sort-merge rewriter should be registered"); assert!( sanity_check_position < cayenne_rewriter_position, "CayenneJoinRewriter must run after DataFusion's built-in physical optimizer rules" ); + assert!( + sanity_check_position < cayenne_filter_sharing_position, + "CayenneDynamicFilterSharing must run after DataFusion's built-in physical optimizer rules" + ); + assert!( + cayenne_filter_sharing_position < cayenne_rewriter_position, + "CayenneDynamicFilterSharing must run before CayenneJoinRewriter so it can inspect DataFusion's default HashJoinExec nodes" + ); + assert!( + cayenne_filter_sharing_position < cayenne_anti_sort_merge_position, + "CayenneDynamicFilterSharing must run before CayenneAntiJoinSortMergeRewriter so anti joins can still receive shared scan filters" + ); + assert!( + cayenne_anti_sort_merge_position < cayenne_rewriter_position, + "CayenneAntiJoinSortMergeRewriter must run before CayenneJoinRewriter so anti joins are not recreated with the hash-join accumulator" + ); } #[cfg(not(windows))] diff --git a/crates/spicepod/src/component/function.rs b/crates/spicepod/src/component/function.rs index 4f6ed2e9ee..43925f111a 100644 --- a/crates/spicepod/src/component/function.rs +++ b/crates/spicepod/src/component/function.rs @@ -122,6 +122,13 @@ pub struct Function { } /// Function kind. +/// +/// A future `HigherOrder` variant will land when `DataFusion`'s +/// `HigherOrderUDF` trait is absorbed (see apache/datafusion#21679). It lowers +/// to a separate trait, not a decoration on `ScalarUDF`, so it must be a +/// distinct kind here as well. Adding it later is backward-compatible: serde +/// rejects unknown values today, so no spicepod can accidentally take the +/// name. #[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)] #[cfg_attr(feature = "schemars", derive(JsonSchema))] #[serde(rename_all = "lowercase")] diff --git a/crates/test-framework/src/queries/mod.rs b/crates/test-framework/src/queries/mod.rs index d93b36977a..b97a105111 100644 --- a/crates/test-framework/src/queries/mod.rs +++ b/crates/test-framework/src/queries/mod.rs @@ -601,6 +601,7 @@ impl QuerySet { "chbench_q18", "chbench_q19", "chbench_q20", + "chbench_q21", "chbench_q22", ] } @@ -1214,10 +1215,9 @@ pub fn get_clickbench_test_queries(overrides: Option) -> Vec) -> Vec { let queries = generate_chbench_queries!( - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22 ); // q15 excluded: requires a `revenue1` view - // q21 excluded: multi-way JOIN + anti-join exhausts HashJoin memory (can't spill) match overrides { // No engine-specific overrides yet