Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 130 additions & 13 deletions src/daft-local-execution/src/join/inner_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,66 @@ use daft_recordbatch::{GrowableRecordBatch, ProbeState};

use crate::join::hash_join::HashJoinParams;

const DEFAULT_GROWABLE_SIZE: usize = 20;
const MIN_MATCHES_FOR_VECTORIZED_TAKE: usize = 1024;
const MIN_FANOUT_FOR_VECTORIZED_TAKE: usize = 4;
const MIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE: usize = 8;

type BuildMatch = (u64, usize, u64);

struct MatchStats {
matched_probe_rows: usize,
build_table_runs: usize,
}

pub(crate) fn probe_inner(
input: &MicroPartition,
probe_state: &ProbeState,
params: &HashJoinParams,
) -> DaftResult<MicroPartition> {
let build_side_tables = probe_state.get_record_batches().iter().collect::<Vec<_>>();
const DEFAULT_GROWABLE_SIZE: usize = 20;

let input_tables = input.record_batches();
let result_tables = input_tables
.iter()
.map(|input_table| {
let mut build_side_growable =
GrowableRecordBatch::new(&build_side_tables, false, DEFAULT_GROWABLE_SIZE)?;
let mut probe_side_idxs = Vec::new();

let join_keys = input_table.eval_expression_list(&params.probe_on)?;
let idx_iter = probe_state.probe_indices(join_keys)?;
let mut matches = Vec::new();
let mut matched_probe_rows = 0;
let mut build_table_runs = 0;
let mut previous_build_table = None;
for (probe_row_idx, inner_iter) in idx_iter.enumerate() {
let probe_matches_start = matches.len();
if let Some(inner_iter) = inner_iter {
for (build_rb_idx, build_row_idx) in inner_iter {
build_side_growable.extend(
build_rb_idx as usize,
build_row_idx as usize,
1,
);
probe_side_idxs.push(probe_row_idx as u64);
let build_rb_idx = build_rb_idx as usize;
if previous_build_table != Some(build_rb_idx) {
build_table_runs += 1;
previous_build_table = Some(build_rb_idx);
}
matches.push((probe_row_idx as u64, build_rb_idx, build_row_idx));
}
}
if matches.len() > probe_matches_start {
matched_probe_rows += 1;
}
}
let match_stats = MatchStats {
matched_probe_rows,
build_table_runs,
};

let build_side_table = build_side_growable.build()?;
let build_side_table =
build_side_for_inner_matches(&build_side_tables, &matches, &match_stats)?;
let probe_side_table = {
let indices_arr = UInt64Array::from_vec("", probe_side_idxs);
let indices_arr = UInt64Array::from_vec(
"",
matches
.iter()
.map(|(probe_row_idx, _, _)| *probe_row_idx)
.collect::<Vec<_>>(),
);
input_table.take(&indices_arr)?
};

Expand Down Expand Up @@ -83,3 +109,94 @@ pub(crate) fn probe_inner(
None,
))
}

fn build_side_for_inner_matches(
build_side_tables: &[&daft_recordbatch::RecordBatch],
matches: &[BuildMatch],
match_stats: &MatchStats,
) -> DaftResult<daft_recordbatch::RecordBatch> {
if should_use_vectorized_take(matches, match_stats) {
build_side_with_vectorized_take(build_side_tables, matches, match_stats)
} else {
build_side_with_growable(build_side_tables, matches)
}
}

fn should_use_vectorized_take(matches: &[BuildMatch], match_stats: &MatchStats) -> bool {
if matches.len() < MIN_MATCHES_FOR_VECTORIZED_TAKE || match_stats.matched_probe_rows == 0 {
return false;
}

if matches.len()
< match_stats
.matched_probe_rows
.saturating_mul(MIN_FANOUT_FOR_VECTORIZED_TAKE)
{
return false;
}

matches.len()
>= match_stats
.build_table_runs
.saturating_mul(MIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE)
}

fn build_side_with_growable(
build_side_tables: &[&daft_recordbatch::RecordBatch],
matches: &[BuildMatch],
) -> DaftResult<daft_recordbatch::RecordBatch> {
let mut build_side_growable =
GrowableRecordBatch::new(build_side_tables, false, DEFAULT_GROWABLE_SIZE)?;

for (_, build_rb_idx, build_row_idx) in matches {
build_side_growable.extend(*build_rb_idx, *build_row_idx as usize, 1);
}

build_side_growable.build()
}

fn build_side_with_vectorized_take(
build_side_tables: &[&daft_recordbatch::RecordBatch],
matches: &[BuildMatch],
match_stats: &MatchStats,
) -> DaftResult<daft_recordbatch::RecordBatch> {
if matches.is_empty() {
return build_side_with_growable(build_side_tables, matches);
}

let mut taken_tables = Vec::with_capacity(match_stats.build_table_runs);
let mut current_build_table = matches[0].1;
let mut run_row_idxs = Vec::new();

for (_, build_rb_idx, build_row_idx) in matches {
if *build_rb_idx != current_build_table {
push_taken_run(
build_side_tables,
&mut taken_tables,
current_build_table,
&mut run_row_idxs,
)?;
current_build_table = *build_rb_idx;
}
run_row_idxs.push(*build_row_idx);
}
push_taken_run(
build_side_tables,
&mut taken_tables,
current_build_table,
&mut run_row_idxs,
)?;

daft_recordbatch::RecordBatch::concat(taken_tables)
}

fn push_taken_run(
build_side_tables: &[&daft_recordbatch::RecordBatch],
taken_tables: &mut Vec<daft_recordbatch::RecordBatch>,
build_table_idx: usize,
run_row_idxs: &mut Vec<u64>,
) -> DaftResult<()> {
let indices_arr = UInt64Array::from_vec("", std::mem::take(run_row_idxs));
taken_tables.push(build_side_tables[build_table_idx].take(&indices_arr)?);
Ok(())
}
33 changes: 33 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,39 @@ def test_joins_all_same_key(join_strategy, join_type, make_df, n_partitions: int
}


@pytest.mark.parametrize("n_partitions", [1, 4])
def test_inner_join_high_build_side_fanout(make_df, n_partitions: int, with_default_morsel_size):
num_keys = 64
fanout = 32

left = make_df(
{
"A": list(range(num_keys)),
"left_payload": [f"left-{i}" for i in range(num_keys)],
},
repartition=n_partitions,
repartition_columns=["A"],
)
right_keys = [key for key in range(num_keys) for _ in range(fanout)]
right = make_df(
{
"A": right_keys,
"right_payload": [f"right-{key}-{idx:02d}" for key in range(num_keys) for idx in range(fanout)],
},
repartition=n_partitions,
repartition_columns=["A"],
)

joined = left.join(right, on="A", strategy="hash", how="inner").sort(["A", "right_payload"])
joined_data = joined.to_pydict()

assert joined_data["A"] == [key for key in range(num_keys) for _ in range(fanout)]
assert joined_data["left_payload"] == [f"left-{key}" for key in range(num_keys) for _ in range(fanout)]
assert joined_data["right_payload"] == [
f"right-{key}-{idx:02d}" for key in range(num_keys) for idx in range(fanout)
]


@pytest.mark.parametrize("n_partitions", get_n_partitions())
@pytest.mark.parametrize(
"join_strategy",
Expand Down
40 changes: 40 additions & 0 deletions tests/microbenchmarks/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,46 @@ def bench_join() -> DataFrame:
assert data[-big_factor:] == [str(small_length - 1)] * big_factor


@pytest.mark.benchmark(group="joins")
@pytest.mark.parametrize("num_partitions", [1, 10], ids=["1part", "10part"])
def test_inner_join_high_fanout(benchmark, num_partitions) -> None:
"""Test inner joins where each probe row matches many build-side rows."""
small_length = 1_000
fanout = 32

left_arr = np.arange(small_length)
np.random.shuffle(left_arr)
right_arr = np.repeat(np.arange(small_length), fanout)
np.random.shuffle(right_arr)

left_table = (
daft.from_pydict(
{
"keys": left_arr,
}
)
.into_partitions(num_partitions)
.collect()
)
right_table = daft.from_pydict(
{
"keys": right_arr,
"right_payload": [str(x) for x in right_arr],
}
).collect()

def bench_join() -> DataFrame:
return left_table.join(right_table, on=["keys"], how="inner").collect()

result = benchmark(bench_join)

assert len(result) == small_length * fanout
assert (
result.groupby("keys").agg(col("right_payload").count()).sort("keys").to_pydict()["right_payload"]
== [fanout] * small_length
)


@pytest.mark.benchmark(group="joins")
@pytest.mark.parametrize(
"num_samples, num_partitions",
Expand Down
Loading