Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing in key metadata when Reading From Kafka #34426

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import javax.annotation.Nullable;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;

@AutoValue
Expand Down Expand Up @@ -60,6 +61,14 @@ public static Schema errorSchemaBytes() {
Schema.Field.of("error_message", Schema.FieldType.STRING));
}

public static Schema errorSchemaKvBytes() {
return Schema.builder()
.addField("inputKey", Schema.FieldType.BYTES)
.addField("inputValue", Schema.FieldType.BYTES)
.addField("error", Schema.FieldType.STRING)
.build();
}

@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
Expand All @@ -79,4 +88,10 @@ public static Row errorRecord(Schema errorSchema, byte[] inputBytes, Throwable t
.withFieldValue("error_message", th.getMessage())
.build();
}

public static Row errorRecord(Schema errorSchema, KV<byte[], byte[]> input, Exception e) {
return Row.withSchema(errorSchema)
.addValues(input.getKey(), input.getValue(), e.toString())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ public static Builder builder() {
/** Sets the topic from which to read. */
public abstract String getTopic();

@SchemaFieldDescription("Include key metadata when reading from Kafka.")
@Nullable
public abstract Boolean getWithKeyMetadata();

@SchemaFieldDescription("Upper bound of how long to read from Kafka.")
@Nullable
public abstract Integer getMaxReadTimeSeconds();
Expand Down Expand Up @@ -180,6 +184,8 @@ public abstract static class Builder {

public abstract Builder setConsumerConfigUpdates(Map<String, String> consumerConfigUpdates);

public abstract Builder setWithKeyMetadata(Boolean value);

/** Sets the topic from which to read. */
public abstract Builder setTopic(String value);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.schemas.utils.JsonUtils;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
Expand Down Expand Up @@ -103,6 +105,16 @@ public Row apply(byte[] input) {
};
}

public static SerializableFunction<KV<byte[], byte[]>, Row> getRawBytesKvToRowFunction(
Schema rawSchema) {
return new SimpleFunction<KV<byte[], byte[]>, Row>() {
@Override
public Row apply(KV<byte[], byte[]> input) {
return Row.withSchema(rawSchema).addValues(input.getKey(), input.getValue()).build();
}
};
}

Comment on lines +108 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a fixed function with a fixed output Schema (i.e. not a method)

@Override
public String identifier() {
return getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ);
Expand Down Expand Up @@ -191,8 +203,64 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
}

if ("RAW".equals(format)) {
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
boolean withKeyMetadata = Boolean.TRUE.equals(configuration.getWithKeyMetadata());
if (withKeyMetadata) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if block is really long. Can we put the logic here into its own function? Something like

if (withKeyMetadata) {
  return someFnHere(beamSchema);
} else {
  beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
  valueMapper = getRawBytesToRowFunction(beamSchema);
}

beamSchema =
Schema.builder()
.addField("key", Schema.FieldType.BYTES)
.addField("payload", Schema.FieldType.BYTES)
.build();
SerializableFunction<KV<byte[], byte[]>, Row> kvValueMapper =
getRawBytesKvToRowFunction(beamSchema);
KafkaIO.Read<byte[], byte[]> kafkaRead =
KafkaIO.readBytes()
.withConsumerConfigUpdates(consumerConfigs)
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withTopic(configuration.getTopic())
.withBootstrapServers(configuration.getBootstrapServers());
Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
if (maxReadTimeSeconds != null) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
}
PCollection<KafkaRecord<byte[], byte[]>> kafkaRecords =
input.getPipeline().apply(kafkaRead);

PCollection<KV<byte[], byte[]>> kafkaValues =
kafkaRecords.apply(
MapElements.via(
new SimpleFunction<KafkaRecord<byte[], byte[]>, KV<byte[], byte[]>>() {
@Override
public KV<byte[], byte[]> apply(KafkaRecord<byte[], byte[]> record) {
return KV.of(record.getKV().getKey(), record.getKV().getValue());
}
}));

Schema errorSchema = ErrorHandling.errorSchemaKvBytes();

PCollectionTuple outputTuple =
kafkaValues.apply(
ParDo.of(
new ErrorKvFn(
"Kafka-read-error-counter", kvValueMapper, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

PCollectionRowTuple outputRows =
PCollectionRowTuple.of(
"output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));

PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows =
outputRows.and(
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(),
errorOutput);
}
return outputRows;

} else {
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
}
} else if ("PROTO".equals(format)) {
String fileDescriptorPath = configuration.getFileDescriptorPath();
String messageName = checkArgumentNotNull(configuration.getMessageName());
Expand Down Expand Up @@ -295,6 +363,54 @@ public void finish(FinishBundleContext c) {
}
}

public static class ErrorKvFn extends DoFn<KV<byte[], byte[]>, Row> {
private static final Logger LOG = LoggerFactory.getLogger(ErrorKvFn.class);
private final SerializableFunction<KV<byte[], byte[]>, Row> valueMapper;
private final Counter errorCounter;
private Long errorsInBundle = 0L;
private final boolean handleErrors;
private final Schema errorSchema;

public ErrorKvFn(
String name,
SerializableFunction<KV<byte[], byte[]>, Row> valueMapper,
Schema errorSchema,
boolean handleErrors) {
this.errorCounter = Metrics.counter(KafkaReadSchemaTransformProvider.class, name);
this.valueMapper = valueMapper;
this.handleErrors = handleErrors;
this.errorSchema = errorSchema;
}

@ProcessElement
public void process(@DoFn.Element KV<byte[], byte[]> msg, MultiOutputReceiver receiver) {
Row mappedRow = null;
try {
mappedRow = valueMapper.apply(msg);
} catch (Exception e) {
if (!handleErrors) {
throw new RuntimeException(e);
}
errorsInBundle += 1;
LOG.warn("Error while parsing the element", e);
receiver
.get(ERROR_TAG)
.output(
ErrorHandling.errorRecord(
errorSchema, msg, e)); // Use ErrorHandling.errorRecord for KV
}
if (mappedRow != null) {
receiver.get(OUTPUT_TAG).output(mappedRow);
}
}

@FinishBundle
public void finish(FinishBundleContext c) {
errorCounter.inc(errorsInBundle);
errorsInBundle = 0L;
}
}

private static class ConsumerFactoryWithGcsTrustStores
implements SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.kafka;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
Expand All @@ -34,11 +35,15 @@
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.managed.Managed;
import org.apache.beam.sdk.managed.ManagedTransformConstants;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
Expand Down Expand Up @@ -134,7 +139,8 @@ public void testFindTransformAndMakeItWork() {
"error_handling",
"file_descriptor_path",
"message_name",
"max_read_time_seconds"),
"max_read_time_seconds",
"with_key_metadata"),
kafkaProvider.configurationSchema().getFields().stream()
.map(field -> field.getName())
.collect(Collectors.toSet()));
Expand Down Expand Up @@ -340,4 +346,61 @@ public void testManagedMappings() {
assertTrue(configSchemaFieldNames.contains(paramName));
}
}

@Test
public void testGetRawBytesKvToRowFunction() {
Schema testSchema =
Schema.builder()
.addField("key", Schema.FieldType.BYTES)
.addField("value", Schema.FieldType.BYTES)
.build();

SerializableFunction<KV<byte[], byte[]>, Row> kvToRow =
KafkaReadSchemaTransformProvider.getRawBytesKvToRowFunction(testSchema);

KV<byte[], byte[]> inputKv =
KV.of(
"testKey".getBytes(StandardCharsets.UTF_8),
"testValue".getBytes(StandardCharsets.UTF_8));
Row outputRow = kvToRow.apply(inputKv);

assertEquals("testKey", new String(outputRow.getBytes("key"), StandardCharsets.UTF_8));
assertEquals("testValue", new String(outputRow.getBytes("value"), StandardCharsets.UTF_8));
}

@Test
public void testGetRawBytesKvToRowFunctionEmptyKey() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add an integration test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw above comment about ITs. You can add them to KafkaIOIT, where you can write a pipeline which triggers your code path as if it were a user's pipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to keep using KafkaIOIT. It should already have some integration tests for KafkaIO as a ManagedIO (look for Managed.write/read), which uses the SchemaTransform under the hood

Schema testSchema =
Schema.builder()
.addField("key", Schema.FieldType.BYTES)
.addField("value", Schema.FieldType.BYTES)
.build();

SerializableFunction<KV<byte[], byte[]>, Row> kvToRow =
KafkaReadSchemaTransformProvider.getRawBytesKvToRowFunction(testSchema);

KV<byte[], byte[]> inputKv = KV.of(new byte[0], "testValue".getBytes(StandardCharsets.UTF_8));
Row outputRow = kvToRow.apply(inputKv);

assertArrayEquals(new byte[0], outputRow.getBytes("key"));
}

@Test
public void testGetRawBytesKvToRowFunctionNullValue() {
Schema testSchema =
Schema.builder()
.addField("key", Schema.FieldType.BYTES)
.addField("value", Schema.FieldType.BYTES.withNullable(true))
.build();

SerializableFunction<KV<byte[], byte[]>, Row> kvToRow =
KafkaReadSchemaTransformProvider.getRawBytesKvToRowFunction(testSchema);

KV<byte[], byte[]> inputKv = KV.of("testKey".getBytes(StandardCharsets.UTF_8), null);
Row outputRow = kvToRow.apply(inputKv);

byte[] valueBytes = outputRow.getBytes("value");
assertEquals(null, valueBytes);
;
}
}
Loading