Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ private CloseableIterator<FilteredColumnarBatch> applyDataSkipping(
// pruning it after is much simpler
StructType prunedStatsSchema =
DataSkippingUtils.pruneStatsSchema(
getStatsSchema(metadata.getDataSchema()), dataSkippingFilter.getReferencedCols());
getStatsSchema(metadata.getDataSchema(), dataSkippingFilter.getReferencedCollations()),
dataSkippingFilter.getReferencedCols());

// Skipping happens in two steps:
// 1. The predicate produces false for any file whose stats prove we can safely skip it. A
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ public class DataSkippingPredicate extends Predicate {
/** Set of {@link Column}s referenced by the predicate or any of its child expressions */
private final Set<Column> referencedCols;

/**
* Set of {@link CollationIdentifier}s referenced by this predicate or any of its child
* expressions
*/
private final Set<CollationIdentifier> collationIdentifiers;

Comment on lines -30 to -35
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should have this field. We already have collationIdentifier in Predicate, so having this just increases complexity. Also, getReferencedCollations is only called once in the codebase, so we don't lose much by not persisting its value.

/**
* @param name the predicate name
* @param children list of expressions that are input to this predicate.
Expand All @@ -42,7 +36,6 @@ public class DataSkippingPredicate extends Predicate {
DataSkippingPredicate(String name, List<Expression> children, Set<Column> referencedCols) {
super(name, children);
this.referencedCols = Collections.unmodifiableSet(referencedCols);
this.collationIdentifiers = Collections.unmodifiableSet(new HashSet<>());
}

/**
Expand All @@ -59,7 +52,6 @@ public class DataSkippingPredicate extends Predicate {
Set<Column> referencedCols) {
super(name, children, collationIdentifier);
this.referencedCols = Collections.unmodifiableSet(referencedCols);
this.collationIdentifiers = Collections.singleton(collationIdentifier);
}

/**
Expand All @@ -73,8 +65,6 @@ public class DataSkippingPredicate extends Predicate {
DataSkippingPredicate(String name, DataSkippingPredicate left, DataSkippingPredicate right) {
super(name, Arrays.asList(left, right));
this.referencedCols = immutableUnion(left.referencedCols, right.referencedCols);
this.collationIdentifiers =
immutableUnion(left.collationIdentifiers, right.collationIdentifiers);
}

/** @return set of columns referenced by this predicate or any of its child expressions */
Expand All @@ -87,7 +77,26 @@ public Set<Column> getReferencedCols() {
* expressions
*/
public Set<CollationIdentifier> getReferencedCollations() {
return collationIdentifiers;
Set<CollationIdentifier> referencedCollations = new HashSet<>();

if (this.getCollationIdentifier().isPresent()) {
referencedCollations.add(this.getCollationIdentifier().get());
}

for (Expression child : children) {
if (child instanceof Predicate) {
if (child instanceof DataSkippingPredicate) {
referencedCollations.addAll(((DataSkippingPredicate) child).getReferencedCollations());
} else {
throw new IllegalStateException(
String.format(
"Expected child Predicate of DataSkippingPredicate to also be a"
+ " DataSkippingPredicate, but found %s",
child.getClass().getName()));
}
}
}
return Collections.unmodifiableSet(referencedCollations);
}

/** @return an unmodifiable set containing all elements from both sets. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public static boolean isSkippingEligibleDataType(DataType dataType, boolean isCo
* |-- a: struct (nullable = true)
* | |-- b: struct (nullable = true)
* | | |-- c: long (nullable = true)
* | | |-- d: string (nullable = true)
* </pre>
*
* <p>Collected Statistics:
Expand All @@ -102,18 +103,32 @@ public static boolean isSkippingEligibleDataType(DataType dataType, boolean isCo
* | | |-- a: struct (nullable = false)
* | | | |-- b: struct (nullable = false)
* | | | | |-- c: long (nullable = true)
* | | | | |-- d: string (nullable = true)
* | |-- maxValues: struct (nullable = false)
* | | |-- a: struct (nullable = false)
* | | | |-- b: struct (nullable = false)
* | | | | |-- c: long (nullable = true)
* | | | | |-- d: string (nullable = true)
* | |-- nullCount: struct (nullable = false)
* | | |-- a: struct (nullable = false)
* | | | |-- b: struct (nullable = false)
* | | | | |-- c: long (nullable = true)
* | | | | |-- d: string (nullable = true)
* | |-- statsWithCollation: struct (nullable = true)
* | | |-- collationName: struct (nullable = true)
* | | | |-- min: struct (nullable = false)
* | | | | |-- a: struct (nullable = false)
* | | | | | |-- b: struct (nullable = false)
* | | | | | | |-- d: string (nullable = true)
* | | | |-- max: struct (nullable = false)
* | | | | |-- a: struct (nullable = false)
* | | | | | |-- b: struct (nullable = false)
* | | | | | | |-- d: string (nullable = true)
* | |-- tightBounds: boolean (nullable = true)
* </pre>
*/
public static StructType getStatsSchema(StructType dataSchema) {
public static StructType getStatsSchema(
StructType dataSchema, Set<CollationIdentifier> collationIdentifiers) {
StructType statsSchema = new StructType().add(NUM_RECORDS, LongType.LONG, true);

StructType minMaxStatsSchema = getMinMaxStatsSchema(dataSchema);
Expand All @@ -128,6 +143,11 @@ public static StructType getStatsSchema(StructType dataSchema) {

statsSchema = statsSchema.add(TIGHT_BOUNDS, BooleanType.BOOLEAN, true);

StructType collatedMinMaxStatsSchema = getCollatedStatsSchema(dataSchema, collationIdentifiers);
if (collatedMinMaxStatsSchema.length() > 0) {
statsSchema = statsSchema.add(STATS_WITH_COLLATION, collatedMinMaxStatsSchema, true);
}

return statsSchema;
}

Expand Down Expand Up @@ -272,24 +292,60 @@ public boolean isSkippingEligibleNullCountColumn(Column column) {
/**
* Given a data schema returns the expected schema for a min or max statistics column. This means
* 1) replace logical names with physical names 2) set nullable=true 3) only keep stats eligible
* fields (i.e. don't include fields with isSkippingEligibleDataType=false)
* fields (i.e. don't include fields with isSkippingEligibleDataType=false). Collation-aware
* statistics are not included.
*/
private static StructType getMinMaxStatsSchema(StructType dataSchema) {
return getMinMaxStatsSchema(dataSchema, /* isCollatedSkipping */ false);
}

/**
* Given a data schema returns the expected schema for a min or max statistics column. This means
* 1) replace logical names with physical names 2) set nullable=true 3) only keep stats eligible
* fields (i.e. don't include fields with isSkippingEligibleDataType=false). In case when
* isCollatedSkipping is true, only `StringType` fields are eligible.
*/
private static StructType getMinMaxStatsSchema(
StructType dataSchema, boolean isCollatedSkipping) {
List<StructField> fields = new ArrayList<>();
for (StructField field : dataSchema.fields()) {
if (isSkippingEligibleDataType(field.getDataType(), false)) {
if (isSkippingEligibleDataType(field.getDataType(), isCollatedSkipping)) {
fields.add(new StructField(getPhysicalName(field), field.getDataType(), true));
} else if (field.getDataType() instanceof StructType) {
fields.add(
new StructField(
getPhysicalName(field),
getMinMaxStatsSchema((StructType) field.getDataType()),
getMinMaxStatsSchema((StructType) field.getDataType(), isCollatedSkipping),
true));
}
}
return new StructType(fields);
}

/**
* Given a data schema and a set of collation identifiers returns the expected schema for
* collation-aware statistics columns. This means 1) replace logical names with physical names 2)
* set nullable=true 3) only keep collated-stats eligible fields (`StringType` fields)
*/
Comment on lines +325 to +329
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this what we do in delta spark for the stats read schema? Or is there an optimization where we only read the collated stats for columns referenced in the predicate (with that collation)?

Maybe this is a future optimization we could consider?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can optimize it like that. We can do this for both collated and binary stats. Let's leave this for future optimization so we don't complicate this PR further.

private static StructType getCollatedStatsSchema(
StructType dataSchema, Set<CollationIdentifier> collationIdentifiers) {
StructType statsWithCollation = new StructType();
StructType minMaxSchemaForCollationAwareFields =
getMinMaxStatsSchema(dataSchema, /* isCollatedSkipping */ true);
if (minMaxSchemaForCollationAwareFields.length() > 0) {
for (CollationIdentifier collationIdentifier : collationIdentifiers) {
statsWithCollation =
statsWithCollation.add(
collationIdentifier.toString(),
new StructType()
.add(MIN, minMaxSchemaForCollationAwareFields, true)
.add(MAX, minMaxSchemaForCollationAwareFields, true),
true);
}
}
return statsWithCollation;
}

/**
* Given a data schema returns the expected schema for a null_count statistics column. This means
* 1) replace logical names with physical names 2) set nullable=true 3) use LongType for all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils {
collation: CollationIdentifier,
statName: String,
fieldName: String): Column = {
new Column(Array(STATS_WITH_COLLATION, collation.toString, statName, fieldName))
val columnPath =
Array(STATS_WITH_COLLATION, collation.toString, statName) ++ fieldName.split('.')
new Column(columnPath)
}

/* For struct type checks for equality based on field names & data type only */
Expand Down Expand Up @@ -188,6 +190,118 @@ class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils {
new StructType())
}

test("pruneStatsSchema - collated min/max columns") {
val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE")
val unicode = CollationIdentifier.fromString("ICU.UNICODE")
val testSchema = new StructType()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - I think you could define vals to make this simpler that can be reused?

val nestedField = ...
val s1Field = ...
val allFields = ... // (s1 + i1 + i2 + nested)

.add(
MIN,
new StructType()
.add("s1", StringType.STRING)
.add("i1", INTEGER)
.add("i2", INTEGER)
.add("nested", new StructType().add("s2", StringType.STRING)))
.add(
MAX,
new StructType()
.add("s1", StringType.STRING)
.add("i1", INTEGER)
.add("i2", INTEGER)
.add("nested", new StructType().add("s2", StringType.STRING)))
.add(
STATS_WITH_COLLATION,
new StructType()
.add(
utf8Lcase.toString,
new StructType()
.add(
MIN,
new StructType()
.add("s1", StringType.STRING)
.add("nested", new StructType().add("s2", StringType.STRING)))
.add(
MAX,
new StructType()
.add("s1", StringType.STRING)
.add("nested", new StructType().add("s2", StringType.STRING))))
.add(
unicode.toString,
new StructType()
.add(
MIN,
new StructType()
.add("s1", StringType.STRING)
.add("nested", new StructType().add("s2", StringType.STRING)))
.add(
MAX,
new StructType()
.add("s1", StringType.STRING)
.add("nested", new StructType().add("s2", StringType.STRING)))))

val testCases = Seq(
(
Set(nestedCol(s"$MIN.nested.s2"), nestedCol(s"$MAX.i1")),
new StructType()
.add(
MIN,
new StructType()
.add("nested", new StructType().add("s2", StringType.STRING)))
.add(
MAX,
new StructType()
.add("i1", INTEGER))),
(
Set(
collatedStatsCol(utf8Lcase, MIN, "s1"),
collatedStatsCol(unicode, MAX, "nested.s2")),
new StructType()
.add(
STATS_WITH_COLLATION,
new StructType()
.add(
utf8Lcase.toString,
new StructType()
.add(
MIN,
new StructType().add("s1", StringType.STRING)))
.add(
unicode.toString,
new StructType()
.add(
MAX,
new StructType().add(
"nested",
new StructType().add("s2", StringType.STRING)))))),
(
Set(
nestedCol(s"$MIN.i2"),
collatedStatsCol(utf8Lcase, MAX, "nested.s2"),
collatedStatsCol(utf8Lcase, MIN, "nested.s2")),
new StructType()
.add(
MIN,
new StructType()
.add("i2", INTEGER))
.add(
STATS_WITH_COLLATION,
new StructType()
.add(
utf8Lcase.toString,
new StructType()
.add(
MIN,
new StructType()
.add("nested", new StructType().add("s2", StringType.STRING)))
.add(
MAX,
new StructType()
.add("nested", new StructType().add("s2", StringType.STRING)))))))

testCases.foreach { case (referencedCols, expectedSchema) =>
checkPruneStatsSchema(testSchema, referencedCols, expectedSchema)
}
}

// TODO: add tests for remaining operators
test("check constructDataSkippingFilter") {
val testCases = Seq(
Expand Down
Loading
Loading