Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ primaryExpression
| CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
| CAST '(' expression AS type ')' #cast
// This is a postgres extension to ANSI SQL, which allows for the use of "::" to cast
| primaryExpression DOUBLE_COLON type #cast
| TRY_CAST '(' expression AS type ')' #cast
| ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor
| value=primaryExpression '[' index=valueExpression ']' #subscript
Expand Down Expand Up @@ -1327,6 +1329,7 @@ LT: '<';
LTE: '<=';
GT: '>';
GTE: '>=';
DOUBLE_COLON: '::';
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This token declaration isn't needed, as we don't compare the token value in the generated code.


PLUS: '+';
MINUS: '-';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@
import static io.trino.sql.tree.TableFunctionDescriptorArgument.nullDescriptorArgument;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.Objects.requireNonNullElse;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -2387,7 +2388,8 @@ public Node visitArrayConstructor(SqlBaseParser.ArrayConstructorContext context)
public Node visitCast(SqlBaseParser.CastContext context)
{
boolean isTryCast = context.TRY_CAST() != null;
return new Cast(getLocation(context), (Expression) visit(context.expression()), (DataType) visit(context.type()), isTryCast);
Expression expression = (Expression) visit(requireNonNullElse(context.expression(), context.primaryExpression()));
return new Cast(getLocation(context), expression, (DataType) visit(context.type()), isTryCast);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,51 @@ public void testCase()
Optional.of(new LongLiteral(location(1, 38), "3"))));
}

@Test
public void testCast()
{
assertThat(expression("CAST(1 AS BIGINT)"))
.isEqualTo(new Cast(location(1, 1),
new LongLiteral(location(1, 6), "1"),
simpleType(location(1, 11), "BIGINT")));

assertThat(expression("1::BIGINT"))
.isEqualTo(new Cast(location(1, 1),
new LongLiteral(location(1, 1), "1"),
simpleType(location(1, 4), "BIGINT")));

assertThat(expression("-3::BIGINT"))
.isEqualTo(new Cast(location(1, 1),
new LongLiteral(location(1, 1), "-3"),
simpleType(location(1, 5), "BIGINT")));

assertThat(expression("3*'4'::BIGINT"))
.isEqualTo(new ArithmeticBinaryExpression(
location(1, 2),
ArithmeticBinaryExpression.Operator.MULTIPLY,
new LongLiteral(location(1, 1), "3"),
new Cast(
location(1, 3),
new StringLiteral(location(1, 3), "4"),
simpleType(location(1, 8), "BIGINT"))));

assertThat(expression("CAST(ROW(11, 12) AS ROW(COL0 INTEGER, COL1 INTEGER))"))
.isEqualTo(new Cast(location(1, 1),
new Row(location(1, 6), Lists.newArrayList(new LongLiteral(location(1, 10), "11"), new LongLiteral(location(1, 14), "12"))),
rowType(
location(1, 21),
field(location(1, 25), "COL0", simpleType(location(1, 30), "INTEGER")),
field(location(1, 39), "COL1", simpleType(location(1, 44), "INTEGER")))));

assertThat(expression("ROW(11, 12)::ROW(COL0 INTEGER, COL1 INTEGER)"))
.isEqualTo(new Cast(location(1, 1),
new Row(location(1, 1), Lists.newArrayList(new LongLiteral(location(1, 5), "11"), new LongLiteral(location(1, 9), "12"))),
rowType(
location(1, 14),
field(location(1, 18), "COL0", simpleType(location(1, 23), "INTEGER")),
field(location(1, 32), "COL1", simpleType(location(1, 37), "INTEGER")))));
}

@Test
public void testSearchedCase()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private static Stream<Arguments> expressions()
{
return Stream.of(
Arguments.of("", "line 1:1: mismatched input '<EOF>'. Expecting: <expression>"),
Arguments.of("1 + 1 x", "line 1:7: mismatched input 'x'. Expecting: '%', '*', '+', '-', '.', '/', 'AND', 'AT', 'OR', '[', '||', <EOF>, <predicate>"));
Arguments.of("1 + 1 x", "line 1:7: mismatched input 'x'. Expecting: '%', '*', '+', '-', '.', '/', '::', 'AND', 'AT', 'OR', '[', '||', <EOF>, <predicate>"));
}

private static Stream<Arguments> statements()
Expand Down Expand Up @@ -67,7 +67,7 @@ private static Stream<Arguments> statements()
Arguments.of("select 1x from dual",
"line 1:8: identifiers must not start with a digit; surround the identifier with double quotes"),
Arguments.of("select fuu from dual order by fuu order by fuu",
"line 1:35: mismatched input 'order'. Expecting: '%', '*', '+', ',', '-', '.', '/', 'AND', 'ASC', 'AT', 'DESC', 'FETCH', 'LIMIT', 'NULLS', 'OFFSET', 'OR', '[', '||', <EOF>, <predicate>"),
"line 1:35: mismatched input 'order'. Expecting: '%', '*', '+', ',', '-', '.', '/', '::', 'AND', 'ASC', 'AT', 'DESC', 'FETCH', 'LIMIT', 'NULLS', 'OFFSET', 'OR', '[', '||', <EOF>, <predicate>"),
Arguments.of("select fuu from dual limit 10 order by fuu",
"line 1:31: mismatched input 'order'. Expecting: <EOF>"),
Arguments.of("select CAST(12223222232535343423232435343 AS BIGINT)",
Expand Down Expand Up @@ -99,7 +99,7 @@ private static Stream<Arguments> statements()
Arguments.of("SELECT x() over (ROWS select) FROM t",
"line 1:23: mismatched input 'select'. Expecting: ')', 'BETWEEN', 'CURRENT', 'GROUPS', 'MEASURES', 'ORDER', 'PARTITION', 'RANGE', 'ROWS', 'UNBOUNDED', <expression>"),
Arguments.of("SELECT X() OVER (ROWS UNBOUNDED) FROM T",
"line 1:32: mismatched input ')'. Expecting: '%', '(', '*', '+', '-', '->', '.', '/', 'AND', 'AT', 'FOLLOWING', 'OR', 'OVER', 'PRECEDING', '[', '||', <predicate>, <string>"),
"line 1:32: mismatched input ')'. Expecting: '%', '(', '*', '+', '-', '->', '.', '/', '::', 'AND', 'AT', 'FOLLOWING', 'OR', 'OVER', 'PRECEDING', '[', '||', <predicate>, <string>"),
Arguments.of("SELECT a FROM x ORDER BY (SELECT b FROM t WHERE ",
"line 1:49: mismatched input '<EOF>'. Expecting: <expression>"),
Arguments.of("SELECT a FROM a AS x TABLESAMPLE x ",
Expand Down Expand Up @@ -134,7 +134,7 @@ private static Stream<Arguments> statements()
Arguments.of("SELECT a FROM \"\".s.t",
"line 1:15: Zero-length delimited identifier not allowed"),
Arguments.of("WITH t AS (SELECT 1 SELECT t.* FROM t",
"line 1:21: mismatched input 'SELECT'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', 'AND', 'AS', 'AT', 'EXCEPT', 'FETCH', 'FROM', " +
"line 1:21: mismatched input 'SELECT'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', '::', 'AND', 'AS', 'AT', 'EXCEPT', 'FETCH', 'FROM', " +
"'GROUP', 'HAVING', 'INTERSECT', 'LIMIT', 'OFFSET', 'OR', 'ORDER', 'UNION', 'WHERE', 'WINDOW', '[', '||', " +
"<identifier>, <predicate>"),
Arguments.of("SHOW CATALOGS LIKE '%$_%' ESCAPE",
Expand All @@ -160,9 +160,9 @@ private static Stream<Arguments> statements()
Arguments.of("SELECT * FROM t FOR VERSION AS OF TIMESTAMP WHERE",
"line 1:50: mismatched input '<EOF>'. Expecting: <expression>"),
Arguments.of("SELECT ROW(DATE '2022-10-10', DOUBLE 12.0)",
"line 1:38: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'ORDER', 'OVER', 'PRECISION', '[', '||', <predicate>, <string>"),
"line 1:38: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', '::', 'AND', 'AT', 'OR', 'ORDER', 'OVER', 'PRECISION', '[', '||', <predicate>, <string>"),
Arguments.of("VALUES(DATE 2)",
"line 1:13: mismatched input '2'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'OVER', '[', '||', <predicate>, <string>"),
"line 1:13: mismatched input '2'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', '::', 'AND', 'AT', 'OR', 'OVER', '[', '||', <predicate>, <string>"),
Arguments.of("SELECT count(DISTINCT *) FROM (VALUES 1)",
"line 1:23: mismatched input '*'. Expecting: <expression>"));
}
Expand All @@ -182,7 +182,7 @@ public void testPossibleExponentialBacktracking()
"1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * " +
"1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * " +
"1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9",
"line 1:375: mismatched input '<EOF>'. Expecting: '%', '*', '+', '-', '.', '/', 'AND', 'AT', 'OR', 'THEN', '[', '||', <predicate>");
"line 1:375: mismatched input '<EOF>'. Expecting: '%', '*', '+', '-', '.', '/', '::', 'AND', 'AT', 'OR', 'THEN', '[', '||', <predicate>");
}

@Test
Expand Down Expand Up @@ -212,7 +212,7 @@ public void testPossibleExponentialBacktracking2()
"OR (f()\n" +
"OR (f()\n" +
"GROUP BY id",
"line 24:1: mismatched input 'GROUP'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', 'AND', 'AT', 'FILTER', 'IGNORE', 'OR', 'OVER', 'RESPECT', '[', '||', <predicate>");
"line 24:1: mismatched input 'GROUP'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', '::', 'AND', 'AT', 'FILTER', 'IGNORE', 'OR', 'OVER', 'RESPECT', '[', '||', <predicate>");
}

@ParameterizedTest
Expand Down
Loading