-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Normalize tf record io #34411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Normalize tf record io #34411
Changes from 22 commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
d693f6c
create TFRecordReadSchemaTransform
derrickaw c2f61ce
create TFRecordWriteSchemaTransform
derrickaw 00ce573
create TFRecordSchemaTransform Test
derrickaw 4453f50
fix conflicts
derrickaw 34238e1
add writeToTFRecord return
derrickaw f067a27
add translation file and getRowConfiguration method for said file
derrickaw 48a4b39
fix formatting
derrickaw 86d56e5
remove print
derrickaw c428774
fix more lint issues
derrickaw cbce5e1
add support for bytes in yaml
derrickaw 35d80c0
update ReadTransform with error handling and use string compression
derrickaw c78d69f
add error handling to writetransform and change compression to string
derrickaw b28acfa
change compression to string
derrickaw dc4f963
update compression type and no_spilling
derrickaw 539e648
update writeToTFRecord parameters
derrickaw 0402693
add tfrecord yaml test pipeline
derrickaw e9d6428
remove old code
derrickaw b944593
fix lint issue
derrickaw 4707213
update standard external transforms with tfrecord info
derrickaw e614f37
remove bad character and broken comment
derrickaw 414bdb0
fix lint issue
derrickaw d030e3f
add no_spilling doc string
derrickaw 1a5c8b5
change tfrecord.yaml to write version
derrickaw 0a0e23d
update parameter name
derrickaw 8eca45e
update read and write yaml for tfrecord
derrickaw 821d4c6
update pipeline to handle write and read
derrickaw 3a17dc1
fix lint issues
derrickaw 80d23b0
fix lint
derrickaw 5158bb1
fix lint, precommit checker is broken for me, so hence many commits :)
derrickaw c98eb38
minor fix
derrickaw 057c681
fix java comments
derrickaw ed83c00
Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWr…
derrickaw 9bb0898
Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordSc…
derrickaw 610bcaa
fix class name change
derrickaw 55359e7
fix MakeItWork test case
derrickaw 2136a47
fix spotless issues
derrickaw fc0fcc6
Update sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWr…
derrickaw fe3b598
remove nullable for outputprefix per comments
derrickaw a64234b
remove no_spilling parameter and extra white space
derrickaw fe7a900
update standard external transforms
derrickaw 78840ff
remove one more no_spilling
derrickaw 875b725
fix nullable on no_spilling
derrickaw 5fced8b
rerun generate external transforms
derrickaw 5ce1adf
fix order of python and java providers
derrickaw 818f1df
revert java and python providers changes
derrickaw fe60137
update test and fix prior test failures
derrickaw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
111 changes: 111 additions & 0 deletions
111
...a/core/src/main/java/org/apache/beam/sdk/io/TFRecordReadSchemaTransformConfiguration.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.beam.sdk.io; | ||
|
||
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; | ||
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; | ||
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; | ||
|
||
import com.google.auto.value.AutoValue; | ||
import java.io.IOException; | ||
import java.io.Serializable; | ||
import javax.annotation.Nullable; | ||
import org.apache.beam.sdk.io.fs.MatchResult; | ||
import org.apache.beam.sdk.schemas.AutoValueSchema; | ||
import org.apache.beam.sdk.schemas.annotations.DefaultSchema; | ||
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; | ||
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; | ||
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; | ||
|
||
/** | ||
* Configuration for reading from TFRecord. | ||
* | ||
* <p>This class is meant to be used with {@link TFRecordReadSchemaTransformProvider}. | ||
* | ||
* <p><b>Internal only:</b> This class is actively being worked on, and it will likely change. We | ||
* provide no backwards compatibility guarantees, and it should not be implemented outside the Beam | ||
* repository. | ||
*/ | ||
@DefaultSchema(AutoValueSchema.class) | ||
@AutoValue | ||
public abstract class TFRecordReadSchemaTransformConfiguration implements Serializable { | ||
|
||
public void validate() { | ||
String invalidConfigMessage = "Invalid TFRecord Read configuration: "; | ||
checkNotNull(getValidate(), "To read from TFRecord, validation must be specified."); | ||
checkNotNull(getCompression(), "To read from TFRecord, compression must be specified."); | ||
ahmedabu98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
String filePattern = getFilePattern(); | ||
if (filePattern == null || filePattern.isEmpty()) { | ||
throw new IllegalStateException( | ||
"Need to set the filepattern of a TFRecordReadSchema transform"); | ||
} | ||
ahmedabu98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (getValidate()) { | ||
try { | ||
MatchResult matches = FileSystems.match(filePattern); | ||
checkState( | ||
!matches.metadata().isEmpty(), "Unable to find any files matching %s", filePattern); | ||
} catch (IOException e) { | ||
throw new IllegalStateException( | ||
String.format(invalidConfigMessage + "Failed to validate %s", filePattern), e); | ||
} | ||
} | ||
|
||
ErrorHandling errorHandling = getErrorHandling(); | ||
if (errorHandling != null) { | ||
checkArgument( | ||
!Strings.isNullOrEmpty(errorHandling.getOutput()), | ||
invalidConfigMessage + "Output must not be empty if error handling specified."); | ||
} | ||
} | ||
|
||
/** Instantiates a {@link TFRecordReadSchemaTransformConfiguration.Builder} instance. */ | ||
public static TFRecordReadSchemaTransformConfiguration.Builder builder() { | ||
return new AutoValue_TFRecordReadSchemaTransformConfiguration.Builder(); | ||
} | ||
|
||
@SchemaFieldDescription("Validate file pattern.") | ||
public abstract boolean getValidate(); | ||
|
||
@SchemaFieldDescription("Decompression type to use when reading input files.") | ||
public abstract String getCompression(); | ||
|
||
@SchemaFieldDescription("Filename or file pattern used to find input files.") | ||
public abstract @Nullable String getFilePattern(); | ||
|
||
@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") | ||
public abstract @Nullable ErrorHandling getErrorHandling(); | ||
|
||
abstract Builder toBuilder(); | ||
|
||
/** Builder for {@link TFRecordReadSchemaTransformConfiguration}. */ | ||
@AutoValue.Builder | ||
public abstract static class Builder { | ||
|
||
public abstract Builder setValidate(boolean value); | ||
|
||
public abstract Builder setCompression(String value); | ||
|
||
public abstract Builder setFilePattern(String value); | ||
|
||
public abstract Builder setErrorHandling(@Nullable ErrorHandling errorHandling); | ||
|
||
/** Builds the {@link TFRecordReadSchemaTransformConfiguration} configuration. */ | ||
public abstract TFRecordReadSchemaTransformConfiguration build(); | ||
} | ||
} |
221 changes: 221 additions & 0 deletions
221
sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordReadSchemaTransformProvider.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.beam.sdk.io; | ||
|
||
import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; | ||
|
||
import com.google.auto.service.AutoService; | ||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import org.apache.beam.sdk.metrics.Counter; | ||
import org.apache.beam.sdk.metrics.Metrics; | ||
import org.apache.beam.sdk.schemas.NoSuchSchemaException; | ||
import org.apache.beam.sdk.schemas.Schema; | ||
import org.apache.beam.sdk.schemas.SchemaRegistry; | ||
import org.apache.beam.sdk.schemas.transforms.SchemaTransform; | ||
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; | ||
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; | ||
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; | ||
import org.apache.beam.sdk.transforms.DoFn; | ||
import org.apache.beam.sdk.transforms.DoFn.ProcessElement; | ||
import org.apache.beam.sdk.transforms.ParDo; | ||
import org.apache.beam.sdk.transforms.SerializableFunction; | ||
import org.apache.beam.sdk.transforms.SimpleFunction; | ||
import org.apache.beam.sdk.values.PCollection; | ||
import org.apache.beam.sdk.values.PCollectionRowTuple; | ||
import org.apache.beam.sdk.values.PCollectionTuple; | ||
import org.apache.beam.sdk.values.Row; | ||
import org.apache.beam.sdk.values.TupleTag; | ||
import org.apache.beam.sdk.values.TupleTagList; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
@AutoService(SchemaTransformProvider.class) | ||
public class TFRecordReadSchemaTransformProvider | ||
extends TypedSchemaTransformProvider<TFRecordReadSchemaTransformConfiguration> { | ||
private static final String IDENTIFIER = "beam:schematransform:org.apache.beam:tfrecord_read:v1"; | ||
private static final String OUTPUT = "output"; | ||
private static final String ERROR = "errors"; | ||
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {}; | ||
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {}; | ||
private static final Logger LOG = | ||
LoggerFactory.getLogger(TFRecordReadSchemaTransformProvider.class); | ||
|
||
/** Returns the expected class of the configuration. */ | ||
@Override | ||
protected Class<TFRecordReadSchemaTransformConfiguration> configurationClass() { | ||
return TFRecordReadSchemaTransformConfiguration.class; | ||
} | ||
ahmedabu98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/** Returns the expected {@link SchemaTransform} of the configuration. */ | ||
@Override | ||
protected SchemaTransform from(TFRecordReadSchemaTransformConfiguration configuration) { | ||
return new TFRecordReadSchemaTransform(configuration); | ||
} | ||
|
||
/** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */ | ||
@Override | ||
public String identifier() { | ||
return IDENTIFIER; | ||
} | ||
|
||
/** | ||
* Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since | ||
* no input is expected, this returns an empty list. | ||
*/ | ||
@Override | ||
public List<String> inputCollectionNames() { | ||
return Collections.emptyList(); | ||
} | ||
ahmedabu98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/** Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. */ | ||
@Override | ||
public List<String> outputCollectionNames() { | ||
return Arrays.asList(OUTPUT, ERROR); | ||
} | ||
|
||
/** | ||
* An implementation of {@link SchemaTransform} for TFRecord read jobs configured using {@link | ||
* TFRecordReadSchemaTransformConfiguration}. | ||
*/ | ||
static class TFRecordReadSchemaTransform extends SchemaTransform { | ||
private final TFRecordReadSchemaTransformConfiguration configuration; | ||
|
||
TFRecordReadSchemaTransform(TFRecordReadSchemaTransformConfiguration configuration) { | ||
this.configuration = configuration; | ||
} | ||
|
||
public Row getConfigurationRow() { | ||
try { | ||
// To stay consistent with our SchemaTransform configuration naming conventions, | ||
// we sort lexicographically | ||
return SchemaRegistry.createDefault() | ||
.getToRowFunction(TFRecordReadSchemaTransformConfiguration.class) | ||
.apply(configuration) | ||
.sorted() | ||
.toSnakeCase(); | ||
} catch (NoSuchSchemaException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
@Override | ||
public PCollectionRowTuple expand(PCollectionRowTuple input) { | ||
// Validate configuration parameters | ||
configuration.validate(); | ||
|
||
TFRecordIO.Read readTransform = | ||
TFRecordIO.read().withCompression(Compression.valueOf(configuration.getCompression())); | ||
|
||
String filePattern = configuration.getFilePattern(); | ||
if (filePattern != null) { | ||
readTransform = readTransform.from(filePattern); | ||
} | ||
if (!configuration.getValidate()) { | ||
readTransform = readTransform.withoutValidation(); | ||
} | ||
|
||
// Read TFRecord files into a PCollection of byte arrays. | ||
PCollection<byte[]> tfRecordValues = input.getPipeline().apply(readTransform); | ||
|
||
// Define the schema for the row | ||
Schema schema = Schema.of(Schema.Field.of("record", Schema.FieldType.BYTES)); | ||
ahmedabu98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Schema errorSchema = ErrorHandling.errorSchemaBytes(); | ||
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); | ||
|
||
SerializableFunction<byte[], Row> bytesToRowFn = getBytesToRowFn(schema); | ||
|
||
// Apply bytes to row fn | ||
PCollectionTuple outputTuple = | ||
tfRecordValues.apply( | ||
ParDo.of( | ||
new ErrorFn( | ||
"TFRecord-read-error-counter", bytesToRowFn, errorSchema, handleErrors)) | ||
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); | ||
|
||
PCollectionRowTuple outputRows = | ||
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(schema)); | ||
|
||
// Error handling | ||
PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); | ||
if (handleErrors) { | ||
outputRows = | ||
outputRows.and( | ||
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput); | ||
} | ||
return outputRows; | ||
} | ||
} | ||
|
||
public static SerializableFunction<byte[], Row> getBytesToRowFn(Schema schema) { | ||
return new SimpleFunction<byte[], Row>() { | ||
@Override | ||
public Row apply(byte[] input) { | ||
Row row = Row.withSchema(schema).addValues(input).build(); | ||
if (row == null) { | ||
throw new NullPointerException(); | ||
} | ||
ahmedabu98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return row; | ||
} | ||
}; | ||
} | ||
|
||
public static class ErrorFn extends DoFn<byte[], Row> { | ||
private final SerializableFunction<byte[], Row> valueMapper; | ||
private final Counter errorCounter; | ||
private Long errorsInBundle = 0L; | ||
private final boolean handleErrors; | ||
private final Schema errorSchema; | ||
|
||
public ErrorFn( | ||
String name, | ||
SerializableFunction<byte[], Row> valueMapper, | ||
Schema errorSchema, | ||
boolean handleErrors) { | ||
this.errorCounter = Metrics.counter(TFRecordReadSchemaTransformProvider.class, name); | ||
this.valueMapper = valueMapper; | ||
this.handleErrors = handleErrors; | ||
this.errorSchema = errorSchema; | ||
} | ||
|
||
@ProcessElement | ||
public void process(@DoFn.Element byte[] msg, MultiOutputReceiver receiver) { | ||
Row mappedRow = null; | ||
try { | ||
mappedRow = valueMapper.apply(msg); | ||
} catch (Exception e) { | ||
if (!handleErrors) { | ||
throw new RuntimeException(e); | ||
} | ||
errorsInBundle += 1; | ||
LOG.warn("Error while parsing the element", e); | ||
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, msg, e)); | ||
} | ||
if (mappedRow != null) { | ||
receiver.get(OUTPUT_TAG).output(mappedRow); | ||
} | ||
} | ||
|
||
@FinishBundle | ||
public void finish(FinishBundleContext c) { | ||
errorCounter.inc(errorsInBundle); | ||
errorsInBundle = 0L; | ||
} | ||
} | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.