@@ -7,13 +7,22 @@ prelude
77import Init.Grind.Ordered.Module
88import Lean.Meta.Tactic.Grind.Simp
99import Lean.Meta.Tactic.Grind.Internalize
10+ import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
1011import Lean.Meta.Tactic.Grind.Arith.Linear.Util
1112import Lean.Meta.Tactic.Grind.Arith.Linear.Var
1213
1314namespace Lean.Meta.Grind.Arith.Linear
1415
16+ private def preprocess (e : Expr) : GoalM Expr := do
17+ shareCommon (← canon e)
18+
1519private def internalizeFn (fn : Expr) : GoalM Expr := do
16- shareCommon (← canon fn)
20+ preprocess fn
21+
22+ private def preprocessConst (c : Expr) : GoalM Expr := do
23+ let c ← preprocess c
24+ internalize c none
25+ return c
1726
1827private def internalizeConst (c : Expr) : GoalM Expr := do
1928 let c ← shareCommon (← canon c)
@@ -22,6 +31,13 @@ private def internalizeConst (c : Expr) : GoalM Expr := do
2231
2332open Grind.Linarith (Poly)
2433
34+ private def mkExpectedDefEqMsg (a b : Expr) : MetaM MessageData :=
35+ return m!"`grind linarith` expected{indentExpr a}\n to be definitionally equal to{indentExpr b}"
36+
37+ private def ensureDefEq (a b : Expr) : MetaM Unit := do
38+ unless (← withDefault <| isDefEq a b) do
39+ throwError (← mkExpectedDefEqMsg a b)
40+
2541def getStructId? (type : Expr) : GoalM (Option Nat) := do
2642 if let some id? := (← get').typeIdOf.find? { expr := type } then
2743 return id?
@@ -55,21 +71,24 @@ where
5571 let some inst := inst? | return none
5672 let toField := mkApp2 (mkConst toFieldName [u]) type inst
5773 unless (← withDefault <| isDefEq parentInst toField) do
58- reportIssue! "`grind linarith` expected{indentExpr parentInst} \n to be definitionally equal to{indentExpr toField}"
74+ reportIssue! (← mkExpectedDefEqMsg parentInst toField)
5975 return none
6076 return some inst
6177 let ensureToFieldDefEq (parentInst : Expr) (inst : Expr) (toFieldName : Name) : GoalM Unit := do
6278 let toField := mkApp2 (mkConst toFieldName [u]) type inst
63- unless (← withDefault <| isDefEq parentInst toField) do
64- throwError "`grind linarith` expected{indentExpr parentInst}\n to be definitionally equal to{indentExpr toField}"
79+ ensureDefEq parentInst toField
6580 let ensureToHomoFieldDefEq (parentInst : Expr) (inst : Expr) (toFieldName : Name) (toHeteroName : Name) : GoalM Unit := do
6681 let toField := mkApp2 (mkConst toFieldName [u]) type inst
6782 let heteroToField := mkApp2 (mkConst toHeteroName [u]) type toField
68- unless (← withDefault <| isDefEq parentInst heteroToField) do
69- throwError "`grind linarith` expected{indentExpr parentInst}\n to be definitionally equal to{indentExpr heteroToField}"
83+ ensureDefEq parentInst heteroToField
7084 let some intModuleInst ← getInst? ``Grind.IntModule | return none
7185 let zeroInst ← getInst ``Zero
7286 let zero ← internalizeConst <| mkApp2 (mkConst ``Zero.zero [u]) type zeroInst
87+ let ofNatZeroType := mkApp2 (mkConst ``OfNat [u]) type (mkRawNatLit 0 )
88+ let some ofNatZeroInst := LOption.toOption (← trySynthInstance ofNatZeroType) | return none
89+ -- `ofNatZero` is used internally, we don't need to internalize
90+ let ofNatZero ← preprocess <| mkApp3 (mkConst ``OfNat.ofNat [u]) type (mkRawNatLit 0 ) ofNatZeroInst
91+ ensureDefEq zero ofNatZero
7392 let addInst ← getBinHomoInst ``HAdd
7493 let addFn ← internalizeFn <| mkApp4 (mkConst ``HAdd.hAdd [u, u, u]) type type type addInst
7594 let subInst ← getBinHomoInst ``HSub
@@ -100,15 +119,16 @@ where
100119 let smulFn ← internalizeFn <| mkApp4 (mkConst ``HSMul.hSMul [0 , u, u]) Int.mkType type smulInst smulInst
101120 if (← withDefault <| isDefEq hmulFn smulFn) then
102121 return smulFn
103- reportIssue! "`grind linarith` expected{indentExpr hmulFn} \n to be definitionally equal to{indentExpr smulFn}"
122+ reportIssue! (← mkExpectedDefEqMsg hmulFn smulFn)
104123 return none
105124 let smulFn? ← getSMulFn?
125+ let ringId? ← CommRing.getRingId? type
106126 let ringInst? ← getInst? ``Grind.Ring
107127 let getOne? : GoalM (Option Expr) := do
108128 let some oneInst ← getInst? ``One | return none
109129 let one ← internalizeConst <| mkApp2 (mkConst ``One.one [u]) type oneInst
110130 let one' ← mkNumeral type 1
111- unless (← withDefault <| isDefEq one one') do reportIssue! "`grind linarith` expected{indentExpr one} \n to be definitionally equal to{indentExpr one'}"
131+ unless (← withDefault <| isDefEq one one') do reportIssue! (← mkExpectedDefEqMsg one one')
112132 return some one
113133 let one? ← getOne?
114134 let commRingInst? ← getInst? ``Grind.CommRing
@@ -127,7 +147,7 @@ where
127147 let struct : Struct := {
128148 id, type, u, intModuleInst, preorderInst, isOrdInst, partialInst?, linearInst?, noNatDivInst?
129149 leFn, ltFn, addFn, subFn, negFn, hmulFn, smulFn?, zero, one?
130- ringInst?, commRingInst?, ringIsOrdInst?
150+ ringInst?, commRingInst?, ringIsOrdInst?, ringId?, ofNatZero
131151 }
132152 modify' fun s => { s with structs := s.structs.push struct }
133153 if let some one := one? then
0 commit comments