[DO NOT REVIEW] async shuffle write#13325
Conversation
Signed-off-by: Hongbin Ma (Mahone) <mahongbin@apache.org>
|
build |
There was a problem hiding this comment.
Pull Request Overview
This PR introduces asynchronous write functionality for GPU shuffle operations in the RAPIDS plugin for Spark. The implementation adds background thread-based prefetching to overlap data processing with downstream operations during shuffle writes.
- Adds configuration option
spark.rapids.sql.asyncWrite.shuffle.enabledto control async write behavior - Implements background prefetching using
ThrottlingExecutorwith proper task completion callbacks - Modifies shuffle dependency creation to support async write parameters
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| GpuShuffleExchangeExecBase.scala | Implements async shuffle write logic with background thread prefetching and executor management |
| RapidsConf.scala | Adds configuration option for enabling/disabling async shuffle writes |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| } | ||
| } | ||
| // Use targetBatchSize as memory estimate for the prefetch task | ||
| prefetchFuture = Some(executor.get.submit(callable, targetBatchSize)) |
There was a problem hiding this comment.
Using .get on an Option without checking if it's defined could throw NoSuchElementException. The condition executor.isDefined is checked earlier, but the Option could theoretically become None between the check and this line in a concurrent context.
| prefetchFuture = Some(executor.get.submit(callable, targetBatchSize)) | |
| executor.foreach { exec => | |
| prefetchFuture = Some(exec.submit(callable, targetBatchSize)) | |
| } |
| // Get the prefetched batch | ||
| val batchOpt = prefetchFuture.get.get() | ||
| prefetchFuture = None | ||
| batchOpt |
There was a problem hiding this comment.
Double .get call is unsafe - first .get extracts from Option and second .get() waits for Future result. If prefetchFuture is None, this will throw NoSuchElementException. Consider using pattern matching or explicit checks.
| batchOpt | |
| val batch = if (asyncWriteEnabled) { | |
| prefetchFuture match { | |
| case Some(future) => | |
| val batchOpt = future.get() | |
| prefetchFuture = None | |
| batchOpt | |
| case None => | |
| // Synchronous path | |
| if (iter.hasNext) { | |
| var batch = iter.next() | |
| while (batch.numRows == 0 && iter.hasNext) { | |
| batch.close() | |
| batch = iter.next() | |
| } | |
| if (batch.numRows > 0) { | |
| Some(batch) | |
| } else { | |
| batch.close() | |
| None | |
| } | |
| } else { | |
| None | |
| } | |
| } |
| } | ||
|
|
||
| def close(): Unit = { | ||
| executor.foreach(_.shutdownNow(10, TimeUnit.SECONDS)) |
There was a problem hiding this comment.
The shutdownNow method returns a list of tasks that never started execution, but this return value is ignored. Consider handling the case where tasks don't shutdown gracefully within the timeout period.
| executor.foreach(_.shutdownNow(10, TimeUnit.SECONDS)) | |
| executor.foreach { ex => | |
| val notStarted = ex.shutdownNow(10, TimeUnit.SECONDS) | |
| if (notStarted != null && !notStarted.isEmpty) { | |
| // Log a warning about tasks that never started | |
| GpuShuffleExchangeExecBase.logWarning( | |
| s"GpuShuffleExchangeExecBase: ${notStarted.size} async shuffle write task(s) never started and were cancelled during shutdown.") | |
| } | |
| } |
| } | ||
| } | ||
| // Use targetBatchSize as memory estimate for the prefetch task | ||
| prefetchFuture = Some(executor.get.submit(callable, targetBatchSize)) |
There was a problem hiding this comment.
The startPrefetch method is not thread-safe. Multiple threads could potentially call this method concurrently, leading to race conditions when checking and setting prefetchFuture. Consider adding synchronization or using atomic operations.
| prefetchFuture = Some(executor.get.submit(callable, targetBatchSize)) | |
| prefetchLock.synchronized { | |
| if (asyncWriteEnabled && executor.isDefined && prefetchFuture.isEmpty) { | |
| val callable = new Callable[Option[ColumnarBatch]]() { | |
| override def call(): Option[ColumnarBatch] = { | |
| if (iter.hasNext) { | |
| var batch = iter.next() | |
| while (batch.numRows == 0 && iter.hasNext) { | |
| batch.close() | |
| batch = iter.next() | |
| } | |
| if (batch.numRows > 0) { | |
| Some(batch) | |
| } else { | |
| batch.close() | |
| None | |
| } | |
| } else { | |
| None | |
| } | |
| } | |
| } | |
| // Use targetBatchSize as memory estimate for the prefetch task | |
| prefetchFuture = Some(executor.get.submit(callable, targetBatchSize)) | |
| } |
|
NOTE: release/25.12 has been created from main. Please retarget your PR to release/25.12 if it should be included in the release. |
under construction