Skip to content

Commit 6a1228b

Browse files
committed
CTE reuse WIP
1 parent 40aeed5 commit 6a1228b

File tree

21 files changed

+1501
-4
lines changed

21 files changed

+1501
-4
lines changed

core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java

+15-3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
import io.trino.sql.analyzer.Analysis;
6262
import io.trino.sql.analyzer.Analyzer;
6363
import io.trino.sql.analyzer.AnalyzerFactory;
64+
import io.trino.sql.newir.FormatOptions;
65+
import io.trino.sql.newir.Program;
6466
import io.trino.sql.planner.AdaptivePlanner;
6567
import io.trino.sql.planner.InputExtractor;
6668
import io.trino.sql.planner.LogicalPlanner;
@@ -73,6 +75,7 @@
7375
import io.trino.sql.planner.SplitSourceFactory;
7476
import io.trino.sql.planner.SubPlan;
7577
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
78+
import io.trino.sql.planner.optimizations.CteReuse;
7679
import io.trino.sql.planner.optimizations.PlanOptimizer;
7780
import io.trino.sql.planner.plan.OutputNode;
7881
import io.trino.sql.tree.ExplainAnalyze;
@@ -148,6 +151,7 @@ public class SqlQueryExecution
148151
private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory;
149152
private final TaskDescriptorStorage taskDescriptorStorage;
150153
private final PlanOptimizersStatsCollector planOptimizersStatsCollector;
154+
private final FormatOptions formatOptions;
151155

152156
private SqlQueryExecution(
153157
PreparedQuery preparedQuery,
@@ -185,7 +189,8 @@ private SqlQueryExecution(
185189
SqlTaskManager coordinatorTaskManager,
186190
ExchangeManagerRegistry exchangeManagerRegistry,
187191
EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory,
188-
TaskDescriptorStorage taskDescriptorStorage)
192+
TaskDescriptorStorage taskDescriptorStorage,
193+
FormatOptions formatOptions)
189194
{
190195
try (SetThreadName _ = new SetThreadName("Query-" + stateMachine.getQueryId())) {
191196
this.slug = requireNonNull(slug, "slug is null");
@@ -240,6 +245,7 @@ private SqlQueryExecution(
240245
this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "taskSourceFactory is null");
241246
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
242247
this.planOptimizersStatsCollector = requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null");
248+
this.formatOptions = requireNonNull(formatOptions, "formatOptions is null");
243249
}
244250
}
245251

@@ -503,6 +509,8 @@ private PlanRoot doPlanQuery(CachingTableStatsProvider tableStatsProvider)
503509
Plan plan = logicalPlanner.plan(analysis);
504510
queryPlan.set(plan);
505511

512+
Optional<Program> optimizedProgram = CteReuse.reuseCommonSubqueries(plan, plannerContext, getSession());
513+
506514
// fragment the plan
507515
SubPlan fragmentedPlan;
508516
try (var _ = scopedSpan(tracer, "fragment-plan")) {
@@ -809,6 +817,7 @@ public static class SqlQueryExecutionFactory
809817
private final ExchangeManagerRegistry exchangeManagerRegistry;
810818
private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory;
811819
private final TaskDescriptorStorage taskDescriptorStorage;
820+
private final FormatOptions formatOptions;
812821

813822
@Inject
814823
SqlQueryExecutionFactory(
@@ -841,7 +850,8 @@ public static class SqlQueryExecutionFactory
841850
SqlTaskManager coordinatorTaskManager,
842851
ExchangeManagerRegistry exchangeManagerRegistry,
843852
EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory,
844-
TaskDescriptorStorage taskDescriptorStorage)
853+
TaskDescriptorStorage taskDescriptorStorage,
854+
FormatOptions formatOptions)
845855
{
846856
this.tracer = requireNonNull(tracer, "tracer is null");
847857
this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null");
@@ -875,6 +885,7 @@ public static class SqlQueryExecutionFactory
875885
this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null");
876886
this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "eventDrivenTaskSourceFactory is null");
877887
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
888+
this.formatOptions = requireNonNull(formatOptions, "formatOptions is null");
878889
}
879890

880891
@Override
@@ -925,7 +936,8 @@ public QueryExecution createQueryExecution(
925936
coordinatorTaskManager,
926937
exchangeManagerRegistry,
927938
eventDrivenTaskSourceFactory,
928-
taskDescriptorStorage);
939+
taskDescriptorStorage,
940+
formatOptions);
929941
}
930942
}
931943
}

core/trino-main/src/main/java/io/trino/metadata/Metadata.java

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import io.trino.spi.connector.TableFunctionApplicationResult;
5555
import io.trino.spi.connector.TableScanRedirectApplicationResult;
5656
import io.trino.spi.connector.TopNApplicationResult;
57+
import io.trino.spi.connector.UnificationResult;
5758
import io.trino.spi.connector.WriterScalingOptions;
5859
import io.trino.spi.expression.ConnectorExpression;
5960
import io.trino.spi.expression.Constant;
@@ -569,6 +570,8 @@ Optional<TopNApplicationResult<TableHandle>> applyTopN(
569570

570571
Optional<TableFunctionApplicationResult<TableHandle>> applyTableFunction(Session session, TableFunctionHandle handle);
571572

573+
Optional<UnificationResult<TableHandle>> unifyTables(Session session, TableHandle first, TableHandle second);
574+
572575
default void validateScan(Session session, TableHandle table) {}
573576

574577
//

core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java

+19
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import io.trino.spi.connector.TableFunctionApplicationResult;
8989
import io.trino.spi.connector.TableScanRedirectApplicationResult;
9090
import io.trino.spi.connector.TopNApplicationResult;
91+
import io.trino.spi.connector.UnificationResult;
9192
import io.trino.spi.connector.WriterScalingOptions;
9293
import io.trino.spi.expression.ConnectorExpression;
9394
import io.trino.spi.expression.Constant;
@@ -2157,6 +2158,24 @@ public Optional<TableFunctionApplicationResult<TableHandle>> applyTableFunction(
21572158
result.getColumnHandles()));
21582159
}
21592160

2161+
@Override
2162+
public Optional<UnificationResult<TableHandle>> unifyTables(Session session, TableHandle first, TableHandle second)
2163+
{
2164+
CatalogHandle catalogHandle = first.catalogHandle();
2165+
ConnectorTransactionHandle transaction = first.transaction();
2166+
if (!catalogHandle.equals(second.catalogHandle()) || !transaction.equals(second.transaction())) {
2167+
return Optional.empty();
2168+
}
2169+
ConnectorMetadata metadata = getMetadata(session, catalogHandle);
2170+
2171+
return metadata.unifyTables(session.toConnectorSession(catalogHandle), first.connectorHandle(), second.connectorHandle())
2172+
.map(result -> new UnificationResult<>(
2173+
new TableHandle(catalogHandle, result.unifiedHandle(), transaction),
2174+
result.firstCompensation(),
2175+
result.secondCompensation(),
2176+
result.enforcedProperties()));
2177+
}
2178+
21602179
private void verifyProjection(TableHandle table, List<ConnectorExpression> projections, List<Assignment> assignments, int expectedProjectionSize)
21612180
{
21622181
projections.forEach(projection -> requireNonNull(projection, "one of the projections is null"));

core/trino-main/src/main/java/io/trino/sql/dialect/trino/Attributes.java

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.util.Arrays;
4141
import java.util.List;
4242
import java.util.Map;
43+
import java.util.Optional;
4344

4445
import static io.trino.spi.StandardErrorCode.IR_ERROR;
4546
import static io.trino.sql.dialect.trino.TrinoDialect.TRINO;

core/trino-main/src/main/java/io/trino/sql/dialect/trino/ProgramBuilder.java

+10
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ public static class ValueNameAllocator
7979
{
8080
private int label;
8181

82+
public ValueNameAllocator()
83+
{
84+
this(0);
85+
}
86+
87+
public ValueNameAllocator(int initialLabel)
88+
{
89+
this.label = initialLabel;
90+
}
91+
8292
public String newName()
8393
{
8494
return "%" + label++;

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/FieldReference.java

+5
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
9595
return "pretty field reference";
9696
}
9797

98+
public Value base()
99+
{
100+
return base;
101+
}
102+
98103
@Override
99104
public boolean equals(Object obj)
100105
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/FieldSelection.java

+5
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
110110
return "pretty field selection";
111111
}
112112

113+
public Value base()
114+
{
115+
return base;
116+
}
117+
113118
@Override
114119
public boolean equals(Object obj)
115120
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Filter.java

+5
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
102102
return "pretty filter";
103103
}
104104

105+
public Value argument()
106+
{
107+
return input;
108+
}
109+
105110
@Override
106111
public boolean equals(Object obj)
107112
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Project.java

+24
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
import io.trino.sql.newir.FormatOptions;
2424
import io.trino.sql.newir.Operation;
2525
import io.trino.sql.newir.Region;
26+
import io.trino.sql.newir.SourceNode;
2627
import io.trino.sql.newir.Value;
2728

2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Objects;
3132

33+
import static com.google.common.collect.Iterables.getOnlyElement;
3234
import static io.trino.spi.StandardErrorCode.IR_ERROR;
3335
import static io.trino.spi.type.EmptyRowType.EMPTY_ROW;
3436
import static io.trino.sql.dialect.trino.RelationalProgramBuilder.assignRelationRowTypeFieldNames;
@@ -112,6 +114,28 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
112114
return "pretty project";
113115
}
114116

117+
public Block assignments()
118+
{
119+
return assignments.getOnlyBlock();
120+
}
121+
122+
public boolean isPruning(Map<Value, SourceNode> valueMap)
123+
{
124+
if (relationRowType(trinoType(result.type())).equals(EMPTY_ROW)) {
125+
// prunes all fields
126+
return true;
127+
}
128+
Block assignments = assignments();
129+
Row rowConstructor = (Row) valueMap.get(((Return) assignments.getTerminalOperation()).argument());
130+
for (Value rowElement : rowConstructor.arguments()) {
131+
SourceNode assignment = valueMap.get(rowElement);
132+
if (!(assignment instanceof FieldSelection fieldSelection) || !valueMap.get(fieldSelection.base()).equals(assignments)) {
133+
return false;
134+
}
135+
}
136+
return true;
137+
}
138+
115139
@Override
116140
public boolean equals(Object obj)
117141
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Return.java

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
8080
return "pretty return";
8181
}
8282

83+
public Value argument()
84+
{
85+
return input;
86+
}
87+
8388
@Override
8489
public boolean equals(Object obj)
8590
{

core/trino-main/src/main/java/io/trino/sql/newir/Region.java

+10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.List;
2020

21+
import static com.google.common.collect.Iterables.getOnlyElement;
2122
import static io.trino.spi.StandardErrorCode.IR_ERROR;
2223
import static io.trino.sql.newir.FormatOptions.INDENT;
2324
import static java.util.Objects.requireNonNull;
@@ -47,6 +48,15 @@ public static Region singleBlockRegion(Block block)
4748
return new Region(ImmutableList.of(block));
4849
}
4950

51+
public Block getOnlyBlock()
52+
{
53+
if (blocks().size() != 1) {
54+
throw new TrinoException(IR_ERROR, "expected 1 block, actual: " + blocks.size());
55+
}
56+
57+
return getOnlyElement(blocks());
58+
}
59+
5060
public String print(int version, int indentLevel, FormatOptions formatOptions)
5161
{
5262
String indent = INDENT.repeat(indentLevel);

0 commit comments

Comments
 (0)