@@ -72,7 +72,7 @@ impl Function {
7272 let mut context = Context :: default ( ) ;
7373
7474 context. find_rcs_in_entry_block ( self ) ;
75- context. scan_for_array_sets ( self ) ;
75+ context. scan_for_mutations ( self ) ;
7676 let to_remove = context. find_rcs_to_remove ( self ) ;
7777 remove_instructions ( to_remove, self ) ;
7878 }
@@ -99,20 +99,51 @@ impl Context {
9999 }
100100 }
101101
102- /// Find each array_set instruction in the function and mark any arrays used
102+ /// Find each array_set or call instruction in the function and mark any arrays used
103103 /// by the inc_rc instructions as possibly mutated if they're the same type.
104- fn scan_for_array_sets ( & mut self , function : & Function ) {
104+ fn scan_for_mutations ( & mut self , function : & Function ) {
105105 for block in function. reachable_blocks ( ) {
106106 for instruction in function. dfg [ block] . instructions ( ) {
107- if let Instruction :: ArraySet { array, .. } = function. dfg [ * instruction] {
108- let typ = function. dfg . type_of_value ( array) ;
109- if let Some ( inc_rcs) = self . inc_rcs . get_mut ( & typ) {
110- for inc_rc in inc_rcs {
111- inc_rc. possibly_mutated = true ;
107+ match & function. dfg [ * instruction] {
108+ Instruction :: ArraySet { array, .. } => {
109+ let typ = function. dfg . type_of_value ( * array) ;
110+ self . mark_as_mutated ( & typ) ;
111+ }
112+ Instruction :: Call { arguments, .. } => {
113+ // A call with an array argument could mutate that array
114+ for arg in arguments {
115+ let typ = function. dfg . type_of_value ( * arg) ;
116+ self . mark_each_contained_array_as_mutated ( & typ) ;
112117 }
113118 }
119+
120+ _ => { }
121+ }
122+ }
123+ }
124+ }
125+
126+ fn mark_as_mutated ( & mut self , typ : & Type ) {
127+ if let Some ( inc_rcs) = self . inc_rcs . get_mut ( typ) {
128+ for inc_rc in inc_rcs {
129+ inc_rc. possibly_mutated = true ;
130+ }
131+ }
132+ }
133+
134+ /// Recursively unwrap references and mark any contained arrays as mutated.
135+ fn mark_each_contained_array_as_mutated ( & mut self , typ : & Type ) {
136+ match typ {
137+ Type :: Reference ( element) => self . mark_each_contained_array_as_mutated ( element) ,
138+ Type :: Array ( element_types, _) | Type :: Vector ( element_types) => {
139+ // Mark the array type we have found as being possibly mutated
140+ self . mark_as_mutated ( typ) ;
141+ // We now need to also mark nested arrays which are possibly mutated
142+ for element in element_types. iter ( ) {
143+ self . mark_each_contained_array_as_mutated ( element) ;
114144 }
115145 }
146+ Type :: Numeric ( _) | Type :: Function => { }
116147 }
117148 }
118149
@@ -441,4 +472,91 @@ mod tests {
441472 }
442473 " ) ;
443474 }
475+
476+ #[ test]
477+ #[ ignore]
478+ fn mutation_through_call_with_mutable_reference ( ) {
479+ // We expect `inc_rc v0` to remain.
480+ // If you accessed v0 directly after the call (not through the reference):
481+ // - With inc_rc: v0 is protected, COW happens, v0 retains old value
482+ // - Without inc_rc: v0 is mutated in place
483+ let src = "
484+ brillig(inline) fn main f0 {
485+ b0(v0: [Field; 2]):
486+ inc_rc v0
487+ v1 = allocate -> &mut [Field; 2]
488+ store v0 at v1
489+ call f1(v1)
490+ v3 = load v1 -> [Field; 2]
491+ v5 = array_get v3, index u32 0 -> Field
492+ constrain v5 == Field 5
493+ dec_rc v0
494+ return
495+ }
496+ brillig(inline) fn mutator f1 {
497+ b0(v0: &mut [Field; 2]):
498+ v1 = load v0 -> [Field; 2]
499+ v2 = array_set v1, index u32 0, value Field 5
500+ store v2 at v0
501+ return
502+ }
503+ " ;
504+ assert_ssa_does_not_change ( src, Ssa :: remove_paired_rc) ;
505+ }
506+
507+ /// Same as [mutation_through_call_with_mutable_reference] except with a deeply nested reference to an array (e.g., `&mut &mut [Field; 2]`)
508+ #[ test]
509+ #[ ignore]
510+ fn mutation_through_call_with_deeply_nested_reference ( ) {
511+ let src = "
512+ brillig(inline) fn main f0 {
513+ b0(v0: [Field; 2]):
514+ inc_rc v0
515+ v1 = allocate -> &mut [Field; 2]
516+ store v0 at v1
517+ v2 = allocate -> &mut &mut [Field; 2]
518+ store v1 at v2
519+ call f1(v2)
520+ v4 = load v1 -> [Field; 2]
521+ v6 = array_get v4, index u32 0 -> Field
522+ constrain v6 == Field 5
523+ dec_rc v0
524+ return
525+ }
526+ brillig(inline) fn mutator f1 {
527+ b0(v0: &mut &mut [Field; 2]):
528+ v1 = load v0 -> &mut [Field; 2]
529+ v2 = load v1 -> [Field; 2]
530+ v3 = array_set v2, index u32 0, value Field 5
531+ store v3 at v1
532+ return
533+ }
534+ " ;
535+ assert_ssa_does_not_change ( src, Ssa :: remove_paired_rc) ;
536+ }
537+
538+ #[ test]
539+ #[ ignore]
540+ fn mutation_through_call_with_array_passed_by_value ( ) {
541+ // We expect `inc_rc v0` to remain
542+ // After the call to f1 we expect v0 to be unchanged.
543+ // If the inc_rc were removed, f1 would mutate v0 in place.
544+ let src = "
545+ brillig(inline) fn main f0 {
546+ b0(v0: [Field; 2]):
547+ inc_rc v0
548+ call f1(v0)
549+ v3 = array_get v0, index u32 0 -> Field
550+ constrain v3 == Field 5
551+ dec_rc v0
552+ return
553+ }
554+ brillig(inline) fn mutator f1 {
555+ b0(v0: [Field; 2]):
556+ v3 = array_set v0, index u32 0, value Field 5
557+ return
558+ }
559+ " ;
560+ assert_ssa_does_not_change ( src, Ssa :: remove_paired_rc) ;
561+ }
444562}
0 commit comments