Skip to content

Commit 697826b

Browse files
committed
feat: Enable the hash join to accept a pre-built hash table for joining
Signed-off-by: Yuan <[email protected]>
1 parent 430c3fa commit 697826b

File tree

11 files changed

+320
-107
lines changed

11 files changed

+320
-107
lines changed

velox/core/PlanNode.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,7 +3107,8 @@ class HashJoinNode : public AbstractJoinNode {
31073107
TypedExprPtr filter,
31083108
PlanNodePtr left,
31093109
PlanNodePtr right,
3110-
RowTypePtr outputType)
3110+
RowTypePtr outputType,
3111+
void* reusedHashTableAddress = nullptr)
31113112
: AbstractJoinNode(
31123113
id,
31133114
joinType,
@@ -3117,9 +3118,9 @@ class HashJoinNode : public AbstractJoinNode {
31173118
std::move(left),
31183119
std::move(right),
31193120
std::move(outputType)),
3120-
nullAware_{nullAware} {
3121+
nullAware_{nullAware},
3122+
reusedHashTableAddress_(reusedHashTableAddress) {
31213123
validate();
3122-
31233124
if (nullAware) {
31243125
VELOX_USER_CHECK(
31253126
isNullAwareSupported(joinType),
@@ -3202,6 +3203,10 @@ class HashJoinNode : public AbstractJoinNode {
32023203
return nullAware_;
32033204
}
32043205

3206+
void* reusedHashTableAddress() const {
3207+
return reusedHashTableAddress_;
3208+
}
3209+
32053210
folly::dynamic serialize() const override;
32063211

32073212
static PlanNodePtr create(const folly::dynamic& obj, void* context);
@@ -3210,6 +3215,8 @@ class HashJoinNode : public AbstractJoinNode {
32103215
void addDetails(std::stringstream& stream) const override;
32113216

32123217
const bool nullAware_;
3218+
3219+
void* reusedHashTableAddress_;
32133220
};
32143221

32153222
using HashJoinNodePtr = std::shared_ptr<const HashJoinNode>;

velox/exec/HashBuild.cpp

Lines changed: 77 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,46 +66,83 @@ HashBuild::HashBuild(
6666
joinBridge_(operatorCtx_->task()->getHashJoinBridgeLocked(
6767
operatorCtx_->driverCtx()->splitGroupId,
6868
planNodeId())),
69-
keyChannelMap_(joinNode_->rightKeys().size()) {
69+
keyChannelMap_(joinNode_->rightKeys().size()),
70+
reusedHashTableAddress_(joinNode_->reusedHashTableAddress()) {
7071
VELOX_CHECK(pool()->trackUsage());
7172
VELOX_CHECK_NOT_NULL(joinBridge_);
7273

7374
joinBridge_->addBuilder();
7475

75-
const auto& inputType = joinNode_->sources()[1]->outputType();
76-
77-
const auto numKeys = joinNode_->rightKeys().size();
78-
keyChannels_.reserve(numKeys);
79-
80-
for (int i = 0; i < numKeys; ++i) {
81-
auto& key = joinNode_->rightKeys()[i];
82-
auto channel = exprToChannel(key.get(), inputType);
83-
keyChannelMap_[channel] = i;
84-
keyChannels_.emplace_back(channel);
85-
}
86-
87-
// Identify the non-key build side columns and make a decoder for each.
88-
const int32_t numDependents = inputType->size() - numKeys;
89-
if (numDependents > 0) {
90-
// Number of join keys (numKeys) may be less then number of input columns
91-
// (inputType->size()). In this case numDependents is negative and cannot be
92-
// used to call 'reserve'. This happens when we join different probe side
93-
// keys with the same build side key: SELECT * FROM t LEFT JOIN u ON t.k1 =
94-
// u.k AND t.k2 = u.k.
95-
dependentChannels_.reserve(numDependents);
96-
decoders_.reserve(numDependents);
97-
}
98-
for (auto i = 0; i < inputType->size(); ++i) {
99-
if (keyChannelMap_.find(i) == keyChannelMap_.end()) {
100-
dependentChannels_.emplace_back(i);
101-
decoders_.emplace_back(std::make_unique<DecodedVector>());
76+
if (reusedHashTableAddress_ != nullptr) {
77+
auto baseHashTable =
78+
reinterpret_cast<exec::BaseHashTable*>(reusedHashTableAddress_);
79+
80+
VELOX_CHECK_NOT_NULL(joinBridge_);
81+
joinBridge_->start();
82+
83+
if (baseHashTable->joinHasNullKeys() && isAntiJoin(joinType_) &&
84+
nullAware_ && !joinNode_->filter()) {
85+
joinBridge_->setAntiJoinHasNullKeys();
86+
} else {
87+
baseHashTable->prepareJoinTable(
88+
{}, BaseHashTable::kNoSpillInputStartPartitionBit);
89+
90+
VELOX_CHECK_NOT_NULL(joinBridge_);
91+
std::unique_ptr<
92+
exec::BaseHashTable,
93+
std::function<void(exec::BaseHashTable*)>>
94+
hashTable(nullptr, [](exec::BaseHashTable* ptr) { /* Do nothing */ });
95+
96+
if (auto hasheTableWithNullKeys =
97+
dynamic_cast<exec::HashTable<true>*>(baseHashTable)) {
98+
hashTable.reset(hasheTableWithNullKeys);
99+
} else if (
100+
auto hasheTableWithoutNullKeys =
101+
dynamic_cast<exec::HashTable<false>*>(baseHashTable)) {
102+
hashTable.reset(hasheTableWithoutNullKeys);
103+
} else {
104+
VELOX_UNREACHABLE("Unexpected HashTable {}", baseHashTable->toString());
105+
}
106+
joinBridge_->setHashTable(
107+
std::move(hashTable), hashTable->joinHasNullKeys());
102108
}
103-
}
104109

105-
tableType_ = hashJoinTableType(joinNode_);
106-
setupTable();
107-
setupSpiller();
108-
stateCleared_ = false;
110+
} else {
111+
const auto& inputType = joinNode_->sources()[1]->outputType();
112+
113+
const auto numKeys = joinNode_->rightKeys().size();
114+
keyChannels_.reserve(numKeys);
115+
116+
for (int i = 0; i < numKeys; ++i) {
117+
auto& key = joinNode_->rightKeys()[i];
118+
auto channel = exprToChannel(key.get(), inputType);
119+
keyChannelMap_[channel] = i;
120+
keyChannels_.emplace_back(channel);
121+
}
122+
123+
// Identify the non-key build side columns and make a decoder for each.
124+
const int32_t numDependents = inputType->size() - numKeys;
125+
if (numDependents > 0) {
126+
// Number of join keys (numKeys) may be less then number of input columns
127+
// (inputType->size()). In this case numDependents is negative and cannot
128+
// be used to call 'reserve'. This happens when we join different probe
129+
// side keys with the same build side key: SELECT * FROM t LEFT JOIN u ON
130+
// t.k1 = u.k AND t.k2 = u.k.
131+
dependentChannels_.reserve(numDependents);
132+
decoders_.reserve(numDependents);
133+
}
134+
for (auto i = 0; i < inputType->size(); ++i) {
135+
if (keyChannelMap_.find(i) == keyChannelMap_.end()) {
136+
dependentChannels_.emplace_back(i);
137+
decoders_.emplace_back(std::make_unique<DecodedVector>());
138+
}
139+
}
140+
141+
tableType_ = hashJoinTableType(joinNode_);
142+
setupTable();
143+
setupSpiller();
144+
stateCleared_ = false;
145+
}
109146
}
110147

111148
void HashBuild::initialize() {
@@ -622,6 +659,10 @@ void HashBuild::spillPartition(
622659
}
623660

624661
void HashBuild::noMoreInput() {
662+
if (reusedHashTableAddress_ != nullptr) {
663+
return;
664+
}
665+
625666
checkRunning();
626667

627668
if (noMoreInput_) {
@@ -995,6 +1036,9 @@ BlockingReason HashBuild::isBlocked(ContinueFuture* future) {
9951036
}
9961037

9971038
bool HashBuild::isFinished() {
1039+
if (reusedHashTableAddress_ != nullptr) {
1040+
return true;
1041+
}
9981042
return state_ == State::kFinish;
9991043
}
10001044

velox/exec/HashBuild.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class HashBuild final : public Operator {
6868
}
6969

7070
bool needsInput() const override {
71+
if (reusedHashTableAddress_ != nullptr) {
72+
return false;
73+
}
7174
return !noMoreInput_;
7275
}
7376

@@ -310,6 +313,8 @@ class HashBuild final : public Operator {
310313

311314
// Maps key channel in 'input_' to channel in key.
312315
folly::F14FastMap<column_index_t, column_index_t> keyChannelMap_;
316+
317+
void* reusedHashTableAddress_;
313318
};
314319

315320
inline std::ostream& operator<<(std::ostream& os, HashBuild::State state) {

velox/exec/HashJoinBridge.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,24 @@ void HashJoinBridge::setHashTable(
256256
notify(std::move(promises));
257257
}
258258

259+
void HashJoinBridge::setHashTable(
260+
std::shared_ptr<BaseHashTable> table,
261+
bool hasNullKeys) {
262+
VELOX_CHECK_NOT_NULL(table, "setHashTable called with null table");
263+
264+
std::vector<ContinuePromise> promises;
265+
{
266+
std::lock_guard<std::mutex> l(mutex_);
267+
VELOX_CHECK(started_);
268+
VELOX_CHECK(!buildResult_.has_value());
269+
VELOX_CHECK(restoringSpillShards_.empty());
270+
271+
buildResult_ = HashBuildResult(std::move(table), hasNullKeys);
272+
promises = std::move(promises_);
273+
}
274+
notify(std::move(promises));
275+
}
276+
259277
void HashJoinBridge::appendSpilledHashTablePartitions(
260278
SpillPartitionSet spillPartitionSet) {
261279
VELOX_CHECK(

velox/exec/HashJoinBridge.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class HashJoinBridge : public JoinBridge {
6363
std::shared_ptr<wave::HashTableHolder> table,
6464
bool hasNullKeys);
6565

66+
void setHashTable(std::shared_ptr<BaseHashTable> table, bool hasNullKeys);
67+
6668
/// Invoked by the probe operator to append the spilled hash table partitions
6769
/// while probing. The function appends the spilled table partitions into
6870
/// 'spillPartitionSets_' stack. This only applies if the disk spilling is
@@ -93,6 +95,9 @@ class HashJoinBridge : public JoinBridge {
9395
bool _hasNullKeys)
9496
: hasNullKeys(_hasNullKeys), waveTable(std::move(_table)) {}
9597

98+
HashBuildResult(std::shared_ptr<BaseHashTable> _table, bool _hasNullKeys)
99+
: hasNullKeys(_hasNullKeys), table(std::move(_table)) {}
100+
96101
bool hasNullKeys;
97102
std::shared_ptr<BaseHashTable> table;
98103

velox/exec/HashProbe.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ void HashProbe::prepareForSpillRestore() {
484484

485485
// Reset the internal states which are relevant to the previous probe run.
486486
noMoreSpillInput_ = false;
487-
if (lastProber_) {
487+
if (lastProber_ && !table_->reused()) {
488488
table_->clear(true);
489489
}
490490
table_.reset();
@@ -1071,7 +1071,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) {
10711071
}
10721072
} else {
10731073
joinBridge_->probeFinished();
1074-
if (table_ != nullptr) {
1074+
if (table_ != nullptr && !table_->reused()) {
10751075
table_->clear(true);
10761076
}
10771077
}
@@ -1209,7 +1209,9 @@ bool HashProbe::maybeReadSpillOutput() {
12091209
return false;
12101210
}
12111211

1212-
VELOX_DCHECK_EQ(table_->numDistinct(), 0);
1212+
if (!table_->reused()) {
1213+
VELOX_DCHECK_EQ(table_->numDistinct(), 0);
1214+
}
12131215

12141216
if (!spillOutputReader_->nextBatch(output_)) {
12151217
spillOutputReader_.reset();
@@ -1885,18 +1887,21 @@ void HashProbe::reclaim(
18851887
spillOutput(probeOps);
18861888

18871889
SpillPartitionSet spillPartitionSet;
1888-
if (hasMoreProbeInput) {
1889-
// Only spill hash table if any hash probe operators still has input probe
1890-
// data, otherwise we skip this step.
1891-
spillPartitionSet = spillHashJoinTable(
1892-
table_,
1893-
restoringPartitionId_,
1894-
tableSpillHashBits_,
1895-
joinNode_,
1896-
spillConfig(),
1897-
spillStats_.get());
1898-
VELOX_CHECK(!spillPartitionSet.empty());
1890+
if (!table_->reused()) {
1891+
if (hasMoreProbeInput) {
1892+
// Only spill hash table if any hash probe operators still has input probe
1893+
// data, otherwise we skip this step.
1894+
spillPartitionSet = spillHashJoinTable(
1895+
table_,
1896+
restoringPartitionId_,
1897+
tableSpillHashBits_,
1898+
joinNode_,
1899+
spillConfig(),
1900+
spillStats_.get());
1901+
VELOX_CHECK(!spillPartitionSet.empty());
1902+
}
18991903
}
1904+
19001905
const auto spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet);
19011906

19021907
for (auto* probeOp : probeOps) {
@@ -1911,12 +1916,15 @@ void HashProbe::reclaim(
19111916
probeOp->pool()->release();
19121917
}
19131918

1914-
// Clears memory resources held by the built hash table.
1915-
table_->clear(true);
1919+
if (!table_->reused()) {
1920+
// Clears memory resources held by the built hash table.
1921+
table_->clear(true);
19161922

1917-
// Sets the spilled hash table in the join bridge.
1918-
if (!spillPartitionIdSet.empty()) {
1919-
joinBridge_->appendSpilledHashTablePartitions(std::move(spillPartitionSet));
1923+
// Sets the spilled hash table in the join bridge.
1924+
if (!spillPartitionIdSet.empty()) {
1925+
joinBridge_->appendSpilledHashTablePartitions(
1926+
std::move(spillPartitionSet));
1927+
}
19201928
}
19211929
}
19221930

0 commit comments

Comments
 (0)