diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java index a24e06d5a68..84d74f8c145 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java @@ -52,6 +52,14 @@ public static void validateAttemptConfig(SparkConf conf) throws IllegalArgumentE } } + public static String encodeAppShuffleIdentifier(int appShuffleId, TaskContext context) { + return appShuffleId + "-" + context.stageId() + "-" + context.stageAttemptNumber(); + } + + public static String[] decodeAppShuffleIdentifier(String appShuffleIdentifier) { + return appShuffleIdentifier.split("-"); + } + public static int getEncodedAttemptNumber(TaskContext context) { return (context.stageAttemptNumber() << 16) | context.attemptNumber(); } diff --git a/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala b/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala new file mode 100644 index 00000000000..e88f6f640be --- /dev/null +++ b/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.celeborn.spark + +import java.util +import java.util.concurrent.{LinkedBlockingQueue, ScheduledExecutorService, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.shuffle.celeborn.SparkCommonUtils + +import org.apache.celeborn.client.LifecycleManager +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.util.ThreadUtils + +private[celeborn] class FailedShuffleCleaner(lifecycleManager: LifecycleManager) extends Logging { + + // in celeborn ids + private val shufflesToBeCleaned = new LinkedBlockingQueue[Int]() + private val cleanedShuffleIds = new mutable.HashSet[Int] + + private lazy val cleanInterval = + lifecycleManager.conf.clientFetchCleanFailedShuffleIntervalMS + + // for test + def reset(): Unit = { + shufflesToBeCleaned.clear() + cleanedShuffleIds.clear() + if (cleanerThreadPool != null) { + cleanerThreadPool.shutdownNow() + cleanerThreadPool = null + } + } + + def addShuffleIdToBeCleaned(appShuffleIdentifier: String): Unit = { + val Array(appShuffleId, _, _) = SparkCommonUtils.decodeAppShuffleIdentifier( + appShuffleIdentifier) + lifecycleManager.getShuffleIdMapping.get(appShuffleId.toInt).foreach { + case (_, (celebornShuffleId, _)) => shufflesToBeCleaned.put(celebornShuffleId) + } + } + + def init(): Unit = { + cleanerThreadPool = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "failedShuffleCleanerThreadPool") + cleanerThreadPool.scheduleWithFixedDelay( + new Runnable { + override def run(): Unit = { + try { + val allShuffleIds = new util.ArrayList[Int] + shufflesToBeCleaned.drainTo(allShuffleIds) + allShuffleIds.asScala.foreach { shuffleId => + if (!cleanedShuffleIds.contains(shuffleId)) { + lifecycleManager.unregisterShuffle(shuffleId) + logInfo( + s"sent unregister shuffle request for shuffle $shuffleId (celeborn shuffle id)") + cleanedShuffleIds += shuffleId + } + } + } catch { + case e: Exception => + logError("unexpected exception in cleaner thread", e) + } + } + }, + cleanInterval, + cleanInterval, + TimeUnit.MILLISECONDS) + } + + init() + + def removeCleanedShuffleId(celebornShuffleId: Int): Unit = { + cleanedShuffleIds.remove(celebornShuffleId) + } + + private var cleanerThreadPool: ScheduledExecutorService = _ +} diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 8c099b29f3f..80ea5c256a1 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -37,6 +37,7 @@ import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.ShuffleMode; import org.apache.celeborn.reflect.DynMethods; +import org.apache.celeborn.spark.FailedShuffleCleaner; /** * In order to support Spark Stage resubmit with ShuffleReader FetchFails, Celeborn shuffleId has to @@ -84,6 +85,8 @@ public class SparkShuffleManager implements ShuffleManager { ConcurrentHashMap.newKeySet(); private final CelebornShuffleFallbackPolicyRunner fallbackPolicyRunner; + private FailedShuffleCleaner failedShuffleCleaner = null; + private long sendBufferPoolCheckInterval; private long sendBufferPoolExpireTimeout; @@ -158,6 +161,23 @@ private void initializeLifecycleManager(String appId) { } } + if (lifecycleManager.conf().clientFetchCleanFailedShuffle()) { + if (!lifecycleManager.conf().clientStageRerunEnabled()) { + throw new IllegalArgumentException( + CelebornConf.CLIENT_STAGE_RERUN_ENABLED().key() + + " has to be " + + "enabled, when " + + CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE().key() + + " is set to true"); + } + failedShuffleCleaner = new FailedShuffleCleaner(lifecycleManager); + lifecycleManager.registerValidateCelebornShuffleIdForCleanCallback( + (appShuffleIdentifier) -> + SparkUtils.addWriterShuffleIdsToBeCleaned(this, appShuffleIdentifier)); + lifecycleManager.registerUnregisterShuffleCallback( + (celebornShuffleId) -> SparkUtils.removeCleanedShuffleId(this, celebornShuffleId)); + } + if (celebornConf.getReducerFileGroupBroadcastEnabled()) { lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback( (shuffleId, getReducerFileGroupResponse) -> @@ -249,6 +269,9 @@ public void stop() { _sortShuffleManager.stop(); _sortShuffleManager = null; } + if (celebornConf.clientFetchCleanFailedShuffle()) { + failedShuffleCleaner.reset(); + } } @Override @@ -470,4 +493,8 @@ private void checkUserClassPathFirst(ShuffleHandle handle) { public LifecycleManager getLifecycleManager() { return this.lifecycleManager; } + + public FailedShuffleCleaner getFailedShuffleCleaner() { + return this.failedShuffleCleaner; + } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index b2e64565ec8..fc5d605d8ac 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -128,17 +128,14 @@ public static String appUniqueId(SparkContext context) { .getOrElse(context::applicationId); } - public static String getAppShuffleIdentifier(int appShuffleId, TaskContext context) { - return appShuffleId + "-" + context.stageId() + "-" + context.stageAttemptNumber(); - } - public static int celebornShuffleId( ShuffleClient client, CelebornShuffleHandle handle, TaskContext context, Boolean isWriter) { if (handle.throwsFetchFailure()) { - String appShuffleIdentifier = getAppShuffleIdentifier(handle.shuffleId(), context); + String appShuffleIdentifier = + SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId(), context); Tuple2 res = client.getShuffleId( handle.shuffleId(), @@ -327,7 +324,8 @@ public static void addFailureListenerIfBarrierTask( if (!(taskContext instanceof BarrierTaskContext)) return; int appShuffleId = handle.shuffleId(); - String appShuffleIdentifier = SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext); + String appShuffleIdentifier = + SparkCommonUtils.encodeAppShuffleIdentifier(appShuffleId, taskContext); BarrierTaskContext barrierContext = (BarrierTaskContext) taskContext; barrierContext.addTaskFailureListener( @@ -625,4 +623,14 @@ public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuff return null; }); } + + public static void addWriterShuffleIdsToBeCleaned( + SparkShuffleManager sparkShuffleManager, String appShuffleIdentifier) { + sparkShuffleManager.getFailedShuffleCleaner().addShuffleIdToBeCleaned(appShuffleIdentifier); + } + + public static void removeCleanedShuffleId( + SparkShuffleManager sparkShuffleManager, int celebornShuffleId) { + sparkShuffleManager.getFailedShuffleCleaner().removeCleanedShuffleId(celebornShuffleId); + } } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index ae692f9089a..c0eb4dfb84c 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -931,6 +931,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logInfo(s"reuse existing shuffleId $id for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") id } else { + // this branch means it is a redo of previous write stage if (isBarrierStage) { // unregister previous shuffle(s) which are still valid val mapUpdates = shuffleIds.filter(_._2._2).map { kv => @@ -941,6 +942,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } val newShuffleId = shuffleIdGenerator.getAndIncrement() logInfo(s"generate new shuffleId $newShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier") + validateCelebornShuffleIdForClean.foreach(callback => + callback.accept(appShuffleIdentifier)) shuffleIds.put(appShuffleIdentifier, (newShuffleId, true)) newShuffleId } @@ -954,11 +957,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } else { shuffleIds.values.filter(v => v._2).map(v => v._1).toSeq.reverse.find( areAllMapTasksEnd) match { - case Some(shuffleId) => + case Some(celebornShuffleId) => val pbGetShuffleIdResponse = { logDebug( - s"get shuffleId $shuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") - PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build() + s"get shuffleId $celebornShuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") + PbGetShuffleIdResponse.newBuilder().setShuffleId(celebornShuffleId).setSuccess( + true).build() } context.reply(pbGetShuffleIdResponse) case None => @@ -1160,6 +1164,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds.values.map { case (shuffleId, _) => unregisterShuffle(shuffleId) + unregisterShuffleCallback.foreach(c => c.accept(shuffleId)) }) } } else { @@ -1850,6 +1855,19 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends appShuffleTrackerCallback = Some(callback) } + // expecting celeborn shuffle id and application shuffle identifier + @volatile private var validateCelebornShuffleIdForClean: Option[Consumer[String]] = + None + def registerValidateCelebornShuffleIdForCleanCallback( + callback: Consumer[String]): Unit = { + validateCelebornShuffleIdForClean = Some(callback) + } + + @volatile private var unregisterShuffleCallback: Option[Consumer[Integer]] = None + def registerUnregisterShuffleCallback(callback: Consumer[Integer]): Unit = { + unregisterShuffleCallback = Some(callback) + } + def registerAppShuffleDeterminate(appShuffleId: Int, determinate: Boolean): Unit = { appShuffleDeterminateMap.put(appShuffleId, determinate) } @@ -1943,4 +1961,6 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } }) } + + def getShuffleIdMapping = shuffleIdMapping } diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 45c371ee538..6ee3b1ace0c 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -70,7 +70,7 @@ class ReducePartitionCommitHandler( private val getReducerFileGroupRequest = JavaUtils.newConcurrentHashMap[Int, util.Set[MultiSerdeVersionRpcContext]]() - private val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]() + private[celeborn] val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]() private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]() private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]() private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int, Array[Int]]() diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index d33028a9ab5..86515916794 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -997,6 +997,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientFetchMaxRetriesForEachReplica: Int = get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA) def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED) + def clientFetchCleanFailedShuffle: Boolean = get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE) + def clientFetchCleanFailedShuffleIntervalMS: Long = + get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL) def clientFetchExcludeWorkerOnFailureEnabled: Boolean = get(CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED) def clientFetchExcludedWorkerExpireTimeout: Long = @@ -4813,6 +4816,23 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(true) + val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE: ConfigEntry[Boolean] = + buildConf("celeborn.client.spark.fetch.cleanFailedShuffle") + .categories("client") + .version("0.6.0") + .doc("whether to clean those disk space occupied by shuffles which cannot be fetched") + .booleanConf + .createWithDefault(false) + + val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL: ConfigEntry[Long] = + buildConf("celeborn.client.spark.fetch.cleanFailedShuffleInterval") + .categories("client") + .version("0.6.0") + .doc("the interval to clean the failed-to-fetch shuffle files, only valid when" + + s" ${CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key} is enabled") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("1s") + val CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.client.fetch.excludeWorkerOnFailure.enabled") .categories("client") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 6c0ff752d3e..e4e8e0e83ee 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -111,6 +111,8 @@ license: | | celeborn.client.shuffle.register.filterExcludedWorker.enabled | false | false | Whether to filter excluded worker when register shuffle. | 0.4.0 | | | celeborn.client.shuffle.reviseLostShuffles.enabled | false | false | Whether to revise lost shuffles. | 0.6.0 | | | celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 | | +| celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 | | +| celeborn.client.spark.fetch.cleanFailedShuffleInterval | 1s | false | the interval to clean the failed-to-fetch shuffle files, only valid when celeborn.client.spark.fetch.cleanFailedShuffle is enabled | 0.6.0 | | | celeborn.client.spark.push.dynamicWriteMode.enabled | false | false | Whether to dynamically switch push write mode based on conditions.If true, shuffle mode will be only determined by partition count | 0.5.0 | | | celeborn.client.spark.push.dynamicWriteMode.partitionNum.threshold | 2000 | false | Threshold of shuffle partition number for dynamically switching push writer mode. When the shuffle partition number is greater than this value, use the sort-based shuffle writer for memory efficiency; otherwise use the hash-based shuffle writer for speed. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is true. | 0.5.0 | | | celeborn.client.spark.push.sort.memory.maxMemoryFactor | 0.4 | false | the max portion of executor memory which can be used for SortBasedWriter buffer (only valid when celeborn.client.spark.push.sort.memory.useAdaptiveThreshold is enabled | 0.5.0 | | diff --git a/pom.xml b/pom.xml index 31a38cda12a..e57cb72a95a 100644 --- a/pom.xml +++ b/pom.xml @@ -907,7 +907,7 @@ file:src/test/resources/log4j.properties src/test/resources/log4j2-test.xml ${project.build.directory}/tmp - 1g + 8g ${spark.shuffle.plugin.class} @@ -946,7 +946,7 @@ file:src/test/resources/log4j.properties src/test/resources/log4j2-test.xml ${project.build.directory}/tmp - 1g + 8g ${spark.shuffle.plugin.class} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala new file mode 100644 index 00000000000..936ea696113 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.celeborn.tests.spark + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.sql.SparkSession +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.celeborn.service.deploy.worker.Worker +import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks + +class CelebornFetchFailureDiskCleanSuite extends AnyFunSuite + with SparkTestBase + with BeforeAndAfterEach { + + override def beforeAll(): Unit = { + logInfo("test initialized , setup Celeborn mini cluster") + setupMiniClusterWithRandomPorts(workerNum = 1) + } + + override def beforeEach(): Unit = { + ShuffleClient.reset() + } + + override def afterEach(): Unit = { + System.gc() + } + + override def createWorker(map: Map[String, String]): Worker = { + val storageDir = createTmpDir() + workerDirs = workerDirs :+ storageDir + super.createWorker(map ++ Map("celeborn.master.heartbeat.worker.timeout" -> "10s"), storageDir) + } + + test("celeborn spark integration test - the failed shuffle file is cleaned up correctly") { + if (Spark3OrNewer) { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.stageRerun.enabled", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .config("spark.celeborn.client.spark.fetch.cleanFailedShuffle", "true") + .getOrCreate() + + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderGetHooks( + celebornConf, + workerDirs, + shuffleIdToBeDeleted = Seq(0)) + TestCelebornShuffleManager.registerReaderGetHook(hook) + val checkingThread = + triggerStorageCheckThread(Seq(0), Seq(1), sparkSession) + val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2) + .map { i => (i, i) }.groupByKey(4).collect() + checkStorageValidation(checkingThread) + // verify result + assert(hook.executed.get()) + assert(tuples.length == 10000) + for (elem <- tuples) { + elem._2.foreach(i => assert(i.equals(elem._1))) + } + sparkSession.stop() + } + } + + class CheckingThread( + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + sparkSession: SparkSession) + extends Thread { + var exception: Exception = _ + + protected def checkDirStatus(): Boolean = { + val deletedSuccessfully = shuffleIdShouldNotExist.forall(shuffleId => { + workerDirs.forall(dir => + !new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) + }) + val deletedSuccessfullyString = shuffleIdShouldNotExist.map(shuffleId => { + shuffleId.toString + ":" + + workerDirs.map(dir => + !new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList + }).mkString(",") + val createdSuccessfully = shuffleIdMustExist.forall(shuffleId => { + workerDirs.exists(dir => + new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()) + }) + val createdSuccessfullyString = shuffleIdMustExist.map(shuffleId => { + shuffleId.toString + ":" + + workerDirs.map(dir => + new File(s"$dir/celeborn-worker/shuffle_data/" + + s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList + }).mkString(",") + println(s"shuffle-to-be-deleted status: $deletedSuccessfullyString \n" + + s"shuffle-to-be-created status: $createdSuccessfullyString") + deletedSuccessfully && createdSuccessfully + } + + override def run(): Unit = { + var allDataInShape = checkDirStatus() + while (!allDataInShape) { + Thread.sleep(1000) + allDataInShape = checkDirStatus() + } + } + } + + protected def triggerStorageCheckThread( + shuffleIdShouldNotExist: Seq[Int], + shuffleIdMustExist: Seq[Int], + sparkSession: SparkSession): CheckingThread = { + val checkingThread = + new CheckingThread(shuffleIdShouldNotExist, shuffleIdMustExist, sparkSession) + checkingThread.setDaemon(true) + checkingThread.start() + checkingThread + } + + protected def checkStorageValidation(thread: Thread, timeout: Long = 1200 * 1000): Unit = { + val checkingThread = thread.asInstanceOf[CheckingThread] + checkingThread.join(timeout) + if (checkingThread.isAlive || checkingThread.exception != null) { + throw new IllegalStateException("the storage checking status failed," + + s"${checkingThread.isAlive} ${if (checkingThread.exception != null) checkingThread.exception.getMessage + else "NULL"}") + } + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala index dd0f3840149..9db3912a78f 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.protocol.ShuffleMode +import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks class CelebornFetchFailureSuite extends AnyFunSuite with SparkTestBase @@ -57,7 +58,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs) TestCelebornShuffleManager.registerReaderGetHook(hook) val value = Range(1, 10000).mkString(",") @@ -130,7 +131,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs) TestCelebornShuffleManager.registerReaderGetHook(hook) import sparkSession.implicits._ @@ -161,7 +162,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext @@ -201,7 +202,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index e29b21a0c71..41cbe072b32 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -116,46 +116,4 @@ trait SparkTestBase extends AnyFunSuite val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap outMap } - - class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook { - var executed: AtomicBoolean = new AtomicBoolean(false) - val lock = new Object - - override def exec( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): Unit = { - if (executed.get() == true) return - - lock.synchronized { - handle match { - case h: CelebornShuffleHandle[_, _, _] => { - val appUniqueId = h.appUniqueId - val shuffleClient = ShuffleClient.get( - h.appUniqueId, - h.lifecycleManagerHost, - h.lifecycleManagerPort, - conf, - h.userIdentifier, - h.extension) - val celebornShuffleId = - SparkUtils.celebornShuffleId(shuffleClient, h, context, false) - val allFiles = workerDirs.map(dir => { - new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") - }) - val datafile = allFiles.filter(_.exists()) - .flatMap(_.listFiles().iterator).sortBy(_.getName).headOption - datafile match { - case Some(file) => file.delete() - case None => throw new RuntimeException("unexpected, there must be some data file" + - s" under ${workerDirs.mkString(",")}") - } - } - case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") - } - executed.set(true) - } - } - } } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala new file mode 100644 index 00000000000..adac14242bd --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark.fetch.failure + +import java.io.File +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf + +class ShuffleReaderGetHooks( + conf: CelebornConf, + workerDirs: Seq[String], + shuffleIdToBeDeleted: Seq[Int] = Seq(), + triggerStageId: Option[Int] = None) + extends ShuffleManagerHook { + + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + + private def deleteDataFile(appUniqueId: String, celebornShuffleId: Int): Unit = { + val datafile = + workerDirs.map(dir => { + new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") + }).filter(_.exists()) + .flatMap(_.listFiles().iterator).headOption + datafile match { + case Some(file) => { + file.delete() + } + case None => throw new RuntimeException("unexpected, there must be some data file") + } + } + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + if (executed.get()) { + return + } + lock.synchronized { + handle match { + case h: CelebornShuffleHandle[_, _, _] => { + val appUniqueId = h.appUniqueId + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + val appShuffleIdentifier = + SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId, context) + val Array(_, stageId, _) = appShuffleIdentifier.split('-') + if (triggerStageId.isEmpty || triggerStageId.get == stageId.toInt) { + if (shuffleIdToBeDeleted.isEmpty) { + deleteDataFile(appUniqueId, celebornShuffleId) + } else { + shuffleIdToBeDeleted.foreach { shuffleId => + deleteDataFile(appUniqueId, shuffleId) + } + } + executed.set(true) + } + } + case x => throw new RuntimeException(s"unexpected, only support RssShuffleHandle here," + + s" but get ${x.getClass.getCanonicalName}") + } + } + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 83a5e12f60b..293ef080eb0 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -33,6 +33,7 @@ import org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode} import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.tests.spark.SparkTestBase +import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks class SparkUtilsSuite extends AnyFunSuite with SparkTestBase @@ -60,7 +61,7 @@ class SparkUtilsSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs) TestCelebornShuffleManager.registerReaderGetHook(hook) try {