diff --git a/Cargo.lock b/Cargo.lock
index d1742cd998..fe6ca6b9c2 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -744,6 +744,7 @@ dependencies = [
"arrow-cast",
"arrow-json",
"arrow-schema",
+ "criterion 0.7.0",
"datafusion",
"insta",
"serde_json",
diff --git a/crates/arrow_tools/Cargo.toml b/crates/arrow_tools/Cargo.toml
index c92cf02dda..1e59078f45 100644
--- a/crates/arrow_tools/Cargo.toml
+++ b/crates/arrow_tools/Cargo.toml
@@ -20,3 +20,8 @@ tracing.workspace = true
insta = {workspace = true, features = ["json"]}
arrow-json.workspace = true
serde_json.workspace = true
+criterion = { version = "0.7", features = ["html_reports"] }
+
+[[bench]]
+name = "truncate"
+harness = false
diff --git a/crates/arrow_tools/benches/truncate.rs b/crates/arrow_tools/benches/truncate.rs
new file mode 100644
index 0000000000..7b54f889ac
--- /dev/null
+++ b/crates/arrow_tools/benches/truncate.rs
@@ -0,0 +1,462 @@
+/*
+Copyright 2024-2026 The Spice.ai OSS Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+// Criterion benchmarks for the truncation fast-paths and formatting helpers
+// in arrow_tools.
+//
+// Run with: cargo bench -p arrow_tools --bench truncate
+//
+// The benchmarks deliberately separate "fast path" (no actual work needed,
+// exercises the zero-copy `clone()` paths added in the audit) from
+// "actual truncation" (exercises the slice+concat / collect paths).
+
+use arrow::array::{
+ FixedSizeListArray, Int32Array, LargeListViewArray, ListArray, ListViewArray, StringArray,
+ StringViewBuilder,
+};
+use arrow::buffer::{OffsetBuffer, ScalarBuffer};
+use arrow::datatypes::{DataType, Field, Schema};
+use arrow::record_batch::RecordBatch;
+use arrow_tools::record_batch::{truncate_numeric_column_length, truncate_string_columns};
+use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
+use std::hint::black_box;
+use std::sync::Arc;
+
+// Creates a batch where *no* string needs truncation at the given limit.
+// This exercises the new UTF8 fast-path (cheap any() + early Arc::clone).
+fn make_all_short_string_batch(n: usize, _max_chars: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, true)]));
+ // All strings are ASCII and well under the limit
+ let short = "short".repeat(3); // ~15 chars
+ let strings: Vec> = (0..n).map(|_| Some(short.clone())).collect();
+ let arr = StringArray::from(strings);
+ RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch")
+}
+
+// Creates a batch where a significant fraction of strings *do* need
+// character truncation. Exercises the collect + truncation path.
+fn make_many_long_string_batch(n: usize, max_chars: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, true)]));
+ let long = "x".repeat(max_chars + 20); // will need truncation
+ let short = "s";
+ let strings: Vec > = (0..n)
+ .map(|i| {
+ if i % 3 == 0 {
+ Some(long.clone())
+ } else {
+ Some(short.to_string())
+ }
+ })
+ .collect();
+ let arr = StringArray::from(strings);
+ RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch")
+}
+
+// StringViewArray versions of the above (exercises the specific fast-path
+// arm for DataType::Utf8View that was added during the audit).
+fn make_all_short_string_view_batch(n: usize, _max_chars: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "text",
+ DataType::Utf8View,
+ true,
+ )]));
+ let short = "short".repeat(3);
+ let mut builder = StringViewBuilder::new();
+ for _ in 0..n {
+ builder.append_value(&short);
+ }
+ let arr = builder.finish();
+ RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch")
+}
+
+fn make_many_long_string_view_batch(n: usize, max_chars: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "text",
+ DataType::Utf8View,
+ true,
+ )]));
+ let long = "x".repeat(max_chars + 20);
+ let short = "s";
+ let mut builder = StringViewBuilder::new();
+ for i in 0..n {
+ if i % 3 == 0 {
+ builder.append_value(&long);
+ } else {
+ builder.append_value(short);
+ }
+ }
+ let arr = builder.finish();
+ RecordBatch::try_new(schema, vec![Arc::new(arr)]).expect("valid batch")
+}
+
+// Creates a List batch where no list exceeds the element limit.
+// Exercises the list fast-path (try_fold + clone).
+fn make_all_short_list_batch(n: usize, max_elems: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
+ true,
+ )]));
+ let per_list = max_elems.min(5);
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let offsets: Vec = (0..=n)
+ .map(|i| i32::try_from(i * per_list).expect("offset fits in i32"))
+ .collect();
+ let list = ListArray::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ OffsetBuffer::::new(offsets.into()),
+ Arc::new(values),
+ None,
+ );
+ RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch")
+}
+
+// Creates a List batch where truncation is required for every list.
+fn make_long_list_batch(n: usize, truncate_to: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
+ true,
+ )]));
+ let per_list = truncate_to + 15; // will need truncation
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let offsets: Vec = (0..=n)
+ .map(|i| i32::try_from(i * per_list).expect("offset fits in i32"))
+ .collect();
+ let list = ListArray::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ OffsetBuffer::::new(offsets.into()),
+ Arc::new(values),
+ None,
+ );
+ RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch")
+}
+
+// ListView versions (exercises the ListView truncation path, which has
+// more complex offset + size handling and a different fast-path decision
+// based on the explicit sizes buffer).
+fn make_all_short_list_view_batch(n: usize, max_elems: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))),
+ true,
+ )]));
+ let per_list = max_elems.min(5);
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let offsets: Vec = (0..n)
+ .map(|i| i32::try_from(i * per_list).expect("offset fits in i32"))
+ .collect();
+ let per_list_i32 = i32::try_from(per_list).expect("per_list fits in i32");
+ let sizes: Vec = (0..n).map(|_| per_list_i32).collect();
+ let list_view = ListViewArray::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ ScalarBuffer::::from(offsets),
+ ScalarBuffer::::from(sizes),
+ Arc::new(values),
+ None,
+ )
+ .expect("ListViewArray construction for benchmark");
+ RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch")
+}
+
+fn make_long_list_view_batch(n: usize, truncate_to: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))),
+ true,
+ )]));
+ let per_list = truncate_to + 15;
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let offsets: Vec = (0..n)
+ .map(|i| i32::try_from(i * per_list).expect("offset fits in i32"))
+ .collect();
+ let per_list_i32 = i32::try_from(per_list).expect("per_list fits in i32");
+ let sizes: Vec = (0..n).map(|_| per_list_i32).collect();
+ let list_view = ListViewArray::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ ScalarBuffer::::from(offsets),
+ ScalarBuffer::::from(sizes),
+ Arc::new(values),
+ None,
+ )
+ .expect("ListViewArray construction for benchmark");
+ RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch")
+}
+
+// FixedSizeList versions (exercises the FixedSizeList fast-path, which is
+// the cheapest of all — just a uniform size comparison — plus the stride-based
+// slicing + concat work path).
+fn make_all_short_fixed_size_list_batch(n: usize, _max_elems: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 5),
+ true,
+ )]));
+ let per_list = 5usize;
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let list = FixedSizeListArray::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ 5,
+ Arc::new(values),
+ None,
+ );
+ RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch")
+}
+
+fn make_long_fixed_size_list_batch(n: usize, _truncate_to: usize) -> RecordBatch {
+ // For FixedSizeList the "long" case is when the fixed size > truncate_to.
+ // We create lists of size 20 and truncate to 5.
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 20),
+ true,
+ )]));
+ let per_list = 20usize;
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let list = FixedSizeListArray::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ 20,
+ Arc::new(values),
+ None,
+ );
+ RecordBatch::try_new(schema, vec![Arc::new(list)]).expect("valid batch")
+}
+
+// LargeListView versions (completes benchmark coverage for all five list
+// variants that received full support during the type audit; i64 offsets/sizes
+// + more complex offset rebuild in the work path).
+fn make_all_short_large_list_view_batch(n: usize, max_elems: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))),
+ true,
+ )]));
+ let per_list = max_elems.min(5);
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let offsets: Vec = (0..n)
+ .map(|i| {
+ let offset = i.checked_mul(per_list).expect("offset fits in usize");
+ i64::try_from(offset).expect("offset fits in i64")
+ })
+ .collect();
+ let per_list_i64 = i64::try_from(per_list).expect("list size fits in i64");
+ let sizes: Vec = std::iter::repeat_n(per_list_i64, n).collect();
+ let list_view = LargeListViewArray::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ ScalarBuffer::::from(offsets),
+ ScalarBuffer::::from(sizes),
+ Arc::new(values),
+ None,
+ )
+ .expect("LargeListViewArray construction for benchmark");
+ RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch")
+}
+
+fn make_long_large_list_view_batch(n: usize, truncate_to: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "nums",
+ DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))),
+ true,
+ )]));
+ let per_list = truncate_to + 15;
+ let total_values = n * per_list;
+ let values = Int32Array::from(
+ (0..total_values)
+ .map(|v| i32::try_from(v).expect("value fits in i32"))
+ .collect::>(),
+ );
+ let offsets: Vec = (0..n)
+ .map(|i| {
+ let offset = i.checked_mul(per_list).expect("offset fits in usize");
+ i64::try_from(offset).expect("offset fits in i64")
+ })
+ .collect();
+ let per_list_i64 = i64::try_from(per_list).expect("list size fits in i64");
+ let sizes: Vec = std::iter::repeat_n(per_list_i64, n).collect();
+ let list_view = LargeListViewArray::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ ScalarBuffer::::from(offsets),
+ ScalarBuffer::::from(sizes),
+ Arc::new(values),
+ None,
+ )
+ .expect("LargeListViewArray construction for benchmark");
+ RecordBatch::try_new(schema, vec![Arc::new(list_view)]).expect("valid batch")
+}
+
+fn bench_truncate(c: &mut Criterion) {
+ let mut group = c.benchmark_group("arrow_tools_truncate");
+
+ // String (Utf8) fast path
+ let short_strings = make_all_short_string_batch(2000, 50);
+ group.bench_function("string_fast_path_2000_rows_all_short", |b| {
+ b.iter(|| {
+ truncate_string_columns(black_box(&short_strings), 50).expect("truncate short strings");
+ });
+ });
+
+ // String (Utf8) actual work
+ let long_strings = make_many_long_string_batch(2000, 50);
+ group.bench_function("string_with_truncation_2000_rows_mixed", |b| {
+ b.iter(|| {
+ truncate_string_columns(black_box(&long_strings), 50).expect("truncate long strings");
+ });
+ });
+
+ // StringView fast path (exercises the specific Utf8View arm + is_some_and decision)
+ let short_views = make_all_short_string_view_batch(2000, 50);
+ group.bench_function("stringview_fast_path_2000_rows_all_short", |b| {
+ b.iter(|| {
+ truncate_string_columns(black_box(&short_views), 50)
+ .expect("truncate short string views");
+ });
+ });
+
+ // StringView actual truncation
+ let long_views = make_many_long_string_view_batch(2000, 50);
+ group.bench_function("stringview_with_truncation_2000_rows_mixed", |b| {
+ b.iter(|| {
+ truncate_string_columns(black_box(&long_views), 50)
+ .expect("truncate long string views");
+ });
+ });
+
+ // List (regular List) fast path
+ let short_lists = make_all_short_list_batch(800, 5);
+ group.bench_function("list_fast_path_800_rows_all_short", |b| {
+ b.iter(|| {
+ truncate_numeric_column_length(black_box(&short_lists), 5)
+ .expect("truncate short lists");
+ });
+ });
+
+ // List (regular List) actual truncation work
+ let long_lists = make_long_list_batch(800, 5);
+ group.bench_function("list_with_truncation_800_rows", |b| {
+ b.iter(|| {
+ truncate_numeric_column_length(black_box(&long_lists), 5).expect("truncate long lists");
+ });
+ });
+
+ // ListView fast path (exercises the sizes-based decision + clone)
+ // Use iter_batched to isolate setup cost (creating the ListView batch
+ // with non-contiguous offsets/sizes is non-trivial).
+ group.bench_function("listview_fast_path_800_rows_all_short", |b| {
+ b.iter_batched(
+ || make_all_short_list_view_batch(800, 5),
+ |batch| {
+ truncate_numeric_column_length(black_box(&batch), 5)
+ .expect("truncate short list views");
+ },
+ BatchSize::SmallInput,
+ );
+ });
+
+ // ListView actual truncation work (exercises the more complex offset/size rebuild)
+ let long_list_views = make_long_list_view_batch(800, 5);
+ group.bench_function("listview_with_truncation_800_rows", |b| {
+ b.iter(|| {
+ truncate_numeric_column_length(black_box(&long_list_views), 5)
+ .expect("truncate long list views");
+ });
+ });
+
+ // FixedSizeList fast path (cheapest decision: uniform size comparison).
+ // Use iter_batched for consistency with the view variants (even though
+ // setup is simpler, it keeps the benchmark structure uniform).
+ group.bench_function("fixed_size_list_fast_path_800_rows_all_short", |b| {
+ b.iter_batched(
+ || make_all_short_fixed_size_list_batch(800, 5),
+ |batch| {
+ truncate_numeric_column_length(black_box(&batch), 5)
+ .expect("truncate short fixed-size lists");
+ },
+ BatchSize::SmallInput,
+ );
+ });
+
+ // FixedSizeList actual truncation work (stride-based slicing + concat)
+ let long_fsl = make_long_fixed_size_list_batch(800, 5);
+ group.bench_function("fixed_size_list_with_truncation_800_rows", |b| {
+ b.iter(|| {
+ truncate_numeric_column_length(black_box(&long_fsl), 5)
+ .expect("truncate long fixed-size lists");
+ });
+ });
+
+ // LargeListView fast path (completes the five-variant benchmark coverage;
+ // i64 sizes scan + zero-copy clone).
+ // Use iter_batched to isolate setup cost (i64 ScalarBuffer creation).
+ group.bench_function("large_listview_fast_path_800_rows_all_short", |b| {
+ b.iter_batched(
+ || make_all_short_large_list_view_batch(800, 5),
+ |batch| {
+ truncate_numeric_column_length(black_box(&batch), 5)
+ .expect("truncate short large list views");
+ },
+ BatchSize::SmallInput,
+ );
+ });
+
+ // LargeListView actual truncation work (i64 offset/size rebuild path)
+ let long_large_list_views = make_long_large_list_view_batch(800, 5);
+ group.bench_function("large_listview_with_truncation_800_rows", |b| {
+ b.iter(|| {
+ truncate_numeric_column_length(black_box(&long_large_list_views), 5)
+ .expect("truncate long large list views");
+ });
+ });
+
+ group.finish();
+}
+
+criterion_group!(benches, bench_truncate);
+criterion_main!(benches);
diff --git a/crates/arrow_tools/src/format.rs b/crates/arrow_tools/src/format.rs
index 52c1dbc68f..6444a3d3c1 100644
--- a/crates/arrow_tools/src/format.rs
+++ b/crates/arrow_tools/src/format.rs
@@ -15,8 +15,11 @@ limitations under the License.
*/
use crate::schema::to_source_native_type_name;
-use arrow::array::{Array, ArrayRef, FixedSizeListArray, ListArray, RecordBatch, StructArray};
-use arrow::buffer::OffsetBuffer;
+use arrow::array::{
+ Array, ArrayRef, FixedSizeListArray, LargeListArray, LargeListViewArray, ListArray,
+ ListViewArray, RecordBatch, StructArray,
+};
+use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::compute::concat;
use arrow_cast::display::{ArrayFormatter, FormatOptions};
use arrow_schema::{ArrowError, DataType, Field, Schema};
@@ -52,6 +55,19 @@ pub(crate) fn format_column_data(
"Failed to downcast to StringViewArray".into(),
))?;
+ // Fast path: zero-copy when no string requires character truncation.
+ // Cheap byte-length filter first (char count only on candidates).
+ if !string_array.iter().any(|opt| {
+ opt.is_some_and(|s| {
+ // Short-circuit at the (max_characters)th boundary so we
+ // stop walking pathologically long strings as soon as we
+ // know they'll need truncation.
+ s.len() > max_characters && s.chars().nth(max_characters).is_some()
+ })
+ }) {
+ return Ok(Arc::clone(&column));
+ }
+
let truncated = string_array
.iter()
.map(|x| truncate_str(x, max_characters))
@@ -64,9 +80,22 @@ pub(crate) fn format_column_data(
.as_any()
.downcast_ref::()
.ok_or(ArrowError::CastError(
- "Failed to downcast to ListArray".into(),
+ "Failed to downcast to StringArray".into(),
))?;
+ // Fast path: zero-copy when no string requires character truncation.
+ // Cheap byte-length filter first (char count only on candidates).
+ if !string_array.iter().any(|opt| {
+ opt.is_some_and(|s| {
+ // Short-circuit at the (max_characters)th boundary so we
+ // stop walking pathologically long strings as soon as we
+ // know they'll need truncation.
+ s.len() > max_characters && s.chars().nth(max_characters).is_some()
+ })
+ }) {
+ return Ok(Arc::clone(&column));
+ }
+
let truncated = string_array
.iter()
.map(|x| truncate_str(x, max_characters))
@@ -80,36 +109,66 @@ pub(crate) fn format_column_data(
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _)
- | DataType::ListView(_),
+ | DataType::ListView(_)
+ | DataType::LargeListView(_),
Some(_),
),
) => {
- let array_ref = if let DataType::FixedSizeList(_, _) = column.data_type() {
- let fixed_list_array = column
- .as_any()
- .downcast_ref::()
- .ok_or_else(|| {
- ArrowError::CastError("Failed to downcast to FixedSizeListArray".into())
- })?;
- Arc::new(truncate_fixed_size_list_array(
- fixed_list_array,
- num_elements,
- )?) as ArrayRef
- } else {
- let list_array = column
- .as_any()
- .downcast_ref::()
- .ok_or_else(|| {
- ArrowError::CastError("Failed to downcast to ListArray".into())
- })?;
- Arc::new(truncate_list_array(list_array, num_elements)?) as ArrayRef
+ let array_ref = match column.data_type() {
+ DataType::FixedSizeList(_, _) => {
+ let fixed_list_array = column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to FixedSizeListArray".into())
+ })?;
+ Arc::new(truncate_fixed_size_list_array(
+ fixed_list_array,
+ num_elements,
+ )?) as ArrayRef
+ }
+ DataType::LargeList(_) => {
+ let list_array = column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to LargeListArray".into())
+ })?;
+ Arc::new(truncate_large_list_array(list_array, num_elements)?) as ArrayRef
+ }
+ DataType::ListView(_) => {
+ let list_array =
+ column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to ListViewArray".into())
+ })?;
+ Arc::new(truncate_list_view_array(list_array, num_elements)?) as ArrayRef
+ }
+ DataType::LargeListView(_) => {
+ let list_array = column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to LargeListViewArray".into())
+ })?;
+ Arc::new(truncate_large_list_view_array(list_array, num_elements)?) as ArrayRef
+ }
+ _ => {
+ let list_array =
+ column.as_any().downcast_ref::().ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to ListArray".into())
+ })?;
+ Arc::new(truncate_list_array(list_array, num_elements)?) as ArrayRef
+ }
};
Ok(array_ref)
}
(FormatOperation::TruncateUtf8Length(max_characters), (DataType::List(field), _)) => {
let list_array = column
.as_any()
- .downcast_ref::()
+ .downcast_ref::()
.ok_or_else(|| ArrowError::CastError("Failed to downcast to ListArray".into()))?;
let truncated_values = format_column_data(
@@ -118,24 +177,150 @@ pub(crate) fn format_column_data(
FormatOperation::TruncateUtf8Length(max_characters),
)?;
+ // Zero-copy fast path for the outer list when the inner values were
+ // not modified (common when text in lists is short).
+ if Arc::ptr_eq(&truncated_values, list_array.values()) {
+ return Ok(Arc::clone(&column));
+ }
+
let list = ListArray::new(
Arc::clone(&field),
- arrow::buffer::OffsetBuffer::new(
- arrow::buffer::Buffer::from_slice_ref(list_array.value_offsets()).into(),
- ),
+ list_array.offsets().clone(),
truncated_values,
- list_array.logical_nulls(),
+ list_array.nulls().cloned(),
);
Ok(Arc::new(list) as ArrayRef)
}
+ (FormatOperation::TruncateUtf8Length(max_characters), (DataType::LargeList(field), _)) => {
+ let list_array = column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to LargeListArray".into())
+ })?;
+
+ let truncated_values = format_column_data(
+ Arc::clone(list_array.values()),
+ &field,
+ FormatOperation::TruncateUtf8Length(max_characters),
+ )?;
+
+ // Zero-copy fast path for the outer list when the inner values were
+ // not modified (common when text in lists is short).
+ if Arc::ptr_eq(&truncated_values, list_array.values()) {
+ return Ok(Arc::clone(&column));
+ }
+
+ let list = LargeListArray::new(
+ Arc::clone(&field),
+ list_array.offsets().clone(),
+ truncated_values,
+ list_array.nulls().cloned(),
+ );
+
+ Ok(Arc::new(list) as ArrayRef)
+ }
+ (
+ FormatOperation::TruncateUtf8Length(max_characters),
+ (DataType::FixedSizeList(field, size), _),
+ ) => {
+ let list_array = column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to FixedSizeListArray".into())
+ })?;
+
+ let truncated_values = format_column_data(
+ Arc::clone(list_array.values()),
+ &field,
+ FormatOperation::TruncateUtf8Length(max_characters),
+ )?;
+
+ // Zero-copy fast path for the outer list when the inner values were
+ // not modified (common when text in lists is short).
+ if Arc::ptr_eq(&truncated_values, list_array.values()) {
+ return Ok(Arc::clone(&column));
+ }
+
+ let list = FixedSizeListArray::new(
+ Arc::clone(&field),
+ size,
+ truncated_values,
+ list_array.nulls().cloned(),
+ );
+
+ Ok(Arc::new(list) as ArrayRef)
+ }
+ (FormatOperation::TruncateUtf8Length(max_characters), (DataType::ListView(field), _)) => {
+ let list_array = column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to ListViewArray".into())
+ })?;
+
+ let truncated_values = format_column_data(
+ Arc::clone(list_array.values()),
+ &field,
+ FormatOperation::TruncateUtf8Length(max_characters),
+ )?;
+
+ // Zero-copy fast path for the outer list when the inner values were
+ // not modified (common when text in lists is short).
+ if Arc::ptr_eq(&truncated_values, list_array.values()) {
+ return Ok(Arc::clone(&column));
+ }
+
+ ListViewArray::try_new(
+ Arc::clone(&field),
+ list_array.offsets().clone(),
+ list_array.sizes().clone(),
+ truncated_values,
+ list_array.nulls().cloned(),
+ )
+ .map(|list| Arc::new(list) as ArrayRef)
+ }
+ (
+ FormatOperation::TruncateUtf8Length(max_characters),
+ (DataType::LargeListView(field), _),
+ ) => {
+ let list_array = column
+ .as_any()
+ .downcast_ref::()
+ .ok_or_else(|| {
+ ArrowError::CastError("Failed to downcast to LargeListViewArray".into())
+ })?;
+
+ let truncated_values = format_column_data(
+ Arc::clone(list_array.values()),
+ &field,
+ FormatOperation::TruncateUtf8Length(max_characters),
+ )?;
+
+ // Zero-copy fast path for the outer list when the inner values were
+ // not modified (common when text in lists is short).
+ if Arc::ptr_eq(&truncated_values, list_array.values()) {
+ return Ok(Arc::clone(&column));
+ }
+
+ LargeListViewArray::try_new(
+ Arc::clone(&field),
+ list_array.offsets().clone(),
+ list_array.sizes().clone(),
+ truncated_values,
+ list_array.nulls().cloned(),
+ )
+ .map(|list| Arc::new(list) as ArrayRef)
+ }
(FormatOperation::TruncateUtf8Length(max_characters), (DataType::Struct(fields), _)) => {
let struct_array = column
.as_any()
.downcast_ref::()
.ok_or_else(|| ArrowError::CastError("Failed to downcast to StructArray".into()))?;
- let columns = fields
+ let columns: Vec<_> = fields
.iter()
.enumerate()
.map(|(i, field)| {
@@ -148,6 +333,15 @@ pub(crate) fn format_column_data(
})
.collect::, _>>()?;
+ // Zero-copy fast path for the whole struct when no field was modified.
+ let all_unchanged = columns
+ .iter()
+ .enumerate()
+ .all(|(i, c)| Arc::ptr_eq(c, struct_array.column(i)));
+ if all_unchanged {
+ return Ok(Arc::clone(&column));
+ }
+
let truncated_struct =
StructArray::from(fields.iter().cloned().zip(columns).collect::>());
Ok(Arc::new(truncated_struct) as ArrayRef)
@@ -161,9 +355,11 @@ fn get_possible_nested_list_datatype(f: &Arc) -> (DataType, Option {
- Some(f.data_type().clone())
- }
+ DataType::List(f)
+ | DataType::FixedSizeList(f, _)
+ | DataType::LargeList(f)
+ | DataType::ListView(f)
+ | DataType::LargeListView(f) => Some(f.data_type().clone()),
_ => None,
},
)
@@ -184,18 +380,78 @@ fn truncate_str(str: Option<&str>, max_characters: usize) -> Option<&str> {
})
}
-#[expect(
- clippy::cast_sign_loss,
- clippy::cast_possible_truncation,
- clippy::cast_possible_wrap
-)]
+fn value_to_usize(value: T, value_name: &str) -> Result
+where
+ T: TryInto + std::fmt::Display + Copy,
+{
+ value.try_into().map_err(|_| {
+ ArrowError::InvalidArgumentError(format!(
+ "{value_name} {value} cannot be represented as usize"
+ ))
+ })
+}
+
+fn value_to_i32(value: usize, value_name: &str) -> Result {
+ i32::try_from(value).map_err(|_| {
+ ArrowError::InvalidArgumentError(format!(
+ "{value_name} {value} cannot be represented as i32"
+ ))
+ })
+}
+
+fn list_slice_range(
+ start_offset: T,
+ end_offset: T,
+ array_name: &str,
+ start_offset_name: &str,
+ end_offset_name: &str,
+) -> Result<(usize, usize), ArrowError>
+where
+ T: TryInto + std::fmt::Display + Copy,
+{
+ let start = value_to_usize(start_offset, start_offset_name)?;
+ let end = value_to_usize(end_offset, end_offset_name)?;
+ let len = end.checked_sub(start).ok_or_else(|| {
+ ArrowError::InvalidArgumentError(format!(
+ "{array_name} end offset {end} is before start offset {start}"
+ ))
+ })?;
+
+ Ok((start, len))
+}
+
+fn list_element_field(data_type: &DataType, array_name: &str) -> Result, ArrowError> {
+ match data_type {
+ DataType::List(field)
+ | DataType::LargeList(field)
+ | DataType::ListView(field)
+ | DataType::LargeListView(field)
+ | DataType::FixedSizeList(field, _) => Ok(Arc::clone(field)),
+ _ => Err(ArrowError::InvalidArgumentError(format!(
+ "{array_name} data type is not a list type: {data_type}"
+ ))),
+ }
+}
+
fn truncate_fixed_size_list_array(
list_array: &FixedSizeListArray,
max_len: usize,
) -> Result {
+ if list_array.is_empty() {
+ return Ok(list_array.clone());
+ }
let child_array = list_array.values();
- let original_size = list_array.value_length() as usize;
- let truncated_size = max_len.min(original_size);
+ let original_size =
+ value_to_usize(list_array.value_length(), "FixedSizeListArray value length")?;
+ // Fast path: zero-copy clone when truncation would be a no-op. This is the
+ // common case for display formatting (generous max_len, small actual lists)
+ // and avoids per-row slice + concat allocation/copy entirely.
+ if max_len >= original_size {
+ return Ok(list_array.clone());
+ }
+ let truncated_size = max_len; // known to be < original_size
+ let truncated_size_i32 = value_to_i32(truncated_size, "FixedSizeListArray truncated size")?;
+ let element_field = list_element_field(list_array.data_type(), "FixedSizeListArray")?;
let sliced_arrays: Vec> = (0..list_array.len())
.map(|i| child_array.slice(i * original_size, truncated_size))
@@ -204,57 +460,272 @@ fn truncate_fixed_size_list_array(
let new_child_array = Arc::new(concat(
&sliced_arrays.iter().map(AsRef::as_ref).collect::>(),
)?);
- let nulls = new_child_array.nulls().cloned();
-
- FixedSizeListArray::try_new(
- Arc::new(Field::new(
- "item",
- child_array.data_type().clone(),
- child_array.is_nullable(),
- )),
- truncated_size as i32,
+ // Parent list-level nulls (which slots are NULL lists) are preserved as-is;
+ // they live separately from the child array's element-level null bitmap.
+ let nulls = list_array.nulls().cloned();
+
+ FixedSizeListArray::try_new(element_field, truncated_size_i32, new_child_array, nulls)
+}
+
+fn truncate_list_array(list_array: &ListArray, max_len: usize) -> Result {
+ if list_array.is_empty() {
+ return Ok(list_array.clone());
+ }
+ let child_array = list_array.values();
+ let offsets = list_array.value_offsets();
+ // Fast path: zero-copy clone when no list element exceeds max_len.
+ let mut needs_trunc = false;
+ for i in 0..list_array.len() {
+ let (_, len) = list_slice_range(
+ offsets[i],
+ offsets[i + 1],
+ "ListArray",
+ "ListArray start offset",
+ "ListArray end offset",
+ )?;
+ if len > max_len {
+ needs_trunc = true;
+ break;
+ }
+ }
+ if !needs_trunc {
+ return Ok(list_array.clone());
+ }
+ let element_field = list_element_field(list_array.data_type(), "ListArray")?;
+
+ let slice_ranges: Vec<(usize, usize)> = (0..list_array.len())
+ .map(|i| {
+ list_slice_range(
+ offsets[i],
+ offsets[i + 1],
+ "ListArray",
+ "ListArray start offset",
+ "ListArray end offset",
+ )
+ .map(|(start, len)| (start, max_len.min(len)))
+ })
+ .collect::>()?;
+ let new_lengths: Vec = slice_ranges.iter().map(|&(_, len)| len).collect();
+
+ let sliced_arrays: Vec> = new_lengths
+ .iter()
+ .zip(slice_ranges.iter())
+ .map(|(&len, &(start, _))| child_array.slice(start, len))
+ .collect();
+
+ let new_child_array = Arc::new(concat(
+ &sliced_arrays.iter().map(AsRef::as_ref).collect::>(),
+ )?);
+
+ let nulls = list_array.nulls().cloned();
+
+ ListArray::try_new(
+ element_field,
+ OffsetBuffer::from_lengths(new_lengths),
new_child_array,
nulls,
)
}
-#[expect(clippy::cast_sign_loss)]
-fn truncate_list_array(list_array: &ListArray, max_len: usize) -> Result {
+fn truncate_large_list_array(
+ list_array: &LargeListArray,
+ max_len: usize,
+) -> Result {
+ if list_array.is_empty() {
+ return Ok(list_array.clone());
+ }
let child_array = list_array.values();
let offsets = list_array.value_offsets();
+ // Fast path: zero-copy clone when truncation is unnecessary (common for
+ // result formatting). LargeList uses i64 offsets, so validate conversion
+ // before deciding whether the slow path is needed.
+ let mut needs_trunc = false;
+ for i in 0..list_array.len() {
+ let (_, len) = list_slice_range(
+ offsets[i],
+ offsets[i + 1],
+ "LargeListArray",
+ "LargeListArray start offset",
+ "LargeListArray end offset",
+ )?;
+ if len > max_len {
+ needs_trunc = true;
+ break;
+ }
+ }
+ if !needs_trunc {
+ return Ok(list_array.clone());
+ }
+ let element_field = list_element_field(list_array.data_type(), "LargeListArray")?;
- let new_lengths: Vec = (0..list_array.len())
+ let slice_ranges: Vec<(usize, usize)> = (0..list_array.len())
.map(|i| {
- let start = offsets[i] as usize;
- let end = offsets[i + 1] as usize;
- max_len.min(end - start)
+ list_slice_range(
+ offsets[i],
+ offsets[i + 1],
+ "LargeListArray",
+ "LargeListArray start offset",
+ "LargeListArray end offset",
+ )
+ .map(|(start, len)| (start, max_len.min(len)))
})
- .collect();
+ .collect::>()?;
+ let new_lengths: Vec = slice_ranges.iter().map(|&(_, len)| len).collect();
let sliced_arrays: Vec> = new_lengths
.iter()
- .enumerate()
- .map(|(i, &len)| child_array.slice(offsets[i] as usize, len))
+ .zip(slice_ranges.iter())
+ .map(|(&len, &(start, _))| child_array.slice(start, len))
.collect();
let new_child_array = Arc::new(concat(
&sliced_arrays.iter().map(AsRef::as_ref).collect::>(),
)?);
- let nulls = new_child_array.nulls().cloned();
+ let nulls = list_array.nulls().cloned();
- ListArray::try_new(
- Arc::new(Field::new(
- "item",
- child_array.data_type().clone(),
- child_array.is_nullable(),
- )),
+ LargeListArray::try_new(
+ element_field,
OffsetBuffer::from_lengths(new_lengths),
new_child_array,
nulls,
)
}
+fn truncate_list_view_array(
+ list_array: &ListViewArray,
+ max_len: usize,
+) -> Result {
+ if list_array.is_empty() {
+ return Ok(list_array.clone());
+ }
+ let child_array = list_array.values();
+ let sizes = list_array.value_sizes();
+ // Fast path for ListView: sizes are stored explicitly, cheap scan.
+ // When no element exceeds the limit we return the original (preserving
+ // whatever non-contiguous view layout the caller had).
+ let mut needs_trunc = false;
+ for &size in sizes {
+ if value_to_usize(size, "ListViewArray size")? > max_len {
+ needs_trunc = true;
+ break;
+ }
+ }
+ if !needs_trunc {
+ return Ok(list_array.clone());
+ }
+ let offsets = list_array.value_offsets();
+ let element_field = list_element_field(list_array.data_type(), "ListViewArray")?;
+
+ let slice_ranges: Vec<(usize, usize, i32)> = (0..list_array.len())
+ .map(|i| {
+ let start = value_to_usize(offsets[i], "ListViewArray offset")?;
+ let original_size = value_to_usize(sizes[i], "ListViewArray size")?;
+ let truncated_size = max_len.min(original_size);
+ let truncated_size_i32 = value_to_i32(truncated_size, "ListViewArray truncated size")?;
+
+ Ok((start, truncated_size, truncated_size_i32))
+ })
+ .collect::>()?;
+ let new_sizes: Vec = slice_ranges
+ .iter()
+ .map(|&(_, _, truncated_size_i32)| truncated_size_i32)
+ .collect();
+
+ let sliced_arrays: Vec> = (0..list_array.len())
+ .map(|i| child_array.slice(slice_ranges[i].0, slice_ranges[i].1))
+ .collect();
+
+ let new_child_array = Arc::new(concat(
+ &sliced_arrays.iter().map(AsRef::as_ref).collect::>(),
+ )?);
+
+ let nulls = list_array.nulls().cloned();
+
+ // After concat, the new values buffer is laid out contiguously, so we
+ // re-derive offsets from the truncated sizes. Bail on overflow rather
+ // than silently producing misaligned offsets.
+ let mut new_offsets: Vec = Vec::with_capacity(new_sizes.len());
+ let mut running: i32 = 0;
+ for &s in &new_sizes {
+ new_offsets.push(running);
+ running = running.checked_add(s).ok_or_else(|| {
+ ArrowError::InvalidArgumentError("ListViewArray cumulative size overflowed i32".into())
+ })?;
+ }
+
+ ListViewArray::try_new(
+ element_field,
+ ScalarBuffer::from(new_offsets),
+ ScalarBuffer::from(new_sizes),
+ new_child_array,
+ nulls,
+ )
+}
+
+fn truncate_large_list_view_array(
+ list_array: &LargeListViewArray,
+ max_len: usize,
+) -> Result {
+ if list_array.is_empty() {
+ return Ok(list_array.clone());
+ }
+ let child_array = list_array.values();
+ let sizes = list_array.value_sizes();
+ // Fast path for LargeListView (i64 sizes): cheap scan over the sizes buffer.
+ // When nothing needs truncation we avoid the expensive concat + offset rebuild.
+ let max_len_i64 = i64::try_from(max_len).map_err(|_| {
+ ArrowError::InvalidArgumentError(format!("max_len {max_len} cannot be represented as i64"))
+ })?;
+ if !sizes.iter().any(|&s| s > max_len_i64) {
+ return Ok(list_array.clone());
+ }
+ let offsets = list_array.value_offsets();
+ let element_field = list_element_field(list_array.data_type(), "LargeListViewArray")?;
+
+ let new_sizes: Vec = (0..list_array.len())
+ .map(|i| sizes[i].min(max_len_i64))
+ .collect();
+
+ let slice_ranges: Vec<(usize, usize)> = (0..list_array.len())
+ .map(|i| {
+ let start = value_to_usize(offsets[i], "LargeListViewArray offset")?;
+ let len = value_to_usize(new_sizes[i], "LargeListViewArray size")?;
+
+ Ok((start, len))
+ })
+ .collect::>()?;
+
+ let sliced_arrays: Vec> = (0..list_array.len())
+ .map(|i| child_array.slice(slice_ranges[i].0, slice_ranges[i].1))
+ .collect();
+
+ let new_child_array = Arc::new(concat(
+ &sliced_arrays.iter().map(AsRef::as_ref).collect::>(),
+ )?);
+
+ let nulls = list_array.nulls().cloned();
+
+ let mut new_offsets: Vec = Vec::with_capacity(new_sizes.len());
+ let mut running: i64 = 0;
+ for &s in &new_sizes {
+ new_offsets.push(running);
+ running = running.checked_add(s).ok_or_else(|| {
+ ArrowError::InvalidArgumentError(
+ "LargeListViewArray cumulative size overflowed i64".into(),
+ )
+ })?;
+ }
+
+ LargeListViewArray::try_new(
+ element_field,
+ ScalarBuffer::from(new_offsets),
+ ScalarBuffer::from(new_sizes),
+ new_child_array,
+ nulls,
+ )
+}
+
/// Creates a visual representation of record batches using markdown document format with additional header fields.
///
/// # Errors
@@ -413,6 +884,26 @@ pub fn pretty_print_schema(
write_data_type(field.data_type(), w)?;
w.write_char('>')
}
+ DataType::FixedSizeList(field, size) => {
+ w.write_str("fixed_size_list[{size}]")
+ }
+ DataType::ListView(field) => {
+ w.write_str("list_view')
+ }
+ DataType::LargeListView(field) => {
+ w.write_str("large_list_view')
+ }
+ DataType::Map(field, _) => {
+ w.write_str("map<")?;
+ write_data_type(field.data_type(), w)?;
+ w.write_char('>')
+ }
DataType::Struct(fields) => {
w.write_str("struct<")?;
for (i, f) in fields.iter().enumerate() {
@@ -634,6 +1125,519 @@ Cras venenatis euismod malesuada.",
}
}
+ #[test]
+ fn test_truncate_large_list_array() {
+ use arrow::array::{Int32Array, LargeListArray as LargeListArrayAlias};
+ use arrow::buffer::OffsetBuffer;
+
+ let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]);
+ let offsets = OffsetBuffer::::new(vec![0_i64, 3, 6, 8].into());
+ let input = LargeListArrayAlias::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ offsets,
+ Arc::new(values),
+ None,
+ );
+
+ let output =
+ truncate_large_list_array(&input, 2).expect("truncate_large_list_array failed");
+
+ assert_eq!(output.len(), 3);
+ // Each sublist should now have at most 2 elements.
+ let observed_lengths: Vec = (0..output.len())
+ .map(|i| {
+ let start =
+ usize::try_from(output.value_offsets()[i]).expect("offset should fit in usize");
+ let end = usize::try_from(output.value_offsets()[i + 1])
+ .expect("offset should fit in usize");
+ end - start
+ })
+ .collect();
+ assert_eq!(observed_lengths, vec![2, 2, 2]);
+ }
+
+ /// Parent list-level nulls (which slots are NULL lists) must survive
+ /// truncation. The bug we're guarding against: pulling the child array's
+ /// null bitmap and stuffing it into the parent's `nulls` slot, which can
+ /// flip valid slots to NULL or vice versa.
+ #[test]
+ fn test_truncate_list_array_preserves_parent_nulls() {
+ let input = ListArray::from_iter_primitive::(vec![
+ Some(vec![Some(0), Some(1), Some(2), Some(3)]),
+ None,
+ Some(vec![Some(4), Some(5)]),
+ ]);
+ let parent_nulls_before: Vec = (0..input.len()).map(|i| input.is_null(i)).collect();
+
+ let output = truncate_list_array(&input, 2).expect("truncate_list_array failed");
+ let parent_nulls_after: Vec = (0..output.len()).map(|i| output.is_null(i)).collect();
+
+ assert_eq!(parent_nulls_before, parent_nulls_after);
+ assert_eq!(parent_nulls_after, vec![false, true, false]);
+ }
+
+ #[test]
+ fn test_truncate_large_list_array_preserves_parent_nulls() {
+ use arrow::array::{Int32Array, LargeListArray as LargeListArrayAlias};
+ use arrow::buffer::{NullBuffer, OffsetBuffer};
+
+ let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5]);
+ let offsets = OffsetBuffer::::new(vec![0_i64, 3, 3, 6].into());
+ // Mark the middle slot as a NULL list.
+ let nulls = NullBuffer::from(vec![true, false, true]);
+ let input = LargeListArrayAlias::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ offsets,
+ Arc::new(values),
+ Some(nulls),
+ );
+
+ let output =
+ truncate_large_list_array(&input, 2).expect("truncate_large_list_array failed");
+
+ assert!(!output.is_null(0));
+ assert!(output.is_null(1));
+ assert!(!output.is_null(2));
+ }
+
+ /// Truncation must carry the original element `Field` through unchanged.
+ /// Building a fresh field with a hardcoded name ("item") and
+ /// `child_array.is_nullable()` (which reflects the *data*, not the
+ /// declared schema) would silently rewrite the resulting `DataType` and
+ /// drop any field-level metadata.
+ #[test]
+ fn test_truncate_helpers_preserve_element_field() {
+ use arrow::array::{
+ FixedSizeListArray as FixedSizeListArrayAlias, Int32Array,
+ LargeListArray as LargeListArrayAlias, LargeListViewArray as LargeListViewArrayAlias,
+ ListViewArray as ListViewArrayAlias,
+ };
+ use arrow::buffer::{OffsetBuffer, ScalarBuffer};
+
+ let metadata: HashMap = [("origin".to_string(), "audit_test".to_string())]
+ .into_iter()
+ .collect();
+ // Original field has a custom name, declared-nullable=false, and metadata;
+ // none of which the truncation helper should be free to discard.
+ let element_field =
+ Arc::new(Field::new("value", DataType::Int32, false).with_metadata(metadata.clone()));
+
+ let values = || Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]);
+
+ let assert_field_preserved = |actual: &Arc| {
+ assert_eq!(actual.name(), "value");
+ assert!(!actual.is_nullable());
+ assert_eq!(actual.metadata(), &metadata);
+ };
+
+ // List
+ let list = ListArray::new(
+ Arc::clone(&element_field),
+ OffsetBuffer::::new(vec![0_i32, 3, 6, 8].into()),
+ Arc::new(values()),
+ None,
+ );
+ let output = truncate_list_array(&list, 2).expect("truncate List");
+ match output.data_type() {
+ DataType::List(field) => assert_field_preserved(field),
+ other => panic!("unexpected {other:?}"),
+ }
+
+ // LargeList
+ let large_list = LargeListArrayAlias::new(
+ Arc::clone(&element_field),
+ OffsetBuffer::::new(vec![0_i64, 3, 6, 8].into()),
+ Arc::new(values()),
+ None,
+ );
+ let output = truncate_large_list_array(&large_list, 2).expect("truncate LargeList");
+ match output.data_type() {
+ DataType::LargeList(field) => assert_field_preserved(field),
+ other => panic!("unexpected {other:?}"),
+ }
+
+ // FixedSizeList
+ let fixed_size_list = FixedSizeListArrayAlias::new(
+ Arc::clone(&element_field),
+ 4,
+ Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7])),
+ None,
+ );
+ let output =
+ truncate_fixed_size_list_array(&fixed_size_list, 2).expect("truncate FixedSizeList");
+ match output.data_type() {
+ DataType::FixedSizeList(field, _) => assert_field_preserved(field),
+ other => panic!("unexpected {other:?}"),
+ }
+
+ // ListView
+ let list_view = ListViewArrayAlias::try_new(
+ Arc::clone(&element_field),
+ ScalarBuffer::::from(vec![0_i32, 3, 6]),
+ ScalarBuffer::::from(vec![3_i32, 3, 2]),
+ Arc::new(values()),
+ None,
+ )
+ .expect("ListViewArray construction");
+ let output = truncate_list_view_array(&list_view, 2).expect("truncate ListView");
+ match output.data_type() {
+ DataType::ListView(field) => assert_field_preserved(field),
+ other => panic!("unexpected {other:?}"),
+ }
+
+ // LargeListView
+ let large_list_view = LargeListViewArrayAlias::try_new(
+ Arc::clone(&element_field),
+ ScalarBuffer::::from(vec![0_i64, 3, 6]),
+ ScalarBuffer::::from(vec![3_i64, 3, 2]),
+ Arc::new(values()),
+ None,
+ )
+ .expect("LargeListViewArray construction");
+ let output =
+ truncate_large_list_view_array(&large_list_view, 2).expect("truncate LargeListView");
+ match output.data_type() {
+ DataType::LargeListView(field) => assert_field_preserved(field),
+ other => panic!("unexpected {other:?}"),
+ }
+ }
+
+ /// `TruncateUtf8Length` must recurse into every list-like variant, not
+ /// just `List`. Otherwise strings nested under `LargeList` /
+ /// `FixedSizeList` / `ListView` / `LargeListView` silently skip
+ /// truncation when callers run `truncate_string_columns` over a record
+ /// batch.
+ #[test]
+ fn test_truncate_utf8_recurses_into_all_list_variants() {
+ use arrow::array::{
+ FixedSizeListArray as FixedSizeListArrayAlias, LargeListArray as LargeListArrayAlias,
+ LargeListViewArray as LargeListViewArrayAlias, ListViewArray as ListViewArrayAlias,
+ StringArray,
+ };
+ use arrow::buffer::{OffsetBuffer, ScalarBuffer};
+
+ fn assert_first_string_truncated(arr: &ArrayRef, expected_first_char: &str) {
+ // Each variant stores its UTF8 elements contiguously in the child
+ // array; inspecting element 0 is enough to prove truncation ran.
+ let dt = arr.data_type().clone();
+ let child: ArrayRef = match &dt {
+ DataType::List(_) => Arc::clone(
+ arr.as_any()
+ .downcast_ref::()
+ .expect("ListArray")
+ .values(),
+ ),
+ DataType::LargeList(_) => Arc::clone(
+ arr.as_any()
+ .downcast_ref::()
+ .expect("LargeListArray")
+ .values(),
+ ),
+ DataType::FixedSizeList(_, _) => Arc::clone(
+ arr.as_any()
+ .downcast_ref::()
+ .expect("FixedSizeListArray")
+ .values(),
+ ),
+ DataType::ListView(_) => Arc::clone(
+ arr.as_any()
+ .downcast_ref::()
+ .expect("ListViewArray")
+ .values(),
+ ),
+ DataType::LargeListView(_) => Arc::clone(
+ arr.as_any()
+ .downcast_ref::()
+ .expect("LargeListViewArray")
+ .values(),
+ ),
+ other => panic!("unexpected outer type {other:?}"),
+ };
+ let strings = child
+ .as_any()
+ .downcast_ref::()
+ .expect("StringArray child");
+ assert_eq!(strings.value(0), expected_first_char);
+ }
+
+ let inner_field = Arc::new(Field::new("item", DataType::Utf8, true));
+ let strings = || StringArray::from(vec!["abcdef", "ghijkl"]);
+
+ // List
+ let list = ListArray::new(
+ Arc::clone(&inner_field),
+ OffsetBuffer::::new(vec![0_i32, 1, 2].into()),
+ Arc::new(strings()),
+ None,
+ );
+ let truncated = format_column_data(
+ Arc::new(list) as ArrayRef,
+ &Arc::new(Field::new(
+ "col",
+ DataType::List(Arc::clone(&inner_field)),
+ true,
+ )),
+ FormatOperation::TruncateUtf8Length(1),
+ )
+ .expect("List Utf8 truncate");
+ assert_first_string_truncated(&truncated, "a");
+
+ // LargeList
+ let large_list = LargeListArrayAlias::new(
+ Arc::clone(&inner_field),
+ OffsetBuffer::::new(vec![0_i64, 1, 2].into()),
+ Arc::new(strings()),
+ None,
+ );
+ let truncated = format_column_data(
+ Arc::new(large_list) as ArrayRef,
+ &Arc::new(Field::new(
+ "col",
+ DataType::LargeList(Arc::clone(&inner_field)),
+ true,
+ )),
+ FormatOperation::TruncateUtf8Length(1),
+ )
+ .expect("LargeList Utf8 truncate");
+ assert_first_string_truncated(&truncated, "a");
+
+ // FixedSizeList[1]
+ let fsl =
+ FixedSizeListArrayAlias::new(Arc::clone(&inner_field), 1, Arc::new(strings()), None);
+ let truncated = format_column_data(
+ Arc::new(fsl) as ArrayRef,
+ &Arc::new(Field::new(
+ "col",
+ DataType::FixedSizeList(Arc::clone(&inner_field), 1),
+ true,
+ )),
+ FormatOperation::TruncateUtf8Length(1),
+ )
+ .expect("FixedSizeList Utf8 truncate");
+ assert_first_string_truncated(&truncated, "a");
+
+ // ListView
+ let list_view = ListViewArrayAlias::try_new(
+ Arc::clone(&inner_field),
+ ScalarBuffer::::from(vec![0_i32, 1]),
+ ScalarBuffer::::from(vec![1_i32, 1]),
+ Arc::new(strings()),
+ None,
+ )
+ .expect("ListViewArray construction");
+ let truncated = format_column_data(
+ Arc::new(list_view) as ArrayRef,
+ &Arc::new(Field::new(
+ "col",
+ DataType::ListView(Arc::clone(&inner_field)),
+ true,
+ )),
+ FormatOperation::TruncateUtf8Length(1),
+ )
+ .expect("ListView Utf8 truncate");
+ assert_first_string_truncated(&truncated, "a");
+
+ // LargeListView
+ let large_list_view = LargeListViewArrayAlias::try_new(
+ Arc::clone(&inner_field),
+ ScalarBuffer::::from(vec![0_i64, 1]),
+ ScalarBuffer::::from(vec![1_i64, 1]),
+ Arc::new(strings()),
+ None,
+ )
+ .expect("LargeListViewArray construction");
+ let truncated = format_column_data(
+ Arc::new(large_list_view) as ArrayRef,
+ &Arc::new(Field::new(
+ "col",
+ DataType::LargeListView(Arc::clone(&inner_field)),
+ true,
+ )),
+ FormatOperation::TruncateUtf8Length(1),
+ )
+ .expect("LargeListView Utf8 truncate");
+ assert_first_string_truncated(&truncated, "a");
+ }
+
+ /// `arrow::compute::concat` errors when given an empty slice of arrays.
+ /// 0-row record batches are a realistic input (e.g. a sample query that
+ /// returns no rows), so each list truncation helper must short-circuit
+ /// before reaching `concat`.
+ #[test]
+ fn test_truncate_helpers_handle_empty_arrays() {
+ use arrow::array::{
+ FixedSizeListArray as FixedSizeListArrayAlias, Int32Array,
+ LargeListArray as LargeListArrayAlias, LargeListViewArray as LargeListViewArrayAlias,
+ ListViewArray as ListViewArrayAlias,
+ };
+ use arrow::buffer::{OffsetBuffer, ScalarBuffer};
+
+ // Empty ListArray
+ let empty_list = ListArray::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ OffsetBuffer::::new(vec![0_i32].into()),
+ Arc::new(Int32Array::from(Vec::::new())),
+ None,
+ );
+ assert_eq!(
+ truncate_list_array(&empty_list, 4)
+ .expect("empty ListArray must round-trip")
+ .len(),
+ 0
+ );
+
+ // Empty LargeListArray
+ let empty_large_list = LargeListArrayAlias::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ OffsetBuffer::::new(vec![0_i64].into()),
+ Arc::new(Int32Array::from(Vec::::new())),
+ None,
+ );
+ assert_eq!(
+ truncate_large_list_array(&empty_large_list, 4)
+ .expect("empty LargeListArray must round-trip")
+ .len(),
+ 0
+ );
+
+ // Empty FixedSizeListArray
+ let empty_fsl = FixedSizeListArrayAlias::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ 3,
+ Arc::new(Int32Array::from(Vec::::new())),
+ None,
+ );
+ assert_eq!(
+ truncate_fixed_size_list_array(&empty_fsl, 2)
+ .expect("empty FixedSizeListArray must round-trip")
+ .len(),
+ 0
+ );
+
+ // Empty ListViewArray
+ let empty_list_view = ListViewArrayAlias::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ ScalarBuffer::::from(Vec::::new()),
+ ScalarBuffer::::from(Vec::::new()),
+ Arc::new(Int32Array::from(Vec::::new())),
+ None,
+ )
+ .expect("empty ListViewArray construction");
+ assert_eq!(
+ truncate_list_view_array(&empty_list_view, 4)
+ .expect("empty ListViewArray must round-trip")
+ .len(),
+ 0
+ );
+
+ // Empty LargeListViewArray
+ let empty_large_list_view = LargeListViewArrayAlias::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ ScalarBuffer::::from(Vec::::new()),
+ ScalarBuffer::::from(Vec::::new()),
+ Arc::new(Int32Array::from(Vec::::new())),
+ None,
+ )
+ .expect("empty LargeListViewArray construction");
+ assert_eq!(
+ truncate_large_list_view_array(&empty_large_list_view, 4)
+ .expect("empty LargeListViewArray must round-trip")
+ .len(),
+ 0
+ );
+ }
+
+ #[test]
+ fn test_truncate_fixed_size_list_array_preserves_parent_nulls() {
+ use arrow::array::{FixedSizeListArray as FixedSizeListArrayAlias, Int32Array};
+ use arrow::buffer::NullBuffer;
+
+ let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
+ let nulls = NullBuffer::from(vec![true, false, true]);
+ let input = FixedSizeListArrayAlias::new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ 3,
+ Arc::new(values),
+ Some(nulls),
+ );
+
+ let output = truncate_fixed_size_list_array(&input, 2)
+ .expect("truncate_fixed_size_list_array failed");
+
+ assert!(!output.is_null(0));
+ assert!(output.is_null(1));
+ assert!(!output.is_null(2));
+ }
+
+ #[test]
+ fn test_truncate_list_view_array() {
+ use arrow::array::{Int32Array, ListViewArray as ListViewArrayAlias};
+ use arrow::buffer::ScalarBuffer;
+
+ let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]);
+ let offsets = ScalarBuffer::::from(vec![0_i32, 3, 6]);
+ let sizes = ScalarBuffer::::from(vec![3_i32, 3, 2]);
+ let input = ListViewArrayAlias::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ offsets,
+ sizes,
+ Arc::new(values),
+ None,
+ )
+ .expect("ListViewArray::try_new");
+
+ let output = truncate_list_view_array(&input, 2).expect("truncate_list_view_array failed");
+
+ assert_eq!(output.len(), 3);
+ // After truncation each entry has at most 2 elements.
+ for size in output.value_sizes() {
+ assert!(*size <= 2, "size {size} exceeded max_len");
+ }
+ }
+
+ #[test]
+ fn test_truncate_large_list_view_array() {
+ use arrow::array::{Int32Array, LargeListViewArray as LargeListViewArrayAlias};
+ use arrow::buffer::ScalarBuffer;
+
+ let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]);
+ let offsets = ScalarBuffer::::from(vec![0_i64, 3, 6]);
+ let sizes = ScalarBuffer::::from(vec![3_i64, 3, 2]);
+ let input = LargeListViewArrayAlias::try_new(
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ offsets,
+ sizes,
+ Arc::new(values),
+ None,
+ )
+ .expect("LargeListViewArray::try_new");
+
+ let output = truncate_large_list_view_array(&input, 2)
+ .expect("truncate_large_list_view_array failed");
+
+ assert_eq!(output.len(), 3);
+ for size in output.value_sizes() {
+ assert!(*size <= 2, "size {size} exceeded max_len");
+ }
+ }
+
+ #[test]
+ fn test_get_possible_nested_list_datatype_view_variants() {
+ let inner = DataType::Int32;
+ let cases = vec![
+ DataType::List(Arc::new(Field::new("item", inner.clone(), true))),
+ DataType::LargeList(Arc::new(Field::new("item", inner.clone(), true))),
+ DataType::FixedSizeList(Arc::new(Field::new("item", inner.clone(), true)), 4),
+ DataType::ListView(Arc::new(Field::new("item", inner.clone(), true))),
+ DataType::LargeListView(Arc::new(Field::new("item", inner.clone(), true))),
+ ];
+ for dt in cases {
+ let f = Arc::new(Field::new("col", dt, true));
+ let (_, inner_dt) = get_possible_nested_list_datatype(&f);
+ assert_eq!(inner_dt.as_ref(), Some(&inner), "field {f:?}");
+ }
+ }
+
#[test]
fn test_truncate_fixed_size_list_array() {
let test_cases: Vec<(&str, usize, FixedSizeListArray)> = vec![
@@ -751,4 +1755,196 @@ Cras venenatis euismod malesuada.",
metadata: struct (nullable)
");
}
+
+ /// `max_len = 0` is a valid (if unusual) truncation request: every list
+ /// must become a 0-element list while parent nulls and the element Field
+ /// (including metadata/nullable) are preserved exactly.
+ #[test]
+ fn test_truncate_list_to_zero_elements() {
+ use arrow::array::{Int32Array, ListArray as ListArrayAlias};
+ use arrow::buffer::{NullBuffer, OffsetBuffer};
+
+ let element_field = Arc::new(
+ Field::new("value", DataType::Int32, false)
+ .with_metadata([("audit".to_string(), "zero_len".to_string())].into()),
+ );
+
+ // List with mixed lengths + one parent NULL
+ // lists: [0,1,2], [], [3,4,5], [6] (values 0..6)
+ let list = ListArrayAlias::new(
+ Arc::clone(&element_field),
+ OffsetBuffer::::new(vec![0_i32, 3, 3, 6, 7].into()),
+ Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])),
+ Some(NullBuffer::from(vec![true, false, true, true])),
+ );
+ let out = truncate_list_array(&list, 0).expect("truncate to 0");
+ assert_eq!(out.len(), 4);
+ assert!(out.is_null(1));
+ assert!(!out.is_null(0));
+ // All non-null lists must report length 0
+ for i in [0, 2, 3] {
+ assert!(!out.is_null(i));
+ assert_eq!(out.value(i).len(), 0);
+ }
+ // Element field preserved
+ if let DataType::List(f) = out.data_type() {
+ assert_eq!(f.name(), "value");
+ assert!(!f.is_nullable());
+ assert_eq!(f.metadata().get("audit"), Some(&"zero_len".to_string()));
+ } else {
+ panic!("expected List");
+ }
+ }
+
+ /// Fast-path + `max_len=0` for `ListView` (non-contiguous layout).
+ #[test]
+ fn test_truncate_list_view_to_zero_and_fast_path() {
+ use arrow::array::{Int32Array, ListViewArray as ListViewArrayAlias};
+ use arrow::buffer::{NullBuffer, ScalarBuffer};
+
+ let element_field = Arc::new(Field::new("item", DataType::Int32, true));
+ let values = Int32Array::from(vec![10, 20, 30, 40]);
+ // Two lists: [10,20,30] at offset 0 size 3, and a NULL list, and [40]
+ let offsets = ScalarBuffer::::from(vec![0_i32, 0, 3]);
+ let sizes = ScalarBuffer::::from(vec![3_i32, 0, 1]);
+ let nulls = NullBuffer::from(vec![true, false, true]);
+ let input = ListViewArrayAlias::try_new(
+ Arc::clone(&element_field),
+ offsets,
+ sizes,
+ Arc::new(values),
+ Some(nulls),
+ )
+ .expect("ListView construction");
+
+ // Fast path: max_len larger than any size (including the 0-size one)
+ let out_fast = truncate_list_view_array(&input, 100).expect("fast path");
+ assert_eq!(out_fast.len(), 3);
+ assert!(out_fast.is_null(1));
+ assert_eq!(out_fast.value_sizes(), &[3, 0, 1]); // unchanged layout
+
+ // max_len=0 path
+ let out_zero = truncate_list_view_array(&input, 0).expect("zero len view");
+ assert_eq!(out_zero.len(), 3);
+ assert!(out_zero.is_null(1));
+ for i in [0, 2] {
+ assert!(!out_zero.is_null(i));
+ assert_eq!(out_zero.value(i).len(), 0);
+ }
+ }
+
+ /// Using the public `truncate_numeric_column_length` API with a `RecordBatch`
+ /// that contains a mix of list columns that do and do not need truncation.
+ /// When the fast path triggers for some columns, the original column Arc
+ /// should be reused (identity) for those that didn't change.
+ #[test]
+ fn test_truncate_numeric_column_length_fast_path_and_mixed() {
+ use arrow::array::{Int32Array, ListArray as ListArrayAlias};
+ use arrow::buffer::OffsetBuffer;
+ use arrow::datatypes::Field as ArrowField;
+
+ use crate::record_batch::truncate_numeric_column_length;
+
+ let schema = Arc::new(Schema::new(vec![
+ ArrowField::new(
+ "short_lists",
+ DataType::List(Arc::new(Field::new("v", DataType::Int32, true))),
+ true,
+ ),
+ ArrowField::new(
+ "long_lists",
+ DataType::List(Arc::new(Field::new("v", DataType::Int32, true))),
+ true,
+ ),
+ ]));
+
+ let short = ListArrayAlias::new(
+ Arc::new(Field::new("v", DataType::Int32, true)),
+ OffsetBuffer::::new(vec![0_i32, 1, 2].into()),
+ Arc::new(Int32Array::from(vec![1, 2])),
+ None,
+ );
+ let long = ListArrayAlias::new(
+ Arc::new(Field::new("v", DataType::Int32, true)),
+ OffsetBuffer::::new(vec![0_i32, 5, 10].into()),
+ Arc::new(Int32Array::from((0..10).collect::>())),
+ None,
+ );
+
+ let batch =
+ RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(short), Arc::new(long)])
+ .expect("batch");
+
+ // max_elements=3 : short_lists hits fast path (no change), long_lists truncated
+ let processed =
+ truncate_numeric_column_length(&batch, 3).expect("truncate_numeric_column_length");
+
+ assert_eq!(processed.num_columns(), 2);
+ // First column should be the exact same Arc (fast path clone in practice
+ // returns the input, but RecordBatch construction may not dedup Arcs;
+ // we at least verify logical identity of data).
+ let short_in = batch.column(0);
+ let short_out = processed.column(0);
+ assert_eq!(short_in.data_type(), short_out.data_type());
+ assert_eq!(short_in.len(), short_out.len());
+ // Second column must have been truncated (lengths now <=3)
+ let long_out = processed
+ .column(1)
+ .as_any()
+ .downcast_ref::()
+ .expect("ListArray");
+ for i in 0..long_out.len() {
+ assert!(long_out.value(i).len() <= 3);
+ }
+ }
+
+ /// `max_len = 0` must work correctly for `ListView` (non-contiguous layout).
+ /// This exercises the `ListView` truncation path with the most aggressive
+ /// truncation limit, ensuring parent nulls and element Field are preserved
+ /// even when every list is reduced to zero elements.
+ #[test]
+ fn test_truncate_list_view_to_zero_elements() {
+ use arrow::array::{Int32Array, ListViewArray as ListViewArrayAlias};
+ use arrow::buffer::{NullBuffer, ScalarBuffer};
+
+ let element_field = Arc::new(
+ Field::new("value", DataType::Int32, false)
+ .with_metadata([("audit".to_string(), "zero_len_view".to_string())].into()),
+ );
+
+ // ListView with mixed lengths + one parent NULL (3 lists)
+ let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6]);
+ let offsets = ScalarBuffer::::from(vec![0_i32, 3, 3]); // 3 lists
+ let sizes = ScalarBuffer::::from(vec![3_i32, 0, 3]);
+ let nulls = NullBuffer::from(vec![true, false, true]); // 3 entries
+ let input = ListViewArrayAlias::try_new(
+ Arc::clone(&element_field),
+ offsets,
+ sizes,
+ Arc::new(values),
+ Some(nulls),
+ )
+ .expect("ListView construction");
+
+ let out = truncate_list_view_array(&input, 0).expect("truncate ListView to 0");
+
+ assert_eq!(out.len(), 3);
+ assert!(out.is_null(1));
+ assert!(!out.is_null(0));
+ assert!(!out.is_null(2));
+ assert_eq!(out.value(0).len(), 0);
+ assert_eq!(out.value(2).len(), 0);
+
+ // Element field preserved
+ if let DataType::ListView(f) = out.data_type() {
+ assert_eq!(f.name(), "value");
+ assert!(!f.is_nullable());
+ assert_eq!(
+ f.metadata().get("audit"),
+ Some(&"zero_len_view".to_string())
+ );
+ } else {
+ panic!("expected ListView");
+ }
+ }
}
diff --git a/crates/cayenne/benches/mutation_writer.rs b/crates/cayenne/benches/mutation_writer.rs
index 8c84e77dfe..5b12a3d8e2 100644
--- a/crates/cayenne/benches/mutation_writer.rs
+++ b/crates/cayenne/benches/mutation_writer.rs
@@ -261,5 +261,82 @@ fn bench_inline_mutation_paths(c: &mut Criterion) {
pressure.finish();
}
-criterion_group!(benches, bench_append_roundtrip, bench_inline_mutation_paths);
+/// Benchmarks the directory durability primitives added for ACID correctness
+/// on local FS (parent-directory `sync_all` after `create_dir_all` for
+/// snapshot directories, _partitioned_wal/, and deletions/ subdirs).
+///
+/// These one-time-per-snapshot or per-table costs are the direct result of
+/// the durability hardening. The benchmark quantifies the "tax" for Q21
+/// workloads that trigger frequent compactions or cross-partition operations.
+fn bench_directory_durability_primitives(c: &mut Criterion) {
+ let rt = Runtime::new().expect("runtime");
+
+ let mut group = c.benchmark_group("directory_durability_sync_all");
+ // These are one-time operations; a smaller sample size is sufficient
+ // to get stable numbers without making the bench too slow.
+ group.sample_size(30);
+
+ group.bench_function("create_dir_all_plus_parent_sync", |b| {
+ b.iter_batched(
+ || {
+ let temp = tempfile::tempdir().expect("tempdir for bench");
+ let parent = temp.path().to_path_buf();
+ let child = parent.join("new_snapshot_or_wal_or_deletions_dir");
+ (temp, parent, child)
+ },
+ |(_keep_alive, parent, child)| {
+ rt.block_on(async {
+ // Replicate the exact hardened pattern used in
+ // ensure_snapshot_dir_exists, ensure_partitioned_wal_dir_and_sync_parent,
+ // and the deletions/ subdir creation in DeletionVectorWriter.
+ if !child.exists() {
+ tokio::fs::create_dir_all(&child)
+ .await
+ .expect("create_dir_all");
+ let p = parent.clone();
+ let _ = tokio::task::spawn_blocking(move || {
+ std::fs::File::open(&p).and_then(|f| f.sync_all())
+ })
+ .await;
+ }
+ black_box(child);
+ });
+ },
+ BatchSize::SmallInput,
+ );
+ });
+
+ // For comparison: the cost of create_dir_all *without* the parent sync.
+ // This shows the incremental cost of the durability guarantee.
+ group.bench_function("create_dir_all_without_sync", |b| {
+ b.iter_batched(
+ || {
+ let temp = tempfile::tempdir().expect("tempdir for bench");
+ let parent = temp.path().to_path_buf();
+ let child = parent.join("new_snapshot_without_sync");
+ (temp, parent, child)
+ },
+ |(_keep_alive, _parent, child)| {
+ rt.block_on(async {
+ if !child.exists() {
+ tokio::fs::create_dir_all(&child)
+ .await
+ .expect("create_dir_all");
+ }
+ black_box(child);
+ });
+ },
+ BatchSize::SmallInput,
+ );
+ });
+
+ group.finish();
+}
+
+criterion_group!(
+ benches,
+ bench_append_roundtrip,
+ bench_inline_mutation_paths,
+ bench_directory_durability_primitives
+);
criterion_main!(benches);
diff --git a/crates/cayenne/src/cayenne_catalog.rs b/crates/cayenne/src/cayenne_catalog.rs
index b6a5d9088d..3f24ecb235 100644
--- a/crates/cayenne/src/cayenne_catalog.rs
+++ b/crates/cayenne/src/cayenne_catalog.rs
@@ -379,6 +379,39 @@ impl MetadataCatalog for CayenneCatalog {
if !db_dir.exists() {
tokio::fs::create_dir_all(db_dir).await?;
+
+ // Best-effort sync of the parent directory so the db_dir entry
+ // itself is durable on local FS before we proceed to create the
+ // catalog DB file and initialize its schema.
+ //
+ // We keep this best-effort (with warning on failure) rather than
+ // fatal because:
+ // - Catalog DB directory creation is a one-time initialization
+ // event (not a hot write path).
+ // - It is immediately followed by DB file creation and schema
+ // initialization, which provide strong content durability.
+ // - The parent directory is frequently a stable, operator-
+ // managed volume root (e.g., K8s PersistentVolume) where
+ // directory entry durability is already handled at a higher
+ // level.
+ //
+ // This is still the right thing to do for consistency with the
+ // uniform durability contract used for all per-table mutable
+ // data paths, and it gives operators a clear warning if
+ // something unusual happens on a fresh deployment.
+ if let Some(parent) = db_dir.parent() {
+ let parent = parent.to_path_buf();
+ if let Err(e) = tokio::task::spawn_blocking(move || {
+ std::fs::File::open(&parent).and_then(|f| f.sync_all())
+ })
+ .await
+ {
+ tracing::warn!(
+ "Failed to sync parent of catalog DB directory {} (subsequent DB writes will still be durable; directory entry may not survive crash): {e}",
+ db_dir.display()
+ );
+ }
+ }
}
// Initialize schema using the appropriate metastore backend
@@ -476,6 +509,35 @@ impl MetadataCatalog for CayenneCatalog {
// Generate initial snapshot UUID
let initial_snapshot_id = uuid::Uuid::now_v7().to_string();
+ // Create the initial snapshot directory *before* inserting the table
+ // row into the metastore. This ensures the directory entry is durable
+ // (with parent sync of the table root) before the catalog "commits"
+ // the existence of a table pointing at this snapshot_id. This is the
+ // final piece of the uniform local-FS durability contract (snapshot
+ // dirs, _partitioned_wal/, deletions/, and now initial table creation).
+ // Matches the contract we enforce everywhere else in the write path.
+ if !base_path.starts_with("s3://") {
+ let table_root = std::path::PathBuf::from(&base_path).join(&table_id);
+ let snapshot_dir = table_root.join(&initial_snapshot_id);
+
+ if !snapshot_dir.exists() {
+ tokio::fs::create_dir_all(&snapshot_dir)
+ .await
+ .map_err(|e| CatalogError::Io { source: e })?;
+
+ // Sync the table root (parent of the new snapshot dir) so the
+ // subdir entry is durable on local FS. Best-effort on the sync
+ // itself (creation failure is already fatal above); this is
+ // the same pattern used for the first _partitioned_wal/ and
+ // first deletions/ subdirs.
+ let table_root_for_sync = table_root.clone();
+ let _ = tokio::task::spawn_blocking(move || {
+ let _ = std::fs::File::open(&table_root_for_sync).and_then(|f| f.sync_all());
+ })
+ .await;
+ }
+ }
+
// Serialize Vortex config to JSON
let vortex_config_json = serde_json::to_string(&options.vortex_config).map_err(|e| {
CatalogError::InvalidOperation {
@@ -534,18 +596,9 @@ impl MetadataCatalog for CayenneCatalog {
Err(e) => return Err(e),
}
- // Create the initial snapshot directory (only for local paths)
- // Directory structure: [base_path]/[table_id]/[snapshot_id]/
- // For S3 paths, directories are virtual and created when files are written
- if !base_path.starts_with("s3://") {
- let snapshot_dir = std::path::PathBuf::from(&base_path)
- .join(&table_id)
- .join(&initial_snapshot_id);
-
- tokio::fs::create_dir_all(&snapshot_dir)
- .await
- .map_err(|e| CatalogError::Io { source: e })?;
- }
+ // The initial snapshot directory was already created (with parent
+ // sync) before the metastore INSERT, so the catalog row now points
+ // at a durable directory. Nothing more to do here for local FS.
Ok(table_id)
}
@@ -1966,13 +2019,39 @@ async fn ensure_snapshot_directory_exists(table: &TableMetadata) -> CatalogResul
return Ok(());
}
- let snapshot_dir = std::path::PathBuf::from(&table.path)
- .join(&table.table_id)
- .join(&table.current_snapshot_id);
+ let table_root = std::path::PathBuf::from(&table.path).join(&table.table_id);
+ let snapshot_dir = table_root.join(&table.current_snapshot_id);
+
+ match tokio::fs::metadata(&snapshot_dir).await {
+ Ok(metadata) if metadata.is_dir() => return Ok(()),
+ Ok(_) => {
+ return Err(CatalogError::Io {
+ source: std::io::Error::new(
+ std::io::ErrorKind::AlreadyExists,
+ format!(
+ "snapshot path '{}' exists but is not a directory",
+ snapshot_dir.display()
+ ),
+ ),
+ });
+ }
+ Err(source) if source.kind() == std::io::ErrorKind::NotFound => {}
+ Err(source) => return Err(CatalogError::Io { source }),
+ }
tokio::fs::create_dir_all(&snapshot_dir)
.await
- .map_err(|e| CatalogError::Io { source: e })
+ .map_err(|source| CatalogError::Io { source })?;
+
+ // Sync parent (table root) for the same durability reason as the
+ // initial creation path above and all other new subdir creations.
+ let table_root_for_sync = table_root;
+ let _ = tokio::task::spawn_blocking(move || {
+ let _ = std::fs::File::open(&table_root_for_sync).and_then(|f| f.sync_all());
+ })
+ .await;
+
+ Ok(())
}
/// Checks if the existing stored configuration matches the new [`CreateTableOptions`].
diff --git a/crates/cayenne/src/lib.rs b/crates/cayenne/src/lib.rs
index 5280c70708..bad547db01 100644
--- a/crates/cayenne/src/lib.rs
+++ b/crates/cayenne/src/lib.rs
@@ -60,6 +60,7 @@ pub mod cayenne_catalog;
pub mod ddl;
#[cfg(feature = "partition-table-provider")]
pub use ddl::CayenneDdlHandler;
+pub mod logical_optimizer;
pub mod metadata;
pub mod metastore;
pub mod optimizer_rules;
diff --git a/crates/cayenne/src/logical_optimizer.rs b/crates/cayenne/src/logical_optimizer.rs
new file mode 100644
index 0000000000..c8d522dd8d
--- /dev/null
+++ b/crates/cayenne/src/logical_optimizer.rs
@@ -0,0 +1,1899 @@
+/*
+Copyright 2026 The Spice.ai OSS Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+//! Logical optimizer rules for Cayenne.
+//!
+//! The flagship rule here is [`CayennePropagateFilterAcrossEquiJoinKeys`], the
+//! plan-time predicate transitive closure used to unblock chbench q21 (see
+//! `crates/cayenne/src/optimizer_rules.rs` module docs for the broader
+//! no-spill strategy this fits into).
+//!
+//! `DataFusion`'s stock `infer_join_predicates` (in `push_down_filter`) already
+//! propagates predicates that *directly* reference a join-key column:
+//! `WHERE nation.n_nationkey = 5 AND nation.n_nationkey = supplier.s_nationkey`
+//! is transformed into `WHERE supplier.s_nationkey = 5 AND ...`. That covers
+//! the `n_nationkey = $const` shape but misses the q21 shape, where the
+//! selective filter is on a *non-key* column (`n_name = 'CHINA'`). The
+//! cardinality bound the dim-table filter implies for the equi-joined key
+//! column never reaches the fact-table scans, so by the time the planner
+//! orders joins from the SQL `FROM` clause, `(supplier, order_line, …)`
+//! has already been chosen with no nation filter pushed through.
+//!
+//! ## What the rule does
+//!
+//! For every `LogicalPlan::Join` with `JoinType::Inner`, `JoinType::LeftSemi`,
+//! `JoinType::RightSemi`, `JoinType::Left`, or `JoinType::Right`, default SQL
+//! NULL equality (`NULL != NULL`), and one or more equi-key pairs whose data
+//! types match, the rule inspects each side for a non-trivial `Filter` that
+//! references at least one column other than each candidate join key. If one
+//! side is dim-like and has a projectable column key, it wraps the *opposite*
+//! side with
+//!
+//! ```text
+//! Filter(other_side.key IN (SELECT this_side.key FROM this_side_subtree))
+//! ```
+//!
+//! The inserted subquery re-projects the join key through whatever filters
+//! already exist on the original side, so `DataFusion`'s
+//! `decorrelate_predicate_subquery` and `push_down_filter` can then plant a
+//! `LeftSemi` join (or, after pushdown, a partition-pruning predicate) on
+//! the fact-table scan. For q21 this turns
+//! `nation ⋈ supplier ⋈ order_line` into a shape where `supplier.s_nationkey
+//! IN (SELECT n_nationkey FROM nation WHERE n_name = 'CHINA')` is visible
+//! while the join graph is being costed.
+//!
+//! Semi-join coverage is what makes chained propagation work: after
+//! `decorrelate_predicate_subquery` rewrites a propagated `InSubquery` into a
+//! `LeftSemi` join, the next optimizer pass can keep propagating across
+//! adjacent inner joins (e.g. `region → nation → supplier → fact`) instead of
+//! halting at the semi-join boundary. Propagation correctness on
+//! `LeftSemi`/`RightSemi` follows from the join's existing key-domain
+//! semantics: wrapping either input with `IN (SELECT key FROM other_side)`
+//! produces a subset of rows that the semi-join would already retain.
+//!
+//! For outer joins (`Left`, `Right`) the rule fires *only* in the
+//! preserved-side → lookup-side direction. Filtering the lookup side narrows
+//! matches the outer join would already drop (and substitute `NULL` for);
+//! filtering the preserved side would silently delete rows the outer join is
+//! supposed to emit as `NULL`-padded, which would change the output.
+//! `FullOuter` is excluded — both sides are preserved, so neither direction is
+//! safe.
+//!
+//! ## Termination
+//!
+//! Each introduced subquery is wrapped in a `SubqueryAlias` whose name
+//! starts with [`PROPAGATED_FILTER_ALIAS_PREFIX`]. Before firing, the rule
+//! walks the candidate side's filter chain and refuses to re-introduce a
+//! propagated filter for the same target key. This prevents the rule from
+//! oscillating with itself when the optimizer iterates to fixed point, while
+//! still allowing composite joins to receive one derived filter per key.
+//!
+//! ## Conservatism
+//!
+//! The rule only fires when the side providing the filter is dim-like: a small
+//! subtree with at most [`MAX_DIM_LIKE_TABLE_SCANS`] table scans behind
+//! identity-preserving operators and inner joins. Joining a non-trivial subtree
+//! would risk duplicate-executing a large plan inside the subquery, since
+//! `DataFusion` does not currently de-duplicate plan-level common subexpressions
+//! across the outer plan and an `InSubquery`. The dim-table-filter shape
+//! (`Filter(n_name='CHINA') → TableScan(nation)`) and small dimension snowflakes
+//! are cheap to re-execute.
+//!
+//! Two cardinality gates further suppress propagations that wouldn't pay off
+//! at runtime, when the underlying [`TableSource`]s expose row counts via
+//! `TableProvider::statistics`:
+//!
+//! * [`MIN_DIM_ROWS_FOR_PROPAGATION`] — skip when the dim subtree's known
+//! upper-bound row count is below the threshold. Very small dims (≪ 1k
+//! rows) already participate in fast hash builds; the extra `InSubquery →
+//! LeftSemi` shape we'd introduce doesn't recover its own decorrelation /
+//! planning cost.
+//! * [`MIN_FACT_ROWS_FOR_PROPAGATION`] — skip when the receiving fact
+//! subtree's known upper-bound row count is below the threshold. Below it
+//! there isn't enough probe-side cardinality for the filter to save
+//! meaningful work, and the plain hash join wins.
+//!
+//! Both gates only fire when stats are present (`Precision::Exact` or
+//! `Precision::Inexact`); missing stats fall back to the structural behavior.
+
+use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
+use datafusion::common::{Column, DataFusionError, NullEquality, Result, Spans, TableReference};
+use datafusion::logical_expr::{
+ Filter, Join, JoinType, LogicalPlan, Projection, Subquery, SubqueryAlias,
+};
+use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
+use datafusion_expr::Expr;
+use datafusion_expr::ExprSchemable;
+use datafusion_expr::expr::InSubquery;
+use std::{collections::BTreeSet, sync::Arc};
+
+/// Prefix for [`SubqueryAlias`] names introduced by
+/// [`CayennePropagateFilterAcrossEquiJoinKeys`].
+///
+/// Used both as a sentinel for key-scoped cycle detection (the rule refuses to
+/// add another propagated filter for a target key that already has one) and as
+/// a marker in explain output so the rewrite is recognizable when reading plans.
+pub const PROPAGATED_FILTER_ALIAS_PREFIX: &str = "__cayenne_xclos__";
+
+/// Logical optimizer rule that, for each `Inner`, `LeftSemi`, or `RightSemi`
+/// join with default SQL NULL equality and a simple equi-key
+/// `(left.a = right.b)`, introduces
+/// `Filter(other_side.key IN (SELECT this_side.key FROM this_side_subtree))`
+/// on the side opposite a non-key filter.
+///
+/// See the module-level docs for the full design and the q21 motivation.
+#[derive(Default)]
+pub struct CayennePropagateFilterAcrossEquiJoinKeys;
+
+impl CayennePropagateFilterAcrossEquiJoinKeys {
+ /// Create a new instance of the rule.
+ #[must_use]
+ pub fn new() -> Self {
+ Self
+ }
+}
+
+impl std::fmt::Debug for CayennePropagateFilterAcrossEquiJoinKeys {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("CayennePropagateFilterAcrossEquiJoinKeys")
+ .finish()
+ }
+}
+
+impl OptimizerRule for CayennePropagateFilterAcrossEquiJoinKeys {
+ fn name(&self) -> &'static str {
+ "cayenne_propagate_filter_across_equi_join_keys"
+ }
+
+ fn apply_order(&self) -> Option {
+ // TopDown: process outer joins first so the propagation seeds reach
+ // inner joins on the next pass.
+ Some(ApplyOrder::TopDown)
+ }
+
+ fn rewrite(
+ &self,
+ plan: LogicalPlan,
+ config: &dyn OptimizerConfig,
+ ) -> Result, DataFusionError> {
+ let join = match plan {
+ LogicalPlan::Join(j) => j,
+ other => return Ok(Transformed::no(other)),
+ };
+ if !matches!(
+ join.join_type,
+ JoinType::Inner
+ | JoinType::LeftSemi
+ | JoinType::RightSemi
+ | JoinType::Left
+ | JoinType::Right,
+ ) {
+ return Ok(Transformed::no(LogicalPlan::Join(join)));
+ }
+ if join.null_equality != NullEquality::NullEqualsNothing {
+ return Ok(Transformed::no(LogicalPlan::Join(join)));
+ }
+ if matches!(join.join_type, JoinType::LeftSemi)
+ && right_side_carries_propagation_marker(&join.right)
+ {
+ return Ok(Transformed::no(LogicalPlan::Join(join)));
+ }
+ // For outer joins, propagation is only safe in the *preserved-side →
+ // lookup-side* direction. Filtering the lookup side can only narrow
+ // matches that the join would already drop; filtering the preserved
+ // side would drop output rows that the outer join would have emitted
+ // as `NULL`-padded. Inner and semi joins are unrestricted.
+ let allow_left_to_right = matches!(
+ join.join_type,
+ JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Left,
+ );
+ let allow_right_to_left = matches!(
+ join.join_type,
+ JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right,
+ );
+
+ let equijoin_keys = matching_equijoin_keys(&join);
+ if equijoin_keys.is_empty() {
+ return Ok(Transformed::no(LogicalPlan::Join(join)));
+ }
+
+ let mut left_analysis = analyze_logical_side(&join.left);
+ let mut right_analysis = analyze_logical_side(&join.right);
+
+ let mut new_left: Arc = Arc::clone(&join.left);
+ let mut new_right: Arc = Arc::clone(&join.right);
+ let mut changed = false;
+
+ for key in &equijoin_keys {
+ match key {
+ EquiKey::BothColumns { left, right } => {
+ // Propagate the LEFT-side filtered key domain → the RIGHT side.
+ if allow_left_to_right
+ && left_analysis.is_dim_like
+ && left_analysis.has_non_key_filter(&left.name)
+ && key_preserved_through_summaries(&join.left, left)
+ && !skip_propagation_by_cardinality(&join.left, &join.right)
+ && !right_analysis.has_propagated_filter_target(&column_expr(right))
+ {
+ let subquery_plan = build_key_projection_subquery(
+ Arc::clone(&join.left),
+ left,
+ config.alias_generator(),
+ )?;
+ let target = column_expr(right);
+ let wrapped = wrap_with_in_subquery_filter_expr(
+ Arc::clone(&new_right),
+ &target,
+ subquery_plan,
+ )?;
+ new_right = Arc::new(wrapped);
+ right_analysis.add_propagated_filter_target(&target);
+ changed = true;
+ }
+
+ // Propagate the RIGHT-side filtered key domain → the LEFT side.
+ if allow_right_to_left
+ && right_analysis.is_dim_like
+ && right_analysis.has_non_key_filter(&right.name)
+ && key_preserved_through_summaries(&join.right, right)
+ && !skip_propagation_by_cardinality(&join.right, &join.left)
+ && !left_analysis.has_propagated_filter_target(&column_expr(left))
+ {
+ let subquery_plan = build_key_projection_subquery(
+ Arc::clone(&join.right),
+ right,
+ config.alias_generator(),
+ )?;
+ let target = column_expr(left);
+ let wrapped = wrap_with_in_subquery_filter_expr(
+ Arc::clone(&new_left),
+ &target,
+ subquery_plan,
+ )?;
+ new_left = Arc::new(wrapped);
+ left_analysis.add_propagated_filter_target(&target);
+ changed = true;
+ }
+ }
+ EquiKey::LeftColumnRightExpr {
+ left_col,
+ right_expr,
+ } => {
+ // Only LEFT-dim → RIGHT-expr direction can fire: the right
+ // side has an expression key, so the fact-side filter
+ // target must be that expression. Propagation in the other
+ // direction would require projecting an expression
+ // (potentially referencing fact-side rows) inside the dim
+ // subquery, which would no longer be a cheap re-execution.
+ if allow_left_to_right
+ && left_analysis.is_dim_like
+ && left_analysis.has_non_key_filter(&left_col.name)
+ && key_preserved_through_summaries(&join.left, left_col)
+ && !skip_propagation_by_cardinality(&join.left, &join.right)
+ && !right_analysis.has_propagated_filter_target(right_expr)
+ {
+ let subquery_plan = build_key_projection_subquery(
+ Arc::clone(&join.left),
+ left_col,
+ config.alias_generator(),
+ )?;
+ let wrapped = wrap_with_in_subquery_filter_expr(
+ Arc::clone(&new_right),
+ right_expr,
+ subquery_plan,
+ )?;
+ new_right = Arc::new(wrapped);
+ right_analysis.add_propagated_filter_target(right_expr);
+ changed = true;
+ }
+ }
+ EquiKey::LeftExprRightColumn {
+ left_expr,
+ right_col,
+ } => {
+ // Symmetric: only RIGHT-dim → LEFT-expr direction.
+ if allow_right_to_left
+ && right_analysis.is_dim_like
+ && right_analysis.has_non_key_filter(&right_col.name)
+ && key_preserved_through_summaries(&join.right, right_col)
+ && !skip_propagation_by_cardinality(&join.right, &join.left)
+ && !left_analysis.has_propagated_filter_target(left_expr)
+ {
+ let subquery_plan = build_key_projection_subquery(
+ Arc::clone(&join.right),
+ right_col,
+ config.alias_generator(),
+ )?;
+ let wrapped = wrap_with_in_subquery_filter_expr(
+ Arc::clone(&new_left),
+ left_expr,
+ subquery_plan,
+ )?;
+ new_left = Arc::new(wrapped);
+ left_analysis.add_propagated_filter_target(left_expr);
+ changed = true;
+ }
+ }
+ }
+ }
+
+ if !changed {
+ return Ok(Transformed::no(LogicalPlan::Join(join)));
+ }
+
+ let new_join = Join::try_new(
+ new_left,
+ new_right,
+ join.on,
+ join.filter,
+ join.join_type,
+ join.join_constraint,
+ join.null_equality,
+ )?;
+
+ Ok(Transformed::yes(LogicalPlan::Join(new_join)))
+ }
+}
+
+#[derive(Default)]
+struct SideAnalysis {
+ is_dim_like: bool,
+ filter_columns: BTreeSet,
+ /// Targets of already-propagated `InSubquery` filters on this side, keyed
+ /// by the `Display` form of the target expression. Used for cycle
+ /// prevention — the same target should not be wrapped twice. Tracks both
+ /// pure-column and expression targets uniformly, so the chbench
+ /// `ascii(substr(c_state,1,1)) - 65` shape is also cycle-guarded.
+ propagated_filter_targets: BTreeSet,
+}
+
+impl SideAnalysis {
+ fn has_non_key_filter(&self, key_name: &str) -> bool {
+ self.filter_columns.iter().any(|column| column != key_name)
+ }
+
+ fn has_propagated_filter_target(&self, target: &Expr) -> bool {
+ self.propagated_filter_targets.contains(&target.to_string())
+ }
+
+ fn add_propagated_filter_target(&mut self, target: &Expr) {
+ self.propagated_filter_targets.insert(target.to_string());
+ }
+}
+
+fn column_expr(column: &Column) -> Expr {
+ Expr::Column(column.clone())
+}
+
+fn analyze_logical_side(plan: &LogicalPlan) -> SideAnalysis {
+ let mut analysis = SideAnalysis {
+ is_dim_like: is_dim_like_subtree(plan),
+ ..SideAnalysis::default()
+ };
+
+ let _ = plan.apply(|node| {
+ if let LogicalPlan::Filter(filter) = node {
+ collect_filter_column_names(&filter.predicate, &mut analysis.filter_columns);
+ collect_propagated_filter_targets(
+ &filter.predicate,
+ &mut analysis.propagated_filter_targets,
+ );
+ }
+ // Post-decorrelation cycle detection: `decorrelate_predicate_subquery`
+ // rewrites our propagated `InSubquery` into a `LeftSemi` join with the
+ // marker `SubqueryAlias` as its right child. Without this branch the
+ // rule's cycle guard misses the marker (it only walked Filter
+ // predicates), and the optimizer would re-propagate on every iteration
+ // until hitting `max_passes`, stacking redundant `LeftSemi` layers.
+ if let LogicalPlan::Join(join) = node
+ && matches!(join.join_type, JoinType::LeftSemi)
+ && right_side_carries_propagation_marker(&join.right)
+ {
+ for (left_expr, _) in &join.on {
+ analysis
+ .propagated_filter_targets
+ .insert(left_expr.to_string());
+ }
+ }
+
+ Ok(TreeNodeRecursion::Continue)
+ });
+
+ analysis
+}
+
+/// Returns `true` if `plan` is — possibly behind a chain of `Projection` or
+/// `SubqueryAlias` wrappers added by later optimizer rules — a `SubqueryAlias`
+/// whose name starts with [`PROPAGATED_FILTER_ALIAS_PREFIX`].
+fn right_side_carries_propagation_marker(plan: &LogicalPlan) -> bool {
+ match plan {
+ LogicalPlan::SubqueryAlias(alias) => {
+ if alias
+ .alias
+ .table()
+ .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX)
+ {
+ return true;
+ }
+ right_side_carries_propagation_marker(&alias.input)
+ }
+ LogicalPlan::Projection(p) => right_side_carries_propagation_marker(&p.input),
+ _ => false,
+ }
+}
+
+/// An equi-join key from `Join::on`, classified by which sides are pure
+/// columns. Propagation requires the *dim* side to be a `Column` so the IN
+/// subquery has a cheap, projectable key; the *fact* side may be an arbitrary
+/// expression (e.g. the chbench `ascii(substr(c_state,1,1)) - 65` pattern).
+enum EquiKey {
+ /// Both join keys are columns. The rule may fire in either direction
+ /// depending on which side is dim-like.
+ BothColumns { left: Column, right: Column },
+ /// Left key is a column, right key is an expression. Only the
+ /// `LEFT → RIGHT` propagation direction is supported.
+ LeftColumnRightExpr { left_col: Column, right_expr: Expr },
+ /// Right key is a column, left key is an expression. Only the
+ /// `RIGHT → LEFT` propagation direction is supported.
+ LeftExprRightColumn { left_expr: Expr, right_col: Column },
+}
+
+/// Return the equi-join keys from `join.on` whose data types match. Drops
+/// pairs where both sides are expressions (no dim-like column to project) and
+/// pairs whose types differ (the `IN` subquery would need an implicit cast we
+/// don't insert here).
+fn matching_equijoin_keys(join: &Join) -> Vec {
+ join.on
+ .iter()
+ .filter_map(|(left, right)| {
+ if !join_key_types_match(left, right, &join.left, &join.right) {
+ return None;
+ }
+
+ match (left, right) {
+ (Expr::Column(l), Expr::Column(r)) => Some(EquiKey::BothColumns {
+ left: l.clone(),
+ right: r.clone(),
+ }),
+ (Expr::Column(l), other) => Some(EquiKey::LeftColumnRightExpr {
+ left_col: l.clone(),
+ right_expr: other.clone(),
+ }),
+ (other, Expr::Column(r)) => Some(EquiKey::LeftExprRightColumn {
+ left_expr: other.clone(),
+ right_col: r.clone(),
+ }),
+ // Both sides are non-trivial expressions — no cheap projection
+ // target on either side, skip.
+ _ => None,
+ }
+ })
+ .collect()
+}
+
+fn join_key_types_match(
+ left: &Expr,
+ right: &Expr,
+ left_plan: &LogicalPlan,
+ right_plan: &LogicalPlan,
+) -> bool {
+ let Ok(left_type) = left.get_type(left_plan.schema()) else {
+ return false;
+ };
+ let Ok(right_type) = right.get_type(right_plan.schema()) else {
+ return false;
+ };
+
+ left_type == right_type
+}
+
+/// Maximum number of `TableScan` leaves allowed inside a dim-like subtree.
+///
+/// Chosen to cover the canonical chbench / TPC-H dimension snowflake
+/// (`region ⋈ nation ⋈ supplier`, three leaves) without admitting arbitrarily
+/// large dim joins whose re-execution under an `InSubquery` would be expensive.
+const MAX_DIM_LIKE_TABLE_SCANS: usize = 3;
+
+/// Skip propagation when the dim subtree's known upper-bound row count is
+/// below this threshold. Below it the dim is already small enough that the
+/// stock hash build is fast, and the `InSubquery → LeftSemi` decorrelation +
+/// planning cost outweighs the saved probe work.
+const MIN_DIM_ROWS_FOR_PROPAGATION: usize = 1_000;
+
+/// Skip propagation when the receiving fact subtree's known upper-bound row
+/// count is below this threshold. Below it there isn't enough probe
+/// cardinality for the filter to recoup the propagation overhead.
+const MIN_FACT_ROWS_FOR_PROPAGATION: usize = 100_000;
+
+/// Returns `true` if `plan` is a "dim-like" subtree — a small snowflake of
+/// dimensions composed of at most [`MAX_DIM_LIKE_TABLE_SCANS`] `TableScan`s
+/// connected through identity-preserving operators (`Projection`,
+/// `SubqueryAlias`, `Filter`, `Limit`), inner equi-joins with default SQL
+/// NULL equality, `Aggregate`, or `Distinct`.
+///
+/// The conservatism here keeps the duplicated subquery cheap: `DataFusion` will
+/// execute it independently of the outer join, so we only fire on subtrees
+/// where re-running the scan(s) + filter(s) is cheap. Unions, windows, sorts,
+/// and any non-inner / null-equal join terminate the walk.
+///
+/// `Aggregate` and `Distinct` are *structurally* allowed here, but the rule's
+/// caller must additionally verify the join key is preserved through any
+/// aggregations via [`key_preserved_through_summaries`] — an aggregate that
+/// does not group by the key does not preserve its domain and cannot be the
+/// source of a propagated subquery on that key.
+fn is_dim_like_subtree(plan: &LogicalPlan) -> bool {
+ count_dim_like_table_scans(plan).is_some_and(|n| n <= MAX_DIM_LIKE_TABLE_SCANS)
+}
+
+fn count_dim_like_table_scans(plan: &LogicalPlan) -> Option {
+ match plan {
+ LogicalPlan::TableScan(_) => Some(1),
+ LogicalPlan::Projection(p) => count_dim_like_table_scans(&p.input),
+ LogicalPlan::SubqueryAlias(a) => count_dim_like_table_scans(&a.input),
+ LogicalPlan::Filter(f) => count_dim_like_table_scans(&f.input),
+ LogicalPlan::Limit(l) => count_dim_like_table_scans(&l.input),
+ LogicalPlan::Aggregate(a) => count_dim_like_table_scans(&a.input),
+ LogicalPlan::Distinct(d) => count_dim_like_table_scans(distinct_input(d)),
+ LogicalPlan::Join(j)
+ if j.join_type == JoinType::Inner
+ && j.null_equality == NullEquality::NullEqualsNothing =>
+ {
+ let l = count_dim_like_table_scans(&j.left)?;
+ let r = count_dim_like_table_scans(&j.right)?;
+ Some(l + r)
+ }
+ _ => None,
+ }
+}
+
+/// Returns the single input plan of a `Distinct` regardless of variant.
+fn distinct_input(distinct: &datafusion::logical_expr::Distinct) -> &LogicalPlan {
+ use datafusion::logical_expr::Distinct;
+ match distinct {
+ Distinct::All(input) => input,
+ Distinct::On(on) => &on.input,
+ }
+}
+
+/// Sum of known upper-bound row counts of all `TableScan`s reachable from
+/// `plan`. Returns `None` if any reachable `TableScan` is missing stats — the
+/// caller falls back to the structural gates in that case.
+///
+/// The walk follows every `LogicalPlan` child (not just dim-like wrappers) so
+/// fact-side subtrees with joins, aggregates, etc. are summed too. The result
+/// is a loose *upper bound* — filter selectivity isn't accounted for — which
+/// is the right direction for the "skip if known small" gate (a true upper
+/// bound below the threshold guarantees the subtree is actually small).
+fn subtree_upper_bound_rows(plan: &LogicalPlan) -> Option {
+ use datafusion::common::stats::Precision;
+ use datafusion::datasource::DefaultTableSource;
+
+ let mut total: usize = 0;
+ let mut any_unknown = false;
+ let _ = plan.apply(|node| {
+ if let LogicalPlan::TableScan(scan) = node {
+ let rows = scan
+ .source
+ .as_any()
+ .downcast_ref::()
+ .and_then(|default| default.table_provider.statistics())
+ .and_then(|stats| match stats.num_rows {
+ Precision::Exact(n) | Precision::Inexact(n) => Some(n),
+ Precision::Absent => None,
+ });
+ if let Some(n) = rows {
+ total = total.saturating_add(n);
+ } else {
+ any_unknown = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ }
+ Ok(TreeNodeRecursion::Continue)
+ });
+ if any_unknown { None } else { Some(total) }
+}
+
+/// `true` when stats prove the dim side is below
+/// [`MIN_DIM_ROWS_FOR_PROPAGATION`] *or* the fact side is below
+/// [`MIN_FACT_ROWS_FOR_PROPAGATION`]. Missing stats on either side fall back
+/// to the structural gates: this function returns `false` (allow propagation),
+/// matching the rule's behavior before the cardinality gates were added.
+fn skip_propagation_by_cardinality(dim_side: &LogicalPlan, fact_side: &LogicalPlan) -> bool {
+ if matches!(
+ subtree_upper_bound_rows(dim_side),
+ Some(n) if n < MIN_DIM_ROWS_FOR_PROPAGATION
+ ) {
+ return true;
+ }
+ if matches!(
+ subtree_upper_bound_rows(fact_side),
+ Some(n) if n < MIN_FACT_ROWS_FOR_PROPAGATION
+ ) {
+ return true;
+ }
+ false
+}
+
+/// Returns `true` if `key` retains its scan-level domain through every
+/// `Aggregate` / `Distinct` reachable in `plan`.
+///
+/// * `Aggregate` preserves a column's domain only when it appears in
+/// `group_expr` as a plain `Expr::Column` reference.
+/// * `Distinct::All` preserves every projected column (deduplication keeps
+/// value identity).
+/// * `Distinct::On(distinct_on)` preserves only the columns named in its `on`
+/// list; for safety we conservatively require the key to appear there.
+///
+/// The walk follows only identity-preserving operators plus inner equi-joins —
+/// the same shape `is_dim_like_subtree` accepts. Anything outside that vocab
+/// (`Sort`, `Window`, etc.) is conservatively rejected by returning `false`.
+fn key_preserved_through_summaries(plan: &LogicalPlan, key: &Column) -> bool {
+ fn key_for_input_schema(input: &LogicalPlan, key: &Column) -> Option {
+ input
+ .schema()
+ .qualified_field_with_unqualified_name(&key.name)
+ .ok()
+ .map(|(qualifier, field)| Column::new(qualifier.cloned(), field.name().clone()))
+ }
+
+ fn walk(plan: &LogicalPlan, key: &Column) -> bool {
+ match plan {
+ LogicalPlan::TableScan(_) => plan.schema().has_column(key),
+ LogicalPlan::Projection(p) => plan.schema().has_column(key) && walk(&p.input, key),
+ LogicalPlan::SubqueryAlias(a) => {
+ let relation_matches_alias = match key.relation.as_ref() {
+ Some(relation) => relation == &a.alias,
+ None => true,
+ };
+ relation_matches_alias
+ && key_for_input_schema(&a.input, key)
+ .is_some_and(|input_key| walk(&a.input, &input_key))
+ }
+ LogicalPlan::Filter(f) => plan.schema().has_column(key) && walk(&f.input, key),
+ LogicalPlan::Limit(l) => plan.schema().has_column(key) && walk(&l.input, key),
+ LogicalPlan::Aggregate(a) => {
+ let key_in_group = a
+ .group_expr
+ .iter()
+ .any(|expr| matches!(expr, Expr::Column(column) if column == key));
+ key_in_group && plan.schema().has_column(key) && walk(&a.input, key)
+ }
+ LogicalPlan::Distinct(distinct) => {
+ use datafusion::logical_expr::Distinct;
+ let key_kept = match distinct {
+ Distinct::All(_) => true,
+ Distinct::On(on) => on
+ .on_expr
+ .iter()
+ .any(|expr| matches!(expr, Expr::Column(column) if column == key)),
+ };
+ key_kept && plan.schema().has_column(key) && walk(distinct_input(distinct), key)
+ }
+ LogicalPlan::Join(j)
+ if j.join_type == JoinType::Inner
+ && j.null_equality == NullEquality::NullEqualsNothing =>
+ {
+ plan.schema().has_column(key) && (walk(&j.left, key) || walk(&j.right, key))
+ }
+ _ => false,
+ }
+ }
+
+ walk(plan, key)
+}
+
+fn collect_filter_column_names(expr: &Expr, columns: &mut BTreeSet) {
+ let _ = expr.apply(|e| {
+ if let Expr::Column(column) = e {
+ columns.insert(column.name.clone());
+ }
+
+ Ok(TreeNodeRecursion::Continue)
+ });
+}
+
+fn collect_propagated_filter_targets(expr: &Expr, targets: &mut BTreeSet) {
+ let _ = expr.apply(|e| {
+ if let Expr::InSubquery(InSubquery {
+ expr: target_expr,
+ subquery,
+ ..
+ }) = e
+ && let LogicalPlan::SubqueryAlias(alias) = subquery.subquery.as_ref()
+ && alias
+ .alias
+ .table()
+ .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX)
+ {
+ targets.insert(target_expr.to_string());
+ return Ok(TreeNodeRecursion::Jump);
+ }
+
+ Ok(TreeNodeRecursion::Continue)
+ });
+}
+
+/// Returns `true` if `plan` already contains a propagated-filter marker.
+#[cfg(test)]
+fn subtree_has_propagated_filter(plan: &LogicalPlan) -> bool {
+ let mut found = false;
+ let _ = plan.apply(|node| {
+ if let LogicalPlan::SubqueryAlias(alias) = node
+ && alias
+ .alias
+ .table()
+ .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX)
+ {
+ found = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ if let LogicalPlan::Filter(f) = node
+ && expr_has_propagated_filter(&f.predicate)
+ {
+ found = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ Ok(TreeNodeRecursion::Continue)
+ });
+ found
+}
+
+/// Returns `true` if `expr` contains an [`InSubquery`] whose inner plan
+/// starts with a [`SubqueryAlias`] named with
+/// [`PROPAGATED_FILTER_ALIAS_PREFIX`].
+#[must_use]
+#[cfg(test)]
+fn expr_has_propagated_filter(expr: &Expr) -> bool {
+ let mut found = false;
+ let _ = expr.apply(|e| {
+ if let Expr::InSubquery(InSubquery { subquery, .. }) = e
+ && let LogicalPlan::SubqueryAlias(alias) = subquery.subquery.as_ref()
+ && alias
+ .alias
+ .table()
+ .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX)
+ {
+ found = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ Ok(TreeNodeRecursion::Continue)
+ });
+ found
+}
+
+/// Build a `SubqueryAlias(__cayenne_xclos__N, Projection([key_col], subtree))`
+/// suitable for use as the inner plan of a [`Subquery`] referenced by an
+/// [`InSubquery`] expression.
+///
+/// The alias name uses [`PROPAGATED_FILTER_ALIAS_PREFIX`] plus a unique id
+/// from [`OptimizerConfig::alias_generator`], so each invocation produces a
+/// distinct marker. The marker doubles as the cycle-detection sentinel
+/// scanned by [`analyze_logical_side`].
+fn build_key_projection_subquery(
+ subtree: Arc,
+ key_col: &Column,
+ alias_gen: &Arc,
+) -> Result {
+ let key_expr = Expr::Column(key_col.clone());
+ let projection = LogicalPlan::Projection(Projection::try_new(vec![key_expr], subtree)?);
+ let alias_name = alias_gen.next(PROPAGATED_FILTER_ALIAS_PREFIX);
+ let aliased = SubqueryAlias::try_new(Arc::new(projection), TableReference::bare(alias_name))?;
+ Ok(LogicalPlan::SubqueryAlias(aliased))
+}
+
+/// Wrap `input` with `Filter(target IN (subquery))` using the `subquery_plan`
+/// (which must already be a `SubqueryAlias` named with
+/// [`PROPAGATED_FILTER_ALIAS_PREFIX`]) as the right-hand side. `target` may be
+/// a column or any expression whose columns all resolve in `input`'s schema —
+/// the chbench `ascii(substr(c_state,1,1)) - 65` shape is supported through
+/// this entry point.
+fn wrap_with_in_subquery_filter_expr(
+ input: Arc,
+ target: &Expr,
+ subquery_plan: LogicalPlan,
+) -> Result {
+ let predicate = Expr::InSubquery(InSubquery::new(
+ Box::new(target.clone()),
+ Subquery {
+ subquery: Arc::new(subquery_plan),
+ outer_ref_columns: vec![],
+ spans: Spans::default(),
+ },
+ false,
+ ));
+ let filter = Filter::try_new(predicate, input)?;
+ Ok(LogicalPlan::Filter(filter))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use datafusion::arrow::datatypes::{DataType, Field, Schema};
+ use datafusion::catalog::MemTable;
+ use datafusion::prelude::SessionContext;
+ use std::sync::Arc;
+
+ fn rule() -> CayennePropagateFilterAcrossEquiJoinKeys {
+ CayennePropagateFilterAcrossEquiJoinKeys::new()
+ }
+
+ fn make_ctx() -> Result {
+ let ctx = SessionContext::new();
+ // dim-like nation table — gains an `n_regionkey` so the multi-hop
+ // `region ⋈ nation` propagation tests can join through it.
+ let nation_schema = Arc::new(Schema::new(vec![
+ Field::new("n_nationkey", DataType::Int64, false),
+ Field::new("n_name", DataType::Utf8, true),
+ Field::new("n_regionkey", DataType::Int64, false),
+ ]));
+ // dim-like region table for multi-hop tests.
+ let region_schema = Arc::new(Schema::new(vec![
+ Field::new("r_regionkey", DataType::Int64, false),
+ Field::new("r_name", DataType::Utf8, true),
+ ]));
+ // fact-like supplier table
+ let supplier_schema = Arc::new(Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, false),
+ Field::new("s_nationkey", DataType::Int64, false),
+ ]));
+ // fact-like customer table for expression-equi-key tests
+ // (chbench `ascii(substr(c_state, 1, 1)) - 65` nation mapping).
+ let customer_schema = Arc::new(Schema::new(vec![
+ Field::new("c_id", DataType::Int64, false),
+ Field::new("c_state", DataType::Utf8, true),
+ ]));
+ ctx.register_table(
+ "nation",
+ Arc::new(MemTable::try_new(Arc::clone(&nation_schema), vec![vec![]])?),
+ )?;
+ ctx.register_table(
+ "region",
+ Arc::new(MemTable::try_new(Arc::clone(®ion_schema), vec![vec![]])?),
+ )?;
+ ctx.register_table(
+ "supplier",
+ Arc::new(MemTable::try_new(
+ Arc::clone(&supplier_schema),
+ vec![vec![]],
+ )?),
+ )?;
+ ctx.register_table(
+ "customer",
+ Arc::new(MemTable::try_new(
+ Arc::clone(&customer_schema),
+ vec![vec![]],
+ )?),
+ )?;
+ Ok(ctx)
+ }
+
+ /// Walk a `LogicalPlan` to find the first `Join` and return whichever
+ /// side's plan tree contains a `SubqueryAlias` whose name starts with
+ /// [`PROPAGATED_FILTER_ALIAS_PREFIX`].
+ fn find_propagated_side(plan: &LogicalPlan) -> Option<&'static str> {
+ let mut result: Option<&'static str> = None;
+ let _ = plan.apply(|node| {
+ if let LogicalPlan::Join(j) = node {
+ if subtree_has_propagated_filter(j.left.as_ref()) {
+ result = Some("left");
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ if subtree_has_propagated_filter(j.right.as_ref()) {
+ result = Some("right");
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ }
+ Ok(TreeNodeRecursion::Continue)
+ });
+ result
+ }
+
+ fn count_propagated_filter_exprs(plan: &LogicalPlan) -> usize {
+ let mut count = 0;
+ let _ = plan.apply(|node| {
+ if let LogicalPlan::Filter(f) = node {
+ let _ = f.predicate.apply(|expr| {
+ if let Expr::InSubquery(InSubquery { subquery, .. }) = expr
+ && let LogicalPlan::SubqueryAlias(alias) = subquery.subquery.as_ref()
+ && alias
+ .alias
+ .table()
+ .starts_with(PROPAGATED_FILTER_ALIAS_PREFIX)
+ {
+ count += 1;
+ }
+ Ok(TreeNodeRecursion::Continue)
+ });
+ }
+ Ok(TreeNodeRecursion::Continue)
+ });
+ count
+ }
+
+ #[test]
+ fn rule_metadata() {
+ assert_eq!(
+ rule().name(),
+ "cayenne_propagate_filter_across_equi_join_keys"
+ );
+ assert_eq!(rule().apply_order(), Some(ApplyOrder::TopDown));
+ }
+
+ #[tokio::test]
+ async fn non_inner_join_is_unchanged() -> Result<()> {
+ // Use `IS NULL` on the right side so `eliminate_outer_join` doesn't
+ // promote the LEFT JOIN to an INNER JOIN, otherwise we'd be testing
+ // the wrong thing.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier LEFT JOIN nation \
+ ON s_nationkey = n_nationkey WHERE n_name IS NULL",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ assert!(
+ !changed,
+ "LEFT JOIN must be skipped by the rule; plan was:\n{plan}"
+ );
+ Ok(())
+ }
+
+ /// Run the rule against every `LogicalPlan::Join` reachable from `plan`,
+ /// returning the transformed plan and a flag indicating whether at least
+ /// one invocation made a change.
+ ///
+ /// Mirrors what `DataFusion`'s optimizer driver does for an
+ /// `ApplyOrder::TopDown` rule, but without spinning up the rest of the
+ /// rule pipeline — keeps the tests focused on this rule's behavior in
+ /// isolation.
+ fn apply_rule_to_all_joins(
+ rule: &CayennePropagateFilterAcrossEquiJoinKeys,
+ plan: LogicalPlan,
+ cfg: &datafusion::optimizer::OptimizerContext,
+ ) -> Result<(LogicalPlan, bool)> {
+ let mut any_changed = false;
+ let transformed = plan.transform_down(|node| {
+ if matches!(node, LogicalPlan::Join(_)) {
+ let r = rule.rewrite(node, cfg)?;
+ if r.transformed {
+ any_changed = true;
+ }
+ Ok(r)
+ } else {
+ Ok(Transformed::no(node))
+ }
+ })?;
+ Ok((transformed.data, any_changed))
+ }
+
+ #[tokio::test]
+ async fn inner_join_with_dim_filter_propagates_via_subquery() -> Result<()> {
+ // The canonical q21 shape (reduced):
+ // FROM supplier, nation
+ // WHERE s_nationkey = n_nationkey AND n_name = 'CHINA'
+ //
+ // After PushDownFilter, `n_name = 'CHINA'` lives in a Filter directly
+ // above the nation TableScan. The rule should then wrap supplier with
+ // `Filter(s_nationkey IN (SELECT n_nationkey FROM nation
+ // WHERE n_name = 'CHINA'))`.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier, nation \
+ WHERE s_nationkey = n_nationkey AND n_name = 'CHINA'",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+
+ // Depending on `DataFusion`'s planner the join's `left`/`right` may be
+ // either order. We don't care which side gets the InSubquery, only
+ // that exactly one of them does, and that it carries the marker.
+ let propagated = find_propagated_side(&transformed_plan);
+ assert!(
+ changed,
+ "rule should fire on inner join with dim-side non-key filter; plan was:\n{plan}"
+ );
+ assert!(
+ propagated.is_some(),
+ "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}"
+ );
+
+ // Cycle prevention: running the rule a second time on the
+ // already-transformed plan must be a no-op.
+ let (second_plan, changed2) = apply_rule_to_all_joins(&r, transformed_plan.clone(), &cfg)?;
+ assert!(
+ !changed2,
+ "second pass must not re-propagate (cycle guard); plan was:\n{second_plan}"
+ );
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn left_semi_join_with_dim_filter_propagates_via_subquery() -> Result<()> {
+ // The `IN (subquery)` shape that `decorrelate_predicate_subquery`
+ // rewrites into a `LeftSemi` join. The propagation rule must still
+ // fire on the resulting semi-join so the dim filter reaches the fact
+ // side across chained joins.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier \
+ WHERE s_nationkey IN \
+ (SELECT n_nationkey FROM nation WHERE n_name = 'CHINA')",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ // Sanity-check that decorrelation produced a semi-join shape; if it
+ // didn't, this test is testing the wrong thing.
+ let mut semi_seen = false;
+ let _ = plan.apply(|node| {
+ if let LogicalPlan::Join(j) = node
+ && matches!(j.join_type, JoinType::LeftSemi | JoinType::RightSemi)
+ {
+ semi_seen = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ Ok(TreeNodeRecursion::Continue)
+ });
+ assert!(
+ semi_seen,
+ "expected decorrelation to produce a semi-join; plan was:\n{plan}"
+ );
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ assert!(
+ changed,
+ "rule should fire on semi-join with dim-side non-key filter; plan was:\n{plan}"
+ );
+ assert!(
+ find_propagated_side(&transformed_plan).is_some(),
+ "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn left_outer_join_propagates_only_left_to_right() -> Result<()> {
+ // `supplier LEFT JOIN nation ON s_nationkey = n_nationkey WHERE
+ // s_name = 'X'`. The LEFT side (supplier) has a non-key filter; it is
+ // the preserved side. Propagating to the lookup side (nation) is safe.
+ //
+ // Note: `eliminate_outer_join` will rewrite the LEFT JOIN to an INNER
+ // JOIN only if the WHERE clause forces the right side to be non-null
+ // — using a filter on the LEFT side instead preserves the outer
+ // semantics, which is what we want for this test.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier LEFT JOIN nation \
+ ON s_nationkey = n_nationkey WHERE s_suppkey > 5",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ // The supplier-side filter (`s_suppkey > 5`) is a non-key filter on
+ // the LEFT/preserved side. Direction LEFT→RIGHT is allowed; the rule
+ // should propagate `n_nationkey IN (SELECT s_nationkey FROM filtered_supplier)`
+ // onto nation.
+ assert!(
+ changed,
+ "rule should fire LEFT→RIGHT for LEFT OUTER; plan was:\n{plan}"
+ );
+ assert!(
+ find_propagated_side(&transformed_plan).is_some(),
+ "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn left_outer_join_blocks_right_to_left_propagation() -> Result<()> {
+ // Filter on the RIGHT (lookup) side of a LEFT OUTER must NOT cause
+ // propagation onto the LEFT (preserved) side: doing so would drop
+ // left rows the outer join should emit as `(left, NULL...)`.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier LEFT JOIN nation \
+ ON s_nationkey = n_nationkey \
+ WHERE n_name = 'CHINA' OR n_name IS NULL",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ // The filter is on the RIGHT side. RIGHT→LEFT propagation is forbidden
+ // for LEFT OUTER. LEFT→RIGHT is allowed but there's no LEFT-side filter
+ // to propagate. So the rule must be a no-op here.
+ assert!(
+ !changed,
+ "RIGHT→LEFT propagation must not fire on LEFT OUTER; plan was:\n{plan}"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn rule_does_not_re_fire_on_post_decorrelation_left_semi() -> Result<()> {
+ // Regression test for the cycle-detection bug across optimizer
+ // iterations: after Pass 1 wraps the receiving side with an
+ // `InSubquery`, `decorrelate_predicate_subquery` rewrites that into a
+ // `LeftSemi` join with the marker `SubqueryAlias` as its right child.
+ // If the rule's cycle detection only sees `InSubquery` markers (and
+ // not the structural `LeftSemi`-with-marker shape), Pass 2 sees no
+ // marker on the receiving side and re-propagates, producing nested
+ // LeftSemi joins on every subsequent optimizer pass.
+ //
+ // The fix detects the post-decorrelation shape and records the
+ // already-propagated target so the rule's cycle guard short-circuits
+ // on subsequent passes.
+ use datafusion::common::NullEquality;
+ use datafusion::logical_expr::JoinConstraint;
+ use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit};
+
+ let dim_schema = Arc::new(Schema::new(vec![
+ Field::new("n_nationkey", DataType::Int64, false),
+ Field::new("n_name", DataType::Utf8, true),
+ ]));
+ let fact_schema = Arc::new(Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, false),
+ Field::new("s_nationkey", DataType::Int64, false),
+ ]));
+
+ // Build the dim subquery: `Filter(n_name='CHINA') → TableScan(nation)`
+ // wrapped in the propagated-filter alias the rule would have produced.
+ let nation_scan = table_scan(Some("nation"), &dim_schema, None)?.build()?;
+ let nation_filter = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("nation"), "n_name")).eq(lit("CHINA")),
+ Arc::new(nation_scan),
+ )?);
+ let nation_projection = LogicalPlan::Projection(Projection::try_new(
+ vec![Expr::Column(Column::new(Some("nation"), "n_nationkey"))],
+ Arc::new(nation_filter),
+ )?);
+ let dim_subquery_alias = format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1");
+ let dim_subquery = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
+ Arc::new(nation_projection),
+ TableReference::bare(dim_subquery_alias),
+ )?);
+
+ // Build supplier scan (the receiving fact side).
+ let supplier_scan = table_scan(Some("supplier"), &fact_schema, None)?.build()?;
+
+ // Compose the post-decorrelation shape: `LeftSemi(supplier, dim_subquery)`
+ // on `s_nationkey = n_nationkey`.
+ let semi_join_input = LogicalPlanBuilder::from(supplier_scan)
+ .join_with_expr_keys(
+ dim_subquery,
+ JoinType::LeftSemi,
+ (
+ vec![Expr::Column(Column::new(Some("supplier"), "s_nationkey"))],
+ vec![Expr::Column(Column::new(
+ Some(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")),
+ "n_nationkey",
+ ))],
+ ),
+ None,
+ )?
+ .build()?;
+
+ // Now build an outer `Inner Join` between the *original* nation_filtered
+ // and this `LeftSemi` subtree on the same equi-key — the exact shape an
+ // optimizer pass would see after the rule already fired + decorrelated.
+ let dim_filter_again_scan = table_scan(Some("nation_outer"), &dim_schema, None)?.build()?;
+ let dim_filter_again = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("nation_outer"), "n_name")).eq(lit("CHINA")),
+ Arc::new(dim_filter_again_scan),
+ )?);
+
+ let outer_join = LogicalPlan::Join(Join::try_new(
+ Arc::new(dim_filter_again),
+ Arc::new(semi_join_input),
+ vec![(
+ Expr::Column(Column::new(Some("nation_outer"), "n_nationkey")),
+ Expr::Column(Column::new(Some("supplier"), "s_nationkey")),
+ )],
+ None,
+ JoinType::Inner,
+ JoinConstraint::On,
+ NullEquality::NullEqualsNothing,
+ )?);
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, outer_join.clone(), &cfg)?;
+ assert!(
+ !changed,
+ "rule must not re-fire when the receiving side already contains a \
+ post-decorrelation LeftSemi propagation marker; plan was:\n{outer_join}"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn rule_re_fires_when_receiving_side_has_non_marker_subquery_alias() -> Result<()> {
+ // Devil's-advocate edge case: a `LeftSemi` whose right side is a
+ // `SubqueryAlias` with a *non-marker* name should NOT block
+ // propagation (the marker prefix is the unique signal that this rule
+ // already fired). Guards against the cycle guard being too aggressive.
+ use datafusion::common::NullEquality;
+ use datafusion::logical_expr::JoinConstraint;
+ use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit};
+
+ let dim_schema = Arc::new(Schema::new(vec![
+ Field::new("n_nationkey", DataType::Int64, false),
+ Field::new("n_name", DataType::Utf8, true),
+ ]));
+ let fact_schema = Arc::new(Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, false),
+ Field::new("s_nationkey", DataType::Int64, false),
+ ]));
+
+ let nation_scan = table_scan(Some("nation"), &dim_schema, None)?.build()?;
+ let nation_filter = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("nation"), "n_name")).eq(lit("CHINA")),
+ Arc::new(nation_scan),
+ )?);
+ let nation_projection = LogicalPlan::Projection(Projection::try_new(
+ vec![Expr::Column(Column::new(Some("nation"), "n_nationkey"))],
+ Arc::new(nation_filter),
+ )?);
+ let user_alias = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
+ Arc::new(nation_projection),
+ TableReference::bare("some_user_alias"),
+ )?);
+
+ let supplier_scan = table_scan(Some("supplier"), &fact_schema, None)?.build()?;
+ let semi_join_input = LogicalPlanBuilder::from(supplier_scan)
+ .join_with_expr_keys(
+ user_alias,
+ JoinType::LeftSemi,
+ (
+ vec![Expr::Column(Column::new(Some("supplier"), "s_nationkey"))],
+ vec![Expr::Column(Column::new(
+ Some("some_user_alias".to_string()),
+ "n_nationkey",
+ ))],
+ ),
+ None,
+ )?
+ .build()?;
+
+ let outer_dim_scan = table_scan(Some("nation_outer"), &dim_schema, None)?.build()?;
+ let outer_dim_filter = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("nation_outer"), "n_name")).eq(lit("CHINA")),
+ Arc::new(outer_dim_scan),
+ )?);
+
+ let outer_join = LogicalPlan::Join(Join::try_new(
+ Arc::new(outer_dim_filter),
+ Arc::new(semi_join_input),
+ vec![(
+ Expr::Column(Column::new(Some("nation_outer"), "n_nationkey")),
+ Expr::Column(Column::new(Some("supplier"), "s_nationkey")),
+ )],
+ None,
+ JoinType::Inner,
+ JoinConstraint::On,
+ NullEquality::NullEqualsNothing,
+ )?);
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, outer_join.clone(), &cfg)?;
+ assert!(
+ changed,
+ "rule should still fire when the receiving LeftSemi's alias is \
+ user-supplied (not the propagation marker); plan was:\n{outer_join}"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn rule_detects_marker_through_projection_wrapper() -> Result<()> {
+ // Subsequent optimizer rules (`MergeProjection`, etc.) may wrap the
+ // marker `SubqueryAlias` in a `Projection`. The cycle guard must still
+ // detect the marker through this wrapping.
+ use datafusion::common::NullEquality;
+ use datafusion::logical_expr::JoinConstraint;
+ use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit};
+
+ let dim_schema = Arc::new(Schema::new(vec![
+ Field::new("n_nationkey", DataType::Int64, false),
+ Field::new("n_name", DataType::Utf8, true),
+ ]));
+ let fact_schema = Arc::new(Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, false),
+ Field::new("s_nationkey", DataType::Int64, false),
+ ]));
+
+ let nation_scan = table_scan(Some("nation"), &dim_schema, None)?.build()?;
+ let nation_filter = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("nation"), "n_name")).eq(lit("CHINA")),
+ Arc::new(nation_scan),
+ )?);
+ let inner_projection = LogicalPlan::Projection(Projection::try_new(
+ vec![Expr::Column(Column::new(Some("nation"), "n_nationkey"))],
+ Arc::new(nation_filter),
+ )?);
+ let marker_alias = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
+ Arc::new(inner_projection),
+ TableReference::bare(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")),
+ )?);
+ let wrapped_marker = LogicalPlan::Projection(Projection::try_new(
+ vec![Expr::Column(Column::new(
+ Some(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")),
+ "n_nationkey",
+ ))],
+ Arc::new(marker_alias),
+ )?);
+
+ let supplier_scan = table_scan(Some("supplier"), &fact_schema, None)?.build()?;
+ let semi_join_input = LogicalPlanBuilder::from(supplier_scan)
+ .join_with_expr_keys(
+ wrapped_marker,
+ JoinType::LeftSemi,
+ (
+ vec![Expr::Column(Column::new(Some("supplier"), "s_nationkey"))],
+ vec![Expr::Column(Column::new(
+ Some(format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1")),
+ "n_nationkey",
+ ))],
+ ),
+ None,
+ )?
+ .build()?;
+
+ let outer_dim_scan = table_scan(Some("nation_outer"), &dim_schema, None)?.build()?;
+ let outer_dim_filter = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("nation_outer"), "n_name")).eq(lit("CHINA")),
+ Arc::new(outer_dim_scan),
+ )?);
+
+ let outer_join = LogicalPlan::Join(Join::try_new(
+ Arc::new(outer_dim_filter),
+ Arc::new(semi_join_input),
+ vec![(
+ Expr::Column(Column::new(Some("nation_outer"), "n_nationkey")),
+ Expr::Column(Column::new(Some("supplier"), "s_nationkey")),
+ )],
+ None,
+ JoinType::Inner,
+ JoinConstraint::On,
+ NullEquality::NullEqualsNothing,
+ )?);
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, outer_join.clone(), &cfg)?;
+ assert!(
+ !changed,
+ "cycle guard must detect a marker wrapped in an outer Projection; \
+ plan was:\n{outer_join}"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn inner_join_without_filter_is_noop() -> Result<()> {
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier, nation \
+ WHERE s_nationkey = n_nationkey",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ assert!(
+ !changed,
+ "rule must not fire when neither side has a non-key filter; plan was:\n{plan}"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn inner_join_with_expression_fact_key_propagates_dim_filter() -> Result<()> {
+ // The canonical chbench Q5/Q7/Q10 shape: a non-trivial expression on
+ // the fact side and a pure column on the dim side, with the dim side
+ // carrying the selective non-key filter.
+ //
+ // The rule must fire on `(Column, Expr)` (or `(Expr, Column)`) equi-key
+ // pairs even though neither side is a pure column-column join.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT c_id FROM customer, nation \
+ WHERE ascii(substr(c_state, 1, 1)) - 65 = n_nationkey \
+ AND n_name = 'CHINA'",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ assert!(
+ changed,
+ "rule should fire on expression-vs-column equi-key; plan was:\n{plan}"
+ );
+ assert!(
+ find_propagated_side(&transformed_plan).is_some(),
+ "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}"
+ );
+
+ // Cycle prevention: running the rule a second time must be a no-op
+ // (the unified Display-keyed cycle guard tracks the InSubquery target
+ // expression, not just column targets).
+ let (_, changed2) = apply_rule_to_all_joins(&r, transformed_plan, &cfg)?;
+ assert!(
+ !changed2,
+ "second pass must not re-propagate (cycle guard) on expression target"
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn multi_hop_dim_subtree_propagates_through_region_nation() -> Result<()> {
+ // The canonical Q5 shape: `region ⋈ nation ⋈ supplier` with a
+ // selective filter on `region.r_name`. With the multi-hop dim
+ // detector the `region ⋈ nation` subtree counts as dim-like, so the
+ // rule can propagate the filtered `n_nationkey` domain to `supplier`
+ // in a single pass instead of waiting for the optimizer's fixed
+ // point.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier, nation, region \
+ WHERE s_nationkey = n_nationkey \
+ AND n_regionkey = r_regionkey \
+ AND r_name = 'ASIA'",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ assert!(
+ changed,
+ "rule should propagate r_name filter through the multi-hop dim subtree; \
+ plan was:\n{plan}"
+ );
+ assert!(
+ find_propagated_side(&transformed_plan).is_some(),
+ "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn key_preserved_through_summaries_accepts_distinct_all() -> Result<()> {
+ // `Distinct::All` deduplicates whole rows but preserves every column's
+ // values (it can only remove duplicate rows), so any join key survives.
+ use datafusion::logical_expr::Distinct;
+ use datafusion_expr::builder::table_scan;
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, false),
+ Field::new("b", DataType::Int64, false),
+ ]));
+ let scan = table_scan(Some("t"), &schema, None)?.build()?;
+ let distinct = LogicalPlan::Distinct(Distinct::All(Arc::new(scan)));
+
+ let key_a = Column::new(Some("t"), "a");
+ let key_b = Column::new(Some("t"), "b");
+
+ assert!(key_preserved_through_summaries(&distinct, &key_a));
+ assert!(key_preserved_through_summaries(&distinct, &key_b));
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn aggregate_dim_propagates_when_key_is_in_group_by() -> Result<()> {
+ // Pre-aggregated dim: `SELECT n_nationkey, count(*) FROM nation
+ // WHERE n_name = 'CHINA' GROUP BY n_nationkey` joined against
+ // supplier. The aggregate's GROUP BY includes `n_nationkey`, so the
+ // key's domain is preserved through the aggregation and the rule
+ // should still propagate to supplier.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier, \
+ (SELECT n_nationkey FROM nation WHERE n_name = 'CHINA' \
+ GROUP BY n_nationkey) AS n_agg \
+ WHERE s_nationkey = n_nationkey",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (transformed_plan, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ assert!(
+ changed,
+ "rule should fire when dim has Aggregate(GROUP BY key); plan was:\n{plan}"
+ );
+ assert!(
+ find_propagated_side(&transformed_plan).is_some(),
+ "rule fired but produced no propagated-filter marker; plan was:\n{transformed_plan}"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn key_preserved_through_summaries_rejects_aggregate_without_key_in_group() -> Result<()> {
+ // Sanity-check the helper: an aggregate that does NOT group by `a`
+ // must report the key as not preserved.
+ use datafusion::logical_expr::Aggregate;
+ use datafusion_expr::builder::table_scan;
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, false),
+ Field::new("b", DataType::Int64, false),
+ ]));
+ let scan = table_scan(Some("t"), &schema, None)?.build()?;
+ let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(scan),
+ vec![Expr::Column(Column::new(Some("t"), "b"))],
+ vec![],
+ )?);
+
+ let key_a = Column::new(Some("t"), "a");
+ let key_b = Column::new(Some("t"), "b");
+
+ assert!(
+ !key_preserved_through_summaries(&agg, &key_a),
+ "`a` aggregated away, must not be preserved"
+ );
+ assert!(
+ key_preserved_through_summaries(&agg, &key_b),
+ "`b` is in GROUP BY, must be preserved"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn subtree_upper_bound_rows_sums_stats_across_dim_subtree() -> Result<()> {
+ use datafusion::catalog::{Session, TableProvider};
+ use datafusion::common::stats::Precision;
+ use datafusion::datasource::DefaultTableSource;
+ use datafusion::logical_expr::{TableType, dml::InsertOp};
+ use datafusion::physical_plan::ExecutionPlan;
+ use datafusion_common::Statistics;
+ use datafusion_expr::Expr as ExprAlias;
+ use datafusion_expr::LogicalPlanBuilder;
+
+ /// `TableProvider` that returns a constant row count from `statistics()`.
+ #[derive(Debug)]
+ struct FixedStatsProvider {
+ schema: arrow::datatypes::SchemaRef,
+ num_rows: usize,
+ }
+
+ #[async_trait::async_trait]
+ impl TableProvider for FixedStatsProvider {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+ fn schema(&self) -> arrow::datatypes::SchemaRef {
+ Arc::clone(&self.schema)
+ }
+ fn table_type(&self) -> TableType {
+ TableType::Base
+ }
+ async fn scan(
+ &self,
+ _state: &dyn Session,
+ _projection: Option<&Vec>,
+ _filters: &[ExprAlias],
+ _limit: Option,
+ ) -> Result> {
+ Err(datafusion::common::DataFusionError::NotImplemented(
+ "FixedStatsProvider scan not used in this test".to_string(),
+ ))
+ }
+ fn statistics(&self) -> Option {
+ let mut stats = Statistics::new_unknown(self.schema.as_ref());
+ stats.num_rows = Precision::Exact(self.num_rows);
+ Some(stats)
+ }
+ async fn insert_into(
+ &self,
+ _state: &dyn Session,
+ _input: Arc,
+ _insert_op: InsertOp,
+ ) -> Result> {
+ Err(datafusion::common::DataFusionError::NotImplemented(
+ "FixedStatsProvider insert not used".to_string(),
+ ))
+ }
+ }
+
+ fn fixed_table_scan(rows: usize) -> Result {
+ let schema = Arc::new(Schema::new(vec![Field::new("k", DataType::Int64, false)]));
+ let provider = Arc::new(FixedStatsProvider {
+ schema: Arc::clone(&schema),
+ num_rows: rows,
+ });
+ let source = Arc::new(DefaultTableSource::new(provider));
+ LogicalPlanBuilder::scan("t", source, None)?.build()
+ }
+
+ // Single scan: row count is reported directly.
+ let small = fixed_table_scan(500)?;
+ assert_eq!(subtree_upper_bound_rows(&small), Some(500));
+
+ // Below the dim threshold → gate fires (skip propagation).
+ let fact = fixed_table_scan(1_000_000)?;
+ assert!(skip_propagation_by_cardinality(&small, &fact));
+
+ // Above the dim threshold + above the fact threshold → gate is silent.
+ let big_dim = fixed_table_scan(5_000)?;
+ assert!(!skip_propagation_by_cardinality(&big_dim, &fact));
+
+ // Below the fact threshold → gate fires from the fact side.
+ let tiny_fact = fixed_table_scan(50_000)?;
+ assert!(skip_propagation_by_cardinality(&big_dim, &tiny_fact));
+
+ Ok(())
+ }
+
+ #[test]
+ fn skip_propagation_by_cardinality_silent_when_stats_absent() -> Result<()> {
+ // MemTable doesn't expose row counts via `TableProvider::statistics()`,
+ // so the gate must fall back to the structural behavior (no skip).
+ use datafusion::catalog::MemTable;
+ use datafusion::datasource::DefaultTableSource;
+ use datafusion_expr::LogicalPlanBuilder;
+
+ let schema = Arc::new(Schema::new(vec![Field::new("k", DataType::Int64, false)]));
+ let provider = Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![]])?);
+ let source = Arc::new(DefaultTableSource::new(provider));
+ let scan = LogicalPlanBuilder::scan("t", source, None)?.build()?;
+
+ assert_eq!(subtree_upper_bound_rows(&scan), None);
+ assert!(!skip_propagation_by_cardinality(&scan, &scan));
+ Ok(())
+ }
+
+ #[test]
+ fn key_preserved_through_summaries_rejects_same_name_different_relation() -> Result<()> {
+ use datafusion::logical_expr::{Aggregate, Distinct, DistinctOn};
+ use datafusion_expr::builder::table_scan;
+
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
+ let scan = table_scan(Some("t2"), &schema, None)?.build()?;
+ let t1_key = Column::new(Some("t1"), "a");
+ let t2_key = Column::new(Some("t2"), "a");
+
+ let aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(scan.clone()),
+ vec![Expr::Column(t2_key.clone())],
+ vec![],
+ )?);
+ assert!(
+ !key_preserved_through_summaries(&aggregate, &t1_key),
+ "same-name GROUP BY columns from a different relation must not preserve the key"
+ );
+
+ let distinct_on = LogicalPlan::Distinct(Distinct::On(DistinctOn::try_new(
+ vec![Expr::Column(t2_key.clone())],
+ vec![Expr::Column(t2_key)],
+ None,
+ Arc::new(scan),
+ )?));
+ assert!(
+ !key_preserved_through_summaries(&distinct_on, &t1_key),
+ "same-name DISTINCT ON columns from a different relation must not preserve the key"
+ );
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn inner_join_with_key_only_filter_is_noop() -> Result<()> {
+ // `n_nationkey = 22` references only the join key — `DataFusion`'s
+ // stock `infer_join_predicates` already handles this case, so our
+ // rule must NOT fire and create a redundant subquery.
+ let ctx = make_ctx()?;
+ let plan = ctx
+ .sql(
+ "SELECT s_suppkey FROM supplier, nation \
+ WHERE s_nationkey = n_nationkey AND n_nationkey = 22",
+ )
+ .await?
+ .into_optimized_plan()?;
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, plan.clone(), &cfg)?;
+ assert!(
+ !changed,
+ "rule must not fire when filter references only the join key; plan was:\n{plan}"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn null_equal_inner_join_is_noop() -> Result<()> {
+ use datafusion::logical_expr::JoinConstraint;
+ use datafusion_expr::{builder::table_scan, lit};
+
+ let left_schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("c", DataType::Utf8, true),
+ ]));
+ let right_schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
+
+ let left_scan = table_scan(Some("l"), &left_schema, None)?.build()?;
+ let left = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("l"), "c")).eq(lit("v")),
+ Arc::new(left_scan),
+ )?);
+ let right = table_scan(Some("r"), &right_schema, None)?.build()?;
+
+ let join = LogicalPlan::Join(Join::try_new(
+ Arc::new(left),
+ Arc::new(right),
+ vec![(
+ Expr::Column(Column::new(Some("l"), "a")),
+ Expr::Column(Column::new(Some("r"), "x")),
+ )],
+ None,
+ JoinType::Inner,
+ JoinConstraint::On,
+ NullEquality::NullEqualsNull,
+ )?);
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (_, changed) = apply_rule_to_all_joins(&r, join, &cfg)?;
+
+ assert!(
+ !changed,
+ "rule must not introduce SQL IN filters for null-equal joins"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn composite_join_receives_one_filter_per_non_key_constrained_key() -> Result<()> {
+ use datafusion::common::NullEquality;
+ use datafusion::logical_expr::JoinConstraint;
+ use datafusion_expr::{builder::table_scan, lit};
+
+ let left_schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, false),
+ Field::new("b", DataType::Int64, false),
+ Field::new("c", DataType::Utf8, true),
+ ]));
+ let right_schema = Arc::new(Schema::new(vec![
+ Field::new("x", DataType::Int64, false),
+ Field::new("y", DataType::Int64, false),
+ ]));
+
+ let left_scan = table_scan(Some("l"), &left_schema, None)?.build()?;
+ let left = LogicalPlan::Filter(Filter::try_new(
+ Expr::Column(Column::new(Some("l"), "c")).eq(lit("v")),
+ Arc::new(left_scan),
+ )?);
+ let right = table_scan(Some("r"), &right_schema, None)?.build()?;
+
+ let join = LogicalPlan::Join(Join::try_new(
+ Arc::new(left),
+ Arc::new(right),
+ vec![
+ (
+ Expr::Column(Column::new(Some("l"), "a")),
+ Expr::Column(Column::new(Some("r"), "x")),
+ ),
+ (
+ Expr::Column(Column::new(Some("l"), "b")),
+ Expr::Column(Column::new(Some("r"), "y")),
+ ),
+ ],
+ None,
+ JoinType::Inner,
+ JoinConstraint::On,
+ NullEquality::NullEqualsNothing,
+ )?);
+
+ let r = rule();
+ let cfg = datafusion::optimizer::OptimizerContext::new();
+ let (transformed_plan, changed) = apply_rule_to_all_joins(&r, join, &cfg)?;
+
+ assert!(
+ changed,
+ "rule should fire on composite inner join with side-local non-key filter"
+ );
+ assert_eq!(
+ count_propagated_filter_exprs(&transformed_plan),
+ 2,
+ "each matching composite key should get one propagated filter; plan was:\n{transformed_plan}"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn expr_has_propagated_filter_detects_marker_alias() -> Result<()> {
+ use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit};
+
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
+ let scan = table_scan(Some("t"), &schema, None)?.build()?;
+ let projection = LogicalPlanBuilder::from(scan)
+ .project(vec![Expr::Column(Column::new(Some("t"), "a"))])?
+ .build()?;
+
+ let alias_name = format!("{PROPAGATED_FILTER_ALIAS_PREFIX}1");
+ let aliased = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new(
+ Arc::new(projection),
+ TableReference::bare(alias_name),
+ )?);
+
+ let in_subquery = Expr::InSubquery(InSubquery::new(
+ Box::new(lit(1i64)),
+ Subquery {
+ subquery: Arc::new(aliased),
+ outer_ref_columns: vec![],
+ spans: Spans::default(),
+ },
+ false,
+ ));
+
+ assert!(expr_has_propagated_filter(&in_subquery));
+ assert!(!expr_has_propagated_filter(&lit(5i64)));
+ Ok(())
+ }
+
+ #[test]
+ fn is_dim_like_subtree_handles_simple_scan() -> Result<()> {
+ use datafusion_expr::{LogicalPlanBuilder, builder::table_scan, lit};
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, false),
+ Field::new("x", DataType::Utf8, true),
+ ]));
+ let scan = table_scan(Some("t"), &schema, None)?.build()?;
+ assert!(is_dim_like_subtree(&scan));
+
+ let filtered = LogicalPlanBuilder::from(scan)
+ .filter(Expr::Column(Column::new(Some("t"), "x")).eq(lit("v")))?
+ .build()?;
+ assert!(is_dim_like_subtree(&filtered));
+ Ok(())
+ }
+}
diff --git a/crates/cayenne/src/metastore/sqlite.rs b/crates/cayenne/src/metastore/sqlite.rs
index 42e9c2c8f3..b594331778 100644
--- a/crates/cayenne/src/metastore/sqlite.rs
+++ b/crates/cayenne/src/metastore/sqlite.rs
@@ -120,6 +120,35 @@ impl SqliteMetastore {
if !db_dir.exists() {
tokio::fs::create_dir_all(db_dir).await?;
+
+ // Best-effort parent directory sync (defense-in-depth with
+ // the sync already performed in CayenneCatalog::init).
+ // Ensures the db_dir entry is durable before opening the
+ // SQLite connection and initializing the schema.
+ //
+ // We keep this best-effort (with warning on failure) for
+ // the same reasons as in CayenneCatalog::init: one-time
+ // initialization, followed by DB file + schema creation,
+ // and the parent is often a stable operator-managed
+ // volume root.
+ if let Some(parent) = db_dir.parent() {
+ let parent_for_sync = parent.to_path_buf();
+ let parent_display = parent_for_sync.display().to_string();
+ let db_dir_display = db_dir.display().to_string();
+ match tokio::task::spawn_blocking(move || {
+ std::fs::File::open(&parent_for_sync).and_then(|f| f.sync_all())
+ })
+ .await
+ {
+ Ok(Ok(())) => {}
+ Ok(Err(error)) => tracing::warn!(
+ "Failed to sync parent directory {parent_display} after creating SQLite catalog DB directory {db_dir_display} (subsequent DB writes will still be durable): {error}"
+ ),
+ Err(error) => tracing::warn!(
+ "Failed to join SQLite catalog DB parent directory sync task for {parent_display}: {error}"
+ ),
+ }
+ }
}
// Open connection with tokio-rusqlite
diff --git a/crates/cayenne/src/metastore/turso.rs b/crates/cayenne/src/metastore/turso.rs
index c3fc9ea771..e7e24ffdbc 100644
--- a/crates/cayenne/src/metastore/turso.rs
+++ b/crates/cayenne/src/metastore/turso.rs
@@ -82,6 +82,35 @@ impl TursoMetastore {
if !db_dir.exists() {
tokio::fs::create_dir_all(db_dir).await?;
+
+ // Best-effort parent directory sync (defense-in-depth with
+ // the sync already performed in CayenneCatalog::init).
+ // Ensures the db_dir entry is durable before opening the
+ // Turso connection and initializing the schema.
+ //
+ // We keep this best-effort (with warning on failure) for
+ // the same reasons as in CayenneCatalog::init: one-time
+ // initialization, followed by DB file + schema creation,
+ // and the parent is often a stable operator-managed
+ // volume root.
+ if let Some(parent) = db_dir.parent() {
+ let parent_for_sync = parent.to_path_buf();
+ let parent_display = parent_for_sync.display().to_string();
+ let db_dir_display = db_dir.display().to_string();
+ match tokio::task::spawn_blocking(move || {
+ std::fs::File::open(&parent_for_sync).and_then(|f| f.sync_all())
+ })
+ .await
+ {
+ Ok(Ok(())) => {}
+ Ok(Err(error)) => tracing::warn!(
+ "Failed to sync parent directory {parent_display} after creating Turso catalog DB directory {db_dir_display} (subsequent DB writes will still be durable): {error}"
+ ),
+ Err(error) => tracing::warn!(
+ "Failed to join Turso catalog DB parent directory sync task for {parent_display}: {error}"
+ ),
+ }
+ }
}
let db = Builder::new_local(db_path).build().await.map_err(|e| {
diff --git a/crates/cayenne/src/optimizer_rules.rs b/crates/cayenne/src/optimizer_rules.rs
index 91b223fe72..8f3e515741 100644
--- a/crates/cayenne/src/optimizer_rules.rs
+++ b/crates/cayenne/src/optimizer_rules.rs
@@ -15,24 +15,110 @@ limitations under the License.
*/
//! Physical optimizer rules for Cayenne execution plans.
+//!
+//! # No-spill build-side memory strategy (q21 / chbench multi-way joins)
+//!
+//! `DataFusion`'s `HashJoinExec` build side is non-spillable. Under the runtime
+//! memory pool (`GreedyMemoryPool` wrapped in `TrackConsumersPool`), wide chbench
+//! shapes such as q21 (a 5-way join feeding a correlated `NOT EXISTS` self-join
+//! over `order_line`) exhaust the `HashJoinInput[N]` reservations because each
+//! build-side hash table independently materializes its full keyspace.
+//!
+//! The q21 fix is layered so each optimizer rule handles the part `DataFusion`
+//! cannot currently spill or infer on its own:
+//!
+//! 1. **Logical predicate propagation.**
+//! [`crate::logical_optimizer::CayennePropagateFilterAcrossEquiJoinKeys`]
+//! introduces explicit `InSubquery` filters for equi-join keys when the
+//! selective predicate is on a non-key column. `DataFusion`'s stock
+//! `infer_join_predicates` only fires when the predicate already references
+//! a join key (`WHERE n_nationkey = 5` → `WHERE s_nationkey = 5`). For q21
+//! the filter is `n_name = 'CHINA'`, so the Cayenne rule exposes the
+//! `nation → supplier → stock/order_line` cardinality bound before
+//! `push_down_filter` plants it into scans.
+//!
+//! 2. **Cross-scan dynamic filter sharing.** When a join's
+//! `Arc` is pushed into one
+//! `CayenneAccelerationExec`, [`CayenneDynamicFilterSharing`] installs the
+//! same `Arc` on sibling `CayenneAccelerationExec`s backed by the same
+//! underlying table and equi-joined column set. The shared `Arc` carries the
+//! same `Arc>` state, so all sibling scans observe the exact
+//! filter values as soon as the producing join accumulates them. Applies to
+//! `Inner`, `LeftSemi`, and `RightSemi` parent joins (anti joins are
+//! excluded — their semantics require the *absence* of a match, so sharing
+//! the filter would drop rows the anti-join is supposed to preserve).
+//!
+//! 3. **Same-source anti / semi-join sort-merge rewrite.** `DataFusion` does not
+//! create dynamic filters for anti joins, and q21's `NOT EXISTS` self-join
+//! can leave large `HashJoinInput[N]` reservations behind.
+//! [`CayenneAntiJoinSortMergeRewriter`] rewrites same-source Cayenne
+//! `LeftAnti` / `RightAnti` / `LeftSemi` / `RightSemi` `HashJoinExec` nodes
+//! to `SortMergeJoinExec` with explicit spillable `SortExec` inputs above a
+//! 10M-row build-side threshold. Sort-merge preserves the join semantics for
+//! each of these types without materializing a full non-spillable hash table
+//! on the LEFT input (`HashJoinExec`'s build side, regardless of join type).
+//!
+//! [`CayenneJoinRewriter`] still handles the ordinary inner-join probe side by
+//! swapping the default in-list accumulator for [`ExactLeftAccumulator`], which
+//! produces a precise dynamic filter (or falls back to `RangeBounds` +
+//! `BloomFilter`) that `DataFusion`'s filter-pushdown phase plants into the
+//! right-side `CayenneAccelerationExec`'s `FileSource`.
+//!
+//! ## Audit notes (verified 2026-05-14 against the q21 explain snapshot at
+//! `crates/test-framework/src/snapshot/snapshots/explain/test_framework__snapshot__file[parquet]-cayenne[file]-indexes_tpch_q21_explain.snap`)
+//!
+//! * **Cayenne table statistics are `Exact` at the physical-plan boundary.**
+//! The chain `CayenneTableProvider::statistics`
+//! → [`crate::stats::file_statistics_to_df`] returns
+//! `Precision::Exact(num_rows)` whenever the persisted `i64` row count is
+//! non-negative. Per-file `Statistics` are also `Exact` because
+//! `VortexFormat::infer_stats` reads `row_count` from the file footer, and
+//! `SessionConfig::default().collect_statistics()` is `true`, so
+//! `ListingTable::do_collect_statistics` is exercised for every scan.
+//! `CayenneAccelerationExec::partition_statistics` simply delegates to the
+//! inner `DataSourceExec`, so the value reaches `JoinSelection`. The q21
+//! explain plan confirms `should_swap_join_order` picks the smaller side as
+//! build at every level (nation/supplier on the LEFT, lineitem on the
+//! RIGHT), so the residual OOM is *not* attributable to fuzzy stats — it is
+//! the **logical** join order locking in the SQL `FROM` order and applying
+//! the nation filter last.
+//!
+//! * **Build-side projections are minimal.** Every `CayenneAccelerationExec`
+//! in the snapshot terminates in a `DataSourceExec` whose `projection=[...]`
+//! lists only the join keys and the columns referenced above the join.
+//! `DataFusion`'s stock projection pushdown already prunes wider scans down to
+//! `[s_suppkey, s_name, s_nationkey]`, `[o_orderkey, o_orderstatus]`,
+//! `[l_orderkey, l_suppkey]`, etc. No additional `ProjectionExec` insertion
+//! above the build side is required.
+//!
+//! With these layers active, q21 is included in
+//! `test_framework::queries::get_chbench_test_queries`.
-use datafusion::common::NullEquality;
+use arrow::compute::SortOptions;
+use arrow::datatypes::{DataType, SchemaRef};
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
+use datafusion::common::{JoinType, NullEquality};
use datafusion::config::ConfigOptions;
use datafusion::error::DataFusionError;
+use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
-use datafusion::physical_plan::joins::HashJoinExec;
+use datafusion::physical_plan::joins::{HashJoinExec, SortMergeJoinExec};
+use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::{error::Result, physical_plan::projection::ProjectionExec};
+use datafusion_common::stats::Precision;
+use datafusion_physical_expr::PhysicalExpr;
+use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use runtime_datafusion::execution_plan::schema_cast::SchemaCastScanExec;
use runtime_datafusion::extension::bytes_processed::BytesProcessedExec;
use runtime_datafusion::join_accumulator::ExactLeftAccumulator;
+use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
use crate::provider::CayenneAccelerationExec;
-use crate::provider::scan::IsCayenneAccelerationExec;
+use crate::provider::scan::{IsCayenneAccelerationExec, ScanDynamicFilter, ScanIdentity};
/// Optimizer rule that rewrites `HashJoinExec` nodes to use `ExactLeftAccumulator`
/// when the probe side is a `CayenneAccelerationExec`.
@@ -47,6 +133,530 @@ impl CayenneJoinRewriter {
}
}
+/// Shares already-pushed hash-join dynamic filters between same-source Cayenne
+/// scans when the current hash join proves the relevant columns are equi-joined.
+#[derive(Default)]
+pub struct CayenneDynamicFilterSharing;
+
+impl CayenneDynamicFilterSharing {
+ /// Create a new `CayenneDynamicFilterSharing` optimizer rule.
+ #[must_use]
+ pub fn new() -> Self {
+ Self
+ }
+}
+
+impl std::fmt::Debug for CayenneDynamicFilterSharing {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("CayenneDynamicFilterSharing").finish()
+ }
+}
+
+/// Rewrites same-source Cayenne anti and semi joins from hash join to
+/// sort-merge join when the build side is large enough to risk OOM.
+///
+/// `DataFusion`'s `HashJoinExec` always materializes its left input as the
+/// non-spillable build side regardless of join type. For q21's correlated
+/// `NOT EXISTS` self-join (a `LeftAnti`) that build side can be a large
+/// multi-way `order_line` result; the same shape arises in `EXISTS` /
+/// `IN (subquery)` constructs that decorrelate into `LeftSemi`. Sort-merge
+/// preserves the join semantics for each of these types while keeping the
+/// build side spillable.
+#[derive(Default)]
+pub struct CayenneAntiJoinSortMergeRewriter;
+
+/// Only rewrite same-source anti or semi joins whose LEFT (build) input has
+/// `Precision::Exact` row count exceeding this threshold. Below it the
+/// in-memory hash table is faster than two explicit sort buffers.
+const ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS: usize = 10_000_000;
+
+impl CayenneAntiJoinSortMergeRewriter {
+ /// Create a new `CayenneAntiJoinSortMergeRewriter` optimizer rule.
+ #[must_use]
+ pub fn new() -> Self {
+ Self
+ }
+}
+
+impl std::fmt::Debug for CayenneAntiJoinSortMergeRewriter {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("CayenneAntiJoinSortMergeRewriter").finish()
+ }
+}
+
+impl PhysicalOptimizerRule for CayenneAntiJoinSortMergeRewriter {
+ fn name(&self) -> &'static str {
+ "CayenneAntiJoinSortMergeRewriter"
+ }
+
+ fn schema_check(&self) -> bool {
+ false
+ }
+
+ fn optimize(
+ &self,
+ plan: Arc,
+ _config: &ConfigOptions,
+ ) -> Result, DataFusionError> {
+ plan.transform_down(|node| {
+ let Some(hash_join) = node.as_any().downcast_ref::() else {
+ return Ok(Transformed::no(node));
+ };
+
+ let Some(sort_merge_join) = try_rewrite_same_source_anti_join(hash_join)? else {
+ return Ok(Transformed::no(node));
+ };
+
+ Ok(Transformed::yes(sort_merge_join))
+ })
+ .data()
+ }
+}
+
+impl PhysicalOptimizerRule for CayenneDynamicFilterSharing {
+ fn name(&self) -> &'static str {
+ "CayenneDynamicFilterSharing"
+ }
+
+ fn schema_check(&self) -> bool {
+ false
+ }
+
+ fn optimize(
+ &self,
+ plan: Arc,
+ config: &ConfigOptions,
+ ) -> Result, DataFusionError> {
+ plan.transform_down(|node| {
+ let Some(hash_join) = node.as_any().downcast_ref::() else {
+ return Ok(Transformed::no(node));
+ };
+
+ let (left_additions, right_additions) = filter_additions_for_join(hash_join);
+ if left_additions.is_empty() && right_additions.is_empty() {
+ return Ok(Transformed::no(node));
+ }
+
+ let (left, left_changed) =
+ apply_filter_additions(Arc::clone(hash_join.left()), &left_additions, config)?;
+ let (right, right_changed) =
+ apply_filter_additions(Arc::clone(hash_join.right()), &right_additions, config)?;
+
+ if !left_changed && !right_changed {
+ return Ok(Transformed::no(node));
+ }
+
+ let new_node = node.with_new_children(vec![left, right])?;
+ Ok(Transformed::yes(new_node))
+ })
+ .data()
+ }
+}
+
+#[derive(Clone)]
+struct CayenneScanSummary {
+ identity: Arc,
+ columns: BTreeSet,
+ schema_fields: Vec<(String, DataType)>,
+ dynamic_filters: Vec,
+}
+
+#[derive(Clone)]
+struct FilterAddition {
+ identity: Arc,
+ schema_fields: Vec<(String, DataType)>,
+ filter: Arc,
+}
+
+fn filter_additions_for_join(
+ hash_join: &HashJoinExec,
+) -> (Vec, Vec) {
+ // `Inner`, `LeftSemi`, and `RightSemi` all preserve the equi-key domain:
+ // a dynamic filter built from one side is also a valid filter for an
+ // equi-joined same-source scan on the other side. `LeftAnti`/`RightAnti`
+ // do not — their output requires the absence of a match, so propagating
+ // the filter would drop rows that should be retained.
+ if !matches!(
+ *hash_join.join_type(),
+ JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi,
+ ) {
+ return (Vec::new(), Vec::new());
+ }
+
+ let left_scans = collect_cayenne_scans(hash_join.left());
+ let right_scans = collect_cayenne_scans(hash_join.right());
+ if left_scans.is_empty() || right_scans.is_empty() {
+ return (Vec::new(), Vec::new());
+ }
+ let right_scans_by_identity = scans_by_identity(&right_scans);
+
+ let mut pair_columns: HashMap<(usize, usize), BTreeSet> = HashMap::new();
+ for (left_key, right_key) in hash_join.on() {
+ let Some(left_column) = physical_column_name(left_key) else {
+ continue;
+ };
+ let Some(right_column) = physical_column_name(right_key) else {
+ continue;
+ };
+
+ if left_column != right_column {
+ continue;
+ }
+
+ let matching_pairs = same_source_pairs_for_column(
+ &left_scans,
+ &right_scans,
+ &right_scans_by_identity,
+ left_column,
+ right_column,
+ );
+ let [(left_index, right_index)] = matching_pairs.as_slice() else {
+ continue;
+ };
+ if left_scans[*left_index].schema_fields != right_scans[*right_index].schema_fields {
+ continue;
+ }
+
+ pair_columns
+ .entry((*left_index, *right_index))
+ .or_default()
+ .insert(left_column.to_string());
+ }
+
+ let mut left_additions = Vec::new();
+ let mut right_additions = Vec::new();
+
+ for ((left_index, right_index), shared_columns) in pair_columns {
+ let left_scan = &left_scans[left_index];
+ let right_scan = &right_scans[right_index];
+
+ for filter in &left_scan.dynamic_filters {
+ if filter.columns().is_subset(&shared_columns) {
+ push_filter_addition(
+ &mut right_additions,
+ Arc::clone(&right_scan.identity),
+ right_scan.schema_fields.clone(),
+ Arc::clone(filter.filter()),
+ );
+ }
+ }
+
+ for filter in &right_scan.dynamic_filters {
+ if filter.columns().is_subset(&shared_columns) {
+ push_filter_addition(
+ &mut left_additions,
+ Arc::clone(&left_scan.identity),
+ left_scan.schema_fields.clone(),
+ Arc::clone(filter.filter()),
+ );
+ }
+ }
+ }
+
+ (left_additions, right_additions)
+}
+
+fn try_rewrite_same_source_anti_join(
+ hash_join: &HashJoinExec,
+) -> Result>, DataFusionError> {
+ // Same-source `LeftAnti`/`RightAnti`/`LeftSemi`/`RightSemi` joins all
+ // share the relevant property: `HashJoinExec` builds the LEFT input into
+ // a non-spillable hash table, so a large build side risks OOM. Sort-merge
+ // is spillable and preserves the join semantics for each of these types.
+ if !matches!(
+ hash_join.join_type(),
+ JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi | JoinType::RightSemi,
+ ) {
+ return Ok(None);
+ }
+
+ if hash_join.contains_projection() || hash_join.on().is_empty() {
+ return Ok(None);
+ }
+
+ if !has_single_same_source_pair_for_all_join_keys(hash_join) {
+ return Ok(None);
+ }
+
+ let Some(build_row_count) = spillable_rewrite_build_input_exact_rows(hash_join) else {
+ return Ok(None);
+ };
+ if build_row_count <= ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS {
+ return Ok(None);
+ }
+
+ let sort_options = vec![SortOptions::default(); hash_join.on().len()];
+ let Some(left_ordering) = join_key_ordering(
+ hash_join
+ .on()
+ .iter()
+ .map(|(left_key, _)| Arc::clone(left_key)),
+ &sort_options,
+ ) else {
+ return Ok(None);
+ };
+ let Some(right_ordering) = join_key_ordering(
+ hash_join
+ .on()
+ .iter()
+ .map(|(_, right_key)| Arc::clone(right_key)),
+ &sort_options,
+ ) else {
+ return Ok(None);
+ };
+
+ let left: Arc =
+ Arc::new(SortExec::new(left_ordering, Arc::clone(hash_join.left())));
+ let right: Arc =
+ Arc::new(SortExec::new(right_ordering, Arc::clone(hash_join.right())));
+
+ let join = SortMergeJoinExec::try_new(
+ left,
+ right,
+ hash_join.on().to_vec(),
+ hash_join.filter().cloned(),
+ *hash_join.join_type(),
+ sort_options,
+ hash_join.null_equality(),
+ )?;
+
+ tracing::debug!(
+ join_type = ?hash_join.join_type(),
+ build_row_count,
+ threshold = ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS,
+ "Replacing same-source Cayenne anti/semi HashJoinExec with SortMergeJoinExec"
+ );
+
+ Ok(Some(Arc::new(join)))
+}
+
+fn spillable_rewrite_build_input_exact_rows(hash_join: &HashJoinExec) -> Option {
+ // `HashJoinExec` materializes the LEFT input as the (non-spillable) build
+ // hash table regardless of join type — including `*Anti` and `*Semi`.
+ let build_input = match hash_join.join_type() {
+ JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi | JoinType::RightSemi => {
+ hash_join.left()
+ }
+ _ => return None,
+ };
+
+ match build_input.partition_statistics(None).ok()?.num_rows {
+ Precision::Exact(row_count) => Some(row_count),
+ Precision::Inexact(_) | Precision::Absent => None,
+ }
+}
+
+fn join_key_ordering(
+ keys: impl Iterator- >,
+ sort_options: &[SortOptions],
+) -> Option
{
+ let sort_exprs = keys
+ .zip(sort_options.iter().copied())
+ .map(|(expr, options)| PhysicalSortExpr { expr, options })
+ .collect::>();
+
+ LexOrdering::new(sort_exprs)
+}
+
+fn has_single_same_source_pair_for_all_join_keys(hash_join: &HashJoinExec) -> bool {
+ let left_scans = collect_cayenne_scans(hash_join.left());
+ let right_scans = collect_cayenne_scans(hash_join.right());
+ if left_scans.is_empty() || right_scans.is_empty() {
+ return false;
+ }
+ let right_scans_by_identity = scans_by_identity(&right_scans);
+
+ let mut matched_pair = None;
+ for (left_key, right_key) in hash_join.on() {
+ let Some(left_column) = physical_column_name(left_key) else {
+ return false;
+ };
+ let Some(right_column) = physical_column_name(right_key) else {
+ return false;
+ };
+
+ if left_column != right_column {
+ return false;
+ }
+
+ let pairs = same_source_pairs_for_column(
+ &left_scans,
+ &right_scans,
+ &right_scans_by_identity,
+ left_column,
+ right_column,
+ );
+ let [(left_index, right_index)] = pairs.as_slice() else {
+ return false;
+ };
+ let pair = (*left_index, *right_index);
+
+ if matched_pair.is_some_and(|previous_pair| previous_pair != pair) {
+ return false;
+ }
+ matched_pair = Some(pair);
+ }
+
+ matched_pair.is_some()
+}
+
+fn collect_cayenne_scans(plan: &Arc) -> Vec {
+ let mut scans = Vec::new();
+ collect_cayenne_scans_inner(plan, &mut scans);
+ scans
+}
+
+fn collect_cayenne_scans_inner(plan: &Arc, scans: &mut Vec) {
+ if let Some(cayenne) = plan.as_any().downcast_ref::()
+ && let Some(identity) = cayenne.scan_identity()
+ {
+ let schema_fields = plan_schema_fields(&cayenne.schema());
+ let columns = schema_fields.iter().map(|(name, _)| name.clone()).collect();
+ scans.push(CayenneScanSummary {
+ identity,
+ columns,
+ schema_fields,
+ dynamic_filters: cayenne.dynamic_filters(),
+ });
+ return;
+ }
+
+ for child in plan.children() {
+ collect_cayenne_scans_inner(child, scans);
+ }
+}
+
+fn physical_column_name(expr: &Arc) -> Option<&str> {
+ expr.as_any().downcast_ref::().map(Column::name)
+}
+
+fn scans_by_identity(scans: &[CayenneScanSummary]) -> HashMap, Vec> {
+ let mut by_identity: HashMap, Vec> = HashMap::new();
+ for (index, scan) in scans.iter().enumerate() {
+ by_identity
+ .entry(Arc::clone(&scan.identity))
+ .or_default()
+ .push(index);
+ }
+ by_identity
+}
+
+fn same_source_pairs_for_column(
+ left_scans: &[CayenneScanSummary],
+ right_scans: &[CayenneScanSummary],
+ right_scans_by_identity: &HashMap, Vec>,
+ left_column: &str,
+ right_column: &str,
+) -> Vec<(usize, usize)> {
+ let mut pairs = Vec::new();
+
+ for (left_index, left_scan) in left_scans.iter().enumerate() {
+ if !left_scan.columns.contains(left_column) {
+ continue;
+ }
+
+ let Some(right_indices) = right_scans_by_identity.get(&left_scan.identity) else {
+ continue;
+ };
+
+ for &right_index in right_indices {
+ if right_scans[right_index].columns.contains(right_column) {
+ pairs.push((left_index, right_index));
+ }
+ }
+ }
+
+ pairs
+}
+
+fn push_filter_addition(
+ additions: &mut Vec,
+ identity: Arc,
+ schema_fields: Vec<(String, DataType)>,
+ filter: Arc,
+) {
+ if additions.iter().any(|addition| {
+ addition.identity == identity
+ && addition.schema_fields == schema_fields
+ && Arc::ptr_eq(&addition.filter, &filter)
+ }) {
+ return;
+ }
+
+ additions.push(FilterAddition {
+ identity,
+ schema_fields,
+ filter,
+ });
+}
+
+fn plan_schema_fields(schema: &SchemaRef) -> Vec<(String, DataType)> {
+ schema
+ .fields()
+ .iter()
+ .map(|field| (field.name().clone(), field.data_type().clone()))
+ .collect()
+}
+
+fn apply_filter_additions(
+ plan: Arc,
+ additions: &[FilterAddition],
+ config: &ConfigOptions,
+) -> Result<(Arc, bool), DataFusionError> {
+ if additions.is_empty() {
+ return Ok((plan, false));
+ }
+
+ if let Some(cayenne) = plan.as_any().downcast_ref::() {
+ let Some(identity) = cayenne.scan_identity() else {
+ return Ok((plan, false));
+ };
+ let schema_fields = plan_schema_fields(&cayenne.schema());
+ let existing = cayenne.dynamic_filters();
+ let filters = additions
+ .iter()
+ .filter(|addition| addition.identity == identity)
+ .filter(|addition| addition.schema_fields == schema_fields)
+ .filter(|addition| {
+ !existing
+ .iter()
+ .any(|filter| Arc::ptr_eq(filter.filter(), &addition.filter))
+ })
+ .map(|addition| Arc::clone(&addition.filter))
+ .collect::>();
+
+ let Some(new_plan) = cayenne.with_additional_dynamic_filters(&filters, config)? else {
+ return Ok((plan, false));
+ };
+
+ return Ok((new_plan, true));
+ }
+
+ let children = plan
+ .children()
+ .into_iter()
+ .map(Arc::clone)
+ .collect::>();
+ if children.is_empty() {
+ return Ok((plan, false));
+ }
+
+ let mut changed = false;
+ let mut new_children = Vec::with_capacity(children.len());
+ for child in children {
+ let (new_child, child_changed) = apply_filter_additions(child, additions, config)?;
+ changed |= child_changed;
+ new_children.push(new_child);
+ }
+
+ if !changed {
+ return Ok((plan, false));
+ }
+
+ plan.with_new_children(new_children)
+ .map(|plan| (plan, true))
+}
+
impl std::fmt::Debug for CayenneJoinRewriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CayenneJoinRewriter").finish()
@@ -195,20 +805,135 @@ impl PhysicalOptimizerRule for CayenneJoinRewriter {
#[cfg(test)]
mod tests {
- use super::CayenneJoinRewriter;
+ use super::{
+ ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS, CayenneAntiJoinSortMergeRewriter,
+ CayenneDynamicFilterSharing, CayenneJoinRewriter, FilterAddition, apply_filter_additions,
+ plan_schema_fields,
+ };
use crate::provider::CayenneAccelerationExec;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::{JoinType, NullEquality};
use datafusion::config::ConfigOptions;
use datafusion::datasource::memory::MemorySourceConfig;
+ use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
- use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
+ use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
use datafusion::physical_plan::projection::ProjectionExec;
+ use datafusion::physical_plan::sorts::sort::SortExec;
+ use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::{ExecutionPlan, displayable};
- use datafusion_physical_expr::expressions::col;
+ use datafusion_common::stats::Precision;
+ use datafusion_common::{DataFusionError, Result as DFResult, Statistics};
+ use datafusion_datasource::file::FileSource;
+ use datafusion_datasource::file_groups::FileGroup;
+ use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
+ use datafusion_datasource::file_stream::FileOpener;
+ use datafusion_datasource::source::DataSourceExec;
+ use datafusion_datasource::{PartitionedFile, TableSchema};
+ use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, col, lit};
+ use datafusion_physical_expr::projection::ProjectionExprs;
+ use datafusion_physical_expr::{PhysicalExpr, conjunction};
+ use datafusion_physical_plan::DisplayFormatType;
+ use datafusion_physical_plan::filter_pushdown::{FilterPushdownPropagation, PushedDown};
+ use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet;
+ use object_store::ObjectMeta;
+ use object_store::ObjectStore;
+ use object_store::path::Path;
use runtime_datafusion::join_accumulator::ExactLeftAccumulator;
+ use std::any::Any;
use std::sync::Arc;
+ #[derive(Clone)]
+ struct TestFileSource {
+ table_schema: TableSchema,
+ filter: Option>,
+ metrics: ExecutionPlanMetricsSet,
+ }
+
+ impl TestFileSource {
+ fn new(table_schema: TableSchema, filter: Option>) -> Self {
+ Self {
+ table_schema,
+ filter,
+ metrics: ExecutionPlanMetricsSet::new(),
+ }
+ }
+ }
+
+ impl FileSource for TestFileSource {
+ fn create_file_opener(
+ &self,
+ _object_store: Arc,
+ _base_config: &datafusion_datasource::file_scan_config::FileScanConfig,
+ _partition: usize,
+ ) -> DFResult> {
+ Err(DataFusionError::NotImplemented(
+ "test source cannot open files".to_string(),
+ ))
+ }
+
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn table_schema(&self) -> &TableSchema {
+ &self.table_schema
+ }
+
+ fn with_batch_size(&self, _batch_size: usize) -> Arc {
+ Arc::new(self.clone())
+ }
+
+ fn filter(&self) -> Option> {
+ self.filter.clone()
+ }
+
+ fn projection(&self) -> Option<&ProjectionExprs> {
+ None
+ }
+
+ fn metrics(&self) -> &ExecutionPlanMetricsSet {
+ &self.metrics
+ }
+
+ fn file_type(&self) -> &'static str {
+ "test"
+ }
+
+ fn fmt_extra(
+ &self,
+ _t: DisplayFormatType,
+ _f: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ Ok(())
+ }
+
+ fn try_pushdown_filters(
+ &self,
+ filters: Vec>,
+ _config: &ConfigOptions,
+ ) -> DFResult>> {
+ let filter_count = filters.len();
+ let filter = match &self.filter {
+ Some(existing) => Some(conjunction(
+ std::iter::once(Arc::clone(existing)).chain(filters),
+ )),
+ None => Some(conjunction(filters)),
+ };
+ let source = Self {
+ table_schema: self.table_schema.clone(),
+ filter,
+ metrics: ExecutionPlanMetricsSet::new(),
+ };
+
+ Ok(FilterPushdownPropagation::with_parent_pushdown_result(vec![
+ PushedDown::Yes;
+ filter_count
+ ])
+ .with_updated_node(Arc::new(source)))
+ }
+ }
+
fn memory_exec(column_name: &str) -> Arc {
let schema = Arc::new(Schema::new(vec![Field::new(
column_name,
@@ -219,6 +944,118 @@ mod tests {
.expect("memory exec should be valid")
}
+ fn file_exec(
+ schema: &Arc,
+ path: &str,
+ filter: Option>,
+ ) -> Arc {
+ file_exec_with_statistics(schema, path, filter, Statistics::new_unknown(schema))
+ }
+
+ fn file_exec_with_statistics(
+ schema: &Arc,
+ path: &str,
+ filter: Option>,
+ statistics: Statistics,
+ ) -> Arc {
+ let table_schema = TableSchema::new(Arc::clone(schema), Vec::new());
+ let source = Arc::new(TestFileSource::new(table_schema, filter));
+ let file = PartitionedFile::from(ObjectMeta {
+ location: Path::from(path),
+ last_modified: chrono::DateTime::UNIX_EPOCH,
+ size: 1_024,
+ e_tag: None,
+ version: None,
+ });
+ let config = FileScanConfigBuilder::new(
+ ObjectStoreUrl::parse("file:///").expect("object store url should parse"),
+ source,
+ )
+ .with_file_group(FileGroup::new(vec![file]))
+ .with_statistics(statistics)
+ .build();
+
+ DataSourceExec::from_data_source(config) as Arc
+ }
+
+ fn cayenne_file_exec(
+ schema: &Arc,
+ path: &str,
+ filter: Option>,
+ ) -> Arc {
+ Arc::new(CayenneAccelerationExec::new(file_exec(
+ schema, path, filter,
+ )))
+ }
+
+ fn inlined_exec(schema: &Arc) -> Arc {
+ MemorySourceConfig::try_new_exec(&[vec![]], Arc::clone(schema), None)
+ .expect("inlined memory exec should be valid")
+ }
+
+ fn cayenne_file_with_inlined_exec(
+ schema: &Arc,
+ path: &str,
+ filter: Option>,
+ ) -> Arc {
+ Arc::new(CayenneAccelerationExec::new(
+ UnionExec::try_new(vec![file_exec(schema, path, filter), inlined_exec(schema)])
+ .expect("mixed file and inlined union should be valid"),
+ ))
+ }
+
+ fn cayenne_file_exec_with_num_rows(
+ schema: &Arc,
+ path: &str,
+ row_count: Precision,
+ ) -> Arc {
+ Arc::new(CayenneAccelerationExec::new(file_exec_with_statistics(
+ schema,
+ path,
+ None,
+ Statistics::new_unknown(schema).with_num_rows(row_count),
+ )))
+ }
+
+ fn large_exact_cayenne_file_exec(schema: &Arc, path: &str) -> Arc {
+ cayenne_file_exec_with_num_rows(
+ schema,
+ path,
+ Precision::Exact(ANTI_JOIN_SORT_MERGE_MIN_EXACT_ROWS + 1),
+ )
+ }
+
+ fn order_line_schema() -> Arc {
+ Arc::new(Schema::new(vec![
+ Field::new("order_id", DataType::Int64, false),
+ Field::new("warehouse_id", DataType::Int64, false),
+ Field::new("line_number", DataType::Int64, false),
+ ]))
+ }
+
+ fn reordered_order_line_schema() -> Arc {
+ Arc::new(Schema::new(vec![
+ Field::new("warehouse_id", DataType::Int64, false),
+ Field::new("order_id", DataType::Int64, false),
+ Field::new("line_number", DataType::Int64, false),
+ ]))
+ }
+
+ fn order_line_schema_with_different_non_key_type() -> Arc {
+ Arc::new(Schema::new(vec![
+ Field::new("order_id", DataType::Int64, false),
+ Field::new("warehouse_id", DataType::Int64, false),
+ Field::new("line_number", DataType::UInt64, false),
+ ]))
+ }
+
+ fn dynamic_filter_for(column_name: &str, schema: &Schema) -> Arc {
+ Arc::new(DynamicFilterPhysicalExpr::new(
+ vec![col(column_name, schema).expect("filter column should exist")],
+ lit(true),
+ ))
+ }
+
fn hash_join(
left: Arc