diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index 9334c12d0d1c..a2217cad5f09 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -592,6 +592,8 @@ primaryExpression | CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CAST '(' expression AS type ')' #cast + // This is a postgres extension to ANSI SQL, which allows for the use of "::" to cast + | primaryExpression DOUBLE_COLON type #cast | TRY_CAST '(' expression AS type ')' #cast | ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor | value=primaryExpression '[' index=valueExpression ']' #subscript @@ -1327,6 +1329,7 @@ LT: '<'; LTE: '<='; GT: '>'; GTE: '>='; +DOUBLE_COLON: '::'; PLUS: '+'; MINUS: '-'; diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 06322c2c9234..eca662665771 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -368,6 +368,7 @@ import static io.trino.sql.tree.TableFunctionDescriptorArgument.nullDescriptorArgument; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; +import static java.util.Objects.requireNonNullElse; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; @@ -2387,7 +2388,8 @@ public Node visitArrayConstructor(SqlBaseParser.ArrayConstructorContext context) public Node visitCast(SqlBaseParser.CastContext context) { boolean isTryCast = context.TRY_CAST() != null; - return new Cast(getLocation(context), (Expression) visit(context.expression()), (DataType) visit(context.type()), isTryCast); + Expression expression = (Expression) visit(requireNonNullElse(context.expression(), context.primaryExpression())); + return new Cast(getLocation(context), expression, (DataType) visit(context.type()), isTryCast); } @Override diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index b3c35ef28ab0..10f1f282cbf2 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -1345,6 +1345,51 @@ public void testCase() Optional.of(new LongLiteral(location(1, 38), "3")))); } + @Test + public void testCast() + { + assertThat(expression("CAST(1 AS BIGINT)")) + .isEqualTo(new Cast(location(1, 1), + new LongLiteral(location(1, 6), "1"), + simpleType(location(1, 11), "BIGINT"))); + + assertThat(expression("1::BIGINT")) + .isEqualTo(new Cast(location(1, 1), + new LongLiteral(location(1, 1), "1"), + simpleType(location(1, 4), "BIGINT"))); + + assertThat(expression("-3::BIGINT")) + .isEqualTo(new Cast(location(1, 1), + new LongLiteral(location(1, 1), "-3"), + simpleType(location(1, 5), "BIGINT"))); + + assertThat(expression("3*'4'::BIGINT")) + .isEqualTo(new ArithmeticBinaryExpression( + location(1, 2), + ArithmeticBinaryExpression.Operator.MULTIPLY, + new LongLiteral(location(1, 1), "3"), + new Cast( + location(1, 3), + new StringLiteral(location(1, 3), "4"), + simpleType(location(1, 8), "BIGINT")))); + + assertThat(expression("CAST(ROW(11, 12) AS ROW(COL0 INTEGER, COL1 INTEGER))")) + .isEqualTo(new Cast(location(1, 1), + new Row(location(1, 6), Lists.newArrayList(new LongLiteral(location(1, 10), "11"), new LongLiteral(location(1, 14), "12"))), + rowType( + location(1, 21), + field(location(1, 25), "COL0", simpleType(location(1, 30), "INTEGER")), + field(location(1, 39), "COL1", simpleType(location(1, 44), "INTEGER"))))); + + assertThat(expression("ROW(11, 12)::ROW(COL0 INTEGER, COL1 INTEGER)")) + .isEqualTo(new Cast(location(1, 1), + new Row(location(1, 1), Lists.newArrayList(new LongLiteral(location(1, 5), "11"), new LongLiteral(location(1, 9), "12"))), + rowType( + location(1, 14), + field(location(1, 18), "COL0", simpleType(location(1, 23), "INTEGER")), + field(location(1, 32), "COL1", simpleType(location(1, 37), "INTEGER"))))); + } + @Test public void testSearchedCase() { diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index 41940b9368d5..da452648a8af 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -33,7 +33,7 @@ private static Stream expressions() { return Stream.of( Arguments.of("", "line 1:1: mismatched input ''. Expecting: "), - Arguments.of("1 + 1 x", "line 1:7: mismatched input 'x'. Expecting: '%', '*', '+', '-', '.', '/', 'AND', 'AT', 'OR', '[', '||', , ")); + Arguments.of("1 + 1 x", "line 1:7: mismatched input 'x'. Expecting: '%', '*', '+', '-', '.', '/', '::', 'AND', 'AT', 'OR', '[', '||', , ")); } private static Stream statements() @@ -67,7 +67,7 @@ private static Stream statements() Arguments.of("select 1x from dual", "line 1:8: identifiers must not start with a digit; surround the identifier with double quotes"), Arguments.of("select fuu from dual order by fuu order by fuu", - "line 1:35: mismatched input 'order'. Expecting: '%', '*', '+', ',', '-', '.', '/', 'AND', 'ASC', 'AT', 'DESC', 'FETCH', 'LIMIT', 'NULLS', 'OFFSET', 'OR', '[', '||', , "), + "line 1:35: mismatched input 'order'. Expecting: '%', '*', '+', ',', '-', '.', '/', '::', 'AND', 'ASC', 'AT', 'DESC', 'FETCH', 'LIMIT', 'NULLS', 'OFFSET', 'OR', '[', '||', , "), Arguments.of("select fuu from dual limit 10 order by fuu", "line 1:31: mismatched input 'order'. Expecting: "), Arguments.of("select CAST(12223222232535343423232435343 AS BIGINT)", @@ -99,7 +99,7 @@ private static Stream statements() Arguments.of("SELECT x() over (ROWS select) FROM t", "line 1:23: mismatched input 'select'. Expecting: ')', 'BETWEEN', 'CURRENT', 'GROUPS', 'MEASURES', 'ORDER', 'PARTITION', 'RANGE', 'ROWS', 'UNBOUNDED', "), Arguments.of("SELECT X() OVER (ROWS UNBOUNDED) FROM T", - "line 1:32: mismatched input ')'. Expecting: '%', '(', '*', '+', '-', '->', '.', '/', 'AND', 'AT', 'FOLLOWING', 'OR', 'OVER', 'PRECEDING', '[', '||', , "), + "line 1:32: mismatched input ')'. Expecting: '%', '(', '*', '+', '-', '->', '.', '/', '::', 'AND', 'AT', 'FOLLOWING', 'OR', 'OVER', 'PRECEDING', '[', '||', , "), Arguments.of("SELECT a FROM x ORDER BY (SELECT b FROM t WHERE ", "line 1:49: mismatched input ''. Expecting: "), Arguments.of("SELECT a FROM a AS x TABLESAMPLE x ", @@ -134,7 +134,7 @@ private static Stream statements() Arguments.of("SELECT a FROM \"\".s.t", "line 1:15: Zero-length delimited identifier not allowed"), Arguments.of("WITH t AS (SELECT 1 SELECT t.* FROM t", - "line 1:21: mismatched input 'SELECT'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', 'AND', 'AS', 'AT', 'EXCEPT', 'FETCH', 'FROM', " + + "line 1:21: mismatched input 'SELECT'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', '::', 'AND', 'AS', 'AT', 'EXCEPT', 'FETCH', 'FROM', " + "'GROUP', 'HAVING', 'INTERSECT', 'LIMIT', 'OFFSET', 'OR', 'ORDER', 'UNION', 'WHERE', 'WINDOW', '[', '||', " + ", "), Arguments.of("SHOW CATALOGS LIKE '%$_%' ESCAPE", @@ -160,9 +160,9 @@ private static Stream statements() Arguments.of("SELECT * FROM t FOR VERSION AS OF TIMESTAMP WHERE", "line 1:50: mismatched input ''. Expecting: "), Arguments.of("SELECT ROW(DATE '2022-10-10', DOUBLE 12.0)", - "line 1:38: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'ORDER', 'OVER', 'PRECISION', '[', '||', , "), + "line 1:38: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', '::', 'AND', 'AT', 'OR', 'ORDER', 'OVER', 'PRECISION', '[', '||', , "), Arguments.of("VALUES(DATE 2)", - "line 1:13: mismatched input '2'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'OVER', '[', '||', , "), + "line 1:13: mismatched input '2'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', '::', 'AND', 'AT', 'OR', 'OVER', '[', '||', , "), Arguments.of("SELECT count(DISTINCT *) FROM (VALUES 1)", "line 1:23: mismatched input '*'. Expecting: ")); } @@ -182,7 +182,7 @@ public void testPossibleExponentialBacktracking() "1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * " + "1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * " + "1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9", - "line 1:375: mismatched input ''. Expecting: '%', '*', '+', '-', '.', '/', 'AND', 'AT', 'OR', 'THEN', '[', '||', "); + "line 1:375: mismatched input ''. Expecting: '%', '*', '+', '-', '.', '/', '::', 'AND', 'AT', 'OR', 'THEN', '[', '||', "); } @Test @@ -212,7 +212,7 @@ public void testPossibleExponentialBacktracking2() "OR (f()\n" + "OR (f()\n" + "GROUP BY id", - "line 24:1: mismatched input 'GROUP'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', 'AND', 'AT', 'FILTER', 'IGNORE', 'OR', 'OVER', 'RESPECT', '[', '||', "); + "line 24:1: mismatched input 'GROUP'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', '::', 'AND', 'AT', 'FILTER', 'IGNORE', 'OR', 'OVER', 'RESPECT', '[', '||', "); } @ParameterizedTest