-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Passing in key metadata when Reading From Kafka #34426
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,10 +52,12 @@ | |
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; | ||
import org.apache.beam.sdk.schemas.utils.JsonUtils; | ||
import org.apache.beam.sdk.transforms.DoFn; | ||
import org.apache.beam.sdk.transforms.MapElements; | ||
import org.apache.beam.sdk.transforms.ParDo; | ||
import org.apache.beam.sdk.transforms.SerializableFunction; | ||
import org.apache.beam.sdk.transforms.SimpleFunction; | ||
import org.apache.beam.sdk.transforms.Values; | ||
import org.apache.beam.sdk.values.KV; | ||
import org.apache.beam.sdk.values.PCollection; | ||
import org.apache.beam.sdk.values.PCollectionRowTuple; | ||
import org.apache.beam.sdk.values.PCollectionTuple; | ||
|
@@ -103,6 +105,16 @@ public Row apply(byte[] input) { | |
}; | ||
} | ||
|
||
public static SerializableFunction<KV<byte[], byte[]>, Row> getRawBytesKvToRowFunction( | ||
Schema rawSchema) { | ||
return new SimpleFunction<KV<byte[], byte[]>, Row>() { | ||
@Override | ||
public Row apply(KV<byte[], byte[]> input) { | ||
return Row.withSchema(rawSchema).addValues(input.getKey(), input.getValue()).build(); | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public String identifier() { | ||
return getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ); | ||
|
@@ -191,8 +203,64 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { | |
} | ||
|
||
if ("RAW".equals(format)) { | ||
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); | ||
valueMapper = getRawBytesToRowFunction(beamSchema); | ||
boolean withKeyMetadata = Boolean.TRUE.equals(configuration.getWithKeyMetadata()); | ||
if (withKeyMetadata) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This if block is really long. Can we put the logic here into its own function? Something like
|
||
beamSchema = | ||
Schema.builder() | ||
.addField("key", Schema.FieldType.BYTES) | ||
.addField("payload", Schema.FieldType.BYTES) | ||
.build(); | ||
SerializableFunction<KV<byte[], byte[]>, Row> kvValueMapper = | ||
getRawBytesKvToRowFunction(beamSchema); | ||
KafkaIO.Read<byte[], byte[]> kafkaRead = | ||
KafkaIO.readBytes() | ||
.withConsumerConfigUpdates(consumerConfigs) | ||
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores()) | ||
.withTopic(configuration.getTopic()) | ||
.withBootstrapServers(configuration.getBootstrapServers()); | ||
Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds(); | ||
if (maxReadTimeSeconds != null) { | ||
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds)); | ||
} | ||
PCollection<KafkaRecord<byte[], byte[]>> kafkaRecords = | ||
input.getPipeline().apply(kafkaRead); | ||
|
||
PCollection<KV<byte[], byte[]>> kafkaValues = | ||
kafkaRecords.apply( | ||
MapElements.via( | ||
new SimpleFunction<KafkaRecord<byte[], byte[]>, KV<byte[], byte[]>>() { | ||
@Override | ||
public KV<byte[], byte[]> apply(KafkaRecord<byte[], byte[]> record) { | ||
return KV.of(record.getKV().getKey(), record.getKV().getValue()); | ||
} | ||
})); | ||
|
||
Schema errorSchema = ErrorHandling.errorSchemaKvBytes(); | ||
|
||
PCollectionTuple outputTuple = | ||
kafkaValues.apply( | ||
ParDo.of( | ||
new ErrorKvFn( | ||
"Kafka-read-error-counter", kvValueMapper, errorSchema, handleErrors)) | ||
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); | ||
|
||
PCollectionRowTuple outputRows = | ||
PCollectionRowTuple.of( | ||
"output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema)); | ||
|
||
PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); | ||
if (handleErrors) { | ||
outputRows = | ||
outputRows.and( | ||
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), | ||
errorOutput); | ||
} | ||
return outputRows; | ||
|
||
} else { | ||
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build(); | ||
valueMapper = getRawBytesToRowFunction(beamSchema); | ||
} | ||
} else if ("PROTO".equals(format)) { | ||
String fileDescriptorPath = configuration.getFileDescriptorPath(); | ||
String messageName = checkArgumentNotNull(configuration.getMessageName()); | ||
|
@@ -295,6 +363,54 @@ public void finish(FinishBundleContext c) { | |
} | ||
} | ||
|
||
public static class ErrorKvFn extends DoFn<KV<byte[], byte[]>, Row> { | ||
private static final Logger LOG = LoggerFactory.getLogger(ErrorKvFn.class); | ||
private final SerializableFunction<KV<byte[], byte[]>, Row> valueMapper; | ||
private final Counter errorCounter; | ||
private Long errorsInBundle = 0L; | ||
private final boolean handleErrors; | ||
private final Schema errorSchema; | ||
|
||
public ErrorKvFn( | ||
String name, | ||
SerializableFunction<KV<byte[], byte[]>, Row> valueMapper, | ||
Schema errorSchema, | ||
boolean handleErrors) { | ||
this.errorCounter = Metrics.counter(KafkaReadSchemaTransformProvider.class, name); | ||
this.valueMapper = valueMapper; | ||
this.handleErrors = handleErrors; | ||
this.errorSchema = errorSchema; | ||
} | ||
|
||
@ProcessElement | ||
public void process(@DoFn.Element KV<byte[], byte[]> msg, MultiOutputReceiver receiver) { | ||
Row mappedRow = null; | ||
try { | ||
mappedRow = valueMapper.apply(msg); | ||
} catch (Exception e) { | ||
if (!handleErrors) { | ||
throw new RuntimeException(e); | ||
} | ||
errorsInBundle += 1; | ||
LOG.warn("Error while parsing the element", e); | ||
receiver | ||
.get(ERROR_TAG) | ||
.output( | ||
ErrorHandling.errorRecord( | ||
errorSchema, msg, e)); // Use ErrorHandling.errorRecord for KV | ||
} | ||
if (mappedRow != null) { | ||
receiver.get(OUTPUT_TAG).output(mappedRow); | ||
} | ||
} | ||
|
||
@FinishBundle | ||
public void finish(FinishBundleContext c) { | ||
errorCounter.inc(errorsInBundle); | ||
errorsInBundle = 0L; | ||
} | ||
} | ||
|
||
private static class ConsumerFactoryWithGcsTrustStores | ||
implements SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> { | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
*/ | ||
package org.apache.beam.sdk.io.kafka; | ||
|
||
import static org.junit.Assert.assertArrayEquals; | ||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertNotNull; | ||
import static org.junit.Assert.assertThrows; | ||
|
@@ -34,11 +35,15 @@ | |
import org.apache.beam.sdk.Pipeline; | ||
import org.apache.beam.sdk.managed.Managed; | ||
import org.apache.beam.sdk.managed.ManagedTransformConstants; | ||
import org.apache.beam.sdk.schemas.Schema; | ||
import org.apache.beam.sdk.schemas.transforms.SchemaTransform; | ||
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; | ||
import org.apache.beam.sdk.schemas.utils.YamlUtils; | ||
import org.apache.beam.sdk.transforms.SerializableFunction; | ||
import org.apache.beam.sdk.values.KV; | ||
import org.apache.beam.sdk.values.PBegin; | ||
import org.apache.beam.sdk.values.PCollectionRowTuple; | ||
import org.apache.beam.sdk.values.Row; | ||
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; | ||
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; | ||
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; | ||
|
@@ -134,7 +139,8 @@ public void testFindTransformAndMakeItWork() { | |
"error_handling", | ||
"file_descriptor_path", | ||
"message_name", | ||
"max_read_time_seconds"), | ||
"max_read_time_seconds", | ||
"with_key_metadata"), | ||
kafkaProvider.configurationSchema().getFields().stream() | ||
.map(field -> field.getName()) | ||
.collect(Collectors.toSet())); | ||
|
@@ -340,4 +346,61 @@ public void testManagedMappings() { | |
assertTrue(configSchemaFieldNames.contains(paramName)); | ||
} | ||
} | ||
|
||
@Test | ||
public void testGetRawBytesKvToRowFunction() { | ||
Schema testSchema = | ||
Schema.builder() | ||
.addField("key", Schema.FieldType.BYTES) | ||
.addField("value", Schema.FieldType.BYTES) | ||
.build(); | ||
|
||
SerializableFunction<KV<byte[], byte[]>, Row> kvToRow = | ||
KafkaReadSchemaTransformProvider.getRawBytesKvToRowFunction(testSchema); | ||
|
||
KV<byte[], byte[]> inputKv = | ||
KV.of( | ||
"testKey".getBytes(StandardCharsets.UTF_8), | ||
"testValue".getBytes(StandardCharsets.UTF_8)); | ||
Row outputRow = kvToRow.apply(inputKv); | ||
|
||
assertEquals("testKey", new String(outputRow.getBytes("key"), StandardCharsets.UTF_8)); | ||
assertEquals("testValue", new String(outputRow.getBytes("value"), StandardCharsets.UTF_8)); | ||
} | ||
|
||
@Test | ||
public void testGetRawBytesKvToRowFunctionEmptyKey() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also add an integration test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just saw above comment about ITs. You can add them to KafkaIOIT, where you can write a pipeline which triggers your code path as if it were a user's pipeline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 to keep using KafkaIOIT. It should already have some integration tests for KafkaIO as a ManagedIO (look for |
||
Schema testSchema = | ||
Schema.builder() | ||
.addField("key", Schema.FieldType.BYTES) | ||
.addField("value", Schema.FieldType.BYTES) | ||
.build(); | ||
|
||
SerializableFunction<KV<byte[], byte[]>, Row> kvToRow = | ||
KafkaReadSchemaTransformProvider.getRawBytesKvToRowFunction(testSchema); | ||
|
||
KV<byte[], byte[]> inputKv = KV.of(new byte[0], "testValue".getBytes(StandardCharsets.UTF_8)); | ||
Row outputRow = kvToRow.apply(inputKv); | ||
|
||
assertArrayEquals(new byte[0], outputRow.getBytes("key")); | ||
} | ||
|
||
@Test | ||
public void testGetRawBytesKvToRowFunctionNullValue() { | ||
Schema testSchema = | ||
Schema.builder() | ||
.addField("key", Schema.FieldType.BYTES) | ||
.addField("value", Schema.FieldType.BYTES.withNullable(true)) | ||
.build(); | ||
|
||
SerializableFunction<KV<byte[], byte[]>, Row> kvToRow = | ||
KafkaReadSchemaTransformProvider.getRawBytesKvToRowFunction(testSchema); | ||
|
||
KV<byte[], byte[]> inputKv = KV.of("testKey".getBytes(StandardCharsets.UTF_8), null); | ||
Row outputRow = kvToRow.apply(inputKv); | ||
|
||
byte[] valueBytes = outputRow.getBytes("value"); | ||
assertEquals(null, valueBytes); | ||
; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be a fixed function with a fixed output Schema (i.e. not a method)