Skip to content

Commit 5eede25

Browse files
committed
physical plan: move shared resources from exec nodes
To reuse a physical plan we need to move all shared resources from them. Some execs contain shared futures to initialize them lazy once across all partitions. For example, `HashJoinExec` contain a future that once collects all data from the build sie. This patch moves such resources to the task context. Now each plan can save to the task context its own state to reuse across partitions.
1 parent 66220fb commit 5eede25

File tree

9 files changed

+223
-74
lines changed

9 files changed

+223
-74
lines changed

datafusion/core/src/datasource/physical_plan/parquet/mod.rs

-1
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,6 @@ impl ExecutionPlan for ParquetExec {
709709
.clone()
710710
.unwrap_or_else(|| Arc::new(DefaultSchemaAdapterFactory::default()));
711711

712-
println!("parquet exec registered metrics...");
713712
let metrics = ctx.get_or_register_metric_set_with_default(self, || {
714713
ExecutionPlanMetricsSet::with_inner(self.base_metrics.clone_inner())
715714
});

datafusion/execution/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ pub use disk_manager::DiskManager;
3939
pub use metrics::Metric;
4040
pub use registry::FunctionRegistry;
4141
pub use stream::{RecordBatchStream, SendableRecordBatchStream};
42-
pub use task::TaskContext;
42+
pub use task::{PlanState, TaskContext};

datafusion/execution/src/task.rs

+54
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use std::{
1919
any::Any,
2020
collections::{HashMap, HashSet},
21+
fmt::Debug,
2122
sync::Arc,
2223
};
2324

@@ -60,6 +61,9 @@ pub struct TaskContext {
6061
/// Metrics associated with a execution plan address.
6162
/// std mutex is used because too concurrent access is not assumed.
6263
metrics: std::sync::Mutex<HashMap<usize, ExecutionPlanMetricsSet>>,
64+
/// Session plans state by an execution plan address.
65+
/// Resources that shared across execution partitions.
66+
plan_state: std::sync::Mutex<HashMap<usize, Arc<dyn PlanState>>>,
6367
}
6468

6569
impl Default for TaskContext {
@@ -79,10 +83,16 @@ impl Default for TaskContext {
7983
runtime,
8084
param_values: None,
8185
metrics: Default::default(),
86+
plan_state: Default::default(),
8287
}
8388
}
8489
}
8590

91+
/// Generic plan state.
92+
pub trait PlanState: Debug + Any + Send + Sync {
93+
fn as_any(&self) -> &dyn Any;
94+
}
95+
8696
fn plan_addr(plan: &dyn Any) -> usize {
8797
plan as *const _ as *const () as usize
8898
}
@@ -112,6 +122,30 @@ impl TaskContext {
112122
runtime,
113123
param_values: None,
114124
metrics: Default::default(),
125+
plan_state: Default::default(),
126+
}
127+
}
128+
129+
/// Fork a task context.
130+
///
131+
/// Forked context contains the same
132+
/// * session related attributes (id, udfs, etc),
133+
/// * runtime environment
134+
///
135+
/// But an empty metrics and plan state.
136+
///
137+
pub fn fork(&self) -> Self {
138+
Self {
139+
task_id: self.task_id(),
140+
session_id: self.session_id(),
141+
session_config: self.session_config.clone(),
142+
scalar_functions: self.scalar_functions.clone(),
143+
aggregate_functions: self.aggregate_functions.clone(),
144+
window_functions: self.window_functions.clone(),
145+
runtime: Arc::clone(&self.runtime),
146+
param_values: self.param_values.clone(),
147+
metrics: Default::default(),
148+
plan_state: Default::default(),
115149
}
116150
}
117151

@@ -206,6 +240,26 @@ impl TaskContext {
206240
metric_set
207241
}
208242
}
243+
244+
/// Get state for specific plan or register a new state.
245+
pub fn get_or_register_plan_state<F>(
246+
&self,
247+
plan: &dyn Any,
248+
f: F,
249+
) -> Arc<dyn PlanState>
250+
where
251+
F: FnOnce() -> Arc<dyn PlanState>,
252+
{
253+
let addr = plan_addr(plan);
254+
let mut plan_state = self.plan_state.lock().unwrap();
255+
if let Some(state) = plan_state.get(&addr) {
256+
Arc::clone(state)
257+
} else {
258+
let state = f();
259+
plan_state.insert(addr, Arc::clone(&state));
260+
state
261+
}
262+
}
209263
}
210264

211265
impl FunctionRegistry for TaskContext {

datafusion/physical-plan/src/joins/cross_join.rs

+36-11
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use arrow_array::RecordBatchOptions;
3838
use datafusion_common::stats::Precision;
3939
use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
4040
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
41-
use datafusion_execution::TaskContext;
41+
use datafusion_execution::{PlanState, TaskContext};
4242
use datafusion_physical_expr::equivalence::join_equivalence_properties;
4343

4444
use async_trait::async_trait;
@@ -57,9 +57,28 @@ pub struct CrossJoinExec {
5757
pub right: Arc<dyn ExecutionPlan>,
5858
/// The schema once the join is applied
5959
schema: SchemaRef,
60+
cache: PlanProperties,
61+
}
62+
63+
/// Exec state shared across partitions per one execution.
64+
#[derive(Debug)]
65+
struct CrossJoinExecState {
6066
/// Build-side data
6167
left_fut: OnceAsync<JoinLeftData>,
62-
cache: PlanProperties,
68+
}
69+
70+
impl CrossJoinExecState {
71+
fn new() -> Self {
72+
Self {
73+
left_fut: Default::default(),
74+
}
75+
}
76+
}
77+
78+
impl PlanState for CrossJoinExecState {
79+
fn as_any(&self) -> &dyn Any {
80+
self
81+
}
6382
}
6483

6584
impl CrossJoinExec {
@@ -87,7 +106,6 @@ impl CrossJoinExec {
87106
left,
88107
right,
89108
schema,
90-
left_fut: Default::default(),
91109
cache,
92110
}
93111
}
@@ -239,14 +257,21 @@ impl ExecutionPlan for CrossJoinExec {
239257
let reservation =
240258
MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
241259

242-
let left_fut = self.left_fut.once(|| {
243-
load_left_input(
244-
Arc::clone(&self.left),
245-
Arc::clone(&context),
246-
join_metrics.clone(),
247-
reservation,
248-
)
249-
});
260+
let state = context
261+
.get_or_register_plan_state(self, || Arc::new(CrossJoinExecState::new()));
262+
let left_fut = state
263+
.as_any()
264+
.downcast_ref::<CrossJoinExecState>()
265+
.unwrap()
266+
.left_fut
267+
.once(|| {
268+
load_left_input(
269+
Arc::clone(&self.left),
270+
Arc::clone(&context),
271+
join_metrics.clone(),
272+
reservation,
273+
)
274+
});
250275

251276
Ok(Box::pin(CrossJoinStream {
252277
schema: Arc::clone(&self.schema),

datafusion/physical-plan/src/joins/hash_join.rs

+49-20
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ use datafusion_common::{
6363
JoinSide, JoinType, Result,
6464
};
6565
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
66-
use datafusion_execution::TaskContext;
66+
use datafusion_execution::{PlanState, TaskContext};
6767
use datafusion_physical_expr::equivalence::{
6868
join_equivalence_properties, ProjectionMapping,
6969
};
@@ -306,8 +306,6 @@ pub struct HashJoinExec {
306306
/// The schema after join. Please be careful when using this schema,
307307
/// if there is a projection, the schema isn't the same as the output schema.
308308
join_schema: SchemaRef,
309-
/// Future that consumes left input and builds the hash table
310-
left_fut: OnceAsync<JoinLeftData>,
311309
/// Shared the `RandomState` for the hashing algorithm
312310
random_state: RandomState,
313311
/// Partitioning mode to use
@@ -325,6 +323,27 @@ pub struct HashJoinExec {
325323
cache: PlanProperties,
326324
}
327325

326+
/// Exec state shared across partitions per one execution.
327+
#[derive(Debug)]
328+
struct HashJoinExecState {
329+
/// Future that consumes left input and builds the hash table
330+
left_fut: OnceAsync<JoinLeftData>,
331+
}
332+
333+
impl HashJoinExecState {
334+
fn new() -> Self {
335+
Self {
336+
left_fut: Default::default(),
337+
}
338+
}
339+
}
340+
341+
impl PlanState for HashJoinExecState {
342+
fn as_any(&self) -> &dyn Any {
343+
self
344+
}
345+
}
346+
328347
impl HashJoinExec {
329348
/// Tries to create a new [HashJoinExec].
330349
///
@@ -376,7 +395,6 @@ impl HashJoinExec {
376395
filter,
377396
join_type: *join_type,
378397
join_schema,
379-
left_fut: Default::default(),
380398
random_state,
381399
mode: partition_mode,
382400
projection,
@@ -689,21 +707,32 @@ impl ExecutionPlan for HashJoinExec {
689707
let metrics = context.get_or_register_metric_set(self);
690708
let join_metrics = BuildProbeJoinMetrics::new(partition, &metrics);
691709
let left_fut = match self.mode {
692-
PartitionMode::CollectLeft => self.left_fut.once(|| {
693-
let reservation =
694-
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
695-
collect_left_input(
696-
None,
697-
self.random_state.clone(),
698-
Arc::clone(&self.left),
699-
on_left.clone(),
700-
Arc::clone(&context),
701-
join_metrics.clone(),
702-
reservation,
703-
need_produce_result_in_final(self.join_type),
704-
self.right().output_partitioning().partition_count(),
705-
)
706-
}),
710+
PartitionMode::CollectLeft => {
711+
let state = context.get_or_register_plan_state(self, || {
712+
Arc::new(HashJoinExecState::new())
713+
});
714+
715+
state
716+
.as_any()
717+
.downcast_ref::<HashJoinExecState>()
718+
.unwrap()
719+
.left_fut
720+
.once(|| {
721+
let reservation = MemoryConsumer::new("HashJoinInput")
722+
.register(context.memory_pool());
723+
collect_left_input(
724+
None,
725+
self.random_state.clone(),
726+
Arc::clone(&self.left),
727+
on_left.clone(),
728+
Arc::clone(&context),
729+
join_metrics.clone(),
730+
reservation,
731+
need_produce_result_in_final(self.join_type),
732+
self.right().output_partitioning().partition_count(),
733+
)
734+
})
735+
}
707736
PartitionMode::Partitioned => {
708737
let reservation =
709738
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
@@ -3427,7 +3456,6 @@ mod tests {
34273456
/// Test for parallelised HashJoinExec with PartitionMode::CollectLeft
34283457
#[tokio::test]
34293458
async fn test_collect_left_multiple_partitions_join() -> Result<()> {
3430-
let task_ctx = Arc::new(TaskContext::default());
34313459
let left = build_table(
34323460
("a1", &vec![1, 2, 3]),
34333461
("b1", &vec![4, 5, 7]),
@@ -3522,6 +3550,7 @@ mod tests {
35223550
];
35233551

35243552
for (join_type, expected) in test_cases {
3553+
let task_ctx = Arc::new(TaskContext::default());
35253554
let (_, batches) = join_collect_with_partition_mode(
35263555
Arc::clone(&left),
35273556
Arc::clone(&right),

datafusion/physical-plan/src/joins/nested_loop_join.rs

+40-13
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use arrow::util::bit_util;
4747
use arrow_array::PrimitiveArray;
4848
use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics};
4949
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
50-
use datafusion_execution::TaskContext;
50+
use datafusion_execution::{PlanState, TaskContext};
5151
use datafusion_expr::JoinType;
5252
use datafusion_physical_expr::equivalence::join_equivalence_properties;
5353

@@ -147,14 +147,33 @@ pub struct NestedLoopJoinExec {
147147
pub(crate) join_type: JoinType,
148148
/// The schema once the join is applied
149149
schema: SchemaRef,
150-
/// Build-side data
151-
inner_table: OnceAsync<JoinLeftData>,
152150
/// Information of index and left / right placement of columns
153151
column_indices: Vec<ColumnIndex>,
154152
/// Cache holding plan properties like equivalences, output partitioning etc.
155153
cache: PlanProperties,
156154
}
157155

156+
/// Exec state shared across partitions per one execution.
157+
#[derive(Debug)]
158+
struct NestedLoopJoinExecState {
159+
/// Build-side data.
160+
inner_table: OnceAsync<JoinLeftData>,
161+
}
162+
163+
impl NestedLoopJoinExecState {
164+
fn new() -> Self {
165+
Self {
166+
inner_table: Default::default(),
167+
}
168+
}
169+
}
170+
171+
impl PlanState for NestedLoopJoinExecState {
172+
fn as_any(&self) -> &dyn Any {
173+
self
174+
}
175+
}
176+
158177
impl NestedLoopJoinExec {
159178
/// Try to create a new [`NestedLoopJoinExec`]
160179
pub fn try_new(
@@ -178,7 +197,6 @@ impl NestedLoopJoinExec {
178197
filter,
179198
join_type: *join_type,
180199
schema,
181-
inner_table: Default::default(),
182200
column_indices,
183201
cache,
184202
})
@@ -303,17 +321,26 @@ impl ExecutionPlan for NestedLoopJoinExec {
303321
MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
304322
.register(context.memory_pool());
305323

306-
let inner_table = self.inner_table.once(|| {
307-
collect_left_input(
308-
Arc::clone(&self.left),
309-
Arc::clone(&context),
310-
join_metrics.clone(),
311-
load_reservation,
312-
need_produce_result_in_final(self.join_type),
313-
self.right().output_partitioning().partition_count(),
314-
)
324+
let state = context.get_or_register_plan_state(self, || {
325+
Arc::new(NestedLoopJoinExecState::new())
315326
});
316327

328+
let inner_table = state
329+
.as_any()
330+
.downcast_ref::<NestedLoopJoinExecState>()
331+
.unwrap()
332+
.inner_table
333+
.once(|| {
334+
collect_left_input(
335+
Arc::clone(&self.left),
336+
Arc::clone(&context),
337+
join_metrics.clone(),
338+
load_reservation,
339+
need_produce_result_in_final(self.join_type),
340+
self.right().output_partitioning().partition_count(),
341+
)
342+
});
343+
317344
// Resolve placeholders in filter.
318345
let resolved_filter = if let Some(ref filter) = self.filter {
319346
Some(filter.resolve_placeholders(context.param_values())?)

datafusion/physical-plan/src/joins/symmetric_hash_join.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,7 @@ mod tests {
16421642
Arc::clone(&task_ctx),
16431643
)
16441644
.await?;
1645+
let task_ctx = Arc::new(task_ctx.fork());
16451646
let second_batches = partitioned_hash_join_with_filter(
16461647
left, right, on, filter, &join_type, false, task_ctx,
16471648
)

0 commit comments

Comments
 (0)