diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java index 13e291e887cc3..5760cc64e9041 100644 --- a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java @@ -422,6 +422,10 @@ private static void assertProcessingExceptionHandlerInputs(final ErrorHandlerCon assertTrue(Arrays.asList("ID123-A2", "ID123-A5").contains((String) record.value())); assertEquals("TOPIC_NAME", context.topic()); assertEquals("KSTREAM-PROCESSOR-0000000003", context.processorNodeId()); + assertTrue(Arrays.equals("ID123-2-ERR".getBytes(), context.sourceRawKey()) + || Arrays.equals("ID123-5-ERR".getBytes(), context.sourceRawKey())); + assertTrue(Arrays.equals("ID123-A2".getBytes(), context.sourceRawValue()) + || Arrays.equals("ID123-A5".getBytes(), context.sourceRawValue())); assertEquals(TIMESTAMP.toEpochMilli(), context.timestamp()); assertTrue(exception.getMessage().contains("Exception should be handled by processing exception handler")); } diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java b/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java index d471673a48ed4..c2e212566a7fb 100644 --- a/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java @@ -147,4 +147,38 @@ public interface ErrorHandlerContext { * @return The timestamp. */ long timestamp(); + + /** + * Return the non-deserialized byte[] of the input message key if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return null. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + *

Always returns null if this method is invoked within a + * ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception) + * + * @return the raw byte of the key of the source message + */ + byte[] sourceRawKey(); + + /** + * Return the non-deserialized byte[] of the input message value if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return null. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + *

Always returns null if this method is invoked within a + * ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception) + * + * @return the raw byte of the value of the source message + */ + byte[] sourceRawValue(); } diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java b/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java index efaa6d57e7acc..0e85ce68c0369 100644 --- a/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java @@ -33,6 +33,8 @@ public class DefaultErrorHandlerContext implements ErrorHandlerContext { private final Headers headers; private final String processorNodeId; private final TaskId taskId; + private final byte[] sourceRawKey; + private final byte[] sourceRawValue; private final long timestamp; private final ProcessorContext processorContext; @@ -44,7 +46,9 @@ public DefaultErrorHandlerContext(final ProcessorContext processorContext, final Headers headers, final String processorNodeId, final TaskId taskId, - final long timestamp) { + final long timestamp, + final byte[] sourceRawKey, + final byte[] sourceRawValue) { this.topic = topic; this.partition = partition; this.offset = offset; @@ -53,6 +57,8 @@ public DefaultErrorHandlerContext(final ProcessorContext processorContext, this.taskId = taskId; this.processorContext = processorContext; this.timestamp = timestamp; + this.sourceRawKey = sourceRawKey; + this.sourceRawValue = sourceRawValue; } @Override @@ -90,6 +96,14 @@ public long timestamp() { return timestamp; } + public byte[] sourceRawKey() { + return sourceRawKey; + } + + public byte[] sourceRawValue() { + return sourceRawValue; + } + @Override public String toString() { // we do exclude headers on purpose, to not accidentally log user data diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java index 6b6fd91c85355..012c89247e498 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java @@ -110,4 +110,37 @@ public interface RecordContext { */ Headers headers(); + /** + * Return the non-deserialized byte[] of the input message key if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return null. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + *

Always returns null if this method is invoked within a + * ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception) + * + * @return the raw byte of the key of the source message + */ + byte[] sourceRawKey(); + + /** + * Return the non-deserialized byte[] of the input message value if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return null. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + *

Always returns null if this method is invoked within a + * ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception) + * + * @return the raw byte of the value of the source message + */ + byte[] sourceRawValue(); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java index 8f739d0c0566a..93961daf97b79 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java @@ -260,7 +260,10 @@ public void forward(final Record record, final String childName) { recordContext.offset(), recordContext.partition(), recordContext.topic(), - record.headers()); + record.headers(), + recordContext.sourceRawKey(), + recordContext.sourceRawValue() + ); } if (childName == null) { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java index 5d245ef5f303e..1dddc55ca3c26 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java @@ -215,7 +215,9 @@ public void process(final Record record) { internalProcessorContext.recordContext().headers(), internalProcessorContext.currentNode().name(), internalProcessorContext.taskId(), - internalProcessorContext.recordContext().timestamp() + internalProcessorContext.recordContext().timestamp(), + internalProcessorContext.recordContext().sourceRawKey(), + internalProcessorContext.recordContext().sourceRawValue() ); final ProcessingExceptionHandler.ProcessingHandlerResponse response; diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java index 839baaad87528..a8937b4de8008 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java @@ -37,6 +37,8 @@ public class ProcessorRecordContext implements RecordContext, RecordMetadata { private final String topic; private final int partition; private final Headers headers; + private byte[] sourceRawKey; + private byte[] sourceRawValue; public ProcessorRecordContext(final long timestamp, final long offset, @@ -48,6 +50,24 @@ public ProcessorRecordContext(final long timestamp, this.topic = topic; this.partition = partition; this.headers = Objects.requireNonNull(headers); + this.sourceRawKey = null; + this.sourceRawValue = null; + } + + public ProcessorRecordContext(final long timestamp, + final long offset, + final int partition, + final String topic, + final Headers headers, + final byte[] sourceRawKey, + final byte[] sourceRawValue) { + this.timestamp = timestamp; + this.offset = offset; + this.topic = topic; + this.partition = partition; + this.headers = Objects.requireNonNull(headers); + this.sourceRawKey = sourceRawKey; + this.sourceRawValue = sourceRawValue; } @Override @@ -75,6 +95,16 @@ public Headers headers() { return headers; } + @Override + public byte[] sourceRawKey() { + return sourceRawKey; + } + + @Override + public byte[] sourceRawValue() { + return sourceRawValue; + } + public long residentMemorySizeEstimate() { long size = 0; size += Long.BYTES; // value.context.timestamp @@ -176,6 +206,11 @@ public static ProcessorRecordContext deserialize(final ByteBuffer buffer) { return new ProcessorRecordContext(timestamp, offset, partition, topic, headers); } + public void freeRawRecord() { + this.sourceRawKey = null; + this.sourceRawValue = null; + } + @Override public boolean equals(final Object o) { if (this == o) { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java index d47db7ea94261..89cbf4d4c7d4e 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java @@ -259,6 +259,10 @@ public void send(final String topic, final ProducerRecord serializedRecord = new ProducerRecord<>(topic, partition, timestamp, keyBytes, valBytes, headers); + // As many records could be in-flight, + // freeing raw records in the context to reduce memory pressure + freeRawInputRecordFromContext(context); + streamsProducer.send(serializedRecord, (metadata, exception) -> { try { // if there's already an exception record, skip logging offsets or new exceptions @@ -311,6 +315,12 @@ public void send(final String topic, }); } + private static void freeRawInputRecordFromContext(final InternalProcessorContext context) { + if (context != null && context.recordContext() != null) { + context.recordContext().freeRawRecord(); + } + } + private void handleException(final ProductionExceptionHandler.SerializationExceptionOrigin origin, final String topic, final K key, @@ -388,7 +398,9 @@ private DefaultErrorHandlerContext errorHandlerContext(final InternalProcessorCo recordContext.headers(), processorNodeId, taskId, - recordContext.timestamp() + recordContext.timestamp(), + context.recordContext().sourceRawKey(), + context.recordContext().sourceRawValue() ) : new DefaultErrorHandlerContext( context, @@ -398,7 +410,9 @@ private DefaultErrorHandlerContext errorHandlerContext(final InternalProcessorCo new RecordHeaders(), processorNodeId, taskId, - -1L + -1L, + null, + null ); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java index 6f9fe989552f8..153ca2e02f1ee 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java @@ -95,7 +95,10 @@ public static void handleDeserializationFailure(final DeserializationExceptionHa rawRecord.headers(), sourceNodeName, processorContext.taskId(), - rawRecord.timestamp()); + rawRecord.timestamp(), + rawRecord.key(), + rawRecord.value() + ); final DeserializationHandlerResponse response; try { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java index d38d7b625ae8e..faa90572ca524 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java @@ -243,7 +243,7 @@ private void updateHead() { lastCorruptedRecord = raw; continue; } - headRecord = new StampedRecord(deserialized, timestamp); + headRecord = new StampedRecord(deserialized, timestamp, raw.key(), raw.value()); headRecordSizeInBytes = consumerRecordSizeInBytes(raw); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java index c8ed35a9a8f6c..dd0a1298b6767 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java @@ -23,8 +23,22 @@ public class StampedRecord extends Stamped> { + private final byte[] rawKey; + private final byte[] rawValue; + public StampedRecord(final ConsumerRecord record, final long timestamp) { super(record, timestamp); + this.rawKey = null; + this.rawValue = null; + } + + public StampedRecord(final ConsumerRecord record, + final long timestamp, + final byte[] rawKey, + final byte[] rawValue) { + super(record, timestamp); + this.rawKey = rawKey; + this.rawValue = rawValue; } public String topic() { @@ -55,8 +69,26 @@ public Headers headers() { return value.headers(); } + public byte[] rawKey() { + return rawKey; + } + + public byte[] rawValue() { + return rawValue; + } + @Override public String toString() { return value.toString() + ", timestamp = " + timestamp; } + + @Override + public boolean equals(final Object other) { + return super.equals(other); + } + + @Override + public int hashCode() { + return super.hashCode(); + } } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java index b612223197e4b..3a2b864f277e3 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java @@ -853,7 +853,9 @@ private void doProcess(final long wallClockTime) { record.offset(), record.partition(), record.topic(), - record.headers() + record.headers(), + record.rawKey(), + record.rawValue() ); updateProcessorContext(currNode, wallClockTime, recordContext); @@ -935,7 +937,9 @@ record = null; recordContext.headers(), node.name(), id(), - recordContext.timestamp() + recordContext.timestamp(), + recordContext.sourceRawKey(), + recordContext.sourceRawValue() ); final ProcessingExceptionHandler.ProcessingHandlerResponse response; diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java index 5b4303a16955e..c5dc2b8884e3c 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java @@ -80,6 +80,8 @@ public class ProcessorNodeTest { private static final String NAME = "name"; private static final String KEY = "key"; private static final String VALUE = "value"; + private static final byte[] RAW_KEY = KEY.getBytes(); + private static final byte[] RAW_VALUE = VALUE.getBytes(); @Test public void shouldThrowStreamsExceptionIfExceptionCaughtDuringInit() { @@ -331,7 +333,9 @@ private InternalProcessorContext mockInternalProcessorContext() OFFSET, PARTITION, TOPIC, - new RecordHeaders())); + new RecordHeaders(), + RAW_KEY, + RAW_VALUE)); when(internalProcessorContext.currentNode()).thenReturn(new ProcessorNode<>(NAME)); return internalProcessorContext; @@ -359,6 +363,9 @@ public ProcessingExceptionHandler.ProcessingHandlerResponse handle(final ErrorHa assertEquals(internalProcessorContext.currentNode().name(), context.processorNodeId()); assertEquals(internalProcessorContext.taskId(), context.taskId()); assertEquals(internalProcessorContext.recordContext().timestamp(), context.timestamp()); + assertEquals(internalProcessorContext.recordContext().sourceRawKey(), context.sourceRawKey()); + assertEquals(internalProcessorContext.recordContext().sourceRawValue(), context.sourceRawValue()); + assertEquals(KEY, record.key()); assertEquals(VALUE, record.value()); assertInstanceOf(RuntimeException.class, exception); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java index b01b87ed85f82..a2f7eb6af1c5c 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java @@ -100,6 +100,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -1890,6 +1892,69 @@ public void shouldNotSendIfSendOfOtherTaskFailedInCallback() { )); } + @Test + public void shouldFreeRawRecordsInContextBeforeSending() { + final KafkaException exception = new KafkaException("KABOOM!"); + final byte[][] sourceRawData = new byte[][]{new byte[]{}, new byte[]{}}; + + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + new ProductionExceptionHandler() { + @Override + public void configure(final Map configs) { + + } + + @Override + public ProductionExceptionHandlerResponse handle(final ErrorHandlerContext context, final ProducerRecord record, final Exception exception) { + sourceRawData[0] = context.sourceRawKey(); + sourceRawData[1] = context.sourceRawValue(); + return ProductionExceptionHandlerResponse.CONTINUE; + } + }, + streamsMetrics, + topology + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, sinkNodeName, context, streamPartitioner); + collector.flush(); + + assertNull(sourceRawData[0]); + assertNull(sourceRawData[1]); + } + + + @Test + public void shouldHaveRawDataDuringExceptionInSerialization() { + final byte[][] sourceRawData = new byte[][]{new byte[]{}, new byte[]{}}; + try (final ErrorStringSerializer errorSerializer = new ErrorStringSerializer()) { + final RecordCollector collector = newRecordCollector( + new ProductionExceptionHandler() { + @Override + @SuppressWarnings({"rawtypes", "unused"}) + public ProductionExceptionHandlerResponse handleSerializationException(final ErrorHandlerContext context, final ProducerRecord record, final Exception exception, final SerializationExceptionOrigin origin) { + sourceRawData[0] = context.sourceRawKey(); + sourceRawData[1] = context.sourceRawValue(); + return ProductionExceptionHandlerResponse.CONTINUE; + } + + @Override + public void configure(final Map configs) { + + } + } + ); + collector.initialize(); + + collector.send(topic, "hello", "val", null, 0, null, (Serializer) errorSerializer, stringSerializer, sinkNodeName, context); + + assertNotNull(sourceRawData[0]); + assertNotNull(sourceRawData[1]); + } + } + private RecordCollector newRecordCollector(final ProductionExceptionHandler productionExceptionHandler) { return new RecordCollectorImpl( logContext, diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java index 228df8d63a1ac..ed68c86c49020 100644 --- a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java +++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java @@ -56,6 +56,7 @@ import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; import java.io.File; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; @@ -244,7 +245,9 @@ public InternalMockProcessorContext(final File stateDir, 0, 0, "topic", - new RecordHeaders() + new RecordHeaders(), + "sourceKey".getBytes(StandardCharsets.UTF_8), + "sourceValue".getBytes(StandardCharsets.UTF_8) ); }