Skip to content

Commit 73ab98e

Browse files
committed
Add map literal syntax
1 parent 8d2c8bb commit 73ab98e

File tree

11 files changed

+258
-0
lines changed

11 files changed

+258
-0
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

+5
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ primaryExpression
575575
| POSITION '(' valueExpression IN valueExpression ')' #position
576576
| '(' expression (',' expression)+ ')' #rowConstructor
577577
| ROW '(' expression (',' expression)* ')' #rowConstructor
578+
| '{' (mapEntry (',' mapEntry)*)? '}' #mapConstructor
578579
| name=LISTAGG '(' setQuantifier? expression (',' string)?
579580
(ON OVERFLOW listAggOverflowBehavior)? ')'
580581
(WITHIN GROUP '(' ORDER BY sortItem (',' sortItem)* ')')
@@ -646,6 +647,10 @@ primaryExpression
646647
')' #jsonArray
647648
;
648649

650+
mapEntry
651+
: key=expression ':' value=expression
652+
;
653+
649654
jsonPathInvocation
650655
: jsonValueExpression ',' path=string
651656
(AS pathName=identifier)?

core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java

+11
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import io.trino.spi.type.DecimalParseResult;
4242
import io.trino.spi.type.DecimalType;
4343
import io.trino.spi.type.Decimals;
44+
import io.trino.spi.type.MapType;
4445
import io.trino.spi.type.RowType;
4546
import io.trino.spi.type.TimeType;
4647
import io.trino.spi.type.TimeWithTimeZoneType;
@@ -119,6 +120,7 @@
119120
import io.trino.sql.tree.LocalTimestamp;
120121
import io.trino.sql.tree.LogicalExpression;
121122
import io.trino.sql.tree.LongLiteral;
123+
import io.trino.sql.tree.MapLiteral;
122124
import io.trino.sql.tree.MeasureDefinition;
123125
import io.trino.sql.tree.Node;
124126
import io.trino.sql.tree.NodeRef;
@@ -1099,6 +1101,15 @@ protected Type visitSubscriptExpression(SubscriptExpression node, Context contex
10991101
return getOperator(context, node, SUBSCRIPT, node.getBase(), node.getIndex());
11001102
}
11011103

1104+
@Override
1105+
protected Type visitMapLiteral(MapLiteral node, Context context)
1106+
{
1107+
Type keyType = coerceToSingleType(context, "All MAP keys", node.getKeys());
1108+
Type valueType = coerceToSingleType(context, "All MAP values", node.getValues());
1109+
MapType mapType = new MapType(keyType, valueType, plannerContext.getTypeOperators());
1110+
return setExpressionType(node, mapType);
1111+
}
1112+
11021113
@Override
11031114
protected Type visitArray(Array node, Context context)
11041115
{

core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java

+23
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import io.trino.spi.type.DecimalParseResult;
2828
import io.trino.spi.type.DecimalType;
2929
import io.trino.spi.type.Decimals;
30+
import io.trino.spi.type.MapType;
3031
import io.trino.spi.type.RowType;
3132
import io.trino.spi.type.TimeType;
3233
import io.trino.spi.type.TimeWithTimeZoneType;
@@ -101,6 +102,8 @@
101102
import io.trino.sql.tree.LocalTimestamp;
102103
import io.trino.sql.tree.LogicalExpression;
103104
import io.trino.sql.tree.LongLiteral;
105+
import io.trino.sql.tree.MapLiteral;
106+
import io.trino.sql.tree.MapLiteral.EntryLiteral;
104107
import io.trino.sql.tree.NodeRef;
105108
import io.trino.sql.tree.NotExpression;
106109
import io.trino.sql.tree.NullIfExpression;
@@ -119,6 +122,7 @@
119122
import io.trino.type.JsonPath2016Type;
120123
import io.trino.type.UnknownType;
121124

125+
import java.util.ArrayList;
122126
import java.util.Arrays;
123127
import java.util.Collections;
124128
import java.util.HashMap;
@@ -134,6 +138,7 @@
134138
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
135139
import static io.trino.spi.type.BooleanType.BOOLEAN;
136140
import static io.trino.spi.type.DoubleType.DOUBLE;
141+
import static io.trino.spi.type.RowType.anonymousRow;
137142
import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType;
138143
import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType;
139144
import static io.trino.spi.type.TinyintType.TINYINT;
@@ -318,6 +323,7 @@ private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot)
318323
case FunctionCall expression -> translate(expression);
319324
case DereferenceExpression expression -> translate(expression);
320325
case Array expression -> translate(expression);
326+
case MapLiteral expression -> translate(expression);
321327
case CurrentCatalog expression -> translate(expression);
322328
case CurrentSchema expression -> translate(expression);
323329
case CurrentPath expression -> translate(expression);
@@ -687,6 +693,23 @@ private io.trino.sql.ir.Expression translate(DereferenceExpression expression)
687693
return new FieldReference(translateExpression(expression.getBase()), index);
688694
}
689695

696+
private io.trino.sql.ir.Expression translate(MapLiteral expression)
697+
{
698+
MapType mapType = (MapType) analysis.getType(expression);
699+
700+
// convert entries to array(row(key, value))
701+
List<io.trino.sql.ir.Expression> entries = new ArrayList<>();
702+
for (EntryLiteral entry : expression.getEntries()) {
703+
entries.add(new io.trino.sql.ir.Row(ImmutableList.of(translateExpression(entry.key()), translateExpression(entry.value()))));
704+
}
705+
io.trino.sql.ir.Array array = new io.trino.sql.ir.Array(anonymousRow(mapType.getKeyType(), mapType.getValueType()), entries);
706+
707+
return BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata())
708+
.setName("map_from_entries")
709+
.addArgument(array.type(), array)
710+
.build();
711+
}
712+
690713
private io.trino.sql.ir.Expression translate(Array expression)
691714
{
692715
List<io.trino.sql.ir.Expression> values = expression.getValues().stream()

core/trino-main/src/test/java/io/trino/type/TestMapOperators.java

+33
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,39 @@ public void testMapToMapCast()
16121612
.hasErrorCode(INVALID_CAST_ARGUMENT);
16131613
}
16141614

1615+
@Test
1616+
public void testMapLiteral()
1617+
{
1618+
assertThat(assertions.expression("{}"))
1619+
.hasType(mapType(UNKNOWN, UNKNOWN))
1620+
.isEqualTo(ImmutableMap.of());
1621+
1622+
assertThat(assertions.expression("{'x' : 1, 'y' : 2}"))
1623+
.hasType(mapType(createVarcharType(1), INTEGER))
1624+
.isEqualTo(ImmutableMap.of("x", 1, "y", 2));
1625+
1626+
assertThat(assertions.expression("{{'a' : 1, 'b' : 2} : 11, {'c' : 3, 'd' : 4} : 22}"))
1627+
.hasType(mapType(mapType(createVarcharType(1), INTEGER), INTEGER))
1628+
.isEqualTo(ImmutableMap.of(ImmutableMap.of("a", 1, "b", 2), 11, ImmutableMap.of("c", 3, "d", 4), 22));
1629+
1630+
Map<String, Integer> expectedNullValueMap = new HashMap<>();
1631+
expectedNullValueMap.put("x", 1);
1632+
expectedNullValueMap.put("y", null);
1633+
assertThat(assertions.expression("{'x' : 1, 'y' : null}"))
1634+
.hasType(mapType(createVarcharType(1), INTEGER))
1635+
.isEqualTo(expectedNullValueMap);
1636+
1637+
// invalid invocation
1638+
assertTrinoExceptionThrownBy(assertions.expression("{'a' : 1, 'a' : 2}")::evaluate)
1639+
.hasMessage("Duplicate map keys (a) are not allowed");
1640+
1641+
assertTrinoExceptionThrownBy(assertions.expression("{1 : 'a', 1 : 'b'}")::evaluate)
1642+
.hasMessage("Duplicate map keys (1) are not allowed");
1643+
1644+
assertTrinoExceptionThrownBy(assertions.expression("{'a' : 1, null : 2}")::evaluate)
1645+
.hasMessage("map key cannot be null");
1646+
}
1647+
16151648
@Test
16161649
public void testMapFromEntries()
16171650
{

core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java

+14
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
import io.trino.sql.tree.LocalTimestamp;
7575
import io.trino.sql.tree.LogicalExpression;
7676
import io.trino.sql.tree.LongLiteral;
77+
import io.trino.sql.tree.MapLiteral;
7778
import io.trino.sql.tree.Node;
7879
import io.trino.sql.tree.NotExpression;
7980
import io.trino.sql.tree.NullIfExpression;
@@ -303,6 +304,19 @@ protected String visitAllRows(AllRows node, Void context)
303304
return "ALL";
304305
}
305306

307+
@Override
308+
protected String visitMapLiteral(MapLiteral node, Void context)
309+
{
310+
return node.getEntries().stream()
311+
.map(Formatter::formatMapEntry)
312+
.collect(joining(", ", "{", "}"));
313+
}
314+
315+
private static String formatMapEntry(MapLiteral.EntryLiteral entry)
316+
{
317+
return formatSql(entry.key()) + " : " + formatSql(entry.value());
318+
}
319+
306320
@Override
307321
protected String visitArray(Array node, Void context)
308322
{

core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java

+12
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@
167167
import io.trino.sql.tree.LogicalExpression;
168168
import io.trino.sql.tree.LongLiteral;
169169
import io.trino.sql.tree.LoopStatement;
170+
import io.trino.sql.tree.MapLiteral;
171+
import io.trino.sql.tree.MapLiteral.EntryLiteral;
170172
import io.trino.sql.tree.MeasureDefinition;
171173
import io.trino.sql.tree.Merge;
172174
import io.trino.sql.tree.MergeCase;
@@ -2377,6 +2379,16 @@ public Node visitRowConstructor(SqlBaseParser.RowConstructorContext context)
23772379
return new Row(getLocation(context), visit(context.expression(), Expression.class));
23782380
}
23792381

2382+
@Override
2383+
public Node visitMapConstructor(SqlBaseParser.MapConstructorContext context)
2384+
{
2385+
List<EntryLiteral> entries = new ArrayList<>();
2386+
for (SqlBaseParser.MapEntryContext mapEntry : context.mapEntry()) {
2387+
entries.add(new EntryLiteral((Expression) visit(mapEntry.key), (Expression) visit(mapEntry.value)));
2388+
}
2389+
return new MapLiteral(getLocation(context), entries);
2390+
}
2391+
23802392
@Override
23812393
public Node visitArrayConstructor(SqlBaseParser.ArrayConstructorContext context)
23822394
{

core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java

+5
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,11 @@ protected R visitIsNullPredicate(IsNullPredicate node, C context)
437437
return visitExpression(node, context);
438438
}
439439

440+
protected R visitMapLiteral(MapLiteral node, C context)
441+
{
442+
return visitExpression(node, context);
443+
}
444+
440445
protected R visitArray(Array node, C context)
441446
{
442447
return visitExpression(node, context);

core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java

+5
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ public Expression rewriteLiteral(Literal node, C context, ExpressionTreeRewriter
140140
return rewriteExpression(node, context, treeRewriter);
141141
}
142142

143+
public Expression rewriteMapLiteral(MapLiteral node, C context, ExpressionTreeRewriter<C> treeRewriter)
144+
{
145+
return rewriteExpression(node, context, treeRewriter);
146+
}
147+
143148
public Expression rewriteArray(Array node, C context, ExpressionTreeRewriter<C> treeRewriter)
144149
{
145150
return rewriteExpression(node, context, treeRewriter);

core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java

+31
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,37 @@ public Expression visitArithmeticBinary(ArithmeticBinaryExpression node, Context
135135
return node;
136136
}
137137

138+
@Override
139+
protected Expression visitMapLiteral(MapLiteral node, Context<C> context)
140+
{
141+
if (!context.isDefaultRewrite()) {
142+
Expression result = rewriter.rewriteMapLiteral(node, context.get(), ExpressionTreeRewriter.this);
143+
if (result != null) {
144+
return result;
145+
}
146+
}
147+
148+
boolean changed = false;
149+
ImmutableList.Builder<MapLiteral.EntryLiteral> entries = ImmutableList.builder();
150+
for (MapLiteral.EntryLiteral entry : node.getEntries()) {
151+
Expression key = rewrite(entry.key(), context.get());
152+
Expression value = rewrite(entry.value(), context.get());
153+
if (entry.key() != key || entry.value() != value) {
154+
entries.add(new MapLiteral.EntryLiteral(key, value));
155+
changed = true;
156+
}
157+
else {
158+
entries.add(entry);
159+
}
160+
}
161+
162+
if (changed) {
163+
return new MapLiteral(node.getLocation().orElseThrow(), entries.build());
164+
}
165+
166+
return node;
167+
}
168+
138169
@Override
139170
protected Expression visitArray(Array node, Context<C> context)
140171
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.tree;
15+
16+
import com.google.common.collect.ImmutableList;
17+
18+
import java.util.List;
19+
import java.util.stream.Stream;
20+
21+
import static com.google.common.collect.ImmutableList.toImmutableList;
22+
import static java.util.Objects.requireNonNull;
23+
24+
public final class MapLiteral
25+
extends Expression
26+
{
27+
private final List<EntryLiteral> entries;
28+
29+
public MapLiteral(NodeLocation location, List<EntryLiteral> entries)
30+
{
31+
super(location);
32+
this.entries = ImmutableList.copyOf(requireNonNull(entries, "entries is null"));
33+
}
34+
35+
public List<EntryLiteral> getEntries()
36+
{
37+
return entries;
38+
}
39+
40+
public List<Expression> getKeys()
41+
{
42+
return entries.stream()
43+
.map(EntryLiteral::key)
44+
.collect(toImmutableList());
45+
}
46+
47+
public List<Expression> getValues()
48+
{
49+
return entries.stream()
50+
.map(EntryLiteral::value)
51+
.collect(toImmutableList());
52+
}
53+
54+
@Override
55+
public <R, C> R accept(AstVisitor<R, C> visitor, C context)
56+
{
57+
return visitor.visitMapLiteral(this, context);
58+
}
59+
60+
@Override
61+
public List<Node> getChildren()
62+
{
63+
return entries.stream()
64+
.flatMap(entry -> Stream.of(entry.key(), entry.value()))
65+
.collect(toImmutableList());
66+
}
67+
68+
@Override
69+
public boolean equals(Object o)
70+
{
71+
if (this == o) {
72+
return true;
73+
}
74+
return (o != null) && (getClass() == o.getClass());
75+
}
76+
77+
@Override
78+
public int hashCode()
79+
{
80+
return MapLiteral.class.hashCode();
81+
}
82+
83+
@Override
84+
public boolean shallowEquals(Node other)
85+
{
86+
return sameClass(this, other);
87+
}
88+
89+
public record EntryLiteral(Expression key, Expression value)
90+
{
91+
public EntryLiteral
92+
{
93+
requireNonNull(key, "key is null");
94+
requireNonNull(value, "value is null");
95+
}
96+
97+
@Override
98+
public String toString()
99+
{
100+
return key + " => " + value;
101+
}
102+
}
103+
}

0 commit comments

Comments
 (0)