Skip to content

Commit 47f0c0f

Browse files
committed
feat: eta-expand for oversaturating arguments while specializing (#10924)
1 parent 14d76cc commit 47f0c0f

File tree

3 files changed

+259
-16
lines changed

3 files changed

+259
-16
lines changed

src/Lean/Compiler/LCNF/SpecInfo.lean

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ inductive SpecParamInfo where
2323
-/
2424
| fixedInst
2525
/--
26-
A parameter that is a function and is fixed in recursive declarations. If the user tags a declaration
27-
with `@[specialize]` without specifying which arguments should be specialized, Lean will specialize
28-
`.fixedHO` arguments in addition to `.fixedInst`.
26+
A parameter that is a function and is fixed in recursive declarations, or a parameter the type of
27+
which is the polymorphic return type `α` of the declaration and which could be instantiated to a
28+
function.
29+
If the user tags a declaration with `@[specialize]` without specifying which arguments should be
30+
specialized, Lean will specialize `.fixedHO` arguments in addition to `.fixedInst`.
2931
-/
3032
| fixedHO
3133
/--
@@ -142,6 +144,17 @@ private def hasFwdDeps (decl : Decl) (paramsInfo : Array SpecParamInfo) (j : Nat
142144
return true
143145
return false
144146

147+
def isFixedPolymorphicReturnType (decl : Decl) (type : Expr) (specInfos : Array SpecParamInfo) : CompilerM Bool := do
148+
-- logInfo m!"isFixedPolymorphicReturnType: {decl.name}, {type}, {specInfos}"
149+
let some idx := decl.params.findIdx? fun p => type == p.toExpr
150+
| return false
151+
let α := decl.params[idx]!.toExpr
152+
let retTy ← instantiateForall decl.type <| decl.params.map (mkFVar ·.fvarId)
153+
-- logInfo m!"isFixedPolymorphicReturnType2: {decl.name}, {α}, {retTy}"
154+
if specInfos[idx]! matches .fixedNeutral && retTy == α then
155+
return true
156+
return false
157+
145158
/--
146159
Save parameter information for `decls`.
147160
@@ -158,8 +171,11 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
158171
let specArgs? := getSpecializationArgs? (← getEnv) decl.name
159172
let contains (i : Nat) : Bool := specArgs?.getD #[] |>.contains i
160173
let mut paramsInfo : Array SpecParamInfo := #[]
174+
-- logInfo m!"decl.type: {decl.name} {decl.params.map fun p => (mkFVar p.fvarId, p.type)} {decl.type}"
161175
for h :i in *...decl.params.size do
162176
let param := decl.params[i]
177+
-- let b ← isFixedPolymorphicReturnType decl param.type paramsInfo
178+
-- logInfo m!"isFixedPolymorphicReturnType: {b}"
163179
let info ←
164180
if contains i then
165181
pure .user
@@ -178,7 +194,7 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
178194
specify which arguments must be specialized besides instances. In this case, we try to specialize
179195
any "fixed higher-order argument"
180196
-/
181-
else if specArgs? == some #[] && param.type matches .forallE .. then
197+
else if specArgs? == some #[] && (param.type matches .forallE .. || (← isFixedPolymorphicReturnType decl param.type paramsInfo)) then
182198
pure .fixedHO
183199
else
184200
pure .other

src/Lean/Compiler/LCNF/Specialize.lean

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ The keys never contain free variables or loose bound variables.
168168

169169
/--
170170
Given the specialization mask `paramsInfo` and the arguments `args`,
171-
collect their dependencies, and return an array `mask` of size `paramsInfo.size` s.t.
172-
- `mask[i] = some args[i]` if `paramsInfo[i] != .other`
171+
collect their dependencies, and return an array `mask` of size `args.size` s.t.
172+
- `mask[i] = some args[i]` if `paramsInfo[i]? != some .other`
173173
- `mask[i] = none`, otherwise.
174174
That is, `mask` contains only the arguments that are contributing to the code specialization.
175175
We use this information to compute a "key" to uniquely identify the code specialization, and
@@ -185,7 +185,9 @@ def collect (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM
185185
!ctx.ground.contains fvarId
186186
Closure.run (inScope := ctx.scope.contains) (abstract := abstract) do
187187
let mut argMask := #[]
188-
for paramInfo in paramsInfo, arg in args do
188+
for i in *...args.size do
189+
let paramInfo := paramsInfo[i]?.getD .fixedHO
190+
let arg := args[i]!
189191
match paramInfo with
190192
| .other =>
191193
argMask := argMask.push none
@@ -200,7 +202,9 @@ end Collector
200202
Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information.
201203
-/
202204
def shouldSpecialize (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM Bool := do
203-
for paramInfo in paramsInfo, arg in args do
205+
for i in *...args.size do
206+
let arg := args[i]!
207+
let paramInfo := paramsInfo[i]?.getD .fixedHO -- .fixedHO might be too aggressive
204208
match paramInfo with
205209
| .other => pure ()
206210
| .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here
@@ -267,21 +271,54 @@ where
267271
let .code code := decl.value | panic! "can only specialize decls with code"
268272
let mut params ← params.mapM internalizeParam
269273
let decls ← decls.mapM internalizeCodeDecl
270-
for param in decl.params, arg in argMask do
274+
let mut bodyType := decl.type.instantiateLevelParamsNoCache decl.levelParams us
275+
for arg in argMask, param in decl.params do
276+
let .forallE _ d b _ := bodyType.headBeta
277+
| panic! "has param of type {param.type}, but bodyType {bodyType} was not a forall"
271278
if let some arg := arg then
272279
let arg ← normArg arg
273280
modify fun s => s.insert param.fvarId arg
281+
bodyType := b.instantiate1 arg.toExpr
274282
else
275283
-- Keep the parameter
276-
let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }
277-
params := params.push (← internalizeParam param)
278-
for param in decl.params[argMask.size...*] do
279-
let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }
280-
params := params.push (← internalizeParam param)
284+
let param ← internalizeParam { param with type := d }
285+
params := params.push param
286+
bodyType := b.instantiate1 (.fvar param.fvarId)
287+
let extraParams := decl.params[argMask.size...*] -- non-empty if undersaturated app
288+
let extraMask := argMask[decl.params.size...*] -- non-empty if oversaturated app
289+
-- Add extraneous parameters to decl
290+
for param in extraParams do
291+
let .forallE _ d b _ := bodyType.headBeta
292+
| panic! "has param of type {param.type}, but bodyType {bodyType} was not a forall"
293+
-- Keep the parameter
294+
let param ← internalizeParam { param with type := d }
295+
params := params.push param
296+
bodyType := b.instantiate1 (.fvar param.fvarId)
281297
let code := code.instantiateValueLevelParams decl.levelParams us
282298
let code ← internalizeCode code
283299
let code := attachCodeDecls decls code
284-
let type ← code.inferType
300+
-- Eta-expand to accomodate extraneous args (cf. `etaExpandCore`)
301+
let code ←
302+
if extraMask.size = 0 then
303+
pure code
304+
else
305+
let mut extraArgs := #[]
306+
for arg in extraMask do
307+
let .forallE _ d b _ := bodyType.headBeta
308+
| panic! "oversaturated arg mask but decl.type was not a forall"
309+
if let some arg := arg then
310+
let arg ← normArg arg
311+
extraArgs := extraArgs.push arg
312+
bodyType := b.instantiate1 arg.toExpr
313+
else
314+
let p ← mkAuxParam d
315+
params := params.push p
316+
extraArgs := extraArgs.push (.fvar p.fvarId)
317+
bodyType := b.instantiate1 (.fvar p.fvarId)
318+
code.bind fun fvarId => do
319+
let auxDecl ← mkAuxLetDecl (.fvar fvarId extraArgs)
320+
return .let auxDecl (.return auxDecl.fvarId)
321+
let type := bodyType
285322
let type ← mkForallParams params type
286323
let value := .code code
287324
let safe := decl.safe
@@ -298,7 +335,7 @@ def getRemainingArgs (paramsInfo : Array SpecParamInfo) (args : Array Arg) : Arr
298335
for info in paramsInfo, arg in args do
299336
if info matches .other then
300337
result := result.push arg
301-
return result ++ args[paramsInfo.size...*]
338+
return result -- ++ args[paramsInfo.size...*]
302339

303340
def paramsToGroundVars (params : Array Param) : CompilerM FVarIdSet :=
304341
params.foldlM (init := {}) fun r p => do

tests/lean/run/10924.lean

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
@[specialize]
2+
def foo {α} : Nat → (α → α) → α → α
3+
| 0, f => f
4+
| n+1, f => foo n f
5+
6+
set_option trace.Compiler.saveBase true in
7+
/--
8+
trace: [Compiler.saveBase] size: 5
9+
def foo._at_._example.spec_0 x.1 : Nat :=
10+
cases x.1 : Nat
11+
| Nat.zero =>
12+
let _x.2 := 6;
13+
return _x.2
14+
| Nat.succ n.3 =>
15+
let _x.4 := foo._at_._example.spec_0 n.3;
16+
return _x.4
17+
[Compiler.saveBase] size: 1
18+
def _example n : Nat :=
19+
let _x.1 := foo._at_._example.spec_0 n;
20+
return _x.1
21+
-/
22+
#guard_msgs in
23+
example {n} := foo n (· + 1) 5
24+
25+
set_option trace.Compiler.saveBase true in
26+
/--
27+
trace: [Compiler.saveBase] size: 9
28+
def foo._at_._example.spec_0 x.1 : Nat :=
29+
fun _f.2 x.3 : Nat :=
30+
let _x.4 := 1;
31+
let _x.5 := Nat.add x.3 _x.4;
32+
return _x.5;
33+
let _x.6 := 5;
34+
cases x.1 : Nat
35+
| Nat.zero =>
36+
let _x.7 := _f.2 _x.6;
37+
let _x.8 := _f.2 _x.7;
38+
return _x.8
39+
| Nat.succ n.9 =>
40+
let _x.10 := foo._at_._example.spec_0 n.9;
41+
return _x.10
42+
[Compiler.saveBase] size: 1
43+
def _example n : Nat :=
44+
let _x.1 := foo._at_._example.spec_0 n;
45+
return _x.1
46+
-/
47+
#guard_msgs in
48+
example {n} := foo n (fun f a => f (f a)) (· + 1) 5
49+
50+
set_option trace.Compiler.saveBase true in
51+
/--
52+
trace: [Compiler.saveBase] size: 5
53+
def foo._at_._example.spec_0 x.1 : Nat :=
54+
let _x.2 := 5;
55+
cases x.1 : Nat
56+
| Nat.zero =>
57+
return _x.2
58+
| Nat.succ n.3 =>
59+
let _x.4 := foo._at_._example.spec_0 n.3;
60+
return _x.4
61+
[Compiler.saveBase] size: 1
62+
def _example n : Nat :=
63+
let _x.1 := foo._at_._example.spec_0 n;
64+
return _x.1
65+
-/
66+
#guard_msgs in
67+
example {n} := foo n id id id id id id 5
68+
69+
set_option trace.Compiler.saveBase true in
70+
/--
71+
trace: [Compiler.saveBase] size: 9
72+
def foo._at_._example.spec_0 x.1 : Nat :=
73+
fun _f.2 f g : Nat :=
74+
let _x.3 := f g;
75+
let _x.4 := f _x.3;
76+
return _x.4;
77+
fun _f.5 _y.6 : Nat :=
78+
return _y.6;
79+
let _x.7 := 5;
80+
cases x.1 : Nat
81+
| Nat.zero =>
82+
let _x.8 := _f.2 _f.5;
83+
let _x.9 := _f.2 _x.8 _x.7;
84+
return _x.9
85+
| Nat.succ n.10 =>
86+
let _x.11 := foo._at_._example.spec_0 n.10;
87+
return _x.11
88+
[Compiler.saveBase] size: 1
89+
def _example n : Nat :=
90+
let _x.1 := foo._at_._example.spec_0 n;
91+
return _x.1
92+
-/
93+
#guard_msgs in
94+
example {n} := foo n (fun f g => f <| f g) (fun f g => f <| f g) id 5
95+
96+
@[specialize]
97+
def List.forBreak_ {α : Type u} {m : Type w → Type x} [Monad m] (xs : List α) (body : α → ExceptCpsT PUnit m PUnit) : m PUnit :=
98+
match xs with
99+
| [] => pure ⟨⟩
100+
| x :: xs => body x (fun _ => forBreak_ xs body) (fun _ => pure ⟨⟩)
101+
102+
-- This one still does not properly specialize for the success and error continuations
103+
-- (`_y.4`, `_y.5`). The reason is that the loop body is not yet inlined when the specializer looks
104+
-- at the recursive call site in `List.forBreak_._at_._example.spec_0`, so it allocates another,
105+
-- strictly less general specialization `…spec_0.spec_0`.
106+
-- The reason the loop body is not yet inlined is that it occurs in the recursive call site as well,
107+
-- but only pre-specialization.
108+
set_option trace.Compiler.saveBase true in
109+
/--
110+
trace: [Compiler.saveBase] size: 23
111+
def List.forBreak_._at_.List.forBreak_._at_._example.spec_0.spec_0 _x.1 _y.2 _y.3 _y.4 _y.5 xs : _y.3 :=
112+
cases xs : _y.3
113+
| List.nil =>
114+
let _x.6 := PUnit.unit;
115+
let _x.7 := @Prod.mk _ _ _x.6 _y.2;
116+
let _x.8 := _y.4 _x.7;
117+
return _x.8
118+
| List.cons head.9 tail.10 =>
119+
let _x.11 := 0;
120+
let _x.12 := instDecidableEqNat _y.2 _x.11;
121+
cases _x.12 : _y.3
122+
| Decidable.isFalse x.13 =>
123+
let _x.14 := 10;
124+
let _x.15 := Nat.decLt _x.14 _y.2;
125+
cases _x.15 : _y.3
126+
| Decidable.isFalse x.16 =>
127+
let _x.17 := Nat.add _y.2 head.9;
128+
let _x.18 := List.forBreak_._at_.List.forBreak_._at_._example.spec_0.spec_0 _x.1 _x.17 _y.3 _y.4 _y.5 tail.10;
129+
return _x.18
130+
| Decidable.isTrue x.19 =>
131+
let _x.20 := Nat.add _y.2 _x.1;
132+
let _x.21 := PUnit.unit;
133+
let _x.22 := @Prod.mk _ _ _x.21 _x.20;
134+
let _x.23 := _y.4 _x.22;
135+
return _x.23
136+
| Decidable.isTrue x.24 =>
137+
let _x.25 := _y.5 _y.2;
138+
return _x.25
139+
[Compiler.saveBase] size: 19
140+
def List.forBreak_._at_._example.spec_0 _x.1 xs : Nat :=
141+
let x := 42;
142+
fun _f.2 a : Nat :=
143+
cases a : Nat
144+
| Prod.mk fst.3 snd.4 =>
145+
return snd.4;
146+
fun _f.5 _y.6 : Nat :=
147+
return _y.6;
148+
cases xs : Nat
149+
| List.nil =>
150+
return x
151+
| List.cons head.7 tail.8 =>
152+
let _x.9 := 0;
153+
let _x.10 := instDecidableEqNat x _x.9;
154+
cases _x.10 : Nat
155+
| Decidable.isFalse x.11 =>
156+
let _x.12 := 10;
157+
let _x.13 := Nat.decLt _x.12 x;
158+
cases _x.13 : Nat
159+
| Decidable.isFalse x.14 =>
160+
let _x.15 := Nat.add x head.7;
161+
let _x.16 := List.forBreak_._at_.List.forBreak_._at_._example.spec_0.spec_0 _x.1 _x.15 _ _f.2 _f.5 tail.8;
162+
return _x.16
163+
| Decidable.isTrue x.17 =>
164+
let _x.18 := Nat.add x _x.1;
165+
return _x.18
166+
| Decidable.isTrue x.19 =>
167+
return x
168+
[Compiler.saveBase] size: 8
169+
def _example : Nat :=
170+
let _x.1 := 1;
171+
let _x.2 := 2;
172+
let _x.3 := 3;
173+
let _x.4 := @List.nil _;
174+
let _x.5 := @List.cons _ _x.3 _x.4;
175+
let _x.6 := @List.cons _ _x.2 _x.5;
176+
let _x.7 := @List.cons _ _x.1 _x.6;
177+
let _x.8 := List.forBreak_._at_._example.spec_0 _x.1 _x.7;
178+
return _x.8
179+
-/
180+
#guard_msgs in
181+
-- set_option trace.Compiler.specialize.candidate true in
182+
-- set_option trace.Compiler.specialize.step true in
183+
example := Id.run <| ExceptCpsT.runCatch do
184+
let x := 42;
185+
let ((), x) ←
186+
(List.forBreak_ (m:=StateT Nat (ExceptCpsT Nat Id)) [1, 2, 3] fun i _β «continue» «break» x =>
187+
if x = 0 then throw x
188+
else if x > 10 then «break» () (x + 1)
189+
else «continue» PUnit.unit (x + i)).run x
190+
return x

0 commit comments

Comments
 (0)