Skip to content

Commit 1bbb3fd

Browse files
committed
Rust: Take nested functions into account when resolving variables
1 parent fe216ae commit 1bbb3fd

File tree

6 files changed

+125
-60
lines changed

6 files changed

+125
-60
lines changed

rust/ql/lib/codeql/rust/elements/internal/VariableImpl.qll

+109-44
Original file line numberDiff line numberDiff line change
@@ -397,20 +397,23 @@ module Impl {
397397
)
398398
}
399399

400-
private newtype TVariableOrAccessCand =
401-
TVariableOrAccessCandVariable(Variable v) or
402-
TVariableOrAccessCandVariableAccessCand(VariableAccessCand va)
400+
private newtype TDefOrAccessCand =
401+
TDefOrAccessCandNestedFunction(Function f, BlockExprScope scope) {
402+
f = scope.getStmtList().getAStatement()
403+
} or
404+
TDefOrAccessCandVariable(Variable v) or
405+
TDefOrAccessCandVariableAccessCand(VariableAccessCand va)
403406

404407
/**
405-
* A variable declaration or variable access candidate.
408+
* A nested function declaration, variable declaration, or variable (or function)
409+
* access candidate.
406410
*
407-
* In order to determine whether a candidate is an actual variable access,
408-
* we rank declarations and candidates by their position in source code.
411+
* In order to determine whether a candidate is an actual variable/function access,
412+
* we rank declarations and candidates by their position in the AST.
409413
*
410-
* The ranking must take variable names into account, but also variable scopes;
411-
* below a comment `rank(scope, name, i)` means that the declaration/access on
412-
* the given line has rank `i` amongst all declarations/accesses inside variable
413-
* scope `scope`, for variable name `name`:
414+
* The ranking must take names into account, but also variable scopes; below a comment
415+
* `rank(scope, name, i)` means that the declaration/access on the given line has rank
416+
* `i` amongst all declarations/accesses inside variable scope `scope`, for name `name`:
414417
*
415418
* ```rust
416419
* fn f() { // scope0
@@ -430,8 +433,8 @@ module Impl {
430433
* }
431434
* ```
432435
*
433-
* Variable declarations are only ranked in the scope that they bind into, while
434-
* accesses candidates propagate outwards through scopes, as they may access
436+
* Function/variable declarations are only ranked in the scope that they bind into,
437+
* while accesses candidates propagate outwards through scopes, as they may access
435438
* declarations from outer scopes.
436439
*
437440
* For an access candidate with ranks `{ rank(scope_i, name, rnk_i) | i in I }` and
@@ -448,41 +451,80 @@ module Impl {
448451
* i.e., its the nearest declaration before the access in the same (or outer) scope
449452
* as the access.
450453
*/
451-
private class VariableOrAccessCand extends TVariableOrAccessCand {
452-
Variable asVariable() { this = TVariableOrAccessCandVariable(result) }
454+
abstract private class DefOrAccessCand extends TDefOrAccessCand {
455+
abstract string toString();
453456

454-
VariableAccessCand asVariableAccessCand() {
455-
this = TVariableOrAccessCandVariableAccessCand(result)
456-
}
457+
abstract Location getLocation();
457458

458-
string toString() {
459-
result = this.asVariable().toString() or result = this.asVariableAccessCand().toString()
460-
}
459+
pragma[nomagic]
460+
abstract predicate rankBy(string name, VariableScope scope, int ord, int kind);
461+
}
461462

462-
Location getLocation() {
463-
result = this.asVariable().getLocation() or result = this.asVariableAccessCand().getLocation()
464-
}
463+
abstract private class NestedFunctionOrVariable extends DefOrAccessCand { }
465464

466-
pragma[nomagic]
467-
predicate rankBy(string name, VariableScope scope, int ord, int kind) {
468-
variableDeclInScope(this.asVariable(), scope, name, ord) and
465+
private class DefOrAccessCandNestedFunction extends NestedFunctionOrVariable,
466+
TDefOrAccessCandNestedFunction
467+
{
468+
private Function f;
469+
private BlockExprScope scope_;
470+
471+
DefOrAccessCandNestedFunction() { this = TDefOrAccessCandNestedFunction(f, scope_) }
472+
473+
override string toString() { result = f.toString() }
474+
475+
override Location getLocation() { result = f.getLocation() }
476+
477+
override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
478+
// nested functions behave as if they are defined at the beginning of the scope
479+
name = f.getName().getText() and
480+
scope = scope_ and
481+
ord = 0 and
469482
kind = 0
470-
or
471-
variableAccessCandInScope(this.asVariableAccessCand(), scope, name, _, ord) and
483+
}
484+
}
485+
486+
private class DefOrAccessCandVariable extends NestedFunctionOrVariable, TDefOrAccessCandVariable {
487+
private Variable v;
488+
489+
DefOrAccessCandVariable() { this = TDefOrAccessCandVariable(v) }
490+
491+
override string toString() { result = v.toString() }
492+
493+
override Location getLocation() { result = v.getLocation() }
494+
495+
override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
496+
variableDeclInScope(v, scope, name, ord) and
472497
kind = 1
473498
}
474499
}
475500

501+
private class DefOrAccessCandVariableAccessCand extends DefOrAccessCand,
502+
TDefOrAccessCandVariableAccessCand
503+
{
504+
private VariableAccessCand va;
505+
506+
DefOrAccessCandVariableAccessCand() { this = TDefOrAccessCandVariableAccessCand(va) }
507+
508+
override string toString() { result = va.toString() }
509+
510+
override Location getLocation() { result = va.getLocation() }
511+
512+
override predicate rankBy(string name, VariableScope scope, int ord, int kind) {
513+
variableAccessCandInScope(va, scope, name, _, ord) and
514+
kind = 2
515+
}
516+
}
517+
476518
private module DenseRankInput implements DenseRankInputSig2 {
477519
class C1 = VariableScope;
478520

479521
class C2 = string;
480522

481-
class Ranked = VariableOrAccessCand;
523+
class Ranked = DefOrAccessCand;
482524

483-
int getRank(VariableScope scope, string name, VariableOrAccessCand v) {
525+
int getRank(VariableScope scope, string name, DefOrAccessCand v) {
484526
v =
485-
rank[result](VariableOrAccessCand v0, int ord, int kind |
527+
rank[result](DefOrAccessCand v0, int ord, int kind |
486528
v0.rankBy(name, scope, ord, kind)
487529
|
488530
v0 order by ord, kind
@@ -494,7 +536,7 @@ module Impl {
494536
* Gets the rank of `v` amongst all other declarations or access candidates
495537
* to a variable named `name` in the variable scope `scope`.
496538
*/
497-
private int rankVariableOrAccess(VariableScope scope, string name, VariableOrAccessCand v) {
539+
private int rankVariableOrAccess(VariableScope scope, string name, DefOrAccessCand v) {
498540
v = DenseRank2<DenseRankInput>::denseRank(scope, name, result + 1)
499541
}
500542

@@ -512,25 +554,38 @@ module Impl {
512554
* the declaration at rank 0 can only reach the access at rank 1, while the declaration
513555
* at rank 2 can only reach the access at rank 3.
514556
*/
515-
private predicate variableReachesRank(VariableScope scope, string name, Variable v, int rnk) {
516-
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariable(v))
557+
private predicate variableReachesRank(
558+
VariableScope scope, string name, NestedFunctionOrVariable v, int rnk
559+
) {
560+
rnk = rankVariableOrAccess(scope, name, v)
517561
or
518562
variableReachesRank(scope, name, v, rnk - 1) and
519-
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariableAccessCand(_))
563+
rnk = rankVariableOrAccess(scope, name, TDefOrAccessCandVariableAccessCand(_))
520564
}
521565

522566
private predicate variableReachesCand(
523-
VariableScope scope, string name, Variable v, VariableAccessCand cand, int nestLevel
567+
VariableScope scope, string name, NestedFunctionOrVariable v, VariableAccessCand cand,
568+
int nestLevel
524569
) {
525570
exists(int rnk |
526571
variableReachesRank(scope, name, v, rnk) and
527-
rnk = rankVariableOrAccess(scope, name, TVariableOrAccessCandVariableAccessCand(cand)) and
572+
rnk = rankVariableOrAccess(scope, name, TDefOrAccessCandVariableAccessCand(cand)) and
528573
variableAccessCandInScope(cand, scope, name, nestLevel, _)
529574
)
530575
}
531576

577+
pragma[nomagic]
578+
predicate access(string name, NestedFunctionOrVariable v, VariableAccessCand cand) {
579+
v =
580+
min(NestedFunctionOrVariable v0, int nestLevel |
581+
variableReachesCand(_, name, v0, cand, nestLevel)
582+
|
583+
v0 order by nestLevel
584+
)
585+
}
586+
532587
/** A variable access. */
533-
class VariableAccess extends PathExprBaseImpl::PathExprBase instanceof VariableAccessCand {
588+
class VariableAccess extends PathExprBaseImpl::PathExprBase {
534589
private string name;
535590
private Variable v;
536591

@@ -574,6 +629,16 @@ module Impl {
574629
}
575630
}
576631

632+
/** A nested function access. */
633+
class NestedFunctionAccess extends PathExprBaseImpl::PathExprBase {
634+
private Function f;
635+
636+
NestedFunctionAccess() { nestedFunctionAccess(_, f, this) }
637+
638+
/** Gets the function being accessed. */
639+
Function getFunction() { result = f }
640+
}
641+
577642
cached
578643
private module Cached {
579644
cached
@@ -582,12 +647,12 @@ module Impl {
582647

583648
cached
584649
predicate variableAccess(string name, Variable v, VariableAccessCand cand) {
585-
v =
586-
min(Variable v0, int nestLevel |
587-
variableReachesCand(_, name, v0, cand, nestLevel)
588-
|
589-
v0 order by nestLevel
590-
)
650+
access(name, TDefOrAccessCandVariable(v), cand)
651+
}
652+
653+
cached
654+
predicate nestedFunctionAccess(string name, Function f, VariableAccessCand cand) {
655+
access(name, TDefOrAccessCandNestedFunction(f, _), cand)
591656
}
592657
}
593658

rust/ql/test/library-tests/variables/Cfg.expected

+5-5
Original file line numberDiff line numberDiff line change
@@ -748,17 +748,17 @@ edges
748748
| main.rs:342:17:342:17 | 2 | main.rs:342:15:342:18 | f(...) | |
749749
| main.rs:344:5:358:5 | { ... } | main.rs:331:22:359:1 | { ... } | |
750750
| main.rs:345:9:345:17 | print_i64 | main.rs:345:19:345:19 | f | |
751-
| main.rs:345:9:345:23 | print_i64(...) | main.rs:345:26:348:9 | fn f | |
751+
| main.rs:345:9:345:23 | print_i64(...) | main.rs:346:9:348:9 | fn f | |
752752
| main.rs:345:9:345:24 | ExprStmt | main.rs:345:9:345:17 | print_i64 | |
753753
| main.rs:345:19:345:19 | f | main.rs:345:21:345:21 | 3 | |
754754
| main.rs:345:19:345:22 | f(...) | main.rs:345:9:345:23 | print_i64(...) | |
755755
| main.rs:345:21:345:21 | 3 | main.rs:345:19:345:22 | f(...) | |
756-
| main.rs:345:26:348:9 | enter fn f | main.rs:346:14:346:14 | x | |
757-
| main.rs:345:26:348:9 | exit fn f (normal) | main.rs:345:26:348:9 | exit fn f | |
758-
| main.rs:345:26:348:9 | fn f | main.rs:350:9:352:9 | ExprStmt | |
756+
| main.rs:346:9:348:9 | enter fn f | main.rs:346:14:346:14 | x | |
757+
| main.rs:346:9:348:9 | exit fn f (normal) | main.rs:346:9:348:9 | exit fn f | |
758+
| main.rs:346:9:348:9 | fn f | main.rs:350:9:352:9 | ExprStmt | |
759759
| main.rs:346:14:346:14 | x | main.rs:346:14:346:19 | ...: i64 | match |
760760
| main.rs:346:14:346:19 | ...: i64 | main.rs:347:13:347:13 | 2 | |
761-
| main.rs:346:29:348:9 | { ... } | main.rs:345:26:348:9 | exit fn f (normal) | |
761+
| main.rs:346:29:348:9 | { ... } | main.rs:346:9:348:9 | exit fn f (normal) | |
762762
| main.rs:347:13:347:13 | 2 | main.rs:347:17:347:17 | x | |
763763
| main.rs:347:13:347:17 | ... * ... | main.rs:346:29:348:9 | { ... } | |
764764
| main.rs:347:17:347:17 | x | main.rs:347:13:347:17 | ... * ... | |

rust/ql/test/library-tests/variables/Ssa.expected

+1-5
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@ read
239239
| main.rs:326:9:326:10 | n2 | main.rs:326:9:326:10 | n2 | main.rs:328:15:328:16 | n2 |
240240
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:336:15:336:15 | f |
241241
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:342:15:342:15 | f |
242-
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:345:19:345:19 | f |
243-
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:351:23:351:23 | f |
244242
| main.rs:334:10:334:10 | x | main.rs:334:10:334:10 | x | main.rs:335:9:335:9 | x |
245243
| main.rs:338:10:338:10 | x | main.rs:338:10:338:10 | x | main.rs:339:9:339:9 | x |
246244
| main.rs:346:14:346:14 | x | main.rs:346:14:346:14 | x | main.rs:347:17:347:17 | x |
@@ -497,7 +495,7 @@ lastRead
497495
| main.rs:323:9:323:26 | immutable_variable | main.rs:323:9:323:26 | immutable_variable | main.rs:327:9:327:26 | immutable_variable |
498496
| main.rs:324:10:324:10 | x | main.rs:324:10:324:10 | x | main.rs:325:9:325:9 | x |
499497
| main.rs:326:9:326:10 | n2 | main.rs:326:9:326:10 | n2 | main.rs:328:15:328:16 | n2 |
500-
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:351:23:351:23 | f |
498+
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:342:15:342:15 | f |
501499
| main.rs:334:10:334:10 | x | main.rs:334:10:334:10 | x | main.rs:335:9:335:9 | x |
502500
| main.rs:338:10:338:10 | x | main.rs:338:10:338:10 | x | main.rs:339:9:339:9 | x |
503501
| main.rs:346:14:346:14 | x | main.rs:346:14:346:14 | x | main.rs:347:17:347:17 | x |
@@ -560,8 +558,6 @@ adjacentReads
560558
| main.rs:290:9:290:11 | a10 | main.rs:279:13:279:15 | a10 | main.rs:292:9:292:11 | a10 | main.rs:296:15:296:17 | a10 |
561559
| main.rs:290:9:290:11 | a10 | main.rs:279:13:279:15 | a10 | main.rs:296:15:296:17 | a10 | main.rs:310:15:310:17 | a10 |
562560
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:336:15:336:15 | f | main.rs:342:15:342:15 | f |
563-
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:342:15:342:15 | f | main.rs:345:19:345:19 | f |
564-
| main.rs:333:9:333:9 | f | main.rs:333:9:333:9 | f | main.rs:345:19:345:19 | f | main.rs:351:23:351:23 | f |
565561
| main.rs:386:17:386:17 | x | main.rs:386:17:386:17 | x | main.rs:387:6:387:6 | x | main.rs:388:10:388:10 | x |
566562
| main.rs:386:17:386:17 | x | main.rs:386:17:386:17 | x | main.rs:388:10:388:10 | x | main.rs:389:10:389:10 | x |
567563
| main.rs:386:17:386:17 | x | main.rs:386:17:386:17 | x | main.rs:389:10:389:10 | x | main.rs:390:12:390:12 | x |

rust/ql/test/library-tests/variables/main.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,13 @@ fn nested_function() {
342342
print_i64(f(2)); // $ read_access=f1
343343

344344
{
345-
print_i64(f(3)); // $ SPURIOUS: read_access=f1
345+
print_i64(f(3));
346346
fn f(x: i64) -> i64 { // x_3
347347
2 * x // $ read_access=x_3
348348
}
349349

350350
{
351-
print_i64(f(4)); // $ SPURIOUS: read_access=f1
351+
print_i64(f(4));
352352
}
353353

354354
let f = // f2

rust/ql/test/library-tests/variables/variables.expected

+3-4
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,7 @@ variableAccess
206206
| main.rs:336:15:336:15 | f | main.rs:333:9:333:9 | f |
207207
| main.rs:339:9:339:9 | x | main.rs:338:10:338:10 | x |
208208
| main.rs:342:15:342:15 | f | main.rs:333:9:333:9 | f |
209-
| main.rs:345:19:345:19 | f | main.rs:333:9:333:9 | f |
210209
| main.rs:347:17:347:17 | x | main.rs:346:14:346:14 | x |
211-
| main.rs:351:23:351:23 | f | main.rs:333:9:333:9 | f |
212210
| main.rs:356:13:356:13 | x | main.rs:355:14:355:14 | x |
213211
| main.rs:357:19:357:19 | f | main.rs:354:13:354:13 | f |
214212
| main.rs:365:12:365:12 | v | main.rs:362:9:362:9 | v |
@@ -393,9 +391,7 @@ variableReadAccess
393391
| main.rs:336:15:336:15 | f | main.rs:333:9:333:9 | f |
394392
| main.rs:339:9:339:9 | x | main.rs:338:10:338:10 | x |
395393
| main.rs:342:15:342:15 | f | main.rs:333:9:333:9 | f |
396-
| main.rs:345:19:345:19 | f | main.rs:333:9:333:9 | f |
397394
| main.rs:347:17:347:17 | x | main.rs:346:14:346:14 | x |
398-
| main.rs:351:23:351:23 | f | main.rs:333:9:333:9 | f |
399395
| main.rs:356:13:356:13 | x | main.rs:355:14:355:14 | x |
400396
| main.rs:357:19:357:19 | f | main.rs:354:13:354:13 | f |
401397
| main.rs:365:12:365:12 | v | main.rs:362:9:362:9 | v |
@@ -542,3 +538,6 @@ capturedAccess
542538
| main.rs:459:9:459:9 | z |
543539
| main.rs:468:9:468:9 | i |
544540
| main.rs:523:13:523:16 | self |
541+
nestedFunctionAccess
542+
| main.rs:345:19:345:19 | f | main.rs:346:9:348:9 | fn f |
543+
| main.rs:351:23:351:23 | f | main.rs:346:9:348:9 | fn f |

rust/ql/test/library-tests/variables/variables.ql

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import rust
22
import utils.test.InlineExpectationsTest
3+
import codeql.rust.elements.internal.VariableImpl::Impl as VariableImpl
34

45
query predicate variable(Variable v) { any() }
56

@@ -15,6 +16,10 @@ query predicate capturedVariable(Variable v) { v.isCaptured() }
1516

1617
query predicate capturedAccess(VariableAccess va) { va.isCapture() }
1718

19+
query predicate nestedFunctionAccess(VariableImpl::NestedFunctionAccess nfa, Function f) {
20+
f = nfa.getFunction()
21+
}
22+
1823
module VariableAccessTest implements TestSig {
1924
string getARelevantTag() { result = ["", "write_", "read_"] + "access" }
2025

0 commit comments

Comments
 (0)