diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..3463869 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "weekly" + groups: + rust-dependencies: + patterns: + - "*" + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/src/physical_plan/exec/fetch.rs b/src/physical_plan/exec/fetch.rs index 71d668b..c4f776e 100644 --- a/src/physical_plan/exec/fetch.rs +++ b/src/physical_plan/exec/fetch.rs @@ -40,10 +40,11 @@ use datafusion::physical_plan::{ use futures::stream::{Stream, StreamExt}; use crate::physical_plan::exec::index::IndexScanExec; +use crate::physical_plan::exec::sequential_union::SequentialUnionExec; use crate::physical_plan::fetcher::RecordFetcher; use crate::physical_plan::joins::try_create_index_lookup_join; use crate::physical_plan::{create_index_schema, ROW_ID_COLUMN_NAME}; -use crate::types::{IndexFilter, IndexFilters}; +use crate::types::{IndexFilter, IndexFilters, UnionMode}; use datafusion::arrow::datatypes::Schema; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use datafusion::physical_plan::empty::EmptyExec; @@ -68,15 +69,25 @@ pub struct RecordFetchExec { input: Arc, metrics: ExecutionPlanMetricsSet, schema: SchemaRef, + /// Controls how union operations are executed for OR conditions. + union_mode: UnionMode, } impl RecordFetchExec { /// Create a new `RecordFetchExec` plan. + /// + /// # Arguments + /// * `indexes` - Index filters to use for scanning + /// * `limit` - Optional limit on the number of rows + /// * `record_fetcher` - The fetcher to retrieve records by row ID + /// * `schema` - Output schema + /// * `union_mode` - Controls whether OR conditions use parallel or sequential union pub fn try_new( indexes: Vec, limit: Option, record_fetcher: Arc, schema: SchemaRef, + union_mode: UnionMode, ) -> Result { if indexes.is_empty() { return Err(DataFusionError::Plan( @@ -91,7 +102,7 @@ impl RecordFetchExec { } let input = match indexes.first() { - Some(index_filter) => Self::build_scan_exec(index_filter, limit)?, + Some(index_filter) => Self::build_scan_exec(index_filter, limit, union_mode)?, None => { return Err(DataFusionError::Plan( "RecordFetchExec requires at least one index".to_string(), @@ -114,6 +125,7 @@ impl RecordFetchExec { input, metrics: ExecutionPlanMetricsSet::new(), schema, + union_mode, }) } @@ -179,6 +191,7 @@ impl RecordFetchExec { /// # Arguments /// * `index_filter` - The [`IndexFilter`] tree specifying which indexes to scan and how to combine them /// * `limit` - Optional limit on the number of rows to return, passed through to individual index scans + /// * `union_mode` - Controls whether OR conditions use parallel or sequential union /// /// # Returns /// An [`Arc`] that produces a stream of row IDs matching the filter criteria. @@ -192,6 +205,7 @@ impl RecordFetchExec { fn build_scan_exec( index_filter: &IndexFilter, limit: Option, + union_mode: UnionMode, ) -> Result> { match index_filter { IndexFilter::Single { index, filter } => { @@ -217,7 +231,7 @@ impl RecordFetchExec { IndexFilter::And(filters) => { let mut plans = filters .iter() - .map(|f| Self::build_scan_exec(f, limit)) + .map(|f| Self::build_scan_exec(f, limit, union_mode)) .collect::>>()?; if plans.is_empty() { @@ -236,7 +250,7 @@ impl RecordFetchExec { IndexFilter::Or(filters) => { let original_plans = filters .iter() - .map(|f| Self::build_scan_exec(f, limit)) + .map(|f| Self::build_scan_exec(f, limit, union_mode)) .collect::>>()?; if original_plans.is_empty() { @@ -284,8 +298,13 @@ impl RecordFetchExec { } } - // Now all plans have identical schemas, UnionExec will work - let union_input = UnionExec::try_new(normalized_plans)?; + // Now all plans have identical schemas - create union based on mode + let union_input: Arc = match union_mode { + UnionMode::Parallel => UnionExec::try_new(normalized_plans)?, + UnionMode::Sequential => { + Arc::new(SequentialUnionExec::try_new(normalized_plans)?) + } + }; // Create aggregate to deduplicate row IDs let group_expr = PhysicalGroupBy::new_single(vec![( @@ -377,6 +396,7 @@ impl ExecutionPlan for RecordFetchExec { input: children[0].clone(), metrics: self.metrics.clone(), schema: self.schema.clone(), + union_mode: self.union_mode, })) } @@ -949,8 +969,14 @@ mod tests { #[tokio::test] async fn test_record_fetch_exec_no_indexes() { let fetcher = Arc::new(MockRecordFetcher::new()); - let err = - RecordFetchExec::try_new(vec![], None, fetcher, Arc::new(Schema::empty())).unwrap_err(); + let err = RecordFetchExec::try_new( + vec![], + None, + fetcher, + Arc::new(Schema::empty()), + UnionMode::Parallel, + ) + .unwrap_err(); assert!( matches!(err, DataFusionError::Plan(ref msg) if msg == "RecordFetchExec requires at least one index"), "Unexpected error: {err:?}" @@ -970,7 +996,13 @@ mod tests { }]; let fetcher = Arc::new(MockRecordFetcher::new()); - let exec = RecordFetchExec::try_new(indexes, None, fetcher, Arc::new(Schema::empty()))?; + let exec = RecordFetchExec::try_new( + indexes, + None, + fetcher, + Arc::new(Schema::empty()), + UnionMode::Parallel, + )?; // The input plan should be just the IndexScanExec assert_eq!(exec.input.name(), "IndexScanExec"); @@ -1004,7 +1036,13 @@ mod tests { ])]; let fetcher = Arc::new(MockRecordFetcher::new()); - let exec = RecordFetchExec::try_new(indexes, None, fetcher, Arc::new(Schema::empty()))?; + let exec = RecordFetchExec::try_new( + indexes, + None, + fetcher, + Arc::new(Schema::empty()), + UnionMode::Parallel, + )?; // The input plan should be a HashJoinExec assert_eq!(exec.input.name(), "HashJoinExec"); @@ -1027,7 +1065,13 @@ mod tests { let indexes = vec![IndexFilter::And(indexes_vec)]; let fetcher = Arc::new(MockRecordFetcher::new()); - let exec = RecordFetchExec::try_new(indexes, None, fetcher, Arc::new(Schema::empty()))?; + let exec = RecordFetchExec::try_new( + indexes, + None, + fetcher, + Arc::new(Schema::empty()), + UnionMode::Parallel, + )?; // The input plan should be a tree of HashJoinExecs assert_eq!(exec.input.name(), "HashJoinExec"); @@ -1063,7 +1107,8 @@ mod tests { let schema = fetcher.schema(); // 2. Create exec plan - let exec = RecordFetchExec::try_new(indexes, None, fetcher, schema.clone())?; + let exec = + RecordFetchExec::try_new(indexes, None, fetcher, schema.clone(), UnionMode::Parallel)?; // 3. Execute and collect results let task_ctx = Arc::new(TaskContext::default()); @@ -1100,7 +1145,13 @@ mod tests { let fetcher = Arc::new(MockRecordFetcher::new().with_data()); // 2. Create exec plan - let exec = RecordFetchExec::try_new(indexes, None, fetcher, Arc::new(Schema::empty()))?; + let exec = RecordFetchExec::try_new( + indexes, + None, + fetcher, + Arc::new(Schema::empty()), + UnionMode::Parallel, + )?; // 3. Execute and collect results let task_ctx = Arc::new(TaskContext::default()); @@ -1136,7 +1187,8 @@ mod tests { let schema = fetcher.schema(); // 2. Create exec plan - let exec = RecordFetchExec::try_new(indexes, None, fetcher, schema.clone())?; + let exec = + RecordFetchExec::try_new(indexes, None, fetcher, schema.clone(), UnionMode::Parallel)?; // 3. Execute and collect results let task_ctx = Arc::new(TaskContext::default()); @@ -1193,7 +1245,13 @@ mod tests { let fetcher = Arc::new(ErrorFetcher); // 2. Create exec plan - let exec = RecordFetchExec::try_new(indexes, None, fetcher, Arc::new(Schema::empty()))?; + let exec = RecordFetchExec::try_new( + indexes, + None, + fetcher, + Arc::new(Schema::empty()), + UnionMode::Parallel, + )?; // 3. Execute and expect an error let task_ctx = Arc::new(TaskContext::default()); diff --git a/src/physical_plan/exec/mod.rs b/src/physical_plan/exec/mod.rs index bd124b6..f1b75e1 100644 --- a/src/physical_plan/exec/mod.rs +++ b/src/physical_plan/exec/mod.rs @@ -18,3 +18,4 @@ //! Physical `ExecutionPlan` operators. pub mod fetch; pub mod index; +pub mod sequential_union; diff --git a/src/physical_plan/exec/sequential_union.rs b/src/physical_plan/exec/sequential_union.rs new file mode 100644 index 0000000..64a374e --- /dev/null +++ b/src/physical_plan/exec/sequential_union.rs @@ -0,0 +1,372 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sequential union execution plan that processes inputs without spawning tasks. +//! +//! This module provides [`SequentialUnionExec`], an alternative to DataFusion's +//! [`UnionExec`] that reports a single partition and processes all input partitions +//! sequentially. This avoids the task spawning that occurs when `CoalescePartitionsExec` +//! is inserted for multi-partition plans. + +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::Result; +use datafusion::error::DataFusionError; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::calculate_union; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, +}; +use futures::Stream; + +/// A union execution plan that processes all inputs sequentially in a single partition. +/// +/// Unlike DataFusion's [`UnionExec`] which reports N partitions for N inputs (causing +/// `CoalescePartitionsExec` to be inserted and spawn Tokio tasks), this operator +/// always reports exactly 1 partition and chains all input partition streams sequentially. +/// +/// ## Use Case +/// This operator is required for non-Tokio async runtimes (e.g., custom async executors) +/// where task spawning via `JoinSet::spawn()` would panic. +/// +/// ## Behavior +/// - Reports 1 output partition regardless of input partition counts +/// - Processes all partitions from all inputs sequentially in order +/// - Validates that all inputs have compatible schemas +/// - Computes common ordering properties across all inputs +#[derive(Debug)] +pub struct SequentialUnionExec { + /// Input execution plans to union + inputs: Vec>, + /// Schema of the output (validated to match all inputs) + schema: SchemaRef, + /// Cached plan properties + properties: PlanProperties, +} + +impl SequentialUnionExec { + /// Creates a new `SequentialUnionExec` with the given inputs. + /// + /// # Arguments + /// * `inputs` - Execution plans to union. Must be non-empty and have compatible schemas. + /// + /// # Returns + /// A new `SequentialUnionExec` or an error if: + /// - `inputs` is empty + /// - Input schemas are incompatible + pub fn try_new(inputs: Vec>) -> Result { + if inputs.is_empty() { + return Err(DataFusionError::Plan( + "SequentialUnionExec requires at least one input".to_string(), + )); + } + + let schema = inputs[0].schema(); + + // Validate all schemas match + for (i, input) in inputs.iter().enumerate().skip(1) { + let input_schema = input.schema(); + if input_schema != schema { + return Err(DataFusionError::Plan(format!( + "SequentialUnionExec schema mismatch: input 0 has schema {schema:?}, \ + but input {i} has schema {input_schema:?}" + ))); + } + } + + // Compute common equivalence properties across all inputs + let children_eqps: Vec<_> = inputs + .iter() + .map(|p| p.properties().equivalence_properties().clone()) + .collect(); + let eq_properties = calculate_union(children_eqps, Arc::clone(&schema))?; + + let properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), // KEY: Always 1 partition + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Ok(Self { + inputs, + schema, + properties, + }) + } +} + +impl DisplayAs for SequentialUnionExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "SequentialUnionExec") + } + DisplayFormatType::TreeRender => { + write!(f, "SequentialUnionExec") + } + } + } +} + +impl ExecutionPlan for SequentialUnionExec { + fn name(&self) -> &str { + "SequentialUnionExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::try_new(children)?)) + } + + fn required_input_distribution(&self) -> Vec { + // Require single partition inputs to prevent the optimizer from inserting + // RepartitionExec which would create multi-partition streams that deadlock + // when polled sequentially. + vec![Distribution::SinglePartition; self.inputs.len()] + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + if partition != 0 { + return Err(DataFusionError::Internal(format!( + "SequentialUnionExec only supports partition 0, got {partition}" + ))); + } + + // Create all streams upfront (follows DataFusion's InterleaveExec pattern) + let mut streams = Vec::new(); + for input in &self.inputs { + let partition_count = input.output_partitioning().partition_count(); + for p in 0..partition_count { + streams.push(input.execute(p, Arc::clone(&context))?); + } + } + + Ok(Box::pin(SequentialUnionStream { + streams, + current_index: 0, + schema: Arc::clone(&self.schema), + })) + } +} + +/// Stream that sequentially processes partitions from multiple inputs. +struct SequentialUnionStream { + /// All input streams to process in order (created upfront) + streams: Vec, + /// Index of the current stream being polled + current_index: usize, + /// Output schema + schema: SchemaRef, +} + +impl RecordBatchStream for SequentialUnionStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl Stream for SequentialUnionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + let idx = self.current_index; + if idx >= self.streams.len() { + return Poll::Ready(None); + } + + match Pin::new(&mut self.streams[idx]).poll_next(cx) { + Poll::Ready(Some(batch)) => return Poll::Ready(Some(batch)), + Poll::Ready(None) => self.current_index += 1, + Poll::Pending => return Poll::Pending, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::Int64Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::datasource::memory::MemorySourceConfig; + use futures::StreamExt; + + fn create_test_batch(values: Vec) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap() + } + + fn create_memory_exec(batches: Vec) -> Arc { + let schema = batches[0].schema(); + MemorySourceConfig::try_new_exec(&[batches], schema, None).unwrap() + } + + #[tokio::test] + async fn test_sequential_union_single_input() -> Result<()> { + let batch = create_test_batch(vec![1, 2, 3]); + let input = create_memory_exec(vec![batch.clone()]); + + let union = SequentialUnionExec::try_new(vec![input])?; + + assert_eq!( + union.properties().output_partitioning().partition_count(), + 1 + ); + + let ctx = Arc::new(TaskContext::default()); + let mut stream = union.execute(0, ctx)?; + + let result = stream.next().await.unwrap()?; + assert_eq!(result.num_rows(), 3); + + assert!(stream.next().await.is_none()); + Ok(()) + } + + #[tokio::test] + async fn test_sequential_union_multiple_inputs() -> Result<()> { + let batch1 = create_test_batch(vec![1, 2]); + let batch2 = create_test_batch(vec![3, 4]); + let batch3 = create_test_batch(vec![5, 6]); + + let input1 = create_memory_exec(vec![batch1]); + let input2 = create_memory_exec(vec![batch2]); + let input3 = create_memory_exec(vec![batch3]); + + let union = SequentialUnionExec::try_new(vec![input1, input2, input3])?; + + assert_eq!( + union.properties().output_partitioning().partition_count(), + 1 + ); + + let ctx = Arc::new(TaskContext::default()); + let mut stream = union.execute(0, ctx)?; + + // Should get all batches in order + let r1 = stream.next().await.unwrap()?; + assert_eq!( + r1.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[1, 2] + ); + + let r2 = stream.next().await.unwrap()?; + assert_eq!( + r2.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[3, 4] + ); + + let r3 = stream.next().await.unwrap()?; + assert_eq!( + r3.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[5, 6] + ); + + assert!(stream.next().await.is_none()); + Ok(()) + } + + #[test] + fn test_sequential_union_empty_inputs_error() { + let result = SequentialUnionExec::try_new(vec![]); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("at least one input")); + } + + #[test] + fn test_sequential_union_schema_mismatch_error() { + let schema1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let schema2 = Arc::new(Schema::new(vec![Field::new("b", DataType::Int64, false)])); + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema1), + vec![Arc::new(Int64Array::from(vec![1]))], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + Arc::clone(&schema2), + vec![Arc::new(Int64Array::from(vec![2]))], + ) + .unwrap(); + + let input1 = MemorySourceConfig::try_new_exec(&[vec![batch1]], schema1, None).unwrap(); + let input2 = MemorySourceConfig::try_new_exec(&[vec![batch2]], schema2, None).unwrap(); + + let result = SequentialUnionExec::try_new(vec![input1, input2]); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("schema mismatch")); + } + + #[tokio::test] + async fn test_sequential_union_invalid_partition() { + let batch = create_test_batch(vec![1]); + let input = create_memory_exec(vec![batch]); + let union = SequentialUnionExec::try_new(vec![input]).unwrap(); + + let ctx = Arc::new(TaskContext::default()); + let result = union.execute(1, ctx); + match result { + Ok(_) => panic!("Expected error for invalid partition"), + Err(e) => assert!( + e.to_string().contains("only supports partition 0"), + "Unexpected error: {e}" + ), + } + } +} diff --git a/src/provider.rs b/src/provider.rs index ffaeced..a4d0bd4 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -18,7 +18,7 @@ use crate::physical_plan::exec::fetch::RecordFetchExec; use crate::physical_plan::fetcher::RecordFetcher; use crate::physical_plan::Index; -use crate::types::{IndexFilter, IndexFilters}; +use crate::types::{IndexFilter, IndexFilters, UnionMode}; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::catalog::TableProvider; @@ -122,6 +122,20 @@ pub trait IndexedTableProvider: TableProvider + Sync + Send { Ok(None) } + /// Returns the union mode to use for OR conditions in index scans. + /// + /// # Default implementation + /// Returns [`UnionMode::Parallel`], which uses DataFusion's standard `UnionExec` + /// and may spawn Tokio tasks for parallel execution. + /// + /// # When to override + /// Override this method to return [`UnionMode::Sequential`] if your runtime + /// does not support Tokio task spawning (e.g., custom async executors or + /// single-threaded runtimes). + fn union_mode(&self) -> UnionMode { + UnionMode::default() + } + /// Returns whether the filters can be pushed down to the index. /// This method can be used in `TableProvider::supports_filters_pushdown`. /// @@ -175,6 +189,7 @@ pub trait IndexedTableProvider: TableProvider + Sync + Send { limit, Arc::clone(&mapper), schema.clone(), + self.union_mode(), )?)) } } @@ -188,10 +203,7 @@ mod tests { use datafusion::catalog::Session; use datafusion::common::Statistics; use datafusion::datasource::TableType; - use datafusion::execution::TaskContext; - use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, - }; + use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion::prelude::{col, lit}; use datafusion_common::{DataFusionError, Result}; use std::any::Any; @@ -241,49 +253,6 @@ mod tests { } } - // Mock ExecutionPlan - #[derive(Debug)] - struct MockExec; - - impl DisplayAs for MockExec { - fn fmt_as(&self, _t: DisplayFormatType, _f: &mut std::fmt::Formatter) -> std::fmt::Result { - unimplemented!() - } - } - - impl ExecutionPlan for MockExec { - fn name(&self) -> &str { - "MockExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - unimplemented!() - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _children: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unimplemented!() - } - } - // Mock IndexedTableProvider #[derive(Debug)] struct MockTableProvider { diff --git a/src/types.rs b/src/types.rs index e464978..ca06a3e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -74,6 +74,30 @@ pub enum IndexFilter { /// serves as input to `create_execution_plan_with_indexes()` for physical plan generation. pub type IndexFilters = Vec; +/// Controls how union operations combine multiple index scans for OR conditions. +/// +/// When executing disjunctive (OR) queries across multiple indexes, results must be +/// combined. This enum allows choosing between parallel and sequential execution +/// strategies based on runtime requirements. +#[derive(Debug, Clone, Copy, Default)] +pub enum UnionMode { + /// Use standard `UnionExec` with parallel execution. + /// + /// This mode spawns Tokio tasks via `JoinSet::spawn()` to process partitions + /// concurrently. Best for Tokio-based runtimes with good parallelism support. + /// + /// **Warning**: This mode will panic in non-Tokio async runtimes. + #[default] + Parallel, + + /// Use `SequentialUnionExec` with single-threaded sequential execution. + /// + /// This mode processes all input partitions sequentially in a single stream + /// without spawning any tasks. Required for non-Tokio runtimes such as + /// custom async executors that don't support Tokio task spawning. + Sequential, +} + impl fmt::Display for IndexFilter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/tests/common/employee_provider.rs b/tests/common/employee_provider.rs index d21988c..7515726 100644 --- a/tests/common/employee_provider.rs +++ b/tests/common/employee_provider.rs @@ -13,7 +13,7 @@ use datafusion_common::DataFusionError; use datafusion_index_provider::physical_plan::exec::fetch::RecordFetchExec; use datafusion_index_provider::physical_plan::Index; use datafusion_index_provider::provider::IndexedTableProvider; -use datafusion_index_provider::types::IndexFilter; +use datafusion_index_provider::types::{IndexFilter, UnionMode}; use crate::common::age_index::AgeIndex; use crate::common::department_index::DepartmentIndex; @@ -26,6 +26,7 @@ pub struct EmployeeTableProvider { age_index: Arc, department_index: Arc, mapper: Arc, + union_mode: UnionMode, } impl Default for EmployeeTableProvider { @@ -70,8 +71,15 @@ impl EmployeeTableProvider { age_index: Arc::new(AgeIndex::new(&age_array, &id_array)), department_index: Arc::new(DepartmentIndex::new(&department_array, &id_array)), mapper: Arc::new(BatchMapper::new(vec![batch])), + union_mode: UnionMode::Parallel, } } + + /// Set the union mode for OR condition handling. + pub fn with_union_mode(mut self, mode: UnionMode) -> Self { + self.union_mode = mode; + self + } } #[async_trait] @@ -126,6 +134,10 @@ impl IndexedTableProvider for EmployeeTableProvider { fn indexes(&self) -> Result>, DataFusionError> { Ok(vec![self.age_index.clone(), self.department_index.clone()]) } + + fn union_mode(&self) -> UnionMode { + self.union_mode + } } impl EmployeeTableProvider { @@ -142,6 +154,7 @@ impl EmployeeTableProvider { limit, self.mapper.clone(), self.schema.clone(), + self.union_mode, )?)) } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index a62c66e..eff4b50 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -5,12 +5,23 @@ pub mod record_fetcher; use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray}; use datafusion::execution::context::SessionContext; +use datafusion_index_provider::types::UnionMode; use employee_provider::EmployeeTableProvider; use std::collections::HashSet; use std::sync::Arc; -/// Helper function to setup test environment +/// Helper function to setup test environment with parallel union mode (default). pub async fn setup_test_env() -> SessionContext { + setup_test_env_with_mode(UnionMode::Parallel).await +} + +/// Helper function to setup test environment with sequential union mode. +pub async fn setup_test_env_sequential() -> SessionContext { + setup_test_env_with_mode(UnionMode::Sequential).await +} + +/// Helper function to setup test environment with a specific union mode. +async fn setup_test_env_with_mode(mode: UnionMode) -> SessionContext { let _ = env_logger::builder() .filter_level(log::LevelFilter::Debug) .is_test(true) @@ -18,7 +29,7 @@ pub async fn setup_test_env() -> SessionContext { let ctx = SessionContext::new(); - let provider = EmployeeTableProvider::default(); + let provider = EmployeeTableProvider::new().with_union_mode(mode); ctx.register_table("employees", Arc::new(provider)).unwrap(); ctx diff --git a/tests/common/record_fetcher.rs b/tests/common/record_fetcher.rs index c58f870..b06ab47 100644 --- a/tests/common/record_fetcher.rs +++ b/tests/common/record_fetcher.rs @@ -28,7 +28,10 @@ impl fmt::Debug for BatchMapper { #[async_trait] impl RecordFetcher for BatchMapper { fn schema(&self) -> SchemaRef { - self.batches[0].schema() + self.batches + .first() + .expect("BatchMapper requires at least one batch") + .schema() } async fn fetch(&self, index_batch: RecordBatch) -> Result { diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 822ebb4..65df8b8 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1,6 +1,8 @@ mod common; use crate::common::assert_names; -use common::{assert_ages, assert_departments, extract_names, setup_test_env}; +use common::{ + assert_ages, assert_departments, extract_names, setup_test_env, setup_test_env_sequential, +}; // +----+-------+--------+----------+ // | id | name | age | department| @@ -552,3 +554,102 @@ async fn test_employee_table_filter_all_simple_or_conditions() { assert_names(&results, &["Alice", "Bob", "David", "Eve"]); assert_ages(&results, &[25, 30, 28, 32]); } + +// ============================================================================= +// Sequential Union Mode Tests +// ============================================================================= +// These tests verify that UnionMode::Sequential works correctly by using +// SequentialUnionExec instead of the parallel UnionExec for OR conditions. + +#[tokio::test] +async fn test_sequential_union_mode_simple_or() { + let ctx = setup_test_env_sequential().await; + + // Simple OR query that triggers SequentialUnionExec + let df = ctx + .sql("SELECT name, age FROM employees WHERE age = 25 OR age = 35") + .await + .unwrap(); + let results = df.collect().await.unwrap(); + + assert_names(&results, &["Alice", "Charlie"]); + assert_ages(&results, &[25, 35]); +} + +#[tokio::test] +async fn test_sequential_union_mode_multiple_or() { + let ctx = setup_test_env_sequential().await; + + // Multiple OR conditions + let df = ctx + .sql("SELECT name, age FROM employees WHERE age = 25 OR age = 30 OR age = 35") + .await + .unwrap(); + let results = df.collect().await.unwrap(); + + assert_names(&results, &["Alice", "Bob", "Charlie"]); + assert_ages(&results, &[25, 30, 35]); +} + +#[tokio::test] +async fn test_sequential_union_mode_mixed_indexes() { + let ctx = setup_test_env_sequential().await; + + // OR across different indexes (age and department) + let df = ctx + .sql("SELECT name, age FROM employees WHERE age = 25 OR department = 'Sales'") + .await + .unwrap(); + let results = df.collect().await.unwrap(); + + assert_names(&results, &["Alice", "Bob", "Eve"]); + assert_ages(&results, &[25, 30, 32]); +} + +#[tokio::test] +async fn test_sequential_union_mode_complex_and_or() { + let ctx = setup_test_env_sequential().await; + + // Complex query with AND conditions inside OR - tests schema normalization + let df = ctx + .sql( + "SELECT name, age, department FROM employees WHERE + (age >= 25 AND department = 'Engineering') OR + (age >= 30 AND department = 'Sales')", + ) + .await + .unwrap(); + let results = df.collect().await.unwrap(); + + // Should return: + // - Alice (25, Engineering) and David (28, Engineering) from first AND + // - Bob (30, Sales) and Eve (32, Sales) from second AND + let total_rows = results.iter().map(|b| b.num_rows()).sum::(); + assert_eq!(total_rows, 4, "Expected 4 rows, got {total_rows}"); + + assert_names(&results, &["Alice", "Bob", "David", "Eve"]); + assert_ages(&results, &[25, 30, 28, 32]); + assert_departments(&results, &["Engineering", "Sales"]); +} + +#[tokio::test] +async fn test_sequential_union_mode_deduplication() { + let ctx = setup_test_env_sequential().await; + + // Query with overlapping conditions - tests that deduplication works + let df = ctx + .sql("SELECT name, age FROM employees WHERE age = 25 OR age < 29") + .await + .unwrap(); + let results = df.collect().await.unwrap(); + + // Alice (25) matches both conditions, should appear only once + let total_rows = results.iter().map(|b| b.num_rows()).sum::(); + assert_eq!( + total_rows, 2, + "Expected 2 rows after deduplication, got {total_rows}" + ); + + assert_names(&results, &["Alice", "David"]); + assert_ages(&results, &[25, 28]); +}