Skip to content

Commit c0b60dd

Browse files
cfmcgrady陈福
authored andcommitted
[CELEBORN] Optimize RegisterShuffle for large partition counts
1. Replace partitionIdList (ArrayList<Integer>) transmission with a single numPartitions integer via new PbRequestSlotsV2 message type, eliminating ~10MB protobuf payload for 2M-partition shuffles. Old PbRequestSlots is preserved for backward compatibility. 2. Optimize SlotsAllocator.roundRobin(): - Pre-compute per-worker usable slots into long[] arrays, replacing O(N*W) haveUsableSlots() stream calls with O(1) array lookups. - Replace LinkedList iterator + remove with index-based traversal, eliminating O(N^2) element shifting overhead that dominated CPU (90% in flame graph for 2M partitions).
1 parent a56f69a commit c0b60dd

16 files changed

Lines changed: 286 additions & 140 deletions

File tree

client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,9 @@ class ChangePartitionManager(
305305
|| (unavailableWorkerRatio >= dynamicResourceUnavailableFactor)) {
306306

307307
// get new available workers for the request partition ids
308-
val partitionIds = new util.ArrayList[Integer](
309-
changePartitions.map(_.partitionId).map(Integer.valueOf).toList.asJava)
310308
// The partition id value is not important here because we're just trying to get the workers to use
311309
val requestSlotsRes =
312-
lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, partitionIds)
310+
lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, changePartitions.size)
313311

314312
requestSlotsRes.status match {
315313
case StatusCode.REQUEST_FAILED =>

client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
776776
}
777777

778778
// First, request to get allocated slots from Primary
779-
val ids = new util.ArrayList[Integer](numPartitions)
780-
(0 until numPartitions).foreach(idx => ids.add(Integer.valueOf(idx)))
781-
val res = requestMasterRequestSlotsWithRetry(shuffleId, ids)
779+
val res = requestMasterRequestSlotsWithRetry(shuffleId, numPartitions)
782780

783781
res.status match {
784782
case StatusCode.REQUEST_FAILED =>
@@ -1832,7 +1830,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
18321830

18331831
def requestMasterRequestSlotsWithRetry(
18341832
shuffleId: Int,
1835-
ids: util.ArrayList[Integer]): RequestSlotsResponse = {
1833+
numPartitions: Int): RequestSlotsResponse = {
18361834
val excludedWorkerSet =
18371835
if (excludedWorkersFilter) {
18381836
workerStatusTracker.excludedWorkers.asScala.keys.toSet
@@ -1845,7 +1843,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
18451843
RequestSlots(
18461844
appUniqueId,
18471845
shuffleId,
1848-
ids,
1846+
numPartitions,
18491847
lifecycleHost,
18501848
pushReplicateEnabled,
18511849
pushRackAwareEnabled,

common/src/main/proto/TransportMessages.proto

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ enum MessageType {
117117
READ_REDUCER_PARTITION_END = 94;
118118
READ_REDUCER_PARTITION_END_RESPONSE = 95;
119119
REGISTER_APPLICATION_INFO = 96;
120+
121+
REQUEST_SLOTS_V2 = 97;
120122
}
121123

122124
enum StreamType {
@@ -325,6 +327,22 @@ message PbRequestSlots {
325327
string tagsExpr = 14;
326328
}
327329

330+
message PbRequestSlotsV2 {
331+
string applicationId = 1;
332+
int32 shuffleId = 2;
333+
int32 numPartitions = 3;
334+
string hostname = 4;
335+
bool shouldReplicate = 5;
336+
string requestId = 6;
337+
PbUserIdentifier userIdentifier = 7;
338+
bool shouldRackAware = 8;
339+
int32 maxWorkers = 9;
340+
int32 availableStorageTypes = 10;
341+
repeated PbWorkerInfo excludedWorkerSet = 11;
342+
bool packed = 12;
343+
string tagsExpr = 13;
344+
}
345+
328346
message PbSlotInfo {
329347
map<string, int32> slot = 1;
330348
}

common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ object ControlMessages extends Logging {
164164
case class RequestSlots(
165165
applicationId: String,
166166
shuffleId: Int,
167-
partitionIdList: util.ArrayList[Integer],
167+
numPartitions: Int,
168168
hostname: String,
169169
shouldReplicate: Boolean,
170170
shouldRackAware: Boolean,
@@ -650,7 +650,7 @@ object ControlMessages extends Logging {
650650
case RequestSlots(
651651
applicationId,
652652
shuffleId,
653-
partitionIdList,
653+
numPartitions,
654654
hostname,
655655
shouldReplicate,
656656
shouldRackAware,
@@ -661,10 +661,10 @@ object ControlMessages extends Logging {
661661
packed,
662662
tagsExpr,
663663
requestId) =>
664-
val payload = PbRequestSlots.newBuilder()
664+
val payload = PbRequestSlotsV2.newBuilder()
665665
.setApplicationId(applicationId)
666666
.setShuffleId(shuffleId)
667-
.addAllPartitionIdList(partitionIdList)
667+
.setNumPartitions(numPartitions)
668668
.setHostname(hostname)
669669
.setShouldReplicate(shouldReplicate)
670670
.setShouldRackAware(shouldRackAware)
@@ -677,7 +677,7 @@ object ControlMessages extends Logging {
677677
.setPacked(packed)
678678
.setTagsExpr(tagsExpr)
679679
.build().toByteArray
680-
new TransportMessage(MessageType.REQUEST_SLOTS, payload)
680+
new TransportMessage(MessageType.REQUEST_SLOTS_V2, payload)
681681

682682
case RequestSlotsResponse(status, workerResource, packed) =>
683683
val builder = PbRequestSlotsResponse.newBuilder()
@@ -1151,7 +1151,7 @@ object ControlMessages extends Logging {
11511151
RequestSlots(
11521152
pbRequestSlots.getApplicationId,
11531153
pbRequestSlots.getShuffleId,
1154-
new util.ArrayList[Integer](pbRequestSlots.getPartitionIdListList),
1154+
pbRequestSlots.getPartitionIdListList.size(),
11551155
pbRequestSlots.getHostname,
11561156
pbRequestSlots.getShouldReplicate,
11571157
pbRequestSlots.getShouldRackAware,
@@ -1163,6 +1163,26 @@ object ControlMessages extends Logging {
11631163
pbRequestSlots.getTagsExpr,
11641164
pbRequestSlots.getRequestId)
11651165

1166+
case REQUEST_SLOTS_V2_VALUE =>
1167+
val pb = PbRequestSlotsV2.parseFrom(message.getPayload)
1168+
val userIdentifier = PbSerDeUtils.fromPbUserIdentifier(pb.getUserIdentifier)
1169+
val excludedWorkerInfoSet =
1170+
pb.getExcludedWorkerSetList.asScala.map(PbSerDeUtils.fromPbWorkerInfo).toSet
1171+
RequestSlots(
1172+
pb.getApplicationId,
1173+
pb.getShuffleId,
1174+
pb.getNumPartitions,
1175+
pb.getHostname,
1176+
pb.getShouldReplicate,
1177+
pb.getShouldRackAware,
1178+
userIdentifier,
1179+
pb.getMaxWorkers,
1180+
pb.getAvailableStorageTypes,
1181+
excludedWorkerInfoSet,
1182+
pb.getPacked,
1183+
pb.getTagsExpr,
1184+
pb.getRequestId)
1185+
11661186
case REQUEST_SLOTS_RESPONSE_VALUE =>
11671187
val pbRequestSlotsResponse = PbRequestSlotsResponse.parseFrom(message.getPayload)
11681188
val workerResource =
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
================================================================================================
2+
200 workers, 10K partitions, no replication
3+
================================================================================================
4+
5+
OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4
6+
Apple M2 Pro
7+
200 workers, 10K partitions, no replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
8+
---------------------------------------------------------------------------------------------------------------------------
9+
offerSlotsRoundRobin 1 1 0 15.6 64.3 1.0X
10+
11+
12+
================================================================================================
13+
200 workers, 100K partitions, no replication
14+
================================================================================================
15+
16+
OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4
17+
Apple M2 Pro
18+
200 workers, 100K partitions, no replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
19+
----------------------------------------------------------------------------------------------------------------------------
20+
offerSlotsRoundRobin 6 7 0 15.8 63.4 1.0X
21+
22+
23+
================================================================================================
24+
500 workers, 100K partitions, with replication
25+
================================================================================================
26+
27+
OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4
28+
Apple M2 Pro
29+
500 workers, 100K partitions, with replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
30+
------------------------------------------------------------------------------------------------------------------------------
31+
offerSlotsRoundRobin 12 15 2 8.0 124.5 1.0X
32+
33+
34+
================================================================================================
35+
500 workers, 2M partitions, no replication
36+
================================================================================================
37+
38+
OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4
39+
Apple M2 Pro
40+
500 workers, 2M partitions, no replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
41+
--------------------------------------------------------------------------------------------------------------------------
42+
offerSlotsRoundRobin 252 351 102 7.9 126.1 1.0X
43+
44+
45+
================================================================================================
46+
1000 workers, 500K partitions, with replication
47+
================================================================================================
48+
49+
OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4
50+
Apple M2 Pro
51+
1000 workers, 500K partitions, with replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
52+
-------------------------------------------------------------------------------------------------------------------------------
53+
offerSlotsRoundRobin 77 159 46 6.5 154.7 1.0X
54+
55+

master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,6 @@ private static List<Integer> roundRobin(
517517
}
518518
// workerInfo -> (diskIndexForPrimaryAndReplica)
519519
Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
520-
List<Integer> partitionIdList = new LinkedList<>(partitionIds);
521520

522521
final int primaryWorkersSize = primaryWorkers.size();
523522
final int replicaWorkersSize = replicaWorkers.size();
@@ -533,19 +532,27 @@ private static List<Integer> roundRobin(
533532
replicaIndex = -1;
534533
}
535534

536-
ListIterator<Integer> iter = partitionIdList.listIterator(partitionIdList.size());
537-
// Iterate from the end to preserve O(1) removal of processed partitions.
538-
// This is important when we have a high number of concurrent apps that have a
539-
// high number of partitions.
535+
// Pre-compute usable slots per worker to avoid repeated stream operations O(N*W) -> O(W)
536+
long[] primaryUsableSlots = null;
537+
long[] replicaUsableSlots = null;
538+
if (slotsRestrictions != null && !slotsRestrictions.isEmpty()) {
539+
primaryUsableSlots = computeUsableSlots(primaryWorkers, slotsRestrictions);
540+
if (shouldReplicate) {
541+
replicaUsableSlots = computeUsableSlots(replicaWorkers, slotsRestrictions);
542+
}
543+
}
544+
545+
// Use index-based iteration to avoid O(N^2) LinkedList.remove() overhead.
546+
int allocatedCount = 0;
540547
outer:
541-
while (iter.hasPrevious()) {
548+
for (int pidIdx = 0; pidIdx < partitionIds.size(); pidIdx++) {
542549
int nextPrimaryInd = primaryIndex;
543550

544-
int partitionId = iter.previous();
551+
int partitionId = partitionIds.get(pidIdx);
545552
StorageInfo storageInfo;
546-
if (slotsRestrictions != null && !slotsRestrictions.isEmpty()) {
553+
if (primaryUsableSlots != null) {
547554
// this means that we'll select a mount point
548-
while (!haveUsableSlots(slotsRestrictions, primaryWorkers, nextPrimaryInd)) {
555+
while (primaryUsableSlots[nextPrimaryInd] <= 0) {
549556
nextPrimaryInd = primaryWorkersIncrementIndex.applyAsInt(nextPrimaryInd);
550557
if (nextPrimaryInd == primaryIndex) {
551558
break outer;
@@ -558,6 +565,7 @@ private static List<Integer> roundRobin(
558565
slotsRestrictions,
559566
workerDiskIndex,
560567
availableStorageTypes);
568+
primaryUsableSlots[nextPrimaryInd]--;
561569
} else {
562570
if (StorageInfo.localDiskAvailable(availableStorageTypes)) {
563571
while (!primaryWorkers.get(nextPrimaryInd).haveDisk()) {
@@ -576,9 +584,9 @@ private static List<Integer> roundRobin(
576584

577585
if (shouldReplicate) {
578586
int nextReplicaInd = replicaIndex;
579-
if (slotsRestrictions != null) {
587+
if (replicaUsableSlots != null) {
580588
while ((nextReplicaInd == nextPrimaryInd && skipLocationsOnSameWorkerCheck)
581-
|| !haveUsableSlots(slotsRestrictions, replicaWorkers, nextReplicaInd)
589+
|| replicaUsableSlots[nextReplicaInd] <= 0
582590
|| !satisfyRackAware(
583591
shouldRackAware,
584592
primaryWorkers,
@@ -597,6 +605,7 @@ private static List<Integer> roundRobin(
597605
slotsRestrictions,
598606
workerDiskIndex,
599607
availableStorageTypes);
608+
replicaUsableSlots[nextReplicaInd]--;
600609
} else if (shouldRackAware) {
601610
while ((nextReplicaInd == nextPrimaryInd && skipLocationsOnSameWorkerCheck)
602611
|| !satisfyRackAware(
@@ -642,9 +651,26 @@ private static List<Integer> roundRobin(
642651
v -> new Tuple2<>(new ArrayList<>(), new ArrayList<>()));
643652
locations._1.add(primaryPartition);
644653
primaryIndex = primaryWorkersIncrementIndex.applyAsInt(nextPrimaryInd);
645-
iter.remove();
654+
allocatedCount++;
655+
}
656+
if (allocatedCount == partitionIds.size()) {
657+
return Collections.emptyList();
646658
}
647-
return partitionIdList;
659+
return new ArrayList<>(partitionIds.subList(allocatedCount, partitionIds.size()));
660+
}
661+
662+
private static long[] computeUsableSlots(
663+
List<WorkerInfo> workers, Map<WorkerInfo, List<UsableDiskInfo>> restrictions) {
664+
long[] slots = new long[workers.size()];
665+
for (int i = 0; i < workers.size(); i++) {
666+
List<UsableDiskInfo> disks = restrictions.get(workers.get(i));
667+
if (disks != null) {
668+
for (UsableDiskInfo d : disks) {
669+
slots[i] += d.usableSlots;
670+
}
671+
}
672+
}
673+
return slots;
648674
}
649675

650676
private static boolean haveUsableSlots(

master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,7 @@ private[celeborn] class Master(
930930
}
931931

932932
def handleRequestSlots(context: RpcCallContext, requestSlots: RequestSlots): Unit = {
933-
val numReducers = requestSlots.partitionIdList.size()
933+
val numReducers = requestSlots.numPartitions
934934
val shuffleKey = Utils.makeShuffleKey(requestSlots.applicationId, requestSlots.shuffleId)
935935

936936
var availableWorkers = workersAvailable(requestSlots.excludedWorkerSet)
@@ -966,14 +966,21 @@ private[celeborn] class Master(
966966
0,
967967
startIndex + numWorkers - numAvailableWorkers))
968968
}
969+
// Build partitionIds list locally from numPartitions
970+
val partitionIds = new util.ArrayList[Integer](numReducers)
971+
var i = 0
972+
while (i < numReducers) {
973+
partitionIds.add(Integer.valueOf(i))
974+
i += 1
975+
}
969976
// offer slots
970977
val slots =
971978
masterSource.sample(MasterSource.OFFER_SLOTS_TIME, s"offerSlots-${Random.nextInt()}") {
972979
statusSystem.workersMap.synchronized {
973980
if (slotsAssignPolicy == SlotsAssignPolicy.LOADAWARE) {
974981
SlotsAllocator.offerSlotsLoadAware(
975982
selectedWorkers,
976-
requestSlots.partitionIdList,
983+
partitionIds,
977984
requestSlots.shouldReplicate,
978985
requestSlots.shouldRackAware,
979986
slotsAssignLoadAwareDiskGroupNum,
@@ -986,7 +993,7 @@ private[celeborn] class Master(
986993
} else {
987994
SlotsAllocator.offerSlotsRoundRobin(
988995
selectedWorkers,
989-
requestSlots.partitionIdList,
996+
partitionIds,
990997
requestSlots.shouldReplicate,
991998
requestSlots.shouldRackAware,
992999
requestSlots.availableStorageTypes,

master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class MasterSuite extends AnyFunSuite
177177
val requestSlots = RequestSlots(
178178
"app1",
179179
0,
180-
new util.ArrayList[Integer](),
180+
0,
181181
"localhost",
182182
shouldReplicate = false,
183183
shouldRackAware = false,

0 commit comments

Comments
 (0)