Skip to content

feat: transform nondependent lets into haves in declarations and equation lemmas #8373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/Lean/Elab/MutualDef.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,7 @@ where
finishElab headers
processDeriving headers
elabAsync header view declId := do
assert! view.kind.isTheorem
let env ← getEnv
let async ← env.addConstAsync declId.declName .thm
(exportedKind? := guard (!isPrivateName declId.declName) *> some .axiom)
Expand All @@ -1178,6 +1179,12 @@ where
s := collectLevelParams s type
let scopeLevelNames ← getLevelNames
let levelParams ← IO.ofExcept <| sortDeclLevelParams scopeLevelNames allUserLevelNames s.params

let type ← if cleanup.letToHave.get (← getOptions) then
withRef header.declId <| Meta.letToHave type
else
pure type

async.commitSignature { name := header.declName, levelParams, type }

-- attributes should be applied on the main thread; see below
Expand Down
43 changes: 38 additions & 5 deletions src/Lean/Elab/PreDefinition/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Lean.Util.NumApps
import Lean.Meta.AbstractNestedProofs
import Lean.Meta.ForEachExpr
import Lean.Meta.Eqns
import Lean.Meta.LetToHave
import Lean.Elab.RecAppSyntax
import Lean.Elab.DefView
import Lean.Elab.PreDefinition.TerminationHint
Expand All @@ -21,6 +22,11 @@ namespace Lean.Elab
open Meta
open Term

register_builtin_option cleanup.letToHave : Bool := {
defValue := true
descr := "Enables transforming `let`s to `have`s after elaborating declarations."
}

/--
A (potentially recursive) definition.
The elaborator converts it into Kernel definitions using many different strategies.
Expand Down Expand Up @@ -88,6 +94,31 @@ def applyAttributesOf (preDefs : Array PreDefinition) (applicationTime : Attribu
for preDef in preDefs do
applyAttributesAt preDef.declName preDef.modifiers.attrs applicationTime

/--
Applies `Meta.letToHave` to the values of defs, instances, and abbrevs.
Does not apply the transformation to values that are proofs, or to unsafe definitions.
-/
def letToHaveValue (preDef : PreDefinition) : MetaM PreDefinition := withRef preDef.ref do
if !cleanup.letToHave.get (← getOptions)
|| preDef.modifiers.isUnsafe
|| preDef.kind matches .theorem | .example | .opaque then
return preDef
else if ← Meta.isProp preDef.type then
return preDef
else
let value ← Meta.letToHave preDef.value
return { preDef with value }

/--
Applies `Meta.letToHave` to the type of the predef.
-/
def letToHaveType (preDef : PreDefinition) : MetaM PreDefinition := withRef preDef.ref do
if !cleanup.letToHave.get (← getOptions) || preDef.kind matches .example then
return preDef
else
let type ← Meta.letToHave preDef.type
return { preDef with type }

def abstractNestedProofs (preDef : PreDefinition) (cache := true) : MetaM PreDefinition := withRef preDef.ref do
if preDef.kind.isTheorem || preDef.kind.isExample then
pure preDef
Expand Down Expand Up @@ -138,9 +169,11 @@ private def checkMeta (preDef : PreDefinition) : TermElabM Unit := do
| _, _ => pure ()
return true

private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) (cacheProofs := true) : TermElabM Unit :=
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) (cacheProofs := true) (cleanupValue := false) : TermElabM Unit :=
withRef preDef.ref do
let preDef ← abstractNestedProofs (cache := cacheProofs) preDef
let preDef ← letToHaveType preDef
let preDef ← if cleanupValue then letToHaveValue preDef else pure preDef
let mkDefDecl : TermElabM Declaration :=
return Declaration.defnDecl {
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
Expand Down Expand Up @@ -185,11 +218,11 @@ private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List N
generateEagerEqns preDef.declName
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation

def addAndCompileNonRec (preDef : PreDefinition) (all : List Name := [preDef.declName]) : TermElabM Unit := do
addNonRecAux preDef (compile := true) (all := all)
def addAndCompileNonRec (preDef : PreDefinition) (all : List Name := [preDef.declName]) (cleanupValue := false) : TermElabM Unit := do
addNonRecAux preDef (compile := true) (all := all) (cleanupValue := cleanupValue)

def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all : List Name := [preDef.declName]) (cacheProofs := true) : TermElabM Unit := do
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) (all := all) (cacheProofs := cacheProofs)
def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all : List Name := [preDef.declName]) (cacheProofs := true) (cleanupValue := false) : TermElabM Unit := do
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) (all := all) (cacheProofs := cacheProofs) (cleanupValue := cleanupValue)

/--
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/PreDefinition/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
let preDef ← eraseRecAppSyntax preDefs[0]!
ensureEqnReservedNamesAvailable preDef.declName
if preDef.modifiers.isNoncomputable then
addNonRec preDef
addNonRec preDef (cleanupValue := true)
else
addAndCompileNonRec preDef
addAndCompileNonRec preDef (cleanupValue := true)
preDef.termination.ensureNone "not recursive"
else if preDefs.any (·.modifiers.isUnsafe) then
addAndCompileUnsafe preDefs
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Elab/PreDefinition/PartialFixpoint/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ where
catch e =>
throwError "failed to generate unfold theorem for '{declName}':\n{e.toMessageData}"
let type ← mkForallFVars xs type
let type ← letToHave type
let value ← mkLambdaFVars xs goal
addDecl <| Declaration.thmDecl {
name, type, value
Expand Down
2 changes: 2 additions & 0 deletions src/Lean/Elab/PreDefinition/Structural/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ where
doRealize name type := withOptions (tactic.hygienic.set · false) do
let value ← mkProof info.declName type
let (type, value) ← removeUnusedEqnHypotheses type value
let type ← letToHave type
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
Expand Down Expand Up @@ -126,6 +127,7 @@ where
let goal ← mkFreshExprSyntheticOpaqueMVar type
mkUnfoldProof declName goal.mvarId!
let type ← mkForallFVars xs type
let type ← letToHave type
let value ← mkLambdaFVars xs (← instantiateMVars goal)
addDecl <| Declaration.thmDecl {
name, type, value
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/PreDefinition/Structural/SmartUnfolding.lean
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ partial def addSmartUnfoldingDef (preDef : PreDefinition) (recArgPos : Nat) : Te
else
withEnableInfoTree false do
let preDefSUnfold ← addSmartUnfoldingDefAux preDef recArgPos
addNonRec preDefSUnfold
addNonRec preDefSUnfold (cleanupValue := true)

end Lean.Elab.Structural
2 changes: 2 additions & 0 deletions src/Lean/Elab/PreDefinition/WF/Unfold.lean
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessPr

let value ← instantiateMVars main
let type ← mkForallFVars xs type
let type ← letToHave type
let value ← mkLambdaFVars xs value
addDecl <| Declaration.thmDecl {
name, type, value
Expand Down Expand Up @@ -123,6 +124,7 @@ def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM U

let value ← instantiateMVars main
let type ← mkForallFVars xs type
let type ← letToHave type
let value ← mkLambdaFVars xs value
addDecl <| Declaration.thmDecl {
name, type, value
Expand Down
3 changes: 3 additions & 0 deletions src/Lean/Meta/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Lean.Meta.Basic
import Lean.Meta.AppBuilder
import Lean.Meta.Match.MatcherInfo
import Lean.DefEqAttrib
import Lean.Meta.LetToHave

namespace Lean.Meta

Expand Down Expand Up @@ -178,6 +179,8 @@ where doRealize name info := do
lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
let lhs := mkAppN (mkConst info.name <| info.levelParams.map mkLevelParam) xs
let type ← mkForallFVars xs (← mkEq lhs body)
-- Note: if this definition was added using `def`, then `letToHave` has already been applied to the body.
let type ← letToHave type
let value ← mkLambdaFVars xs (← mkEqRefl lhs)
addDecl <| Declaration.thmDecl {
name, type, value
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/LetToHave.lean
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ The `Meta.letToHave` trace class logs errors and messages.
def letToHave (e : Expr) : MetaM Expr := do
profileitM Exception "let-to-have transformation" (← getOptions) do
let e ← instantiateMVars e
LetToHave.main e
withoutExporting <| LetToHave.main e

builtin_initialize
registerTraceClass `Meta.letToHave
Expand Down
5 changes: 5 additions & 0 deletions src/Lean/Meta/Tactic/FunInd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ where doRealize (inductName : Name) := do

let eTyp ← inferType e'
let eTyp ← elimTypeAnnotations eTyp
let eTyp ← letToHave eTyp
-- logInfo m!"eTyp: {eTyp}"
let levelParams := (collectLevelParams {} eTyp).params
-- Prune unused level parameters, preserving the original order
Expand Down Expand Up @@ -1090,6 +1091,7 @@ def projectMutualInduct (unfolding : Bool) (names : Array Name) (mutualInduct :
let value ← PProdN.projM names.size idx value
mkLambdaFVars xs value
let type ← inferType value
let type ← letToHave type
addDecl <| Declaration.thmDecl { name := inductName, levelParams, type, value }

if idx == 0 then finalizeFirstInd
Expand Down Expand Up @@ -1248,6 +1250,7 @@ where doRealize inductName := do
check value
let type ← inferType value
let type ← elimOptParam type
let type ← letToHave type

addDecl <| Declaration.thmDecl
{ name := inductName, levelParams := ci.levelParams, type, value }
Expand Down Expand Up @@ -1480,6 +1483,7 @@ where doRealize inductName := do

let eTyp ← inferType e'
let eTyp ← elimTypeAnnotations eTyp
let eTyp ← letToHave eTyp
-- logInfo m!"eTyp: {eTyp}"
let levelParams := (collectLevelParams {} eTyp).params
-- Prune unused level parameters, preserving the original order
Expand Down Expand Up @@ -1623,6 +1627,7 @@ def deriveCases (unfolding : Bool) (name : Name) : MetaM Unit := do

let eTyp ← inferType e'
let eTyp ← elimTypeAnnotations eTyp
let eTyp ← letToHave eTyp
-- logInfo m!"eTyp: {eTyp}"
let levelParams := (collectLevelParams {} eTyp).params
-- Prune unused level parameters, preserving the original order
Expand Down
6 changes: 3 additions & 3 deletions src/Lean/ParserCompiler.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ partial def parserNodeKind? (e : Expr) : MetaM (Option Name) := do
let reduceEval? e : MetaM (Option Name) := do
try pure <| some (← reduceEval e) catch _ => pure none
let e ← whnfCore e
if e matches Expr.lam .. then
lambdaLetTelescope e fun _ e => parserNodeKind? e
if e matches Expr.lam .. | Expr.letE .. then
lambdaLetTelescope (preserveNondepLet := false) e fun _ e => parserNodeKind? e
else if e.isAppOfArity ``leadingNode 3 || e.isAppOfArity ``trailingNode 4 || e.isAppOfArity ``node 2 then
reduceEval? (e.getArg! 0)
else if e.isAppOfArity ``withAntiquot 2 then
Expand All @@ -61,7 +61,7 @@ variable {α} (ctx : Context α) (builtin : Bool) (force : Bool) in
partial def compileParserExpr (e : Expr) : MetaM Expr := do
let e ← whnfCore e
match e with
| .lam .. => mapLambdaLetTelescope e fun _ b => compileParserExpr b
| .lam .. | .letE .. => mapLambdaLetTelescope (preserveNondepLet := false) e fun _ b => compileParserExpr b
| .fvar .. => return e
| _ => do
let fn := e.getAppFn
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/445.lean.expected.out
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ i + 1
i + i * 2
i + i * r i j
i + i * r i j
let z := s i j;
have z := s i j;
z + z
4 changes: 2 additions & 2 deletions tests/lean/funind_errors.lean.expected.out
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ funind_errors.lean:22:7-22:23: error: unknown constant 'takeWhile.induct'
takeWhile.foo.induct.{u_1} {α : Type u_1} (p : α → Bool) (as : Array α) (motive : Nat → Array α → Prop)
(case1 :
∀ (i : Nat) (r : Array α) (h : i < as.size),
let a := as.get i h;
have a := as.get i h;
p a = true → motive (i + 1) (r.push a) → motive i r)
(case2 :
∀ (i : Nat) (r : Array α) (h : i < as.size),
let a := as.get i h;
have a := as.get i h;
¬p a = true → motive i r)
(case3 : ∀ (i : Nat) (r : Array α), ¬i < as.size → motive i r) (i : Nat) (r : Array α) : motive i r
funind_errors.lean:38:7-38:20: error: unknown constant 'idEven.induct'
Expand Down
14 changes: 13 additions & 1 deletion tests/lean/run/1026.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,21 @@ info: foo.eq_def (n : Nat) :
foo n =
if n = 0 then 0
else
let x := n - 1;
have x := n - 1;
have this := foo._proof_4;
foo x
-/
#guard_msgs in
#check foo.eq_def

/--
info: foo.eq_unfold :
foo = fun n =>
if n = 0 then 0
else
have x := n - 1;
have this := foo._proof_4;
foo x
-/
#guard_msgs in
#check foo.eq_unfold
2 changes: 1 addition & 1 deletion tests/lean/run/2058.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def foo : IO Unit := do
-- specializes to 0 on error
/--
info: def foo : IO Unit :=
let x := PUnit.unit.{0};
have x := PUnit.unit.{0};
pure.{0, 0} Unit.unit
-/
#guard_msgs in set_option pp.universes true in #print foo
Expand Down
6 changes: 3 additions & 3 deletions tests/lean/run/delabConst.lean
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def f (true : Bool) : Nat :=
/--
info: def MatchTest2.f : Bool → Nat :=
fun true =>
let Bool.true := 1;
have Bool.true := 1;
match true with
| _root_.Bool.true => 0
| false => 1
Expand Down Expand Up @@ -169,8 +169,8 @@ def f (true : Bool) :=
/--
info: def MatchTest3.f : Bool → Bool :=
fun true =>
let Bool.true := true;
let false := true;
have Bool.true := true;
have false := true;
match true with
| _root_.Bool.true => false
| Bool.false =>
Expand Down
4 changes: 2 additions & 2 deletions tests/lean/run/funind_tests.lean
Original file line number Diff line number Diff line change
Expand Up @@ -733,11 +733,11 @@ def foo (n : Nat) : Nat :=
info: Dite.foo.induct (motive : Nat → Prop)
(case1 :
∀ (x : Nat),
let j := x - 1;
have j := x - 1;
j < x → motive j → motive x)
(case2 :
∀ (x : Nat),
let j := x - 1;
have j := x - 1;
¬j < x → motive x)
(n : Nat) : motive n
-/
Expand Down
8 changes: 4 additions & 4 deletions tests/lean/run/funind_unfolding.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def fib'' (n : Nat) : Nat :=
info: fib''.fun_cases_unfolding (n : Nat) (motive : Nat → Prop) (case1 : n < 2 → motive n)
(case2 :
¬n < 2 →
let foo := n - 2;
have foo := n - 2;
foo < 100 → motive (fib'' (n - 1) + fib'' foo))
(case3 :
¬n < 2 →
let foo := n - 2;
have foo := n - 2;
¬foo < 100 → motive 0) :
motive (fib'' n)
-/
Expand Down Expand Up @@ -381,7 +381,7 @@ info: siftDown.induct_unfolding (e : Nat) (motive : (a : Array Int) → Nat →
∀ (a : Array Int) (root : Nat) (h : e ≤ a.size),
leftChild root < e →
let child := leftChild root;
let child := if x : child + 1 < e then if h : a[child] < a[child + 1] then child + 1 else child else child;
have child := if x : child + 1 < e then if h : a[child] < a[child + 1] then child + 1 else child else child;
¬a[root]! < a[child]! → motive a root h a)
(case3 : ∀ (a : Array Int) (root : Nat) (h : e ≤ a.size), ¬leftChild root < e → motive a root h a) (a : Array Int)
(root : Nat) (h : e ≤ a.size) : motive a root h (siftDown a root e h)
Expand All @@ -403,7 +403,7 @@ info: siftDown.induct (e : Nat) (motive : (a : Array Int) → Nat → e ≤ a.si
∀ (a : Array Int) (root : Nat) (h : e ≤ a.size),
leftChild root < e →
let child := leftChild root;
let child := if x : child + 1 < e then if h : a[child] < a[child + 1] then child + 1 else child else child;
have child := if x : child + 1 < e then if h : a[child] < a[child + 1] then child + 1 else child else child;
¬a[root]! < a[child]! → motive a root h)
(case3 : ∀ (a : Array Int) (root : Nat) (h : e ≤ a.size), ¬leftChild root < e → motive a root h) (a : Array Int)
(root : Nat) (h : e ≤ a.size) : motive a root h
Expand Down
6 changes: 3 additions & 3 deletions tests/lean/run/issue5767.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def go1 (ss : Int) (st0 : St) : Bool :=
info: go1.induct (ss : Int) (motive : St → Prop)
(case1 :
∀ (x : St),
let st1 := { m := x.m, map := x.map.insert };
have st1 := { m := x.m, map := x.map.insert };
∀ (val : Unit), st1.map.get? ss = some val → P st1 → P st1 → motive st1 → motive x)
(case2 :
∀ (x : St),
let st1 := { m := x.m, map := x.map.insert };
have st1 := { m := x.m, map := x.map.insert };
st1.map.get? ss = none → motive x)
(st0 : St) : motive st0
-/
Expand All @@ -55,7 +55,7 @@ def go2 (ss : Int) (st0 : St) : Bool :=
/--
info: go2.induct : ∀ (motive : St → Prop),
(∀ (x : St),
let st1 := { m := x.m, map := x.map.insert };
have st1 := { m := x.m, map := x.map.insert };
motive st1 → motive x) →
∀ (st0 : St), motive st0
-/
Expand Down
Loading
Loading