Skip to content

Normalize tf record io #34411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d693f6c
create TFRecordReadSchemaTransform
derrickaw Mar 24, 2025
c2f61ce
create TFRecordWriteSchemaTransform
derrickaw Mar 24, 2025
00ce573
create TFRecordSchemaTransform Test
derrickaw Mar 24, 2025
4453f50
fix conflicts
derrickaw Mar 27, 2025
34238e1
add writeToTFRecord return
derrickaw Mar 24, 2025
f067a27
add translation file and getRowConfiguration method for said file
derrickaw Mar 25, 2025
48a4b39
fix formatting
derrickaw Mar 25, 2025
86d56e5
remove print
derrickaw Mar 25, 2025
c428774
fix more lint issues
derrickaw Mar 25, 2025
cbce5e1
add support for bytes in yaml
derrickaw Mar 27, 2025
35d80c0
update ReadTransform with error handling and use string compression
derrickaw Mar 27, 2025
c78d69f
add error handling to writetransform and change compression to string
derrickaw Mar 27, 2025
b28acfa
change compression to string
derrickaw Mar 27, 2025
dc4f963
update compression type and no_spilling
derrickaw Mar 27, 2025
539e648
update writeToTFRecord parameters
derrickaw Mar 27, 2025
0402693
add tfrecord yaml test pipeline
derrickaw Mar 27, 2025
e9d6428
remove old code
derrickaw Mar 27, 2025
b944593
fix lint issue
derrickaw Mar 27, 2025
4707213
update standard external transforms with tfrecord info
derrickaw Mar 27, 2025
e614f37
remove bad character and broken comment
derrickaw Mar 27, 2025
414bdb0
fix lint issue
derrickaw Mar 27, 2025
d030e3f
add no_spilling doc string
derrickaw Mar 27, 2025
1a5c8b5
change tfrecord.yaml to write version
derrickaw Mar 28, 2025
0a0e23d
update parameter name
derrickaw Mar 29, 2025
8eca45e
update read and write yaml for tfrecord
derrickaw Mar 29, 2025
821d4c6
update pipeline to handle write and read
derrickaw Mar 29, 2025
3a17dc1
fix lint issues
derrickaw Mar 29, 2025
80d23b0
fix lint
derrickaw Mar 29, 2025
5158bb1
fix lint, precommit checker is broken for me, so hence many commits :)
derrickaw Mar 29, 2025
c98eb38
minor fix
derrickaw Mar 29, 2025
057c681
fix java comments
derrickaw Apr 7, 2025
ed83c00
Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWr…
derrickaw Apr 7, 2025
9bb0898
Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordSc…
derrickaw Apr 7, 2025
610bcaa
fix class name change
derrickaw Apr 7, 2025
55359e7
fix MakeItWork test case
derrickaw Apr 7, 2025
2136a47
fix spotless issues
derrickaw Apr 7, 2025
fc0fcc6
Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWr…
derrickaw Apr 7, 2025
fe3b598
remove nullable for outputprefix per comments
derrickaw Apr 8, 2025
a64234b
remove no_spilling parameter and extra white space
derrickaw Apr 9, 2025
fe7a900
update standard external transforms
derrickaw Apr 9, 2025
78840ff
remove one more no_spilling
derrickaw Apr 10, 2025
875b725
fix nullable on no_spilling
derrickaw Apr 10, 2025
5fced8b
rerun generate external transforms
derrickaw Apr 10, 2025
5ce1adf
fix order of python and java providers
derrickaw Apr 11, 2025
818f1df
revert java and python providers changes
derrickaw Apr 11, 2025
fe60137
update test and fix prior test failures
derrickaw Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.io.Serializable;
import javax.annotation.Nullable;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;

/**
* Configuration for reading from TFRecord.
*
* <p>This class is meant to be used with {@link TFRecordReadSchemaTransformProvider}.
*
* <p><b>Internal only:</b> This class is actively being worked on, and it will likely change. We
* provide no backwards compatibility guarantees, and it should not be implemented outside the Beam
* repository.
*/
@DefaultSchema(AutoValueSchema.class)
@AutoValue
public abstract class TFRecordReadSchemaTransformConfiguration implements Serializable {

public void validate() {
String invalidConfigMessage = "Invalid TFRecord Read configuration: ";
checkNotNull(getValidate(), "To read from TFRecord, validation must be specified.");
checkNotNull(getCompression(), "To read from TFRecord, compression must be specified.");

String filePattern = getFilePattern();
if (filePattern == null || filePattern.isEmpty()) {
throw new IllegalStateException(
"Need to set the filepattern of a TFRecordReadSchema transform");
}
if (getValidate()) {
try {
MatchResult matches = FileSystems.match(filePattern);
checkState(
!matches.metadata().isEmpty(), "Unable to find any files matching %s", filePattern);
} catch (IOException e) {
throw new IllegalStateException(
String.format(invalidConfigMessage + "Failed to validate %s", filePattern), e);
}
}

ErrorHandling errorHandling = getErrorHandling();
if (errorHandling != null) {
checkArgument(
!Strings.isNullOrEmpty(errorHandling.getOutput()),
invalidConfigMessage + "Output must not be empty if error handling specified.");
}
}

/** Instantiates a {@link TFRecordReadSchemaTransformConfiguration.Builder} instance. */
public static TFRecordReadSchemaTransformConfiguration.Builder builder() {
return new AutoValue_TFRecordReadSchemaTransformConfiguration.Builder();
}

@SchemaFieldDescription("Validate file pattern.")
public abstract boolean getValidate();

@SchemaFieldDescription("Decompression type to use when reading input files.")
public abstract String getCompression();

@SchemaFieldDescription("Filename or file pattern used to find input files.")
public abstract @Nullable String getFilePattern();

@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
public abstract @Nullable ErrorHandling getErrorHandling();

abstract Builder toBuilder();

/** Builder for {@link TFRecordReadSchemaTransformConfiguration}. */
@AutoValue.Builder
public abstract static class Builder {

public abstract Builder setValidate(boolean value);

public abstract Builder setCompression(String value);

public abstract Builder setFilePattern(String value);

public abstract Builder setErrorHandling(@Nullable ErrorHandling errorHandling);

/** Builds the {@link TFRecordReadSchemaTransformConfiguration} configuration. */
public abstract TFRecordReadSchemaTransformConfiguration build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io;

import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;

import com.google.auto.service.AutoService;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@AutoService(SchemaTransformProvider.class)
public class TFRecordReadSchemaTransformProvider
extends TypedSchemaTransformProvider<TFRecordReadSchemaTransformConfiguration> {
private static final String IDENTIFIER = "beam:schematransform:org.apache.beam:tfrecord_read:v1";
private static final String OUTPUT = "output";
private static final String ERROR = "errors";
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
private static final Logger LOG =
LoggerFactory.getLogger(TFRecordReadSchemaTransformProvider.class);

/** Returns the expected class of the configuration. */
@Override
protected Class<TFRecordReadSchemaTransformConfiguration> configurationClass() {
return TFRecordReadSchemaTransformConfiguration.class;
}

/** Returns the expected {@link SchemaTransform} of the configuration. */
@Override
protected SchemaTransform from(TFRecordReadSchemaTransformConfiguration configuration) {
return new TFRecordReadSchemaTransform(configuration);
}

/** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */
@Override
public String identifier() {
return IDENTIFIER;
}

/**
* Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since
* no input is expected, this returns an empty list.
*/
@Override
public List<String> inputCollectionNames() {
return Collections.emptyList();
}

/** Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. */
@Override
public List<String> outputCollectionNames() {
return Arrays.asList(OUTPUT, ERROR);
}

/**
* An implementation of {@link SchemaTransform} for TFRecord read jobs configured using {@link
* TFRecordReadSchemaTransformConfiguration}.
*/
static class TFRecordReadSchemaTransform extends SchemaTransform {
private final TFRecordReadSchemaTransformConfiguration configuration;

TFRecordReadSchemaTransform(TFRecordReadSchemaTransformConfiguration configuration) {
this.configuration = configuration;
}

public Row getConfigurationRow() {
try {
// To stay consistent with our SchemaTransform configuration naming conventions,
// we sort lexicographically
return SchemaRegistry.createDefault()
.getToRowFunction(TFRecordReadSchemaTransformConfiguration.class)
.apply(configuration)
.sorted()
.toSnakeCase();
} catch (NoSuchSchemaException e) {
throw new RuntimeException(e);
}
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
// Validate configuration parameters
configuration.validate();

TFRecordIO.Read readTransform =
TFRecordIO.read().withCompression(Compression.valueOf(configuration.getCompression()));

String filePattern = configuration.getFilePattern();
if (filePattern != null) {
readTransform = readTransform.from(filePattern);
}
if (!configuration.getValidate()) {
readTransform = readTransform.withoutValidation();
}

// Read TFRecord files into a PCollection of byte arrays.
PCollection<byte[]> tfRecordValues = input.getPipeline().apply(readTransform);

// Define the schema for the row
Schema schema = Schema.of(Schema.Field.of("record", Schema.FieldType.BYTES));
Schema errorSchema = ErrorHandling.errorSchemaBytes();
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());

SerializableFunction<byte[], Row> bytesToRowFn = getBytesToRowFn(schema);

// Apply bytes to row fn
PCollectionTuple outputTuple =
tfRecordValues.apply(
ParDo.of(
new ErrorFn(
"TFRecord-read-error-counter", bytesToRowFn, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

PCollectionRowTuple outputRows =
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(schema));

// Error handling
PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows =
outputRows.and(
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput);
}
return outputRows;
}
}

public static SerializableFunction<byte[], Row> getBytesToRowFn(Schema schema) {
return new SimpleFunction<byte[], Row>() {
@Override
public Row apply(byte[] input) {
Row row = Row.withSchema(schema).addValues(input).build();
if (row == null) {
throw new NullPointerException();
}
return row;
}
};
}

public static class ErrorFn extends DoFn<byte[], Row> {
private final SerializableFunction<byte[], Row> valueMapper;
private final Counter errorCounter;
private Long errorsInBundle = 0L;
private final boolean handleErrors;
private final Schema errorSchema;

public ErrorFn(
String name,
SerializableFunction<byte[], Row> valueMapper,
Schema errorSchema,
boolean handleErrors) {
this.errorCounter = Metrics.counter(TFRecordReadSchemaTransformProvider.class, name);
this.valueMapper = valueMapper;
this.handleErrors = handleErrors;
this.errorSchema = errorSchema;
}

@ProcessElement
public void process(@DoFn.Element byte[] msg, MultiOutputReceiver receiver) {
Row mappedRow = null;
try {
mappedRow = valueMapper.apply(msg);
} catch (Exception e) {
if (!handleErrors) {
throw new RuntimeException(e);
}
errorsInBundle += 1;
LOG.warn("Error while parsing the element", e);
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, msg, e));
}
if (mappedRow != null) {
receiver.get(OUTPUT_TAG).output(mappedRow);
}
}

@FinishBundle
public void finish(FinishBundleContext c) {
errorCounter.inc(errorsInBundle);
errorsInBundle = 0L;
}
}
}
Loading
Loading