Skip to content

Commit cfc68f0

Browse files
committed
Add PostgreSQL :: style casts
1 parent e437b16 commit cfc68f0

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

+3
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,8 @@ primaryExpression
592592
| CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
593593
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
594594
| CAST '(' expression AS type ')' #cast
595+
// This is a postgres extension to ANSI SQL, which allows for the use of "::" to cast
596+
| primaryExpression DOUBLE_COLON type #cast
595597
| TRY_CAST '(' expression AS type ')' #cast
596598
| ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor
597599
| value=primaryExpression '[' index=valueExpression ']' #subscript
@@ -1327,6 +1329,7 @@ LT: '<';
13271329
LTE: '<=';
13281330
GT: '>';
13291331
GTE: '>=';
1332+
DOUBLE_COLON: '::';
13301333

13311334
PLUS: '+';
13321335
MINUS: '-';

core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@
368368
import static io.trino.sql.tree.TableFunctionDescriptorArgument.nullDescriptorArgument;
369369
import static java.util.Locale.ENGLISH;
370370
import static java.util.Objects.requireNonNull;
371+
import static java.util.Objects.requireNonNullElse;
371372
import static java.util.stream.Collectors.joining;
372373
import static java.util.stream.Collectors.toList;
373374

@@ -2387,7 +2388,8 @@ public Node visitArrayConstructor(SqlBaseParser.ArrayConstructorContext context)
23872388
public Node visitCast(SqlBaseParser.CastContext context)
23882389
{
23892390
boolean isTryCast = context.TRY_CAST() != null;
2390-
return new Cast(getLocation(context), (Expression) visit(context.expression()), (DataType) visit(context.type()), isTryCast);
2391+
Expression expression = (Expression) visit(requireNonNullElse(context.expression(), context.primaryExpression()));
2392+
return new Cast(getLocation(context), expression, (DataType) visit(context.type()), isTryCast);
23912393
}
23922394

23932395
@Override

core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java

+45
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,51 @@ public void testCase()
13451345
Optional.of(new LongLiteral(location(1, 38), "3"))));
13461346
}
13471347

1348+
@Test
1349+
public void testCast()
1350+
{
1351+
assertThat(expression("CAST(1 AS BIGINT)"))
1352+
.isEqualTo(new Cast(location(1, 1),
1353+
new LongLiteral(location(1, 6), "1"),
1354+
simpleType(location(1, 11), "BIGINT")));
1355+
1356+
assertThat(expression("1::BIGINT"))
1357+
.isEqualTo(new Cast(location(1, 1),
1358+
new LongLiteral(location(1, 1), "1"),
1359+
simpleType(location(1, 4), "BIGINT")));
1360+
1361+
assertThat(expression("-3::BIGINT"))
1362+
.isEqualTo(new Cast(location(1, 1),
1363+
new LongLiteral(location(1, 1), "-3"),
1364+
simpleType(location(1, 5), "BIGINT")));
1365+
1366+
assertThat(expression("3*'4'::BIGINT"))
1367+
.isEqualTo(new ArithmeticBinaryExpression(
1368+
location(1, 2),
1369+
ArithmeticBinaryExpression.Operator.MULTIPLY,
1370+
new LongLiteral(location(1, 1), "3"),
1371+
new Cast(
1372+
location(1, 3),
1373+
new StringLiteral(location(1, 3), "4"),
1374+
simpleType(location(1, 8), "BIGINT"))));
1375+
1376+
assertThat(expression("CAST(ROW(11, 12) AS ROW(COL0 INTEGER, COL1 INTEGER))"))
1377+
.isEqualTo(new Cast(location(1, 1),
1378+
new Row(location(1, 6), Lists.newArrayList(new LongLiteral(location(1, 10), "11"), new LongLiteral(location(1, 14), "12"))),
1379+
rowType(
1380+
location(1, 21),
1381+
field(location(1, 25), "COL0", simpleType(location(1, 30), "INTEGER")),
1382+
field(location(1, 39), "COL1", simpleType(location(1, 44), "INTEGER")))));
1383+
1384+
assertThat(expression("ROW(11, 12)::ROW(COL0 INTEGER, COL1 INTEGER)"))
1385+
.isEqualTo(new Cast(location(1, 1),
1386+
new Row(location(1, 1), Lists.newArrayList(new LongLiteral(location(1, 5), "11"), new LongLiteral(location(1, 9), "12"))),
1387+
rowType(
1388+
location(1, 14),
1389+
field(location(1, 18), "COL0", simpleType(location(1, 23), "INTEGER")),
1390+
field(location(1, 32), "COL1", simpleType(location(1, 37), "INTEGER")))));
1391+
}
1392+
13481393
@Test
13491394
public void testSearchedCase()
13501395
{

0 commit comments

Comments
 (0)