Skip to content

Commit 8bc4d75

Browse files
authored
Merge pull request #28 from spiceai/peasee/260421-topk-improvements
fix: Improve TopK filter pushdown
2 parents c601f86 + 42f17f8 commit 8bc4d75

6 files changed

Lines changed: 142 additions & 14 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ballista/executor/src/execution_engine.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,13 @@ use ballista_core::execution_plans::sort_shuffle::SortShuffleWriterExec;
2727
use ballista_core::serde::protobuf::ShuffleWritePartition;
2828
use ballista_core::utils;
2929
use datafusion::common::tree_node::{Transformed, TreeNode};
30+
use datafusion::config::ConfigOptions;
3031
use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource};
3132
use datafusion::datasource::source::DataSourceExec;
3233
use datafusion::error::{DataFusionError, Result};
3334
use datafusion::execution::context::TaskContext;
35+
use datafusion::physical_optimizer::PhysicalOptimizerRule;
36+
use datafusion::physical_optimizer::filter_pushdown::FilterPushdown;
3437
use datafusion::physical_plan::ExecutionPlan;
3538
use datafusion::physical_plan::metrics::MetricsSet;
3639
use std::fmt::{Debug, Display};
@@ -53,6 +56,7 @@ pub trait ExecutionEngine: Sync + Send {
5356
stage_id: usize,
5457
plan: Arc<dyn ExecutionPlan>,
5558
work_dir: &str,
59+
config: &ConfigOptions,
5660
) -> Result<Arc<dyn QueryStageExecutor>>;
5761
}
5862

@@ -135,7 +139,14 @@ impl ExecutionEngine for DefaultExecutionEngine {
135139
stage_id: usize,
136140
plan: Arc<dyn ExecutionPlan>,
137141
work_dir: &str,
142+
config: &ConfigOptions,
138143
) -> Result<Arc<dyn QueryStageExecutor>> {
144+
// Re-run FilterPushdown(Post) to re-establish dynamic filter links
145+
// (e.g., TopK → DataSourceExec) that are lost during protobuf
146+
// serialization/deserialization between scheduler and executor.
147+
let filter_pushdown = FilterPushdown::new_post_optimization();
148+
let plan = filter_pushdown.optimize(plan, config)?;
149+
139150
// Fix ParquetSource metadata_size_hint lost during serialization
140151
let plan = fix_parquet_metadata_size_hint(plan)?;
141152

ballista/executor/src/execution_loop.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
418418
stage_id as usize,
419419
plan,
420420
&executor.work_dir,
421+
task_context.session_config().options(),
421422
)?;
422423
dedicated_executor.spawn(async move {
423424
use std::panic::AssertUnwindSafe;

ballista/executor/src/executor_server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
378378
stage_id,
379379
plan,
380380
&self.executor.work_dir,
381+
task.session_config.options(),
381382
)
382383
.unwrap();
383384

ballista/scheduler/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@ tonic-prost-build = { workspace = true }
8080

8181
[dev-dependencies]
8282
rstest = { workspace = true }
83+
tempfile = { workspace = true }

ballista/scheduler/src/planner.rs

Lines changed: 127 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,51 @@ impl DefaultDistributedPlanner {
169169
with_new_children_if_necessary(execution_plan, vec![unresolved_shuffle])?,
170170
stages,
171171
))
172-
} else if let Some(_sort_preserving_merge) = execution_plan
172+
} else if let Some(sort_preserving_merge) = execution_plan
173173
.as_any()
174174
.downcast_ref::<SortPreservingMergeExec>(
175175
) {
176-
let shuffle_writer = create_shuffle_writer_with_config(
177-
job_id,
178-
self.next_stage_id(),
179-
children[0].clone(),
180-
None,
181-
config,
182-
)?;
183-
let unresolved_shuffle = create_unresolved_shuffle(shuffle_writer.as_ref());
184-
stages.push(shuffle_writer);
185-
Ok((
186-
with_new_children_if_necessary(execution_plan, vec![unresolved_shuffle])?,
187-
stages,
188-
))
176+
// For TopK queries (SortPreservingMergeExec with a small fetch/limit),
177+
// skip the stage break and keep the merge in the same stage as its children.
178+
// This avoids the overhead of shuffle write/read for a small number of rows,
179+
// which dominates execution time for TopK queries in distributed mode.
180+
//
181+
// Note on parallelism: because SortPreservingMergeExec has an output
182+
// partitioning of 1, the entire stage becomes a single task assigned to
183+
// one executor (ShuffleWriterExec::input_partition_count() == 1).
184+
// This does sacrifice cluster-level parallelism (no cross-executor
185+
// distribution). However, within that executor the child partitions
186+
// still execute as parallel async streams, so intra-executor parallelism
187+
// is preserved. For small fetch values this trade-off is worthwhile as
188+
// the shuffle coordination overhead far exceeds the merge cost.
189+
const TOPK_FETCH_THRESHOLD: usize = 1000;
190+
if sort_preserving_merge
191+
.fetch()
192+
.is_some_and(|f| f <= TOPK_FETCH_THRESHOLD)
193+
{
194+
Ok((
195+
with_new_children_if_necessary(execution_plan, children)?,
196+
stages,
197+
))
198+
} else {
199+
let shuffle_writer = create_shuffle_writer_with_config(
200+
job_id,
201+
self.next_stage_id(),
202+
children[0].clone(),
203+
None,
204+
config,
205+
)?;
206+
let unresolved_shuffle =
207+
create_unresolved_shuffle(shuffle_writer.as_ref());
208+
stages.push(shuffle_writer);
209+
Ok((
210+
with_new_children_if_necessary(
211+
execution_plan,
212+
vec![unresolved_shuffle],
213+
)?,
214+
stages,
215+
))
216+
}
189217
} else if let Some(repart) =
190218
execution_plan.as_any().downcast_ref::<RepartitionExec>()
191219
{
@@ -940,6 +968,91 @@ order by
940968
Ok(result_exec_plan)
941969
}
942970

971+
/// Verifies that TopK queries (ORDER BY ... LIMIT N, where N is small)
972+
/// do NOT create a stage break at SortPreservingMergeExec, avoiding
973+
/// shuffle overhead for small result sets.
974+
#[tokio::test]
975+
async fn test_topk_avoids_stage_break() -> Result<(), BallistaError> {
976+
use datafusion::prelude::{CsvReadOptions, SessionConfig, SessionContext};
977+
use std::io::Write;
978+
979+
let tmp_dir = tempfile::tempdir().unwrap();
980+
let schema = "id,value\n";
981+
for i in 0..4 {
982+
let path = tmp_dir.path().join(format!("part{i:02}.csv"));
983+
let mut f = std::fs::File::create(&path).unwrap();
984+
write!(f, "{schema}").unwrap();
985+
for j in 0..10 {
986+
writeln!(f, "{},{}", i * 10 + j, (i * 10 + j) * 100).unwrap();
987+
}
988+
}
989+
990+
let config = SessionConfig::new().with_target_partitions(4);
991+
let ctx = SessionContext::new_with_config(config);
992+
ctx.register_csv(
993+
"test_table",
994+
tmp_dir.path().to_str().unwrap(),
995+
CsvReadOptions::new(),
996+
)
997+
.await?;
998+
999+
// TopK query with small LIMIT — should produce a single stage
1000+
let df = ctx
1001+
.sql("SELECT id, value FROM test_table ORDER BY value DESC LIMIT 10")
1002+
.await?;
1003+
let plan = df.into_optimized_plan()?;
1004+
let plan = ctx.state().create_physical_plan(&plan).await?;
1005+
1006+
let mut planner = DefaultDistributedPlanner::new();
1007+
let stages = planner.plan_query_stages(
1008+
"job-topk",
1009+
plan,
1010+
ctx.state().config().options(),
1011+
)?;
1012+
1013+
for (i, stage) in stages.iter().enumerate() {
1014+
println!(
1015+
"TopK Stage {i}:\n{}",
1016+
displayable(stage.as_ref()).indent(false)
1017+
);
1018+
}
1019+
1020+
// Should be a single stage (no shuffle for TopK with small limit)
1021+
assert_eq!(
1022+
1,
1023+
stages.len(),
1024+
"TopK with small LIMIT should produce 1 stage, got {}",
1025+
stages.len()
1026+
);
1027+
1028+
// The single stage should contain SortPreservingMergeExec
1029+
let root = stages[0].children()[0].clone();
1030+
let _merge = downcast_exec!(root, SortPreservingMergeExec);
1031+
1032+
// Without LIMIT, the same query should produce 2 stages (with shuffle)
1033+
let df_no_limit = ctx
1034+
.sql("SELECT id, value FROM test_table ORDER BY value DESC")
1035+
.await?;
1036+
let plan_no_limit = df_no_limit.into_optimized_plan()?;
1037+
let plan_no_limit = ctx.state().create_physical_plan(&plan_no_limit).await?;
1038+
1039+
let mut planner2 = DefaultDistributedPlanner::new();
1040+
let stages_no_limit = planner2.plan_query_stages(
1041+
"job-no-limit",
1042+
plan_no_limit,
1043+
ctx.state().config().options(),
1044+
)?;
1045+
1046+
assert_eq!(
1047+
2,
1048+
stages_no_limit.len(),
1049+
"ORDER BY without LIMIT should produce 2 stages, got {}",
1050+
stages_no_limit.len()
1051+
);
1052+
1053+
Ok(())
1054+
}
1055+
9431056
fn memory_exec(
9441057
schema: Arc<Schema>,
9451058
partition_count: usize,

0 commit comments

Comments
 (0)