Skip to content

Commit 1377da0

Browse files
authored
feat: heterogeneous constructor injectivity theorems (#11487)
This PR adds a heterogeneous version of the constructor injectivity theorems. These theorems are useful for indexed families, and will be used in `grind`.
1 parent 5db4f96 commit 1377da0

File tree

3 files changed

+158
-24
lines changed

3 files changed

+158
-24
lines changed

src/Lean/Meta/Injective.lean

Lines changed: 128 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@ Released under Apache 2.0 license as described in the file LICENSE.
44
Authors: Leonardo de Moura
55
-/
66
module
7-
87
prelude
98
public import Lean.Meta.Basic
109
import Lean.Meta.Tactic.Refl
1110
import Lean.Meta.Tactic.Cases
1211
import Lean.Meta.Tactic.Assumption
1312
import Lean.Meta.Tactic.Simp.Main
1413
import Lean.Meta.SameCtorUtils
15-
1614
public section
17-
1815
namespace Lean.Meta
1916

2017
private def mkAnd? (args : Array Expr) : Option Expr := Id.run do
@@ -33,20 +30,26 @@ def elimOptParam (type : Expr) : CoreM Expr := do
3330
else
3431
return .continue
3532

33+
private def mkEqs (args1 args2 : Array Expr) (skipIfPropOrEq : Bool := true) : MetaM (Array Expr) := do
34+
let mut eqs := #[]
35+
for arg1 in args1, arg2 in args2 do
36+
let arg1Type ← inferType arg1
37+
if !skipIfPropOrEq then
38+
eqs := eqs.push (← mkEqHEq arg1 arg2)
39+
else if !(← isProp arg1Type) && arg1 != arg2 then
40+
eqs := eqs.push (← mkEqHEq arg1 arg2)
41+
return eqs
42+
3643
private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do
3744
let us := ctorVal.levelParams.map mkLevelParam
3845
let type ← elimOptParam ctorVal.type
3946
forallBoundedTelescope type ctorVal.numParams fun params type =>
4047
forallTelescope type fun args1 resultType => do
41-
let jp (args2 args2New : Array Expr) : MetaM (Option Expr) := do
48+
let k (args2 args2New : Array Expr) : MetaM (Option Expr) := do
4249
let lhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args1
4350
let rhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args2
4451
let eq ← mkEq lhs rhs
45-
let mut eqs := #[]
46-
for arg1 in args1, arg2 in args2 do
47-
let arg1Type ← inferType arg1
48-
if !(← isProp arg1Type) && arg1 != arg2 then
49-
eqs := eqs.push (← mkEqHEq arg1 arg2)
52+
let eqs ← mkEqs args1 args2
5053
if let some andEqs := mkAnd? eqs then
5154
let result ← if useEq then
5255
mkEq eq andEqs
@@ -57,17 +60,15 @@ private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useE
5760
return none
5861
let rec mkArgs2 (i : Nat) (type : Expr) (args2 args2New : Array Expr) : MetaM (Option Expr) := do
5962
if h : i < args1.size then
60-
match (← whnf type) with
61-
| Expr.forallE n d b _ =>
62-
let arg1 := args1[i]
63-
if occursOrInType (← getLCtx) arg1 resultType then
64-
mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New
65-
else
66-
withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
67-
mkArgs2 (i + 1) (b.instantiate1 arg2) (args2.push arg2) (args2New.push arg2)
68-
| _ => throwError "unexpected constructor type for `{ctorVal.name}`"
63+
let .forallE n d b _ ← whnf type | throwError "unexpected constructor type for `{ctorVal.name}`"
64+
let arg1 := args1[i]
65+
if occursOrInType (← getLCtx) arg1 resultType then
66+
mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New
67+
else
68+
withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
69+
mkArgs2 (i + 1) (b.instantiate1 arg2) (args2.push arg2) (args2New.push arg2)
6970
else
70-
jp args2 args2New
71+
k args2 args2New
7172
if useEq then
7273
mkArgs2 0 type #[] #[]
7374
else
@@ -84,14 +85,16 @@ private def injTheoremFailureHeader (ctorName : Name) : MessageData :=
8485
private def throwInjectiveTheoremFailure {α} (ctorName : Name) (mvarId : MVarId) : MetaM α :=
8586
throwError "{injTheoremFailureHeader ctorName}{indentD <| MessageData.ofGoal mvarId}"
8687

88+
private def splitAndAssumption (mvarId : MVarId) (ctorName : Name) : MetaM Unit := do
89+
(← mvarId.splitAnd).forM fun mvarId =>
90+
unless (← mvarId.assumptionCore) do
91+
throwInjectiveTheoremFailure ctorName mvarId
92+
8793
private def solveEqOfCtorEq (ctorName : Name) (mvarId : MVarId) (h : FVarId) : MetaM Unit := do
8894
trace[Meta.injective] "solving injectivity goal for {ctorName} with hypothesis {mkFVar h} at\n{mvarId}"
8995
match (← injection mvarId h) with
9096
| InjectionResult.solved => unreachable!
91-
| InjectionResult.subgoal mvarId .. =>
92-
(← mvarId.splitAnd).forM fun mvarId =>
93-
unless (← mvarId.assumptionCore) do
94-
throwInjectiveTheoremFailure ctorName mvarId
97+
| InjectionResult.subgoal mvarId .. => splitAndAssumption mvarId ctorName
9598

9699
private def mkInjectiveTheoremValue (ctorName : Name) (targetType : Expr) : MetaM Expr :=
97100
forallTelescopeReducing targetType fun xs type => do
@@ -178,4 +181,106 @@ def mkInjectiveTheorems (declName : Name) : MetaM Unit := do
178181
builtin_initialize
179182
registerTraceClass `Meta.injective
180183

184+
private def getIndices? (ctorApp : Expr) : MetaM (Option (Array Expr)) := do
185+
let type ← whnfD (← inferType ctorApp)
186+
type.withApp fun typeFn typeArgs => do
187+
let .const declName _ := typeFn | return none
188+
let .inductInfo val ← getConstInfo declName | return none
189+
if val.numIndices == 0 then return some #[]
190+
return some typeArgs[val.numParams...*].toArray
191+
192+
private def mkArrows (hs : Array Expr) (type : Expr) : CoreM Expr := do
193+
hs.foldrM (init := type) mkArrow
194+
195+
private structure MkHInjTypeResult where
196+
thmType : Expr
197+
us : List Level
198+
numIndices : Nat
199+
200+
private partial def mkHInjType? (ctorVal : ConstructorVal) : MetaM (Option MkHInjTypeResult) := do
201+
let us := ctorVal.levelParams.map mkLevelParam
202+
let type ← elimOptParam ctorVal.type
203+
forallBoundedTelescope type ctorVal.numParams fun params type =>
204+
forallTelescope type fun args1 resultType => do
205+
let k (args2 : Array Expr) : MetaM (Option MkHInjTypeResult) := do
206+
let lhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args1
207+
let rhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args2
208+
let eq ← mkEqHEq lhs rhs
209+
let eqs ← mkEqs args1 args2
210+
if let some andEqs := mkAnd? eqs then
211+
let result ← mkArrow eq andEqs
212+
let some idxs1 ← getIndices? lhs | return none
213+
let some idxs2 ← getIndices? rhs | return none
214+
-- **Note**: We dot not skip here because the type of `noConfusion` does not.
215+
let idxEqs ← mkEqs idxs1 idxs2 (skipIfPropOrEq := false)
216+
let result ← mkArrows idxEqs result
217+
let thmType ← mkForallFVars params (← mkForallFVars args1 (← mkForallFVars args2 result))
218+
return some { thmType, us, numIndices := idxs1.size }
219+
else
220+
return none
221+
let rec mkArgs2 (i : Nat) (type : Expr) (args2 : Array Expr) : MetaM (Option MkHInjTypeResult) := do
222+
if h : i < args1.size then
223+
let .forallE n d b _ ← whnf type | throwError "unexpected constructor type for `{ctorVal.name}`"
224+
let arg1 := args1[i]
225+
withLocalDecl n .implicit d fun arg2 =>
226+
mkArgs2 (i + 1) (b.instantiate1 arg2) (args2.push arg2)
227+
else
228+
k args2
229+
withNewBinderInfos (params.map fun param => (param.fvarId!, BinderInfo.implicit)) <|
230+
withNewBinderInfos (args1.map fun arg1 => (arg1.fvarId!, BinderInfo.implicit)) <|
231+
mkArgs2 0 type #[]
232+
233+
private def failedToGenHInj (ctorVal : ConstructorVal) : MetaM α :=
234+
throwError "failed to generate heterogeneous injectivity theorem for `{ctorVal.name}`"
235+
236+
private partial def mkHInjectiveTheoremValue? (ctorVal : ConstructorVal) (typeInfo : MkHInjTypeResult) : MetaM (Option Expr) := do
237+
forallTelescopeReducing typeInfo.thmType fun xs type => do
238+
let noConfusionName := ctorVal.induct.str "noConfusion"
239+
let params := xs[*...ctorVal.numParams]
240+
let noConfusion := mkAppN (mkConst noConfusionName (0 :: typeInfo.us)) params
241+
let noConfusion := mkApp noConfusion type
242+
let n := xs.size - typeInfo.numIndices - 1
243+
let eqs := xs[n...*].toArray
244+
let eqExprs ← eqs.mapM fun x => do
245+
match_expr (← inferType x) with
246+
| Eq _ lhs rhs => return (lhs, rhs)
247+
| HEq _ lhs _ rhs => return (lhs, rhs)
248+
| _ => failedToGenHInj ctorVal
249+
let (args₁, args₂) := eqExprs.unzip
250+
let noConfusion := mkAppN (mkAppN (mkAppN noConfusion args₁) args₂) eqs
251+
let .forallE _ d _ _ ← whnf (← inferType noConfusion) | failedToGenHInj ctorVal
252+
let mvar ← mkFreshExprSyntheticOpaqueMVar d
253+
let noConfusion := mkApp noConfusion mvar
254+
let mvarId := mvar.mvarId!
255+
let (_, mvarId) ← mvarId.intros
256+
splitAndAssumption mvarId ctorVal.name
257+
check noConfusion
258+
let result ← instantiateMVars noConfusion
259+
mkLambdaFVars xs result
260+
261+
private def hinjSuffix := "hinj"
262+
263+
def mkHInjectiveTheoremNameFor (ctorName : Name) : Name :=
264+
ctorName.str hinjSuffix
265+
266+
private def mkHInjectiveTheorem? (thmName : Name) (ctorVal : ConstructorVal) : MetaM (Option TheoremVal) := do
267+
let some typeInfo ← mkHInjType? ctorVal | return none
268+
let some value ← mkHInjectiveTheoremValue? ctorVal typeInfo | return none
269+
return some { name := thmName, value, levelParams := ctorVal.levelParams, type := typeInfo.thmType }
270+
271+
builtin_initialize registerReservedNamePredicate fun env n =>
272+
match n with
273+
| .str p "hinj" => (env.find? p matches some (.ctorInfo _))
274+
| _ => false
275+
276+
builtin_initialize
277+
registerReservedNameAction fun name => do
278+
let .str p "hinj" := name | return false
279+
let some (.ctorInfo ctorVal) := (← getEnv).find? p | return false
280+
MetaM.run' do
281+
let some thmVal ← mkHInjectiveTheorem? name ctorVal | return false
282+
realizeConst p name do
283+
addDecl (← mkThmOrUnsafeDef thmVal)
284+
return true
285+
181286
end Lean.Meta

tests/lean/run/hinj_thm.lean

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
opaque double : Nat → Nat
2+
def P (n : Nat) : Prop := n >= 0
3+
theorem pax (n : Nat) : P n := by grind [P]
4+
def T (n : Nat) : Type := Vector Nat n
5+
6+
inductive Foo' (α β : Type u) : (n : Nat) → P n -> Type u
7+
| even (a : α) (n : Nat) (v : T n) (h : P n) : Foo' α β (double n) (pax _)
8+
| odd (b : β) (n : Nat) (v : T n) : Foo' α β (Nat.succ (double n)) (pax _)
9+
10+
/--
11+
info: Foo'.even.hinj.{u} {α β : Type u} {a : α} {n : Nat} {v : T n} {h : P n} {a✝ : α} {n✝ : Nat} {v✝ : T n✝} {h✝ : P n✝} :
12+
double n = double n✝ → ⋯ ≍ ⋯ → Foo'.even a n v h ≍ Foo'.even a✝ n✝ v✝ h✝ → a = a✝ ∧ n = n✝ ∧ v ≍ v✝
13+
-/
14+
#guard_msgs in
15+
#check Foo'.even.hinj
16+
17+
/--
18+
info: Foo'.odd.hinj.{u} {α β : Type u} {b : β} {n : Nat} {v : T n} {b✝ : β} {n✝ : Nat} {v✝ : T n✝} :
19+
(double n).succ = (double n✝).succ → ⋯ ≍ ⋯ → Foo'.odd b n v ≍ Foo'.odd b✝ n✝ v✝ → b = b✝ ∧ n = n✝ ∧ v ≍ v✝
20+
-/
21+
#guard_msgs in
22+
#check Foo'.odd.hinj

tests/lean/run/issue11450.lean

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,15 @@ info: Vec.cons.inj.{u} {α : Type u} {n : Nat} {x : α} {xs : Vec α n} {x✝ :
5656
#guard_msgs in
5757
#check Vec.cons.inj
5858

59-
theorem Vec.cons.hinj {α : Type u}
59+
theorem Vec.cons.hinj' {α : Type u}
6060
{x : α} {n : Nat} {xs : Vec α n} {x' : α} {n' : Nat} {xs' : Vec α n'} :
6161
Vec.cons x xs ≍ Vec.cons x' xs' → (n + 1 = n' + 1 → (x = x' ∧ xs ≍ xs')) := by
6262
intro h eq_1
6363
apply Vec.cons.noConfusion eq_1 h (fun _ eq_x eq_xs => ⟨eq_x, eq_xs⟩)
64+
65+
/--
66+
info: Vec.cons.hinj.{u} {α : Type u} {n : Nat} {x : α} {xs : Vec α n} {n✝ : Nat} {x✝ : α} {xs✝ : Vec α n✝} :
67+
n + 1 = n✝ + 1 → Vec.cons x xs ≍ Vec.cons x✝ xs✝ → n = n✝ ∧ x = x✝ ∧ xs ≍ xs✝
68+
-/
69+
#guard_msgs in
70+
#check Vec.cons.hinj

0 commit comments

Comments
 (0)