@@ -2146,7 +2146,6 @@ mod tests {
21462146 exec_err, internal_err,
21472147 } ;
21482148 use datafusion_execution:: config:: SessionConfig ;
2149- use datafusion_execution:: disk_manager:: { DiskManagerBuilder , DiskManagerMode } ;
21502149 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
21512150 use datafusion_expr:: Operator ;
21522151 use datafusion_functions_aggregate:: count:: count_udaf;
@@ -2504,52 +2503,21 @@ mod tests {
25042503 Ok ( ( columns, batches, metrics) )
25052504 }
25062505
2507- fn memory_limited_aggregate_join_task_ctx (
2508- batch_size : usize ,
2509- memory_limit : Option < usize > ,
2510- ) -> Result < Arc < TaskContext > > {
2511- let mut session_config = SessionConfig :: default ( ) . with_batch_size ( batch_size) ;
2512-
2513- // Keep the repro focused on normal hash aggregation and hash join paths.
2514- session_config
2515- . options_mut ( )
2516- . execution
2517- . skip_partial_aggregation_probe_rows_threshold = usize:: MAX ;
2518- session_config
2519- . options_mut ( )
2520- . execution
2521- . perfect_hash_join_small_build_threshold = 0 ;
2522- session_config
2523- . options_mut ( )
2524- . execution
2525- . perfect_hash_join_min_key_density = f64:: INFINITY ;
2526-
2527- let mut runtime_builder = RuntimeEnvBuilder :: new ( ) . with_disk_manager_builder (
2528- DiskManagerBuilder :: default ( ) . with_mode ( DiskManagerMode :: Disabled ) ,
2529- ) ;
2530- if let Some ( memory_limit) = memory_limit {
2531- runtime_builder = runtime_builder. with_memory_limit ( memory_limit, 1.0 ) ;
2532- }
2533-
2534- Ok ( Arc :: new (
2535- TaskContext :: default ( )
2536- . with_session_config ( session_config)
2537- . with_runtime ( runtime_builder. build_arc ( ) ?) ,
2538- ) )
2506+ fn aggregate_join_group_key ( i : usize ) -> u32 {
2507+ ( i as u32 ) * 1000
25392508 }
25402509
2541- async fn final_aggregate_build_side (
2542- num_groups : usize ,
2543- batch_size : usize ,
2544- ) -> Result < Arc < AggregateExec > > {
2510+ async fn final_aggregate_build_side ( num_groups : usize ) -> Result < Arc < AggregateExec > > {
25452511 let raw_schema = Arc :: new ( Schema :: new ( vec ! [
25462512 Field :: new( "group_key" , DataType :: UInt32 , false ) ,
25472513 Field :: new( "value" , DataType :: UInt64 , false ) ,
25482514 ] ) ) ;
25492515 let batch = RecordBatch :: try_new (
25502516 Arc :: clone ( & raw_schema) ,
25512517 vec ! [
2552- Arc :: new( UInt32Array :: from_iter_values( 0 ..num_groups as u32 ) ) ,
2518+ Arc :: new( UInt32Array :: from_iter_values(
2519+ ( 0 ..num_groups) . map( aggregate_join_group_key) ,
2520+ ) ) ,
25532521 Arc :: new( UInt64Array :: from( vec![ 1 ; num_groups] ) ) ,
25542522 ] ,
25552523 ) ?;
@@ -2575,9 +2543,10 @@ mod tests {
25752543 Arc :: clone ( & raw_schema) ,
25762544 ) ?) ;
25772545 let partial_schema = partial_aggregate. schema ( ) ;
2578- let task_ctx = memory_limited_aggregate_join_task_ctx ( batch_size, None ) ?;
2579- let partial_batches =
2580- common:: collect ( partial_aggregate. execute ( 0 , task_ctx) ?) . await ?;
2546+ let partial_batches = common:: collect (
2547+ partial_aggregate. execute ( 0 , Arc :: new ( TaskContext :: default ( ) ) ) ?,
2548+ )
2549+ . await ?;
25812550 let partial_input = TestMemoryExec :: try_new_exec (
25822551 & [ partial_batches] ,
25832552 Arc :: clone ( & partial_schema) ,
@@ -2594,104 +2563,66 @@ mod tests {
25942563 ) ?) )
25952564 }
25962565
2597- fn probe_side ( num_groups : usize ) -> Result < Arc < dyn ExecutionPlan > > {
2598- let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
2566+ #[ tokio:: test]
2567+ async fn build_side_final_aggregate_respects_grouped_memory_limit ( ) -> Result < ( ) > {
2568+ const BATCH_SIZE : usize = 8192 ;
2569+ const NUM_GROUPS : usize = BATCH_SIZE * 32 + 1 ;
2570+ const EXPECTED_JOIN_ROWS : usize = 3 ;
2571+
2572+ let aggregate = final_aggregate_build_side ( NUM_GROUPS ) . await ?;
2573+ let aggregate_batches =
2574+ common:: collect ( aggregate. execute ( 0 , Arc :: new ( TaskContext :: default ( ) ) ) ?)
2575+ . await ?;
2576+ assert ! ( aggregate_batches. len( ) > 1 ) ;
2577+ assert_eq ! (
2578+ aggregate_batches
2579+ . iter( )
2580+ . map( RecordBatch :: num_rows)
2581+ . sum:: <usize >( ) ,
2582+ NUM_GROUPS
2583+ ) ;
2584+ let aggregate_batch = concat_batches ( & aggregate. schema ( ) , & aggregate_batches) ?;
2585+ let memory_limit = get_record_batch_memory_size ( & aggregate_batch) * 4 ;
2586+
2587+ let probe_schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
25992588 "probe_key" ,
26002589 DataType :: UInt32 ,
26012590 false ,
26022591 ) ] ) ) ;
2603- let batch = RecordBatch :: try_new (
2604- Arc :: clone ( & schema ) ,
2592+ let probe_batch = RecordBatch :: try_new (
2593+ Arc :: clone ( & probe_schema ) ,
26052594 vec ! [ Arc :: new( UInt32Array :: from( vec![
2606- 0 ,
2607- ( num_groups / 2 ) as u32 ,
2608- ( num_groups - 1 ) as u32 ,
2595+ aggregate_join_group_key ( 0 ) ,
2596+ aggregate_join_group_key ( NUM_GROUPS / 2 ) ,
2597+ aggregate_join_group_key ( NUM_GROUPS - 1 ) ,
26092598 ] ) ) ] ,
26102599 ) ?;
2600+ let probe: Arc < dyn ExecutionPlan > = TestMemoryExec :: try_new_exec (
2601+ & [ vec ! [ probe_batch] ] ,
2602+ Arc :: clone ( & probe_schema) ,
2603+ None ,
2604+ ) ?;
26112605
2612- let exec: Arc < dyn ExecutionPlan > =
2613- TestMemoryExec :: try_new_exec ( & [ vec ! [ batch] ] , schema, None ) ?;
2614-
2615- Ok ( exec)
2616- }
2617-
2618- async fn final_aggregate_peak_mem_used (
2619- aggregate : & Arc < AggregateExec > ,
2620- num_groups : usize ,
2621- batch_size : usize ,
2622- ) -> Result < usize > {
2623- let task_ctx = memory_limited_aggregate_join_task_ctx ( batch_size, None ) ?;
2624- let batches = common:: collect ( aggregate. execute ( 0 , task_ctx) ?) . await ?;
2625-
2626- assert ! (
2627- batches. len( ) > 1 ,
2628- "expected final aggregate output to be split into multiple batches"
2629- ) ;
2630- assert_eq ! (
2631- batches. iter( ) . map( RecordBatch :: num_rows) . sum:: <usize >( ) ,
2632- num_groups
2633- ) ;
2634-
2635- let metrics = aggregate. metrics ( ) . expect ( "aggregate metrics" ) ;
2636- let peak_mem_used = metrics
2637- . sum_by_name ( "peak_mem_used" )
2638- . expect ( "peak_mem_used metric" )
2639- . as_usize ( ) ;
2640- assert ! (
2641- peak_mem_used > 0 ,
2642- "expected non-zero final aggregate peak memory"
2643- ) ;
2644-
2645- Ok ( peak_mem_used)
2646- }
2647-
2648- async fn run_aggregate_build_side_join (
2649- aggregate : Arc < AggregateExec > ,
2650- num_groups : usize ,
2651- batch_size : usize ,
2652- memory_limit : usize ,
2653- ) -> Result < Vec < RecordBatch > > {
26542606 let aggregate: Arc < dyn ExecutionPlan > = aggregate;
2655- let right = probe_side ( num_groups) ?;
2656- let on = vec ! [ (
2657- Arc :: new( Column :: new_with_schema( "group_key" , & aggregate. schema( ) ) ?) as _,
2658- Arc :: new( Column :: new_with_schema( "probe_key" , & right. schema( ) ) ?) as _,
2659- ) ] ;
26602607 let join = HashJoinExec :: try_new (
2661- aggregate,
2662- right,
2663- on,
2608+ Arc :: clone ( & aggregate) ,
2609+ probe,
2610+ vec ! [ (
2611+ Arc :: new( Column :: new_with_schema( "group_key" , & aggregate. schema( ) ) ?) as _,
2612+ Arc :: new( Column :: new_with_schema( "probe_key" , & probe_schema) ?) as _,
2613+ ) ] ,
26642614 None ,
26652615 & JoinType :: Inner ,
26662616 None ,
26672617 PartitionMode :: CollectLeft ,
26682618 NullEquality :: NullEqualsNothing ,
26692619 false ,
26702620 ) ?;
2671-
2672- let task_ctx =
2673- memory_limited_aggregate_join_task_ctx ( batch_size, Some ( memory_limit) ) ?;
2674- common:: collect ( join. execute ( 0 , task_ctx) ?) . await
2675- }
2676-
2677- #[ tokio:: test]
2678- async fn build_side_final_aggregate_respects_grouped_memory_limit ( ) -> Result < ( ) > {
2679- const BATCH_SIZE : usize = 8192 ;
2680- const NUM_GROUPS : usize = BATCH_SIZE * 32 + 1 ;
2681- const EXPECTED_JOIN_ROWS : usize = 3 ;
2682-
2683- let aggregate = final_aggregate_build_side ( NUM_GROUPS , BATCH_SIZE ) . await ?;
2684- let aggregate_peak_mem_used =
2685- final_aggregate_peak_mem_used ( & aggregate, NUM_GROUPS , BATCH_SIZE ) . await ?;
2686- let memory_limit = aggregate_peak_mem_used * 2 ;
2687-
2688- let batches = run_aggregate_build_side_join (
2689- aggregate,
2690- NUM_GROUPS ,
2691- BATCH_SIZE ,
2692- memory_limit,
2693- )
2694- . await ?;
2621+ let runtime = RuntimeEnvBuilder :: new ( )
2622+ . with_memory_limit ( memory_limit, 1.0 )
2623+ . build_arc ( ) ?;
2624+ let task_ctx = Arc :: new ( TaskContext :: default ( ) . with_runtime ( runtime) ) ;
2625+ let batches = common:: collect ( join. execute ( 0 , task_ctx) ?) . await ?;
26952626 assert_eq ! (
26962627 batches. iter( ) . map( RecordBatch :: num_rows) . sum:: <usize >( ) ,
26972628 EXPECTED_JOIN_ROWS
0 commit comments