@@ -34,6 +34,23 @@ structure CandidateInfo where
3434 associated : Std.HashSet FVarId
3535 deriving Inhabited
3636
37+ structure FindCtx where
38+ /--
39+ The current definition depth is defined by how many `fun` binders we are
40+ nested in at the current point. Note that this does *not* include `jp`
41+ binders.
42+ -/
43+ definitionDepth : Nat := 0
44+ /--
45+ A map from function declarations that are currently in scope to their
46+ definition depth.
47+ -/
48+ scope : FVarIdMap Nat := {}
49+ /--
50+ The current function binder we are inside of if any.
51+ -/
52+ currentFunction : Option FVarId := none
53+
3754/--
3855The state for the join point candidate finder.
3956-/
@@ -42,38 +59,36 @@ structure FindState where
4259 All current join point candidates accessible by their `FVarId`.
4360 -/
4461 candidates : Std.HashMap FVarId CandidateInfo := ∅
45- /--
46- The `FVarId`s of all `fun` declarations that were declared within the
47- current `fun`.
48- -/
49- scope : Std.HashSet FVarId := ∅
5062
51- abbrev ReplaceCtx := Std.HashMap FVarId Name
5263
53- abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
64+ abbrev FindM := ReaderT FindCtx StateRefT FindState CompilerM
65+
66+ abbrev ReplaceCtx := Std.HashMap FVarId Name
5467abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
5568
5669/--
5770Attempt to find a join point candidate by its `FVarId`.
5871-/
72+ @[inline]
5973private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
6074 return (← get).candidates[fvarId]?
6175
76+ /--
77+ Combinator for modifying the candidates in `FindM`.
78+ -/
79+ @[inline]
80+ private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
81+ modify (fun state => { state with candidates := f state.candidates })
82+
6283/--
6384Erase a join point candidate as well as all the ones that depend on it
6485by its `FVarId`, no error is thrown is the candidate does not exist.
6586-/
6687private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
6788 if let some info ← findCandidate? fvarId then
68- modify ( fun state => { state with candidates := state.candidates. erase fvarId })
89+ modifyCandidates fun cs => cs. erase fvarId
6990 info.associated.forM eraseCandidate
7091
71- /--
72- Combinator for modifying the candidates in `FindM`.
73- -/
74- private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
75- modify (fun state => {state with candidates := f state.candidates })
76-
7792/--
7893Remove all join point candidates contained in `a`.
7994-/
@@ -91,16 +106,30 @@ Add a new join point candidate to the state.
91106-/
92107private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
93108 let cinfo := { arity, associated := ∅ }
94- modifyCandidates ( fun cs => cs.insert fvarId cinfo )
109+ modifyCandidates fun cs => cs.insert fvarId cinfo
95110
96111/--
97112Add a new join point dependency from `src` to `dst`.
98113-/
99114private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
100- if let some targetInfo ← findCandidate? target then
101- modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
102- else
103- eraseCandidate src
115+ modifyCandidates fun cs =>
116+ cs.modify target fun targetInfo =>
117+ { targetInfo with associated := targetInfo.associated.insert src }
118+
119+ @[inline]
120+ private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
121+ withReader (fun ctx => {
122+ ctx with
123+ definitionDepth := ctx.definitionDepth + 1 ,
124+ currentFunction := some decl.fvarId }) do
125+ x
126+
127+ @[inline]
128+ private def withFnDefined (decl : FunDecl) (x : FindM α) : FindM α :=
129+ withReader (fun ctx => {
130+ ctx with
131+ scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
132+ x
104133
105134/--
106135Find all `fun` declarations that qualify as a join point, that is:
@@ -135,7 +164,7 @@ this. This is because otherwise the calls to `myjp` in `f` and `g` would
135164produce out of scope join point jumps.
136165-/
137166partial def find (decl : Decl) : CompilerM FindState := do
138- let (_, candidates) ← decl.value.forCodeM go |>.run none |>.run {} |>.run' {}
167+ let (_, candidates) ← decl.value.forCodeM go |>.run {} |>.run {}
139168 return candidates
140169where
141170 go : Code → FindM Unit
@@ -148,29 +177,30 @@ where
148177 if valId != decl.fvarId || args.size != candidateInfo.arity then
149178 eraseCandidate fvarId
150179 -- Out of scope join point candidate handling
151- else if let some upperCandidate ← read then
152- if !(← isInScope fvarId) then
153- addDependency fvarId upperCandidate
154- else
155- eraseCandidate fvarId
180+ else
181+ let currDepth := (← read).definitionDepth
182+ let calleeDepth := (← read).scope.get! fvarId
183+ if currDepth == calleeDepth then
184+ return ()
185+ else if calleeDepth + 1 == currDepth then
186+ addDependency fvarId (← read).currentFunction.get!
187+ else
188+ eraseCandidate fvarId
156189 | _, _ =>
157190 removeCandidatesInLetValue decl.value
158191 go k
159192 | .fun decl k => do
160- withReader (fun _ => some decl.fvarId) do
161- withNewScope do
162- go decl.value
163193 addCandidate decl.fvarId decl.getArity
164- addToScope decl.fvarId
165- go k
194+ withFnBody decl do
195+ go decl.value
196+ withFnDefined decl do
197+ go k
166198 | .jp decl k => do
167199 go decl.value
168200 go k
169201 | .jmp _ args => args.forM removeCandidatesInArg
170202 | .return val => eraseCandidate val
171- | .cases c => do
172- eraseCandidate c.discr
173- c.alts.forM (·.forCodeM go)
203+ | .cases c => c.alts.forM (·.forCodeM go)
174204 | .unreach .. => return ()
175205
176206/--
@@ -602,17 +632,23 @@ where
602632
603633end JoinPointCommonArgs
604634
635+ def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
636+ let findResult ← JoinPointFinder.find decl
637+ trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}"
638+ if findResult.candidates.isEmpty then
639+ return none
640+ else
641+ return some (← JoinPointFinder.replace decl findResult)
642+
605643/--
606644Find all `fun` declarations in `decl` that qualify as join points then replace
607645their definitions and call sites with `jp`/`jmp`.
608646-/
609647def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
610- let findResult ← JoinPointFinder.find decl
611- trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}"
612- JoinPointFinder.replace decl findResult
648+ return (← Decl.findJoinPoints? decl).getD decl
613649
614- def findJoinPoints : Pass :=
615- .mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base
650+ def findJoinPoints (occurrence : Nat := 0 ) : Pass :=
651+ .mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base (occurrence := occurrence)
616652
617653builtin_initialize
618654 registerTraceClass `Compiler.findJoinPoints (inherited := true )
0 commit comments