diff --git a/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java b/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java
index 25354bf62de..178f6ff19ab 100644
--- a/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java
+++ b/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java
@@ -33,8 +33,26 @@
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.Utils;
+/**
+ * Sort-based pusher for MapReduce shuffle data to Celeborn.
+ *
+ *
This implementation uses primitive int arrays to store record metadata (offsets, key lengths,
+ * value lengths) during data collection to minimize object allocation. During flush, temporary
+ * Record objects are created for sorting and immediately garbage collected in Young Gen.
+ *
+ *
To prevent memory pressure during sorting, the implementation triggers early spill when record
+ * count exceeds a threshold (5M records by default), ensuring temporary objects fit within Young
+ * Gen capacity.
+ */
public class CelebornSortBasedPusher extends OutputStream {
private final Logger logger = LoggerFactory.getLogger(CelebornSortBasedPusher.class);
+
+ // Maximum number of records to accumulate before forcing an early spill.
+ // This limits Young Gen pressure during sorting by preventing too many temporary
+ // Record objects from being created at once.
+ // Record size ~24 bytes, so 5M records = 120MB at peak
+ private static final int MAX_SORT_RECORDS = 5_000_000;
+
private final int mapId;
private final int attempt;
private final int numMappers;
@@ -48,10 +66,11 @@ public class CelebornSortBasedPusher extends OutputStream {
private final AtomicReference exception = new AtomicReference<>();
private final Counters.Counter mapOutputByteCounter;
private final Counters.Counter mapOutputRecordCounter;
- private final Map> partitionedKVs;
+ private final Map partitionedKVBuffers;
private int writePos;
private byte[] serializedKV;
private final int maxPushDataSize;
+ private int totalRecordCount;
public CelebornSortBasedPusher(
int numMappers,
@@ -79,7 +98,7 @@ public CelebornSortBasedPusher(
this.mapOutputRecordCounter = mapOutputRecordCounter;
this.comparator = comparator;
this.shuffleClient = shuffleClient;
- partitionedKVs = new HashMap<>();
+ partitionedKVBuffers = new HashMap<>();
serializedKV = new byte[maxIOBufferSize];
maxPushDataSize = (int) celebornConf.clientMrMaxPushData();
logger.info(
@@ -102,6 +121,7 @@ public CelebornSortBasedPusher(
public void insert(K key, V value, int partition) {
try {
+ // Check if we should spill based on buffer size
if (writePos >= spillIOBufferSize) {
// needs to sort and flush data
if (logger.isDebugEnabled()) {
@@ -114,6 +134,20 @@ public void insert(K key, V value, int partition) {
sortKVs();
sendKVAndUpdateWritePos();
}
+
+ // Additional check: limit total record count to avoid memory pressure during sort
+ // If total records exceed safe threshold, force an early spill
+ if (totalRecordCount >= MAX_SORT_RECORDS && writePos > 0) {
+ if (logger.isDebugEnabled()) {
+ logger.debug(
+ "Record count {} exceeds safe threshold {}, forcing early spill",
+ totalRecordCount,
+ MAX_SORT_RECORDS);
+ }
+ sortKVs();
+ sendKVAndUpdateWritePos();
+ }
+
int dataLen = insertRecordInternal(key, value, partition);
if (logger.isDebugEnabled()) {
logger.debug(
@@ -127,45 +161,64 @@ public void insert(K key, V value, int partition) {
}
private void sendKVAndUpdateWritePos() throws IOException {
- Iterator>> entryIter =
- partitionedKVs.entrySet().iterator();
+ Iterator> entryIter =
+ partitionedKVBuffers.entrySet().iterator();
while (entryIter.hasNext()) {
- Map.Entry> entry = entryIter.next();
+ Map.Entry entry = entryIter.next();
entryIter.remove();
int partition = entry.getKey();
- List kvs = entry.getValue();
- List localKVs = new ArrayList<>();
+ KVBufferInfo bufferInfo = entry.getValue();
int partitionKVTotalLen = 0;
- // process buffers for specific partition
- for (SerializedKV kv : kvs) {
- partitionKVTotalLen += kv.kLen + kv.vLen;
- localKVs.add(kv);
- if (partitionKVTotalLen > maxPushDataSize) {
- // limit max size of pushdata to avoid possible memory issue in Celeborn worker
- // data layout
- // pushdata header (16) + pushDataLen(4) +
- // [varKeyLen+varValLen+serializedRecord(x)][...]
- sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);
- localKVs.clear();
- partitionKVTotalLen = 0;
+ int batchStartIdx = 0;
+
+ // Process buffers for specific partition (arrays are already sorted in-place)
+ for (int i = 0; i < bufferInfo.count; i++) {
+ int recordLen = bufferInfo.keyLens[i] + bufferInfo.valueLens[i];
+
+ // Check if adding this record would exceed the limit
+ // This ensures we never send batches larger than maxPushDataSize
+ if (partitionKVTotalLen + recordLen > maxPushDataSize && partitionKVTotalLen > 0) {
+ // Send the previous batch (before adding current record)
+ int batchLength = 0;
+ for (int j = batchStartIdx; j < i; j++) {
+ batchLength += bufferInfo.keyLens[j] + bufferInfo.valueLens[j];
+ }
+ sendSortedBuffersPartition(
+ partition, bufferInfo, batchStartIdx, i - batchStartIdx, batchLength);
+ // Start new batch with current record
+ batchStartIdx = i;
+ partitionKVTotalLen = recordLen;
+ } else {
+ // Add record to current batch
+ partitionKVTotalLen += recordLen;
}
}
- if (!localKVs.isEmpty()) {
- sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);
+
+ // Send remaining records
+ if (batchStartIdx < bufferInfo.count) {
+ int batchLength = 0;
+ for (int i = batchStartIdx; i < bufferInfo.count; i++) {
+ batchLength += bufferInfo.keyLens[i] + bufferInfo.valueLens[i];
+ }
+ sendSortedBuffersPartition(
+ partition, bufferInfo, batchStartIdx, bufferInfo.count - batchStartIdx, batchLength);
}
- kvs.clear();
+ // Clear buffer info for reuse
+ bufferInfo.clear();
}
- // all data sent
- partitionedKVs.clear();
+ // All data sent, reset counters
+ partitionedKVBuffers.clear();
writePos = 0;
+ totalRecordCount = 0;
}
private void sendSortedBuffersPartition(
- int partition, List localKVs, int partitionKVTotalLen) throws IOException {
+ int partition, KVBufferInfo bufferInfo, int startIdx, int count, int partitionKVTotalLen)
+ throws IOException {
int extraSize = 0;
- for (SerializedKV localKV : localKVs) {
- extraSize += WritableUtils.getVIntSize(localKV.kLen);
- extraSize += WritableUtils.getVIntSize(localKV.vLen);
+ for (int i = startIdx; i < startIdx + count; i++) {
+ extraSize += WritableUtils.getVIntSize(bufferInfo.keyLens[i]);
+ extraSize += WritableUtils.getVIntSize(bufferInfo.valueLens[i]);
}
// copied from hadoop logic
extraSize += WritableUtils.getVIntSize(-1);
@@ -174,14 +227,16 @@ private void sendSortedBuffersPartition(
byte[] pkvs = new byte[4 + extraSize + partitionKVTotalLen];
int pkvsPos = 4;
Platform.putInt(pkvs, Platform.BYTE_ARRAY_OFFSET, partitionKVTotalLen + extraSize);
- for (SerializedKV kv : localKVs) {
- int recordLen = kv.kLen + kv.vLen;
+ for (int i = startIdx; i < startIdx + count; i++) {
+ int kLen = bufferInfo.keyLens[i];
+ int vLen = bufferInfo.valueLens[i];
+ int recordLen = kLen + vLen;
// write key len
- pkvsPos = writeVLong(pkvs, pkvsPos, kv.kLen);
+ pkvsPos = writeVLong(pkvs, pkvsPos, kLen);
// write value len
- pkvsPos = writeVLong(pkvs, pkvsPos, kv.vLen);
+ pkvsPos = writeVLong(pkvs, pkvsPos, vLen);
// write serialized record
- System.arraycopy(serializedKV, kv.offset, pkvs, pkvsPos, recordLen);
+ System.arraycopy(serializedKV, bufferInfo.offsets[i], pkvs, pkvsPos, recordLen);
pkvsPos += recordLen;
}
// finally write -1 two times
@@ -245,13 +300,71 @@ private int writeVLong(byte[] data, int offset, long dataInt) {
}
private void sortKVs() {
- for (Map.Entry> partitionKVEntry : partitionedKVs.entrySet()) {
- partitionKVEntry
- .getValue()
- .sort(
- (o1, o2) ->
- comparator.compare(
- serializedKV, o1.offset, o1.kLen, serializedKV, o2.offset, o2.kLen));
+ for (Map.Entry partitionKVEntry : partitionedKVBuffers.entrySet()) {
+ KVBufferInfo bufferInfo = partitionKVEntry.getValue();
+ if (bufferInfo.count <= 1) {
+ continue;
+ }
+
+ // The early-spill mechanism in insert() should prevent us from ever exceeding
+ // MAX_SORT_RECORDS.
+ // If we do exceed it, this means the configuration allows too much data to accumulate.
+ // We rely on the early-spill check to keep us within safe limits, so we can always do a full
+ // sort.
+ sortBatch(bufferInfo, 0, bufferInfo.count);
+ }
+ }
+
+ /** Sort a batch of records from start (inclusive) to end (exclusive). */
+ private void sortBatch(KVBufferInfo bufferInfo, int start, int end) {
+ int size = end - start;
+
+ // Create temporary Record objects
+ Record[] records = new Record[size];
+ for (int i = 0; i < size; i++) {
+ records[i] =
+ new Record(
+ serializedKV,
+ comparator,
+ bufferInfo.offsets[start + i],
+ bufferInfo.keyLens[start + i],
+ bufferInfo.valueLens[start + i]);
+ }
+
+ // Sort using Arrays.sort
+ Arrays.sort(records);
+
+ // Write back sorted results
+ for (int i = 0; i < size; i++) {
+ bufferInfo.offsets[start + i] = records[i].offset;
+ bufferInfo.keyLens[start + i] = records[i].kLen;
+ bufferInfo.valueLens[start + i] = records[i].vLen;
+ }
+ }
+
+ /**
+ * Temporary record for sorting. These objects are created only during sort, then garbage
+ * collected in Young Gen. Static class to avoid holding reference to outer class instance.
+ */
+ private static class Record implements Comparable {
+ private final byte[] serializedKV;
+ private final RawComparator comparator;
+ final int offset;
+ final int kLen;
+ final int vLen;
+
+ Record(byte[] serializedKV, RawComparator comparator, int offset, int kLen, int vLen) {
+ this.serializedKV = serializedKV;
+ this.comparator = comparator;
+ this.offset = offset;
+ this.kLen = kLen;
+ this.vLen = vLen;
+ }
+
+ @Override
+ public int compareTo(Record other) {
+ return comparator.compare(
+ serializedKV, offset, kLen, other.serializedKV, other.offset, other.kLen);
}
}
@@ -263,9 +376,12 @@ private int insertRecordInternal(K key, V value, int partition) throws IOExcepti
keyLen = writePos - offset;
vSer.serialize(value);
valLen = writePos - keyLen - offset;
- List serializedKVs =
- partitionedKVs.computeIfAbsent(partition, v -> new ArrayList<>());
- serializedKVs.add(new SerializedKV(offset, keyLen, valLen));
+ KVBufferInfo bufferInfo =
+ partitionedKVBuffers.computeIfAbsent(
+ partition, v -> new KVBufferInfo(1024)); // Initial capacity: 1024 records
+ // Store metadata directly in primitive arrays, no object allocation
+ bufferInfo.add(offset, keyLen, valLen);
+ totalRecordCount++; // Track total records across all partitions
if (logger.isDebugEnabled()) {
logger.debug(
"Pusher insert into buffer partition:{} offset:{} keyLen:{} valueLen:{} size:{}",
@@ -273,7 +389,7 @@ private int insertRecordInternal(K key, V value, int partition) throws IOExcepti
offset,
keyLen,
valLen,
- partitionedKVs.size());
+ partitionedKVBuffers.size());
}
return keyLen + valLen;
}
@@ -320,19 +436,63 @@ public void close() {
} catch (IOException e) {
exception.compareAndSet(null, e);
}
- partitionedKVs.clear();
+ partitionedKVBuffers.clear();
+ totalRecordCount = 0;
serializedKV = null;
}
- static class SerializedKV {
- final int offset;
- final int kLen;
- final int vLen;
+ /**
+ * Buffer info to manage serialized key-value records for each partition. Uses primitive int
+ * arrays to store metadata instead of object arrays, significantly reducing memory overhead.
+ *
+ * Memory comparison for 1 million records: - ArrayList: ~32MB (8MB references +
+ * 24MB objects) - This approach: ~12MB (4MBĂ—3 int arrays)
+ *
+ * Saves 62.5% memory!
+ */
+ static class KVBufferInfo {
+ int[] offsets; // Store key offset in serializedKV buffer
+ int[] keyLens; // Store key length
+ int[] valueLens; // Store value length
+ int count;
+ int capacity;
- public SerializedKV(int offset, int kLen, int vLen) {
- this.offset = offset;
- this.kLen = kLen;
- this.vLen = vLen;
+ KVBufferInfo(int initialCapacity) {
+ this.offsets = new int[initialCapacity];
+ this.keyLens = new int[initialCapacity];
+ this.valueLens = new int[initialCapacity];
+ this.capacity = initialCapacity;
+ this.count = 0;
+ }
+
+ void add(int offset, int kLen, int vLen) {
+ if (count >= capacity) {
+ // Expand arrays with 2x growth strategy
+ int newCapacity = capacity * 2;
+
+ int[] newOffsets = new int[newCapacity];
+ int[] newKeyLens = new int[newCapacity];
+ int[] newValueLens = new int[newCapacity];
+
+ System.arraycopy(offsets, 0, newOffsets, 0, count);
+ System.arraycopy(keyLens, 0, newKeyLens, 0, count);
+ System.arraycopy(valueLens, 0, newValueLens, 0, count);
+
+ offsets = newOffsets;
+ keyLens = newKeyLens;
+ valueLens = newValueLens;
+ capacity = newCapacity;
+ }
+ offsets[count] = offset;
+ keyLens[count] = kLen;
+ valueLens[count] = vLen;
+ count++;
+ }
+
+ void clear() {
+ count = 0;
+ // Note: We don't clear the arrays themselves to avoid overhead
+ // They will be overwritten as new data is added
}
}
}