@@ -49,11 +49,23 @@ partial def consumed (x : VarId) : FnBody → Bool
4949 | e => !e.isTerminal && consumed x e.body
5050
5151abbrev Mask := Array (Option VarId)
52+ abbrev ProjCounts := Std.HashMap (VarId × Nat) Nat
53+
54+ partial def computeProjCounts (bs : Array FnBody) : ProjCounts :=
55+ let incrementCountIfProj r b :=
56+ if let .vdecl _ _ (.proj i v) _ := b then
57+ r.alter (v, i) fun
58+ | some n => some (n + 1 )
59+ | none => some 1
60+ else
61+ r
62+ bs.foldl incrementCountIfProj Std.HashMap.emptyWithCapacity
5263
5364/-- Auxiliary function for eraseProjIncFor -/
54- partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
65+ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (projCounts : ProjCounts)
66+ (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
5567 let done (_ : Unit) := (bs ++ keep.reverse, mask)
56- let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
68+ let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop projCounts mask (keep.push b)
5769 if h : bs.size < 2 then done ()
5870 else
5971 let b := bs.back!
@@ -65,7 +77,10 @@ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (ke
6577 let b' := bs[bs.size - 2 ]
6678 match b' with
6779 | .vdecl w _ (.proj i x) _ =>
68- if w == z && y == x then
80+ -- We disable the inc optimization if there are multiple projections with the same base
81+ -- and index, because the downstream transformations are incapable of correctly handling
82+ -- the aliasing.
83+ if w == z && y == x && projCounts[(x, i)]! == 1 then
6984 /- Found
7085 ```
7186 let z := proj[i] y
@@ -77,15 +92,15 @@ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (ke
7792 let mask := mask.set! i (some z)
7893 let keep := keep.push b'
7994 let keep := if n == 1 then keep else keep.push (FnBody.inc z (n-1 ) c p FnBody.nil)
80- eraseProjIncForAux y bs mask keep
95+ eraseProjIncForAux y bs projCounts mask keep
8196 else done ()
8297 | _ => done ()
8398 | _ => done ()
8499
85100/-- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
86101 Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
87102def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
88- eraseProjIncForAux y bs (.replicate n none) #[]
103+ eraseProjIncForAux y bs (computeProjCounts bs) ( .replicate n none) #[]
89104
90105/-- Replace `reuse x ctor ...` with `ctor ...`, and remove `dec x` -/
91106partial def reuseToCtor (x : VarId) : FnBody → FnBody
0 commit comments