Skip to content

Commit da4ba4c

Browse files
committed
[CELEBORN-2312] Support committing uncommitted partitions for graceful shutdown
1 parent a56f69a commit da4ba4c

7 files changed

Lines changed: 301 additions & 2 deletions

File tree

common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
13711371
get(WORKER_GRACEFUL_SHUTDOWN_SAVE_COMMITTED_FILEINFO_SYNC)
13721372
def workerGracefulShutdownDbDeleteFailurePolicy: String =
13731373
get(WORKER_GRACEFUL_SHUTDOWN_DB_DELETE_FAILURE_POLICY)
1374+
def workerGracefulShutdownCommitUncommittedPartitionsEnabled: Boolean =
1375+
get(WORKER_GRACEFUL_SHUTDOWN_COMMIT_UNCOMMITTED_PARTITIONS_ENABLED)
13741376

13751377
// //////////////////////////////////////////////////////
13761378
// Flusher //
@@ -4003,6 +4005,15 @@ object CelebornConf extends Logging {
40034005
.checkValues(Set("THROW", "EXIT", "IGNORE"))
40044006
.createWithDefault("IGNORE")
40054007

4008+
val WORKER_GRACEFUL_SHUTDOWN_COMMIT_UNCOMMITTED_PARTITIONS_ENABLED: ConfigEntry[Boolean] =
4009+
buildConf("celeborn.worker.graceful.shutdown.commitUncommittedPartitions.enabled")
4010+
.categories("worker")
4011+
.doc("When true, during graceful shutdown the worker commits uncommitted " +
4012+
"partitions instead of waiting for LifecycleManager to send CommitFiles RPCs.")
4013+
.version("0.7.0")
4014+
.booleanConf
4015+
.createWithDefault(false)
4016+
40064017
val WORKER_DISKTIME_SLIDINGWINDOW_SIZE: ConfigEntry[Int] =
40074018
buildConf("celeborn.worker.flusher.diskTime.slidingWindow.size")
40084019
.withAlternative("celeborn.worker.flusher.avgFlushTime.slidingWindow.size")

common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,23 @@ class WorkerPartitionLocationInfo extends Logging {
173173
} else null
174174
}
175175

176+
/**
177+
* Snapshot uncommitted partition unique IDs grouped by shuffle key.
178+
* The returned snapshot is a point-in-time view (ConcurrentHashMap iteration
179+
* is weakly consistent — concurrent mutations may or may not be visible).
180+
*
181+
* @return (primaryIds, replicaIds) — each a Map[shuffleKey, List[uniqueId]]
182+
*/
183+
def snapshotUncommittedUniqueIds
184+
: (Map[String, util.List[String]], Map[String, util.List[String]]) =
185+
(snapshotIds(primaryPartitionLocations), snapshotIds(replicaPartitionLocations))
186+
187+
private def snapshotIds(partInfo: PartitionInfo): Map[String, util.List[String]] =
188+
partInfo.asScala.collect {
189+
case (shuffleKey, partMap) if !partMap.isEmpty =>
190+
shuffleKey -> new util.ArrayList[String](partMap.keySet())
191+
}.toMap
192+
176193
def isEmpty: Boolean = {
177194
(primaryPartitionLocations.isEmpty ||
178195
primaryPartitionLocations.asScala.values.forall(_.isEmpty)) &&

common/src/test/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfoSuite.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,68 @@ class WorkerPartitionLocationInfoSuite extends CelebornFunSuite {
6464
assertEquals(workerPartitionLocationInfo.isEmpty, true)
6565
}
6666

67+
test("snapshotUncommittedUniqueIds - empty info returns empty maps") {
68+
val info = new WorkerPartitionLocationInfo
69+
val (primary, replica) = info.snapshotUncommittedUniqueIds
70+
assert(primary.isEmpty)
71+
assert(replica.isEmpty)
72+
}
73+
74+
test("snapshotUncommittedUniqueIds - captures correct IDs across shuffles") {
75+
val info = new WorkerPartitionLocationInfo
76+
val shuffle1 = "app1-0"
77+
val shuffle2 = "app2-1"
78+
val locs1 = new util.ArrayList[PartitionLocation]()
79+
locs1.add(mockPartition(0, 0))
80+
locs1.add(mockPartition(1, 0))
81+
info.addPrimaryPartitions(shuffle1, locs1)
82+
val locs2 = new util.ArrayList[PartitionLocation]()
83+
locs2.add(mockPartition(2, 0))
84+
info.addPrimaryPartitions(shuffle2, locs2)
85+
val replicaLocs = new util.ArrayList[PartitionLocation]()
86+
replicaLocs.add(mockPartition(3, 0))
87+
info.addReplicaPartitions(shuffle1, replicaLocs)
88+
val (primary, replica) = info.snapshotUncommittedUniqueIds
89+
assert(primary.size == 2)
90+
assert(primary(shuffle1).size() == 2)
91+
assert(primary(shuffle1).contains("0-0"))
92+
assert(primary(shuffle1).contains("1-0"))
93+
assert(primary(shuffle2).size() == 1)
94+
assert(primary(shuffle2).contains("2-0"))
95+
assert(replica.size == 1)
96+
assert(replica(shuffle1).size() == 1)
97+
assert(replica(shuffle1).contains("3-0"))
98+
}
99+
100+
test("snapshotUncommittedUniqueIds - filters empty shuffle keys") {
101+
val info = new WorkerPartitionLocationInfo
102+
val shuffleKey = "app1-0"
103+
val locs = new util.ArrayList[PartitionLocation]()
104+
locs.add(mockPartition(0, 0))
105+
locs.add(mockPartition(1, 0))
106+
info.addPrimaryPartitions(shuffleKey, locs)
107+
info.removePrimaryPartitions(shuffleKey, locs.asScala.map(_.getUniqueId).asJava)
108+
val (primary, _) = info.snapshotUncommittedUniqueIds
109+
assert(!primary.contains(shuffleKey))
110+
}
111+
112+
test("snapshotUncommittedUniqueIds - snapshot is a point-in-time copy") {
113+
val info = new WorkerPartitionLocationInfo
114+
val shuffleKey = "app1-0"
115+
val locs = new util.ArrayList[PartitionLocation]()
116+
locs.add(mockPartition(0, 0))
117+
info.addPrimaryPartitions(shuffleKey, locs)
118+
val (primary, _) = info.snapshotUncommittedUniqueIds
119+
assert(primary(shuffleKey).size() == 1)
120+
// Add more partitions after snapshot
121+
val moreLocs = new util.ArrayList[PartitionLocation]()
122+
moreLocs.add(mockPartition(1, 0))
123+
moreLocs.add(mockPartition(2, 0))
124+
info.addPrimaryPartitions(shuffleKey, moreLocs)
125+
// Snapshot remains unchanged
126+
assert(primary(shuffleKey).size() == 1)
127+
}
128+
67129
private def mockPartition(partitionId: Int, epoch: Int): PartitionLocation = {
68130
new PartitionLocation(
69131
partitionId,

docs/configuration/worker.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ license: |
102102
| celeborn.worker.flusher.threads | 16 | false | Flusher's thread count per disk for unknown-type disks. | 0.2.0 | |
103103
| celeborn.worker.graceful.shutdown.checkSlotsFinished.interval | 1s | false | The wait interval of checking whether all released slots to be committed or destroyed during worker graceful shutdown | 0.2.0 | |
104104
| celeborn.worker.graceful.shutdown.checkSlotsFinished.timeout | 480s | false | The wait time of waiting for the released slots to be committed or destroyed during worker graceful shutdown. | 0.2.0 | |
105+
| celeborn.worker.graceful.shutdown.commitUncommittedPartitions.enabled | false | false | When true, during graceful shutdown the worker commits uncommitted partitions instead of waiting for LifecycleManager to send CommitFiles RPCs. | 0.7.0 | |
105106
| celeborn.worker.graceful.shutdown.dbDeleteFailurePolicy | IGNORE | false | Policy for handling DB delete failures during graceful shutdown. THROW: throw exception, EXIT: trigger graceful shutdown, IGNORE: log error and continue (default). | 0.7.0 | |
106107
| celeborn.worker.graceful.shutdown.enabled | false | false | When true, during worker shutdown, the worker will wait for all released slots to be committed or destroyed. | 0.2.0 | |
107108
| celeborn.worker.graceful.shutdown.partitionSorter.shutdownTimeout | 120s | false | The wait time of waiting for sorting partition files during worker graceful shutdown. | 0.2.0 | |

worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,78 @@ private[deploy] class Controller(
459459
}
460460
}
461461

462+
private[worker] def commitUncommittedPartitions(): Unit = {
463+
val (primarySnapshot, replicaSnapshot) = partitionLocationInfo.snapshotUncommittedUniqueIds
464+
if (primarySnapshot.isEmpty && replicaSnapshot.isEmpty) {
465+
logInfo("No uncommitted partitions.")
466+
return
467+
}
468+
val shuffleKeys = primarySnapshot.keySet ++ replicaSnapshot.keySet
469+
val primaryTotal = primarySnapshot.values.map(_.size()).sum
470+
val replicaTotal = replicaSnapshot.values.map(_.size()).sum
471+
logInfo(s"Committing uncommitted partitions across ${shuffleKeys.size} shuffles ($primaryTotal primary, $replicaTotal replica).")
472+
val committedIds = ConcurrentHashMap.newKeySet[String]()
473+
val failedIds = ConcurrentHashMap.newKeySet[String]()
474+
val emptyFileIds = ConcurrentHashMap.newKeySet[String]()
475+
val storageInfos = JavaUtils.newConcurrentHashMap[String, StorageInfo]()
476+
val mapIdBitMap = JavaUtils.newConcurrentHashMap[String, RoaringBitmap]()
477+
val partitionSizes = new LinkedBlockingQueue[Long]()
478+
val emptyIds = java.util.Collections.emptyList[String]()
479+
val futures = ArrayBuffer[CompletableFuture[Void]]()
480+
for (shuffleKey <- shuffleKeys) {
481+
val primaryIds = primarySnapshot.getOrElse(shuffleKey, emptyIds)
482+
val replicaIds = replicaSnapshot.getOrElse(shuffleKey, emptyIds)
483+
val (primaryFuture, _) = commitFiles(
484+
shuffleKey,
485+
primaryIds,
486+
committedIds,
487+
emptyFileIds,
488+
failedIds,
489+
storageInfos,
490+
mapIdBitMap,
491+
partitionSizes)
492+
val (replicaFuture, _) = commitFiles(
493+
shuffleKey,
494+
replicaIds,
495+
committedIds,
496+
emptyFileIds,
497+
failedIds,
498+
storageInfos,
499+
mapIdBitMap,
500+
partitionSizes,
501+
isPrimary = false)
502+
if (primaryFuture != null) { futures += primaryFuture }
503+
if (replicaFuture != null) { futures += replicaFuture }
504+
}
505+
if (futures.nonEmpty) {
506+
try {
507+
CompletableFuture.allOf(futures: _*).get(shuffleCommitTimeout, TimeUnit.MILLISECONDS)
508+
val failedMsg = if (failedIds.size() > 0) s", ${failedIds.size()} failed" else ""
509+
logInfo(s"Committed ${committedIds.size()} partitions$failedMsg.")
510+
} catch {
511+
case e: Exception =>
512+
logWarning(
513+
s"Commit timed out after ${shuffleCommitTimeout}ms " +
514+
s"across ${shuffleKeys.size} shuffles " +
515+
s"(${committedIds.size()} committed, ${failedIds.size()} failed). " +
516+
s"Shuffle keys: ${shuffleKeys.mkString(", ")}",
517+
e)
518+
}
519+
}
520+
for (shuffleKey <- shuffleKeys) {
521+
val primaryIds = primarySnapshot.getOrElse(shuffleKey, emptyIds)
522+
.asScala.filterNot(failedIds.contains).asJava
523+
val replicaIds = replicaSnapshot.getOrElse(shuffleKey, emptyIds)
524+
.asScala.filterNot(failedIds.contains).asJava
525+
val (primarySlots, _) =
526+
partitionLocationInfo.removePrimaryPartitions(shuffleKey, primaryIds)
527+
val (replicaSlots, _) =
528+
partitionLocationInfo.removeReplicaPartitions(shuffleKey, replicaIds)
529+
workerInfo.releaseSlots(shuffleKey, primarySlots)
530+
workerInfo.releaseSlots(shuffleKey, replicaSlots)
531+
}
532+
}
533+
462534
private def handleCommitFiles(
463535
context: RpcCallContext,
464536
shuffleKey: String,

worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,17 @@ private[celeborn] class Worker(
985985
e)
986986
}
987987
shutdown.set(true)
988+
989+
if (conf.workerGracefulShutdownCommitUncommittedPartitionsEnabled) {
990+
// Commit uncommitted partitions instead of waiting for LifecycleManager to send CommitFiles RPCs.
991+
try {
992+
controller.commitUncommittedPartitions()
993+
} catch {
994+
case e: Throwable =>
995+
logError("Failed to commit uncommitted partitions during graceful shutdown", e)
996+
}
997+
}
998+
988999
val interval = conf.workerGracefulShutdownCheckSlotsFinishedInterval
9891000
val timeout = conf.workerGracefulShutdownCheckSlotsFinishedTimeoutMs
9901001
var waitTimes = 0

worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.celeborn.service.deploy.worker
1919

20-
import java.io.File
20+
import java.io.{File, IOException}
2121
import java.nio.file.{Files, Paths}
2222
import java.util
2323
import java.util.{HashSet => JHashSet}
@@ -33,7 +33,7 @@ import org.scalatest.funsuite.AnyFunSuite
3333

3434
import org.apache.celeborn.common.CelebornConf
3535
import org.apache.celeborn.common.identity.UserIdentifier
36-
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType}
36+
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, StorageInfo}
3737
import org.apache.celeborn.common.protocol.message.ControlMessages.CommitFilesResponse
3838
import org.apache.celeborn.common.protocol.message.StatusCode
3939
import org.apache.celeborn.common.quota.ResourceConsumption
@@ -303,4 +303,129 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach {
303303
assert(shuffleCommitTime.get(shuffleKey).get(epoch2) == null)
304304
assert(epochCommitMap.get(epoch2).response.status == StatusCode.SUCCESS)
305305
}
306+
307+
test("commitUncommittedPartitions - commits primary and replica partitions") {
308+
conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, "/tmp")
309+
worker = new Worker(conf, workerArgs)
310+
val controller = worker.controller
311+
controller.init(worker)
312+
val shuffleKey = "app1-0"
313+
val writer1 = mockWriter(100L)
314+
val writer2 = mockWriter(200L)
315+
val writer3 = mockWriter(50L)
316+
val primaryLocs = new util.ArrayList[PartitionLocation]()
317+
primaryLocs.add(mockWorkingPartition(0, writer1))
318+
primaryLocs.add(mockWorkingPartition(1, writer2))
319+
worker.partitionLocationInfo.addPrimaryPartitions(shuffleKey, primaryLocs)
320+
val replicaLocs = new util.ArrayList[PartitionLocation]()
321+
replicaLocs.add(mockWorkingPartition(2, writer3))
322+
worker.partitionLocationInfo.addReplicaPartitions(shuffleKey, replicaLocs)
323+
assert(!worker.partitionLocationInfo.isEmpty)
324+
controller.commitUncommittedPartitions()
325+
verify(writer1).close()
326+
verify(writer2).close()
327+
verify(writer3).close()
328+
assert(worker.partitionLocationInfo.isEmpty)
329+
}
330+
331+
test("commitUncommittedPartitions - no-op when no partitions") {
332+
conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, "/tmp")
333+
worker = new Worker(conf, workerArgs)
334+
val controller = worker.controller
335+
controller.init(worker)
336+
assert(worker.partitionLocationInfo.isEmpty)
337+
controller.commitUncommittedPartitions()
338+
assert(worker.partitionLocationInfo.isEmpty)
339+
}
340+
341+
test("commitUncommittedPartitions - idempotent on double call") {
342+
conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, "/tmp")
343+
worker = new Worker(conf, workerArgs)
344+
val controller = worker.controller
345+
controller.init(worker)
346+
val shuffleKey = "app1-0"
347+
val writer = mockWriter(100L)
348+
val locs = new util.ArrayList[PartitionLocation]()
349+
locs.add(mockWorkingPartition(0, writer))
350+
worker.partitionLocationInfo.addPrimaryPartitions(shuffleKey, locs)
351+
controller.commitUncommittedPartitions()
352+
assert(worker.partitionLocationInfo.isEmpty)
353+
// Second call — no partitions remain, verify close only called once
354+
controller.commitUncommittedPartitions()
355+
assert(worker.partitionLocationInfo.isEmpty)
356+
verify(writer, times(1)).close()
357+
}
358+
359+
test("commitUncommittedPartitions - retains failed partitions for passive wait") {
360+
conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, "/tmp")
361+
worker = new Worker(conf, workerArgs)
362+
val controller = worker.controller
363+
controller.init(worker)
364+
val shuffleKey = "app1-0"
365+
val successWriter = mockWriter(100L)
366+
val failWriter = mock[PartitionDataWriter]
367+
when(failWriter.close()).thenThrow(new IOException("disk error"))
368+
when(failWriter.getStorageInfo).thenReturn(new StorageInfo("/tmp", StorageInfo.Type.HDD, 1))
369+
when(failWriter.getMapIdBitMap).thenReturn(null)
370+
when(failWriter.getMetaHandler).thenReturn(null)
371+
val locs = new util.ArrayList[PartitionLocation]()
372+
locs.add(mockWorkingPartition(0, successWriter))
373+
locs.add(mockWorkingPartition(1, failWriter))
374+
worker.partitionLocationInfo.addPrimaryPartitions(shuffleKey, locs)
375+
controller.commitUncommittedPartitions()
376+
// Successful partition (0-0) removed, failed partition (1-0) retained for LifecycleManager retry
377+
assert(worker.partitionLocationInfo.getPrimaryLocation(shuffleKey, "1-0") != null)
378+
assert(worker.partitionLocationInfo.getPrimaryLocation(shuffleKey, "0-0") == null)
379+
}
380+
381+
test("commitUncommittedPartitions - commits across multiple shuffle keys") {
382+
conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, "/tmp")
383+
worker = new Worker(conf, workerArgs)
384+
val controller = worker.controller
385+
controller.init(worker)
386+
val shuffle1 = "app1-0"
387+
val shuffle2 = "app2-1"
388+
val writer1 = mockWriter(100L)
389+
val writer2 = mockWriter(200L)
390+
val writer3 = mockWriter(50L)
391+
val locs1 = new util.ArrayList[PartitionLocation]()
392+
locs1.add(mockWorkingPartition(0, writer1))
393+
worker.partitionLocationInfo.addPrimaryPartitions(shuffle1, locs1)
394+
val locs2 = new util.ArrayList[PartitionLocation]()
395+
locs2.add(mockWorkingPartition(1, writer2))
396+
worker.partitionLocationInfo.addPrimaryPartitions(shuffle2, locs2)
397+
val replicaLocs = new util.ArrayList[PartitionLocation]()
398+
replicaLocs.add(mockWorkingPartition(2, writer3))
399+
worker.partitionLocationInfo.addReplicaPartitions(shuffle1, replicaLocs)
400+
assert(!worker.partitionLocationInfo.isEmpty)
401+
controller.commitUncommittedPartitions()
402+
verify(writer1).close()
403+
verify(writer2).close()
404+
verify(writer3).close()
405+
assert(worker.partitionLocationInfo.isEmpty)
406+
}
407+
408+
private def mockWriter(bytesOnClose: Long): PartitionDataWriter = {
409+
val writer = mock[PartitionDataWriter]
410+
when(writer.close()).thenReturn(bytesOnClose)
411+
when(writer.getStorageInfo).thenReturn(new StorageInfo("/tmp", StorageInfo.Type.HDD, 1))
412+
when(writer.getMapIdBitMap).thenReturn(null)
413+
when(writer.getMetaHandler).thenReturn(null)
414+
writer
415+
}
416+
417+
private def mockWorkingPartition(
418+
partitionId: Int,
419+
writer: PartitionDataWriter): WorkingPartition = {
420+
val location = new PartitionLocation(
421+
partitionId,
422+
0,
423+
"host",
424+
0,
425+
0,
426+
0,
427+
0,
428+
PartitionLocation.Mode.PRIMARY)
429+
new WorkingPartition(location, writer)
430+
}
306431
}

0 commit comments

Comments
 (0)