Skip to content

Commit cb50d57

Browse files
committed
Add SQL:2023 non-static method invocation
SQL:2023 specifies <method invocation>, expr.method(args), as a sibling of <static method invocation>. Where T::method(args) namespaces a function under a type, expr.method(args) dispatches on the runtime type of the receiver and passes it as the implicit self argument. Functions tagged with the new @instancemethod SPI annotation register with a receiver type taken from their first @SqlType parameter and become callable only as expr.method(args). A plain method(args) call cannot resolve to an instance method, mirroring the static-method namespace separation. The receiver expression form (expr).method(args) parses unambiguously as a method call. The bare A.B(args) form (where A is in scope as a column) keeps parsing as a routine invocation but, per SQL:2023 6.3 Syntax Rule 2, the analyzer treats it as a method invocation when one applies and only falls back to function resolution otherwise.
1 parent a2f4c96 commit cb50d57

19 files changed

Lines changed: 719 additions & 14 deletions

File tree

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ primaryExpression
606606
| processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)?
607607
orderBy? ')' filter? (nullTreatment? over)? #functionCall
608608
| qualifiedName '::' identifier '(' (expression (',' expression)*)? ')' #staticMethodCall
609+
| primaryExpression '.' identifier '(' (expression (',' expression)*)? ')' #methodCall
609610
| identifier over #measure
610611
| identifier '->' expression #lambda
611612
| '(' (identifier (',' identifier)*)? ')' '->' expression #lambda

core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ public ResolvedFunction resolveStaticMethod(
146146
parameterTypes,
147147
catalogSchemaFunctionName -> filterCandidates(
148148
metadata.getFunctions(session, catalogSchemaFunctionName),
149-
candidate -> candidate.functionMetadata().getReceiverType()
149+
candidate -> !candidate.functionMetadata().isInstanceMethod()
150+
&& candidate.functionMetadata().getReceiverType()
150151
.map(TypeSignature::getBase).equals(Optional.of(receiver))),
151152
accessControl);
152153

@@ -158,6 +159,33 @@ public ResolvedFunction resolveStaticMethod(
158159
return resolve(session, catalogFunctionBinding, accessControl);
159160
}
160161

162+
public ResolvedFunction resolveInstanceMethod(
163+
Session session,
164+
TypeSignature receiverType,
165+
QualifiedName methodName,
166+
List<TypeSignatureProvider> parameterTypes,
167+
AccessControl accessControl)
168+
{
169+
String receiver = receiverType.getBase();
170+
CatalogFunctionBinding catalogFunctionBinding = bindFunction(
171+
session,
172+
methodName,
173+
parameterTypes,
174+
catalogSchemaFunctionName -> filterCandidates(
175+
metadata.getFunctions(session, catalogSchemaFunctionName),
176+
candidate -> candidate.functionMetadata().isInstanceMethod()
177+
&& candidate.functionMetadata().getReceiverType()
178+
.map(TypeSignature::getBase).equals(Optional.of(receiver))),
179+
accessControl);
180+
181+
FunctionMetadata functionMetadata = catalogFunctionBinding.boundFunctionMetadata();
182+
if (functionMetadata.isDeprecated()) {
183+
warningCollector.add(new TrinoWarning(DEPRECATED_FUNCTION, "Use of deprecated function: %s.%s: %s".formatted(receiverType, methodName, functionMetadata.getDescription())));
184+
}
185+
186+
return resolve(session, catalogFunctionBinding, accessControl);
187+
}
188+
161189
private static Collection<CatalogFunctionMetadata> filterCandidates(
162190
Collection<CatalogFunctionMetadata> candidates,
163191
Predicate<CatalogFunctionMetadata> predicate)

core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,14 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Scal
8181
if (deprecated) {
8282
functionMetadata.deprecated();
8383
}
84-
details.getReceiverType().ifPresent(functionMetadata::receiverType);
84+
if (details.isInstanceMethod()) {
85+
checkCondition(!signature.getArgumentTypes().isEmpty(), FUNCTION_IMPLEMENTATION_ERROR, "Instance method %s must declare a self argument", details.getName());
86+
functionMetadata.receiverType(signature.getArgumentTypes().getFirst());
87+
functionMetadata.instanceMethod();
88+
}
89+
else {
90+
details.getReceiverType().ifPresent(functionMetadata::receiverType);
91+
}
8592

8693
if (functionNullability.isReturnNullable()) {
8794
functionMetadata.nullable();

core/trino-main/src/main/java/io/trino/operator/scalar/ScalarHeader.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import com.google.common.collect.ImmutableList;
1717
import com.google.common.collect.ImmutableSet;
18+
import io.trino.spi.function.InstanceMethod;
1819
import io.trino.spi.function.OperatorType;
1920
import io.trino.spi.function.ScalarFunction;
2021
import io.trino.spi.function.ScalarOperator;
@@ -45,13 +46,14 @@ public class ScalarHeader
4546
private final boolean deterministic;
4647
private final boolean neverFails;
4748
private final Optional<TypeSignature> receiverType;
49+
private final boolean instanceMethod;
4850

4951
public ScalarHeader(String name, Set<String> aliases, Optional<String> description, boolean hidden, boolean deterministic, boolean neverFails)
5052
{
51-
this(name, aliases, description, hidden, deterministic, neverFails, Optional.empty());
53+
this(name, aliases, description, hidden, deterministic, neverFails, Optional.empty(), false);
5254
}
5355

54-
public ScalarHeader(String name, Set<String> aliases, Optional<String> description, boolean hidden, boolean deterministic, boolean neverFails, Optional<TypeSignature> receiverType)
56+
public ScalarHeader(String name, Set<String> aliases, Optional<String> description, boolean hidden, boolean deterministic, boolean neverFails, Optional<TypeSignature> receiverType, boolean instanceMethod)
5557
{
5658
this.name = requireNonNull(name, "name is null");
5759
checkArgument(!name.isEmpty());
@@ -63,6 +65,8 @@ public ScalarHeader(String name, Set<String> aliases, Optional<String> descripti
6365
this.deterministic = deterministic;
6466
this.neverFails = neverFails;
6567
this.receiverType = requireNonNull(receiverType, "receiverType is null");
68+
checkArgument(!instanceMethod || receiverType.isEmpty(), "instance method receiver type is inferred from the first argument");
69+
this.instanceMethod = instanceMethod;
6670
}
6771

6872
public ScalarHeader(OperatorType operatorType, Optional<String> description)
@@ -75,30 +79,36 @@ public ScalarHeader(OperatorType operatorType, Optional<String> description)
7579
this.deterministic = true;
7680
this.neverFails = false;
7781
this.receiverType = Optional.empty();
82+
this.instanceMethod = false;
7883
}
7984

8085
public static List<ScalarHeader> fromAnnotatedElement(AnnotatedElement annotated)
8186
{
8287
ScalarFunction scalarFunction = annotated.getAnnotation(ScalarFunction.class);
8388
ScalarOperator scalarOperator = annotated.getAnnotation(ScalarOperator.class);
8489
StaticMethod staticMethod = annotated.getAnnotation(StaticMethod.class);
90+
InstanceMethod instanceMethod = annotated.getAnnotation(InstanceMethod.class);
8591
Optional<String> description = parseDescription(annotated);
8692

8793
ImmutableList.Builder<ScalarHeader> builder = ImmutableList.builder();
8894

8995
if (scalarFunction != null) {
96+
checkArgument(staticMethod == null || instanceMethod == null, "@StaticMethod and @InstanceMethod are mutually exclusive on %s", annotated);
9097
String baseName = scalarFunction.value().isEmpty() ? camelToSnake(annotatedName(annotated)) : scalarFunction.value();
9198
Optional<TypeSignature> receiverType = Optional.empty();
9299
if (staticMethod != null) {
93100
TypeSignature parsed = parseTypeSignature(staticMethod.value(), ImmutableSet.of());
94101
checkArgument(parsed.getParameters().isEmpty(), "@StaticMethod receiver type must not have parameters: %s", staticMethod.value());
95102
receiverType = Optional.of(parsed);
96103
}
97-
builder.add(new ScalarHeader(baseName, ImmutableSet.copyOf(scalarFunction.alias()), description, scalarFunction.hidden(), scalarFunction.deterministic(), scalarFunction.neverFails(), receiverType));
104+
builder.add(new ScalarHeader(baseName, ImmutableSet.copyOf(scalarFunction.alias()), description, scalarFunction.hidden(), scalarFunction.deterministic(), scalarFunction.neverFails(), receiverType, instanceMethod != null));
98105
}
99106
else if (staticMethod != null) {
100107
throw new IllegalArgumentException("@StaticMethod requires @ScalarFunction on " + annotated);
101108
}
109+
else if (instanceMethod != null) {
110+
throw new IllegalArgumentException("@InstanceMethod requires @ScalarFunction on " + annotated);
111+
}
102112

103113
if (scalarOperator != null) {
104114
builder.add(new ScalarHeader(scalarOperator.value(), description));
@@ -165,4 +175,9 @@ public Optional<TypeSignature> getReceiverType()
165175
{
166176
return receiverType;
167177
}
178+
179+
public boolean isInstanceMethod()
180+
{
181+
return instanceMethod;
182+
}
168183
}

core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ public class Analysis
225225
private final Map<NodeRef<Expression>, ResolvedFunction> frameBoundCalculations = new LinkedHashMap<>();
226226
private final Map<NodeRef<Relation>, List<Type>> relationCoercions = new LinkedHashMap<>();
227227
private final Map<NodeRef<Node>, RoutineEntry> resolvedFunctions = new LinkedHashMap<>();
228+
private final Map<NodeRef<FunctionCall>, Identifier> methodCallReceivers = new LinkedHashMap<>();
228229
private final Map<NodeRef<Identifier>, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>();
229230

230231
private final Map<Field, ColumnHandle> columns = new LinkedHashMap<>();
@@ -720,6 +721,16 @@ public void addResolvedFunction(Node node, ResolvedFunction function, String aut
720721
resolvedFunctions.put(NodeRef.of(node), new RoutineEntry(function, authorization));
721722
}
722723

724+
public void addMethodCallReceiver(FunctionCall node, Identifier receiver)
725+
{
726+
methodCallReceivers.put(NodeRef.of(node), receiver);
727+
}
728+
729+
public Optional<Identifier> getMethodCallReceiver(FunctionCall node)
730+
{
731+
return Optional.ofNullable(methodCallReceivers.get(NodeRef.of(node)));
732+
}
733+
723734
public Set<NodeRef<Expression>> getColumnReferences()
724735
{
725736
return unmodifiableSet(columnReferences.keySet());

core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
import io.trino.sql.tree.LogicalExpression;
130130
import io.trino.sql.tree.LongLiteral;
131131
import io.trino.sql.tree.MeasureDefinition;
132+
import io.trino.sql.tree.MethodCall;
132133
import io.trino.sql.tree.Node;
133134
import io.trino.sql.tree.NodeRef;
134135
import io.trino.sql.tree.NotExpression;
@@ -329,6 +330,7 @@ public class ExpressionAnalyzer
329330
private final Cache<String, Type> varcharCastableTypeCache = buildNonEvictableCache(CacheBuilder.newBuilder().maximumSize(1000));
330331

331332
private final Map<NodeRef<Node>, ResolvedFunction> resolvedFunctions = new LinkedHashMap<>();
333+
private final Map<NodeRef<FunctionCall>, Identifier> methodCallReceivers = new LinkedHashMap<>();
332334
private final Set<NodeRef<SubqueryExpression>> subqueries = new LinkedHashSet<>();
333335
private final Set<NodeRef<ExistsPredicate>> existsSubqueries = new LinkedHashSet<>();
334336
private final Map<NodeRef<Expression>, Type> expressionCoercions = new LinkedHashMap<>();
@@ -437,6 +439,11 @@ public Map<NodeRef<Node>, ResolvedFunction> getResolvedFunctions()
437439
return unmodifiableMap(resolvedFunctions);
438440
}
439441

442+
public Map<NodeRef<FunctionCall>, Identifier> getMethodCallReceivers()
443+
{
444+
return unmodifiableMap(methodCallReceivers);
445+
}
446+
440447
public Map<NodeRef<Expression>, Type> getExpressionTypes()
441448
{
442449
return unmodifiableMap(expressionTypes);
@@ -1309,6 +1316,14 @@ protected Type visitNullLiteral(NullLiteral node, Context context)
13091316
@Override
13101317
protected Type visitFunctionCall(FunctionCall node, Context context)
13111318
{
1319+
// SQL:2023 6.3 Syntax Rule 2: a non-parenthesized value expression primary
1320+
// of the form A.B(args) is treated as a method invocation if it satisfies
1321+
// the rules for one; otherwise it is a routine invocation.
1322+
Optional<Type> asMethod = tryResolveAsInstanceMethod(node, context);
1323+
if (asMethod.isPresent()) {
1324+
return asMethod.get();
1325+
}
1326+
13121327
boolean isAggregation = functionResolver.isAggregationFunction(session, node.getName(), accessControl);
13131328
boolean isRowPatternCount = context.isPatternRecognition() &&
13141329
isAggregation &&
@@ -1460,6 +1475,123 @@ else if (isAggregation) {
14601475
return setExpressionType(node, type);
14611476
}
14621477

1478+
private Optional<Type> tryResolveAsInstanceMethod(FunctionCall node, Context context)
1479+
{
1480+
QualifiedName name = node.getName();
1481+
if (name.getParts().size() != 2) {
1482+
return Optional.empty();
1483+
}
1484+
if (node.isDistinct()
1485+
|| node.getFilter().isPresent()
1486+
|| node.getOrderBy().isPresent()
1487+
|| node.getWindow().isPresent()
1488+
|| node.getProcessingMode().isPresent()
1489+
|| node.getNullTreatment().isPresent()) {
1490+
return Optional.empty();
1491+
}
1492+
if (context.isPatternRecognition() || context.isInWindow()) {
1493+
return Optional.empty();
1494+
}
1495+
1496+
Identifier receiver = name.getOriginalParts().get(0);
1497+
Identifier method = name.getOriginalParts().get(1);
1498+
1499+
// Method-call interpretation only applies when the receiver resolves
1500+
// as a field in the current scope. tryResolveField has no side effects
1501+
// so a non-match leaves the analyzer state untouched.
1502+
Optional<ResolvedField> resolvedReceiver = context.getScope()
1503+
.tryResolveField(receiver, QualifiedName.of(receiver.getValue()));
1504+
if (resolvedReceiver.isEmpty()) {
1505+
return Optional.empty();
1506+
}
1507+
Type receiverType = resolvedReceiver.get().getField().getType();
1508+
1509+
MethodResolution resolution;
1510+
try {
1511+
resolution = resolveInstanceMethodCall(receiverType, method.getValue(), node.getArguments(), context);
1512+
}
1513+
catch (TrinoException e) {
1514+
return Optional.empty();
1515+
}
1516+
1517+
// Commit to method-call interpretation: record the receiver field reference.
1518+
process(receiver, context);
1519+
1520+
Type result = analyzeInstanceMethodInvocation(node, receiver, receiverType, method.getValue(), node.getArguments(), resolution, context);
1521+
methodCallReceivers.put(NodeRef.of(node), receiver);
1522+
return Optional.of(result);
1523+
}
1524+
1525+
@Override
1526+
protected Type visitMethodCall(MethodCall node, Context context)
1527+
{
1528+
Type receiverType = process(node.getReceiver(), context);
1529+
String methodName = node.getMethod().getValue();
1530+
1531+
MethodResolution resolution;
1532+
try {
1533+
resolution = resolveInstanceMethodCall(receiverType, methodName, node.getArguments(), context);
1534+
}
1535+
catch (TrinoException e) {
1536+
if (e.getLocation().isPresent()) {
1537+
throw e;
1538+
}
1539+
throw new TrinoException(e::getErrorCode, extractLocation(node), e.getMessage(), e);
1540+
}
1541+
1542+
return analyzeInstanceMethodInvocation(node, node.getReceiver(), receiverType, methodName, node.getArguments(), resolution, context);
1543+
}
1544+
1545+
private MethodResolution resolveInstanceMethodCall(Type receiverType, String methodName, List<Expression> arguments, Context context)
1546+
{
1547+
List<TypeSignatureProvider> argumentTypes = ImmutableList.<TypeSignatureProvider>builder()
1548+
.add(new TypeSignatureProvider(receiverType.getTypeSignature()))
1549+
.addAll(getCallArgumentTypes(arguments, context))
1550+
.build();
1551+
ResolvedFunction function = functionResolver.resolveInstanceMethod(
1552+
session,
1553+
receiverType.getTypeSignature(),
1554+
QualifiedName.of(methodName),
1555+
argumentTypes,
1556+
accessControl);
1557+
return new MethodResolution(function, argumentTypes);
1558+
}
1559+
1560+
private Type analyzeInstanceMethodInvocation(
1561+
Expression node,
1562+
Expression receiver,
1563+
Type receiverType,
1564+
String methodName,
1565+
List<Expression> arguments,
1566+
MethodResolution resolution,
1567+
Context context)
1568+
{
1569+
if (arguments.size() + 1 > 127) {
1570+
throw semanticException(TOO_MANY_ARGUMENTS, node, "Too many arguments for method call .%s()", methodName);
1571+
}
1572+
1573+
BoundSignature signature = resolution.function().signature();
1574+
Type expectedReceiverType = signature.getArgumentTypes().getFirst();
1575+
coerceType(receiver, receiverType, expectedReceiverType, format("Method .%s receiver", methodName));
1576+
// Slot 0 of the signature is the receiver (self), so user-visible argument i maps to signature slot i + 1.
1577+
for (int i = 0; i < arguments.size(); i++) {
1578+
Expression expression = arguments.get(i);
1579+
Type expectedType = signature.getArgumentTypes().get(i + 1);
1580+
if (resolution.argumentTypes().get(i + 1).hasDependency()) {
1581+
FunctionType expectedFunctionType = (FunctionType) expectedType;
1582+
process(expression, context.expectingLambda(expectedFunctionType.getArgumentTypes()));
1583+
}
1584+
else {
1585+
Type actualType = plannerContext.getTypeManager().getType(resolution.argumentTypes().get(i + 1).getTypeSignature());
1586+
coerceType(expression, actualType, expectedType, format("Method .%s argument %d", methodName, i));
1587+
}
1588+
}
1589+
resolvedFunctions.put(NodeRef.of(node), resolution.function());
1590+
return setExpressionType(node, signature.getReturnType());
1591+
}
1592+
1593+
private record MethodResolution(ResolvedFunction function, List<TypeSignatureProvider> argumentTypes) {}
1594+
14631595
@Override
14641596
protected Type visitStaticMethodCall(StaticMethodCall node, Context context)
14651597
{
@@ -3948,6 +4080,7 @@ private static void updateAnalysis(Analysis analysis, ExpressionAnalyzer analyze
39484080
analyzer.getSortKeyCoercionsForFrameBoundComparison());
39494081
analysis.addFrameBoundCalculations(analyzer.getFrameBoundCalculations());
39504082
analyzer.getResolvedFunctions().forEach((key, value) -> analysis.addResolvedFunction(key.getNode(), value, session.getUser()));
4083+
analyzer.getMethodCallReceivers().forEach((key, value) -> analysis.addMethodCallReceiver(key.getNode(), value));
39514084
analysis.addColumnReferences(analyzer.getColumnReferences());
39524085
analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences());
39534086
analysis.addTableColumnReferences(accessControl, session.getIdentity(), analyzer.getTableColumnReferences());

0 commit comments

Comments
 (0)