Skip to content

Commit 5a0b63a

Browse files
committed
lets try again!
1 parent db7cda4 commit 5a0b63a

File tree

1 file changed

+52
-59
lines changed

1 file changed

+52
-59
lines changed

src/Lean/Compiler/LCNF/JoinPoints.lean

Lines changed: 52 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ structure FindCtx where
4545
A map from function declarations that are currently in scope to their
4646
definition depth.
4747
-/
48-
scope : FVarIdMap Nat := {}
48+
scope : PersistentHashMap FVarId Nat := {}
4949
/--
5050
The current function binder we are inside of if any.
5151
-/
@@ -69,51 +69,42 @@ abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
6969
/--
7070
Attempt to find a join point candidate by its `FVarId`.
7171
-/
72+
@[inline]
7273
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
7374
return (← get).candidates[fvarId]?
7475

75-
/--
76-
Erase a join point candidate as well as all the ones that depend on it
77-
by its `FVarId`, no error is thrown is the candidate does not exist.
78-
-/
79-
private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
80-
if let some info ← findCandidate? fvarId then
81-
modify (fun state => { state with candidates := state.candidates.erase fvarId })
82-
info.associated.forM eraseCandidate
83-
8476
/--
8577
Combinator for modifying the candidates in `FindM`.
8678
-/
79+
@[inline]
8780
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
8881
modify (fun state => {state with candidates := f state.candidates })
8982

9083
/--
91-
Remove all join point candidates contained in `a`.
92-
-/
93-
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
94-
forFVarM eraseCandidate a
95-
96-
/--
97-
Remove all join point candidates contained in `a`.
84+
Erase a join point candidate as well as all the ones that depend on it
85+
by its `FVarId`, no error is thrown is the candidate does not exist.
9886
-/
99-
private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
100-
forFVarM eraseCandidate e
87+
@[inline]
88+
private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
89+
if let some info ← findCandidate? fvarId then
90+
modifyCandidates fun cs => cs.erase fvarId
91+
info.associated.forM eraseCandidate
10192

10293
/--
10394
Add a new join point candidate to the state.
10495
-/
96+
@[inline]
10597
private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
106-
let cinfo := { arity, associated := ∅ }
107-
modifyCandidates (fun cs => cs.insert fvarId cinfo )
98+
modifyCandidates fun cs => cs.insert fvarId { arity, associated := ∅ }
10899

109100
/--
110101
Add a new join point dependency from `src` to `dst`.
111102
-/
103+
@[inline]
112104
private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
113-
if let some targetInfo ← findCandidate? target then
114-
modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
115-
else
116-
eraseCandidate src
105+
modifyCandidates fun cs =>
106+
cs.modify target fun targetInfo =>
107+
{ targetInfo with associated := targetInfo.associated.insert src }
117108

118109
/--
119110
Find all `fun` declarations that qualify as a join point, that is:
@@ -152,49 +143,51 @@ partial def find (decl : Decl) : CompilerM FindState := do
152143
return candidates
153144
where
154145
go : Code → FindM Unit
155-
| .let decl k => do
156-
match k, decl.value with
157-
| .return valId, .fvar fvarId args =>
158-
args.forM removeCandidatesInArg
159-
if let some candidateInfo ← findCandidate? fvarId then
160-
-- Erase candidate that are not fully applied or applied outside of tail position
161-
if valId != decl.fvarId || args.size != candidateInfo.arity then
162-
eraseCandidate fvarId
163-
-- Out of scope join point candidate handling
164-
else
165-
let currDepth := (← read).definitionDepth
166-
let calleeDepth := (← read).scope.get! fvarId
167-
if currDepth == calleeDepth then
168-
return ()
169-
else if calleeDepth + 1 == currDepth then
170-
addDependency fvarId (← read).currentFunction.get!
171-
else
146+
| .let decl k => do
147+
match k, decl.value with
148+
| .return valId, .fvar fvarId args =>
149+
args.forM (forFVarM eraseCandidate)
150+
if let some candidateInfo ← findCandidate? fvarId then
151+
-- Erase candidate that are not fully applied or applied outside of tail position
152+
if valId != decl.fvarId || args.size != candidateInfo.arity then
172153
eraseCandidate fvarId
173-
| _, _ =>
174-
removeCandidatesInLetValue decl.value
175-
go k
176-
| .fun decl k => do
177-
addCandidate decl.fvarId decl.getArity
178-
withReader (fun ctx => {
179-
ctx with
180-
definitionDepth := ctx.definitionDepth + 1,
181-
currentFunction := some decl.fvarId }) do
154+
-- Out of scope join point candidate handling
155+
else
156+
let currDepth := (← read).definitionDepth
157+
let calleeDepth := (← read).scope.find! fvarId
158+
if currDepth == calleeDepth then
159+
return ()
160+
else if calleeDepth + 1 == currDepth then
161+
addDependency fvarId (← read).currentFunction.get!
162+
else
163+
eraseCandidate fvarId
164+
| _, _ =>
165+
forFVarM eraseCandidate decl.value
166+
go k
167+
| .fun decl k => do
168+
addCandidate decl.fvarId decl.getArity
169+
withReader (fun ctx => {
170+
ctx with
171+
definitionDepth := ctx.definitionDepth + 1,
172+
currentFunction := some decl.fvarId }) do
173+
go decl.value
174+
withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
175+
go k
176+
| .jp decl k => do
182177
go decl.value
183-
withReader (fun ctx => { ctx with scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
184178
go k
185-
| .jp decl k => do
186-
go decl.value
187-
go k
188-
| .jmp _ args => args.forM removeCandidatesInArg
189-
| .return val => eraseCandidate val
190-
| .cases c => c.alts.forM (·.forCodeM go)
191-
| .unreach .. => return ()
179+
| .jmp _ args => args.forM (forFVarM eraseCandidate)
180+
| .return val => eraseCandidate val
181+
| .cases c => c.alts.forM (·.forCodeM go)
182+
| .unreach .. => return ()
192183

193184
/--
194185
Replace all join point candidate `fun` declarations with `jp` ones
195186
and all calls to them with `jmp`s.
196187
-/
197188
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
189+
if state.candidates.isEmpty then
190+
return decl
198191
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
199192
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := ∅) mapper
200193
let newValue ← decl.value.mapCodeM go |>.run replaceCtx

0 commit comments

Comments
 (0)