Skip to content

Commit 89a77b6

Browse files
authored
[GLUTEN-11888] [VL] Parallel build hash table to improve bhj performance (#11889)
1 parent 9b268b5 commit 89a77b6

File tree

10 files changed

+84
-45
lines changed

10 files changed

+84
-45
lines changed

backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
480480
mode: BroadcastMode,
481481
child: SparkPlan,
482482
numOutputRows: SQLMetric,
483-
dataSize: SQLMetric): BuildSideRelation = {
483+
dataSize: SQLMetric,
484+
buildThreads: SQLMetric): BuildSideRelation = {
484485

485486
val (buildKeys, isNullAware) = mode match {
486487
case mode1: HashedRelationBroadcastMode =>

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,8 @@ class VeloxMetricsApi extends MetricsApi with Logging {
554554
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
555555
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
556556
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
557-
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")
557+
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"),
558+
"buildThreads" -> SQLMetrics.createMetric(sparkContext, "build threads")
558559
)
559560

560561
override def genColumnarSubqueryBroadcastMetrics(
@@ -667,7 +668,10 @@ class VeloxMetricsApi extends MetricsApi with Logging {
667668
"numOutputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of output bytes"),
668669
"loadLazyVectorTime" -> SQLMetrics.createNanoTimingMetric(
669670
sparkContext,
670-
"time of loading lazy vectors")
671+
"time of loading lazy vectors"),
672+
"buildHashTableTime" -> SQLMetrics.createTimingMetric(
673+
sparkContext,
674+
"time to build hash table")
671675
)
672676

673677
override def genHashJoinTransformerMetricsUpdater(

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
706706
mode: BroadcastMode,
707707
child: SparkPlan,
708708
numOutputRows: SQLMetric,
709-
dataSize: SQLMetric): BuildSideRelation = {
709+
dataSize: SQLMetric,
710+
buildThreads: SQLMetric): BuildSideRelation = {
710711

711712
val buildKeys = mode match {
712713
case mode1: HashedRelationBroadcastMode =>
@@ -851,22 +852,31 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
851852
numOutputRows += serialized.map(_.numRows).sum
852853
dataSize += rawSize
853854

855+
val rawThreads =
856+
math
857+
.ceil(dataSize.value.toDouble / VeloxConfig.get.veloxBroadcastHashTableBuildTargetBytes)
858+
.toInt
859+
val buildThreadsValue = if (rawThreads < 1) 1 else rawThreads
860+
buildThreads += buildThreadsValue
861+
854862
if (useOffheapBroadcastBuildRelation) {
855863
TaskResources.runUnsafe {
856864
UnsafeColumnarBuildSideRelation(
857865
newOutput,
858866
serialized.flatMap(_.offHeapData().asScala),
859867
mode,
860868
newBuildKeys,
861-
offload)
869+
offload,
870+
buildThreadsValue)
862871
}
863872
} else {
864873
ColumnarBuildSideRelation(
865874
newOutput,
866875
serialized.flatMap(_.onHeapData().asScala).toArray,
867876
mode,
868877
newBuildKeys,
869-
offload)
878+
offload,
879+
buildThreadsValue)
870880
}
871881
}
872882

backends-velox/src/main/scala/org/apache/gluten/config/VeloxConfig.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {
6262
def enableBroadcastBuildOncePerExecutor: Boolean =
6363
getConf(VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR)
6464

65-
def veloxBroadcastHashTableBuildThreads: Int =
66-
getConf(COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS)
65+
def veloxBroadcastHashTableBuildTargetBytes: Long =
66+
getConf(COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_TARGET_BYTES)
6767

6868
def veloxOrcScanEnabled: Boolean =
6969
getConf(VELOX_ORC_SCAN_ENABLED)
@@ -206,12 +206,14 @@ object VeloxConfig extends ConfigRegistry {
206206
.intConf
207207
.createOptional
208208

209-
val COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS =
210-
buildStaticConf("spark.gluten.sql.columnar.backend.velox.broadcastHashTableBuildThreads")
211-
.doc("The number of threads used to build the broadcast hash table. " +
212-
"If not set or set to 0, it will use the default number of threads (available processors).")
213-
.intConf
214-
.createWithDefault(1)
209+
val COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_TARGET_BYTES =
210+
buildStaticConf("spark.gluten.velox.broadcast.build.targetBytesPerThread")
211+
.doc(
212+
"It is used to calculate the number of hash table build threads. Based on our testing" +
213+
" across various thresholds (1MB to 128MB), we recommend a value of 32MB or 64MB," +
214+
" as these consistently provided the most significant performance gains.")
215+
.bytesConf(ByteUnit.BYTE)
216+
.createWithDefaultString("32MB")
215217

216218
val COLUMNAR_VELOX_ASYNC_TIMEOUT =
217219
buildStaticConf("spark.gluten.sql.columnar.backend.velox.asyncTimeoutOnTaskStopping")

backends-velox/src/main/scala/org/apache/gluten/execution/HashJoinExecTransformer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, BuildSide}
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
2727
import org.apache.spark.sql.execution.joins.BuildSideRelation
28+
import org.apache.spark.sql.execution.metric.SQLMetric
2829
import org.apache.spark.sql.vectorized.ColumnarBatch
2930

3031
import io.substrait.proto.JoinRel
@@ -158,7 +159,7 @@ case class BroadcastHashJoinExecTransformer(
158159
buildBroadcastTableId,
159160
isNullAwareAntiJoin,
160161
bloomFilterPushdownSize,
161-
VeloxConfig.get.veloxBroadcastHashTableBuildThreads
162+
metrics.get("buildHashTableTime")
162163
)
163164
val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast, context)
164165
// FIXME: Do we have to make build side a RDD?
@@ -176,4 +177,4 @@ case class BroadcastHashJoinContext(
176177
buildHashTableId: String,
177178
isNullAwareAntiJoin: Boolean = false,
178179
bloomFilterPushdownSize: Long,
179-
broadcastHashTableBuildThreads: Int)
180+
buildHashTableTimeMetric: Option[SQLMetric] = None)

backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ object ColumnarBuildSideRelation {
5151
batches: Array[Array[Byte]],
5252
mode: BroadcastMode,
5353
newBuildKeys: Seq[Expression] = Seq.empty,
54-
offload: Boolean = false): ColumnarBuildSideRelation = {
54+
offload: Boolean = false,
55+
buildThreads: Int = 1): ColumnarBuildSideRelation = {
5556
val boundMode = mode match {
5657
case HashedRelationBroadcastMode(keys, isNullAware) =>
5758
// Bind each key to the build-side output so simple cols become BoundReference
@@ -66,7 +67,8 @@ object ColumnarBuildSideRelation {
6667
batches,
6768
BroadcastModeUtils.toSafe(boundMode),
6869
newBuildKeys,
69-
offload)
70+
offload,
71+
buildThreads)
7072
}
7173
}
7274

@@ -75,7 +77,8 @@ case class ColumnarBuildSideRelation(
7577
batches: Array[Array[Byte]],
7678
safeBroadcastMode: SafeBroadcastMode,
7779
newBuildKeys: Seq[Expression],
78-
offload: Boolean)
80+
offload: Boolean,
81+
buildThreads: Int)
7982
extends BuildSideRelation
8083
with Logging
8184
with KnownSizeEstimation {
@@ -156,6 +159,7 @@ case class ColumnarBuildSideRelation(
156159
broadcastContext: BroadcastHashJoinContext): (Long, ColumnarBuildSideRelation) =
157160
synchronized {
158161
if (hashTableData == 0) {
162+
val startTime = System.nanoTime()
159163
val runtime = Runtimes.contextInstance(
160164
BackendsApiManager.getBackendName,
161165
"ColumnarBuildSideRelation#buildHashTable")
@@ -215,10 +219,15 @@ case class ColumnarBuildSideRelation(
215219
SubstraitUtil.toNameStruct(newOutput).toByteArray,
216220
broadcastContext.isNullAwareAntiJoin,
217221
broadcastContext.bloomFilterPushdownSize,
218-
broadcastContext.broadcastHashTableBuildThreads
222+
buildThreads
219223
)
220224

221225
jniWrapper.close(serializeHandle)
226+
227+
// Update build hash table time metric
228+
val elapsedTime = System.nanoTime() - startTime
229+
broadcastContext.buildHashTableTimeMetric.foreach(_ += elapsedTime / 1000000)
230+
222231
(hashTableData, this)
223232
} else {
224233
(HashJoinBuilder.cloneHashTable(hashTableData), null)

backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ object UnsafeColumnarBuildSideRelation {
5656
batches: Seq[UnsafeByteArray],
5757
mode: BroadcastMode,
5858
newBuildKeys: Seq[Expression] = Seq.empty,
59-
offload: Boolean = false): UnsafeColumnarBuildSideRelation = {
59+
offload: Boolean = false,
60+
buildThreads: Int = 1): UnsafeColumnarBuildSideRelation = {
6061
val boundMode = mode match {
6162
case HashedRelationBroadcastMode(keys, isNullAware) =>
6263
// Bind each key to the build-side output so simple cols become BoundReference
@@ -71,7 +72,8 @@ object UnsafeColumnarBuildSideRelation {
7172
batches,
7273
BroadcastModeUtils.toSafe(boundMode),
7374
newBuildKeys,
74-
offload)
75+
offload,
76+
buildThreads)
7577
}
7678
}
7779

@@ -91,7 +93,8 @@ class UnsafeColumnarBuildSideRelation(
9193
private var batches: Seq[UnsafeByteArray],
9294
private var safeBroadcastMode: SafeBroadcastMode,
9395
private var newBuildKeys: Seq[Expression],
94-
private var offload: Boolean)
96+
private var offload: Boolean,
97+
private var buildThreads: Int)
9598
extends BuildSideRelation
9699
with Externalizable
97100
with Logging
@@ -113,7 +116,7 @@ class UnsafeColumnarBuildSideRelation(
113116

114117
/** needed for serialization. */
115118
def this() = {
116-
this(null, null, null, Seq.empty, false)
119+
this(null, null, null, Seq.empty, false, 1)
117120
}
118121

119122
private[unsafe] def getBatches(): Seq[UnsafeByteArray] = {
@@ -125,6 +128,7 @@ class UnsafeColumnarBuildSideRelation(
125128
def buildHashTable(broadcastContext: BroadcastHashJoinContext): (Long, BuildSideRelation) =
126129
synchronized {
127130
if (hashTableData == 0) {
131+
val startTime = System.nanoTime()
128132
val runtime = Runtimes.contextInstance(
129133
BackendsApiManager.getBackendName,
130134
"UnsafeColumnarBuildSideRelation#buildHashTable")
@@ -185,10 +189,15 @@ class UnsafeColumnarBuildSideRelation(
185189
SubstraitUtil.toNameStruct(newOutput).toByteArray,
186190
broadcastContext.isNullAwareAntiJoin,
187191
broadcastContext.bloomFilterPushdownSize,
188-
broadcastContext.broadcastHashTableBuildThreads
192+
buildThreads
189193
)
190194

191195
jniWrapper.close(serializeHandle)
196+
197+
// Update build hash table time metric
198+
val elapsedTime = System.nanoTime() - startTime
199+
broadcastContext.buildHashTableTimeMetric.foreach(_ += elapsedTime / 1000000)
200+
192201
(hashTableData, this)
193202
} else {
194203
(HashJoinBuilder.cloneHashTable(hashTableData), null)
@@ -205,6 +214,7 @@ class UnsafeColumnarBuildSideRelation(
205214
out.writeObject(batches.toArray)
206215
out.writeObject(newBuildKeys)
207216
out.writeBoolean(offload)
217+
out.writeInt(buildThreads)
208218
}
209219

210220
override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
@@ -213,6 +223,7 @@ class UnsafeColumnarBuildSideRelation(
213223
kryo.writeClassAndObject(out, batches.toArray)
214224
kryo.writeClassAndObject(out, newBuildKeys)
215225
out.writeBoolean(offload)
226+
out.writeInt(buildThreads)
216227
}
217228

218229
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -221,6 +232,7 @@ class UnsafeColumnarBuildSideRelation(
221232
batches = in.readObject().asInstanceOf[Array[UnsafeByteArray]].toSeq
222233
newBuildKeys = in.readObject().asInstanceOf[Seq[Expression]]
223234
offload = in.readBoolean()
235+
buildThreads = in.readInt()
224236
}
225237

226238
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
@@ -229,6 +241,7 @@ class UnsafeColumnarBuildSideRelation(
229241
batches = kryo.readClassAndObject(in).asInstanceOf[Array[UnsafeByteArray]].toSeq
230242
newBuildKeys = kryo.readClassAndObject(in).asInstanceOf[Seq[Expression]]
231243
offload = in.readBoolean()
244+
buildThreads = in.readInt()
232245
}
233246

234247
private def transformProjection: UnsafeProjection = safeBroadcastMode match {

cpp/velox/jni/VeloxJniWrapper.cc

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <jni/JniCommon.h>
2222
#include <velox/connectors/hive/PartitionIdGenerator.h>
2323
#include <velox/exec/OperatorUtils.h>
24+
#include <folly/futures/Future.h>
25+
#include <folly/executors/CPUThreadPoolExecutor.h>
2426

2527
#include <exception>
2628
#include "JniUdf.h"
@@ -946,7 +948,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
946948
jbyteArray namedStruct,
947949
jboolean isNullAwareAntiJoin,
948950
jlong bloomFilterPushdownSize,
949-
jint broadcastHashTableBuildThreads) {
951+
jint numThreads) {
950952
JNI_METHOD_START
951953
const auto hashTableId = jStringToCString(env, tableId);
952954

@@ -985,17 +987,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
985987
cb.push_back(ObjectStore::retrieve<ColumnarBatch>(handle));
986988
}
987989

988-
size_t maxThreads = broadcastHashTableBuildThreads > 0
989-
? std::min((size_t)broadcastHashTableBuildThreads, (size_t)32)
990-
: std::min((size_t)std::thread::hardware_concurrency(), (size_t)32);
991-
992-
// Heuristic: Each thread should process at least a certain number of batches to justify parallelism overhead.
993-
// 32 batches is roughly 128k rows, which is a reasonable granularity for a single thread.
994-
constexpr size_t kMinBatchesPerThread = 32;
995-
size_t numThreads = std::min(maxThreads, (handleCount + kMinBatchesPerThread - 1) / kMinBatchesPerThread);
996-
numThreads = std::max((size_t)1, numThreads);
997-
998-
if (numThreads <= 1) {
990+
if (numThreads == 1) {
999991
auto builder = nativeHashTableBuild(
1000992
hashJoinKeys,
1001993
names,
@@ -1020,16 +1012,20 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
10201012
return gluten::getHashTableObjStore()->save(builder);
10211013
}
10221014

1023-
std::vector<std::thread> threads;
1024-
1015+
// Use thread pool (executor) instead of creating threads directly
1016+
auto executor = VeloxBackend::get()->executor();
1017+
10251018
std::vector<std::shared_ptr<gluten::HashTableBuilder>> hashTableBuilders(numThreads);
10261019
std::vector<std::unique_ptr<facebook::velox::exec::BaseHashTable>> otherTables(numThreads);
1020+
std::vector<folly::Future<folly::Unit>> futures;
1021+
futures.reserve(numThreads);
10271022

10281023
for (size_t t = 0; t < numThreads; ++t) {
10291024
size_t start = (handleCount * t) / numThreads;
10301025
size_t end = (handleCount * (t + 1)) / numThreads;
10311026

1032-
threads.emplace_back([&, t, start, end]() {
1027+
// Submit task to thread pool
1028+
auto future = folly::via(executor, [&, t, start, end]() {
10331029
std::vector<std::shared_ptr<gluten::ColumnarBatch>> threadBatches;
10341030
for (size_t i = start; i < end; ++i) {
10351031
threadBatches.push_back(cb[i]);
@@ -1050,11 +1046,12 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
10501046
hashTableBuilders[t] = std::move(builder);
10511047
otherTables[t] = std::move(hashTableBuilders[t]->uniqueTable());
10521048
});
1049+
1050+
futures.push_back(std::move(future));
10531051
}
10541052

1055-
for (auto& thread : threads) {
1056-
thread.join();
1057-
}
1053+
// Wait for all tasks to complete
1054+
folly::collectAll(futures).wait();
10581055

10591056
auto mainTable = std::move(otherTables[0]);
10601057
std::vector<std::unique_ptr<facebook::velox::exec::BaseHashTable>> tables;

gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ trait SparkPlanExecApi {
435435
mode: BroadcastMode,
436436
child: SparkPlan,
437437
numOutputRows: SQLMetric,
438-
dataSize: SQLMetric): BuildSideRelation
438+
dataSize: SQLMetric,
439+
buildThreads: SQLMetric = null): BuildSideRelation
439440

440441
def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {
441442
mode.canonicalized

gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
7575
mode,
7676
child,
7777
longMetric("numOutputRows"),
78-
longMetric("dataSize"))
78+
longMetric("dataSize"),
79+
metrics.getOrElse("buildThreads", null))
7980
}
8081

8182
val broadcasted = GlutenTimeMetric.millis(longMetric("broadcastTime")) {

0 commit comments

Comments
 (0)