Skip to content

Commit 0173444

Browse files
authored
feat: heterogeneous contructor injectivity in grind (#11491)
This PR implements heterogeneous contructor injectivity in `grind`. Example: ```lean opaque double : Nat → Nat inductive Parity : Nat -> Type | even (n) : Parity (double n) | odd (n) : Parity (Nat.succ (double n)) opaque q : Nat → Nat → Prop axiom qax : q a b → double a = double b attribute [grind →] qax example (motive : (x : Nat) → Parity x → Sort u_1) (h_2 : (j : Nat) → motive (double j) (Parity.even j)) (j n : Nat) (heq_1 : q j n) -- Implies that `double j = double n` (heq_2 : Parity.even n ≍ Parity.even j): h_2 n ≍ h_2 j := by grind ``` Closes #11449
1 parent 1377da0 commit 0173444

File tree

4 files changed

+113
-23
lines changed

4 files changed

+113
-23
lines changed

src/Lean/Meta/Injective.lean

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private def mkEqs (args1 args2 : Array Expr) (skipIfPropOrEq : Bool := true) : M
4040
eqs := eqs.push (← mkEqHEq arg1 arg2)
4141
return eqs
4242

43-
private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do
43+
private def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do
4444
let us := ctorVal.levelParams.map mkLevelParam
4545
let type ← elimOptParam ctorVal.type
4646
forallBoundedTelescope type ctorVal.numParams fun params type =>
@@ -181,7 +181,7 @@ def mkInjectiveTheorems (declName : Name) : MetaM Unit := do
181181
builtin_initialize
182182
registerTraceClass `Meta.injective
183183

184-
private def getIndices? (ctorApp : Expr) : MetaM (Option (Array Expr)) := do
184+
def getCtorAppIndices? (ctorApp : Expr) : MetaM (Option (Array Expr)) := do
185185
let type ← whnfD (← inferType ctorApp)
186186
type.withApp fun typeFn typeArgs => do
187187
let .const declName _ := typeFn | return none
@@ -197,20 +197,20 @@ private structure MkHInjTypeResult where
197197
us : List Level
198198
numIndices : Nat
199199

200-
private partial def mkHInjType? (ctorVal : ConstructorVal) : MetaM (Option MkHInjTypeResult) := do
200+
private def mkHInjType? (ctorVal : ConstructorVal) : MetaM (Option MkHInjTypeResult) := do
201201
let us := ctorVal.levelParams.map mkLevelParam
202202
let type ← elimOptParam ctorVal.type
203203
forallBoundedTelescope type ctorVal.numParams fun params type =>
204-
forallTelescope type fun args1 resultType => do
204+
forallTelescope type fun args1 _ => do
205205
let k (args2 : Array Expr) : MetaM (Option MkHInjTypeResult) := do
206206
let lhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args1
207207
let rhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args2
208208
let eq ← mkEqHEq lhs rhs
209209
let eqs ← mkEqs args1 args2
210210
if let some andEqs := mkAnd? eqs then
211211
let result ← mkArrow eq andEqs
212-
let some idxs1 ← getIndices? lhs | return none
213-
let some idxs2 ← getIndices? rhs | return none
212+
let some idxs1 ← getCtorAppIndices? lhs | return none
213+
let some idxs2 ← getCtorAppIndices? rhs | return none
214214
-- **Note**: We dot not skip here because the type of `noConfusion` does not.
215215
let idxEqs ← mkEqs idxs1 idxs2 (skipIfPropOrEq := false)
216216
let result ← mkArrows idxEqs result
@@ -254,7 +254,6 @@ private partial def mkHInjectiveTheoremValue? (ctorVal : ConstructorVal) (typeIn
254254
let mvarId := mvar.mvarId!
255255
let (_, mvarId) ← mvarId.intros
256256
splitAndAssumption mvarId ctorVal.name
257-
check noConfusion
258257
let result ← instantiateMVars noConfusion
259258
mkLambdaFVars xs result
260259

src/Lean/Meta/Tactic/Grind/Ctor.lean

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
66
module
77
prelude
88
public import Lean.Meta.Tactic.Grind.Types
9+
import Lean.Meta.Injective
910
import Lean.Meta.Tactic.Grind.Simp
1011
public section
1112
namespace Lean.Meta.Grind
@@ -33,20 +34,13 @@ private partial def propagateInjEqs (eqs : Expr) (proof : Expr) (generation : Na
3334
reportIssue! "unexpected injectivity theorem result type{indentExpr eqs}"
3435
return ()
3536

36-
/--
37-
Given constructors `a` and `b`, propagate equalities if they are the same,
38-
and close goal if they are different.
39-
-/
40-
def propagateCtor (a b : Expr) : GoalM Unit := do
41-
let aType ← whnfD (← inferType a)
42-
let bType ← whnfD (← inferType b)
43-
unless (← isDefEqD aType bType) do
44-
return ()
37+
/-- Homogeneous case where constructor applications `a` and `b` have the same type `α`. -/
38+
private def propagateCtorHomo (α : Expr) (a b : Expr) : GoalM Unit := do
4539
let ctor₁ := a.getAppFn
4640
let ctor₂ := b.getAppFn
4741
if ctor₁ == ctor₂ then
48-
let .const ctorName _ := a.getAppFn | return ()
49-
let injDeclName := Name.mkStr ctorName "inj"
42+
let .const ctorName _ := ctor₁ | return ()
43+
let injDeclName := mkInjectiveTheoremNameFor ctorName
5044
unless (← getEnv).contains injDeclName do return ()
5145
let info ← getConstInfo injDeclName
5246
let n := info.type.getForallArity
@@ -62,9 +56,57 @@ def propagateCtor (a b : Expr) : GoalM Unit := do
6256
let gen := max (← getGeneration a) (← getGeneration b)
6357
propagateInjEqs injLemmaType injLemma gen
6458
else
65-
let .const declName _ := aType.getAppFn | return ()
59+
let .const declName _ := α.getAppFn | return ()
6660
let noConfusionDeclName := Name.mkStr declName "noConfusion"
6761
unless (← getEnv).contains noConfusionDeclName do return ()
6862
closeGoal (← mkNoConfusion (← getFalseExpr) (← mkEqProof a b))
6963

64+
/-- Heterogeneous case where constructor applications `a` and `b` have different types `α` and `β`. -/
65+
private def propagateCtorHetero (a b : Expr) : GoalM Unit := do
66+
a.withApp fun ctor₁ args₁ =>
67+
b.withApp fun ctor₂ args₂ => do
68+
let .const ctorName₁ us₁ := ctor₁ | return ()
69+
let .const ctorName₂ us₂ := ctor₂ | return ()
70+
let .ctorInfo ctorVal₁ ← getConstInfo ctorName₁ | return ()
71+
let .ctorInfo ctorVal₂ ← getConstInfo ctorName₂ | return ()
72+
unless ctorVal₁.induct == ctorVal₂.induct do return ()
73+
let params₁ := args₁[*...ctorVal₁.numParams]
74+
let params₂ := args₂[*...ctorVal₂.numParams]
75+
let fields₁ := args₁[ctorVal₁.numParams...*]
76+
let fields₂ := args₂[ctorVal₂.numParams...*]
77+
if h : params₁.size ≠ params₂.size then return () else
78+
unless (← params₁.size.allM fun i h => isDefEq params₁[i] params₂[i]) do return ()
79+
unless us₁.length == us₂.length do return ()
80+
unless (← us₁.zip us₂ |>.allM fun (u₁, u₂) => isLevelDefEq u₁ u₂) do return ()
81+
let gen := max (← getGeneration a) (← getGeneration b)
82+
if ctorName₁ == ctorName₂ then
83+
let hinjDeclName := mkHInjectiveTheoremNameFor ctorName₁
84+
unless (← getEnv).containsOnBranch hinjDeclName do
85+
let _ ← executeReservedNameAction hinjDeclName
86+
let proof := mkAppN (mkConst hinjDeclName us₁) params₁
87+
let proof := mkAppN (mkAppN proof fields₁) fields₂
88+
addNewRawFact proof (← inferType proof) gen (.inj (.decl hinjDeclName))
89+
else
90+
let some indices₁ ← getCtorAppIndices? a | return ()
91+
let some indices₂ ← getCtorAppIndices? b | return ()
92+
let noConfusionName := ctorVal₁.induct.str "noConfusion"
93+
let noConfusion := mkAppN (mkConst noConfusionName (0 :: us₁)) params₁
94+
let noConfusion := mkApp noConfusion (← getFalseExpr)
95+
let noConfusion := mkApp (mkAppN noConfusion indices₁) a
96+
let noConfusion := mkApp (mkAppN noConfusion indices₂) b
97+
let proof := noConfusion
98+
addNewRawFact proof (← inferType proof) gen (.inj (.decl noConfusionName))
99+
100+
/--
101+
Given constructors `a` and `b`, propagate equalities if they are the same,
102+
and close goal if they are different.
103+
-/
104+
def propagateCtor (a b : Expr) : GoalM Unit := do
105+
let aType ← whnfD (← inferType a)
106+
let bType ← whnfD (← inferType b)
107+
if (← isDefEqD aType bType) then
108+
propagateCtorHomo aType a b
109+
else
110+
propagateCtorHetero a b
111+
70112
end Lean.Meta.Grind

tests/lean/run/grind_11449.lean

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
opaque double : Nat → Nat
2+
def P (n : Nat) : Prop := n >= 0
3+
theorem pax (n : Nat) : P n := by grind [P]
4+
def T (n : Nat) : Type := Vector Nat n
5+
6+
inductive Foo' (α β : Type u) : (n : Nat) → P n -> Type u
7+
| even (a : α) (n : Nat) (v : T n) (h : P n) : Foo' α β (double n) (pax _)
8+
| odd (b : β) (n : Nat) (v : T n) : Foo' α β (Nat.succ (double n)) (pax _)
9+
10+
example (α β : Type) (a₁ a₂ : α)
11+
(n₁ n₂ : Nat)
12+
(v₁ : T n₁) (v₂ : T n₂)
13+
(h₁ : P n₁) (h₂ : P n₂)
14+
(h_1 : double n₁ = double n₂)
15+
(h_2 : Foo'.even (β := β) a₁ n₁ v₁ h₁ ≍ Foo'.even (β := β) a₂ n₂ v₂ h₂) :
16+
a₁ = a₂ ∧ n₁ = n₂ ∧ v₁ ≍ v₂ ∧ h₁ ≍ h₂ := by
17+
grind
18+
19+
example (α β : Type) (a : α) (b : β)
20+
(n₁ n₂ : Nat)
21+
(v₁ : T n₁) (v₂ : T n₂)
22+
(h₁ : P n₁)
23+
(h_1 : double n₁ = (double n₂).succ)
24+
(h_2 : Foo'.even (β := β) a n₁ v₁ h₁ ≍ Foo'.odd (α := α) b n₂ v₂)
25+
: False := by
26+
grind
27+
28+
inductive Parity : Nat -> Type
29+
| even (n) : Parity (double n)
30+
| odd (n) : Parity (Nat.succ (double n))
31+
32+
example
33+
(motive : (x : Nat) → Parity x → Sort u_1)
34+
(h_2 : (j : Nat) → motive (double j) (Parity.even j))
35+
(j n : Nat)
36+
(heq_1 : double n = double j)
37+
(heq_2 : Parity.even n ≍ Parity.even j):
38+
h_2 n ≍ h_2 j := by
39+
grind
40+
41+
opaque q : Nat → Nat → Prop
42+
axiom qax : q a b → double a = double b
43+
attribute [grind →] qax
44+
45+
example
46+
(motive : (x : Nat) → Parity x → Sort u_1)
47+
(h_2 : (j : Nat) → motive (double j) (Parity.even j))
48+
(j n : Nat)
49+
(heq_1 : q j n)
50+
(heq_2 : Parity.even n ≍ Parity.even j):
51+
h_2 n ≍ h_2 j := by
52+
grind

tests/lean/run/issue11449.lean

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@ example
1111
(heq_1 : double n = double j)
1212
(heq_2 : Parity.even n ≍ Parity.even j):
1313
h_2 n ≍ h_2 j := by
14-
fail_if_success grind -- does not work yet
15-
sorry
16-
17-
-- manual proof using heterogenenous noConfusion
14+
grind
1815

1916
example
2017
(motive : (x : Nat) → Parity x → Sort u_1)

0 commit comments

Comments
 (0)