Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,27 @@ private static class ColumnIOCreatorVisitor implements TypeVisitor {
private final boolean validating;
private final MessageType requestedSchema;
private final String createdBy;
private final boolean strictUnsignedIntegerValidation;
private int currentRequestedIndex;
private Type currentRequestedType;
private boolean strictTypeChecking;

private ColumnIOCreatorVisitor(
boolean validating, MessageType requestedSchema, String createdBy, boolean strictTypeChecking) {
boolean validating,
MessageType requestedSchema,
String createdBy,
boolean strictTypeChecking,
boolean strictUnsignedIntegerValidation) {
this.validating = validating;
this.requestedSchema = requestedSchema;
this.createdBy = createdBy;
this.strictTypeChecking = strictTypeChecking;
this.strictUnsignedIntegerValidation = strictUnsignedIntegerValidation;
}

@Override
public void visit(MessageType messageType) {
columnIO = new MessageColumnIO(requestedSchema, validating, createdBy);
columnIO = new MessageColumnIO(requestedSchema, validating, strictUnsignedIntegerValidation, createdBy);
visitChildren(columnIO, messageType, requestedSchema);
columnIO.setLevels();
columnIO.setLeaves(leaves);
Expand Down Expand Up @@ -113,12 +119,13 @@ public MessageColumnIO getColumnIO() {

private final String createdBy;
private final boolean validating;
private final boolean strictUnsignedIntegerValidation;

/**
* validation is off by default
*/
public ColumnIOFactory() {
this(null, false);
this(null, false, false);
}

/**
Expand All @@ -127,24 +134,42 @@ public ColumnIOFactory() {
* @param createdBy createdBy string for readers
*/
public ColumnIOFactory(String createdBy) {
this(createdBy, false);
this(createdBy, false, false);
}

/**
* @param validating to turn validation on
*/
public ColumnIOFactory(boolean validating) {
this(null, validating);
this(null, validating, false);
}

/**
* @param validating to turn validation on
* @param strictUnsignedIntegerValidation to turn strict unsigned integer validation on
*/
public ColumnIOFactory(boolean validating, boolean strictUnsignedIntegerValidation) {
this(null, validating, strictUnsignedIntegerValidation);
}

/**
* @param createdBy createdBy string for readers
* @param validating to turn validation on
*/
public ColumnIOFactory(String createdBy, boolean validating) {
this(createdBy, validating, false);
}

/**
* @param createdBy createdBy string for readers
* @param validating to turn validation on
* @param strictUnsignedIntegerValidation to turn strict unsigned integer validation on
*/
public ColumnIOFactory(String createdBy, boolean validating, boolean strictUnsignedIntegerValidation) {
super();
this.createdBy = createdBy;
this.validating = validating;
this.strictUnsignedIntegerValidation = strictUnsignedIntegerValidation;
}

/**
Expand All @@ -163,7 +188,8 @@ public MessageColumnIO getColumnIO(MessageType requestedSchema, MessageType file
* @return the corresponding serializing/deserializing structure
*/
public MessageColumnIO getColumnIO(MessageType requestedSchema, MessageType fileSchema, boolean strict) {
ColumnIOCreatorVisitor visitor = new ColumnIOCreatorVisitor(validating, requestedSchema, createdBy, strict);
ColumnIOCreatorVisitor visitor = new ColumnIOCreatorVisitor(
validating, requestedSchema, createdBy, strict, strictUnsignedIntegerValidation);
fileSchema.accept(visitor);
return visitor.getColumnIO();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,18 @@ public class MessageColumnIO extends GroupColumnIO {
private List<PrimitiveColumnIO> leaves;

private final boolean validating;
private final boolean strictUnsignedIntegerValidation;
private final String createdBy;

MessageColumnIO(MessageType messageType, boolean validating, String createdBy) {
this(messageType, validating, false, createdBy);
}

MessageColumnIO(
MessageType messageType, boolean validating, boolean strictUnsignedIntegerValidation, String createdBy) {
super(messageType, null, 0);
this.validating = validating;
this.strictUnsignedIntegerValidation = strictUnsignedIntegerValidation;
this.createdBy = createdBy;
}

Expand Down Expand Up @@ -508,7 +515,9 @@ public void flush() {
public RecordConsumer getRecordWriter(ColumnWriteStore columns) {
RecordConsumer recordWriter = new MessageColumnIORecordConsumer(columns);
if (DEBUG) recordWriter = new RecordConsumerLoggingWrapper(recordWriter);
return validating ? new ValidatingRecordConsumer(recordWriter, getType()) : recordWriter;
return validating
? new ValidatingRecordConsumer(recordWriter, getType(), strictUnsignedIntegerValidation)
: recordWriter;
}

void setLevels() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.Optional;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.apache.parquet.schema.Type;
Expand All @@ -46,7 +48,11 @@
public class ValidatingRecordConsumer extends RecordConsumer {
private static final Logger LOG = LoggerFactory.getLogger(ValidatingRecordConsumer.class);

private static final int UINT_8_MAX_VALUE = 255;
private static final int UINT_16_MAX_VALUE = 65535;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only 8 and 16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UINT_8, UINT_16, and UINT_32 annotate an int32 value, UINT_64 annotatyes a i64 value. For UINT_8, and UINT16 valid values are in the range [0, UINT_8_MAX_VALUE] and [0, UINT_16_MAX_VALUE] respectively. For UINT_32 and UINT_64, all non negative values are valid.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that was the reason thanks!


private final RecordConsumer delegate;
private final boolean strictUnsignedIntegerValidation;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be validated at all times? This is already the ValidatingRecordConsumer so it might need to validate everything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately there are several test breaking and existing code built with the assumption of no validation. Since we put the changes directly within ValidatingRecordConsumer it causes existing behaviour to break. I added this with a feature flag so new users can start using the validation. I'll investigate as a followup if we can cleanup/fix existing code thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help investigate why existing test break? Are all of them failed because of unsigned integer overflow? If yes, they are actually bugs and should be fixed in this PR so we don't have to add a separate flag for unsigned integer.


private Deque<Type> types = new ArrayDeque<>();
private Deque<Integer> fields = new ArrayDeque<>();
Expand All @@ -58,7 +64,18 @@ public class ValidatingRecordConsumer extends RecordConsumer {
* @param schema the schema to validate against
*/
public ValidatingRecordConsumer(RecordConsumer delegate, MessageType schema) {
this(delegate, schema, false);
}

/**
* @param delegate the consumer to pass down the event to
* @param schema the schema to validate against
* @param strictUnsignedIntegerValidation whether to enable strict unsigned integer validation
*/
public ValidatingRecordConsumer(
RecordConsumer delegate, MessageType schema, boolean strictUnsignedIntegerValidation) {
this.delegate = delegate;
this.strictUnsignedIntegerValidation = strictUnsignedIntegerValidation;
this.types.push(schema);
}

Expand Down Expand Up @@ -202,6 +219,9 @@ private void validate(PrimitiveTypeName... ptypes) {
@Override
public void addInteger(int value) {
validate(INT32);
if (strictUnsignedIntegerValidation) {
validateUnsignedInteger(value);
}
delegate.addInteger(value);
}

Expand All @@ -211,6 +231,9 @@ public void addInteger(int value) {
@Override
public void addLong(long value) {
validate(INT64);
if (strictUnsignedIntegerValidation) {
validateUnsignedLong(value);
}
delegate.addLong(value);
}

Expand Down Expand Up @@ -249,4 +272,66 @@ public void addDouble(double value) {
validate(DOUBLE);
delegate.addDouble(value);
}

private void validateUnsignedInteger(int value) {
Type currentType = types.peek().asGroupType().getType(fields.peek());
if (currentType != null && currentType.isPrimitive()) {
LogicalTypeAnnotation logicalType = currentType.asPrimitiveType().getLogicalTypeAnnotation();
if (logicalType != null) {
logicalType.accept(new LogicalTypeAnnotation.LogicalTypeAnnotationVisitor<Void>() {
@Override
public Optional<Void> visit(LogicalTypeAnnotation.IntLogicalTypeAnnotation intType) {
if (!intType.isSigned()) {
switch (intType.getBitWidth()) {
case 8:
if (value < 0 || value > UINT_8_MAX_VALUE) {
throw new InvalidRecordException("Value " + value
+ " is out of range for UINT_8 (0-" + UINT_8_MAX_VALUE + ") in field "
+ currentType.getName());
}
break;
case 16:
if (value < 0 || value > UINT_16_MAX_VALUE) {
throw new InvalidRecordException("Value " + value
+ " is out of range for UINT_16 (0-" + UINT_16_MAX_VALUE + ") in field "
+ currentType.getName());
}
break;
case 32:
case 64:
if (value < 0) {
throw new InvalidRecordException("Negative value " + value
+ " is not allowed for unsigned integer type "
+ currentType.getName());
}
break;
}
}
return Optional.empty();
}
});
}
}
}

private void validateUnsignedLong(long value) {
Type currentType = types.peek().asGroupType().getType(fields.peek());
if (currentType != null && currentType.isPrimitive()) {
LogicalTypeAnnotation logicalType = currentType.asPrimitiveType().getLogicalTypeAnnotation();
if (logicalType != null) {
logicalType.accept(new LogicalTypeAnnotation.LogicalTypeAnnotationVisitor<Void>() {
@Override
public Optional<Void> visit(LogicalTypeAnnotation.IntLogicalTypeAnnotation intType) {
if (!intType.isSigned()) {
if (value < 0) {
throw new InvalidRecordException("Negative value " + value
+ " is not allowed for unsigned integer type " + currentType.getName());
}
}
return Optional.empty();
}
});
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class InternalParquetRecordWriter<T> {
private long nextRowGroupSize;
private final BytesInputCompressor compressor;
private final boolean validating;
private final boolean strictUnsignedIntegerValidation;
private final ParquetProperties props;

private boolean closed;
Expand Down Expand Up @@ -87,6 +88,28 @@ public InternalParquetRecordWriter(
BytesInputCompressor compressor,
boolean validating,
ParquetProperties props) {
this(
parquetFileWriter,
writeSupport,
schema,
extraMetaData,
rowGroupSize,
compressor,
validating,
false,
props);
}

public InternalParquetRecordWriter(
ParquetFileWriter parquetFileWriter,
WriteSupport<T> writeSupport,
MessageType schema,
Map<String, String> extraMetaData,
long rowGroupSize,
BytesInputCompressor compressor,
boolean validating,
boolean strictUnsignedIntegerValidation,
ParquetProperties props) {
this.parquetFileWriter = parquetFileWriter;
this.writeSupport = Objects.requireNonNull(writeSupport, "writeSupport cannot be null");
this.schema = schema;
Expand All @@ -96,6 +119,7 @@ public InternalParquetRecordWriter(
this.nextRowGroupSize = rowGroupSizeThreshold;
this.compressor = compressor;
this.validating = validating;
this.strictUnsignedIntegerValidation = strictUnsignedIntegerValidation;
this.props = props;
this.fileEncryptor = parquetFileWriter.getEncryptor();
this.rowGroupOrdinal = 0;
Expand All @@ -120,7 +144,7 @@ private void initStore() {
bloomFilterWriteStore = columnChunkPageWriteStore;

columnStore = props.newColumnWriteStore(schema, pageStore, bloomFilterWriteStore);
MessageColumnIO columnIO = new ColumnIOFactory(validating).getColumnIO(schema);
MessageColumnIO columnIO = new ColumnIOFactory(validating, strictUnsignedIntegerValidation).getColumnIO(schema);
this.recordConsumer = columnIO.getRecordWriter(columnStore);
writeSupport.prepareForWrite(recordConsumer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ public ParquetWriter(Path file, Configuration conf, WriteSupport<T> writeSupport
new CodecFactory(conf, encodingProps.getPageSizeThreshold()),
rowGroupSize,
validating,
false,
conf,
maxPaddingSize,
encodingProps,
Expand All @@ -375,6 +376,7 @@ public ParquetWriter(Path file, Configuration conf, WriteSupport<T> writeSupport
CompressionCodecFactory codecFactory,
long rowGroupSize,
boolean validating,
boolean strictUnsignedIntegerValidation,
ParquetConfiguration conf,
int maxPaddingSize,
ParquetProperties encodingProps,
Expand Down Expand Up @@ -417,7 +419,15 @@ public ParquetWriter(Path file, Configuration conf, WriteSupport<T> writeSupport
}

this.writer = new InternalParquetRecordWriter<T>(
fileWriter, writeSupport, schema, extraMetadata, rowGroupSize, compressor, validating, encodingProps);
fileWriter,
writeSupport,
schema,
extraMetadata,
rowGroupSize,
compressor,
validating,
strictUnsignedIntegerValidation,
encodingProps);
}

public void write(T object) throws IOException {
Expand Down Expand Up @@ -474,6 +484,7 @@ public abstract static class Builder<T, SELF extends Builder<T, SELF>> {
private long rowGroupSize = DEFAULT_BLOCK_SIZE;
private int maxPaddingSize = MAX_PADDING_SIZE_DEFAULT;
private boolean enableValidation = DEFAULT_IS_VALIDATING_ENABLED;
private boolean strictUnsignedIntegerValidation = false;
private ParquetProperties.Builder encodingPropsBuilder = ParquetProperties.builder();

protected Builder(Path path) {
Expand Down Expand Up @@ -715,6 +726,27 @@ public SELF withValidation(boolean enableValidation) {
return self();
}

/**
* Enable strict unsigned integer validation for the constructed writer.
*
* @return this builder for method chaining.
*/
public SELF enableStrictUnsignedIntegerValidation() {
this.strictUnsignedIntegerValidation = true;
return self();
}

/**
* Enable or disable strict unsigned integer validation for the constructed writer.
*
* @param strictUnsignedIntegerValidation whether strict unsigned integer validation should be enabled
* @return this builder for method chaining.
*/
public SELF withStrictUnsignedIntegerValidation(boolean strictUnsignedIntegerValidation) {
this.strictUnsignedIntegerValidation = strictUnsignedIntegerValidation;
return self();
}

/**
* Set the {@link WriterVersion format version} used by the constructed
* writer.
Expand Down Expand Up @@ -978,6 +1010,7 @@ public ParquetWriter<T> build() throws IOException {
codecFactory,
rowGroupSize,
enableValidation,
strictUnsignedIntegerValidation,
conf,
maxPaddingSize,
encodingProps,
Expand Down
Loading