Skip to content

Commit

Permalink
Rust: Take nested functions into account when resolving variables
Browse files Browse the repository at this point in the history
  • Loading branch information
hvitved committed Jan 13, 2025
1 parent 492f4e9 commit 50f80f2
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 60 deletions.
153 changes: 109 additions & 44 deletions rust/ql/lib/codeql/rust/elements/internal/VariableImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -397,20 +397,23 @@ module Impl {
)
}

private newtype TVariableOrAccessCand =
TVariableOrAccessCandVariable(Variable v) or
TVariableOrAccessCandVariableAccessCand(VariableAccessCand va)
private newtype TDefOrAccessCand =
TDefOrAccessCandNestedFunction(Function f, BlockExprScope scope) {
f = scope.getStmtList().getAStatement()
} or
TDefOrAccessCandVariable(Variable v) or
TDefOrAccessCandVariableAccessCand(VariableAccessCand va)

/**
* A variable declaration or variable access candidate.
* A nested function declaration, variable declaration, or variable (or function)
* access candidate.
*
* In order to determine whether a candidate is an actual variable access,
* we rank declarations and candidates by their position in source code.
* In order to determine whether a candidate is an actual variable/function access,
* we rank declarations and candidates by their position in the AST.
*
* The ranking must take variable names into account, but also variable scopes;
* below a comment `rank(scope, name, i)` means that the declaration/access on
* the given line has rank `i` amongst all declarations/accesses inside variable
* scope `scope`, for variable name `name`:
* The ranking must take names into account, but also variable scopes; below a comment
* `rank(scope, name, i)` means that the declaration/access on the given line has rank
* `i` amongst all declarations/accesses inside variable scope `scope`, for name `name`:
*
* ```rust
* fn f() { // scope0
Expand All @@ -430,8 +433,8 @@ module Impl {
* }
* ```
*
* Variable declarations are only ranked in the scope that they bind into, while
* accesses candidates propagate outwards through scopes, as they may access
* Function/variable declarations are only ranked in the scope that they bind into,
* while accesses candidates propagate outwards through scopes, as they may access
* declarations from outer scopes.
*
* For an access candidate with ranks `{ rank(scope_i, name, rnk_i) | i in I }` and
Expand All @@ -448,41 +451,80 @@ module Impl {
* i.e., its the nearest declaration before the access in the same (or outer) scope
* as the access.
*/
private class VariableOrAccessCand extends TVariableOrAccessCand {
Variable asVariable() { this = TVariableOrAccessCandVariable(result) }
abstract private class DefOrAccessCand extends TDefOrAccessCand {
abstract string toString();

VariableAccessCand asVariableAccessCand() {
this = TVariableOrAccessCandVariableAccessCand(result)
}
abstract Location getLocation();

string toString() {
result = this.asVariable().toString() or result = this.asVariableAccessCand().toString()
}
pragma[nomagic]
abstract predicate rankBy(string name, VariableScope scope, int ord, int kind);
}

Location getLocation() {
result = this.asVariable().getLocation() or result = this.asVariableAccessCand().getLocation()
}
abstract private class NestedFunctionOrVariable extends DefOrAccessCand { }

pragma[nomagic]
predicate rankBy(string name, VariableScope scope, int ord, int kind) {
variableDeclInScope(this.asVariable(), scope, name, ord) and
private class DefOrAccessCandNestedFunction extends NestedFunctionOrVariable,
TDefOrAccessCandNestedFunction
{
private Function f;
private BlockExprScope scope_;

DefOrAccessCandNestedFunction() { this = TDefOrAccessCandNestedFunction(f, scope_) }

override string toString() { result = f.toString() }

override Location getLocation() { result = f.getLocation() }

override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
// nested functions behave as if they are defined at the beginning of the scope
name = f.getName().getText() and
scope = scope_ and
ord = 0 and
kind = 0
or
variableAccessCandInScope(this.asVariableAccessCand(), scope, name, _, ord) and
}
}

private class DefOrAccessCandVariable extends NestedFunctionOrVariable, TDefOrAccessCandVariable {
private Variable v;

DefOrAccessCandVariable() { this = TDefOrAccessCandVariable(v) }

override string toString() { result = v.toString() }

override Location getLocation() { result = v.getLocation() }

override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
variableDeclInScope(v, scope, name, ord) and
kind = 1
}
}

private class DefOrAccessCandVariableAccessCand extends DefOrAccessCand,
TDefOrAccessCandVariableAccessCand
{
private VariableAccessCand va;

DefOrAccessCandVariableAccessCand() { this = TDefOrAccessCandVariableAccessCand(va) }

override string toString() { result = va.toString() }

override Location getLocation() { result = va.getLocation() }

override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
variableAccessCandInScope(va, scope, name, _, ord) and
kind = 2
}
}

private module DenseRankInput implements DenseRankInputSig2 {
class C1 = VariableScope;

class C2 = string;

class Ranked = VariableOrAccessCand;
class Ranked = DefOrAccessCand;

int getRank(VariableScope scope, string name, VariableOrAccessCand v) {
int getRank(VariableScope scope, string name, DefOrAccessCand v) {
v =
rank[result](VariableOrAccessCand v0, int ord, int kind |
rank[result](DefOrAccessCand v0, int ord, int kind |
v0.rankBy(name, scope, ord, kind)
|
v0 order by ord, kind
Expand All @@ -494,7 +536,7 @@ module Impl {
* Gets the rank of `v` amongst all other declarations or access candidates
* to a variable named `name` in the variable scope `scope`.
*/
private int rankVariableOrAccess(VariableScope scope, string name, VariableOrAccessCand v) {
private int rankVariableOrAccess(VariableScope scope, string name, DefOrAccessCand v) {
v = DenseRank2<DenseRankInput>::denseRank(scope, name, result + 1)
}

Expand All @@ -512,25 +554,38 @@ module Impl {
* the declaration at rank 0 can only reach the access at rank 1, while the declaration
* at rank 2 can only reach the access at rank 3.
*/
private predicate variableReachesRank(VariableScope scope, string name, Variable v, int rnk) {
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariable(v))
private predicate variableReachesRank(
VariableScope scope, string name, NestedFunctionOrVariable v, int rnk
) {
rnk = rankVariableOrAccess(scope, name, v)
or
variableReachesRank(scope, name, v, rnk - 1) and
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariableAccessCand(_))
rnk = rankVariableOrAccess(scope, name, TDefOrAccessCandVariableAccessCand(_))
}

private predicate variableReachesCand(
VariableScope scope, string name, Variable v, VariableAccessCand cand, int nestLevel
VariableScope scope, string name, NestedFunctionOrVariable v, VariableAccessCand cand,
int nestLevel
) {
exists(int rnk |
variableReachesRank(scope, name, v, rnk) and
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariableAccessCand(cand)) and
rnk = rankVariableOrAccess(scope, name, TDefOrAccessCandVariableAccessCand(cand)) and
variableAccessCandInScope(cand, scope, name, nestLevel, _)
)
}

pragma[nomagic]
predicate access(string name, NestedFunctionOrVariable v, VariableAccessCand cand) {
v =
min(NestedFunctionOrVariable v0, int nestLevel |
variableReachesCand(_, name, v0, cand, nestLevel)
|
v0 order by nestLevel
)
}

/** A variable access. */
class VariableAccess extends PathExprBaseImpl::PathExprBase instanceof VariableAccessCand {
class VariableAccess extends PathExprBaseImpl::PathExprBase {
private string name;
private Variable v;

Expand Down Expand Up @@ -574,6 +629,16 @@ module Impl {
}
}

/** A variable access. */
class NestedFunctionAccess extends PathExprBaseImpl::PathExprBase {
private Function f;

NestedFunctionAccess() { nestedFunctionAccess(_, f, this) }

/** Gets the function being accessed. */
Function getFunction() { result = f }
}

cached
private module Cached {
cached
Expand All @@ -582,12 +647,12 @@ module Impl {

cached
predicate variableAccess(string name, Variable v, VariableAccessCand cand) {
v =
min(Variable v0, int nestLevel |
variableReachesCand(_, name, v0, cand, nestLevel)
|
v0 order by nestLevel
)
access(name, TDefOrAccessCandVariable(v), cand)
}

cached
predicate nestedFunctionAccess(string name, Function f, VariableAccessCand cand) {
access(name, TDefOrAccessCandNestedFunction(f, _), cand)
}
}

Expand Down
10 changes: 5 additions & 5 deletions rust/ql/test/library-tests/variables/Cfg.expected
Original file line number Diff line number Diff line change
Expand Up @@ -748,17 +748,17 @@ edges
| main.rs:342:17:342:17 | 2 | main.rs:342:15:342:18 | f(...) | |
| main.rs:344:5:358:5 | { ... } | main.rs:331:22:359:1 | { ... } | |
| main.rs:345:9:345:17 | print_i64 | main.rs:345:19:345:19 | f | |
| main.rs:345:9:345:23 | print_i64(...) | main.rs:345:26:348:9 | fn f | |
| main.rs:345:9:345:23 | print_i64(...) | main.rs:346:9:348:9 | fn f | |
| main.rs:345:9:345:24 | ExprStmt | main.rs:345:9:345:17 | print_i64 | |
| main.rs:345:19:345:19 | f | main.rs:345:21:345:21 | 3 | |
| main.rs:345:19:345:22 | f(...) | main.rs:345:9:345:23 | print_i64(...) | |
| main.rs:345:21:345:21 | 3 | main.rs:345:19:345:22 | f(...) | |
| main.rs:345:26:348:9 | enter fn f | main.rs:346:14:346:14 | x | |
| main.rs:345:26:348:9 | exit fn f (normal) | main.rs:345:26:348:9 | exit fn f | |
| main.rs:345:26:348:9 | fn f | main.rs:350:9:352:9 | ExprStmt | |
| main.rs:346:9:348:9 | enter fn f | main.rs:346:14:346:14 | x | |
| main.rs:346:9:348:9 | exit fn f (normal) | main.rs:346:9:348:9 | exit fn f | |
| main.rs:346:9:348:9 | fn f | main.rs:350:9:352:9 | ExprStmt | |
| main.rs:346:14:346:14 | x | main.rs:346:14:346:19 | ...: i64 | match |
| main.rs:346:14:346:19 | ...: i64 | main.rs:347:13:347:13 | 2 | |
| main.rs:346:29:348:9 | { ... } | main.rs:345:26:348:9 | exit fn f (normal) | |
| main.rs:346:29:348:9 | { ... } | main.rs:346:9:348:9 | exit fn f (normal) | |
| main.rs:347:13:347:13 | 2 | main.rs:347:17:347:17 | x | |
| main.rs:347:13:347:17 | ... * ... | main.rs:346:29:348:9 | { ... } | |
| main.rs:347:17:347:17 | x | main.rs:347:13:347:17 | ... * ... | |
Expand Down
6 changes: 1 addition & 5 deletions rust/ql/test/library-tests/variables/Ssa.expected
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,6 @@ read
| main.rs:326:9:326:10 | n2 | main.rs:326:9:326:10 | n2 | main.rs:328:15:328:16 | n2 |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:336:15:336:15 | f |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:342:15:342:15 | f |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:345:19:345:19 | f |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:351:23:351:23 | f |
| main.rs:334:10:334:10 | x | main.rs:334:10:334:10 | x | main.rs:335:9:335:9 | x |
| main.rs:338:10:338:10 | x | main.rs:338:10:338:10 | x | main.rs:339:9:339:9 | x |
| main.rs:346:14:346:14 | x | main.rs:346:14:346:14 | x | main.rs:347:17:347:17 | x |
Expand Down Expand Up @@ -497,7 +495,7 @@ lastRead
| main.rs:323:9:323:26 | immutable_variable | main.rs:323:9:323:26 | immutable_variable | main.rs:327:9:327:26 | immutable_variable |
| main.rs:324:10:324:10 | x | main.rs:324:10:324:10 | x | main.rs:325:9:325:9 | x |
| main.rs:326:9:326:10 | n2 | main.rs:326:9:326:10 | n2 | main.rs:328:15:328:16 | n2 |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:351:23:351:23 | f |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:342:15:342:15 | f |
| main.rs:334:10:334:10 | x | main.rs:334:10:334:10 | x | main.rs:335:9:335:9 | x |
| main.rs:338:10:338:10 | x | main.rs:338:10:338:10 | x | main.rs:339:9:339:9 | x |
| main.rs:346:14:346:14 | x | main.rs:346:14:346:14 | x | main.rs:347:17:347:17 | x |
Expand Down Expand Up @@ -560,8 +558,6 @@ adjacentReads
| main.rs:290:9:290:11 | a10 | main.rs:279:13:279:15 | a10 | main.rs:292:9:292:11 | a10 | main.rs:296:15:296:17 | a10 |
| main.rs:290:9:290:11 | a10 | main.rs:279:13:279:15 | a10 | main.rs:296:15:296:17 | a10 | main.rs:310:15:310:17 | a10 |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:336:15:336:15 | f | main.rs:342:15:342:15 | f |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:342:15:342:15 | f | main.rs:345:19:345:19 | f |
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:345:19:345:19 | f | main.rs:351:23:351:23 | f |
| main.rs:386:17:386:17 | x | main.rs:386:17:386:17 | x | main.rs:387:6:387:6 | x | main.rs:388:10:388:10 | x |
| main.rs:386:17:386:17 | x | main.rs:386:17:386:17 | x | main.rs:388:10:388:10 | x | main.rs:389:10:389:10 | x |
| main.rs:386:17:386:17 | x | main.rs:386:17:386:17 | x | main.rs:389:10:389:10 | x | main.rs:390:12:390:12 | x |
Expand Down
4 changes: 2 additions & 2 deletions rust/ql/test/library-tests/variables/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,13 @@ fn nested_function() {
print_i64(f(2)); // $ read_access=f1

{
print_i64(f(3)); // $ SPURIOUS: read_access=f1
print_i64(f(3));
fn f(x: i64) -> i64 { // x_3
2 * x // $ read_access=x_3
}

{
print_i64(f(4)); // $ SPURIOUS: read_access=f1
print_i64(f(4));
}

let f = // f2
Expand Down
7 changes: 3 additions & 4 deletions rust/ql/test/library-tests/variables/variables.expected
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ variableAccess
| main.rs:336:15:336:15 | f | main.rs:333:9:333:9 | f |
| main.rs:339:9:339:9 | x | main.rs:338:10:338:10 | x |
| main.rs:342:15:342:15 | f | main.rs:333:9:333:9 | f |
| main.rs:345:19:345:19 | f | main.rs:333:9:333:9 | f |
| main.rs:347:17:347:17 | x | main.rs:346:14:346:14 | x |
| main.rs:351:23:351:23 | f | main.rs:333:9:333:9 | f |
| main.rs:356:13:356:13 | x | main.rs:355:14:355:14 | x |
| main.rs:357:19:357:19 | f | main.rs:354:13:354:13 | f |
| main.rs:365:12:365:12 | v | main.rs:362:9:362:9 | v |
Expand Down Expand Up @@ -393,9 +391,7 @@ variableReadAccess
| main.rs:336:15:336:15 | f | main.rs:333:9:333:9 | f |
| main.rs:339:9:339:9 | x | main.rs:338:10:338:10 | x |
| main.rs:342:15:342:15 | f | main.rs:333:9:333:9 | f |
| main.rs:345:19:345:19 | f | main.rs:333:9:333:9 | f |
| main.rs:347:17:347:17 | x | main.rs:346:14:346:14 | x |
| main.rs:351:23:351:23 | f | main.rs:333:9:333:9 | f |
| main.rs:356:13:356:13 | x | main.rs:355:14:355:14 | x |
| main.rs:357:19:357:19 | f | main.rs:354:13:354:13 | f |
| main.rs:365:12:365:12 | v | main.rs:362:9:362:9 | v |
Expand Down Expand Up @@ -542,3 +538,6 @@ capturedAccess
| main.rs:459:9:459:9 | z |
| main.rs:468:9:468:9 | i |
| main.rs:523:13:523:16 | self |
nestedFunctionAccess
| main.rs:345:19:345:19 | f | main.rs:346:9:348:9 | fn f |
| main.rs:351:23:351:23 | f | main.rs:346:9:348:9 | fn f |
5 changes: 5 additions & 0 deletions rust/ql/test/library-tests/variables/variables.ql
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import rust
import utils.test.InlineExpectationsTest
import codeql.rust.elements.internal.VariableImpl::Impl as VariableImpl

query predicate variable(Variable v) { any() }

Expand All @@ -15,6 +16,10 @@ query predicate capturedVariable(Variable v) { v.isCaptured() }

query predicate capturedAccess(VariableAccess va) { va.isCapture() }

query predicate nestedFunctionAccess(VariableImpl::NestedFunctionAccess nfa, Function f) {
f = nfa.getFunction()
}

module VariableAccessTest implements TestSig {
string getARelevantTag() { result = ["", "write_", "read_"] + "access" }

Expand Down

0 comments on commit 50f80f2

Please sign in to comment.