Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize tf record io #34411

Draft
wants to merge 30 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d693f6c
create TFRecordReadSchemaTransform
derrickaw Mar 24, 2025
c2f61ce
create TFRecordWriteSchemaTransform
derrickaw Mar 24, 2025
00ce573
create TFRecordSchemaTransform Test
derrickaw Mar 24, 2025
4453f50
fix conflicts
derrickaw Mar 27, 2025
34238e1
add writeToTFRecord return
derrickaw Mar 24, 2025
f067a27
add translation file and getRowConfiguration method for said file
derrickaw Mar 25, 2025
48a4b39
fix formatting
derrickaw Mar 25, 2025
86d56e5
remove print
derrickaw Mar 25, 2025
c428774
fix more lint issues
derrickaw Mar 25, 2025
cbce5e1
add support for bytes in yaml
derrickaw Mar 27, 2025
35d80c0
update ReadTransform with error handling and use string compression
derrickaw Mar 27, 2025
c78d69f
add error handling to writetransform and change compression to string
derrickaw Mar 27, 2025
b28acfa
change compression to string
derrickaw Mar 27, 2025
dc4f963
update compression type and no_spilling
derrickaw Mar 27, 2025
539e648
update writeToTFRecord parameters
derrickaw Mar 27, 2025
0402693
add tfrecord yaml test pipeline
derrickaw Mar 27, 2025
e9d6428
remove old code
derrickaw Mar 27, 2025
b944593
fix lint issue
derrickaw Mar 27, 2025
4707213
update standard external transforms with tfrecord info
derrickaw Mar 27, 2025
e614f37
remove bad character and broken comment
derrickaw Mar 27, 2025
414bdb0
fix lint issue
derrickaw Mar 27, 2025
d030e3f
add no_spilling doc string
derrickaw Mar 27, 2025
1a5c8b5
change tfrecord.yaml to write version
derrickaw Mar 28, 2025
0a0e23d
update parameter name
derrickaw Mar 29, 2025
8eca45e
update read and write yaml for tfrecord
derrickaw Mar 29, 2025
821d4c6
update pipeline to handle write and read
derrickaw Mar 29, 2025
3a17dc1
fix lint issues
derrickaw Mar 29, 2025
80d23b0
fix lint
derrickaw Mar 29, 2025
5158bb1
fix lint, precommit checker is broken for me, so hence many commits :)
derrickaw Mar 29, 2025
c98eb38
minor fix
derrickaw Mar 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this is ready for full review (IMO, this is probably now if we can add a few integration tests). I'd recommend doing a separate PR for read/write since they're not really tied together at all. That will make it easier to review/iterate. It also will hopefully unblock the read review while dealing with the issues you're seeing on write

* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.io.Serializable;
import javax.annotation.Nullable;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;

/**
* Configuration for reading from TFRecord.
*
* <p>This class is meant to be used with {@link TFRecordReadSchemaTransformProvider}.
*
* <p><b>Internal only:</b> This class is actively being worked on, and it will likely change. We
* provide no backwards compatibility guarantees, and it should not be implemented outside the Beam
* repository.
*/
@DefaultSchema(AutoValueSchema.class)
@AutoValue
public abstract class TFRecordReadSchemaTransformConfiguration implements Serializable {

public void validate() {
String invalidConfigMessage = "Invalid TFRecord Read configuration: ";
checkNotNull(getValidate(), "To read from TFRecord, validation must be specified.");
checkNotNull(getCompression(), "To read from TFRecord, compression must be specified.");

String filePattern = getFilePattern();
if (filePattern == null || filePattern.isEmpty()) {
throw new IllegalStateException(
"Need to set the filepattern of a TFRecordReadSchema transform");
}
if (getValidate()) {
try {
MatchResult matches = FileSystems.match(filePattern);
checkState(
!matches.metadata().isEmpty(), "Unable to find any files matching %s", filePattern);
} catch (IOException e) {
throw new IllegalStateException(
String.format(invalidConfigMessage + "Failed to validate %s", filePattern), e);
}
}

ErrorHandling errorHandling = getErrorHandling();
if (errorHandling != null) {
checkArgument(
!Strings.isNullOrEmpty(errorHandling.getOutput()),
invalidConfigMessage + "Output must not be empty if error handling specified.");
}
}

/** Instantiates a {@link TFRecordReadSchemaTransformConfiguration.Builder} instance. */
public static TFRecordReadSchemaTransformConfiguration.Builder builder() {
return new AutoValue_TFRecordReadSchemaTransformConfiguration.Builder();
}

@SchemaFieldDescription("Validate file pattern.")
public abstract boolean getValidate();

@SchemaFieldDescription("Decompression type to use when reading input files.")
public abstract String getCompression();

@SchemaFieldDescription("Filename or file pattern used to find input files.")
public abstract @Nullable String getFilePattern();

@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
public abstract @Nullable ErrorHandling getErrorHandling();

abstract Builder toBuilder();

/** Builder for {@link TFRecordReadSchemaTransformConfiguration}. */
@AutoValue.Builder
public abstract static class Builder {

public abstract Builder setValidate(boolean value);

public abstract Builder setCompression(String value);

public abstract Builder setFilePattern(String value);

public abstract Builder setErrorHandling(@Nullable ErrorHandling errorHandling);

/** Builds the {@link TFRecordReadSchemaTransformConfiguration} configuration. */
public abstract TFRecordReadSchemaTransformConfiguration build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io;

import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;

import com.google.auto.service.AutoService;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@AutoService(SchemaTransformProvider.class)
public class TFRecordReadSchemaTransformProvider
extends TypedSchemaTransformProvider<TFRecordReadSchemaTransformConfiguration> {
private static final String IDENTIFIER = "beam:schematransform:org.apache.beam:tfrecord_read:v1";
private static final String OUTPUT = "output";
private static final String ERROR = "errors";
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
private static final Logger LOG =
LoggerFactory.getLogger(TFRecordReadSchemaTransformProvider.class);

/** Returns the expected class of the configuration. */
@Override
protected Class<TFRecordReadSchemaTransformConfiguration> configurationClass() {
return TFRecordReadSchemaTransformConfiguration.class;
}

/** Returns the expected {@link SchemaTransform} of the configuration. */
@Override
protected SchemaTransform from(TFRecordReadSchemaTransformConfiguration configuration) {
return new TFRecordReadSchemaTransform(configuration);
}

/** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */
@Override
public String identifier() {
return IDENTIFIER;
}

/**
* Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since
* no input is expected, this returns an empty list.
*/
@Override
public List<String> inputCollectionNames() {
return Collections.emptyList();
}

/** Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. */
@Override
public List<String> outputCollectionNames() {
return Arrays.asList(OUTPUT, ERROR);
}

/**
* An implementation of {@link SchemaTransform} for TFRecord read jobs configured using {@link
* TFRecordReadSchemaTransformConfiguration}.
*/
static class TFRecordReadSchemaTransform extends SchemaTransform {
private final TFRecordReadSchemaTransformConfiguration configuration;

TFRecordReadSchemaTransform(TFRecordReadSchemaTransformConfiguration configuration) {
this.configuration = configuration;
}

public Row getConfigurationRow() {
try {
// To stay consistent with our SchemaTransform configuration naming conventions,
// we sort lexicographically
return SchemaRegistry.createDefault()
.getToRowFunction(TFRecordReadSchemaTransformConfiguration.class)
.apply(configuration)
.sorted()
.toSnakeCase();
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e);
}
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
// Validate configuration parameters
configuration.validate();

TFRecordIO.Read readTransform =
TFRecordIO.read().withCompression(Compression.valueOf(configuration.getCompression()));

String filePattern = configuration.getFilePattern();
if (filePattern != null) {
readTransform = readTransform.from(filePattern);
}
if (!configuration.getValidate()) {
readTransform = readTransform.withoutValidation();
}

// Read TFRecord files into a PCollection of byte arrays.
PCollection<byte[]> tfRecordValues = input.getPipeline().apply(readTransform);

// Define the schema for the row
Schema schema = Schema.of(Schema.Field.of("record", Schema.FieldType.BYTES));
Schema errorSchema = ErrorHandling.errorSchemaBytes();
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());

SerializableFunction<byte[], Row> bytesToRowFn = getBytesToRowFn(schema);

// Apply bytes to row fn
PCollectionTuple outputTuple =
tfRecordValues.apply(
ParDo.of(
new ErrorFn(
"TFRecord-read-error-counter", bytesToRowFn, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

PCollectionRowTuple outputRows =
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(schema));

// Error handling
PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows =
outputRows.and(
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput);
}
return outputRows;
}
}

public static SerializableFunction<byte[], Row> getBytesToRowFn(Schema schema) {
return new SimpleFunction<byte[], Row>() {
@Override
public Row apply(byte[] input) {
Row row = Row.withSchema(schema).addValues(input).build();
if (row == null) {
throw new NullPointerException();
}
return row;
}
};
}

public static class ErrorFn extends DoFn<byte[], Row> {
private final SerializableFunction<byte[], Row> valueMapper;
private final Counter errorCounter;
private Long errorsInBundle = 0L;
private final boolean handleErrors;
private final Schema errorSchema;

public ErrorFn(
String name,
SerializableFunction<byte[], Row> valueMapper,
Schema errorSchema,
boolean handleErrors) {
this.errorCounter = Metrics.counter(TFRecordReadSchemaTransformProvider.class, name);
this.valueMapper = valueMapper;
this.handleErrors = handleErrors;
this.errorSchema = errorSchema;
}

@ProcessElement
public void process(@DoFn.Element byte[] msg, MultiOutputReceiver receiver) {
Row mappedRow = null;
try {
mappedRow = valueMapper.apply(msg);
} catch (Exception e) {
if (!handleErrors) {
throw new RuntimeException(e);
}
errorsInBundle += 1;
LOG.warn("Error while parsing the element", e);
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, msg, e));
}
if (mappedRow != null) {
receiver.get(OUTPUT_TAG).output(mappedRow);
}
}

@FinishBundle
public void finish(FinishBundleContext c) {
errorCounter.inc(errorsInBundle);
errorsInBundle = 0L;
}
}
}
Loading
Loading