@@ -45,7 +45,7 @@ structure FindCtx where
4545 A map from function declarations that are currently in scope to their
4646 definition depth.
4747 -/
48- scope : FVarIdMap Nat := {}
48+ scope : PersistentHashMap FVarId Nat := {}
4949 /--
5050 The current function binder we are inside of if any.
5151 -/
@@ -69,13 +69,15 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
6969/--
7070Attempt to find a join point candidate by its `FVarId`.
7171-/
72+ @[inline]
7273private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
7374 return (← get).candidates[fvarId]?
7475
7576/--
7677Erase a join point candidate as well as all the ones that depend on it
7778by its `FVarId`, no error is thrown is the candidate does not exist.
7879-/
80+ @[inline]
7981private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
8082 if let some info ← findCandidate? fvarId then
8183 modify (fun state => { state with candidates := state.candidates.erase fvarId })
@@ -84,37 +86,55 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
8486/--
8587Combinator for modifying the candidates in `FindM`.
8688-/
89+ @[inline]
8790private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
8891 modify (fun state => {state with candidates := f state.candidates })
8992
9093/--
9194Remove all join point candidates contained in `a`.
9295-/
93- private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
96+ @[inline]
97+ private def removeCandidatesInArg (a : Arg) : FindM Unit := do
9498 forFVarM eraseCandidate a
9599
96100/--
97101Remove all join point candidates contained in `a`.
98102-/
99- private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
103+ @[inline]
104+ private def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
100105 forFVarM eraseCandidate e
101106
102107/--
103108Add a new join point candidate to the state.
104109-/
110+ @[inline]
105111private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
106112 let cinfo := { arity, associated := ∅ }
107113 modifyCandidates (fun cs => cs.insert fvarId cinfo )
108114
109115/--
110116Add a new join point dependency from `src` to `dst`.
111117-/
118+ @[inline]
112119private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
113120 if let some targetInfo ← findCandidate? target then
114121 modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
115122 else
116123 eraseCandidate src
117124
125+ @[inline]
126+ private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α := do
127+ withReader (fun ctx => {
128+ ctx with
129+ definitionDepth := ctx.definitionDepth + 1 ,
130+ currentFunction := some decl.fvarId }) do
131+ x
132+
133+ @[inline]
134+ private def withDefinedFn (decl : FunDecl) (x : FindM α) : FindM α := do
135+ withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
136+ x
137+
118138/--
119139Find all `fun` declarations that qualify as a join point, that is:
120140- are always fully applied
@@ -137,58 +157,83 @@ def test (b : Bool) (x y : Nat) : Nat :=
137157 fun f y =>
138158 let x := Nat.add y y
139159 myjp x
140- fun f y =>
160+ fun g y =>
141161 let x := Nat.mul y y
142162 myjp x
143163 cases b (f x) (g y)
144164```
145165`f` and `g` can be detected as a join point right away, however
146166`myjp` can only ever be detected as a join point after we have established
147167this. This is because otherwise the calls to `myjp` in `f` and `g` would
148- produce out of scope join point jumps.
168+ produce out of scope join point jumps. This analysis supports detecting `myjp`
169+ as a join point. However, it does not support this in situations where `f` and `g`
170+ are nested within another function block that might become a join point.
171+ We believe this should be fine because this analysis gets run multiple times together
172+ with floating declarations out of nested ones so the vast majority of practically
173+ detectable join points should be detectable.
149174-/
150175partial def find (decl : Decl) : CompilerM FindState := do
151176 let (_, candidates) ← decl.value.forCodeM go |>.run {} |>.run {}
152177 return candidates
153178where
154179 go : Code → FindM Unit
155- | .let decl k => do
156- match k, decl.value with
157- | .return valId, .fvar fvarId args =>
158- args.forM removeCandidatesInArg
159- if let some candidateInfo ← findCandidate? fvarId then
160- -- Erase candidate that are not fully applied or applied outside of tail position
161- if valId != decl.fvarId || args.size != candidateInfo.arity then
162- eraseCandidate fvarId
163- -- Out of scope join point candidate handling
164- else
165- let currDepth := (← read).definitionDepth
166- let calleeDepth := (← read).scope.get! fvarId
167- if currDepth == calleeDepth then
168- return ()
169- else if calleeDepth + 1 == currDepth then
170- addDependency fvarId (← read).currentFunction.get!
171- else
180+ | .let decl k => do
181+ match k, decl.value with
182+ | .return valId, .fvar fvarId args =>
183+ args.forM removeCandidatesInArg
184+ if let some candidateInfo ← findCandidate? fvarId then
185+ -- Erase candidate that are not fully applied or applied outside of tail position
186+ if valId != decl.fvarId || args.size != candidateInfo.arity then
172187 eraseCandidate fvarId
173- | _, _ =>
174- removeCandidatesInLetValue decl.value
175- go k
176- | .fun decl k => do
177- addCandidate decl.fvarId decl.getArity
178- withReader (fun ctx => {
179- ctx with
180- definitionDepth := ctx.definitionDepth + 1 ,
181- currentFunction := some decl.fvarId }) do
188+ else
189+ let currDepth := (← read).definitionDepth
190+ let calleeDepth := (← read).scope.find! fvarId
191+ if currDepth == calleeDepth then
192+ /-
193+ we are in a situation like:
194+ fun f x :=
195+ ...
196+ ...
197+ f ()
198+ -/
199+ return ()
200+ else if calleeDepth + 1 == currDepth then
201+ /-
202+ we are in a situation like:
203+ fun f x :=
204+ ...
205+ fun g x :=
206+ ...
207+ f ()
208+ -/
209+ addDependency fvarId (← read).currentFunction.get!
210+ else
211+ /-
212+ we are in a situation like:
213+ fun f x :=
214+ ...
215+ fun h x :=
216+ fun g x :=
217+ ...
218+ f ()
219+ -/
220+ eraseCandidate fvarId
221+ | _, _ =>
222+ removeCandidatesInLetValue decl.value
223+ go k
224+ | .fun decl k => do
225+ addCandidate decl.fvarId decl.getArity
226+ withFnBody decl do
227+ go decl.value
228+ withDefinedFn decl do
229+ go k
230+ | .jp decl k => do
182231 go decl.value
183- withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
184232 go k
185- | .jp decl k => do
186- go decl.value
187- go k
188- | .jmp _ args => args.forM removeCandidatesInArg
189- | .return val => eraseCandidate val
190- | .cases c => c.alts.forM (·.forCodeM go)
191- | .unreach .. => return ()
233+ | .jmp _ args => args.forM removeCandidatesInArg
234+ | .return val => eraseCandidate val
235+ | .cases c => c.alts.forM (·.forCodeM go)
236+ | .unreach .. => return ()
192237
193238/--
194239Replace all join point candidate `fun` declarations with `jp` ones
0 commit comments