Skip to content

Commit e778cbe

Browse files
committed
Rust: Resolve function calls to traits methods
1 parent e4d1b01 commit e778cbe

File tree

8 files changed

+238
-123
lines changed

8 files changed

+238
-123
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallExprImpl.qll

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ private import codeql.rust.elements.PathExpr
1414
module Impl {
1515
private import rust
1616
private import codeql.rust.internal.PathResolution as PathResolution
17+
private import codeql.rust.internal.TypeInference as TypeInference
1718

1819
pragma[nomagic]
1920
Path getFunctionPath(CallExpr ce) { result = ce.getFunction().(PathExpr).getPath() }
@@ -36,7 +37,14 @@ module Impl {
3637
class CallExpr extends Generated::CallExpr {
3738
override string toStringImpl() { result = this.getFunction().toAbbreviatedString() + "(...)" }
3839

39-
override Callable getStaticTarget() { result = getResolvedFunction(this) }
40+
override Callable getStaticTarget() {
41+
// If this call is to a trait method, e.g., `Trait::foo(bar)`, then check
42+
// if type inference can resolve it to the correct trait implementation.
43+
result = TypeInference::resolveMethodCallTarget(this)
44+
or
45+
not exists(TypeInference::resolveMethodCallTarget(this)) and
46+
result = getResolvedFunction(this)
47+
}
4048

4149
/** Gets the struct that this call resolves to, if any. */
4250
Struct getStruct() { result = getResolvedFunction(this) }

rust/ql/lib/codeql/rust/elements/internal/MethodCallExprImpl.qll

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,6 @@ private import codeql.rust.internal.TypeInference
1414
* be referenced directly.
1515
*/
1616
module Impl {
17-
private predicate isInherentImplFunction(Function f) {
18-
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
19-
}
20-
21-
private predicate isTraitImplFunction(Function f) {
22-
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
23-
}
24-
2517
// the following QLdoc is generated: if you need to edit it, do it in the schema file
2618
/**
2719
* A method call expression. For example:
@@ -31,38 +23,7 @@ module Impl {
3123
* ```
3224
*/
3325
class MethodCallExpr extends Generated::MethodCallExpr {
34-
private Function getStaticTargetFrom(boolean fromSource) {
35-
result = resolveMethodCallExpr(this) and
36-
(if result.fromSource() then fromSource = true else fromSource = false) and
37-
(
38-
// prioritize inherent implementation methods first
39-
isInherentImplFunction(result)
40-
or
41-
not isInherentImplFunction(resolveMethodCallExpr(this)) and
42-
(
43-
// then trait implementation methods
44-
isTraitImplFunction(result)
45-
or
46-
not isTraitImplFunction(resolveMethodCallExpr(this)) and
47-
(
48-
// then trait methods with default implementations
49-
result.hasBody()
50-
or
51-
// and finally trait methods without default implementations
52-
not resolveMethodCallExpr(this).hasBody()
53-
)
54-
)
55-
)
56-
}
57-
58-
override Function getStaticTarget() {
59-
// Functions in source code also gets extracted as library code, due to
60-
// this duplication we prioritize functions from source code.
61-
result = this.getStaticTargetFrom(true)
62-
or
63-
not exists(this.getStaticTargetFrom(true)) and
64-
result = this.getStaticTargetFrom(false)
65-
}
26+
override Function getStaticTarget() { result = resolveMethodCallTarget(this) }
6627

6728
private string toStringPart(int index) {
6829
index = 0 and

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 179 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
678678
Declaration getTarget() {
679679
result = CallExprImpl::getResolvedFunction(this)
680680
or
681-
result = resolveMethodCallExpr(this) // mutual recursion; resolving method calls requires resolving types and vice versa
681+
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
682682
}
683683
}
684684

@@ -1000,6 +1000,150 @@ private StructType inferLiteralType(LiteralExpr le) {
10001000
)
10011001
}
10021002

1003+
private module MethodCall {
1004+
/** An expression that calls a method. */
1005+
abstract private class MethodCallImpl extends Expr {
1006+
/** Gets the name of the method targeted. */
1007+
abstract string getMethodName();
1008+
1009+
/** Gets the number of arguments _excluding_ the `self` argument. */
1010+
abstract int getArity();
1011+
1012+
/** Gets the trait targeted by this method call, if any. */
1013+
Trait getTrait() { none() }
1014+
1015+
/** Gets the type of the receiver of the method call at `path`. */
1016+
abstract Type getTypeAt(TypePath path);
1017+
}
1018+
1019+
final class MethodCall = MethodCallImpl;
1020+
1021+
private class MethodCallExprMethodCall extends MethodCallImpl instanceof MethodCallExpr {
1022+
override string getMethodName() { result = super.getIdentifier().getText() }
1023+
1024+
override int getArity() { result = super.getArgList().getNumberOfArgs() }
1025+
1026+
pragma[nomagic]
1027+
override Type getTypeAt(TypePath path) {
1028+
exists(TypePath path0 | result = inferType(super.getReceiver(), path0) |
1029+
path0.isCons(TRefTypeParameter(), path)
1030+
or
1031+
not path0.isCons(TRefTypeParameter(), _) and
1032+
not (path0.isEmpty() and result = TRefType()) and
1033+
path = path0
1034+
)
1035+
}
1036+
}
1037+
1038+
private class CallExprMethodCall extends MethodCallImpl instanceof CallExpr {
1039+
TraitItemNode trait;
1040+
string methodName;
1041+
Expr receiver;
1042+
1043+
CallExprMethodCall() {
1044+
receiver = this.getArgList().getArg(0) and
1045+
exists(Path path, Function f |
1046+
path = this.getFunction().(PathExpr).getPath() and
1047+
f = resolvePath(path) and
1048+
f.getParamList().hasSelfParam() and
1049+
trait = resolvePath(path.getQualifier()) and
1050+
trait.getAnAssocItem() = f and
1051+
path.getSegment().getIdentifier().getText() = methodName
1052+
)
1053+
}
1054+
1055+
override string getMethodName() { result = methodName }
1056+
1057+
override int getArity() { result = super.getArgList().getNumberOfArgs() - 1 }
1058+
1059+
override Trait getTrait() { result = trait }
1060+
1061+
pragma[nomagic]
1062+
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
1063+
}
1064+
}
1065+
1066+
import MethodCall
1067+
1068+
/**
1069+
* Holds if a method for `type` with the name `name` and the arity `arity`
1070+
* exists in `impl`.
1071+
*/
1072+
private predicate methodCandidate(Type type, string name, int arity, Impl impl) {
1073+
type = impl.getSelfTy().(TypeMention).resolveType() and
1074+
exists(Function f |
1075+
f = impl.(ImplItemNode).getASuccessor(name) and
1076+
f.getParamList().hasSelfParam() and
1077+
arity = f.getParamList().getNumberOfParams()
1078+
)
1079+
}
1080+
1081+
/**
1082+
* Holds if a method for `type` for `trait` with the name `name` and the arity
1083+
* `arity` exists in `impl`.
1084+
*/
1085+
pragma[nomagic]
1086+
private predicate methodCandidateTrait(Type type, Trait trait, string name, int arity, Impl impl) {
1087+
trait = resolvePath(impl.getTrait().(PathTypeRepr).getPath()) and
1088+
methodCandidate(type, name, arity, impl)
1089+
}
1090+
1091+
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
1092+
pragma[nomagic]
1093+
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
1094+
exists(Type rootType, string name, int arity |
1095+
rootType = mc.getTypeAt(TypePath::nil()) and
1096+
name = mc.getMethodName() and
1097+
arity = mc.getArity() and
1098+
constraint = impl.(ImplTypeAbstraction).getSelfTy()
1099+
|
1100+
methodCandidateTrait(rootType, mc.getTrait(), name, arity, impl)
1101+
or
1102+
not exists(mc.getTrait()) and
1103+
methodCandidate(rootType, name, arity, impl)
1104+
)
1105+
}
1106+
1107+
predicate relevantTypeMention(TypeMention constraint) {
1108+
exists(Impl impl | methodCandidate(_, _, _, impl) and constraint = impl.getSelfTy())
1109+
}
1110+
}
1111+
1112+
bindingset[item, name]
1113+
pragma[inline_late]
1114+
private Function getMethodSuccessor(ItemNode item, string name) {
1115+
result = item.getASuccessor(name)
1116+
}
1117+
1118+
bindingset[tp, name]
1119+
pragma[inline_late]
1120+
private Function getTypeParameterMethod(TypeParameter tp, string name) {
1121+
result = getMethodSuccessor(tp.(TypeParamTypeParameter).getTypeParam(), name)
1122+
or
1123+
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
1124+
}
1125+
1126+
/** Gets a method from an `impl` block that matches the method call `mc`. */
1127+
private Function getMethodFromImpl(MethodCall mc) {
1128+
exists(Impl impl |
1129+
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
1130+
result = getMethodSuccessor(impl, mc.getMethodName())
1131+
)
1132+
}
1133+
1134+
/**
1135+
* Gets a method that the method call `mc` resolves to based on type inference,
1136+
* if any.
1137+
*/
1138+
private Function inferMethodCallTarget(MethodCall mc) {
1139+
// The method comes from an `impl` block targeting the type of the receiver.
1140+
result = getMethodFromImpl(mc)
1141+
or
1142+
// The type of the receiver is a type parameter and the method comes from a
1143+
// trait bound on the type parameter.
1144+
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
1145+
}
1146+
10031147
cached
10041148
private module Cached {
10051149
private import codeql.rust.internal.CachedStages
@@ -1026,90 +1170,47 @@ private module Cached {
10261170
)
10271171
}
10281172

1029-
private class ReceiverExpr extends Expr {
1030-
MethodCallExpr mce;
1031-
1032-
ReceiverExpr() { mce.getReceiver() = this }
1033-
1034-
string getField() { result = mce.getIdentifier().getText() }
1035-
1036-
int getNumberOfArgs() { result = mce.getArgList().getNumberOfArgs() }
1037-
1038-
pragma[nomagic]
1039-
Type getTypeAt(TypePath path) {
1040-
exists(TypePath path0 | result = inferType(this, path0) |
1041-
path0.isCons(TRefTypeParameter(), path)
1042-
or
1043-
not path0.isCons(TRefTypeParameter(), _) and
1044-
not (path0.isEmpty() and result = TRefType()) and
1045-
path = path0
1046-
)
1047-
}
1048-
}
1049-
1050-
/** Holds if a method for `type` with the name `name` and the arity `arity` exists in `impl`. */
1051-
pragma[nomagic]
1052-
private predicate methodCandidate(Type type, string name, int arity, Impl impl) {
1053-
type = impl.getSelfTy().(TypeReprMention).resolveType() and
1054-
exists(Function f |
1055-
f = impl.(ImplItemNode).getASuccessor(name) and
1056-
f.getParamList().hasSelfParam() and
1057-
arity = f.getParamList().getNumberOfParams()
1058-
)
1059-
}
1060-
1061-
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<ReceiverExpr> {
1062-
pragma[nomagic]
1063-
predicate potentialInstantiationOf(
1064-
ReceiverExpr receiver, TypeAbstraction impl, TypeMention constraint
1065-
) {
1066-
methodCandidate(receiver.getTypeAt(TypePath::nil()), receiver.getField(),
1067-
receiver.getNumberOfArgs(), impl) and
1068-
constraint = impl.(ImplTypeAbstraction).getSelfTy()
1069-
}
1070-
1071-
predicate relevantTypeMention(TypeMention constraint) {
1072-
exists(Impl impl | methodCandidate(_, _, _, impl) and constraint = impl.getSelfTy())
1073-
}
1074-
}
1075-
1076-
bindingset[item, name]
1077-
pragma[inline_late]
1078-
private Function getMethodSuccessor(ItemNode item, string name) {
1079-
result = item.getASuccessor(name)
1173+
private predicate isInherentImplFunction(Function f) {
1174+
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
10801175
}
10811176

1082-
bindingset[tp, name]
1083-
pragma[inline_late]
1084-
private Function getTypeParameterMethod(TypeParameter tp, string name) {
1085-
result = getMethodSuccessor(tp.(TypeParamTypeParameter).getTypeParam(), name)
1086-
or
1087-
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
1177+
private predicate isTraitImplFunction(Function f) {
1178+
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
10881179
}
10891180

1090-
/**
1091-
* Gets the method from an `impl` block with an implementing type that matches
1092-
* the type of `receiver` and with a name of the method call in which
1093-
* `receiver` occurs, if any.
1094-
*/
1095-
private Function getMethodFromImpl(ReceiverExpr receiver) {
1096-
exists(Impl impl |
1097-
IsInstantiationOf<ReceiverExpr, IsInstantiationOfInput>::isInstantiationOf(receiver, impl, _) and
1098-
result = getMethodSuccessor(impl, receiver.getField())
1181+
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
1182+
result = inferMethodCallTarget(mc) and
1183+
(if result.fromSource() then fromSource = true else fromSource = false) and
1184+
(
1185+
// prioritize inherent implementation methods first
1186+
isInherentImplFunction(result)
1187+
or
1188+
not isInherentImplFunction(inferMethodCallTarget(mc)) and
1189+
(
1190+
// then trait implementation methods
1191+
isTraitImplFunction(result)
1192+
or
1193+
not isTraitImplFunction(inferMethodCallTarget(mc)) and
1194+
(
1195+
// then trait methods with default implementations
1196+
result.hasBody()
1197+
or
1198+
// and finally trait methods without default implementations
1199+
not inferMethodCallTarget(mc).hasBody()
1200+
)
1201+
)
10991202
)
11001203
}
11011204

1102-
/** Gets a method that the method call `mce` resolves to, if any. */
1205+
/** Gets a method that the method call `mc` resolves to, if any. */
11031206
cached
1104-
Function resolveMethodCallExpr(MethodCallExpr mce) {
1105-
exists(ReceiverExpr receiver | mce.getReceiver() = receiver |
1106-
// The method comes from an `impl` block targeting the type of `receiver`.
1107-
result = getMethodFromImpl(receiver)
1108-
or
1109-
// The type of `receiver` is a type parameter and the method comes from a
1110-
// trait bound on the type parameter.
1111-
result = getTypeParameterMethod(receiver.getTypeAt(TypePath::nil()), receiver.getField())
1112-
)
1207+
Function resolveMethodCallTarget(MethodCall mc) {
1208+
// Functions in source code also gets extracted as library code, due to
1209+
// this duplication we prioritize functions from source code.
1210+
result = resolveMethodCallTargetFrom(mc, true)
1211+
or
1212+
not exists(resolveMethodCallTargetFrom(mc, true)) and
1213+
result = resolveMethodCallTargetFrom(mc, false)
11131214
}
11141215

11151216
pragma[inline]
@@ -1243,6 +1344,6 @@ private module Debug {
12431344

12441345
Function debugResolveMethodCallExpr(MethodCallExpr mce) {
12451346
mce = getRelevantLocatable() and
1246-
result = resolveMethodCallExpr(mce)
1347+
result = resolveMethodCallTarget(mce)
12471348
}
12481349
}

0 commit comments

Comments
 (0)