@@ -4,17 +4,14 @@ Released under Apache 2.0 license as described in the file LICENSE.
44Authors: Leonardo de Moura
55-/
66module
7-
87prelude
98public import Lean.Meta.Basic
109import Lean.Meta.Tactic.Refl
1110import Lean.Meta.Tactic.Cases
1211import Lean.Meta.Tactic.Assumption
1312import Lean.Meta.Tactic.Simp.Main
1413import Lean.Meta.SameCtorUtils
15-
1614public section
17-
1815namespace Lean.Meta
1916
2017private def mkAnd? (args : Array Expr) : Option Expr := Id.run do
@@ -33,20 +30,26 @@ def elimOptParam (type : Expr) : CoreM Expr := do
3330 else
3431 return .continue
3532
33+ private def mkEqs (args1 args2 : Array Expr) (skipIfPropOrEq : Bool := true ) : MetaM (Array Expr) := do
34+ let mut eqs := #[]
35+ for arg1 in args1, arg2 in args2 do
36+ let arg1Type ← inferType arg1
37+ if !skipIfPropOrEq then
38+ eqs := eqs.push (← mkEqHEq arg1 arg2)
39+ else if !(← isProp arg1Type) && arg1 != arg2 then
40+ eqs := eqs.push (← mkEqHEq arg1 arg2)
41+ return eqs
42+
3643private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do
3744 let us := ctorVal.levelParams.map mkLevelParam
3845 let type ← elimOptParam ctorVal.type
3946 forallBoundedTelescope type ctorVal.numParams fun params type =>
4047 forallTelescope type fun args1 resultType => do
41- let jp (args2 args2New : Array Expr) : MetaM (Option Expr) := do
48+ let k (args2 args2New : Array Expr) : MetaM (Option Expr) := do
4249 let lhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args1
4350 let rhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args2
4451 let eq ← mkEq lhs rhs
45- let mut eqs := #[]
46- for arg1 in args1, arg2 in args2 do
47- let arg1Type ← inferType arg1
48- if !(← isProp arg1Type) && arg1 != arg2 then
49- eqs := eqs.push (← mkEqHEq arg1 arg2)
52+ let eqs ← mkEqs args1 args2
5053 if let some andEqs := mkAnd? eqs then
5154 let result ← if useEq then
5255 mkEq eq andEqs
@@ -57,17 +60,15 @@ private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useE
5760 return none
5861 let rec mkArgs2 (i : Nat) (type : Expr) (args2 args2New : Array Expr) : MetaM (Option Expr) := do
5962 if h : i < args1.size then
60- match (← whnf type) with
61- | Expr.forallE n d b _ =>
62- let arg1 := args1[i]
63- if occursOrInType (← getLCtx) arg1 resultType then
64- mkArgs2 (i + 1 ) (b.instantiate1 arg1) (args2.push arg1) args2New
65- else
66- withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
67- mkArgs2 (i + 1 ) (b.instantiate1 arg2) (args2.push arg2) (args2New.push arg2)
68- | _ => throwError "unexpected constructor type for `{ctorVal.name}`"
63+ let .forallE n d b _ ← whnf type | throwError "unexpected constructor type for `{ctorVal.name}`"
64+ let arg1 := args1[i]
65+ if occursOrInType (← getLCtx) arg1 resultType then
66+ mkArgs2 (i + 1 ) (b.instantiate1 arg1) (args2.push arg1) args2New
67+ else
68+ withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
69+ mkArgs2 (i + 1 ) (b.instantiate1 arg2) (args2.push arg2) (args2New.push arg2)
6970 else
70- jp args2 args2New
71+ k args2 args2New
7172 if useEq then
7273 mkArgs2 0 type #[] #[]
7374 else
@@ -84,14 +85,16 @@ private def injTheoremFailureHeader (ctorName : Name) : MessageData :=
8485private def throwInjectiveTheoremFailure {α} (ctorName : Name) (mvarId : MVarId) : MetaM α :=
8586 throwError "{injTheoremFailureHeader ctorName}{indentD <| MessageData.ofGoal mvarId}"
8687
88+ private def splitAndAssumption (mvarId : MVarId) (ctorName : Name) : MetaM Unit := do
89+ (← mvarId.splitAnd).forM fun mvarId =>
90+ unless (← mvarId.assumptionCore) do
91+ throwInjectiveTheoremFailure ctorName mvarId
92+
8793private def solveEqOfCtorEq (ctorName : Name) (mvarId : MVarId) (h : FVarId) : MetaM Unit := do
8894 trace[Meta.injective] "solving injectivity goal for {ctorName} with hypothesis {mkFVar h} at\n {mvarId}"
8995 match (← injection mvarId h) with
9096 | InjectionResult.solved => unreachable!
91- | InjectionResult.subgoal mvarId .. =>
92- (← mvarId.splitAnd).forM fun mvarId =>
93- unless (← mvarId.assumptionCore) do
94- throwInjectiveTheoremFailure ctorName mvarId
97+ | InjectionResult.subgoal mvarId .. => splitAndAssumption mvarId ctorName
9598
9699private def mkInjectiveTheoremValue (ctorName : Name) (targetType : Expr) : MetaM Expr :=
97100 forallTelescopeReducing targetType fun xs type => do
@@ -178,4 +181,106 @@ def mkInjectiveTheorems (declName : Name) : MetaM Unit := do
178181builtin_initialize
179182 registerTraceClass `Meta.injective
180183
184+ private def getIndices? (ctorApp : Expr) : MetaM (Option (Array Expr)) := do
185+ let type ← whnfD (← inferType ctorApp)
186+ type.withApp fun typeFn typeArgs => do
187+ let .const declName _ := typeFn | return none
188+ let .inductInfo val ← getConstInfo declName | return none
189+ if val.numIndices == 0 then return some #[]
190+ return some typeArgs[val.numParams...*].toArray
191+
192+ private def mkArrows (hs : Array Expr) (type : Expr) : CoreM Expr := do
193+ hs.foldrM (init := type) mkArrow
194+
195+ private structure MkHInjTypeResult where
196+ thmType : Expr
197+ us : List Level
198+ numIndices : Nat
199+
200+ private partial def mkHInjType? (ctorVal : ConstructorVal) : MetaM (Option MkHInjTypeResult) := do
201+ let us := ctorVal.levelParams.map mkLevelParam
202+ let type ← elimOptParam ctorVal.type
203+ forallBoundedTelescope type ctorVal.numParams fun params type =>
204+ forallTelescope type fun args1 resultType => do
205+ let k (args2 : Array Expr) : MetaM (Option MkHInjTypeResult) := do
206+ let lhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args1
207+ let rhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args2
208+ let eq ← mkEqHEq lhs rhs
209+ let eqs ← mkEqs args1 args2
210+ if let some andEqs := mkAnd? eqs then
211+ let result ← mkArrow eq andEqs
212+ let some idxs1 ← getIndices? lhs | return none
213+ let some idxs2 ← getIndices? rhs | return none
214+ -- **Note** : We dot not skip here because the type of `noConfusion` does not.
215+ let idxEqs ← mkEqs idxs1 idxs2 (skipIfPropOrEq := false )
216+ let result ← mkArrows idxEqs result
217+ let thmType ← mkForallFVars params (← mkForallFVars args1 (← mkForallFVars args2 result))
218+ return some { thmType, us, numIndices := idxs1.size }
219+ else
220+ return none
221+ let rec mkArgs2 (i : Nat) (type : Expr) (args2 : Array Expr) : MetaM (Option MkHInjTypeResult) := do
222+ if h : i < args1.size then
223+ let .forallE n d b _ ← whnf type | throwError "unexpected constructor type for `{ctorVal.name}`"
224+ let arg1 := args1[i]
225+ withLocalDecl n .implicit d fun arg2 =>
226+ mkArgs2 (i + 1 ) (b.instantiate1 arg2) (args2.push arg2)
227+ else
228+ k args2
229+ withNewBinderInfos (params.map fun param => (param.fvarId!, BinderInfo.implicit)) <|
230+ withNewBinderInfos (args1.map fun arg1 => (arg1.fvarId!, BinderInfo.implicit)) <|
231+ mkArgs2 0 type #[]
232+
233+ private def failedToGenHInj (ctorVal : ConstructorVal) : MetaM α :=
234+ throwError "failed to generate heterogeneous injectivity theorem for `{ctorVal.name}`"
235+
236+ private partial def mkHInjectiveTheoremValue? (ctorVal : ConstructorVal) (typeInfo : MkHInjTypeResult) : MetaM (Option Expr) := do
237+ forallTelescopeReducing typeInfo.thmType fun xs type => do
238+ let noConfusionName := ctorVal.induct.str "noConfusion"
239+ let params := xs[*...ctorVal.numParams]
240+ let noConfusion := mkAppN (mkConst noConfusionName (0 :: typeInfo.us)) params
241+ let noConfusion := mkApp noConfusion type
242+ let n := xs.size - typeInfo.numIndices - 1
243+ let eqs := xs[n...*].toArray
244+ let eqExprs ← eqs.mapM fun x => do
245+ match_expr (← inferType x) with
246+ | Eq _ lhs rhs => return (lhs, rhs)
247+ | HEq _ lhs _ rhs => return (lhs, rhs)
248+ | _ => failedToGenHInj ctorVal
249+ let (args₁, args₂) := eqExprs.unzip
250+ let noConfusion := mkAppN (mkAppN (mkAppN noConfusion args₁) args₂) eqs
251+ let .forallE _ d _ _ ← whnf (← inferType noConfusion) | failedToGenHInj ctorVal
252+ let mvar ← mkFreshExprSyntheticOpaqueMVar d
253+ let noConfusion := mkApp noConfusion mvar
254+ let mvarId := mvar.mvarId!
255+ let (_, mvarId) ← mvarId.intros
256+ splitAndAssumption mvarId ctorVal.name
257+ check noConfusion
258+ let result ← instantiateMVars noConfusion
259+ mkLambdaFVars xs result
260+
261+ private def hinjSuffix := "hinj"
262+
263+ def mkHInjectiveTheoremNameFor (ctorName : Name) : Name :=
264+ ctorName.str hinjSuffix
265+
266+ private def mkHInjectiveTheorem? (thmName : Name) (ctorVal : ConstructorVal) : MetaM (Option TheoremVal) := do
267+ let some typeInfo ← mkHInjType? ctorVal | return none
268+ let some value ← mkHInjectiveTheoremValue? ctorVal typeInfo | return none
269+ return some { name := thmName, value, levelParams := ctorVal.levelParams, type := typeInfo.thmType }
270+
271+ builtin_initialize registerReservedNamePredicate fun env n =>
272+ match n with
273+ | .str p "hinj" => (env.find? p matches some (.ctorInfo _))
274+ | _ => false
275+
276+ builtin_initialize
277+ registerReservedNameAction fun name => do
278+ let .str p "hinj" := name | return false
279+ let some (.ctorInfo ctorVal) := (← getEnv).find? p | return false
280+ MetaM.run' do
281+ let some thmVal ← mkHInjectiveTheorem? name ctorVal | return false
282+ realizeConst p name do
283+ addDecl (← mkThmOrUnsafeDef thmVal)
284+ return true
285+
181286end Lean.Meta
0 commit comments