@@ -45,7 +45,7 @@ structure FindCtx where
4545 A map from function declarations that are currently in scope to their
4646 definition depth.
4747 -/
48- scope : FVarIdMap Nat := {}
48+ scope : PersistentHashMap FVarId Nat := {}
4949 /--
5050 The current function binder we are inside of if any.
5151 -/
@@ -69,51 +69,42 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
6969/--
7070Attempt to find a join point candidate by its `FVarId`.
7171-/
72+ @[inline]
7273private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
7374 return (← get).candidates[fvarId]?
7475
75- /--
76- Erase a join point candidate as well as all the ones that depend on it
77- by its `FVarId`, no error is thrown is the candidate does not exist.
78- -/
79- private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
80- if let some info ← findCandidate? fvarId then
81- modify (fun state => { state with candidates := state.candidates.erase fvarId })
82- info.associated.forM eraseCandidate
83-
8476/--
8577Combinator for modifying the candidates in `FindM`.
8678-/
79+ @[inline]
8780private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
8881 modify (fun state => {state with candidates := f state.candidates })
8982
9083/--
91- Remove all join point candidates contained in `a`.
92- -/
93- private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
94- forFVarM eraseCandidate a
95-
96- /--
97- Remove all join point candidates contained in `a`.
84+ Erase a join point candidate as well as all the ones that depend on it
85+ by its `FVarId`, no error is thrown is the candidate does not exist.
9886-/
99- private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
100- forFVarM eraseCandidate e
87+ @[inline]
88+ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
89+ if let some info ← findCandidate? fvarId then
90+ modifyCandidates fun cs => cs.erase fvarId
91+ info.associated.forM eraseCandidate
10192
10293/--
10394Add a new join point candidate to the state.
10495-/
96+ @[inline]
10597private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
106- let cinfo := { arity, associated := ∅ }
107- modifyCandidates (fun cs => cs.insert fvarId cinfo )
98+ modifyCandidates fun cs => cs.insert fvarId { arity, associated := ∅ }
10899
109100/--
110101Add a new join point dependency from `src` to `dst`.
111102-/
103+ @[inline]
112104private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
113- if let some targetInfo ← findCandidate? target then
114- modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
115- else
116- eraseCandidate src
105+ modifyCandidates fun cs =>
106+ cs.modify target fun targetInfo =>
107+ { targetInfo with associated := targetInfo.associated.insert src }
117108
118109/--
119110Find all `fun` declarations that qualify as a join point, that is:
@@ -152,49 +143,51 @@ partial def find (decl : Decl) : CompilerM FindState := do
152143 return candidates
153144where
154145 go : Code → FindM Unit
155- | .let decl k => do
156- match k, decl.value with
157- | .return valId, .fvar fvarId args =>
158- args.forM removeCandidatesInArg
159- if let some candidateInfo ← findCandidate? fvarId then
160- -- Erase candidate that are not fully applied or applied outside of tail position
161- if valId != decl.fvarId || args.size != candidateInfo.arity then
162- eraseCandidate fvarId
163- -- Out of scope join point candidate handling
164- else
165- let currDepth := (← read).definitionDepth
166- let calleeDepth := (← read).scope.get! fvarId
167- if currDepth == calleeDepth then
168- return ()
169- else if calleeDepth + 1 == currDepth then
170- addDependency fvarId (← read).currentFunction.get!
171- else
146+ | .let decl k => do
147+ match k, decl.value with
148+ | .return valId, .fvar fvarId args =>
149+ args.forM (forFVarM eraseCandidate)
150+ if let some candidateInfo ← findCandidate? fvarId then
151+ -- Erase candidate that are not fully applied or applied outside of tail position
152+ if valId != decl.fvarId || args.size != candidateInfo.arity then
172153 eraseCandidate fvarId
173- | _, _ =>
174- removeCandidatesInLetValue decl.value
175- go k
176- | .fun decl k => do
177- addCandidate decl.fvarId decl.getArity
178- withReader (fun ctx => {
179- ctx with
180- definitionDepth := ctx.definitionDepth + 1 ,
181- currentFunction := some decl.fvarId }) do
154+ -- Out of scope join point candidate handling
155+ else
156+ let currDepth := (← read).definitionDepth
157+ let calleeDepth := (← read).scope.find! fvarId
158+ if currDepth == calleeDepth then
159+ return ()
160+ else if calleeDepth + 1 == currDepth then
161+ addDependency fvarId (← read).currentFunction.get!
162+ else
163+ eraseCandidate fvarId
164+ | _, _ =>
165+ forFVarM eraseCandidate decl.value
166+ go k
167+ | .fun decl k => do
168+ addCandidate decl.fvarId decl.getArity
169+ withReader (fun ctx => {
170+ ctx with
171+ definitionDepth := ctx.definitionDepth + 1 ,
172+ currentFunction := some decl.fvarId }) do
173+ go decl.value
174+ withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
175+ go k
176+ | .jp decl k => do
182177 go decl.value
183- withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
184178 go k
185- | .jp decl k => do
186- go decl.value
187- go k
188- | .jmp _ args => args.forM removeCandidatesInArg
189- | .return val => eraseCandidate val
190- | .cases c => c.alts.forM (·.forCodeM go)
191- | .unreach .. => return ()
179+ | .jmp _ args => args.forM (forFVarM eraseCandidate)
180+ | .return val => eraseCandidate val
181+ | .cases c => c.alts.forM (·.forCodeM go)
182+ | .unreach .. => return ()
192183
193184/--
194185Replace all join point candidate `fun` declarations with `jp` ones
195186and all calls to them with `jmp`s.
196187-/
197188partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
189+ if state.candidates.isEmpty then
190+ return decl
198191 let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
199192 let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := ∅) mapper
200193 let newValue ← decl.value.mapCodeM go |>.run replaceCtx
0 commit comments