diff --git a/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java b/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java index 25354bf62de..178f6ff19ab 100644 --- a/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java +++ b/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java @@ -33,8 +33,26 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.Utils; +/** + * Sort-based pusher for MapReduce shuffle data to Celeborn. + * + *

This implementation uses primitive int arrays to store record metadata (offsets, key lengths, + * value lengths) during data collection to minimize object allocation. During flush, temporary + * Record objects are created for sorting and immediately garbage collected in Young Gen. + * + *

To prevent memory pressure during sorting, the implementation triggers early spill when record + * count exceeds a threshold (5M records by default), ensuring temporary objects fit within Young + * Gen capacity. + */ public class CelebornSortBasedPusher extends OutputStream { private final Logger logger = LoggerFactory.getLogger(CelebornSortBasedPusher.class); + + // Maximum number of records to accumulate before forcing an early spill. + // This limits Young Gen pressure during sorting by preventing too many temporary + // Record objects from being created at once. + // Record size ~24 bytes, so 5M records = 120MB at peak + private static final int MAX_SORT_RECORDS = 5_000_000; + private final int mapId; private final int attempt; private final int numMappers; @@ -48,10 +66,11 @@ public class CelebornSortBasedPusher extends OutputStream { private final AtomicReference exception = new AtomicReference<>(); private final Counters.Counter mapOutputByteCounter; private final Counters.Counter mapOutputRecordCounter; - private final Map> partitionedKVs; + private final Map partitionedKVBuffers; private int writePos; private byte[] serializedKV; private final int maxPushDataSize; + private int totalRecordCount; public CelebornSortBasedPusher( int numMappers, @@ -79,7 +98,7 @@ public CelebornSortBasedPusher( this.mapOutputRecordCounter = mapOutputRecordCounter; this.comparator = comparator; this.shuffleClient = shuffleClient; - partitionedKVs = new HashMap<>(); + partitionedKVBuffers = new HashMap<>(); serializedKV = new byte[maxIOBufferSize]; maxPushDataSize = (int) celebornConf.clientMrMaxPushData(); logger.info( @@ -102,6 +121,7 @@ public CelebornSortBasedPusher( public void insert(K key, V value, int partition) { try { + // Check if we should spill based on buffer size if (writePos >= spillIOBufferSize) { // needs to sort and flush data if (logger.isDebugEnabled()) { @@ -114,6 +134,20 @@ public void insert(K key, V value, int partition) { sortKVs(); sendKVAndUpdateWritePos(); } + + // Additional check: limit total record count to avoid memory pressure during sort + // If total records exceed safe threshold, force an early spill + if (totalRecordCount >= MAX_SORT_RECORDS && writePos > 0) { + if (logger.isDebugEnabled()) { + logger.debug( + "Record count {} exceeds safe threshold {}, forcing early spill", + totalRecordCount, + MAX_SORT_RECORDS); + } + sortKVs(); + sendKVAndUpdateWritePos(); + } + int dataLen = insertRecordInternal(key, value, partition); if (logger.isDebugEnabled()) { logger.debug( @@ -127,45 +161,64 @@ public void insert(K key, V value, int partition) { } private void sendKVAndUpdateWritePos() throws IOException { - Iterator>> entryIter = - partitionedKVs.entrySet().iterator(); + Iterator> entryIter = + partitionedKVBuffers.entrySet().iterator(); while (entryIter.hasNext()) { - Map.Entry> entry = entryIter.next(); + Map.Entry entry = entryIter.next(); entryIter.remove(); int partition = entry.getKey(); - List kvs = entry.getValue(); - List localKVs = new ArrayList<>(); + KVBufferInfo bufferInfo = entry.getValue(); int partitionKVTotalLen = 0; - // process buffers for specific partition - for (SerializedKV kv : kvs) { - partitionKVTotalLen += kv.kLen + kv.vLen; - localKVs.add(kv); - if (partitionKVTotalLen > maxPushDataSize) { - // limit max size of pushdata to avoid possible memory issue in Celeborn worker - // data layout - // pushdata header (16) + pushDataLen(4) + - // [varKeyLen+varValLen+serializedRecord(x)][...] - sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen); - localKVs.clear(); - partitionKVTotalLen = 0; + int batchStartIdx = 0; + + // Process buffers for specific partition (arrays are already sorted in-place) + for (int i = 0; i < bufferInfo.count; i++) { + int recordLen = bufferInfo.keyLens[i] + bufferInfo.valueLens[i]; + + // Check if adding this record would exceed the limit + // This ensures we never send batches larger than maxPushDataSize + if (partitionKVTotalLen + recordLen > maxPushDataSize && partitionKVTotalLen > 0) { + // Send the previous batch (before adding current record) + int batchLength = 0; + for (int j = batchStartIdx; j < i; j++) { + batchLength += bufferInfo.keyLens[j] + bufferInfo.valueLens[j]; + } + sendSortedBuffersPartition( + partition, bufferInfo, batchStartIdx, i - batchStartIdx, batchLength); + // Start new batch with current record + batchStartIdx = i; + partitionKVTotalLen = recordLen; + } else { + // Add record to current batch + partitionKVTotalLen += recordLen; } } - if (!localKVs.isEmpty()) { - sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen); + + // Send remaining records + if (batchStartIdx < bufferInfo.count) { + int batchLength = 0; + for (int i = batchStartIdx; i < bufferInfo.count; i++) { + batchLength += bufferInfo.keyLens[i] + bufferInfo.valueLens[i]; + } + sendSortedBuffersPartition( + partition, bufferInfo, batchStartIdx, bufferInfo.count - batchStartIdx, batchLength); } - kvs.clear(); + // Clear buffer info for reuse + bufferInfo.clear(); } - // all data sent - partitionedKVs.clear(); + // All data sent, reset counters + partitionedKVBuffers.clear(); writePos = 0; + totalRecordCount = 0; } private void sendSortedBuffersPartition( - int partition, List localKVs, int partitionKVTotalLen) throws IOException { + int partition, KVBufferInfo bufferInfo, int startIdx, int count, int partitionKVTotalLen) + throws IOException { int extraSize = 0; - for (SerializedKV localKV : localKVs) { - extraSize += WritableUtils.getVIntSize(localKV.kLen); - extraSize += WritableUtils.getVIntSize(localKV.vLen); + for (int i = startIdx; i < startIdx + count; i++) { + extraSize += WritableUtils.getVIntSize(bufferInfo.keyLens[i]); + extraSize += WritableUtils.getVIntSize(bufferInfo.valueLens[i]); } // copied from hadoop logic extraSize += WritableUtils.getVIntSize(-1); @@ -174,14 +227,16 @@ private void sendSortedBuffersPartition( byte[] pkvs = new byte[4 + extraSize + partitionKVTotalLen]; int pkvsPos = 4; Platform.putInt(pkvs, Platform.BYTE_ARRAY_OFFSET, partitionKVTotalLen + extraSize); - for (SerializedKV kv : localKVs) { - int recordLen = kv.kLen + kv.vLen; + for (int i = startIdx; i < startIdx + count; i++) { + int kLen = bufferInfo.keyLens[i]; + int vLen = bufferInfo.valueLens[i]; + int recordLen = kLen + vLen; // write key len - pkvsPos = writeVLong(pkvs, pkvsPos, kv.kLen); + pkvsPos = writeVLong(pkvs, pkvsPos, kLen); // write value len - pkvsPos = writeVLong(pkvs, pkvsPos, kv.vLen); + pkvsPos = writeVLong(pkvs, pkvsPos, vLen); // write serialized record - System.arraycopy(serializedKV, kv.offset, pkvs, pkvsPos, recordLen); + System.arraycopy(serializedKV, bufferInfo.offsets[i], pkvs, pkvsPos, recordLen); pkvsPos += recordLen; } // finally write -1 two times @@ -245,13 +300,71 @@ private int writeVLong(byte[] data, int offset, long dataInt) { } private void sortKVs() { - for (Map.Entry> partitionKVEntry : partitionedKVs.entrySet()) { - partitionKVEntry - .getValue() - .sort( - (o1, o2) -> - comparator.compare( - serializedKV, o1.offset, o1.kLen, serializedKV, o2.offset, o2.kLen)); + for (Map.Entry partitionKVEntry : partitionedKVBuffers.entrySet()) { + KVBufferInfo bufferInfo = partitionKVEntry.getValue(); + if (bufferInfo.count <= 1) { + continue; + } + + // The early-spill mechanism in insert() should prevent us from ever exceeding + // MAX_SORT_RECORDS. + // If we do exceed it, this means the configuration allows too much data to accumulate. + // We rely on the early-spill check to keep us within safe limits, so we can always do a full + // sort. + sortBatch(bufferInfo, 0, bufferInfo.count); + } + } + + /** Sort a batch of records from start (inclusive) to end (exclusive). */ + private void sortBatch(KVBufferInfo bufferInfo, int start, int end) { + int size = end - start; + + // Create temporary Record objects + Record[] records = new Record[size]; + for (int i = 0; i < size; i++) { + records[i] = + new Record( + serializedKV, + comparator, + bufferInfo.offsets[start + i], + bufferInfo.keyLens[start + i], + bufferInfo.valueLens[start + i]); + } + + // Sort using Arrays.sort + Arrays.sort(records); + + // Write back sorted results + for (int i = 0; i < size; i++) { + bufferInfo.offsets[start + i] = records[i].offset; + bufferInfo.keyLens[start + i] = records[i].kLen; + bufferInfo.valueLens[start + i] = records[i].vLen; + } + } + + /** + * Temporary record for sorting. These objects are created only during sort, then garbage + * collected in Young Gen. Static class to avoid holding reference to outer class instance. + */ + private static class Record implements Comparable { + private final byte[] serializedKV; + private final RawComparator comparator; + final int offset; + final int kLen; + final int vLen; + + Record(byte[] serializedKV, RawComparator comparator, int offset, int kLen, int vLen) { + this.serializedKV = serializedKV; + this.comparator = comparator; + this.offset = offset; + this.kLen = kLen; + this.vLen = vLen; + } + + @Override + public int compareTo(Record other) { + return comparator.compare( + serializedKV, offset, kLen, other.serializedKV, other.offset, other.kLen); } } @@ -263,9 +376,12 @@ private int insertRecordInternal(K key, V value, int partition) throws IOExcepti keyLen = writePos - offset; vSer.serialize(value); valLen = writePos - keyLen - offset; - List serializedKVs = - partitionedKVs.computeIfAbsent(partition, v -> new ArrayList<>()); - serializedKVs.add(new SerializedKV(offset, keyLen, valLen)); + KVBufferInfo bufferInfo = + partitionedKVBuffers.computeIfAbsent( + partition, v -> new KVBufferInfo(1024)); // Initial capacity: 1024 records + // Store metadata directly in primitive arrays, no object allocation + bufferInfo.add(offset, keyLen, valLen); + totalRecordCount++; // Track total records across all partitions if (logger.isDebugEnabled()) { logger.debug( "Pusher insert into buffer partition:{} offset:{} keyLen:{} valueLen:{} size:{}", @@ -273,7 +389,7 @@ private int insertRecordInternal(K key, V value, int partition) throws IOExcepti offset, keyLen, valLen, - partitionedKVs.size()); + partitionedKVBuffers.size()); } return keyLen + valLen; } @@ -320,19 +436,63 @@ public void close() { } catch (IOException e) { exception.compareAndSet(null, e); } - partitionedKVs.clear(); + partitionedKVBuffers.clear(); + totalRecordCount = 0; serializedKV = null; } - static class SerializedKV { - final int offset; - final int kLen; - final int vLen; + /** + * Buffer info to manage serialized key-value records for each partition. Uses primitive int + * arrays to store metadata instead of object arrays, significantly reducing memory overhead. + * + *

Memory comparison for 1 million records: - ArrayList: ~32MB (8MB references + + * 24MB objects) - This approach: ~12MB (4MBĂ—3 int arrays) + * + *

Saves 62.5% memory! + */ + static class KVBufferInfo { + int[] offsets; // Store key offset in serializedKV buffer + int[] keyLens; // Store key length + int[] valueLens; // Store value length + int count; + int capacity; - public SerializedKV(int offset, int kLen, int vLen) { - this.offset = offset; - this.kLen = kLen; - this.vLen = vLen; + KVBufferInfo(int initialCapacity) { + this.offsets = new int[initialCapacity]; + this.keyLens = new int[initialCapacity]; + this.valueLens = new int[initialCapacity]; + this.capacity = initialCapacity; + this.count = 0; + } + + void add(int offset, int kLen, int vLen) { + if (count >= capacity) { + // Expand arrays with 2x growth strategy + int newCapacity = capacity * 2; + + int[] newOffsets = new int[newCapacity]; + int[] newKeyLens = new int[newCapacity]; + int[] newValueLens = new int[newCapacity]; + + System.arraycopy(offsets, 0, newOffsets, 0, count); + System.arraycopy(keyLens, 0, newKeyLens, 0, count); + System.arraycopy(valueLens, 0, newValueLens, 0, count); + + offsets = newOffsets; + keyLens = newKeyLens; + valueLens = newValueLens; + capacity = newCapacity; + } + offsets[count] = offset; + keyLens[count] = kLen; + valueLens[count] = vLen; + count++; + } + + void clear() { + count = 0; + // Note: We don't clear the arrays themselves to avoid overhead + // They will be overwritten as new data is added } } }