Skip to content

Commit c0ab3dc

Browse files
committed
CTE reuse WIP
1 parent aa3e5ae commit c0ab3dc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+3906
-43
lines changed

core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java

+16-3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
import io.trino.sql.analyzer.Analysis;
6262
import io.trino.sql.analyzer.Analyzer;
6363
import io.trino.sql.analyzer.AnalyzerFactory;
64+
import io.trino.sql.newir.FormatOptions;
65+
import io.trino.sql.newir.Program;
6466
import io.trino.sql.planner.AdaptivePlanner;
6567
import io.trino.sql.planner.InputExtractor;
6668
import io.trino.sql.planner.LogicalPlanner;
@@ -74,6 +76,7 @@
7476
import io.trino.sql.planner.SubPlan;
7577
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
7678
import io.trino.sql.planner.optimizations.PlanOptimizer;
79+
import io.trino.sql.planner.optimizations.ctereuse.CteReuse;
7780
import io.trino.sql.planner.plan.OutputNode;
7881
import io.trino.sql.tree.ExplainAnalyze;
7982
import io.trino.sql.tree.Query;
@@ -148,6 +151,7 @@ public class SqlQueryExecution
148151
private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory;
149152
private final TaskDescriptorStorage taskDescriptorStorage;
150153
private final PlanOptimizersStatsCollector planOptimizersStatsCollector;
154+
private final FormatOptions formatOptions;
151155

152156
private SqlQueryExecution(
153157
PreparedQuery preparedQuery,
@@ -185,7 +189,8 @@ private SqlQueryExecution(
185189
SqlTaskManager coordinatorTaskManager,
186190
ExchangeManagerRegistry exchangeManagerRegistry,
187191
EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory,
188-
TaskDescriptorStorage taskDescriptorStorage)
192+
TaskDescriptorStorage taskDescriptorStorage,
193+
FormatOptions formatOptions)
189194
{
190195
try (SetThreadName _ = new SetThreadName("Query-" + stateMachine.getQueryId())) {
191196
this.slug = requireNonNull(slug, "slug is null");
@@ -240,6 +245,7 @@ private SqlQueryExecution(
240245
this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "taskSourceFactory is null");
241246
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
242247
this.planOptimizersStatsCollector = requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null");
248+
this.formatOptions = requireNonNull(formatOptions, "formatOptions is null");
243249
}
244250
}
245251

@@ -503,6 +509,9 @@ private PlanRoot doPlanQuery(CachingTableStatsProvider tableStatsProvider)
503509
Plan plan = logicalPlanner.plan(analysis);
504510
queryPlan.set(plan);
505511

512+
Optional<Program> optimizedProgram = CteReuse.reuseCommonSubqueries(plan, plannerContext, getSession(), formatOptions);
513+
checkState(optimizedProgram.isEmpty());
514+
506515
// fragment the plan
507516
SubPlan fragmentedPlan;
508517
try (var _ = scopedSpan(tracer, "fragment-plan")) {
@@ -809,6 +818,7 @@ public static class SqlQueryExecutionFactory
809818
private final ExchangeManagerRegistry exchangeManagerRegistry;
810819
private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory;
811820
private final TaskDescriptorStorage taskDescriptorStorage;
821+
private final FormatOptions formatOptions;
812822

813823
@Inject
814824
SqlQueryExecutionFactory(
@@ -841,7 +851,8 @@ public static class SqlQueryExecutionFactory
841851
SqlTaskManager coordinatorTaskManager,
842852
ExchangeManagerRegistry exchangeManagerRegistry,
843853
EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory,
844-
TaskDescriptorStorage taskDescriptorStorage)
854+
TaskDescriptorStorage taskDescriptorStorage,
855+
FormatOptions formatOptions)
845856
{
846857
this.tracer = requireNonNull(tracer, "tracer is null");
847858
this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null");
@@ -875,6 +886,7 @@ public static class SqlQueryExecutionFactory
875886
this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null");
876887
this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "eventDrivenTaskSourceFactory is null");
877888
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
889+
this.formatOptions = requireNonNull(formatOptions, "formatOptions is null");
878890
}
879891

880892
@Override
@@ -925,7 +937,8 @@ public QueryExecution createQueryExecution(
925937
coordinatorTaskManager,
926938
exchangeManagerRegistry,
927939
eventDrivenTaskSourceFactory,
928-
taskDescriptorStorage);
940+
taskDescriptorStorage,
941+
formatOptions);
929942
}
930943
}
931944
}

core/trino-main/src/main/java/io/trino/metadata/Metadata.java

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import io.trino.spi.connector.TableFunctionApplicationResult;
5555
import io.trino.spi.connector.TableScanRedirectApplicationResult;
5656
import io.trino.spi.connector.TopNApplicationResult;
57+
import io.trino.spi.connector.UnificationResult;
5758
import io.trino.spi.connector.WriterScalingOptions;
5859
import io.trino.spi.expression.ConnectorExpression;
5960
import io.trino.spi.expression.Constant;
@@ -569,6 +570,8 @@ Optional<TopNApplicationResult<TableHandle>> applyTopN(
569570

570571
Optional<TableFunctionApplicationResult<TableHandle>> applyTableFunction(Session session, TableFunctionHandle handle);
571572

573+
Optional<UnificationResult<TableHandle>> unifyTables(Session session, TableHandle first, TableHandle second);
574+
572575
default void validateScan(Session session, TableHandle table) {}
573576

574577
//

core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java

+19
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import io.trino.spi.connector.TableFunctionApplicationResult;
8989
import io.trino.spi.connector.TableScanRedirectApplicationResult;
9090
import io.trino.spi.connector.TopNApplicationResult;
91+
import io.trino.spi.connector.UnificationResult;
9192
import io.trino.spi.connector.WriterScalingOptions;
9293
import io.trino.spi.expression.ConnectorExpression;
9394
import io.trino.spi.expression.Constant;
@@ -2157,6 +2158,24 @@ public Optional<TableFunctionApplicationResult<TableHandle>> applyTableFunction(
21572158
result.getColumnHandles()));
21582159
}
21592160

2161+
@Override
2162+
public Optional<UnificationResult<TableHandle>> unifyTables(Session session, TableHandle first, TableHandle second)
2163+
{
2164+
CatalogHandle catalogHandle = first.catalogHandle();
2165+
ConnectorTransactionHandle transaction = first.transaction();
2166+
if (!catalogHandle.equals(second.catalogHandle()) || !transaction.equals(second.transaction())) {
2167+
return Optional.empty();
2168+
}
2169+
ConnectorMetadata metadata = getMetadata(session, catalogHandle);
2170+
2171+
return metadata.unifyTables(session.toConnectorSession(catalogHandle), first.connectorHandle(), second.connectorHandle())
2172+
.map(result -> new UnificationResult<>(
2173+
new TableHandle(catalogHandle, result.unifiedHandle(), transaction),
2174+
result.firstCompensationFilter(),
2175+
result.secondCompensationFilter(),
2176+
result.enforcedProperties()));
2177+
}
2178+
21602179
private void verifyProjection(TableHandle table, List<ConnectorExpression> projections, List<Assignment> assignments, int expectedProjectionSize)
21612180
{
21622181
projections.forEach(projection -> requireNonNull(projection, "one of the projections is null"));

core/trino-main/src/main/java/io/trino/sql/dialect/trino/ProgramBuilder.java

+10
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ public static class ValueNameAllocator
7979
{
8080
private int label;
8181

82+
public ValueNameAllocator()
83+
{
84+
this(0);
85+
}
86+
87+
public ValueNameAllocator(int initialLabel)
88+
{
89+
this.label = initialLabel;
90+
}
91+
8292
public String newName()
8393
{
8494
return "%" + label++;

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/AggregateCall.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import static io.trino.spi.StandardErrorCode.IR_ERROR;
3737
import static io.trino.spi.type.EmptyRowType.EMPTY_ROW;
38+
import static io.trino.sql.dialect.trino.Attributes.AGGREGATION_STEP;
3839
import static io.trino.sql.dialect.trino.Attributes.AggregationStep.FINAL;
3940
import static io.trino.sql.dialect.trino.Attributes.AggregationStep.PARTIAL;
4041
import static io.trino.sql.dialect.trino.Attributes.AggregationStep.SINGLE;
@@ -51,7 +52,7 @@
5152
import static java.util.Objects.requireNonNull;
5253

5354
public class AggregateCall
54-
extends Operation
55+
extends TrinoOperation
5556
{
5657
private static final String NAME = "aggregate_call";
5758

@@ -221,6 +222,24 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
221222
return "pretty aggregate call";
222223
}
223224

225+
@Override
226+
public Operation withArgument(Value newArgument, int index)
227+
{
228+
validateArgument(newArgument, index);
229+
return new AggregateCall(
230+
result.name(),
231+
newArgument,
232+
trinoType(result.type()),
233+
arguments.getOnlyBlock(),
234+
filterSelector.getOnlyBlock(),
235+
maskSelector.getOnlyBlock(),
236+
orderingSelector.getOnlyBlock(),
237+
Optional.ofNullable(SORT_ORDERS.getAttribute(attributes)),
238+
RESOLVED_FUNCTION.getAttribute(attributes),
239+
DISTINCT.getAttribute(attributes),
240+
AGGREGATION_STEP.getAttribute(attributes));
241+
}
242+
224243
@Override
225244
public boolean equals(Object obj)
226245
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Aggregation.java

+21-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.List;
3030
import java.util.Map;
3131
import java.util.Objects;
32+
import java.util.Optional;
3233
import java.util.OptionalInt;
3334

3435
import static io.trino.spi.StandardErrorCode.IR_ERROR;
@@ -48,7 +49,7 @@
4849
import static java.util.Objects.requireNonNull;
4950

5051
public class Aggregation
51-
extends Operation
52+
extends TrinoOperation
5253
{
5354
private static final String NAME = "aggregation";
5455

@@ -182,6 +183,25 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
182183
return "pretty aggregation";
183184
}
184185

186+
@Override
187+
public Operation withArgument(Value newArgument, int index)
188+
{
189+
validateArgument(newArgument, index);
190+
return new Aggregation(
191+
result.name(),
192+
newArgument,
193+
aggregateCalls.getOnlyBlock(),
194+
groupingKeysSelector.getOnlyBlock(),
195+
hashSelector.getOnlyBlock(),
196+
GROUPING_SETS_COUNT.getAttribute(attributes),
197+
GLOBAL_GROUPING_SETS.getAttribute(attributes),
198+
Optional.ofNullable(GROUP_ID_INDEX.getAttribute(attributes)).map(OptionalInt::of).orElse(OptionalInt.empty()),
199+
PRE_GROUPED_INDEXES.getAttribute(attributes),
200+
AGGREGATION_STEP.getAttribute(attributes),
201+
INPUT_REDUCING.getAttribute(attributes),
202+
ImmutableMap.of());
203+
}
204+
185205
@Override
186206
public boolean equals(Object obj)
187207
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Array.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.trino.sql.newir.Region;
2424
import io.trino.sql.newir.Value;
2525

26+
import java.util.ArrayList;
2627
import java.util.List;
2728
import java.util.Map;
2829
import java.util.Objects;
@@ -35,7 +36,7 @@
3536
import static java.util.Objects.requireNonNull;
3637

3738
public final class Array
38-
extends Operation
39+
extends TrinoOperation
3940
{
4041
private static final String NAME = "array";
4142

@@ -95,6 +96,19 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
9596
return "array :)";
9697
}
9798

99+
@Override
100+
public Operation withArgument(Value newArgument, int index)
101+
{
102+
validateArgument(newArgument, index);
103+
List<Value> newArguments = new ArrayList<>(elements);
104+
newArguments.set(index, newArgument);
105+
return new Array(
106+
result.name(),
107+
((ArrayType) trinoType(result.type())).getElementType(),
108+
newArguments,
109+
ImmutableList.of());
110+
}
111+
98112
@Override
99113
public boolean equals(Object obj)
100114
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Between.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import static java.util.Objects.requireNonNull;
3535

3636
public final class Between
37-
extends Operation
37+
extends TrinoOperation
3838
{
3939
private static final String NAME = "between";
4040

@@ -100,6 +100,18 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
100100
return "pretty between";
101101
}
102102

103+
@Override
104+
public Operation withArgument(Value newArgument, int index)
105+
{
106+
validateArgument(newArgument, index);
107+
return new Between(
108+
result.name(),
109+
index == 0 ? newArgument : input,
110+
index == 1 ? newArgument : min,
111+
index == 2 ? newArgument : max,
112+
ImmutableList.of());
113+
}
114+
103115
@Override
104116
public boolean equals(Object obj)
105117
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Bind.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.trino.sql.newir.Value;
2424
import io.trino.type.FunctionType;
2525

26+
import java.util.ArrayList;
2627
import java.util.List;
2728
import java.util.Map;
2829
import java.util.Objects;
@@ -35,7 +36,7 @@
3536
import static java.util.Objects.requireNonNull;
3637

3738
public final class Bind
38-
extends Operation
39+
extends TrinoOperation
3940
{
4041
private static final String NAME = "bind";
4142

@@ -106,6 +107,21 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
106107
return "bind :)";
107108
}
108109

110+
@Override
111+
public Operation withArgument(Value newArgument, int index)
112+
{
113+
validateArgument(newArgument, index);
114+
List<Value> newValues = new ArrayList<>(values);
115+
if (index < values.size()) {
116+
newValues.set(index, newArgument);
117+
}
118+
return new Bind(
119+
result.name(),
120+
newValues,
121+
index == values.size() ? newArgument : lambda,
122+
ImmutableList.of());
123+
}
124+
109125
@Override
110126
public boolean equals(Object obj)
111127
{

core/trino-main/src/main/java/io/trino/sql/dialect/trino/operation/Call.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io.trino.sql.newir.Region;
2323
import io.trino.sql.newir.Value;
2424

25+
import java.util.ArrayList;
2526
import java.util.List;
2627
import java.util.Map;
2728
import java.util.Objects;
@@ -35,7 +36,7 @@
3536
import static java.util.Objects.requireNonNull;
3637

3738
public final class Call
38-
extends Operation
39+
extends TrinoOperation
3940
{
4041
private static final String NAME = "call";
4142

@@ -100,6 +101,19 @@ public String prettyPrint(int indentLevel, FormatOptions formatOptions)
100101
return "call :)";
101102
}
102103

104+
@Override
105+
public Operation withArgument(Value newArgument, int index)
106+
{
107+
validateArgument(newArgument, index);
108+
List<Value> newArguments = new ArrayList<>(arguments);
109+
newArguments.set(index, newArgument);
110+
return new Call(
111+
result.name(),
112+
newArguments,
113+
RESOLVED_FUNCTION.getAttribute(attributes),
114+
ImmutableList.of());
115+
}
116+
103117
@Override
104118
public boolean equals(Object obj)
105119
{

0 commit comments

Comments
 (0)