Skip to content

Add map literal syntax #25276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ primaryExpression
| POSITION '(' valueExpression IN valueExpression ')' #position
| '(' expression (',' expression)+ ')' #rowConstructor
| ROW '(' expression (',' expression)* ')' #rowConstructor
| '{' (mapEntry (',' mapEntry)*)? '}' #mapConstructor
| name=LISTAGG '(' setQuantifier? expression (',' string)?
(ON OVERFLOW listAggOverflowBehavior)? ')'
(WITHIN GROUP '(' ORDER BY sortItem (',' sortItem)* ')')
Expand Down Expand Up @@ -646,6 +647,10 @@ primaryExpression
')' #jsonArray
;

mapEntry
: key=expression ':' value=expression
;

jsonPathInvocation
: jsonValueExpression ',' path=string
(AS pathName=identifier)?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import io.trino.spi.type.DecimalParseResult;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TimeType;
import io.trino.spi.type.TimeWithTimeZoneType;
Expand Down Expand Up @@ -119,6 +120,7 @@
import io.trino.sql.tree.LocalTimestamp;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.MapLiteral;
import io.trino.sql.tree.MeasureDefinition;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
Expand Down Expand Up @@ -1099,6 +1101,15 @@ protected Type visitSubscriptExpression(SubscriptExpression node, Context contex
return getOperator(context, node, SUBSCRIPT, node.getBase(), node.getIndex());
}

@Override
protected Type visitMapLiteral(MapLiteral node, Context context)
{
Type keyType = coerceToSingleType(context, "All MAP keys", node.getKeys());
Type valueType = coerceToSingleType(context, "All MAP values", node.getValues());
MapType mapType = new MapType(keyType, valueType, plannerContext.getTypeOperators());
return setExpressionType(node, mapType);
}

@Override
protected Type visitArray(Array node, Context context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.trino.spi.type.DecimalParseResult;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TimeType;
import io.trino.spi.type.TimeWithTimeZoneType;
Expand Down Expand Up @@ -101,6 +102,8 @@
import io.trino.sql.tree.LocalTimestamp;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.MapLiteral;
import io.trino.sql.tree.MapLiteral.EntryLiteral;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullIfExpression;
Expand All @@ -119,6 +122,7 @@
import io.trino.type.JsonPath2016Type;
import io.trino.type.UnknownType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -134,6 +138,7 @@
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RowType.anonymousRow;
import static io.trino.spi.type.TimeWithTimeZoneType.createTimeWithTimeZoneType;
import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType;
import static io.trino.spi.type.TinyintType.TINYINT;
Expand Down Expand Up @@ -318,6 +323,7 @@ private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot)
case FunctionCall expression -> translate(expression);
case DereferenceExpression expression -> translate(expression);
case Array expression -> translate(expression);
case MapLiteral expression -> translate(expression);
case CurrentCatalog expression -> translate(expression);
case CurrentSchema expression -> translate(expression);
case CurrentPath expression -> translate(expression);
Expand Down Expand Up @@ -687,6 +693,23 @@ private io.trino.sql.ir.Expression translate(DereferenceExpression expression)
return new FieldReference(translateExpression(expression.getBase()), index);
}

private io.trino.sql.ir.Expression translate(MapLiteral expression)
{
MapType mapType = (MapType) analysis.getType(expression);

// convert entries to array(row(key, value))
List<io.trino.sql.ir.Expression> entries = new ArrayList<>();
for (EntryLiteral entry : expression.getEntries()) {
entries.add(new io.trino.sql.ir.Row(ImmutableList.of(translateExpression(entry.key()), translateExpression(entry.value()))));
}
io.trino.sql.ir.Array array = new io.trino.sql.ir.Array(anonymousRow(mapType.getKeyType(), mapType.getValueType()), entries);

return BuiltinFunctionCallBuilder.resolve(plannerContext.getMetadata())
.setName("map_from_entries")
.addArgument(array.type(), array)
.build();
}

private io.trino.sql.ir.Expression translate(Array expression)
{
List<io.trino.sql.ir.Expression> values = expression.getValues().stream()
Expand Down
33 changes: 33 additions & 0 deletions core/trino-main/src/test/java/io/trino/type/TestMapOperators.java
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,39 @@ public void testMapToMapCast()
.hasErrorCode(INVALID_CAST_ARGUMENT);
}

@Test
public void testMapLiteral()
{
assertThat(assertions.expression("{}"))
.hasType(mapType(UNKNOWN, UNKNOWN))
.isEqualTo(ImmutableMap.of());

assertThat(assertions.expression("{'x' : 1, 'y' : 2}"))
.hasType(mapType(createVarcharType(1), INTEGER))
.isEqualTo(ImmutableMap.of("x", 1, "y", 2));

assertThat(assertions.expression("{{'a' : 1, 'b' : 2} : 11, {'c' : 3, 'd' : 4} : 22}"))
.hasType(mapType(mapType(createVarcharType(1), INTEGER), INTEGER))
.isEqualTo(ImmutableMap.of(ImmutableMap.of("a", 1, "b", 2), 11, ImmutableMap.of("c", 3, "d", 4), 22));

Map<String, Integer> expectedNullValueMap = new HashMap<>();
expectedNullValueMap.put("x", 1);
expectedNullValueMap.put("y", null);
assertThat(assertions.expression("{'x' : 1, 'y' : null}"))
.hasType(mapType(createVarcharType(1), INTEGER))
.isEqualTo(expectedNullValueMap);

// invalid invocation
assertTrinoExceptionThrownBy(assertions.expression("{'a' : 1, 'a' : 2}")::evaluate)
.hasMessage("Duplicate map keys (a) are not allowed");

assertTrinoExceptionThrownBy(assertions.expression("{1 : 'a', 1 : 'b'}")::evaluate)
.hasMessage("Duplicate map keys (1) are not allowed");

assertTrinoExceptionThrownBy(assertions.expression("{'a' : 1, null : 2}")::evaluate)
.hasMessage("map key cannot be null");
}

@Test
public void testMapFromEntries()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import io.trino.sql.tree.LocalTimestamp;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.MapLiteral;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullIfExpression;
Expand Down Expand Up @@ -303,6 +304,19 @@ protected String visitAllRows(AllRows node, Void context)
return "ALL";
}

@Override
protected String visitMapLiteral(MapLiteral node, Void context)
{
return node.getEntries().stream()
.map(Formatter::formatMapEntry)
.collect(joining(", ", "{", "}"));
}

private static String formatMapEntry(MapLiteral.EntryLiteral entry)
{
return formatSql(entry.key()) + " : " + formatSql(entry.value());
}

@Override
protected String visitArray(Array node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.LoopStatement;
import io.trino.sql.tree.MapLiteral;
import io.trino.sql.tree.MapLiteral.EntryLiteral;
import io.trino.sql.tree.MeasureDefinition;
import io.trino.sql.tree.Merge;
import io.trino.sql.tree.MergeCase;
Expand Down Expand Up @@ -2377,6 +2379,16 @@ public Node visitRowConstructor(SqlBaseParser.RowConstructorContext context)
return new Row(getLocation(context), visit(context.expression(), Expression.class));
}

@Override
public Node visitMapConstructor(SqlBaseParser.MapConstructorContext context)
{
List<EntryLiteral> entries = new ArrayList<>();
for (SqlBaseParser.MapEntryContext mapEntry : context.mapEntry()) {
entries.add(new EntryLiteral((Expression) visit(mapEntry.key), (Expression) visit(mapEntry.value)));
}
return new MapLiteral(getLocation(context), entries);
}

@Override
public Node visitArrayConstructor(SqlBaseParser.ArrayConstructorContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ protected R visitIsNullPredicate(IsNullPredicate node, C context)
return visitExpression(node, context);
}

protected R visitMapLiteral(MapLiteral node, C context)
{
return visitExpression(node, context);
}

protected R visitArray(Array node, C context)
{
return visitExpression(node, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ public Expression rewriteLiteral(Literal node, C context, ExpressionTreeRewriter
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteMapLiteral(MapLiteral node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
}

public Expression rewriteArray(Array node, C context, ExpressionTreeRewriter<C> treeRewriter)
{
return rewriteExpression(node, context, treeRewriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,37 @@ public Expression visitArithmeticBinary(ArithmeticBinaryExpression node, Context
return node;
}

@Override
protected Expression visitMapLiteral(MapLiteral node, Context<C> context)
{
if (!context.isDefaultRewrite()) {
Expression result = rewriter.rewriteMapLiteral(node, context.get(), ExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}

boolean changed = false;
ImmutableList.Builder<MapLiteral.EntryLiteral> entries = ImmutableList.builder();
for (MapLiteral.EntryLiteral entry : node.getEntries()) {
Expression key = rewrite(entry.key(), context.get());
Expression value = rewrite(entry.value(), context.get());
if (entry.key() != key || entry.value() != value) {
entries.add(new MapLiteral.EntryLiteral(key, value));
changed = true;
}
else {
entries.add(entry);
}
}

if (changed) {
return new MapLiteral(node.getLocation().orElseThrow(), entries.build());
}

return node;
}

@Override
protected Expression visitArray(Array node, Context<C> context)
{
Expand Down
103 changes: 103 additions & 0 deletions core/trino-parser/src/main/java/io/trino/sql/tree/MapLiteral.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.tree;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public final class MapLiteral
extends Expression
{
private final List<EntryLiteral> entries;

public MapLiteral(NodeLocation location, List<EntryLiteral> entries)
{
super(location);
this.entries = ImmutableList.copyOf(requireNonNull(entries, "entries is null"));
}

public List<EntryLiteral> getEntries()
{
return entries;
}

public List<Expression> getKeys()
{
return entries.stream()
.map(EntryLiteral::key)
.collect(toImmutableList());
}

public List<Expression> getValues()
{
return entries.stream()
.map(EntryLiteral::value)
.collect(toImmutableList());
}

@Override
public <R, C> R accept(AstVisitor<R, C> visitor, C context)
{
return visitor.visitMapLiteral(this, context);
}

@Override
public List<Node> getChildren()
{
return entries.stream()
.flatMap(entry -> Stream.of(entry.key(), entry.value()))
.collect(toImmutableList());
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
return (o != null) && (getClass() == o.getClass());
}

@Override
public int hashCode()
{
return MapLiteral.class.hashCode();
}

@Override
public boolean shallowEquals(Node other)
{
return sameClass(this, other);
}

public record EntryLiteral(Expression key, Expression value)
{
public EntryLiteral
{
requireNonNull(key, "key is null");
requireNonNull(value, "value is null");
}

@Override
public String toString()
{
return key + " => " + value;
}
}
}
Loading