Skip to content

Commit 227b659

Browse files
committed
CTE reuse WIP
1 parent aa23d8a commit 227b659

File tree

39 files changed

+3670
-26
lines changed

39 files changed

+3670
-26
lines changed

core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java

+15-3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
import io.trino.sql.analyzer.Analysis;
6262
import io.trino.sql.analyzer.Analyzer;
6363
import io.trino.sql.analyzer.AnalyzerFactory;
64+
import io.trino.sql.newir.FormatOptions;
65+
import io.trino.sql.newir.Program;
6466
import io.trino.sql.planner.AdaptivePlanner;
6567
import io.trino.sql.planner.InputExtractor;
6668
import io.trino.sql.planner.LogicalPlanner;
@@ -73,6 +75,7 @@
7375
import io.trino.sql.planner.SplitSourceFactory;
7476
import io.trino.sql.planner.SubPlan;
7577
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
78+
import io.trino.sql.planner.optimizations.ctereuse.CteReuse;
7679
import io.trino.sql.planner.optimizations.PlanOptimizer;
7780
import io.trino.sql.planner.plan.OutputNode;
7881
import io.trino.sql.tree.ExplainAnalyze;
@@ -148,6 +151,7 @@ public class SqlQueryExecution
148151
private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory;
149152
private final TaskDescriptorStorage taskDescriptorStorage;
150153
private final PlanOptimizersStatsCollector planOptimizersStatsCollector;
154+
private final FormatOptions formatOptions;
151155

152156
private SqlQueryExecution(
153157
PreparedQuery preparedQuery,
@@ -185,7 +189,8 @@ private SqlQueryExecution(
185189
SqlTaskManager coordinatorTaskManager,
186190
ExchangeManagerRegistry exchangeManagerRegistry,
187191
EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory,
188-
TaskDescriptorStorage taskDescriptorStorage)
192+
TaskDescriptorStorage taskDescriptorStorage,
193+
FormatOptions formatOptions)
189194
{
190195
try (SetThreadName _ = new SetThreadName("Query-" + stateMachine.getQueryId())) {
191196
this.slug = requireNonNull(slug, "slug is null");
@@ -240,6 +245,7 @@ private SqlQueryExecution(
240245
this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "taskSourceFactory is null");
241246
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
242247
this.planOptimizersStatsCollector = requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null");
248+
this.formatOptions = requireNonNull(formatOptions, "formatOptions is null");
243249
}
244250
}
245251

@@ -503,6 +509,8 @@ private PlanRoot doPlanQuery(CachingTableStatsProvider tableStatsProvider)
503509
Plan plan = logicalPlanner.plan(analysis);
504510
queryPlan.set(plan);
505511

512+
Optional<Program> optimizedProgram = CteReuse.reuseCommonSubqueries(plan, plannerContext, getSession(), formatOptions);
513+
506514
// fragment the plan
507515
SubPlan fragmentedPlan;
508516
try (var _ = scopedSpan(tracer, "fragment-plan")) {
@@ -809,6 +817,7 @@ public static class SqlQueryExecutionFactory
809817
private final ExchangeManagerRegistry exchangeManagerRegistry;
810818
private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory;
811819
private final TaskDescriptorStorage taskDescriptorStorage;
820+
private final FormatOptions formatOptions;
812821

813822
@Inject
814823
SqlQueryExecutionFactory(
@@ -841,7 +850,8 @@ public static class SqlQueryExecutionFactory
841850
SqlTaskManager coordinatorTaskManager,
842851
ExchangeManagerRegistry exchangeManagerRegistry,
843852
EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory,
844-
TaskDescriptorStorage taskDescriptorStorage)
853+
TaskDescriptorStorage taskDescriptorStorage,
854+
FormatOptions formatOptions)
845855
{
846856
this.tracer = requireNonNull(tracer, "tracer is null");
847857
this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null");
@@ -875,6 +885,7 @@ public static class SqlQueryExecutionFactory
875885
this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null");
876886
this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "eventDrivenTaskSourceFactory is null");
877887
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
888+
this.formatOptions = requireNonNull(formatOptions, "formatOptions is null");
878889
}
879890

880891
@Override
@@ -925,7 +936,8 @@ public QueryExecution createQueryExecution(
925936
coordinatorTaskManager,
926937
exchangeManagerRegistry,
927938
eventDrivenTaskSourceFactory,
928-
taskDescriptorStorage);
939+
taskDescriptorStorage,
940+
formatOptions);
929941
}
930942
}
931943
}

core/trino-main/src/main/java/io/trino/metadata/Metadata.java

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import io.trino.spi.connector.TableFunctionApplicationResult;
5555
import io.trino.spi.connector.TableScanRedirectApplicationResult;
5656
import io.trino.spi.connector.TopNApplicationResult;
57+
import io.trino.spi.connector.UnificationResult;
5758
import io.trino.spi.connector.WriterScalingOptions;
5859
import io.trino.spi.expression.ConnectorExpression;
5960
import io.trino.spi.expression.Constant;
@@ -569,6 +570,8 @@ Optional<TopNApplicationResult<TableHandle>> applyTopN(
569570

570571
Optional<TableFunctionApplicationResult<TableHandle>> applyTableFunction(Session session, TableFunctionHandle handle);
571572

573+
Optional<UnificationResult<TableHandle>> unifyTables(Session session, TableHandle first, TableHandle second);
574+
572575
default void validateScan(Session session, TableHandle table) {}
573576

574577
//

core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java

+19
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import io.trino.spi.connector.TableFunctionApplicationResult;
8989
import io.trino.spi.connector.TableScanRedirectApplicationResult;
9090
import io.trino.spi.connector.TopNApplicationResult;
91+
import io.trino.spi.connector.UnificationResult;
9192
import io.trino.spi.connector.WriterScalingOptions;
9293
import io.trino.spi.expression.ConnectorExpression;
9394
import io.trino.spi.expression.Constant;
@@ -2157,6 +2158,24 @@ public Optional<TableFunctionApplicationResult<TableHandle>> applyTableFunction(
21572158
result.getColumnHandles()));
21582159
}
21592160

2161+
@Override
2162+
public Optional<UnificationResult<TableHandle>> unifyTables(Session session, TableHandle first, TableHandle second)
2163+
{
2164+
CatalogHandle catalogHandle = first.catalogHandle();
2165+
ConnectorTransactionHandle transaction = first.transaction();
2166+
if (!catalogHandle.equals(second.catalogHandle()) || !transaction.equals(second.transaction())) {
2167+
return Optional.empty();
2168+
}
2169+
ConnectorMetadata metadata = getMetadata(session, catalogHandle);
2170+
2171+
return metadata.unifyTables(session.toConnectorSession(catalogHandle), first.connectorHandle(), second.connectorHandle())
2172+
.map(result -> new UnificationResult<>(
2173+
new TableHandle(catalogHandle, result.unifiedHandle(), transaction),
2174+
result.firstCompensationFilter(),
2175+
result.secondCompensationFilter(),
2176+
result.enforcedProperties()));
2177+
}
2178+
21602179
private void verifyProjection(TableHandle table, List<ConnectorExpression> projections, List<Assignment> assignments, int expectedProjectionSize)
21612180
{
21622181
projections.forEach(projection -> requireNonNull(projection, "one of the projections is null"));

core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.util.Arrays;
4141
import java.util.List;
4242
import java.util.Map;
43+
import java.util.Optional;
4344

4445
import static io.trino.spi.StandardErrorCode.IR_ERROR;
4546
import static io.trino.sql.dialect.trino.TrinoDialect.TRINO;

core/trino-main/src/main/java/io/trino/sql/dialect/trino/ProgramBuilder.java

+10
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ public static class ValueNameAllocator
7979
{
8080
private int label;
8181

82+
public ValueNameAllocator()
83+
{
84+
this(0);
85+
}
86+
87+
public ValueNameAllocator(int initialLabel)
88+
{
89+
this.label = initialLabel;
90+
}
91+
8292
public String newName()
8393
{
8494
return "%" + label++;

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/AggregateCall.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import io.trino.sql.dialect.trino.Attributes.SortOrderList;
2424
import io.trino.sql.newir.Block;
2525
import io.trino.sql.newir.FormatOptions;
26-
import io.trino.sql.newir.Operation;
2726
import io.trino.sql.newir.Region;
2827
import io.trino.sql.newir.Value;
2928
import io.trino.type.FunctionType;
@@ -51,7 +50,7 @@
5150
import static java.util.Objects.requireNonNull;
5251

5352
public class AggregateCall
54-
extends Operation
53+
extends TrinoOperation
5554
{
5655
private static final String NAME = "aggregate_call";
5756

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Aggregation.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
import static java.util.Objects.requireNonNull;
4949

5050
public class Aggregation
51-
extends Operation
51+
extends TrinoOperation
5252
{
5353
private static final String NAME = "aggregation";
5454

@@ -152,6 +152,18 @@ public Aggregation(
152152
this.attributes = attributes.buildOrThrow();
153153
}
154154

155+
// TODO checks
156+
private Aggregation(Result result, Value input, Region aggregateCalls, Region groupingKeysSelector, Region hashSelector, Map<AttributeKey, Object> attributes)
157+
{
158+
super(TRINO, NAME);
159+
this.result = result;
160+
this.input = input;
161+
this.aggregateCalls = aggregateCalls;
162+
this.groupingKeysSelector = groupingKeysSelector;
163+
this.hashSelector = hashSelector;
164+
this.attributes = ImmutableMap.copyOf(attributes);
165+
}
166+
155167
@Override
156168
public Result result()
157169
{
@@ -182,6 +194,15 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
182194
return "pretty aggregation";
183195
}
184196

197+
@Override
198+
public Operation withArgument(Value newArgument, int index)
199+
{
200+
// TODO this does not validate the new Regions. Should run the same checks as the constructor.
201+
// TODO missing real source attributes - it's incorrect to copy source attributes if the source changed
202+
validateArgument(newArgument, index);
203+
return new Aggregation(result, newArgument, aggregateCalls, groupingKeysSelector, this.hashSelector, this.attributes);
204+
}
205+
185206
@Override
186207
public boolean equals(Object obj)
187208
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Comparison.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import io.trino.spi.TrinoException;
1818
import io.trino.sql.dialect.trino.Attributes.ComparisonOperator;
1919
import io.trino.sql.newir.FormatOptions;
20-
import io.trino.sql.newir.Operation;
2120
import io.trino.sql.newir.Region;
2221
import io.trino.sql.newir.Value;
2322

@@ -35,7 +34,7 @@
3534
import static java.util.Objects.requireNonNull;
3635

3736
public final class Comparison
38-
extends Operation
37+
extends TrinoOperation
3938
{
4039
private static final String NAME = "comparison";
4140

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Constant.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import io.trino.spi.predicate.NullableValue;
1818
import io.trino.spi.type.Type;
1919
import io.trino.sql.newir.FormatOptions;
20-
import io.trino.sql.newir.Operation;
2120
import io.trino.sql.newir.Region;
2221
import io.trino.sql.newir.Value;
2322

@@ -31,7 +30,7 @@
3130
import static java.util.Objects.requireNonNull;
3231

3332
public final class Constant
34-
extends Operation
33+
extends TrinoOperation
3534
{
3635
private static final String NAME = "constant";
3736

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Exchange.java

+39-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
import static java.util.Objects.requireNonNull;
6363

6464
public class Exchange
65-
extends Operation
65+
extends TrinoOperation
6666
{
6767
private static final String NAME = "exchange";
6868

@@ -210,6 +210,26 @@ public Exchange(
210210
this.attributes = attributes.buildOrThrow();
211211
}
212212

213+
private Exchange(
214+
Result result,
215+
List<Value> inputs,
216+
List<Region> inputFieldSelectors,
217+
Region partitioningBoundArguments,
218+
Region partitioningHashSelector,
219+
Region orderingSelector,
220+
Map<AttributeKey, Object> attributes)
221+
{
222+
// TODO checks!!!
223+
super(TRINO, NAME);
224+
this.result = result;
225+
this.inputs = inputs;
226+
this.inputFieldSelectors = inputFieldSelectors;
227+
this.partitioningBoundArguments = partitioningBoundArguments;
228+
this.partitioningHashSelector = partitioningHashSelector;
229+
this.orderingSelector = orderingSelector;
230+
this.attributes = attributes;
231+
}
232+
213233
@Override
214234
public Result result()
215235
{
@@ -245,6 +265,24 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
245265
return "pretty exchange";
246266
}
247267

268+
@Override
269+
public Operation withArgument(Value newArgument, int index)
270+
{
271+
validateArgument(newArgument, index);
272+
ImmutableList.Builder<Value> newInputs = ImmutableList.builder();
273+
for (int i = 0; i < inputs.size(); i++) {
274+
if (i == index) {
275+
newInputs.add(newArgument);
276+
}
277+
else {
278+
newInputs.add(inputs.get(i));
279+
}
280+
}
281+
// TODO add checks in the copy constructor
282+
// TODO missing real source attributes - it's incorrect to copy source attributes if the source changed
283+
return new Exchange(result, newInputs.build(), inputFieldSelectors, partitioningBoundArguments, partitioningHashSelector, orderingSelector, attributes);
284+
}
285+
248286
@Override
249287
public boolean equals(Object obj)
250288
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/FieldReference.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package io.trino.sql.dialect.trino.operation;
1515

1616
import com.google.common.collect.ImmutableList;
17+
import com.google.common.collect.ImmutableMap;
1718
import io.trino.spi.TrinoException;
1819
import io.trino.spi.type.RowType;
1920
import io.trino.sql.newir.FormatOptions;
@@ -34,7 +35,7 @@
3435
import static java.util.Objects.requireNonNull;
3536

3637
public final class FieldReference
37-
extends Operation
38+
extends TrinoOperation
3839
{
3940
private static final String NAME = "field_reference";
4041

@@ -96,6 +97,19 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
9697
return "pretty field reference";
9798
}
9899

100+
@Override
101+
public Operation withArgument(Value newArgument, int index)
102+
{
103+
validateArgument(newArgument, index);
104+
// TODO missing source attributes
105+
return new FieldReference(result.name(), newArgument, FIELD_INDEX.getAttribute(attributes), ImmutableMap.of());
106+
}
107+
108+
public Value base()
109+
{
110+
return base;
111+
}
112+
99113
@Override
100114
public boolean equals(Object obj)
101115
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Filter.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import static java.util.Objects.requireNonNull;
3737

3838
public final class Filter
39-
extends Operation
39+
extends TrinoOperation
4040
{
4141
private static final String NAME = "filter";
4242

@@ -61,6 +61,7 @@ public Filter(String resultName, Value input, Block predicate, Map<AttributeKey,
6161

6262
this.input = input;
6363

64+
// TODO validate block labels
6465
if (predicate.parameters().size() != 1 ||
6566
!trinoType(predicate.parameters().getFirst().type()).equals(relationRowType(trinoType(input.type()))) ||
6667
!trinoType(predicate.getReturnedType()).equals(BOOLEAN)) {
@@ -102,6 +103,24 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
102103
return "pretty filter";
103104
}
104105

106+
@Override
107+
public Operation withArgument(Value newArgument, int index)
108+
{
109+
validateArgument(newArgument, index);
110+
// TODO missing source attributes
111+
return new Filter(result.name(), newArgument, predicate.getOnlyBlock(), ImmutableMap.of());
112+
}
113+
114+
public Value argument()
115+
{
116+
return input;
117+
}
118+
119+
public Block predicate()
120+
{
121+
return predicate.getOnlyBlock();
122+
}
123+
105124
@Override
106125
public boolean equals(Object obj)
107126
{

0 commit comments

Comments
 (0)