Skip to content

Commit 3b8a277

Browse files
committed
test: simplify even more
1 parent d3235a9 commit 3b8a277

1 file changed

Lines changed: 53 additions & 122 deletions

File tree

  • datafusion/physical-plan/src/joins/hash_join

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 53 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,6 @@ mod tests {
21462146
exec_err, internal_err,
21472147
};
21482148
use datafusion_execution::config::SessionConfig;
2149-
use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
21502149
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
21512150
use datafusion_expr::Operator;
21522151
use datafusion_functions_aggregate::count::count_udaf;
@@ -2504,52 +2503,21 @@ mod tests {
25042503
Ok((columns, batches, metrics))
25052504
}
25062505

2507-
fn memory_limited_aggregate_join_task_ctx(
2508-
batch_size: usize,
2509-
memory_limit: Option<usize>,
2510-
) -> Result<Arc<TaskContext>> {
2511-
let mut session_config = SessionConfig::default().with_batch_size(batch_size);
2512-
2513-
// Keep the repro focused on normal hash aggregation and hash join paths.
2514-
session_config
2515-
.options_mut()
2516-
.execution
2517-
.skip_partial_aggregation_probe_rows_threshold = usize::MAX;
2518-
session_config
2519-
.options_mut()
2520-
.execution
2521-
.perfect_hash_join_small_build_threshold = 0;
2522-
session_config
2523-
.options_mut()
2524-
.execution
2525-
.perfect_hash_join_min_key_density = f64::INFINITY;
2526-
2527-
let mut runtime_builder = RuntimeEnvBuilder::new().with_disk_manager_builder(
2528-
DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2529-
);
2530-
if let Some(memory_limit) = memory_limit {
2531-
runtime_builder = runtime_builder.with_memory_limit(memory_limit, 1.0);
2532-
}
2533-
2534-
Ok(Arc::new(
2535-
TaskContext::default()
2536-
.with_session_config(session_config)
2537-
.with_runtime(runtime_builder.build_arc()?),
2538-
))
2506+
fn aggregate_join_group_key(i: usize) -> u32 {
2507+
(i as u32) * 1000
25392508
}
25402509

2541-
async fn final_aggregate_build_side(
2542-
num_groups: usize,
2543-
batch_size: usize,
2544-
) -> Result<Arc<AggregateExec>> {
2510+
async fn final_aggregate_build_side(num_groups: usize) -> Result<Arc<AggregateExec>> {
25452511
let raw_schema = Arc::new(Schema::new(vec![
25462512
Field::new("group_key", DataType::UInt32, false),
25472513
Field::new("value", DataType::UInt64, false),
25482514
]));
25492515
let batch = RecordBatch::try_new(
25502516
Arc::clone(&raw_schema),
25512517
vec![
2552-
Arc::new(UInt32Array::from_iter_values(0..num_groups as u32)),
2518+
Arc::new(UInt32Array::from_iter_values(
2519+
(0..num_groups).map(aggregate_join_group_key),
2520+
)),
25532521
Arc::new(UInt64Array::from(vec![1; num_groups])),
25542522
],
25552523
)?;
@@ -2575,9 +2543,10 @@ mod tests {
25752543
Arc::clone(&raw_schema),
25762544
)?);
25772545
let partial_schema = partial_aggregate.schema();
2578-
let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?;
2579-
let partial_batches =
2580-
common::collect(partial_aggregate.execute(0, task_ctx)?).await?;
2546+
let partial_batches = common::collect(
2547+
partial_aggregate.execute(0, Arc::new(TaskContext::default()))?,
2548+
)
2549+
.await?;
25812550
let partial_input = TestMemoryExec::try_new_exec(
25822551
&[partial_batches],
25832552
Arc::clone(&partial_schema),
@@ -2594,104 +2563,66 @@ mod tests {
25942563
)?))
25952564
}
25962565

2597-
fn probe_side(num_groups: usize) -> Result<Arc<dyn ExecutionPlan>> {
2598-
let schema = Arc::new(Schema::new(vec![Field::new(
2566+
#[tokio::test]
2567+
async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> {
2568+
const BATCH_SIZE: usize = 8192;
2569+
const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1;
2570+
const EXPECTED_JOIN_ROWS: usize = 3;
2571+
2572+
let aggregate = final_aggregate_build_side(NUM_GROUPS).await?;
2573+
let aggregate_batches =
2574+
common::collect(aggregate.execute(0, Arc::new(TaskContext::default()))?)
2575+
.await?;
2576+
assert!(aggregate_batches.len() > 1);
2577+
assert_eq!(
2578+
aggregate_batches
2579+
.iter()
2580+
.map(RecordBatch::num_rows)
2581+
.sum::<usize>(),
2582+
NUM_GROUPS
2583+
);
2584+
let aggregate_batch = concat_batches(&aggregate.schema(), &aggregate_batches)?;
2585+
let memory_limit = get_record_batch_memory_size(&aggregate_batch) * 4;
2586+
2587+
let probe_schema = Arc::new(Schema::new(vec![Field::new(
25992588
"probe_key",
26002589
DataType::UInt32,
26012590
false,
26022591
)]));
2603-
let batch = RecordBatch::try_new(
2604-
Arc::clone(&schema),
2592+
let probe_batch = RecordBatch::try_new(
2593+
Arc::clone(&probe_schema),
26052594
vec![Arc::new(UInt32Array::from(vec![
2606-
0,
2607-
(num_groups / 2) as u32,
2608-
(num_groups - 1) as u32,
2595+
aggregate_join_group_key(0),
2596+
aggregate_join_group_key(NUM_GROUPS / 2),
2597+
aggregate_join_group_key(NUM_GROUPS - 1),
26092598
]))],
26102599
)?;
2600+
let probe: Arc<dyn ExecutionPlan> = TestMemoryExec::try_new_exec(
2601+
&[vec![probe_batch]],
2602+
Arc::clone(&probe_schema),
2603+
None,
2604+
)?;
26112605

2612-
let exec: Arc<dyn ExecutionPlan> =
2613-
TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?;
2614-
2615-
Ok(exec)
2616-
}
2617-
2618-
async fn final_aggregate_peak_mem_used(
2619-
aggregate: &Arc<AggregateExec>,
2620-
num_groups: usize,
2621-
batch_size: usize,
2622-
) -> Result<usize> {
2623-
let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?;
2624-
let batches = common::collect(aggregate.execute(0, task_ctx)?).await?;
2625-
2626-
assert!(
2627-
batches.len() > 1,
2628-
"expected final aggregate output to be split into multiple batches"
2629-
);
2630-
assert_eq!(
2631-
batches.iter().map(RecordBatch::num_rows).sum::<usize>(),
2632-
num_groups
2633-
);
2634-
2635-
let metrics = aggregate.metrics().expect("aggregate metrics");
2636-
let peak_mem_used = metrics
2637-
.sum_by_name("peak_mem_used")
2638-
.expect("peak_mem_used metric")
2639-
.as_usize();
2640-
assert!(
2641-
peak_mem_used > 0,
2642-
"expected non-zero final aggregate peak memory"
2643-
);
2644-
2645-
Ok(peak_mem_used)
2646-
}
2647-
2648-
async fn run_aggregate_build_side_join(
2649-
aggregate: Arc<AggregateExec>,
2650-
num_groups: usize,
2651-
batch_size: usize,
2652-
memory_limit: usize,
2653-
) -> Result<Vec<RecordBatch>> {
26542606
let aggregate: Arc<dyn ExecutionPlan> = aggregate;
2655-
let right = probe_side(num_groups)?;
2656-
let on = vec![(
2657-
Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _,
2658-
Arc::new(Column::new_with_schema("probe_key", &right.schema())?) as _,
2659-
)];
26602607
let join = HashJoinExec::try_new(
2661-
aggregate,
2662-
right,
2663-
on,
2608+
Arc::clone(&aggregate),
2609+
probe,
2610+
vec![(
2611+
Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _,
2612+
Arc::new(Column::new_with_schema("probe_key", &probe_schema)?) as _,
2613+
)],
26642614
None,
26652615
&JoinType::Inner,
26662616
None,
26672617
PartitionMode::CollectLeft,
26682618
NullEquality::NullEqualsNothing,
26692619
false,
26702620
)?;
2671-
2672-
let task_ctx =
2673-
memory_limited_aggregate_join_task_ctx(batch_size, Some(memory_limit))?;
2674-
common::collect(join.execute(0, task_ctx)?).await
2675-
}
2676-
2677-
#[tokio::test]
2678-
async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> {
2679-
const BATCH_SIZE: usize = 8192;
2680-
const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1;
2681-
const EXPECTED_JOIN_ROWS: usize = 3;
2682-
2683-
let aggregate = final_aggregate_build_side(NUM_GROUPS, BATCH_SIZE).await?;
2684-
let aggregate_peak_mem_used =
2685-
final_aggregate_peak_mem_used(&aggregate, NUM_GROUPS, BATCH_SIZE).await?;
2686-
let memory_limit = aggregate_peak_mem_used * 2;
2687-
2688-
let batches = run_aggregate_build_side_join(
2689-
aggregate,
2690-
NUM_GROUPS,
2691-
BATCH_SIZE,
2692-
memory_limit,
2693-
)
2694-
.await?;
2621+
let runtime = RuntimeEnvBuilder::new()
2622+
.with_memory_limit(memory_limit, 1.0)
2623+
.build_arc()?;
2624+
let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime));
2625+
let batches = common::collect(join.execute(0, task_ctx)?).await?;
26952626
assert_eq!(
26962627
batches.iter().map(RecordBatch::num_rows).sum::<usize>(),
26972628
EXPECTED_JOIN_ROWS

0 commit comments

Comments
 (0)