@@ -63,7 +63,7 @@ use datafusion_common::{
63
63
JoinSide , JoinType , Result ,
64
64
} ;
65
65
use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
66
- use datafusion_execution:: TaskContext ;
66
+ use datafusion_execution:: { PlanState , TaskContext } ;
67
67
use datafusion_physical_expr:: equivalence:: {
68
68
join_equivalence_properties, ProjectionMapping ,
69
69
} ;
@@ -306,8 +306,6 @@ pub struct HashJoinExec {
306
306
/// The schema after join. Please be careful when using this schema,
307
307
/// if there is a projection, the schema isn't the same as the output schema.
308
308
join_schema : SchemaRef ,
309
- /// Future that consumes left input and builds the hash table
310
- left_fut : OnceAsync < JoinLeftData > ,
311
309
/// Shared the `RandomState` for the hashing algorithm
312
310
random_state : RandomState ,
313
311
/// Partitioning mode to use
@@ -325,6 +323,27 @@ pub struct HashJoinExec {
325
323
cache : PlanProperties ,
326
324
}
327
325
326
+ /// Exec state shared across partitions per one execution.
327
+ #[ derive( Debug ) ]
328
+ struct HashJoinExecState {
329
+ /// Future that consumes left input and builds the hash table
330
+ left_fut : OnceAsync < JoinLeftData > ,
331
+ }
332
+
333
+ impl HashJoinExecState {
334
+ fn new ( ) -> Self {
335
+ Self {
336
+ left_fut : Default :: default ( ) ,
337
+ }
338
+ }
339
+ }
340
+
341
+ impl PlanState for HashJoinExecState {
342
+ fn as_any ( & self ) -> & dyn Any {
343
+ self
344
+ }
345
+ }
346
+
328
347
impl HashJoinExec {
329
348
/// Tries to create a new [HashJoinExec].
330
349
///
@@ -376,7 +395,6 @@ impl HashJoinExec {
376
395
filter,
377
396
join_type : * join_type,
378
397
join_schema,
379
- left_fut : Default :: default ( ) ,
380
398
random_state,
381
399
mode : partition_mode,
382
400
projection,
@@ -689,21 +707,32 @@ impl ExecutionPlan for HashJoinExec {
689
707
let metrics = context. get_or_register_metric_set ( self ) ;
690
708
let join_metrics = BuildProbeJoinMetrics :: new ( partition, & metrics) ;
691
709
let left_fut = match self . mode {
692
- PartitionMode :: CollectLeft => self . left_fut . once ( || {
693
- let reservation =
694
- MemoryConsumer :: new ( "HashJoinInput" ) . register ( context. memory_pool ( ) ) ;
695
- collect_left_input (
696
- None ,
697
- self . random_state . clone ( ) ,
698
- Arc :: clone ( & self . left ) ,
699
- on_left. clone ( ) ,
700
- Arc :: clone ( & context) ,
701
- join_metrics. clone ( ) ,
702
- reservation,
703
- need_produce_result_in_final ( self . join_type ) ,
704
- self . right ( ) . output_partitioning ( ) . partition_count ( ) ,
705
- )
706
- } ) ,
710
+ PartitionMode :: CollectLeft => {
711
+ let state = context. get_or_register_plan_state ( self , || {
712
+ Arc :: new ( HashJoinExecState :: new ( ) )
713
+ } ) ;
714
+
715
+ state
716
+ . as_any ( )
717
+ . downcast_ref :: < HashJoinExecState > ( )
718
+ . unwrap ( )
719
+ . left_fut
720
+ . once ( || {
721
+ let reservation = MemoryConsumer :: new ( "HashJoinInput" )
722
+ . register ( context. memory_pool ( ) ) ;
723
+ collect_left_input (
724
+ None ,
725
+ self . random_state . clone ( ) ,
726
+ Arc :: clone ( & self . left ) ,
727
+ on_left. clone ( ) ,
728
+ Arc :: clone ( & context) ,
729
+ join_metrics. clone ( ) ,
730
+ reservation,
731
+ need_produce_result_in_final ( self . join_type ) ,
732
+ self . right ( ) . output_partitioning ( ) . partition_count ( ) ,
733
+ )
734
+ } )
735
+ }
707
736
PartitionMode :: Partitioned => {
708
737
let reservation =
709
738
MemoryConsumer :: new ( format ! ( "HashJoinInput[{partition}]" ) )
@@ -3427,7 +3456,6 @@ mod tests {
3427
3456
/// Test for parallelised HashJoinExec with PartitionMode::CollectLeft
3428
3457
#[ tokio:: test]
3429
3458
async fn test_collect_left_multiple_partitions_join ( ) -> Result < ( ) > {
3430
- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
3431
3459
let left = build_table (
3432
3460
( "a1" , & vec ! [ 1 , 2 , 3 ] ) ,
3433
3461
( "b1" , & vec ! [ 4 , 5 , 7 ] ) ,
@@ -3522,6 +3550,7 @@ mod tests {
3522
3550
] ;
3523
3551
3524
3552
for ( join_type, expected) in test_cases {
3553
+ let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
3525
3554
let ( _, batches) = join_collect_with_partition_mode (
3526
3555
Arc :: clone ( & left) ,
3527
3556
Arc :: clone ( & right) ,
0 commit comments