Skip to content

Commit 852f33b

Browse files
committed
New IR -- WIP
1 parent 67c941f commit 852f33b

37 files changed

+4437
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.dialect.trino;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import com.google.common.collect.ImmutableMap;
18+
import io.trino.spi.type.Type;
19+
import io.trino.sql.ir.Logical;
20+
21+
import java.util.List;
22+
import java.util.Map;
23+
24+
import static java.util.Objects.requireNonNull;
25+
26+
public class Attributes
27+
{
28+
public static final AttributeMetadata<Long> CARDINALITY = new AttributeMetadata<>("cardinality", Long.class, true);
29+
public static final AttributeMetadata<ConstantResult> CONSTANT_RESULT = new AttributeMetadata<>("constant_result", ConstantResult.class, true);
30+
public static final AttributeMetadata<String> FIELD_NAME = new AttributeMetadata<>("field_name", String.class, false);
31+
public static final AttributeMetadata<JoinType> JOIN_TYPE = new AttributeMetadata<>("join_type", JoinType.class, false);
32+
public static final AttributeMetadata<LogicalOperator> LOGICAL_OPERATOR = new AttributeMetadata<>("logical_operator", LogicalOperator.class, false);
33+
public static final AttributeMetadata<OutputNames> OUTPUT_NAMES = new AttributeMetadata<>("output_names", OutputNames.class, false);
34+
35+
// TODO define attributes for deeply nested fields, not just top level or column level
36+
37+
private Attributes() {}
38+
39+
public static class AttributeMetadata<T>
40+
{
41+
private final String name;
42+
private final Class<T> type;
43+
private final boolean external;
44+
45+
private AttributeMetadata(String name, Class<T> type, boolean external)
46+
{
47+
this.name = requireNonNull(name, "name is null");
48+
this.type = requireNonNull(type, "type is null");
49+
this.external = external;
50+
}
51+
52+
public T getAttribute(Map<String, Object> map)
53+
{
54+
return this.type.cast(map.get(this.name));
55+
}
56+
57+
public T putAttribute(Map<String, Object> map, T attribute)
58+
{
59+
return this.type.cast(map.put(name, attribute));
60+
}
61+
62+
public Map<String, Object> asMap(T attribute)
63+
{
64+
return ImmutableMap.of(name, attribute);
65+
}
66+
}
67+
68+
public record ConstantResult(Type type, Object value)
69+
{
70+
public ConstantResult
71+
{
72+
requireNonNull(type, "type is null");
73+
}
74+
75+
@Override
76+
public String toString()
77+
{
78+
return value.toString() + ":" + type.toString();
79+
}
80+
}
81+
82+
public enum JoinType
83+
{
84+
INNER,
85+
LEFT,
86+
RIGHT,
87+
FULL;
88+
89+
public static JoinType of(io.trino.sql.planner.plan.JoinType joinType)
90+
{
91+
return switch (joinType) {
92+
case INNER -> INNER;
93+
case LEFT -> LEFT;
94+
case RIGHT -> RIGHT;
95+
case FULL -> FULL;
96+
};
97+
}
98+
}
99+
100+
public record OutputNames(List<String> outputNames)
101+
{
102+
public OutputNames(List<String> outputNames)
103+
{
104+
this.outputNames = ImmutableList.copyOf(requireNonNull(outputNames, "outputNames is null"));
105+
}
106+
107+
@Override
108+
public String toString()
109+
{
110+
return outputNames.toString();
111+
}
112+
}
113+
114+
public enum LogicalOperator
115+
{
116+
AND,
117+
OR;
118+
119+
public static LogicalOperator of(Logical.Operator operator)
120+
{
121+
return switch (operator) {
122+
case AND -> AND;
123+
case OR -> OR;
124+
};
125+
}
126+
}
127+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.dialect.trino;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import com.google.common.collect.ImmutableMap;
18+
import io.trino.sql.newir.Block;
19+
import io.trino.sql.planner.Symbol;
20+
21+
import java.util.List;
22+
import java.util.Map;
23+
24+
import static com.google.common.collect.ImmutableMap.toImmutableMap;
25+
import static java.util.HashMap.newHashMap;
26+
import static java.util.Objects.requireNonNull;
27+
28+
public record Context(Block.Builder block, Map<Symbol, RowField> symbolMapping)
29+
{
30+
public Context(Block.Builder block)
31+
{
32+
this(block, Map.of());
33+
}
34+
35+
public Context(Block.Builder block, Map<Symbol, RowField> symbolMapping)
36+
{
37+
this.block = requireNonNull(block, "block is null");
38+
this.symbolMapping = ImmutableMap.copyOf(requireNonNull(symbolMapping, "symbolMapping is null"));
39+
}
40+
41+
public static Map<Symbol, RowField> argumentMapping(Block.Parameter parameter, Map<Symbol, String> symbolMapping)
42+
{
43+
return symbolMapping.entrySet().stream()
44+
.collect(toImmutableMap(
45+
Map.Entry::getKey,
46+
entry -> new RowField(parameter, entry.getValue())));
47+
}
48+
49+
public static Map<Symbol, RowField> composedMapping(Context context, Map<Symbol, RowField> newMapping)
50+
{
51+
return composedMapping(context, ImmutableList.of(newMapping));
52+
}
53+
54+
/**
55+
* Compose the correlated mapping from the context with symbol mappings for the current block parameters.
56+
*
57+
* @param context rewrite context containing symbol mapping from all levels of correlation
58+
* @param newMappings list of symbol mappings for current block parameters
59+
* @return composed symbol mapping to rewrite the current block
60+
*/
61+
public static Map<Symbol, RowField> composedMapping(Context context, List<Map<Symbol, RowField>> newMappings)
62+
{
63+
Map<Symbol, RowField> composed = newHashMap(context.symbolMapping().size() + newMappings.stream().mapToInt(Map::size).sum());
64+
composed.putAll(context.symbolMapping());
65+
newMappings.stream().forEach(composed::putAll);
66+
return composed;
67+
}
68+
69+
public record RowField(Block.Parameter row, String field)
70+
{
71+
public RowField
72+
{
73+
requireNonNull(row, "row is null");
74+
requireNonNull(field, "field is null");
75+
}
76+
}
77+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.dialect.trino;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import com.google.common.collect.ImmutableMap;
18+
import com.google.common.collect.Sets;
19+
import io.trino.spi.TrinoException;
20+
import io.trino.sql.dialect.trino.operation.Query;
21+
import io.trino.sql.newir.Block;
22+
import io.trino.sql.newir.Program;
23+
import io.trino.sql.newir.SourceNode;
24+
import io.trino.sql.newir.Value;
25+
import io.trino.sql.planner.plan.OutputNode;
26+
import io.trino.sql.planner.plan.PlanNode;
27+
28+
import java.util.Map;
29+
import java.util.Optional;
30+
import java.util.Set;
31+
import java.util.stream.IntStream;
32+
33+
import static com.google.common.base.Preconditions.checkArgument;
34+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
35+
import static io.trino.spi.StandardErrorCode.IR_ERROR;
36+
37+
/**
38+
* ProgramBuilder builds a MLIR program from a PlanNode tree.
39+
* For now, it builds a program for a single query, and assumes that OutputNode is the root PlanNode.
40+
* In the future, we might support multiple statements.
41+
* The resulting program has the special Query operation as the top-level operation.
42+
* It encloses all query computations in one block.
43+
*/
44+
public class ProgramBuilder
45+
{
46+
private ProgramBuilder() {}
47+
48+
public static Program buildProgram(PlanNode root)
49+
{
50+
checkArgument(root instanceof OutputNode, "Expected root to be an OutputNode. Actual: " + root.getClass().getSimpleName());
51+
52+
ValueNameAllocator nameAllocator = new ValueNameAllocator();
53+
ImmutableMap.Builder<Value, SourceNode> valueMapBuilder = ImmutableMap.builder();
54+
Block.Builder rootBlock = new Block.Builder(Optional.of("^query"), ImmutableList.of());
55+
56+
// for now, ignoring return value. Could be worth to remember it as the final terminal Operation in the Program.
57+
root.accept(new RelationalProgramBuilder(nameAllocator, valueMapBuilder), new Context(rootBlock));
58+
59+
// verify if all values are mapped
60+
Set<String> allocatedValues = IntStream.range(0, nameAllocator.label)
61+
.mapToObj(index -> "%" + index)
62+
.collect(toImmutableSet());
63+
Map<Value, SourceNode> valueMap = valueMapBuilder.buildOrThrow();
64+
Set<String> mappedValues = valueMap.keySet().stream()
65+
.map(Value::name)
66+
.collect(toImmutableSet());
67+
if (!Sets.symmetricDifference(allocatedValues, mappedValues).isEmpty()) {
68+
throw new TrinoException(IR_ERROR, "allocated values differ from mapped values");
69+
}
70+
71+
// allocating this name last to avoid stealing the "%0" label. This label won't be printed.
72+
String resultName = nameAllocator.newName();
73+
74+
return new Program(new Query(resultName, rootBlock.build()), valueMap);
75+
}
76+
77+
public static class ValueNameAllocator
78+
{
79+
private int label = 0;
80+
81+
public String newName()
82+
{
83+
return "%" + label++;
84+
}
85+
}
86+
}

0 commit comments

Comments
 (0)