Skip to content

Commit 8655f77

Browse files
authored
refactor: structural recursion: prove .eq_def directly (#10606)
This PR changes how Lean proves the equational theorems for structural recursion. The core idea is to let-bind the `f` argument to `brecOn` and rewriting `.brecOn` with an unfolding theorem. This means no extra case split for the `.rec` in `.brecOn` is needed, and `simp` doesn't change the `f` argument which can break the definitional equality with the defined function. With this, we can prove the unfolding theorem first, and derive the equational theorems from that, like for all other ways of defining recursive functions. Backs out the changes from #10415, the old strategy works well with the new goals. Fixes #5667 Fixes #10431 Fixes #10195 Fixes #2962
1 parent 5c92ffc commit 8655f77

File tree

22 files changed

+834
-356
lines changed

22 files changed

+834
-356
lines changed

src/Init/Data/Dyadic/Basic.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ prelude
99
public import Init.Data.Rat.Lemmas
1010
import Init.Data.Int.Bitwise.Lemmas
1111
import Init.Data.Int.DivMod.Lemmas
12+
import Init.Hints
1213

1314
/-!
1415
# The dyadic rationals

src/Lean/Elab/PreDefinition/Eqns.lean

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,11 @@ partial def splitMatch? (mvarId : MVarId) (declNames : Array Name) : MetaM (Opti
101101
target declNames badCases then
102102
try
103103
Meta.Split.splitMatch mvarId e
104-
catch _ =>
104+
catch ex =>
105+
trace[Elab.definition.eqns] "cannot split {e}\n{ex.toMessageData}"
105106
go (badCases.insert e)
106107
else
107-
trace[Meta.Tactic.split] "did not find term to split\n{MessageData.ofGoal mvarId}"
108+
trace[Elab.definition.eqns] "did not find term to split\n{MessageData.ofGoal mvarId}"
108109
return none
109110
go {}
110111

@@ -288,7 +289,7 @@ public def deltaLHS (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
288289
let some lhs ← delta? lhs | throwTacticEx `deltaLHS mvarId "failed to delta reduce lhs"
289290
mvarId.replaceTargetDefEq (← mkEq lhs rhs)
290291

291-
def deltaRHS? (mvarId : MVarId) (declName : Name) : MetaM (Option MVarId) := mvarId.withContext do
292+
public def deltaRHS? (mvarId : MVarId) (declName : Name) : MetaM (Option MVarId) := mvarId.withContext do
292293
let target ← mvarId.getType'
293294
let some (_, lhs, rhs) := target.eq? | return none
294295
let some rhs ← delta? rhs.consumeMData (· == declName) | return none
@@ -347,7 +348,7 @@ private def unfoldLHS (declName : Name) (mvarId : MVarId) : MetaM MVarId := mvar
347348
deltaLHS mvarId
348349

349350
private partial def mkEqnProof (declName : Name) (type : Expr) (tryRefl : Bool) : MetaM Expr := do
350-
trace[Elab.definition.eqns] "proving: {type}"
351+
withTraceNode `Elab.definition.eqns (return m!"{exceptEmoji ·} proving:{indentExpr type}") do
351352
withNewMCtxDepth do
352353
let main ← mkFreshExprSyntheticOpaqueMVar type
353354
let (_, mvarId) ← main.mvarId!.intros
@@ -371,7 +372,7 @@ private partial def mkEqnProof (declName : Name) (type : Expr) (tryRefl : Bool)
371372
recursion and structural recursion can and should use this too.
372373
-/
373374
go (mvarId : MVarId) : MetaM Unit := do
374-
trace[Elab.definition.eqns] "step\n{MessageData.ofGoal mvarId}"
375+
withTraceNode `Elab.definition.eqns (return m!"{exceptEmoji ·} step:\n{MessageData.ofGoal mvarId}") do
375376
if (← tryURefl mvarId) then
376377
return ()
377378
else if (← tryContradiction mvarId) then

src/Lean/Elab/PreDefinition/PartialFixpoint/Eqns.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def getUnfoldFor? (declName : Name) : MetaM (Option Name) := do
117117

118118
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
119119
if let some info := eqnInfoExt.find? (← getEnv) declName then
120-
mkEqns declName info.declNames
120+
mkEqns declName info.declNames (tryRefl := false)
121121
else
122122
return none
123123

src/Lean/Elab/PreDefinition/Structural/Eqns.lean

Lines changed: 85 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import Lean.Meta.Tactic.Apply
1616
import Lean.Elab.PreDefinition.Basic
1717
import Lean.Elab.PreDefinition.Structural.Basic
1818
import Lean.Meta.Match.MatchEqs
19+
import Lean.Meta.Tactic.Rewrite
1920

2021
namespace Lean.Elab
2122
open Meta
@@ -29,114 +30,115 @@ public structure EqnInfo extends EqnInfoCore where
2930
fixedParamPerms : FixedParamPerms
3031
deriving Inhabited
3132

32-
private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
33+
/--
34+
Searches in the lhs of goal for a `.brecOn` application, possibly with extra arguments
35+
and under `PProd` projections. Returns the `.brecOn` application and the context
36+
`(fun x => (x).1.2.3 extraArgs = rhs)`.
37+
-/
38+
partial def findBRecOnLHS (goal : Expr) : MetaM (Expr × Expr) := do
39+
let some (_, lhs, rhs) := goal.eq? | throwError "goal not an equality{indentExpr goal}"
40+
go lhs fun brecOnApp x c =>
41+
return (brecOnApp, ← mkLambdaFVars #[x] (← mkEq c rhs))
42+
where
43+
go {α} (e : Expr) (k : Expr → Expr → Expr → MetaM α) : MetaM α := e.withApp fun f xs => do
44+
if let .proj t n e := f then
45+
return ← go e fun brecOnApp x c => k brecOnApp x (mkAppN (mkProj t n c) xs)
46+
if let .const name _ := f then
47+
if isBRecOnRecursor (← getEnv) name then
48+
let arity ← forallTelescope (← inferType f) fun xs _ => return xs.size
49+
if arity ≤ xs.size then
50+
let brecOnApp := mkAppN f xs[:arity]
51+
let extraArgs := xs[arity:]
52+
return ← withLocalDeclD `x (← inferType brecOnApp) fun x =>
53+
k brecOnApp x (mkAppN x extraArgs)
54+
throwError "could not find `.brecOn` application in{indentExpr e}"
55+
56+
/--
57+
Creates the proof of the unfolding theorem for `declName` with type `type`. It
58+
59+
1. unfolds the function on the left to expose the `.brecOn` application
60+
2. rewrites that using the `.brecOn.eq` theorem, unrolling it once
61+
3. let-binds the last argument, which should be the `.brecOn.go` call of type `.below …`.
62+
This way subsequent steps (which may involve `simp`) do not touch it and do
63+
not break the definitional equality with the recursive calls on the RHS.
64+
4. repeatedly splits `match` statements (because on the left we have `match` statements with extra
65+
`.below` arguments, and on the right we have the original `match` statements) until the goal
66+
is solved using `rfl` or `contradiction`.
67+
-/
68+
partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
3369
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} proving:{indentExpr type}") do
70+
prependError m!"failed to generate equational theorem for `{.ofConstName declName}`" do
3471
withNewMCtxDepth do
3572
let main ← mkFreshExprSyntheticOpaqueMVar type
3673
let (_, mvarId) ← main.mvarId!.intros
3774
unless (← tryURefl mvarId) do -- catch easy cases
38-
go1 mvarId
75+
goUnfold (← deltaLHS mvarId)
3976
instantiateMVars main
4077
where
41-
/--
42-
Step 1: Split the function body into its cases, but keeping the LHS intact, because the
43-
`.below`-added `match` statements and the `.rec` can quickly confuse `split`.
44-
-/
45-
go1 (mvarId : MVarId) : MetaM Unit := do
46-
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} go1:\n{MessageData.ofGoal mvarId}") do
78+
goUnfold (mvarId : MVarId) : MetaM Unit := do
79+
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} goUnfold:\n{MessageData.ofGoal mvarId}") do
80+
let mvarId' ← mvarId.withContext do
81+
-- This should now be headed by `.brecOn`
82+
let goal ← mvarId.getType
83+
let (brecOnApp, context) ← findBRecOnLHS goal
84+
let brecOnName := brecOnApp.getAppFn.constName!
85+
let us := brecOnApp.getAppFn.constLevels!
86+
let brecOnThmName := brecOnName.str "eq"
87+
let brecOnAppArgs := brecOnApp.getAppArgs
88+
unless (← hasConst brecOnThmName) do
89+
throwError "no theorem `{brecOnThmName}`\n{MessageData.ofGoal mvarId}"
90+
-- We don't just `← inferType eqThmApp` as that beta-reduces more than we want
91+
let eqThmType ← inferType (mkConst brecOnThmName us)
92+
let eqThmType ← instantiateForall eqThmType brecOnAppArgs
93+
let some (_, _, rwRhs) := eqThmType.eq? | throwError "theorem `{brecOnThmName}` is not an equality\n{MessageData.ofGoal mvarId}"
94+
let recArg := rwRhs.getAppArgs.back!
95+
trace[Elab.definition.structural.eqns] "abstracting{inlineExpr recArg} from{indentExpr rwRhs}"
96+
let mvarId2 ← mvarId.define `r (← inferType recArg) recArg
97+
let (r, mvarId3) ← mvarId2.intro1P
98+
let mvarId4 ← mvarId3.withContext do
99+
let goal' := mkApp rwRhs.appFn! (mkFVar r)
100+
let thm ← mkCongrArg context (mkAppN (mkConst brecOnThmName us) brecOnAppArgs)
101+
mvarId3.replaceTargetEq (mkApp context goal') thm
102+
pure mvarId4
103+
go mvarId'
104+
105+
go (mvarId : MVarId) : MetaM Unit := do
106+
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} step:\n{MessageData.ofGoal mvarId}") do
47107
if (← tryURefl mvarId) then
48108
trace[Elab.definition.structural.eqns] "tryURefl succeeded"
49109
return ()
50110
else if (← tryContradiction mvarId) then
51111
trace[Elab.definition.structural.eqns] "tryContadiction succeeded"
52112
return ()
113+
else if let some mvarId ← whnfReducibleLHS? mvarId then
114+
trace[Elab.definition.structural.eqns] "whnfReducibleLHS succeeded"
115+
go mvarId
53116
else if let some mvarId ← simpMatch? mvarId then
54117
trace[Elab.definition.structural.eqns] "simpMatch? succeeded"
55-
go1 mvarId
118+
go mvarId
56119
else if let some mvarId ← simpIf? mvarId (useNewSemantics := true) then
57120
trace[Elab.definition.structural.eqns] "simpIf? succeeded"
58-
go1 mvarId
121+
go mvarId
59122
else
60123
let ctx ← Simp.mkContext
61124
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
62125
| TacticResultCNM.closed =>
63126
trace[Elab.definition.structural.eqns] "simpTargetStar closed the goal"
64127
| TacticResultCNM.modified mvarId =>
65128
trace[Elab.definition.structural.eqns] "simpTargetStar modified the goal"
66-
go1 mvarId
129+
go mvarId
67130
| TacticResultCNM.noChange =>
68-
if let some mvarIds ← casesOnStuckLHS? mvarId then
131+
if let some mvarId ← deltaRHS? mvarId declName then
132+
trace[Elab.definition.structural.eqns] "deltaRHS? succeeded"
133+
go mvarId
134+
else if let some mvarIds ← casesOnStuckLHS? mvarId then
69135
trace[Elab.definition.structural.eqns] "casesOnStuckLHS? succeeded"
70-
mvarIds.forM go1
136+
mvarIds.forM go
71137
else if let some mvarIds ← splitTarget? mvarId (useNewSemantics := true) then
72138
trace[Elab.definition.structural.eqns] "splitTarget? succeeded"
73-
mvarIds.forM go1
74-
else
75-
go2 (← deltaLHS mvarId)
76-
/-- Step 2: Unfold the lhs to expose the recursor. -/
77-
go2 (mvarId : MVarId) : MetaM Unit := do
78-
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} go2:\n{MessageData.ofGoal mvarId}") do
79-
if let some mvarId ← whnfReducibleLHS? mvarId then
80-
go2 mvarId
81-
else
82-
go3 mvarId
83-
/-- Step 3: Simplify the match and if statements on the left hand side, until we have rfl. -/
84-
go3 (mvarId : MVarId) : MetaM Unit := do
85-
withTraceNode `Elab.definition.structural.eqns (return m!"{exceptEmoji ·} go3:\n{MessageData.ofGoal mvarId}") do
86-
if (← tryURefl mvarId) then
87-
trace[Elab.definition.structural.eqns] "tryURefl succeeded"
88-
return ()
89-
else if (← tryContradiction mvarId) then
90-
trace[Elab.definition.structural.eqns] "tryContadiction succeeded"
91-
return ()
92-
else if let some mvarId ← simpMatch? mvarId then
93-
trace[Elab.definition.structural.eqns] "simpMatch? succeeded"
94-
go3 mvarId
95-
else if let some mvarId ← simpIf? mvarId (useNewSemantics := true) then
96-
trace[Elab.definition.structural.eqns] "simpIf? succeeded"
97-
go3 mvarId
98-
else
99-
let ctx ← Simp.mkContext
100-
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
101-
| TacticResultCNM.closed =>
102-
trace[Elab.definition.structural.eqns] "simpTargetStar closed the goal"
103-
| TacticResultCNM.modified mvarId =>
104-
trace[Elab.definition.structural.eqns] "simpTargetStar modified the goal"
105-
go3 mvarId
106-
| TacticResultCNM.noChange =>
107-
if let some mvarIds ← casesOnStuckLHS? mvarId then
108-
trace[Elab.definition.structural.eqns] "casesOnStuckLHS? succeeded"
109-
mvarIds.forM go3
139+
mvarIds.forM go
110140
else
111-
throwError "failed to generate equational theorem for `{.ofConstName declName}`\n{MessageData.ofGoal mvarId}"
112-
113-
def mkEqns (info : EqnInfo) : MetaM (Array Name) :=
114-
withOptions (tactic.hygienic.set · false) do
115-
let eqnTypes ← withNewMCtxDepth <| lambdaTelescope (cleanupAnnotations := true) info.value fun xs body => do
116-
let us := info.levelParams.map mkLevelParam
117-
let target ← mkEq (mkAppN (Lean.mkConst info.declName us) xs) body
118-
let goal ← mkFreshExprSyntheticOpaqueMVar target
119-
mkEqnTypes info.declNames goal.mvarId!
120-
let mut thmNames := #[]
121-
for h : i in *...eqnTypes.size do
122-
let type := eqnTypes[i]
123-
trace[Elab.definition.structural.eqns] "eqnType {i+1}: {type}"
124-
let name := mkEqLikeNameFor (← getEnv) info.declName s!"{eqnThmSuffixBasePrefix}{i+1}"
125-
thmNames := thmNames.push name
126-
-- determinism: `type` should be independent of the environment changes since `baseName` was
127-
-- added
128-
realizeConst info.declNames[0]! name (doRealize name type)
129-
return thmNames
130-
where
131-
doRealize name type := withOptions (tactic.hygienic.set · false) do
132-
let value ← withoutExporting do mkProof info.declName type
133-
let (type, value) ← removeUnusedEqnHypotheses type value
134-
let type ← letToHave type
135-
addDecl <| Declaration.thmDecl {
136-
name, type, value
137-
levelParams := info.levelParams
138-
}
139-
inferDefEqAttr name
141+
throwError "no progress at goal\n{MessageData.ofGoal mvarId}"
140142

141143
public builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ←
142144
mkMapDeclarationExtension (exportEntriesFn := fun env s _ =>
@@ -151,7 +153,7 @@ public def registerEqnsInfo (preDef : PreDefinition) (declNames : Array Name) (r
151153

152154
def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
153155
if let some info := eqnInfoExt.find? (← getEnv) declName then
154-
mkEqns info
156+
mkEqns declName info.declNames
155157
else
156158
return none
157159

@@ -165,11 +167,10 @@ where
165167
lambdaTelescope info.value fun xs body => do
166168
let us := info.levelParams.map mkLevelParam
167169
let type ← mkEq (mkAppN (Lean.mkConst declName us) xs) body
168-
let goal ← mkFreshExprSyntheticOpaqueMVar type
169-
mkUnfoldProof declName goal.mvarId!
170+
let value ← withoutExporting <| mkProof declName type
170171
let type ← mkForallFVars xs type
171172
let type ← letToHave type
172-
let value ← mkLambdaFVars xs (← instantiateMVars goal)
173+
let value ← mkLambdaFVars xs value
173174
addDecl <| Declaration.thmDecl {
174175
name, type, value
175176
levelParams := info.levelParams
@@ -187,6 +188,7 @@ def getStructuralRecArgPosImp? (declName : Name) : CoreM (Option Nat) := do
187188
let some info := eqnInfoExt.find? (← getEnv) declName | return none
188189
return some info.recArgPos
189190

191+
190192
builtin_initialize
191193
registerGetEqnsFn getEqnsFor?
192194
registerGetUnfoldEqnFn getUnfoldFor?

0 commit comments

Comments
 (0)