Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,26 @@
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.Utils;

/**
* Sort-based pusher for MapReduce shuffle data to Celeborn.
*
* <p>This implementation uses primitive int arrays to store record metadata (offsets, key lengths,
* value lengths) during data collection to minimize object allocation. During flush, temporary
* Record objects are created for sorting and immediately garbage collected in Young Gen.
*
* <p>To prevent memory pressure during sorting, the implementation triggers early spill when record
* count exceeds a threshold (5M records by default), ensuring temporary objects fit within Young
* Gen capacity.
*/
public class CelebornSortBasedPusher<K, V> extends OutputStream {
private final Logger logger = LoggerFactory.getLogger(CelebornSortBasedPusher.class);

// Maximum number of records to accumulate before forcing an early spill.
// This limits Young Gen pressure during sorting by preventing too many temporary
// Record objects from being created at once.
// Record size ~24 bytes, so 5M records = 120MB at peak
private static final int MAX_SORT_RECORDS = 5_000_000;

private final int mapId;
private final int attempt;
private final int numMappers;
Expand All @@ -48,10 +66,11 @@ public class CelebornSortBasedPusher<K, V> extends OutputStream {
private final AtomicReference<Exception> exception = new AtomicReference<>();
private final Counters.Counter mapOutputByteCounter;
private final Counters.Counter mapOutputRecordCounter;
private final Map<Integer, List<SerializedKV>> partitionedKVs;
private final Map<Integer, KVBufferInfo> partitionedKVBuffers;
private int writePos;
private byte[] serializedKV;
private final int maxPushDataSize;
private int totalRecordCount;

public CelebornSortBasedPusher(
int numMappers,
Expand Down Expand Up @@ -79,7 +98,7 @@ public CelebornSortBasedPusher(
this.mapOutputRecordCounter = mapOutputRecordCounter;
this.comparator = comparator;
this.shuffleClient = shuffleClient;
partitionedKVs = new HashMap<>();
partitionedKVBuffers = new HashMap<>();
serializedKV = new byte[maxIOBufferSize];
maxPushDataSize = (int) celebornConf.clientMrMaxPushData();
logger.info(
Expand All @@ -102,6 +121,7 @@ public CelebornSortBasedPusher(

public void insert(K key, V value, int partition) {
try {
// Check if we should spill based on buffer size
if (writePos >= spillIOBufferSize) {
// needs to sort and flush data
if (logger.isDebugEnabled()) {
Expand All @@ -114,6 +134,20 @@ public void insert(K key, V value, int partition) {
sortKVs();
sendKVAndUpdateWritePos();
}

// Additional check: limit total record count to avoid memory pressure during sort
// If total records exceed safe threshold, force an early spill
if (totalRecordCount >= MAX_SORT_RECORDS && writePos > 0) {
if (logger.isDebugEnabled()) {
logger.debug(
"Record count {} exceeds safe threshold {}, forcing early spill",
totalRecordCount,
MAX_SORT_RECORDS);
}
sortKVs();
sendKVAndUpdateWritePos();
}

int dataLen = insertRecordInternal(key, value, partition);
if (logger.isDebugEnabled()) {
logger.debug(
Expand All @@ -127,45 +161,64 @@ public void insert(K key, V value, int partition) {
}

private void sendKVAndUpdateWritePos() throws IOException {
Iterator<Map.Entry<Integer, List<SerializedKV>>> entryIter =
partitionedKVs.entrySet().iterator();
Iterator<Map.Entry<Integer, KVBufferInfo>> entryIter =
partitionedKVBuffers.entrySet().iterator();
while (entryIter.hasNext()) {
Map.Entry<Integer, List<SerializedKV>> entry = entryIter.next();
Map.Entry<Integer, KVBufferInfo> entry = entryIter.next();
entryIter.remove();
int partition = entry.getKey();
List<SerializedKV> kvs = entry.getValue();
List<SerializedKV> localKVs = new ArrayList<>();
KVBufferInfo bufferInfo = entry.getValue();
int partitionKVTotalLen = 0;
// process buffers for specific partition
for (SerializedKV kv : kvs) {
partitionKVTotalLen += kv.kLen + kv.vLen;
localKVs.add(kv);
if (partitionKVTotalLen > maxPushDataSize) {
// limit max size of pushdata to avoid possible memory issue in Celeborn worker
// data layout
// pushdata header (16) + pushDataLen(4) +
// [varKeyLen+varValLen+serializedRecord(x)][...]
sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);
localKVs.clear();
partitionKVTotalLen = 0;
int batchStartIdx = 0;

// Process buffers for specific partition (arrays are already sorted in-place)
for (int i = 0; i < bufferInfo.count; i++) {
int recordLen = bufferInfo.keyLens[i] + bufferInfo.valueLens[i];

// Check if adding this record would exceed the limit
// This ensures we never send batches larger than maxPushDataSize
if (partitionKVTotalLen + recordLen > maxPushDataSize && partitionKVTotalLen > 0) {
// Send the previous batch (before adding current record)
int batchLength = 0;
for (int j = batchStartIdx; j < i; j++) {
batchLength += bufferInfo.keyLens[j] + bufferInfo.valueLens[j];
}
sendSortedBuffersPartition(
partition, bufferInfo, batchStartIdx, i - batchStartIdx, batchLength);
// Start new batch with current record
batchStartIdx = i;
partitionKVTotalLen = recordLen;
} else {
// Add record to current batch
partitionKVTotalLen += recordLen;
}
}
if (!localKVs.isEmpty()) {
sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);

// Send remaining records
if (batchStartIdx < bufferInfo.count) {
int batchLength = 0;
for (int i = batchStartIdx; i < bufferInfo.count; i++) {
batchLength += bufferInfo.keyLens[i] + bufferInfo.valueLens[i];
}
sendSortedBuffersPartition(
partition, bufferInfo, batchStartIdx, bufferInfo.count - batchStartIdx, batchLength);
}
kvs.clear();
// Clear buffer info for reuse
bufferInfo.clear();
}
// all data sent
partitionedKVs.clear();
// All data sent, reset counters
partitionedKVBuffers.clear();
writePos = 0;
totalRecordCount = 0;
}

private void sendSortedBuffersPartition(
int partition, List<SerializedKV> localKVs, int partitionKVTotalLen) throws IOException {
int partition, KVBufferInfo bufferInfo, int startIdx, int count, int partitionKVTotalLen)
throws IOException {
int extraSize = 0;
for (SerializedKV localKV : localKVs) {
extraSize += WritableUtils.getVIntSize(localKV.kLen);
extraSize += WritableUtils.getVIntSize(localKV.vLen);
for (int i = startIdx; i < startIdx + count; i++) {
extraSize += WritableUtils.getVIntSize(bufferInfo.keyLens[i]);
extraSize += WritableUtils.getVIntSize(bufferInfo.valueLens[i]);
}
// copied from hadoop logic
extraSize += WritableUtils.getVIntSize(-1);
Expand All @@ -174,14 +227,16 @@ private void sendSortedBuffersPartition(
byte[] pkvs = new byte[4 + extraSize + partitionKVTotalLen];
int pkvsPos = 4;
Platform.putInt(pkvs, Platform.BYTE_ARRAY_OFFSET, partitionKVTotalLen + extraSize);
for (SerializedKV kv : localKVs) {
int recordLen = kv.kLen + kv.vLen;
for (int i = startIdx; i < startIdx + count; i++) {
int kLen = bufferInfo.keyLens[i];
int vLen = bufferInfo.valueLens[i];
int recordLen = kLen + vLen;
// write key len
pkvsPos = writeVLong(pkvs, pkvsPos, kv.kLen);
pkvsPos = writeVLong(pkvs, pkvsPos, kLen);
// write value len
pkvsPos = writeVLong(pkvs, pkvsPos, kv.vLen);
pkvsPos = writeVLong(pkvs, pkvsPos, vLen);
// write serialized record
System.arraycopy(serializedKV, kv.offset, pkvs, pkvsPos, recordLen);
System.arraycopy(serializedKV, bufferInfo.offsets[i], pkvs, pkvsPos, recordLen);
pkvsPos += recordLen;
}
// finally write -1 two times
Expand Down Expand Up @@ -245,13 +300,71 @@ private int writeVLong(byte[] data, int offset, long dataInt) {
}

private void sortKVs() {
for (Map.Entry<Integer, List<SerializedKV>> partitionKVEntry : partitionedKVs.entrySet()) {
partitionKVEntry
.getValue()
.sort(
(o1, o2) ->
comparator.compare(
serializedKV, o1.offset, o1.kLen, serializedKV, o2.offset, o2.kLen));
for (Map.Entry<Integer, KVBufferInfo> partitionKVEntry : partitionedKVBuffers.entrySet()) {
KVBufferInfo bufferInfo = partitionKVEntry.getValue();
if (bufferInfo.count <= 1) {
continue;
}

// The early-spill mechanism in insert() should prevent us from ever exceeding
// MAX_SORT_RECORDS.
// If we do exceed it, this means the configuration allows too much data to accumulate.
// We rely on the early-spill check to keep us within safe limits, so we can always do a full
// sort.
sortBatch(bufferInfo, 0, bufferInfo.count);
}
}

/** Sort a batch of records from start (inclusive) to end (exclusive). */
private void sortBatch(KVBufferInfo bufferInfo, int start, int end) {
int size = end - start;

// Create temporary Record objects
Record[] records = new Record[size];
for (int i = 0; i < size; i++) {
records[i] =
new Record(
serializedKV,
comparator,
bufferInfo.offsets[start + i],
bufferInfo.keyLens[start + i],
bufferInfo.valueLens[start + i]);
}

// Sort using Arrays.sort
Arrays.sort(records);

// Write back sorted results
for (int i = 0; i < size; i++) {
bufferInfo.offsets[start + i] = records[i].offset;
bufferInfo.keyLens[start + i] = records[i].kLen;
bufferInfo.valueLens[start + i] = records[i].vLen;
}
}

/**
* Temporary record for sorting. These objects are created only during sort, then garbage
* collected in Young Gen. Static class to avoid holding reference to outer class instance.
*/
private static class Record implements Comparable<Record> {
private final byte[] serializedKV;
private final RawComparator comparator;
final int offset;
final int kLen;
final int vLen;

Record(byte[] serializedKV, RawComparator comparator, int offset, int kLen, int vLen) {
this.serializedKV = serializedKV;
this.comparator = comparator;
this.offset = offset;
this.kLen = kLen;
this.vLen = vLen;
}

@Override
public int compareTo(Record other) {
return comparator.compare(
serializedKV, offset, kLen, other.serializedKV, other.offset, other.kLen);
}
}

Expand All @@ -263,17 +376,20 @@ private int insertRecordInternal(K key, V value, int partition) throws IOExcepti
keyLen = writePos - offset;
vSer.serialize(value);
valLen = writePos - keyLen - offset;
List<SerializedKV> serializedKVs =
partitionedKVs.computeIfAbsent(partition, v -> new ArrayList<>());
serializedKVs.add(new SerializedKV(offset, keyLen, valLen));
KVBufferInfo bufferInfo =
partitionedKVBuffers.computeIfAbsent(
partition, v -> new KVBufferInfo(1024)); // Initial capacity: 1024 records
// Store metadata directly in primitive arrays, no object allocation
bufferInfo.add(offset, keyLen, valLen);
totalRecordCount++; // Track total records across all partitions
if (logger.isDebugEnabled()) {
logger.debug(
"Pusher insert into buffer partition:{} offset:{} keyLen:{} valueLen:{} size:{}",
partition,
offset,
keyLen,
valLen,
partitionedKVs.size());
partitionedKVBuffers.size());
}
return keyLen + valLen;
}
Expand Down Expand Up @@ -320,19 +436,63 @@ public void close() {
} catch (IOException e) {
exception.compareAndSet(null, e);
}
partitionedKVs.clear();
partitionedKVBuffers.clear();
totalRecordCount = 0;
serializedKV = null;
}

static class SerializedKV {
final int offset;
final int kLen;
final int vLen;
/**
* Buffer info to manage serialized key-value records for each partition. Uses primitive int
* arrays to store metadata instead of object arrays, significantly reducing memory overhead.
*
* <p>Memory comparison for 1 million records: - ArrayList<SerializedKV>: ~32MB (8MB references +
* 24MB objects) - This approach: ~12MB (4MB×3 int arrays)
*
* <p>Saves 62.5% memory!
*/
static class KVBufferInfo {
int[] offsets; // Store key offset in serializedKV buffer
int[] keyLens; // Store key length
int[] valueLens; // Store value length
int count;
int capacity;

public SerializedKV(int offset, int kLen, int vLen) {
this.offset = offset;
this.kLen = kLen;
this.vLen = vLen;
KVBufferInfo(int initialCapacity) {
this.offsets = new int[initialCapacity];
this.keyLens = new int[initialCapacity];
this.valueLens = new int[initialCapacity];
this.capacity = initialCapacity;
this.count = 0;
}

void add(int offset, int kLen, int vLen) {
if (count >= capacity) {
// Expand arrays with 2x growth strategy
int newCapacity = capacity * 2;

int[] newOffsets = new int[newCapacity];
int[] newKeyLens = new int[newCapacity];
int[] newValueLens = new int[newCapacity];

System.arraycopy(offsets, 0, newOffsets, 0, count);
System.arraycopy(keyLens, 0, newKeyLens, 0, count);
System.arraycopy(valueLens, 0, newValueLens, 0, count);

offsets = newOffsets;
keyLens = newKeyLens;
valueLens = newValueLens;
capacity = newCapacity;
}
offsets[count] = offset;
keyLens[count] = kLen;
valueLens[count] = vLen;
count++;
}

void clear() {
count = 0;
// Note: We don't clear the arrays themselves to avoid overhead
// They will be overwritten as new data is added
}
}
}
Loading