-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Normalize tf record io #34411
Draft
derrickaw
wants to merge
30
commits into
apache:master
Choose a base branch
from
derrickaw:normalizeTFRecordIO
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Normalize tf record io #34411
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
d693f6c
create TFRecordReadSchemaTransform
derrickaw c2f61ce
create TFRecordWriteSchemaTransform
derrickaw 00ce573
create TFRecordSchemaTransform Test
derrickaw 4453f50
fix conflicts
derrickaw 34238e1
add writeToTFRecord return
derrickaw f067a27
add translation file and getRowConfiguration method for said file
derrickaw 48a4b39
fix formatting
derrickaw 86d56e5
remove print
derrickaw c428774
fix more lint issues
derrickaw cbce5e1
add support for bytes in yaml
derrickaw 35d80c0
update ReadTransform with error handling and use string compression
derrickaw c78d69f
add error handling to writetransform and change compression to string
derrickaw b28acfa
change compression to string
derrickaw dc4f963
update compression type and no_spilling
derrickaw 539e648
update writeToTFRecord parameters
derrickaw 0402693
add tfrecord yaml test pipeline
derrickaw e9d6428
remove old code
derrickaw b944593
fix lint issue
derrickaw 4707213
update standard external transforms with tfrecord info
derrickaw e614f37
remove bad character and broken comment
derrickaw 414bdb0
fix lint issue
derrickaw d030e3f
add no_spilling doc string
derrickaw 1a5c8b5
change tfrecord.yaml to write version
derrickaw 0a0e23d
update parameter name
derrickaw 8eca45e
update read and write yaml for tfrecord
derrickaw 821d4c6
update pipeline to handle write and read
derrickaw 3a17dc1
fix lint issues
derrickaw 80d23b0
fix lint
derrickaw 5158bb1
fix lint, precommit checker is broken for me, so hence many commits :)
derrickaw c98eb38
minor fix
derrickaw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
111 changes: 111 additions & 0 deletions
111
...a/core/src/main/java/org/apache/beam/sdk/io/TFRecordReadSchemaTransformConfiguration.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.beam.sdk.io; | ||
|
||
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; | ||
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; | ||
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; | ||
|
||
import com.google.auto.value.AutoValue; | ||
import java.io.IOException; | ||
import java.io.Serializable; | ||
import javax.annotation.Nullable; | ||
import org.apache.beam.sdk.io.fs.MatchResult; | ||
import org.apache.beam.sdk.schemas.AutoValueSchema; | ||
import org.apache.beam.sdk.schemas.annotations.DefaultSchema; | ||
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; | ||
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; | ||
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; | ||
|
||
/** | ||
* Configuration for reading from TFRecord. | ||
* | ||
* <p>This class is meant to be used with {@link TFRecordReadSchemaTransformProvider}. | ||
* | ||
* <p><b>Internal only:</b> This class is actively being worked on, and it will likely change. We | ||
* provide no backwards compatibility guarantees, and it should not be implemented outside the Beam | ||
* repository. | ||
*/ | ||
@DefaultSchema(AutoValueSchema.class) | ||
@AutoValue | ||
public abstract class TFRecordReadSchemaTransformConfiguration implements Serializable { | ||
|
||
public void validate() { | ||
String invalidConfigMessage = "Invalid TFRecord Read configuration: "; | ||
checkNotNull(getValidate(), "To read from TFRecord, validation must be specified."); | ||
checkNotNull(getCompression(), "To read from TFRecord, compression must be specified."); | ||
|
||
String filePattern = getFilePattern(); | ||
if (filePattern == null || filePattern.isEmpty()) { | ||
throw new IllegalStateException( | ||
"Need to set the filepattern of a TFRecordReadSchema transform"); | ||
} | ||
if (getValidate()) { | ||
try { | ||
MatchResult matches = FileSystems.match(filePattern); | ||
checkState( | ||
!matches.metadata().isEmpty(), "Unable to find any files matching %s", filePattern); | ||
} catch (IOException e) { | ||
throw new IllegalStateException( | ||
String.format(invalidConfigMessage + "Failed to validate %s", filePattern), e); | ||
} | ||
} | ||
|
||
ErrorHandling errorHandling = getErrorHandling(); | ||
if (errorHandling != null) { | ||
checkArgument( | ||
!Strings.isNullOrEmpty(errorHandling.getOutput()), | ||
invalidConfigMessage + "Output must not be empty if error handling specified."); | ||
} | ||
} | ||
|
||
/** Instantiates a {@link TFRecordReadSchemaTransformConfiguration.Builder} instance. */ | ||
public static TFRecordReadSchemaTransformConfiguration.Builder builder() { | ||
return new AutoValue_TFRecordReadSchemaTransformConfiguration.Builder(); | ||
} | ||
|
||
@SchemaFieldDescription("Validate file pattern.") | ||
public abstract boolean getValidate(); | ||
|
||
@SchemaFieldDescription("Decompression type to use when reading input files.") | ||
public abstract String getCompression(); | ||
|
||
@SchemaFieldDescription("Filename or file pattern used to find input files.") | ||
public abstract @Nullable String getFilePattern(); | ||
|
||
@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") | ||
public abstract @Nullable ErrorHandling getErrorHandling(); | ||
|
||
abstract Builder toBuilder(); | ||
|
||
/** Builder for {@link TFRecordReadSchemaTransformConfiguration}. */ | ||
@AutoValue.Builder | ||
public abstract static class Builder { | ||
|
||
public abstract Builder setValidate(boolean value); | ||
|
||
public abstract Builder setCompression(String value); | ||
|
||
public abstract Builder setFilePattern(String value); | ||
|
||
public abstract Builder setErrorHandling(@Nullable ErrorHandling errorHandling); | ||
|
||
/** Builds the {@link TFRecordReadSchemaTransformConfiguration} configuration. */ | ||
public abstract TFRecordReadSchemaTransformConfiguration build(); | ||
} | ||
} |
221 changes: 221 additions & 0 deletions
221
sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordReadSchemaTransformProvider.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.beam.sdk.io; | ||
|
||
import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; | ||
|
||
import com.google.auto.service.AutoService; | ||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import org.apache.beam.sdk.metrics.Counter; | ||
import org.apache.beam.sdk.metrics.Metrics; | ||
import org.apache.beam.sdk.schemas.NoSuchSchemaException; | ||
import org.apache.beam.sdk.schemas.Schema; | ||
import org.apache.beam.sdk.schemas.SchemaRegistry; | ||
import org.apache.beam.sdk.schemas.transforms.SchemaTransform; | ||
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; | ||
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; | ||
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling; | ||
import org.apache.beam.sdk.transforms.DoFn; | ||
import org.apache.beam.sdk.transforms.DoFn.ProcessElement; | ||
import org.apache.beam.sdk.transforms.ParDo; | ||
import org.apache.beam.sdk.transforms.SerializableFunction; | ||
import org.apache.beam.sdk.transforms.SimpleFunction; | ||
import org.apache.beam.sdk.values.PCollection; | ||
import org.apache.beam.sdk.values.PCollectionRowTuple; | ||
import org.apache.beam.sdk.values.PCollectionTuple; | ||
import org.apache.beam.sdk.values.Row; | ||
import org.apache.beam.sdk.values.TupleTag; | ||
import org.apache.beam.sdk.values.TupleTagList; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
@AutoService(SchemaTransformProvider.class) | ||
public class TFRecordReadSchemaTransformProvider | ||
extends TypedSchemaTransformProvider<TFRecordReadSchemaTransformConfiguration> { | ||
private static final String IDENTIFIER = "beam:schematransform:org.apache.beam:tfrecord_read:v1"; | ||
private static final String OUTPUT = "output"; | ||
private static final String ERROR = "errors"; | ||
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {}; | ||
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {}; | ||
private static final Logger LOG = | ||
LoggerFactory.getLogger(TFRecordReadSchemaTransformProvider.class); | ||
|
||
/** Returns the expected class of the configuration. */ | ||
@Override | ||
protected Class<TFRecordReadSchemaTransformConfiguration> configurationClass() { | ||
return TFRecordReadSchemaTransformConfiguration.class; | ||
} | ||
|
||
/** Returns the expected {@link SchemaTransform} of the configuration. */ | ||
@Override | ||
protected SchemaTransform from(TFRecordReadSchemaTransformConfiguration configuration) { | ||
return new TFRecordReadSchemaTransform(configuration); | ||
} | ||
|
||
/** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */ | ||
@Override | ||
public String identifier() { | ||
return IDENTIFIER; | ||
} | ||
|
||
/** | ||
* Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since | ||
* no input is expected, this returns an empty list. | ||
*/ | ||
@Override | ||
public List<String> inputCollectionNames() { | ||
return Collections.emptyList(); | ||
} | ||
|
||
/** Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. */ | ||
@Override | ||
public List<String> outputCollectionNames() { | ||
return Arrays.asList(OUTPUT, ERROR); | ||
} | ||
|
||
/** | ||
* An implementation of {@link SchemaTransform} for TFRecord read jobs configured using {@link | ||
* TFRecordReadSchemaTransformConfiguration}. | ||
*/ | ||
static class TFRecordReadSchemaTransform extends SchemaTransform { | ||
private final TFRecordReadSchemaTransformConfiguration configuration; | ||
|
||
TFRecordReadSchemaTransform(TFRecordReadSchemaTransformConfiguration configuration) { | ||
this.configuration = configuration; | ||
} | ||
|
||
public Row getConfigurationRow() { | ||
try { | ||
// To stay consistent with our SchemaTransform configuration naming conventions, | ||
// we sort lexicographically | ||
return SchemaRegistry.createDefault() | ||
.getToRowFunction(TFRecordReadSchemaTransformConfiguration.class) | ||
.apply(configuration) | ||
.sorted() | ||
.toSnakeCase(); | ||
} catch (NoSuchSchemaException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
@Override | ||
public PCollectionRowTuple expand(PCollectionRowTuple input) { | ||
// Validate configuration parameters | ||
configuration.validate(); | ||
|
||
TFRecordIO.Read readTransform = | ||
TFRecordIO.read().withCompression(Compression.valueOf(configuration.getCompression())); | ||
|
||
String filePattern = configuration.getFilePattern(); | ||
if (filePattern != null) { | ||
readTransform = readTransform.from(filePattern); | ||
} | ||
if (!configuration.getValidate()) { | ||
readTransform = readTransform.withoutValidation(); | ||
} | ||
|
||
// Read TFRecord files into a PCollection of byte arrays. | ||
PCollection<byte[]> tfRecordValues = input.getPipeline().apply(readTransform); | ||
|
||
// Define the schema for the row | ||
Schema schema = Schema.of(Schema.Field.of("record", Schema.FieldType.BYTES)); | ||
Schema errorSchema = ErrorHandling.errorSchemaBytes(); | ||
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling()); | ||
|
||
SerializableFunction<byte[], Row> bytesToRowFn = getBytesToRowFn(schema); | ||
|
||
// Apply bytes to row fn | ||
PCollectionTuple outputTuple = | ||
tfRecordValues.apply( | ||
ParDo.of( | ||
new ErrorFn( | ||
"TFRecord-read-error-counter", bytesToRowFn, errorSchema, handleErrors)) | ||
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))); | ||
|
||
PCollectionRowTuple outputRows = | ||
PCollectionRowTuple.of("output", outputTuple.get(OUTPUT_TAG).setRowSchema(schema)); | ||
|
||
// Error handling | ||
PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema); | ||
if (handleErrors) { | ||
outputRows = | ||
outputRows.and( | ||
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), errorOutput); | ||
} | ||
return outputRows; | ||
} | ||
} | ||
|
||
public static SerializableFunction<byte[], Row> getBytesToRowFn(Schema schema) { | ||
return new SimpleFunction<byte[], Row>() { | ||
@Override | ||
public Row apply(byte[] input) { | ||
Row row = Row.withSchema(schema).addValues(input).build(); | ||
if (row == null) { | ||
throw new NullPointerException(); | ||
} | ||
return row; | ||
} | ||
}; | ||
} | ||
|
||
public static class ErrorFn extends DoFn<byte[], Row> { | ||
private final SerializableFunction<byte[], Row> valueMapper; | ||
private final Counter errorCounter; | ||
private Long errorsInBundle = 0L; | ||
private final boolean handleErrors; | ||
private final Schema errorSchema; | ||
|
||
public ErrorFn( | ||
String name, | ||
SerializableFunction<byte[], Row> valueMapper, | ||
Schema errorSchema, | ||
boolean handleErrors) { | ||
this.errorCounter = Metrics.counter(TFRecordReadSchemaTransformProvider.class, name); | ||
this.valueMapper = valueMapper; | ||
this.handleErrors = handleErrors; | ||
this.errorSchema = errorSchema; | ||
} | ||
|
||
@ProcessElement | ||
public void process(@DoFn.Element byte[] msg, MultiOutputReceiver receiver) { | ||
Row mappedRow = null; | ||
try { | ||
mappedRow = valueMapper.apply(msg); | ||
} catch (Exception e) { | ||
if (!handleErrors) { | ||
throw new RuntimeException(e); | ||
} | ||
errorsInBundle += 1; | ||
LOG.warn("Error while parsing the element", e); | ||
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, msg, e)); | ||
} | ||
if (mappedRow != null) { | ||
receiver.get(OUTPUT_TAG).output(mappedRow); | ||
} | ||
} | ||
|
||
@FinishBundle | ||
public void finish(FinishBundleContext c) { | ||
errorCounter.inc(errorsInBundle); | ||
errorsInBundle = 0L; | ||
} | ||
} | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once this is ready for full review (IMO, this is probably now if we can add a few integration tests). I'd recommend doing a separate PR for read/write since they're not really tied together at all. That will make it easier to review/iterate. It also will hopefully unblock the read review while dealing with the issues you're seeing on write