Skip to content

Commit 82d298e

Browse files
lbooker42Copilot
andauthored
fix: DH-20614: correct PartitionAwareSourceTable partition column filter handling (#7296)
cherry pick of #7294 Co-authored-by: Copilot <[email protected]>
1 parent 8b99211 commit 82d298e

File tree

5 files changed

+269
-16
lines changed

5 files changed

+269
-16
lines changed

engine/table/src/main/java/io/deephaven/engine/table/impl/PartitionAwareSourceTable.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import io.deephaven.engine.table.impl.filter.ExtractBarriers;
1212
import io.deephaven.engine.table.impl.filter.ExtractInnerConjunctiveFilters;
1313
import io.deephaven.engine.table.impl.filter.ExtractRespectedBarriers;
14+
import io.deephaven.engine.table.impl.filter.ExtractSerialFilters;
1415
import io.deephaven.engine.table.impl.select.analyzers.SelectAndViewAnalyzer;
1516
import io.deephaven.engine.updategraph.UpdateSourceRegistrar;
1617
import io.deephaven.engine.table.impl.perf.QueryPerformanceRecorder;
@@ -305,7 +306,8 @@ private Table whereImpl(final List<WhereFilter> whereFilters) {
305306
for (WhereFilter whereFilter : whereFilters) {
306307
whereFilter.init(definition, compilationProcessor);
307308

308-
if (!whereFilter.permitParallelization()) {
309+
// Test for user-mandated serial filters (e.g. FilterSerial.of() or Filter.serial())
310+
if (!ExtractSerialFilters.of(whereFilter).isEmpty()) {
309311
serialFilterFound = true;
310312
}
311313

@@ -319,7 +321,7 @@ private Table whereImpl(final List<WhereFilter> whereFilters) {
319321

320322
// similarly, anytime we prioritize a partitioning filter, we record the barriers that it declares. A filter
321323
// that respects no barriers, or only those prioritized barriers may also be prioritized. A filter that
322-
// respects any barrier which was not in partition filters (meaning it must be in in otherFilters - because
324+
// respects any barrier which was not in partition filters (meaning it must be in otherFilters - because
323325
// otherwise you would be respecting an undeclared barrier); cannot be prioritized because that would jump
324326
// the barrier.
325327
if (serialFilterFound || missingBarrier) {
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//
2+
// Copyright (c) 2016-2025 Deephaven Data Labs and Patent Pending
3+
//
4+
package io.deephaven.engine.table.impl.filter;
5+
6+
import io.deephaven.api.RawString;
7+
import io.deephaven.api.expression.Function;
8+
import io.deephaven.api.expression.Method;
9+
import io.deephaven.api.filter.*;
10+
import io.deephaven.api.filter.Filter.Visitor;
11+
import io.deephaven.engine.table.impl.select.*;
12+
13+
import java.util.*;
14+
15+
/**
16+
* Performs a recursive filter extraction against {@code filter}. If {@code filter}, or any sub-filter, is a
17+
* {@link FilterSerial} or {@link WhereFilterSerialImpl}, the filter will be included in the returned collection.
18+
* Otherwise, an empty collection will be returned.
19+
*/
20+
public enum ExtractSerialFilters implements Visitor<Collection<Filter>>, WhereFilter.Visitor<Collection<Filter>> {
21+
INSTANCE;
22+
23+
public static Collection<Filter> of(Filter filter) {
24+
if (filter instanceof WhereFilter) {
25+
final Collection<Filter> retVal =
26+
((WhereFilter) filter).walkWhereFilter(INSTANCE);
27+
return retVal == null ? Collections.emptyList() : retVal;
28+
}
29+
return filter.walk(INSTANCE);
30+
}
31+
32+
@Override
33+
public Collection<Filter> visit(FilterIsNull isNull) {
34+
return Collections.emptyList();
35+
}
36+
37+
@Override
38+
public Collection<Filter> visit(FilterComparison comparison) {
39+
return Collections.emptyList();
40+
}
41+
42+
@Override
43+
public Collection<Filter> visit(FilterIn in) {
44+
return Collections.emptyList();
45+
}
46+
47+
@Override
48+
public Collection<Filter> visit(FilterNot<?> not) {
49+
return not.filter().walk(this);
50+
}
51+
52+
@Override
53+
public Collection<Filter> visit(FilterOr ors) {
54+
final List<Filter> serialFilters = new ArrayList<>();
55+
for (final Filter filter : ors.filters()) {
56+
serialFilters.addAll(of(filter));
57+
}
58+
return serialFilters;
59+
}
60+
61+
@Override
62+
public Collection<Filter> visit(FilterAnd ands) {
63+
final List<Filter> serialFilters = new ArrayList<>();
64+
for (final Filter filter : ands.filters()) {
65+
serialFilters.addAll(of(filter));
66+
}
67+
return serialFilters;
68+
}
69+
70+
@Override
71+
public Collection<Filter> visit(FilterPattern pattern) {
72+
return Collections.emptyList();
73+
}
74+
75+
@Override
76+
public Collection<Filter> visit(FilterSerial serial) {
77+
return List.of(serial); // return this filter
78+
}
79+
80+
@Override
81+
public Collection<Filter> visit(FilterWithDeclaredBarriers declaredBarrier) {
82+
return of(declaredBarrier.filter());
83+
}
84+
85+
@Override
86+
public Collection<Filter> visit(FilterWithRespectedBarriers respectedBarrier) {
87+
return of(respectedBarrier.filter());
88+
}
89+
90+
@Override
91+
public Collection<Filter> visit(Function function) {
92+
return Collections.emptyList();
93+
}
94+
95+
@Override
96+
public Collection<Filter> visit(Method method) {
97+
return Collections.emptyList();
98+
}
99+
100+
@Override
101+
public Collection<Filter> visit(boolean literal) {
102+
return Collections.emptyList();
103+
}
104+
105+
@Override
106+
public Collection<Filter> visit(RawString rawString) {
107+
return Collections.emptyList();
108+
}
109+
110+
@Override
111+
public Collection<Filter> visitWhereFilter(WhereFilterInvertedImpl filter) {
112+
return of(filter.getWrappedFilter());
113+
}
114+
115+
@Override
116+
public Collection<Filter> visitWhereFilter(WhereFilterSerialImpl filter) {
117+
return List.of(filter); // return this filter
118+
}
119+
120+
@Override
121+
public Collection<Filter> visitWhereFilter(WhereFilterWithDeclaredBarriersImpl filter) {
122+
return of(filter.getWrappedFilter());
123+
}
124+
125+
@Override
126+
public Collection<Filter> visitWhereFilter(WhereFilterWithRespectedBarriersImpl filter) {
127+
return of(filter.getWrappedFilter());
128+
}
129+
130+
@Override
131+
public Collection<Filter> visitWhereFilter(DisjunctiveFilter disjunctiveFilters) {
132+
final List<Filter> serialFilters = new ArrayList<>();
133+
for (final Filter filter : disjunctiveFilters.getFilters()) {
134+
serialFilters.addAll(of(filter));
135+
}
136+
return serialFilters;
137+
}
138+
139+
@Override
140+
public Collection<Filter> visitWhereFilter(ConjunctiveFilter conjunctiveFilters) {
141+
final List<Filter> serialFilters = new ArrayList<>();
142+
for (final Filter filter : conjunctiveFilters.getFilters()) {
143+
serialFilters.addAll(of(filter));
144+
}
145+
return serialFilters;
146+
}
147+
}

engine/table/src/test/java/io/deephaven/engine/table/impl/QueryTableWhereTest.java

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import io.deephaven.engine.table.impl.verify.TableAssertions;
3232
import io.deephaven.engine.testutil.*;
3333
import io.deephaven.engine.testutil.QueryTableTestBase.TableComparator;
34+
import io.deephaven.engine.testutil.filters.ParallelizedRowSetCapturingFilter;
3435
import io.deephaven.engine.testutil.filters.RowSetCapturingFilter;
3536
import io.deephaven.engine.testutil.generator.*;
3637
import io.deephaven.engine.testutil.junit4.EngineCleanup;
@@ -2393,20 +2394,6 @@ public void testNullRowKeyAgnosticColumnSources() {
23932394
"A", "A = null", "A != null");
23942395
}
23952396

2396-
/**
2397-
* Private helper to force parallelization of the RowSetCapturingFilter.
2398-
*/
2399-
private class ParallelizedRowSetCapturingFilter extends RowSetCapturingFilter {
2400-
public ParallelizedRowSetCapturingFilter(Filter filter) {
2401-
super(filter);
2402-
}
2403-
2404-
@Override
2405-
public boolean permitParallelization() {
2406-
return true;
2407-
}
2408-
}
2409-
24102397
private void testRowKeyAgnosticColumnSource(
24112398
final ColumnSource<?> columnSource,
24122399
final String columnName,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//
2+
// Copyright (c) 2016-2025 Deephaven Data Labs and Patent Pending
3+
//
4+
package io.deephaven.engine.testutil.filters;
5+
6+
import io.deephaven.api.filter.Filter;
7+
8+
/**
9+
* Helper to force parallelization of the RowSetCapturingFilter.
10+
*/
11+
public class ParallelizedRowSetCapturingFilter extends RowSetCapturingFilter {
12+
public ParallelizedRowSetCapturingFilter(Filter filter) {
13+
super(filter);
14+
}
15+
16+
@Override
17+
public boolean permitParallelization() {
18+
return true;
19+
}
20+
}

extensions/parquet/table/src/test/java/io/deephaven/parquet/table/ParquetTableFilterTest.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//
44
package io.deephaven.parquet.table;
55

6+
import io.deephaven.api.RawString;
67
import io.deephaven.api.filter.Filter;
78
import io.deephaven.base.FileUtils;
89
import io.deephaven.engine.context.ExecutionContext;
@@ -18,6 +19,8 @@
1819
import io.deephaven.engine.table.impl.select.WhereFilter;
1920
import io.deephaven.engine.table.impl.util.ColumnHolder;
2021
import io.deephaven.engine.table.impl.util.ImmediateJobScheduler;
22+
import io.deephaven.engine.testutil.filters.ParallelizedRowSetCapturingFilter;
23+
import io.deephaven.engine.testutil.filters.RowSetCapturingFilter;
2124
import io.deephaven.engine.testutil.junit4.EngineCleanup;
2225
import io.deephaven.engine.util.TableTools;
2326
import io.deephaven.parquet.table.location.ParquetColumnResolverMap;
@@ -43,6 +46,7 @@
4346
import java.time.Instant;
4447
import java.util.*;
4548
import java.util.concurrent.CompletableFuture;
49+
import java.util.concurrent.atomic.AtomicLong;
4650
import java.util.function.BiFunction;
4751

4852
import static io.deephaven.base.FileUtils.convertToURI;
@@ -136,6 +140,10 @@ private static void filterAndVerifyResults(Table diskTable, Table memTable, Stri
136140
verifyResults(diskTable.where(filters).coalesce(), memTable.where(filters).coalesce());
137141
}
138142

143+
private static void filterAndVerifyResults(Table diskTable, Table memTable, Filter filter) {
144+
verifyResults(diskTable.where(filter).coalesce(), memTable.where(filter).coalesce());
145+
}
146+
139147
private static void filterAndVerifyResults(Table diskTable, Table memTable, WhereFilter filter) {
140148
verifyResults(diskTable.where(filter).coalesce(), memTable.where(filter).coalesce());
141149
}
@@ -412,6 +420,83 @@ public void flatPartitionsNoDataIndexAllNullTest() {
412420
filterAndVerifyResultsAllowEmpty(diskTable, memTable, "boolean_col = true");
413421
}
414422

423+
// New test with custom function counting invocations
424+
@Test
425+
public void partitionedDataSerialFilterTest() {
426+
final String destPath = Path.of(rootFile.getPath(), "ParquetTest_kvPartitionsSerialTest").toString();
427+
final int tableSize = 1_000_000;
428+
429+
final Instant baseTime = parseInstant("2023-01-01T00:00:00 NY");
430+
QueryScope.addParam("baseTime", baseTime);
431+
432+
final Table largeTable = TableTools.emptyTable(tableSize).update(
433+
"symbol = ii % 100",
434+
"sequential_val = ii");
435+
436+
final PartitionedTable partitionedTable = largeTable.partitionBy("symbol");
437+
ParquetTools.writeKeyValuePartitionedTable(partitionedTable, destPath, EMPTY);
438+
439+
final Table diskTable = ParquetTools.readTable(destPath);
440+
final Table memTable = diskTable.select();
441+
442+
assertTableEquals(diskTable, memTable);
443+
444+
final AtomicLong invocationCount = new AtomicLong();
445+
QueryScope.addParam("invocationCount", invocationCount);
446+
447+
final Filter partitionFilter = RawString.of("symbol >= 0 && invocationCount.incrementAndGet() >= 0");
448+
final Filter serialPartitionFilter = partitionFilter.withSerial();
449+
450+
final Filter nonPartitionFilter =
451+
RawString.of("sequential_val >= 0 && invocationCount.incrementAndGet() >= 0");
452+
final Filter serialNonPartitionFilter = nonPartitionFilter.withSerial();
453+
454+
Table result;
455+
456+
// Test non-serial partition filter
457+
assertEquals(0L, invocationCount.get());
458+
result = diskTable.where(partitionFilter).coalesce();
459+
assertEquals(100L, invocationCount.get()); // one per partition
460+
// Verify the table contents are equivalent
461+
assertTableEquals(result, diskTable.coalesce().where(partitionFilter));
462+
463+
// Test serial partition filter
464+
invocationCount.set(0);
465+
assertEquals(0L, invocationCount.get());
466+
result = diskTable.where(serialPartitionFilter).coalesce();
467+
assertEquals(1_000_000L, invocationCount.get()); // one per row
468+
// Verify the table contents are equivalent
469+
assertTableEquals(result, diskTable.coalesce().where(serialPartitionFilter));
470+
471+
// Test non-serial non-partition filter
472+
invocationCount.set(0);
473+
assertEquals(0L, invocationCount.get());
474+
result = diskTable.where(nonPartitionFilter).coalesce();
475+
assertEquals(1_000_000L, invocationCount.get()); // one per row
476+
// Verify the table contents are equivalent
477+
assertTableEquals(result, diskTable.coalesce().where(nonPartitionFilter));
478+
479+
// Test serial non-partition filter
480+
invocationCount.set(0);
481+
assertEquals(0L, invocationCount.get());
482+
result = diskTable.where(serialNonPartitionFilter).coalesce();
483+
assertEquals(1_000_000L, invocationCount.get()); // one per row
484+
// Verify the table contents are equivalent
485+
assertTableEquals(result, diskTable.coalesce().where(serialNonPartitionFilter));
486+
487+
// Test stateless partition filter
488+
final RowSetCapturingFilter statelessPartitionFilter =
489+
new ParallelizedRowSetCapturingFilter(RawString.of("symbol >= 0"));
490+
result = diskTable.where(statelessPartitionFilter).coalesce();
491+
assertEquals(100, statelessPartitionFilter.numRowsProcessed()); // one per partition
492+
493+
// Test stateless non-partition filter
494+
final RowSetCapturingFilter statelessNonPartitionFilter =
495+
new ParallelizedRowSetCapturingFilter(RawString.of("sequential_val >= 0"));
496+
result = diskTable.where(statelessNonPartitionFilter).coalesce();
497+
assertEquals(1_000_000, statelessNonPartitionFilter.numRowsProcessed()); // one per row
498+
}
499+
415500
@Test
416501
public void partitionedNoDataIndexTest() {
417502
final String destPath = Path.of(rootFile.getPath(), "ParquetTest_kvPartitionsTest").toString();
@@ -442,6 +527,18 @@ public void partitionedNoDataIndexTest() {
442527
filterAndVerifyResults(diskTable, memTable, "symbol < `s100`");
443528
filterAndVerifyResults(diskTable, memTable, "symbol = `s500`");
444529

530+
// Conditional on partition column
531+
filterAndVerifyResults(diskTable, memTable, "symbol = `s` + `500`");
532+
// Serial conditional on partition column
533+
filterAndVerifyResults(diskTable, memTable,
534+
Filter.serial(Filter.and(Filter.from("symbol = `s` + `500`"))));
535+
536+
// Conditional on non-partition column
537+
filterAndVerifyResults(diskTable, memTable, "sequential_val >= 50 + 1");
538+
// Serial conditional on non-partition column
539+
filterAndVerifyResults(diskTable, memTable,
540+
Filter.serial(Filter.and(Filter.from("sequential_val >= 50 + 1"))));
541+
445542
// Timestamp range and match filters
446543
filterAndVerifyResults(diskTable, memTable, "Timestamp < '2023-01-02T00:00:00 NY'");
447544
filterAndVerifyResults(diskTable, memTable, "Timestamp = '2023-01-02T00:00:00 NY'");

0 commit comments

Comments
 (0)