Skip to content

Commit 5a50c71

Browse files
committed
[CELEBORN-2315] Add iterator fully-consumed validation after shuffle write
Adds a post-write safety check to HashBasedShuffleWriter and SortBasedShuffleWriter: after the write loop completes, verify the input iterator was fully consumed. If records remain, kill the task with TaskKilledException. This guards against silent data loss.
1 parent 69df893 commit 5a50c71

9 files changed

Lines changed: 165 additions & 46 deletions

File tree

client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/TaskInterruptedHelper.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.shuffle.celeborn;
1919

20+
import scala.Option;
21+
2022
import org.apache.spark.TaskContext;
2123
import org.apache.spark.TaskKilledException;
2224

@@ -30,8 +32,17 @@ public class TaskInterruptedHelper {
3032
* kill reason, so here we throw the TaskKilledException.
3133
*/
3234
public static void throwTaskKillException() {
33-
if (TaskContext.get().getKillReason().isDefined()) {
34-
throw new TaskKilledException(TaskContext.get().getKillReason().get());
35+
throwTaskKillException(null);
36+
}
37+
38+
public static void throwTaskKillException(String message) {
39+
Option<String> sparkReason = TaskContext.get().getKillReason();
40+
if (sparkReason.isDefined() && message != null) {
41+
throw new TaskKilledException(sparkReason.get() + "; " + message);
42+
} else if (sparkReason.isDefined()) {
43+
throw new TaskKilledException(sparkReason.get());
44+
} else if (message != null) {
45+
throw new TaskKilledException(message);
3546
} else {
3647
throw new TaskKilledException();
3748
}

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,8 @@ public HashBasedShuffleWriter(
166166
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
167167
boolean needCleanupPusher = true;
168168
try {
169-
if (canUseFastWrite()) {
170-
fastWrite0(records);
171-
} else if (dep.mapSideCombine()) {
172-
if (dep.aggregator().isEmpty()) {
173-
throw new UnsupportedOperationException(
174-
"When using map side combine, an aggregator must be specified.");
175-
}
176-
write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
177-
} else {
178-
write0(records);
179-
}
180-
close();
169+
boolean iteratorHasNext = doWrite(records);
170+
close(iteratorHasNext);
181171
needCleanupPusher = false;
182172
} catch (InterruptedException e) {
183173
TaskInterruptedHelper.throwTaskKillException();
@@ -188,6 +178,26 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
188178
}
189179
}
190180

181+
boolean doWrite(scala.collection.Iterator<Product2<K, V>> records)
182+
throws IOException, InterruptedException {
183+
if (canUseFastWrite()) {
184+
fastWrite0(records);
185+
return records.hasNext();
186+
} else if (dep.mapSideCombine()) {
187+
if (dep.aggregator().isEmpty()) {
188+
throw new UnsupportedOperationException(
189+
"When using map side combine, an aggregator must be specified.");
190+
}
191+
scala.collection.Iterator combinedIterator =
192+
dep.aggregator().get().combineValuesByKey(records, taskContext);
193+
write0(combinedIterator);
194+
return combinedIterator.hasNext();
195+
} else {
196+
write0(records);
197+
return records.hasNext();
198+
}
199+
}
200+
191201
@VisibleForTesting
192202
boolean canUseFastWrite() {
193203
return unsafeRowFastWrite
@@ -331,7 +341,7 @@ private void cleanupPusher() throws IOException {
331341
}
332342
}
333343

334-
private void close() throws IOException, InterruptedException {
344+
private void close(boolean iteratorHasNext) throws IOException, InterruptedException {
335345
// merge and push residual data to reduce network traffic
336346
// NB: since dataPusher thread have no in-flight data at this point,
337347
// we now push merged data by task thread will not introduce any contention
@@ -359,6 +369,8 @@ private void close() throws IOException, InterruptedException {
359369

360370
updateMapStatus();
361371

372+
SparkUtils.assertIteratorFullyConsumed(iteratorHasNext);
373+
362374
sendBufferPool.returnBuffer(sendBuffers);
363375
sendBuffers = null;
364376
sendOffsets = null;

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,8 @@ public SortBasedShuffleWriter(
147147
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
148148
boolean needCleanupPusher = true;
149149
try {
150-
if (canUseFastWrite()) {
151-
fastWrite0(records);
152-
} else if (dep.mapSideCombine()) {
153-
if (dep.aggregator().isEmpty()) {
154-
throw new UnsupportedOperationException(
155-
"When using map side combine, an aggregator must be specified.");
156-
}
157-
write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
158-
} else {
159-
write0(records);
160-
}
161-
close();
150+
boolean iteratorHasNext = doWrite(records);
151+
close(iteratorHasNext);
162152
needCleanupPusher = false;
163153
} finally {
164154
if (needCleanupPusher) {
@@ -167,6 +157,25 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
167157
}
168158
}
169159

160+
boolean doWrite(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
161+
if (canUseFastWrite()) {
162+
fastWrite0(records);
163+
return records.hasNext();
164+
} else if (dep.mapSideCombine()) {
165+
if (dep.aggregator().isEmpty()) {
166+
throw new UnsupportedOperationException(
167+
"When using map side combine, an aggregator must be specified.");
168+
}
169+
scala.collection.Iterator combinedIterator =
170+
dep.aggregator().get().combineValuesByKey(records, taskContext);
171+
write0(combinedIterator);
172+
return combinedIterator.hasNext();
173+
} else {
174+
write0(records);
175+
return records.hasNext();
176+
}
177+
}
178+
170179
@VisibleForTesting
171180
boolean canUseFastWrite() {
172181
return unsafeRowFastWrite
@@ -304,7 +313,7 @@ private void cleanupPusher() throws IOException {
304313
}
305314
}
306315

307-
private void close() throws IOException {
316+
private void close(boolean iteratorHasNext) throws IOException {
308317
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
309318
long pushStartTime = System.nanoTime();
310319
pusher.pushData(false);
@@ -315,6 +324,8 @@ private void close() throws IOException {
315324

316325
updateMapStatus();
317326

327+
SparkUtils.assertIteratorFullyConsumed(iteratorHasNext);
328+
318329
long waitStartTime = System.nanoTime();
319330
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
320331
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,4 +543,11 @@ public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuff
543543
return null;
544544
});
545545
}
546+
547+
public static void assertIteratorFullyConsumed(boolean iteratorHasNext) {
548+
if (iteratorHasNext) {
549+
TaskInterruptedHelper.throwTaskKillException(
550+
"Shuffle write task finished but iterator was not fully consumed.");
551+
}
552+
}
546553
}

client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
import org.apache.spark.SparkConf;
4444
import org.apache.spark.SparkEnv;
4545
import org.apache.spark.TaskContext;
46+
import org.apache.spark.TaskContext$;
47+
import org.apache.spark.TaskKilledException;
4648
import org.apache.spark.executor.ShuffleWriteMetrics;
4749
import org.apache.spark.executor.TaskMetrics;
4850
import org.apache.spark.memory.TaskMemoryManager;
@@ -214,6 +216,24 @@ public void testGiantRecordAndMergeSmallBlockWithFastWrite() throws Exception {
214216
check(2 << 30, conf, serializer);
215217
}
216218

219+
@Test
220+
public void testAssertIteratorFullyConsumed() {
221+
SparkUtils.assertIteratorFullyConsumed(false);
222+
}
223+
224+
@Test
225+
public void testAssertIteratorFullyConsumedThrows() {
226+
TaskContext$.MODULE$.setTaskContext(taskContext);
227+
try {
228+
SparkUtils.assertIteratorFullyConsumed(true);
229+
fail("Expected TaskKilledException when iterator is not fully consumed");
230+
} catch (TaskKilledException e) {
231+
assertTrue(e.getMessage().contains("not fully consumed"));
232+
} finally {
233+
TaskContext$.MODULE$.setTaskContext(null);
234+
}
235+
}
236+
217237
private void check(
218238
final int approximateSize, final CelebornConf conf, final Serializer serializer)
219239
throws Exception {

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,8 @@ public HashBasedShuffleWriter(
162162
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
163163
boolean needCleanupPusher = true;
164164
try {
165-
if (canUseFastWrite()) {
166-
fastWrite0(records);
167-
} else if (dep.mapSideCombine()) {
168-
if (dep.aggregator().isEmpty()) {
169-
throw new UnsupportedOperationException(
170-
"When using map side combine, an aggregator must be specified.");
171-
}
172-
write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
173-
} else {
174-
write0(records);
175-
}
176-
close();
165+
boolean iteratorHasNext = doWrite(records);
166+
close(iteratorHasNext);
177167
needCleanupPusher = false;
178168
} catch (InterruptedException e) {
179169
TaskInterruptedHelper.throwTaskKillException();
@@ -184,6 +174,26 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
184174
}
185175
}
186176

177+
boolean doWrite(scala.collection.Iterator<Product2<K, V>> records)
178+
throws IOException, InterruptedException {
179+
if (canUseFastWrite()) {
180+
fastWrite0(records);
181+
return records.hasNext();
182+
} else if (dep.mapSideCombine()) {
183+
if (dep.aggregator().isEmpty()) {
184+
throw new UnsupportedOperationException(
185+
"When using map side combine, an aggregator must be specified.");
186+
}
187+
scala.collection.Iterator combinedIterator =
188+
dep.aggregator().get().combineValuesByKey(records, taskContext);
189+
write0(combinedIterator);
190+
return combinedIterator.hasNext();
191+
} else {
192+
write0(records);
193+
return records.hasNext();
194+
}
195+
}
196+
187197
@VisibleForTesting
188198
boolean canUseFastWrite() {
189199
boolean keyIsPartitionId = false;
@@ -366,14 +376,16 @@ private void cleanupPusher() throws IOException {
366376
}
367377
}
368378

369-
private void close() throws IOException, InterruptedException {
379+
private void close(boolean iteratorHasNext) throws IOException, InterruptedException {
370380
// Send the remaining data in sendBuffer
371381
long pushMergedDataTime = System.nanoTime();
372382
closeWrite();
373383
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
374384
writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
375385
updateRecordsWrittenMetrics();
376386

387+
SparkUtils.assertIteratorFullyConsumed(iteratorHasNext);
388+
377389
long waitStartTime = System.nanoTime();
378390
dataPusher.waitOnTermination();
379391
sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,26 +212,33 @@ public long getPeakMemoryUsedBytes() {
212212
return peakMemoryUsedBytes;
213213
}
214214

215-
void doWrite(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
215+
// Returns true if the iterator still has records (i.e., not fully consumed)
216+
@VisibleForTesting
217+
boolean doWrite(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
216218
if (canUseFastWrite()) {
217219
fastWrite0(records);
220+
return records.hasNext();
218221
} else if (dep.mapSideCombine()) {
219222
if (dep.aggregator().isEmpty()) {
220223
throw new UnsupportedOperationException(
221224
"When using map side combine, an aggregator must be specified.");
222225
}
223-
write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
226+
scala.collection.Iterator combinedIterator =
227+
dep.aggregator().get().combineValuesByKey(records, taskContext);
228+
write0(combinedIterator);
229+
return combinedIterator.hasNext();
224230
} else {
225231
write0(records);
232+
return records.hasNext();
226233
}
227234
}
228235

229236
@Override
230237
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
231238
boolean needCleanupPusher = true;
232239
try {
233-
doWrite(records);
234-
close();
240+
boolean iteratorHasNext = doWrite(records);
241+
close(iteratorHasNext);
235242
needCleanupPusher = false;
236243
} finally {
237244
if (needCleanupPusher) {
@@ -371,7 +378,7 @@ private void cleanupPusher() throws IOException {
371378
}
372379
}
373380

374-
private void close() throws IOException {
381+
private void close(boolean iteratorHasNext) throws IOException {
375382
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
376383
long pushStartTime = System.nanoTime();
377384
pusher.pushData(false);
@@ -381,6 +388,8 @@ private void close() throws IOException {
381388
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
382389
writeMetrics.incRecordsWritten(tmpRecordsWritten);
383390

391+
SparkUtils.assertIteratorFullyConsumed(iteratorHasNext);
392+
384393
long waitStartTime = System.nanoTime();
385394
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
386395
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,4 +716,18 @@ public static boolean isLocalMaster(SparkConf conf) {
716716
String master = conf.get("spark.master", "");
717717
return master.equals("local") || master.startsWith("local[");
718718
}
719+
720+
/**
721+
* Asserts that the shuffle writer's iterator has been fully consumed. Only call this when the
722+
* shuffle writer finishes writing records. If records remain in the iterator, the task will be
723+
* killed.
724+
*
725+
* @param iteratorHasNext true if the iterator still has records remaining
726+
*/
727+
public static void assertIteratorFullyConsumed(boolean iteratorHasNext) {
728+
if (iteratorHasNext) {
729+
TaskInterruptedHelper.throwTaskKillException(
730+
"Shuffle write task finished but iterator was not fully consumed.");
731+
}
732+
}
719733
}

client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
import org.apache.spark.SparkEnv;
4545
import org.apache.spark.SparkVersionUtil;
4646
import org.apache.spark.TaskContext;
47+
import org.apache.spark.TaskContext$;
48+
import org.apache.spark.TaskKilledException;
4749
import org.apache.spark.executor.ShuffleWriteMetrics;
4850
import org.apache.spark.executor.TaskMetrics;
4951
import org.apache.spark.memory.TaskMemoryManager;
@@ -148,6 +150,8 @@ public void beforeEach() {
148150
Mockito.doReturn(bmId).when(blockManager).shuffleServerId();
149151
Mockito.doReturn(blockManager).when(env).blockManager();
150152
Mockito.doReturn(sparkConf).when(env).conf();
153+
Mockito.doReturn(false).when(dependency).mapSideCombine();
154+
Mockito.doReturn(Option.empty()).when(dependency).aggregator();
151155
SparkEnv.set(env);
152156
}
153157

@@ -213,6 +217,25 @@ public void testGiantRecordAndMergeSmallBlockWithFastWrite() throws Exception {
213217
check(2 << 30, conf, serializer);
214218
}
215219

220+
@Test
221+
public void testAssertIteratorFullyConsumed() {
222+
// Test that assertIteratorFullyConsumed does not throw when iterator is empty
223+
SparkUtils.assertIteratorFullyConsumed(false);
224+
}
225+
226+
@Test
227+
public void testAssertIteratorFullyConsumedThrows() {
228+
TaskContext$.MODULE$.setTaskContext(taskContext);
229+
try {
230+
SparkUtils.assertIteratorFullyConsumed(true);
231+
fail("Expected TaskKilledException when iterator is not fully consumed");
232+
} catch (TaskKilledException e) {
233+
assertTrue(e.getMessage().contains("not fully consumed"));
234+
} finally {
235+
TaskContext$.MODULE$.setTaskContext(null);
236+
}
237+
}
238+
216239
private void check(
217240
final int approximateSize, final CelebornConf conf, final Serializer serializer)
218241
throws Exception {

0 commit comments

Comments
 (0)