Skip to content

Commit e551596

Browse files
authored
fix: FunInd: preserve have (#10519)
This PR improves FunInd by preserving the `nondep` flag on `.letE`, which makes it more likely that subsequent matcher transformations work. Fixes #10516.
1 parent ca1101d commit e551596

File tree

6 files changed

+88
-37
lines changed

6 files changed

+88
-37
lines changed

src/Lean/Meta/Match/MatcherApp/Transform.lean

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,8 @@ def transform
285285
let aux1 := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) params'
286286
let aux1 := mkApp aux1 motive'
287287
let aux1 := mkAppN aux1 discrs'
288-
unless (← isTypeCorrect aux1) do
289-
prependError m!"failed to transform matcher, type error when constructing new pre-splitter motive:{indentExpr aux1}\nfailed with" do
290-
check aux1
288+
prependError m!"failed to transform matcher, type error when constructing new pre-splitter motive:{indentExpr aux1}\nfailed with" do
289+
check aux1
291290
let origAltTypes ← inferArgumentTypesN matcherApp.alts.size aux1
292291

293292
-- We replace the matcher with the splitter

src/Lean/Meta/Tactic/FunInd.lean

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
377377
mkForallFVars #[x] body'
378378

379379
| .letE n t v b nondep =>
380+
trace[Meta.FunInd] "Let-binding {n} with (nondep := {nondep})"
380381
let t' ← foldAndCollect oldIH newIH isRecCall t
381382
let v' ← foldAndCollect oldIH newIH isRecCall v
382383
withLetDecl n t' v' (nondep := nondep) fun x => do
@@ -520,7 +521,7 @@ def buildInductionCase (oldIH newIH : FVarId) (isRecCall : Expr → Option Expr)
520521
Like `mkLambdaFVars (usedOnly := true)`, but
521522
522523
* silently skips expression in `xs` that are not `.isFVar`
523-
* also skips let-bound variabls
524+
* also skips let-bound variables
524525
* returns a mask (same size as `xs`) indicating which variables have been abstracted
525526
(`true` means was abstracted).
526527
@@ -631,11 +632,17 @@ public def rwIfWith (hc : Expr) (e : Expr) : MetaM Simp.Result := do
631632
return { expr := e }
632633

633634
def rwLetWith (h : Expr) (e : Expr) : MetaM Simp.Result := do
634-
if e.isLet then
635-
if (← isDefEq e.letValue! h) then
636-
return { expr := e.letBody!.instantiate1 h }
637-
trace[Meta.FunInd] "rwLetWith failed:{inlineExpr e}not a let expression or `{h}` is not definitionally equal to `{e.letValue!}`"
638-
return { expr := e }
635+
match e with
636+
| .letE _ t v b nondep =>
637+
unless (← isDefEq t (← inferType h)) do
638+
trace[Meta.FunInd] "rwLetWith failed:The type of{inlineExpr h}is not definitionally equal to `{t}`"
639+
unless nondep do
640+
unless (← isDefEq v h) do
641+
trace[Meta.FunInd] "rwLetWith failed:{inlineExpr h}is not definitionally equal to{inlineExpr v}"
642+
return { expr := b.instantiate1 h }
643+
| _ =>
644+
trace[Meta.FunInd] "rwLetWith failed:{inlineExpr e}not a let expression"
645+
return { expr := e }
639646

640647
def rwMData (e : Expr) : MetaM Simp.Result := do
641648
return { expr := e.consumeMData }
@@ -735,7 +742,7 @@ as `MVars` as it goes.
735742
partial def buildInductionBody (toErase toClear : Array FVarId) (goal : Expr)
736743
(oldIH newIH : FVarId) (isRecCall : Expr → Option Expr) (e : Expr) : M2 Expr := do
737744
withTraceNode `Meta.FunInd
738-
(pure m!"{exceptEmoji ·} buildInductionBody: {oldIH.name} → {newIH.name}\ngoal: {goal}:{indentExpr e}") do
745+
(pure m!"{exceptEmoji ·} buildInductionBody: {oldIH.name} → {newIH.name}\ngoal:{indentExpr goal}\nexpr:{indentExpr e}") do
739746

740747
-- if-then-else cause case split:
741748
match_expr e with
@@ -863,13 +870,14 @@ partial def buildInductionBody (toErase toClear : Array FVarId) (goal : Expr)
863870
buildInductionBody toErase toClear goal' oldIH newIH isRecCall e.mdataExpr!
864871
return e.updateMData! b
865872

866-
if let .letE n t v b _ := e then
873+
if let .letE n t v b nondep := e then
874+
trace[Meta.FunInd] "Let-binding {n} with (nondep := {nondep})"
867875
let t' ← foldAndCollect oldIH newIH isRecCall t
868876
let v' ← foldAndCollect oldIH newIH isRecCall v
869-
return ← withLetDecl n t' v' fun x => M2.branch do
877+
return ← withLetDecl n t' v' (nondep := nondep) fun x => M2.branch do
870878
let b' ← withRewrittenMotiveArg goal (rwLetWith x) fun goal' =>
871879
buildInductionBody toErase toClear goal' oldIH newIH isRecCall (b.instantiate1 x)
872-
mkLetFVars #[x] b'
880+
mkLetFVars (generalizeNondepLet := false) #[x] b'
873881

874882
-- Special case for traversing the PProd’ed bodies in our encoding of structural mutual recursion
875883
if let .lam n t b bi := e then
@@ -903,19 +911,25 @@ do not handle delayed assignments correctly.
903911
-/
904912
def abstractIndependentMVars (mvars : Array MVarId) (index : Nat) (e : Expr) : MetaM Expr := do
905913
trace[Meta.FunInd] "abstractIndependentMVars, to revert after {index}, original mvars: {mvars}"
906-
let mvars ← mvars.mapM fun mvar => do
907-
let mvar ← cleanupAfter mvar index
908-
mvar.withContext do
909-
let fvarIds := (← getLCtx).foldl (init := #[]) (start := index+1) fun fvarIds decl => fvarIds.push decl.fvarId
910-
let (_, mvar) ← mvar.revert fvarIds
911-
pure mvar
912-
trace[Meta.FunInd] "abstractIndependentMVars, reverted mvars: {mvars}"
914+
let mvars ← mvars.mapM (cleanupAfter · index)
913915
let names := Array.ofFn (n := mvars.size) fun ⟨i,_⟩ => .mkSimple s!"case{i+1}"
914-
let types ← mvars.mapM MVarId.getType
916+
let types ← mvars.mapM fun mvar => do
917+
mvar.withContext do
918+
let goal ← mvar.getType
919+
let xs := (← getLCtx).foldl (init := #[]) (start := index+1) fun fvarIds decl =>
920+
fvarIds.push (mkFVar decl.fvarId)
921+
mkForallFVars (generalizeNondepLet := true) xs goal
922+
trace[Meta.FunInd] "abstractIndependentMVars, reverted types: {types}"
915923
Meta.withLocalDeclsDND (names.zip types) fun xs => do
916-
for mvar in mvars, x in xs do
917-
mvar.assign x
918-
mkLambdaFVars xs (← instantiateMVars e)
924+
for mvar in mvars, x in xs do
925+
mvar.withContext do
926+
let e := (← getLCtx).foldl (init := x) (start := index+1) fun e decl =>
927+
if decl.isLet (allowNondep := false) then
928+
e
929+
else
930+
.app e (mkFVar decl.fvarId)
931+
mvar.assign e
932+
mkLambdaFVars xs (← instantiateMVars e)
919933

920934
/--
921935
Given a unary definition `foo` defined via `WellFounded.fixF`, derive a suitable induction principle

tests/lean/run/doLogicTests.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@ theorem fib_triple : ⦃⌜True⌝⦄ fib_impl n ⦃⇓ r => ⌜r = fib_spec n
241241

242242
theorem fib_triple_cases : ⦃⌜True⌝⦄ fib_impl n ⦃⇓ r => ⌜r = fib_spec n⌝⦄ := by
243243
apply fib_impl.fun_cases n _ ?case1 ?case2
244-
case case1 => rintro rfl; mintro -; simp only [fib_impl, ↓reduceIte]; mspec
245-
intro h
244+
case case1 => rintro _ rfl; mintro -; simp only [fib_impl, ↓reduceIte]; mspec
245+
intro _ h
246246
mintro -
247247
simp only [fib_impl, h, reduceIte]
248248
mspec
@@ -261,8 +261,8 @@ theorem fib_impl_vcs
261261
(I n hn).1 ⟨⟨pref, cur::suff, by simp[h]⟩, r⟩ ⊢ₛ (I n hn).1 ⟨⟨pref ++ [cur], suff, by simp[h]⟩, r.2, r.1+r.2⟩)
262262
: ⊢ₛ wp⟦fib_impl n⟧ (Q n) := by
263263
apply fib_impl.fun_cases n _ ?case1 ?case2
264-
case case1 => intro h; simp only [fib_impl, h, ↓reduceIte]; mstart; mspec
265-
intro hn
264+
case case1 => intro _ h; simp only [fib_impl, h, ↓reduceIte]; mstart; mspec
265+
intro _ hn
266266
simp only [fib_impl, hn, ↓reduceIte]
267267
mstart
268268
mspec

tests/lean/run/funind_tests.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ termination_by n => n
8181
info: have_tailrec.induct (motive : Nat → Prop) (case1 : motive 0) (case2 : ∀ (n : Nat), n < n + 1 → motive n → motive n.succ)
8282
(a✝ : Nat) : motive a✝
8383
-/
84-
#guard_msgs in
84+
#guard_msgs(pass trace, all) in
8585
#check have_tailrec.induct
8686

8787
set_option linter.unusedVariables false in

tests/lean/run/funind_unfolding.lean

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,31 @@ def fib'' (n : Nat) : Nat :=
5454
0
5555

5656
/--
57-
info: fib''.fun_cases_unfolding (n : Nat) (motive : Nat → Prop) (case1 : n < 2 → motive n)
57+
info: fib''.induct_unfolding (motive : NatNat → Prop) (case1 : ∀ (x : Nat), x < 2 → motive x x)
5858
(case2 :
59-
¬n < 2 →
60-
have foo := n - 2;
61-
foo < 100 → motive (fib'' (n - 1) + fib'' foo))
59+
∀ (x : Nat),
60+
¬x < 2 →
61+
have foo := x - 2;
62+
foo < 100 → motive (x - 1) (fib'' (x - 1)) → motive foo (fib'' foo) → motive x (fib'' (x - 1) + fib'' foo))
6263
(case3 :
63-
¬n < 2 →
64-
have foo := n - 2;
65-
¬foo < 100 → motive 0) :
66-
motive (fib'' n)
64+
∀ (x : Nat),
65+
¬x < 2 →
66+
have foo := x - 2;
67+
¬foo < 100 → motive x 0)
68+
(n : Nat) : motive n (fib'' n)
69+
-/
70+
#guard_msgs(pass trace, all) in
71+
#check fib''.induct_unfolding
72+
73+
/--
74+
info: fib''.fun_cases_unfolding (n : Nat) (motive : Nat → Prop) (case1 : n < 2 → motive n)
75+
(case2 : ¬n < 2 → ∀ (foo : Nat), foo < 100 → motive (fib'' (n - 1) + fib'' foo))
76+
(case3 : ¬n < 2 → ∀ (foo : Nat), ¬foo < 100 → motive 0) : motive (fib'' n)
6777
-/
6878
#guard_msgs(pass trace, all) in
6979
#check fib''.fun_cases_unfolding
7080

81+
7182
-- set_option trace.Meta.FunInd true in
7283
def filter (p : Nat → Bool) : List Nat → List Nat
7384
| [] => []

tests/lean/run/issue10516.lean

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
-- set_option trace.Meta.FunInd true
2+
3+
def bla (a : Nat) (h : 0 < a) : Nat :=
4+
if h : a ≤ 1 then
5+
0
6+
else
7+
have : 0 ≠ a := by omega
8+
match a with
9+
| 0 => False.elim (by contradiction)
10+
| a + 1 => bla a (by omega)
11+
termination_by structural a
12+
13+
/--
14+
info: theorem bla.induct_unfolding : ∀ (motive : (a : Nat) → 0 < a → Nat → Prop),
15+
(∀ (t : Nat) (h : 0 < t), t ≤ 1 → motive t h 0) →
16+
(∀ (a : Nat) (h : 0 < a + 1) (h_1 : ¬a + 1 ≤ 1), 0 ≠ a + 1 → motive a ⋯ (bla a ⋯) → motive a.succ h (bla a ⋯)) →
17+
∀ (a : Nat) (h : 0 < a), motive a h (bla a h)
18+
-/
19+
#guard_msgs(pass trace, all) in #print sig bla.induct_unfolding
20+
21+
/--
22+
info: theorem bla.induct : ∀ (motive : (a : Nat) → 0 < a → Prop),
23+
(∀ (t : Nat) (h : 0 < t), t ≤ 1 → motive t h) →
24+
(∀ (a : Nat) (h : 0 < a + 1) (h_1 : ¬a + 1 ≤ 1), 0 ≠ a + 1 → motive a ⋯ → motive a.succ h) →
25+
∀ (a : Nat) (h : 0 < a), motive a h
26+
-/
27+
#guard_msgs(pass trace, all) in #print sig bla.induct

0 commit comments

Comments
 (0)