Skip to content

Commit af87540

Browse files
committed
fix for kernel performance issues
1 parent 70981dd commit af87540

File tree

4 files changed

+170
-46
lines changed

4 files changed

+170
-46
lines changed

src/Init/SimpLemmas.lean

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ theorem let_body_congr {α : Sort u} {β : α → Sort v} {b b' : (a : α) →
7474
(a : α) (h : ∀ x, b x = b' x) : (let x := a; b x) = (let x := a; b' x) :=
7575
(funext h : b = b') ▸ rfl
7676

77+
/-!
78+
Simp lemmas for `have` have kernel performance issues when stated using `have` directly.
79+
Illustration of the problem: the kernel infers that the type of
80+
`have_congr (fun x => b) (fun x => b') h₁ h₂`
81+
is
82+
`(have x := a; (fun x => b) x) = (have x := a'; (fun x => b') x)`
83+
rather than
84+
`(have x := a; b x) = (have x := a'; b' x)`
85+
That means the kernel will do `whnf_core` at every step of checking a sequence of these lemmas.
86+
Thus, we get quadratically many zeta reductions.
87+
88+
For reference, we have the `have` versions of the theorems in the following comment,
89+
and then after that we have the versions that `simpHaveTelescope` actually uses,
90+
which avoid this issue.
91+
-/
92+
/-
7793
theorem have_unused {α : Sort u} {β : Sort v} (a : α) {b b' : β}
7894
(h : b = b') : (have _ := a; b) = b' := h
7995
@@ -95,6 +111,29 @@ theorem have_body_congr_dep {α : Sort u} {β : α → Sort v} (a : α) {f f' :
95111
theorem have_body_congr {α : Sort u} {β : Sort v} (a : α) {f f' : α → β}
96112
(h : ∀ x, f x = f' x) : (have x := a; f x) = (have x := a; f' x) :=
97113
h a
114+
-/
115+
116+
theorem have_unused' {α : Sort u} {β : Sort v} (a : α) {b b' : β}
117+
(h : b = b') : (fun _ => b) a = b' := h
118+
119+
theorem have_unused_dep' {α : Sort u} {β : Sort v} (a : α) {b : α → β} {b' : β}
120+
(h : ∀ x, b x = b') : b a = b' := h a
121+
122+
theorem have_congr' {α : Sort u} {β : Sort v} {a a' : α} {f f' : α → β}
123+
(h₁ : a = a') (h₂ : ∀ x, f x = f' x) : f a = f' a' :=
124+
@congr α β f f' a a' (funext h₂) h₁
125+
126+
theorem have_val_congr' {α : Sort u} {β : Sort v} {a a' : α} {f : α → β}
127+
(h : a = a') : f a = f a' :=
128+
@congrArg α β a a' f h
129+
130+
theorem have_body_congr_dep' {α : Sort u} {β : α → Sort v} (a : α) {f f' : (x : α) → β x}
131+
(h : ∀ x, f x = f' x) : f a = f' a :=
132+
h a
133+
134+
theorem have_body_congr' {α : Sort u} {β : Sort v} (a : α) {f f' : α → β}
135+
(h : ∀ x, f x = f' x) : f a = f' a :=
136+
h a
98137

99138
theorem letFun_unused {α : Sort u} {β : Sort v} (a : α) {b b' : β} (h : b = b') : @letFun α (fun _ => β) a (fun _ => b) = b' :=
100139
h

src/Lean/Meta/Tactic/Simp/Main.lean

Lines changed: 113 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -488,12 +488,12 @@ def HaveTelescopeInfo.computeFixedUsed (info : HaveTelescopeInfo) (keepUnused :
488488
return (info.bodyTypeDeps, used)
489489

490490
/--
491-
Auxiliary structure used to represent the return value of `simpLet.simpHaveTelescopeAux`.
491+
Auxiliary structure used to represent the return value of `simpHaveTelescopeAux`.
492492
-/
493493
private structure SimpHaveResult where
494494
/--
495-
The simplified expression. Note that it may contain loose bound variables.
496-
`simpLet` attempts to minimize the quadratic overhead imposed
495+
The simplified expression in `(fun x => b) v` form. Note that it may contain loose bound variables.
496+
`simpHaveTelescope` attempts to minimize the quadratic overhead imposed
497497
by the locally nameless discipline in a sequence of `have` expressions.
498498
-/
499499
expr : Expr
@@ -503,6 +503,14 @@ private structure SimpHaveResult where
503503
-/
504504
exprType : Expr
505505
/--
506+
The initial expression in `(fun x => b) v` form.
507+
-/
508+
exprInit : Expr
509+
/--
510+
The expression `expr` in `have x := v; b` form.
511+
-/
512+
exprResult : Expr
513+
/--
506514
The proof that the simplified expression is equal to the input one.
507515
It may contain loose bound variables like in the `expr` field.
508516
-/
@@ -513,14 +521,27 @@ private structure SimpHaveResult where
513521
-/
514522
modified : Bool
515523

516-
private def SimpHaveResult.toResult : SimpHaveResult → Result
517-
| { expr, proof, modified, .. } => { expr, proof? := if modified then some proof else none }
524+
private def SimpHaveResult.toResult (u : Level) (source : Expr) : SimpHaveResult → Result
525+
| { expr, exprType, exprInit, exprResult, proof, modified, .. } =>
526+
{ expr := exprResult
527+
proof? :=
528+
if modified then
529+
-- Add a type hint to convert back into `have` form.
530+
some <| mkApp2 (mkConst ``id [levelZero]) (mkApp3 (mkConst ``Eq [u]) exprType source exprResult)
531+
-- Add in a second type hint, for use in an optimization to avoid zeta/beta reductions in the kernel
532+
-- (see the base case in `simpHaveTelescopeAux`).
533+
(mkApp2 (mkConst ``id [levelZero]) (mkApp3 (mkConst ``Eq [u]) exprType exprInit expr)
534+
proof)
535+
else
536+
none }
518537

519538
/--
520539
Routine for simplifying `have` telescopes. Used by `simpLet`.
521540
This is optimized to be able to handle deep `have` telescopes while avoiding quadratic complexity
522541
arising from the locally nameless expression representations.
523542
543+
## Overview
544+
524545
Consider a `have` telescope:
525546
```
526547
have x₁ : T₁ := v₁; ...; have xₙ : Tₙ := vₙ; b
@@ -529,14 +550,44 @@ We say `xᵢ` is *fixed* if it appears in the type of `b`.
529550
If `xᵢ` is fixed, then it can only be rewritten using definitional equalities.
530551
Otherwise, we can freely rewrite the value using a propositional equality `vᵢ = vᵢ'`.
531552
The body `b` can always be freely rewritten using a propositional equality `b = b'`.
553+
(All the mentioned equalities must be type correct with respect to the obvious local contexts.)
532554
533555
If `xᵢ` neither appears in `b` nor any `Tⱼ` nor any `vⱼ`, then its `have` is *unused*.
534556
Unused `have`s can be eliminated, which we do if `cfg.zetaUnused` is true.
535-
Note that it is best to clear unused `have`s from the right,
536-
to eliminate any uses from unused `have`s.
537-
This is why we honor `zetaUnused` here even though `reduceStep` is also responsible for it.
557+
We clear `have`s from the end, to be able to transitively clear chains of unused `have`s.
558+
This is why we honor `zetaUnused` here even though `reduceStep` is also responsible for it;
559+
`reduceStep` can only eliminate unused `have`s at the start of a telescope.
560+
Eliminating all transitively unused `have`s at once like this also avoids quadratic complexity.
538561
539562
If `debug.simp.check` is enabled then we typecheck the resulting expression and proof.
563+
564+
## Proof generation
565+
566+
There are two main complications with generating proofs.
567+
1. We work almost exclusively with expressions with loose bound variables,
568+
to avoid overhead from instantiating and abstracting free variables,
569+
which can lead to complexity quadratic in the telescope length.
570+
2. We want to avoid triggering zeta reductions in the kernel.
571+
572+
Regarding this second point, the issue is that we cannot organize the proof using congruence theorems
573+
in the obvious way. For example, given
574+
```
575+
theorem have_congr {α : Sort u} {β : Sort v} {a a' : α} {f f' : α → β}
576+
(h₁ : a = a') (h₂ : ∀ x, f x = f' x) :
577+
(have x := a; f x) = (have x := a'; f' x)
578+
```
579+
the kernel sees that the type of `have_congr (fun x => b) (fun x => b') h₁ h₂`
580+
is `(have x := a; (fun x => b) x) = (have x := a'; (fun x => b') x)`,
581+
since when instantiating values it does not beta reduce at the same time.
582+
Hence, when `is_def_eq` is applied to the LHS and `have x := a; b` for example,
583+
it will do `whnf_core` to both sides.
584+
585+
Instead, we use the form `(fun x => b) a = (fun x => b') a'` in the proofs,
586+
which we can reliably match up without triggering beta reductions in the kernel.
587+
The zeta/beta reductions are then limited to the type hint for the entire telescope,
588+
when we convert this back into `have` form.
589+
In the base case, we include an optimization to avoid unnecessary zeta/beta reductions,
590+
by detecting a `simpHaveTelescope` proofs and removing the type hint.
540591
-/
541592
def simpHaveTelescope (e : Expr) : SimpM Result := do
542593
Prod.fst <$> withTraceNode `Debug.Meta.Tactic.simp (fun
@@ -549,7 +600,7 @@ def simpHaveTelescope (e : Expr) : SimpM Result := do
549600
if r.modified && debug.simp.check.get (← getOptions) then
550601
check r.expr
551602
check r.proof
552-
return (r.toResult, used, fixed, r.modified)
603+
return (r.toResult info.level e, used, fixed, r.modified)
553604
where
554605
/-
555606
Re-enters the telescope recorded in `info` using `withExistingLocalDecls` while simplifying according to `fixed`/`used`.
@@ -579,18 +630,20 @@ where
579630
We use a dummy `x` for debugging purposes. (Recall that `Expr.abstract` skips non-fvar/mvars.)
580631
-/
581632
let x := Expr.const `_simp_let_unused_dummy []
582-
let { expr, exprType, proof, modified } ← simpHaveTelescopeAux info fixed used b (i + 1) (xs.push x)
583-
let expr := expr.lowerLooseBVars 1 1
584-
let exprType := exprType.lowerLooseBVars 1 1
585-
if modified then
586-
let proof := mkApp6 (mkConst ``have_unused_dep us) t exprType v (mkLambda n .default t b) expr
587-
(mkLambda n .default t proof)
588-
return { expr, exprType, proof, modified := true }
633+
let rb ← simpHaveTelescopeAux info fixed used b (i + 1) (xs.push x)
634+
let expr := rb.expr.lowerLooseBVars 1 1
635+
let exprType := rb.exprType.lowerLooseBVars 1 1
636+
let exprInit := Expr.app (Expr.lam n t rb.exprInit .default) v
637+
let exprResult := rb.exprResult.lowerLooseBVars 1 1
638+
if rb.modified then
639+
let proof := mkApp6 (mkConst ``have_unused_dep' us) t exprType v (mkLambda n .default t rb.exprInit) expr
640+
(mkLambda n .default t rb.proof)
641+
return { expr, exprType, exprInit, exprResult, proof, modified := true }
589642
else
590-
-- If not modified, this must have been a non-transitively unused `have`.
591-
let proof := mkApp6 (mkConst ``have_unused us) t exprType v expr expr
643+
-- If not modified, this must have been a non-transitively unused `have`, so no need for `dep` form.
644+
let proof := mkApp6 (mkConst ``have_unused' us) t exprType v expr expr
592645
(mkApp2 (mkConst ``Eq.refl [info.level]) exprType expr)
593-
return { expr, exprType, proof, modified := true }
646+
return { expr, exprType, exprInit, exprResult, proof, modified := true }
594647
else if fixed.contains i then
595648
/-
596649
Fixed `have` (like `CongrArgKind.fixed`): dsimp the value and simp the body.
@@ -602,15 +655,21 @@ where
602655
let v' := if vModified then val'.abstract xs else v
603656
withExistingLocalDecls [hinfo.decl] <| withNewLemmas #[x] do
604657
let rb ← simpHaveTelescopeAux info fixed used b (i + 1) (xs.push x)
605-
let expr := Expr.letE n t v' rb.expr true
606-
let exprType := Expr.letE n t v' rb.exprType true
658+
let expr := Expr.app (Expr.lam n t rb.expr .default) v'
659+
let exprType := Expr.app (Expr.lam n t rb.exprType .default) v'
660+
let exprInit := Expr.app (Expr.lam n t rb.exprInit .default) v
661+
let exprResult := Expr.letE n t v' rb.exprResult true
662+
-- Note: if `vModified`, then the kernel will reduce the telescopes and potentially do an expensive defeq check.
663+
-- If this is a performance issue, we could try using a `letFun` encoding here make use of the `is_def_eq_args` optimization.
664+
-- Namely, to guide the defeqs via the sequence
665+
-- `(fun x => b) v` = `letFun (fun x => b) v` = `letFun (fun x => b) v'` = `(fun x => b) v'`
607666
if rb.modified then
608-
let proof := mkApp6 (mkConst ``have_body_congr_dep us) t (mkLambda n .default t rb.exprType) v'
609-
(mkLambda n .default t b) (mkLambda n .default t rb.expr) (mkLambda n .default t rb.proof)
610-
return { expr, exprType, proof, modified := true }
667+
let proof := mkApp6 (mkConst ``have_body_congr_dep' us) t (mkLambda n .default t rb.exprType) v'
668+
(mkLambda n .default t rb.exprInit) (mkLambda n .default t rb.expr) (mkLambda n .default t rb.proof)
669+
return { expr, exprType, exprInit, exprResult, proof, modified := true }
611670
else
612671
let proof := mkApp2 (mkConst ``Eq.refl [info.level]) exprType expr
613-
return { expr, exprType, proof, modified := vModified }
672+
return { expr, exprType, exprInit, exprResult, proof, modified := vModified }
614673
else
615674
/-
616675
Non-fixed `have` (like `CongrArgKind.eq`): simp both the value and the body.
@@ -626,36 +685,52 @@ where
626685
pure <| mkApp2 (mkConst `Eq.refl [hinfo.level]) t v
627686
withExistingLocalDecls [hinfo.decl] <| withNewLemmas #[x] do
628687
let rb ← simpHaveTelescopeAux info fixed used b (i + 1) (xs.push x)
629-
let expr := Expr.letE n t v' rb.expr true
688+
let expr := Expr.app (Expr.lam n t rb.expr .default) v'
630689
let exprType := rb.exprType.lowerLooseBVars 1 1
690+
let exprInit := Expr.app (Expr.lam n t rb.exprInit .default) v
691+
let exprResult := Expr.letE n t v' rb.exprResult true
631692
match valModified, rb.modified with
632693
| false, false =>
633694
let proof := mkApp2 (mkConst `Eq.refl [info.level]) exprType expr
634-
return { expr, exprType, proof, modified := false }
695+
return { expr, exprType, exprInit, exprResult, proof, modified := false }
635696
| true, false =>
636-
let proof := mkApp6 (mkConst ``have_val_congr us) t exprType v v'
637-
(mkLambda n .default t b) vproof
638-
return { expr, exprType, proof, modified := true }
697+
let proof := mkApp6 (mkConst ``have_val_congr' us) t exprType v v'
698+
(mkLambda n .default t rb.exprInit) vproof
699+
return { expr, exprType, exprInit, exprResult, proof, modified := true }
639700
| false, true =>
640-
let proof := mkApp6 (mkConst ``have_body_congr us) t exprType v
641-
(mkLambda n .default t b) (mkLambda n .default t rb.expr) (mkLambda n .default t rb.proof)
642-
return { expr, exprType, proof, modified := true }
701+
let proof := mkApp6 (mkConst ``have_body_congr' us) t exprType v
702+
(mkLambda n .default t rb.exprInit) (mkLambda n .default t rb.expr) (mkLambda n .default t rb.proof)
703+
return { expr, exprType, exprInit, exprResult, proof, modified := true }
643704
| true, true =>
644-
let proof := mkApp8 (mkConst ``have_congr us) t exprType v v'
645-
(mkLambda n .default t b) (mkLambda n .default t rb.expr) vproof (mkLambda n .default t rb.proof)
646-
return { expr, exprType, proof, modified := true }
705+
let proof := mkApp8 (mkConst ``have_congr' us) t exprType v v'
706+
(mkLambda n .default t rb.exprInit) (mkLambda n .default t rb.expr) vproof (mkLambda n .default t rb.proof)
707+
return { expr, exprType, exprInit, exprResult, proof, modified := true }
647708
else
648709
-- Base case: simplify the body.
649710
trace[Debug.Meta.Tactic.simp] "have telescope; simplifying body {info.body}"
650711
let r ← simp info.body
651712
let exprType := info.bodyType.abstract xs
652713
if r.expr == info.body then
653714
let proof := mkApp2 (mkConst `Eq.refl [info.level]) exprType e
654-
return { expr := e, exprType, proof, modified := false }
715+
return { expr := e, exprType, exprInit := e, exprResult := e, proof, modified := false }
655716
else
656717
let expr := r.expr.abstract xs
657718
let proof := (← r.getProof).abstract xs
658-
return { expr, exprType, proof, modified := true }
719+
-- Optimization: if the proof is a `simpHaveTelescope` proof, then remove the type hint
720+
-- to avoid the zeta/beta reductions that the kernel would otherwise do.
721+
-- In `SimpHaveResult.toResult` we insert two type hints; the inner one
722+
-- encodes the `exprInit` and `expr`.
723+
let detectSimpHaveLemma (proof : Expr) : Option SimpHaveResult := do
724+
let_expr id eq proof' := proof | failure
725+
guard <| eq.isAppOfArity ``Eq 3
726+
let_expr id eq' proof'' := proof' | failure
727+
let_expr Eq _ lhs rhs := eq' | failure
728+
let .const n _ := proof''.getAppFn | failure
729+
guard (n matches ``have_unused_dep' | ``have_unused' | ``have_body_congr_dep' | ``have_val_congr' | ``have_body_congr' | ``have_congr')
730+
return { expr := rhs, exprType, exprInit := lhs, exprResult := expr, proof := proof'', modified := true }
731+
if let some res := detectSimpHaveLemma proof then
732+
return res
733+
return { expr, exprType, exprInit := e, exprResult := expr, proof, modified := true }
659734

660735
/--
661736
Routine for simplifying `let` expressions.

tests/lean/run/simpHave.lean

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,10 @@ example (n : Nat) (h : n = 190) : lp 20 0 = n := by
611611
simp only
612612
simp [h]
613613

614+
-- set_option Elab.async false
615+
-- set_option profiler true
616+
-- set_option profiler.threshold 2
617+
-- #time
614618
set_option debug.simp.check false in
615619
example (n : Nat) (h : n = 4950) : lp 100 0 = n := by
616620
simp -zeta -zetaUnused only [lp]

tests/lean/simpZetaFalse.lean.expected.out

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ fun x h =>
1212
Eq.mpr
1313
(id
1414
(congrArg (fun x => x = 1)
15-
(have_congr (Nat.zero_add (x * x)) fun y =>
16-
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
17-
Eq.refl (y + 1))))
15+
(id
16+
(id
17+
(have_congr' (Nat.zero_add (x * x)) fun y =>
18+
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
19+
Eq.refl (y + 1))))))
1820
(of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1)))
1921
x : Nat
2022
h : f (f x) = x
@@ -30,9 +32,11 @@ fun x h =>
3032
Eq.mpr
3133
(id
3234
(congrArg (fun x => x = 1)
33-
(have_body_congr (x * x) fun y =>
34-
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
35-
Eq.refl (y + 1))))
35+
(id
36+
(id
37+
(have_body_congr' (x * x) fun y =>
38+
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
39+
Eq.refl (y + 1))))))
3640
(of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1)))
3741
x z : Nat
3842
h : f (f x) = x
@@ -47,7 +51,8 @@ theorem ex2 : ∀ (x z : Nat),
4751
y) =
4852
z :=
4953
fun x z h h' =>
50-
Eq.mpr (id (congrArg (fun x => x = z) (have_val_congr h))) (of_eq_true (Eq.trans (congrArg (Eq x) h') (eq_self x)))
54+
Eq.mpr (id (congrArg (fun x => x = z) (id (id (have_val_congr' h)))))
55+
(of_eq_true (Eq.trans (congrArg (Eq x) h') (eq_self x)))
5156
x z : Nat
5257
⊢ (let α := Nat;
5358
fun x => 0 + x) =
@@ -63,5 +68,6 @@ theorem ex4 : ∀ (p : Prop),
6368
fun x => x = x) =
6469
fun z => p :=
6570
fun p h =>
66-
Eq.mpr (id (congrArg (fun x => x = fun z => p) (have_body_congr_dep 10 fun n => funext fun x => eq_self x)))
71+
Eq.mpr
72+
(id (congrArg (fun x => x = fun z => p) (id (id (have_body_congr_dep' 10 fun n => funext fun x => eq_self x)))))
6773
(of_eq_true (Eq.trans (congrArg (Eq fun x => True) (funext fun z => eq_true h)) (eq_self fun x => True)))

0 commit comments

Comments
 (0)