Skip to content

Commit 7f9556a

Browse files
committed
Reused Hash Table in HashBuild Phase for Gluten
1 parent 6280634 commit 7f9556a

File tree

5 files changed

+170
-5
lines changed

5 files changed

+170
-5
lines changed

velox/core/PlanNode.h

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ struct ArrowArrayStream;
2929

3030
namespace facebook::velox::core {
3131

32+
/// Direct access to HashTable from HashJoinNode would introduce a dependency
33+
/// cycle. We resolved this by defining an OpaqueHashTable interface,
34+
/// effectively decoupling the node logic from the specific table
35+
/// implementation.
36+
struct OpaqueHashTable;
37+
3238
class PlanNodeVisitor;
3339
class PlanNodeVisitorContext;
3440

@@ -3148,7 +3154,9 @@ class HashJoinNode : public AbstractJoinNode {
31483154
PlanNodePtr left,
31493155
PlanNodePtr right,
31503156
RowTypePtr outputType,
3151-
bool useHashTableCache = false)
3157+
bool useHashTableCache = false,
3158+
bool joinHasNullKeys = false,
3159+
std::shared_ptr<OpaqueHashTable> reusableHashTable = nullptr)
31523160
: AbstractJoinNode(
31533161
id,
31543162
joinType,
@@ -3159,7 +3167,9 @@ class HashJoinNode : public AbstractJoinNode {
31593167
std::move(right),
31603168
std::move(outputType)),
31613169
nullAware_{nullAware},
3162-
useHashTableCache_{useHashTableCache} {
3170+
useHashTableCache_{useHashTableCache},
3171+
joinHasNullKeys_{joinHasNullKeys},
3172+
reusableHashTable_(std::move(reusableHashTable)) {
31633173
validate();
31643174

31653175
if (nullAware) {
@@ -3197,6 +3207,17 @@ class HashJoinNode : public AbstractJoinNode {
31973207
return *this;
31983208
}
31993209

3210+
Builder& joinHasNullKeys(bool joinHasNullKeys) {
3211+
joinHasNullKeys_ = joinHasNullKeys;
3212+
return *this;
3213+
}
3214+
3215+
Builder& reusableHashTable(
3216+
std::shared_ptr<OpaqueHashTable> opaqueHashTable) {
3217+
reusableHashTable_ = std::move(opaqueHashTable);
3218+
return *this;
3219+
}
3220+
32003221
std::shared_ptr<HashJoinNode> build() const {
32013222
VELOX_USER_CHECK(id_.has_value(), "HashJoinNode id is not set");
32023223
VELOX_USER_CHECK(
@@ -3224,12 +3245,16 @@ class HashJoinNode : public AbstractJoinNode {
32243245
left_.value(),
32253246
right_.value(),
32263247
outputType_.value(),
3227-
useHashTableCache_.value_or(false));
3248+
useHashTableCache_.value_or(false),
3249+
joinHasNullKeys_.value_or(false),
3250+
reusableHashTable_.value_or(nullptr));
32283251
}
32293252

32303253
private:
32313254
std::optional<bool> nullAware_;
32323255
std::optional<bool> useHashTableCache_;
3256+
std::optional<bool> joinHasNullKeys_;
3257+
std::optional<std::shared_ptr<OpaqueHashTable>> reusableHashTable_;
32333258
};
32343259

32353260
std::string_view name() const override {
@@ -3258,6 +3283,14 @@ class HashJoinNode : public AbstractJoinNode {
32583283
return useHashTableCache_;
32593284
}
32603285

3286+
bool joinHasNullKeys() const {
3287+
return joinHasNullKeys_;
3288+
}
3289+
3290+
std::shared_ptr<OpaqueHashTable> reusableHashTable() const {
3291+
return reusableHashTable_;
3292+
}
3293+
32613294
folly::dynamic serialize() const override;
32623295

32633296
static PlanNodePtr create(const folly::dynamic& obj, void* context);
@@ -3267,6 +3300,8 @@ class HashJoinNode : public AbstractJoinNode {
32673300

32683301
const bool nullAware_;
32693302
const bool useHashTableCache_;
3303+
const bool joinHasNullKeys_;
3304+
std::shared_ptr<OpaqueHashTable> reusableHashTable_;
32703305
};
32713306

32723307
using HashJoinNodePtr = std::shared_ptr<const HashJoinNode>;

velox/exec/HashBuild.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,25 @@ BlockingReason fromStateToBlockingReason(HashBuild::State state) {
4747
}
4848
} // namespace
4949

50+
void HashBuild::setReusableHashTable(
51+
std::shared_ptr<core::OpaqueHashTable> opaqueHashTable) {
52+
joinBridge_->start();
53+
setState(State::kFinish);
54+
55+
if (joinNode_->joinHasNullKeys() && isAntiJoin(joinType_) && nullAware_ &&
56+
!joinNode_->filter()) {
57+
joinBridge_->setAntiJoinHasNullKeys();
58+
return;
59+
}
60+
61+
auto reusableHashTable =
62+
std::reinterpret_pointer_cast<exec::BaseHashTable>(opaqueHashTable);
63+
64+
joinBridge_->setHashTable(
65+
std::move(reusableHashTable), {}, joinNode_->joinHasNullKeys(), nullptr);
66+
reuseHashTable_ = true;
67+
}
68+
5069
HashBuild::HashBuild(
5170
int32_t operatorId,
5271
DriverCtx* driverCtx,
@@ -80,6 +99,12 @@ HashBuild::HashBuild(
8099

81100
joinBridge_->addBuilder();
82101

102+
if (auto opaqueHashTable = joinNode_->reusableHashTable()) {
103+
TestValue::adjust("facebook::velox::exec::HashBuild::HashBuild", this);
104+
setReusableHashTable(opaqueHashTable);
105+
return;
106+
}
107+
83108
const auto& inputType = joinNode_->sources()[1]->outputType();
84109

85110
const auto numKeys = joinNode_->rightKeys().size();
@@ -121,7 +146,7 @@ HashBuild::HashBuild(
121146
void HashBuild::initialize() {
122147
Operator::initialize();
123148

124-
if (setupCachedHashTable()) {
149+
if (setupCachedHashTable() || reuseHashTable_) {
125150
return;
126151
}
127152

velox/exec/HashBuild.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ class HashBuild final : public Operator {
108108
// Invoked to set up hash table to build.
109109
void setupTable();
110110

111+
// Reuse the pre-built hash table.
112+
void setReusableHashTable(
113+
std::shared_ptr<core::OpaqueHashTable> opaqueHashTable);
114+
111115
// Sets up hash table caching if enabled. Returns true if the cached table
112116
// is already available or if this operator should wait for another task
113117
// to build it, in which case further initialization should be skipped.
@@ -395,6 +399,8 @@ class HashBuild final : public Operator {
395399
// Count the number of hash table input rows for building deduped
396400
// hash table. It will not be updated after abandonBuildNoDupHash_ is true.
397401
int64_t numHashInputRows_ = 0;
402+
403+
bool reuseHashTable_ = false;
398404
};
399405

400406
inline std::ostream& operator<<(std::ostream& os, HashBuild::State state) {

velox/exec/HashProbe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,8 @@ bool HashProbe::canSpill() const {
986986
// Hash table caching is incompatible with spilling. When the table is
987987
// cached and shared across tasks, clearing it after probe would corrupt
988988
// the cache for subsequent tasks.
989-
if (joinNode_->useHashTableCache()) {
989+
if (joinNode_->useHashTableCache() ||
990+
joinNode_->reusableHashTable() != nullptr) {
990991
return false;
991992
}
992993
if (operatorCtx_->task()->hasMixedExecutionGroupJoin(joinNode_.get())) {

velox/exec/tests/HashJoinTest.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8529,6 +8529,104 @@ TEST_P(HashJoinTest, arrayBasedLookupCustomComparisonType) {
85298529
EXPECT_EQ(result->size(), 1'024);
85308530
}
85318531

8532+
DEBUG_ONLY_TEST_P(HashJoinTest, reuseHashTable) {
8533+
// Create build and probe vectors.
8534+
std::vector<RowVectorPtr> buildVectors = makeBatches(1, [&](int32_t) {
8535+
return makeRowVector(
8536+
{"u_0"},
8537+
{
8538+
makeFlatVector<int64_t>(100, [](auto row) { return row % 23; }),
8539+
});
8540+
});
8541+
8542+
std::vector<RowVectorPtr> probeVectors = makeBatches(5, [&](int32_t) {
8543+
return makeRowVector(
8544+
{"t_0"},
8545+
{
8546+
makeFlatVector<int64_t>(100, [](auto row) { return row % 23; }),
8547+
});
8548+
});
8549+
8550+
createDuckDbTable("t", probeVectors);
8551+
createDuckDbTable("u", buildVectors);
8552+
8553+
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
8554+
auto buildPlanNode =
8555+
PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode();
8556+
auto probePlanNode =
8557+
PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode();
8558+
8559+
auto outputType = ROW({"t_0", "u_0"}, {BIGINT(), BIGINT()});
8560+
8561+
// Build the HashTable for build side data
8562+
std::vector<std::unique_ptr<VectorHasher>> hashers;
8563+
hashers.push_back(std::make_unique<VectorHasher>(BIGINT(), 0));
8564+
8565+
std::shared_ptr<facebook::velox::exec::BaseHashTable> table =
8566+
HashTable<false>::createForJoin(
8567+
std::move(hashers),
8568+
{}, /*dependentTypes*/
8569+
true /*allowDuplicates*/,
8570+
true /*hasProbedFlag*/,
8571+
1 /*minTableSizeForParallelJoinBuild*/,
8572+
pool());
8573+
8574+
auto rowContainer = table->rows();
8575+
uint32_t numColumns = buildVectors[0]->childrenSize();
8576+
std::vector<DecodedVector> decodedVectors;
8577+
decodedVectors.reserve(numColumns);
8578+
for (const auto& rowVector : buildVectors) {
8579+
if (!rowVector || rowVector->size() == 0)
8580+
continue;
8581+
8582+
decodedVectors.clear();
8583+
SelectivityVector rows(rowVector->size());
8584+
for (auto& child : rowVector->children()) {
8585+
decodedVectors.emplace_back(*child, rows);
8586+
}
8587+
8588+
for (auto i = 0; i < rowVector->size(); ++i) {
8589+
auto* row = rowContainer->newRow();
8590+
for (auto j = 0; j < numColumns; ++j) {
8591+
rowContainer->store(decodedVectors[j], i, row, j);
8592+
}
8593+
}
8594+
}
8595+
8596+
table->prepareJoinTable(
8597+
{}, BaseHashTable::kNoSpillInputStartPartitionBit, 1'000'000);
8598+
8599+
auto opaqueSharedHashTable = std::shared_ptr<core::OpaqueHashTable>(
8600+
table, reinterpret_cast<core::OpaqueHashTable*>(table.get()));
8601+
8602+
auto joinNode =
8603+
core::HashJoinNode::Builder()
8604+
.id(planNodeIdGenerator->next())
8605+
.joinType(core::JoinType::kInner)
8606+
.nullAware(false)
8607+
.leftKeys(
8608+
{std::make_shared<core::FieldAccessTypedExpr>(INTEGER(), "t_0")})
8609+
.rightKeys(
8610+
{std::make_shared<core::FieldAccessTypedExpr>(INTEGER(), "u_0")})
8611+
.left(probePlanNode)
8612+
.right(buildPlanNode)
8613+
.outputType(outputType)
8614+
.reusableHashTable(opaqueSharedHashTable)
8615+
.build();
8616+
8617+
std::atomic_bool reusedHashTable{false};
8618+
SCOPED_TESTVALUE_SET(
8619+
"facebook::velox::exec::HashBuild::HashBuild",
8620+
std::function<void(HashBuild*)>(
8621+
[&](HashBuild* windowBuild) { reusedHashTable.store(true); }));
8622+
8623+
auto task =
8624+
AssertQueryBuilder(joinNode, duckDbQueryRunner_)
8625+
.maxDrivers(1)
8626+
.assertResults("SELECT t.t_0, u.u_0 FROM t, u WHERE t.t_0 = u.u_0");
8627+
ASSERT_TRUE(reusedHashTable.load());
8628+
}
8629+
85328630
DEBUG_ONLY_TEST_P(
85338631
HashJoinTest,
85348632
hashProbeShouldYieldWhenFilterConsistentlyRejectAll) {

0 commit comments

Comments
 (0)