Skip to content

Commit da37b5e

Browse files
committed
Follow re-export chains when checking call safety
Re-exported symbols (e.g. bar.Foo re-exporting foo.Foo) were not being resolved to their original definitions during safety analysis. This caused check_call to miss unsafe constructors/functions accessed through re-exports, incorrectly marking them as safe. Change re_exports from a HashSet to a HashMap that preserves the mapping to original names, add resolve_re_export() that follows chains, and apply it in check_call and the ImportedTypeAttr property lookup.
1 parent 26fb058 commit da37b5e

2 files changed

Lines changed: 120 additions & 10 deletions

File tree

src/project.rs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ fn merge_all_functions_and_methods(
174174
)
175175
}
176176

177-
fn get_all_safe_re_exports(effect_table: &EffectTable, re_exports: &mut AHashSet<ModuleName>) {
177+
fn get_all_safe_re_exports(
178+
effect_table: &EffectTable,
179+
re_exports: &mut AHashMap<ModuleName, ModuleName>,
180+
) {
178181
let unsafe_re_exports = effect_table
179182
.values()
180183
.flatten()
@@ -790,7 +793,7 @@ struct ProjectInfo {
790793
classes: ClassTable,
791794
// Mappings of functions to the containing module
792795
functions: AHashMap<ModuleName, ModuleName>,
793-
re_exports: AHashSet<ModuleName>,
796+
re_exports: AHashMap<ModuleName, ModuleName>,
794797
// Mapping of all methods called on imported objects
795798
methods: AHashMap<ModuleName, ModuleName>,
796799
// Reverse mapping: parent function → nested function scopes.
@@ -810,9 +813,9 @@ impl ProjectInfo {
810813
merge_all_classes(&mut analysis_map)
811814
});
812815
let re_exports = time(" Getting re-exports", || {
813-
let mut re_exports = exports
816+
let mut re_exports: AHashMap<ModuleName, ModuleName> = exports
814817
.get_re_exports()
815-
.map(|(name, _)| name.as_module_name())
818+
.map(|(name, (original, _))| (name.as_module_name(), original.as_module_name()))
816819
.collect();
817820
get_all_safe_re_exports(&effect_table, &mut re_exports);
818821
re_exports
@@ -833,6 +836,22 @@ impl ProjectInfo {
833836
}
834837
}
835838

839+
/// Resolve a name through the re-exports table, following chains
840+
/// (e.g. baz.Foo → bar.Foo → foo.Foo). Returns the original definition
841+
/// name, or the name unchanged if it is not a re-export. Bails out of
842+
/// cycles (e.g. `from b import X` / `from a import X`).
843+
fn resolve_re_export(&self, name: &ModuleName) -> ModuleName {
844+
let mut current = *name;
845+
let mut seen = AHashSet::new();
846+
while let Some(&original) = self.re_exports.get(&current) {
847+
if !seen.insert(current) {
848+
break;
849+
}
850+
current = original;
851+
}
852+
current
853+
}
854+
836855
pub fn contains_callable(&self, name: &ModuleName) -> bool {
837856
if self.functions.contains_key(name) || self.classes.contains(name) {
838857
return true;
@@ -841,7 +860,7 @@ impl ProjectInfo {
841860
if self.functions.contains_key(&call_name) || self.classes.contains(&call_name) {
842861
true
843862
} else {
844-
self.re_exports.contains(&call_name)
863+
self.re_exports.contains_key(&call_name)
845864
}
846865
}
847866

@@ -943,6 +962,7 @@ impl ProjectInfo {
943962
} else if eff.kind == EffectKind::ImportedTypeAttr {
944963
// Check if this is a property access
945964
if let Some((typ, attr)) = eff.name.split_attr() {
965+
let typ = self.resolve_re_export(&typ);
946966
if let Some(field) = self
947967
.classes
948968
.lookup(&typ)
@@ -1024,11 +1044,15 @@ impl ProjectInfo {
10241044
}
10251045

10261046
fn check_call(&self, call: &mut Call, state: &GlobalAnalysisState) -> Result<bool> {
1027-
if self.classes.contains(&call.func) {
1047+
// Resolve re-exports to their original name so that class/function
1048+
// lookups find the actual definition (e.g. bar.Foo → foo.Foo).
1049+
let resolved_func = self.resolve_re_export(&call.func);
1050+
let mut resolved_call = call.clone_with_name(resolved_func);
1051+
if self.classes.contains(&resolved_call.func) {
10281052
// This is a class constructor
1029-
self.check_constructor_call(call, state)
1053+
self.check_constructor_call(&resolved_call, state)
10301054
} else {
1031-
self.check_call_body(call, state)
1055+
self.check_call_body(&mut resolved_call, state)
10321056
}
10331057
}
10341058

tests/calls.rs

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@ import importlib
222222
from importlib import import_module
223223
224224
a = importlib.import_module("sys")
225-
# This depends on a chain of import aliases which we don't handle well
226-
b = importlib.__import__("math") # TODO: unsafe-function-call
225+
b = importlib.__import__("math") # E: unsafe-function-call
227226
228227
import_module("bar")
229228
"#;
@@ -478,4 +477,91 @@ foo.bar.f()
478477
"#;
479478
check_all(vec![("foo.bar", foo_bar), ("__main__", __main__)]);
480479
}
480+
481+
#[test]
482+
fn test_reexported_unsafe_class() {
483+
let foo = r#"
484+
class Foo:
485+
def __init__(self) -> None:
486+
raise Exception
487+
"#;
488+
let bar = r#"
489+
from foo import Foo
490+
"#;
491+
let baz = r#"
492+
from bar import Foo
493+
x = Foo() # E: unsafe-function-call
494+
"#;
495+
check_all(vec![("foo", foo), ("bar", bar), ("baz", baz)]);
496+
}
497+
498+
#[test]
499+
fn test_reexported_safe_class() {
500+
let foo = r#"
501+
class Foo:
502+
def __init__(self) -> None:
503+
pass
504+
"#;
505+
let bar = r#"
506+
from foo import Foo
507+
"#;
508+
let baz = r#"
509+
from bar import Foo
510+
x = Foo()
511+
"#;
512+
check_all(vec![("foo", foo), ("bar", bar), ("baz", baz)]);
513+
}
514+
515+
#[test]
516+
fn test_reexported_chain_unsafe_class() {
517+
let foo = r#"
518+
class Foo:
519+
def __init__(self) -> None:
520+
raise Exception
521+
"#;
522+
let bar = r#"
523+
from foo import Foo
524+
"#;
525+
let baz = r#"
526+
from bar import Foo
527+
"#;
528+
let consumer = r#"
529+
from baz import Foo
530+
x = Foo() # E: unsafe-function-call
531+
"#;
532+
check_all(vec![("foo", foo), ("bar", bar), ("baz", baz), ("consumer", consumer)]);
533+
}
534+
535+
#[test]
536+
fn test_reexport_cycle_terminates() {
537+
// `a.X` and `b.X` re-export each other with no real definition.
538+
// resolve_re_export must break out of the cycle rather than spin.
539+
let a = r#"
540+
from b import X
541+
"#;
542+
let b = r#"
543+
from a import X
544+
"#;
545+
let main = r#"
546+
from a import X
547+
x = X()
548+
"#;
549+
check_all(vec![("a", a), ("b", b), ("main", main)]);
550+
}
551+
552+
#[test]
553+
fn test_reexported_unsafe_function() {
554+
let foo = r#"
555+
def f():
556+
raise Exception
557+
"#;
558+
let bar = r#"
559+
from foo import f
560+
"#;
561+
let baz = r#"
562+
from bar import f
563+
f() # E: unsafe-function-call
564+
"#;
565+
check_all(vec![("foo", foo), ("bar", bar), ("baz", baz)]);
566+
}
481567
}

0 commit comments

Comments
 (0)