Skip to content

Commit cc046e0

Browse files
authored
perf: improve join point finding (#10999)
This PR improves join point finding in the compiler through two means: 1. We now handle situations where a function `f` can only become a join point when a function `g` becomes a join point as well correctly. 2. We introduce a second join point finding pass after specialisation and before the following simplification pass, as the specialiser might have introduced new join point opportunities for the simplifier to exploit. Notably in the code from #10995 we now correctly detect the missing join point which required both of these changes to be made. Closes: #10995
1 parent e11ef3e commit cc046e0

File tree

3 files changed

+155
-38
lines changed

3 files changed

+155
-38
lines changed

src/Lean/Compiler/LCNF/JoinPoints.lean

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@ structure CandidateInfo where
3434
associated : Std.HashSet FVarId
3535
deriving Inhabited
3636

37+
structure FindCtx where
38+
/--
39+
The current definition depth is defined by how many `fun` binders we are
40+
nested in at the current point. Note that this does *not* include `jp`
41+
binders.
42+
-/
43+
definitionDepth : Nat := 0
44+
/--
45+
A map from function declarations that are currently in scope to their
46+
definition depth.
47+
-/
48+
scope : FVarIdMap Nat := {}
49+
/--
50+
The current function binder we are inside of if any.
51+
-/
52+
currentFunction : Option FVarId := none
53+
3754
/--
3855
The state for the join point candidate finder.
3956
-/
@@ -42,38 +59,36 @@ structure FindState where
4259
All current join point candidates accessible by their `FVarId`.
4360
-/
4461
candidates : Std.HashMap FVarId CandidateInfo := ∅
45-
/--
46-
The `FVarId`s of all `fun` declarations that were declared within the
47-
current `fun`.
48-
-/
49-
scope : Std.HashSet FVarId := ∅
5062

51-
abbrev ReplaceCtx := Std.HashMap FVarId Name
5263

53-
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
64+
abbrev FindM := ReaderT FindCtx StateRefT FindState CompilerM
65+
66+
abbrev ReplaceCtx := Std.HashMap FVarId Name
5467
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
5568

5669
/--
5770
Attempt to find a join point candidate by its `FVarId`.
5871
-/
72+
@[inline]
5973
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
6074
return (← get).candidates[fvarId]?
6175

76+
/--
77+
Combinator for modifying the candidates in `FindM`.
78+
-/
79+
@[inline]
80+
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
81+
modify (fun state => { state with candidates := f state.candidates })
82+
6283
/--
6384
Erase a join point candidate as well as all the ones that depend on it
6485
by its `FVarId`, no error is thrown is the candidate does not exist.
6586
-/
6687
private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
6788
if let some info ← findCandidate? fvarId then
68-
modify (fun state => { state with candidates := state.candidates.erase fvarId })
89+
modifyCandidates fun cs => cs.erase fvarId
6990
info.associated.forM eraseCandidate
7091

71-
/--
72-
Combinator for modifying the candidates in `FindM`.
73-
-/
74-
private def modifyCandidates (f : Std.HashMap FVarId CandidateInfo → Std.HashMap FVarId CandidateInfo) : FindM Unit :=
75-
modify (fun state => {state with candidates := f state.candidates })
76-
7792
/--
7893
Remove all join point candidates contained in `a`.
7994
-/
@@ -91,16 +106,30 @@ Add a new join point candidate to the state.
91106
-/
92107
private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
93108
let cinfo := { arity, associated := ∅ }
94-
modifyCandidates (fun cs => cs.insert fvarId cinfo )
109+
modifyCandidates fun cs => cs.insert fvarId cinfo
95110

96111
/--
97112
Add a new join point dependency from `src` to `dst`.
98113
-/
99114
private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
100-
if let some targetInfo ← findCandidate? target then
101-
modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
102-
else
103-
eraseCandidate src
115+
modifyCandidates fun cs =>
116+
cs.modify target fun targetInfo =>
117+
{ targetInfo with associated := targetInfo.associated.insert src }
118+
119+
@[inline]
120+
private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
121+
withReader (fun ctx => {
122+
ctx with
123+
definitionDepth := ctx.definitionDepth + 1,
124+
currentFunction := some decl.fvarId }) do
125+
x
126+
127+
@[inline]
128+
private def withFnDefined (decl : FunDecl) (x : FindM α) : FindM α :=
129+
withReader (fun ctx => {
130+
ctx with
131+
scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
132+
x
104133

105134
/--
106135
Find all `fun` declarations that qualify as a join point, that is:
@@ -135,7 +164,7 @@ this. This is because otherwise the calls to `myjp` in `f` and `g` would
135164
produce out of scope join point jumps.
136165
-/
137166
partial def find (decl : Decl) : CompilerM FindState := do
138-
let (_, candidates) ← decl.value.forCodeM go |>.run none |>.run {} |>.run' {}
167+
let (_, candidates) ← decl.value.forCodeM go |>.run {} |>.run {}
139168
return candidates
140169
where
141170
go : Code → FindM Unit
@@ -148,29 +177,30 @@ where
148177
if valId != decl.fvarId || args.size != candidateInfo.arity then
149178
eraseCandidate fvarId
150179
-- Out of scope join point candidate handling
151-
else if let some upperCandidate ← read then
152-
if !(← isInScope fvarId) then
153-
addDependency fvarId upperCandidate
154-
else
155-
eraseCandidate fvarId
180+
else
181+
let currDepth := (← read).definitionDepth
182+
let calleeDepth := (← read).scope.get! fvarId
183+
if currDepth == calleeDepth then
184+
return ()
185+
else if calleeDepth + 1 == currDepth then
186+
addDependency fvarId (← read).currentFunction.get!
187+
else
188+
eraseCandidate fvarId
156189
| _, _ =>
157190
removeCandidatesInLetValue decl.value
158191
go k
159192
| .fun decl k => do
160-
withReader (fun _ => some decl.fvarId) do
161-
withNewScope do
162-
go decl.value
163193
addCandidate decl.fvarId decl.getArity
164-
addToScope decl.fvarId
165-
go k
194+
withFnBody decl do
195+
go decl.value
196+
withFnDefined decl do
197+
go k
166198
| .jp decl k => do
167199
go decl.value
168200
go k
169201
| .jmp _ args => args.forM removeCandidatesInArg
170202
| .return val => eraseCandidate val
171-
| .cases c => do
172-
eraseCandidate c.discr
173-
c.alts.forM (·.forCodeM go)
203+
| .cases c => c.alts.forM (·.forCodeM go)
174204
| .unreach .. => return ()
175205

176206
/--
@@ -602,17 +632,23 @@ where
602632

603633
end JoinPointCommonArgs
604634

635+
def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
636+
let findResult ← JoinPointFinder.find decl
637+
trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}"
638+
if findResult.candidates.isEmpty then
639+
return none
640+
else
641+
return some (← JoinPointFinder.replace decl findResult)
642+
605643
/--
606644
Find all `fun` declarations in `decl` that qualify as join points then replace
607645
their definitions and call sites with `jp`/`jmp`.
608646
-/
609647
def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
610-
let findResult ← JoinPointFinder.find decl
611-
trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}"
612-
JoinPointFinder.replace decl findResult
648+
return (← Decl.findJoinPoints? decl).getD decl
613649

614-
def findJoinPoints : Pass :=
615-
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base
650+
def findJoinPoints (occurrence : Nat := 0) : Pass :=
651+
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base (occurrence := occurrence)
616652

617653
builtin_initialize
618654
registerTraceClass `Compiler.findJoinPoints (inherited := true)

src/Lean/Compiler/LCNF/Passes.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def builtinPassManager : PassManager := {
9494
-- checked without nested functions whose bodies specialization does not require access to.
9595
checkTemplateVisibility,
9696
specialize,
97+
findJoinPoints (occurrence := 1),
9798
simp (occurrence := 2),
9899
cse (shouldElimFunDecls := false) (occurrence := 1),
99100
saveBase, -- End of base phase

tests/lean/run/more_jps.lean

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
@[specialize 3 4] def List.foldrNonTR (f : α → β → β) (init : β) : (l : List α) → β
2+
| [] => init
3+
| a :: l => f a (foldrNonTR f init l)
4+
5+
@[always_inline, inline]
6+
def List.forBreak_ {α : Type u} {m : Type w → Type x} [Monad m] (xs : List α) (s : σ) (body : α → OptionT (StateT σ (ExceptT ρ m)) PUnit) (kreturn : ρ → m γ) (kbreak : σ → m γ) : m γ :=
7+
List.foldrNonTR
8+
(fun a acc s => do
9+
let e ← body a s
10+
match e with
11+
| .error r => kreturn r
12+
| .ok (.some _, s) => acc s
13+
| .ok (none, s) => kbreak s)
14+
kbreak
15+
xs
16+
s
17+
18+
/--
19+
trace: [Compiler.saveBase] size: 25
20+
def List.foldrNonTR._at_._example.spec_0 x.1 _y.2 : Nat :=
21+
jp _jp.3 x : Nat :=
22+
let _x.4 := 13;
23+
let x := Nat.add x _x.4;
24+
let x := Nat.add x _x.4;
25+
let x := Nat.add x _x.4;
26+
let x := Nat.add x _x.4;
27+
let x := Nat.add x _x.4;
28+
let x := Nat.add x _x.4;
29+
let x := Nat.add x _x.4;
30+
return x;
31+
cases x.1 : Nat
32+
| List.nil =>
33+
goto _jp.3 _y.2
34+
| List.cons head.5 tail.6 =>
35+
let _x.7 := 0;
36+
let _x.8 := instDecidableEqNat _y.2 _x.7;
37+
cases _x.8 : Nat
38+
| Decidable.isFalse x.9 =>
39+
let _x.10 := 10;
40+
let _x.11 := Nat.decLt _x.10 _y.2;
41+
cases _x.11 : Nat
42+
| Decidable.isFalse x.12 =>
43+
let _x.13 := Nat.add _y.2 head.5;
44+
let _x.14 := List.foldrNonTR._at_._example.spec_0 tail.6 _x.13;
45+
return _x.14
46+
| Decidable.isTrue x.15 =>
47+
goto _jp.3 _y.2
48+
| Decidable.isTrue x.16 =>
49+
return _y.2
50+
[Compiler.saveBase] size: 9
51+
def _example : Nat :=
52+
let x := 42;
53+
let _x.1 := 1;
54+
let _x.2 := 2;
55+
let _x.3 := 3;
56+
let _x.4 := @List.nil _;
57+
let _x.5 := @List.cons _ _x.3 _x.4;
58+
let _x.6 := @List.cons _ _x.2 _x.5;
59+
let _x.7 := @List.cons _ _x.1 _x.6;
60+
let _x.8 := List.foldrNonTR._at_._example.spec_0 _x.7 x;
61+
return _x.8
62+
-/
63+
#guard_msgs in
64+
set_option trace.Compiler.saveBase true in
65+
example := Id.run do
66+
let x := 42;
67+
List.forBreak_ (m:=Id) (ρ := Nat) [1, 2, 3] x (fun i => do
68+
let x ← get
69+
if x = 0 then throw (m := ExceptT Nat Id) x -- return
70+
else if x > 10 then failure (f := OptionT _) -- break
71+
else set (x + i) >>= fun _ => pure ()) -- continue
72+
pure fun x => do
73+
let x := x + 13;
74+
let x := x + 13;
75+
let x := x + 13;
76+
let x := x + 13;
77+
let x := x + 13;
78+
let x := x + 13;
79+
let x := x + 13;
80+
return x

0 commit comments

Comments
 (0)