Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
get(WORKER_GRACEFUL_SHUTDOWN_SAVE_COMMITTED_FILEINFO_SYNC)
def workerGracefulShutdownDbDeleteFailurePolicy: String =
get(WORKER_GRACEFUL_SHUTDOWN_DB_DELETE_FAILURE_POLICY)
def workerGracefulShutdownCommitUncommittedPartitionsEnabled: Boolean =
get(WORKER_GRACEFUL_SHUTDOWN_COMMIT_UNCOMMITTED_PARTITIONS_ENABLED)

// //////////////////////////////////////////////////////
// Flusher //
Expand Down Expand Up @@ -4003,6 +4005,15 @@ object CelebornConf extends Logging {
.checkValues(Set("THROW", "EXIT", "IGNORE"))
.createWithDefault("IGNORE")

val WORKER_GRACEFUL_SHUTDOWN_COMMIT_UNCOMMITTED_PARTITIONS_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.worker.graceful.shutdown.commitUncommittedPartitions.enabled")
.categories("worker")
.doc("When true, during graceful shutdown the worker commits uncommitted " +
"partitions instead of waiting for LifecycleManager to send CommitFiles RPCs.")
.version("0.7.0")
.booleanConf
.createWithDefault(false)

val WORKER_DISKTIME_SLIDINGWINDOW_SIZE: ConfigEntry[Int] =
buildConf("celeborn.worker.flusher.diskTime.slidingWindow.size")
.withAlternative("celeborn.worker.flusher.avgFlushTime.slidingWindow.size")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,24 @@ class WorkerPartitionLocationInfo extends Logging {
} else null
}

/**
* Snapshot uncommitted partition unique IDs grouped by shuffle key.
* The returned snapshot is a best-effort view because ConcurrentHashMap
* iteration is weakly consistent — concurrent mutations may or may not
* be visible.
*
* @return (primaryIds, replicaIds) — each a Map[shuffleKey, List[uniqueId]]
*/
def snapshotUncommittedUniqueIds
: (Map[String, util.List[String]], Map[String, util.List[String]]) =
(snapshotIds(primaryPartitionLocations), snapshotIds(replicaPartitionLocations))
Comment thread
SteNicholas marked this conversation as resolved.

private def snapshotIds(partInfo: PartitionInfo): Map[String, util.List[String]] =
partInfo.asScala.collect {
case (shuffleKey, partMap) if !partMap.isEmpty =>
shuffleKey -> new util.ArrayList[String](partMap.keySet())
}.toMap

def isEmpty: Boolean = {
(primaryPartitionLocations.isEmpty ||
primaryPartitionLocations.asScala.values.forall(_.isEmpty)) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,68 @@ class WorkerPartitionLocationInfoSuite extends CelebornFunSuite {
assertEquals(workerPartitionLocationInfo.isEmpty, true)
}

test("snapshotUncommittedUniqueIds - empty info returns empty maps") {
val info = new WorkerPartitionLocationInfo
val (primary, replica) = info.snapshotUncommittedUniqueIds
assert(primary.isEmpty)
assert(replica.isEmpty)
}

test("snapshotUncommittedUniqueIds - captures correct IDs across shuffles") {
val info = new WorkerPartitionLocationInfo
val shuffle1 = "app1-0"
val shuffle2 = "app2-1"
val locs1 = new util.ArrayList[PartitionLocation]()
locs1.add(mockPartition(0, 0))
locs1.add(mockPartition(1, 0))
info.addPrimaryPartitions(shuffle1, locs1)
val locs2 = new util.ArrayList[PartitionLocation]()
locs2.add(mockPartition(2, 0))
info.addPrimaryPartitions(shuffle2, locs2)
val replicaLocs = new util.ArrayList[PartitionLocation]()
replicaLocs.add(mockPartition(3, 0))
info.addReplicaPartitions(shuffle1, replicaLocs)
val (primary, replica) = info.snapshotUncommittedUniqueIds
assert(primary.size == 2)
assert(primary(shuffle1).size() == 2)
assert(primary(shuffle1).contains("0-0"))
assert(primary(shuffle1).contains("1-0"))
assert(primary(shuffle2).size() == 1)
assert(primary(shuffle2).contains("2-0"))
assert(replica.size == 1)
assert(replica(shuffle1).size() == 1)
assert(replica(shuffle1).contains("3-0"))
}

test("snapshotUncommittedUniqueIds - filters empty shuffle keys") {
val info = new WorkerPartitionLocationInfo
val shuffleKey = "app1-0"
val locs = new util.ArrayList[PartitionLocation]()
locs.add(mockPartition(0, 0))
locs.add(mockPartition(1, 0))
info.addPrimaryPartitions(shuffleKey, locs)
info.removePrimaryPartitions(shuffleKey, locs.asScala.map(_.getUniqueId).asJava)
val (primary, _) = info.snapshotUncommittedUniqueIds
assert(!primary.contains(shuffleKey))
}

test("snapshotUncommittedUniqueIds - snapshot is a point-in-time copy") {
val info = new WorkerPartitionLocationInfo
val shuffleKey = "app1-0"
val locs = new util.ArrayList[PartitionLocation]()
locs.add(mockPartition(0, 0))
info.addPrimaryPartitions(shuffleKey, locs)
val (primary, _) = info.snapshotUncommittedUniqueIds
assert(primary(shuffleKey).size() == 1)
// Add more partitions after snapshot
val moreLocs = new util.ArrayList[PartitionLocation]()
moreLocs.add(mockPartition(1, 0))
moreLocs.add(mockPartition(2, 0))
info.addPrimaryPartitions(shuffleKey, moreLocs)
// Snapshot remains unchanged
assert(primary(shuffleKey).size() == 1)
}

private def mockPartition(partitionId: Int, epoch: Int): PartitionLocation = {
new PartitionLocation(
partitionId,
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ license: |
| celeborn.worker.flusher.threads | 16 | false | Flusher's thread count per disk for unknown-type disks. | 0.2.0 | |
| celeborn.worker.graceful.shutdown.checkSlotsFinished.interval | 1s | false | The wait interval of checking whether all released slots to be committed or destroyed during worker graceful shutdown | 0.2.0 | |
| celeborn.worker.graceful.shutdown.checkSlotsFinished.timeout | 480s | false | The wait time of waiting for the released slots to be committed or destroyed during worker graceful shutdown. | 0.2.0 | |
| celeborn.worker.graceful.shutdown.commitUncommittedPartitions.enabled | false | false | When true, during graceful shutdown the worker commits uncommitted partitions instead of waiting for LifecycleManager to send CommitFiles RPCs. | 0.7.0 | |
| celeborn.worker.graceful.shutdown.dbDeleteFailurePolicy | IGNORE | false | Policy for handling DB delete failures during graceful shutdown. THROW: throw exception, EXIT: trigger graceful shutdown, IGNORE: log error and continue (default). | 0.7.0 | |
| celeborn.worker.graceful.shutdown.enabled | false | false | When true, during worker shutdown, the worker will wait for all released slots to be committed or destroyed. | 0.2.0 | |
| celeborn.worker.graceful.shutdown.partitionSorter.shutdownTimeout | 120s | false | The wait time of waiting for sorting partition files during worker graceful shutdown. | 0.2.0 | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,103 @@ private[deploy] class Controller(
}
}

/**
* Proactively commits all uncommitted partitions during graceful shutdown.
*
* <p>Commit results are tracked per-shuffle because uniqueId ({@code partitionId-epoch})
* is not namespaced by shuffleKey — different shuffles can share the same uniqueId.
*
* <p>Only successfully committed or empty-file partitions are removed and their slots
* released. Failed or in-flight (timed-out) partitions are retained for the passive
* LifecycleManager CommitFiles retry path.
*/
private[worker] def commitUncommittedPartitions(): Unit = {
val (primarySnapshot, replicaSnapshot) = partitionLocationInfo.snapshotUncommittedUniqueIds
if (primarySnapshot.isEmpty && replicaSnapshot.isEmpty) {
logInfo("No uncommitted partitions.")
return
}
val shuffleKeys = primarySnapshot.keySet ++ replicaSnapshot.keySet
val primaryTotal = primarySnapshot.values.map(_.size()).sum
val replicaTotal = replicaSnapshot.values.map(_.size()).sum
logInfo(s"Committing uncommitted partitions across ${shuffleKeys.size} shuffles ($primaryTotal primary, $replicaTotal replica).")
val emptyIds = java.util.Collections.emptyList[String]()
val futures = ArrayBuffer[CompletableFuture[Void]]()
val tasks = ArrayBuffer[CompletableFuture[Void]]()
val committedPerShuffle = JavaUtils.newConcurrentHashMap[String, jSet[String]]()
val emptyPerShuffle = JavaUtils.newConcurrentHashMap[String, jSet[String]]()
for (shuffleKey <- shuffleKeys) {
val committedIds = ConcurrentHashMap.newKeySet[String]()
val emptyFileIds = ConcurrentHashMap.newKeySet[String]()
val failedIds = ConcurrentHashMap.newKeySet[String]()
val storageInfos = JavaUtils.newConcurrentHashMap[String, StorageInfo]()
val mapIdBitMap = JavaUtils.newConcurrentHashMap[String, RoaringBitmap]()
val partitionSizes = new LinkedBlockingQueue[Long]()
committedPerShuffle.put(shuffleKey, committedIds)
emptyPerShuffle.put(shuffleKey, emptyFileIds)
val primaryIds = primarySnapshot.getOrElse(shuffleKey, emptyIds)
val replicaIds = replicaSnapshot.getOrElse(shuffleKey, emptyIds)
val (primaryFuture, primaryTasks) = commitFiles(
shuffleKey,
primaryIds,
committedIds,
emptyFileIds,
failedIds,
storageInfos,
mapIdBitMap,
partitionSizes)
val (replicaFuture, replicaTasks) = commitFiles(
shuffleKey,
replicaIds,
committedIds,
emptyFileIds,
failedIds,
storageInfos,
mapIdBitMap,
partitionSizes,
isPrimary = false)
if (primaryFuture != null) { futures += primaryFuture }
if (replicaFuture != null) { futures += replicaFuture }
tasks ++= primaryTasks
tasks ++= replicaTasks
}
if (futures.nonEmpty) {
try {
CompletableFuture.allOf(futures.toArray: _*).get(
shuffleCommitTimeout,
TimeUnit.MILLISECONDS)
} catch {
case e: Exception =>
futures.foreach(_.cancel(true))
tasks.foreach(_.cancel(true))
logWarning(
s"Commit timed out after ${shuffleCommitTimeout}ms across ${shuffleKeys.size} shuffles: ${shuffleKeys.mkString(", ")}",
e)
}
Comment thread
SteNicholas marked this conversation as resolved.
}
var primaryCommitted = 0
var replicaCommitted = 0
for (shuffleKey <- shuffleKeys) {
val committed = committedPerShuffle.get(shuffleKey)
val empty = emptyPerShuffle.get(shuffleKey)
def isCommitted(id: String): Boolean = committed.contains(id) || empty.contains(id)
val primaryToRemove = primarySnapshot.getOrElse(shuffleKey, emptyIds)
.asScala.filter(isCommitted).asJava
val replicaToRemove = replicaSnapshot.getOrElse(shuffleKey, emptyIds)
.asScala.filter(isCommitted).asJava
val (primarySlots, _) =
partitionLocationInfo.removePrimaryPartitions(shuffleKey, primaryToRemove)
val (replicaSlots, _) =
partitionLocationInfo.removeReplicaPartitions(shuffleKey, replicaToRemove)
workerInfo.releaseSlots(shuffleKey, primarySlots)
workerInfo.releaseSlots(shuffleKey, replicaSlots)
primaryCommitted += primaryToRemove.size()
replicaCommitted += replicaToRemove.size()
}
Comment thread
SteNicholas marked this conversation as resolved.
logInfo(
s"Committed ${primaryCommitted + replicaCommitted} partitions ($primaryCommitted primary, $replicaCommitted replica) across ${shuffleKeys.size} shuffles.")
}

private def handleCommitFiles(
context: RpcCallContext,
shuffleKey: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,17 @@ private[celeborn] class Worker(
e)
}
shutdown.set(true)

if (conf.workerGracefulShutdownCommitUncommittedPartitionsEnabled) {
// Commit uncommitted partitions instead of waiting for LifecycleManager to send CommitFiles RPCs.
try {
controller.commitUncommittedPartitions()
} catch {
case e: Throwable =>
logError("Failed to commit uncommitted partitions during graceful shutdown", e)
}
}
Comment thread
SteNicholas marked this conversation as resolved.

val interval = conf.workerGracefulShutdownCheckSlotsFinishedInterval
val timeout = conf.workerGracefulShutdownCheckSlotsFinishedTimeoutMs
var waitTimes = 0
Expand Down
Loading
Loading