Skip to content

Commit 02bcc41

Browse files
committed
clean it up
1 parent 4442762 commit 02bcc41

File tree

1 file changed

+84
-39
lines changed

1 file changed

+84
-39
lines changed

src/Lean/Compiler/LCNF/JoinPoints.lean

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ structure FindCtx where
4545
A map from function declarations that are currently in scope to their
4646
definition depth.
4747
-/
48-
scope : FVarIdMap Nat := {}
48+
scope : PersistentHashMap FVarId Nat := {}
4949
/--
5050
The current function binder we are inside of if any.
5151
-/
@@ -69,13 +69,15 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
6969
/--
7070
Attempt to find a join point candidate by its `FVarId`.
7171
-/
72+
@[inline]
7273
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
7374
return (← get).candidates[fvarId]?
7475

7576
/--
7677
Erase a join point candidate as well as all the ones that depend on it
7778
by its `FVarId`, no error is thrown is the candidate does not exist.
7879
-/
80+
@[inline]
7981
private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
8082
if let some info ← findCandidate? fvarId then
8183
modify (fun state => { state with candidates := state.candidates.erase fvarId })
@@ -84,37 +86,55 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
8486
/--
8587
Combinator for modifying the candidates in `FindM`.
8688
-/
89+
@[inline]
8790
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
8891
modify (fun state => {state with candidates := f state.candidates })
8992

9093
/--
9194
Remove all join point candidates contained in `a`.
9295
-/
93-
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
96+
@[inline]
97+
private def removeCandidatesInArg (a : Arg) : FindM Unit := do
9498
forFVarM eraseCandidate a
9599

96100
/--
97101
Remove all join point candidates contained in `a`.
98102
-/
99-
private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
103+
@[inline]
104+
private def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
100105
forFVarM eraseCandidate e
101106

102107
/--
103108
Add a new join point candidate to the state.
104109
-/
110+
@[inline]
105111
private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
106112
let cinfo := { arity, associated := ∅ }
107113
modifyCandidates (fun cs => cs.insert fvarId cinfo )
108114

109115
/--
110116
Add a new join point dependency from `src` to `dst`.
111117
-/
118+
@[inline]
112119
private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
113120
if let some targetInfo ← findCandidate? target then
114121
modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
115122
else
116123
eraseCandidate src
117124

125+
@[inline]
126+
private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α := do
127+
withReader (fun ctx => {
128+
ctx with
129+
definitionDepth := ctx.definitionDepth + 1,
130+
currentFunction := some decl.fvarId }) do
131+
x
132+
133+
@[inline]
134+
private def withDefinedFn (decl : FunDecl) (x : FindM α) : FindM α := do
135+
withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
136+
x
137+
118138
/--
119139
Find all `fun` declarations that qualify as a join point, that is:
120140
- are always fully applied
@@ -137,58 +157,83 @@ def test (b : Bool) (x y : Nat) : Nat :=
137157
fun f y =>
138158
let x := Nat.add y y
139159
myjp x
140-
fun f y =>
160+
fun g y =>
141161
let x := Nat.mul y y
142162
myjp x
143163
cases b (f x) (g y)
144164
```
145165
`f` and `g` can be detected as a join point right away, however
146166
`myjp` can only ever be detected as a join point after we have established
147167
this. This is because otherwise the calls to `myjp` in `f` and `g` would
148-
produce out of scope join point jumps.
168+
produce out of scope join point jumps. This analysis supports detecting `myjp`
169+
as a join point. However, it does not support this in situations where `f` and `g`
170+
are nested within another function block that might become a join point.
171+
We believe this should be fine because this analysis gets run multiple times together
172+
with floating declarations out of nested ones so the vast majority of practically
173+
detectable join points should be detectable.
149174
-/
150175
partial def find (decl : Decl) : CompilerM FindState := do
151176
let (_, candidates) ← decl.value.forCodeM go |>.run {} |>.run {}
152177
return candidates
153178
where
154179
go : Code → FindM Unit
155-
| .let decl k => do
156-
match k, decl.value with
157-
| .return valId, .fvar fvarId args =>
158-
args.forM removeCandidatesInArg
159-
if let some candidateInfo ← findCandidate? fvarId then
160-
-- Erase candidate that are not fully applied or applied outside of tail position
161-
if valId != decl.fvarId || args.size != candidateInfo.arity then
162-
eraseCandidate fvarId
163-
-- Out of scope join point candidate handling
164-
else
165-
let currDepth := (← read).definitionDepth
166-
let calleeDepth := (← read).scope.get! fvarId
167-
if currDepth == calleeDepth then
168-
return ()
169-
else if calleeDepth + 1 == currDepth then
170-
addDependency fvarId (← read).currentFunction.get!
171-
else
180+
| .let decl k => do
181+
match k, decl.value with
182+
| .return valId, .fvar fvarId args =>
183+
args.forM removeCandidatesInArg
184+
if let some candidateInfo ← findCandidate? fvarId then
185+
-- Erase candidate that are not fully applied or applied outside of tail position
186+
if valId != decl.fvarId || args.size != candidateInfo.arity then
172187
eraseCandidate fvarId
173-
| _, _ =>
174-
removeCandidatesInLetValue decl.value
175-
go k
176-
| .fun decl k => do
177-
addCandidate decl.fvarId decl.getArity
178-
withReader (fun ctx => {
179-
ctx with
180-
definitionDepth := ctx.definitionDepth + 1,
181-
currentFunction := some decl.fvarId }) do
188+
else
189+
let currDepth := (← read).definitionDepth
190+
let calleeDepth := (← read).scope.find! fvarId
191+
if currDepth == calleeDepth then
192+
/-
193+
we are in a situation like:
194+
fun f x :=
195+
...
196+
...
197+
f ()
198+
-/
199+
return ()
200+
else if calleeDepth + 1 == currDepth then
201+
/-
202+
we are in a situation like:
203+
fun f x :=
204+
...
205+
fun g x :=
206+
...
207+
f ()
208+
-/
209+
addDependency fvarId (← read).currentFunction.get!
210+
else
211+
/-
212+
we are in a situation like:
213+
fun f x :=
214+
...
215+
fun h x :=
216+
fun g x :=
217+
...
218+
f ()
219+
-/
220+
eraseCandidate fvarId
221+
| _, _ =>
222+
removeCandidatesInLetValue decl.value
223+
go k
224+
| .fun decl k => do
225+
addCandidate decl.fvarId decl.getArity
226+
withFnBody decl do
227+
go decl.value
228+
withDefinedFn decl do
229+
go k
230+
| .jp decl k => do
182231
go decl.value
183-
withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
184232
go k
185-
| .jp decl k => do
186-
go decl.value
187-
go k
188-
| .jmp _ args => args.forM removeCandidatesInArg
189-
| .return val => eraseCandidate val
190-
| .cases c => c.alts.forM (·.forCodeM go)
191-
| .unreach .. => return ()
233+
| .jmp _ args => args.forM removeCandidatesInArg
234+
| .return val => eraseCandidate val
235+
| .cases c => c.alts.forM (·.forCodeM go)
236+
| .unreach .. => return ()
192237

193238
/--
194239
Replace all join point candidate `fun` declarations with `jp` ones

0 commit comments

Comments
 (0)