1+ use std:: collections:: BTreeMap ;
12use std:: sync:: Arc ;
23
34use rustc_hash:: { FxHashMap as HashMap , FxHashSet as HashSet } ;
45
56use crate :: ssa:: ir:: call_graph:: CallGraph ;
7+ use crate :: ssa:: opt:: brillig_entry_points:: get_brillig_entry_points;
68use crate :: ssa:: {
79 ir:: {
810 function:: { Function , FunctionId } ,
@@ -27,17 +29,26 @@ impl Ssa {
2729 #[ tracing:: instrument( level = "trace" , skip( self ) ) ]
2830 pub ( crate ) fn purity_analysis ( mut self ) -> Ssa {
2931 let call_graph = CallGraph :: from_ssa ( & self ) ;
32+ let brillig_entry_points =
33+ get_brillig_entry_points ( & self . functions , self . main_id , & call_graph) ;
3034
3135 let ( sccs, recursive_functions) = call_graph. sccs ( ) ;
3236
3337 // First look through each function to get a baseline on its purity and collect
3438 // the functions it calls to build a call graph.
35- let purities: HashMap < _ , _ > =
36- self . functions . values ( ) . map ( |function| ( function. id ( ) , function. is_pure ( ) ) ) . collect ( ) ;
39+ let purities: HashMap < _ , _ > = self
40+ . functions
41+ . values ( )
42+ . map ( |function| {
43+ let is_brillig_entry_point = brillig_entry_points. contains ( & function. id ( ) ) ;
44+ ( function. id ( ) , function. is_pure ( is_brillig_entry_point) )
45+ } )
46+ . collect ( ) ;
3747
3848 // Then transitively 'infect' any functions which call impure functions as also
3949 // impure.
40- let purities = analyze_call_graph ( call_graph, purities, & sccs, & recursive_functions) ;
50+ let purities =
51+ analyze_call_graph ( call_graph, purities, & sccs, & recursive_functions, & self . functions ) ;
4152 let purities = Arc :: new ( purities) ;
4253
4354 // We're done, now store purities somewhere every dfg can find it.
@@ -113,7 +124,7 @@ impl std::fmt::Display for Purity {
113124}
114125
115126impl Function {
116- fn is_pure ( & self ) -> Purity {
127+ fn is_pure ( & self , is_brillig_entrypoint : bool ) -> Purity {
117128 let contains_reference = |value_id : & ValueId | {
118129 let typ = self . dfg . type_of_value ( * value_id) ;
119130 typ. contains_reference ( )
@@ -123,7 +134,7 @@ impl Function {
123134 return Purity :: Impure ;
124135 }
125136
126- let mut result = if self . runtime ( ) . is_acir ( ) {
137+ let mut result = if self . runtime ( ) . is_acir ( ) || !is_brillig_entrypoint {
127138 Purity :: Pure
128139 } else {
129140 // Because we return bogus values when a brillig function is called from acir
@@ -229,6 +240,7 @@ fn analyze_call_graph(
229240 starting_purities : FunctionPurities ,
230241 sccs : & [ Vec < FunctionId > ] ,
231242 recursive_functions : & HashSet < FunctionId > ,
243+ functions : & BTreeMap < FunctionId , Function > ,
232244) -> FunctionPurities {
233245 let mut finished = HashMap :: default ( ) ;
234246
@@ -269,11 +281,16 @@ fn analyze_call_graph(
269281 }
270282 }
271283
272- // Recursive functions cannot be fully pure (may recurse indefinitely),
273- // but we still treat them as PureWithPredicate for deduplication purposes.
274- // If we were to mark recursive functions pure we may entirely eliminate an infinite loop.
284+ // Recursive functions cannot be fully pure (may recurse indefinitely).
285+ // For Brillig functions, recursion is a side-effect (infinite loop),
286+ // so we mark them as Impure to prevent incorrect elimination.
287+ // For ACIR functions, we treat them as PureWithPredicate for deduplication purposes.
275288 if recursive_functions. contains ( & func) {
276- combined_purity = combined_purity. unify ( Purity :: PureWithPredicate ) ;
289+ if functions[ & func] . runtime ( ) . is_brillig ( ) {
290+ combined_purity = combined_purity. unify ( Purity :: Impure ) ;
291+ } else {
292+ combined_purity = combined_purity. unify ( Purity :: PureWithPredicate ) ;
293+ }
277294 }
278295 }
279296
@@ -485,7 +502,7 @@ mod tests {
485502 assert_eq ! ( purities[ & FunctionId :: test_new( 0 ) ] , Purity :: Impure ) ;
486503 assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: Impure ) ;
487504 assert_eq ! ( purities[ & FunctionId :: test_new( 2 ) ] , Purity :: Impure ) ;
488- assert_eq ! ( purities[ & FunctionId :: test_new( 3 ) ] , Purity :: PureWithPredicate ) ;
505+ assert_eq ! ( purities[ & FunctionId :: test_new( 3 ) ] , Purity :: Pure ) ;
489506 }
490507
491508 #[ test]
@@ -506,7 +523,7 @@ mod tests {
506523
507524 let purities = & ssa. main ( ) . dfg . function_purities ;
508525 assert_eq ! ( purities[ & FunctionId :: test_new( 0 ) ] , Purity :: PureWithPredicate ) ;
509- assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: PureWithPredicate ) ;
526+ assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: Pure ) ;
510527 }
511528
512529 /// Functions using inc_rc or dec_rc are always impure - see constant_folding::do_not_deduplicate_call_with_inc_rc
@@ -589,7 +606,7 @@ mod tests {
589606 }
590607
591608 #[ test]
592- fn direct_brillig_recursion_marks_functions_pure_with_predicate ( ) {
609+ fn direct_brillig_recursion_marks_functions_impure ( ) {
593610 let src = r#"
594611 brillig(inline) fn main f0 {
595612 b0():
@@ -607,14 +624,13 @@ mod tests {
607624 let ssa = ssa. purity_analysis ( ) ;
608625
609626 let purities = & ssa. main ( ) . dfg . function_purities ;
610- assert_eq ! ( purities[ & FunctionId :: test_new( 0 ) ] , Purity :: PureWithPredicate ) ;
611- assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: PureWithPredicate ) ;
627+ assert_eq ! ( purities[ & FunctionId :: test_new( 0 ) ] , Purity :: Impure ) ;
628+ assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: Impure ) ;
612629 }
613630
614631 #[ test]
615- fn mutual_recursion_marks_functions_pure ( ) {
616- // We want to test that two pure mutually recursive functions do in fact mark each other as PureWithPredicate.
617- // If we have indefinite recursion and we may accidentally eliminate an infinite loop before inlining can catch it.
632+ fn mutual_recursion_marks_acir_functions_pure_with_predicate ( ) {
633+ // We want to test that two pure mutually recursive ACIR functions are marked as PureWithPredicate.
618634 let src = r#"
619635 acir(inline) fn main f0 {
620636 b0():
@@ -658,9 +674,9 @@ mod tests {
658674 assert_eq ! ( purities[ & FunctionId :: test_new( 2 ) ] , Purity :: PureWithPredicate ) ;
659675 }
660676
661- /// This test matches [mutual_recursion_marks_functions_pure] except all functions have a Brillig runtime
662677 #[ test]
663- fn brillig_mutual_recursion_marks_functions_pure_with_predicate ( ) {
678+ fn brillig_mutual_recursion_marks_functions_impure ( ) {
679+ // Brillig mutually recursive functions should be marked as impure
664680 let src = r#"
665681 brillig(inline) fn main f0 {
666682 b0():
@@ -699,9 +715,9 @@ mod tests {
699715 let ssa = ssa. purity_analysis ( ) ;
700716
701717 let purities = & ssa. main ( ) . dfg . function_purities ;
702- assert_eq ! ( purities[ & FunctionId :: test_new( 0 ) ] , Purity :: PureWithPredicate ) ;
703- assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: PureWithPredicate ) ;
704- assert_eq ! ( purities[ & FunctionId :: test_new( 2 ) ] , Purity :: PureWithPredicate ) ;
718+ assert_eq ! ( purities[ & FunctionId :: test_new( 0 ) ] , Purity :: Impure ) ;
719+ assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: Impure ) ;
720+ assert_eq ! ( purities[ & FunctionId :: test_new( 2 ) ] , Purity :: Impure ) ;
705721 }
706722
707723 #[ test]
@@ -753,7 +769,7 @@ mod tests {
753769
754770 /// This test matches [mutual_recursion_marks_functions_impure] except all functions have a Brillig runtime
755771 #[ test]
756- fn brillig_mutual_recursion_marks_functions_impure ( ) {
772+ fn brillig_mutual_recursion_marks_functions_impure_with_reference ( ) {
757773 let src = r#"
758774 brillig(inline) fn main f0 {
759775 b0():
@@ -817,7 +833,7 @@ mod tests {
817833 }
818834
819835 #[ test]
820- fn brillig_functions_are_pure_with_predicate_if_they_are_not_an_entry_point ( ) {
836+ fn brillig_functions_are_pure_if_they_are_not_an_entry_point ( ) {
821837 let src = "
822838 brillig(inline) fn main f0 {
823839 b0(v0: u1):
@@ -840,10 +856,7 @@ mod tests {
840856
841857 let purities = & ssa. main ( ) . dfg . function_purities ;
842858 assert_eq ! ( purities[ & FunctionId :: test_new( 0 ) ] , Purity :: PureWithPredicate ) ;
843-
844- // Note: even though it would be fine to mark f1 as pure, something in Aztec-Packages
845- // gets broken so until we figure out what that is we can't mark these as pure.
846- assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: PureWithPredicate ) ;
859+ assert_eq ! ( purities[ & FunctionId :: test_new( 1 ) ] , Purity :: Pure ) ;
847860 }
848861
849862 #[ test]
0 commit comments