Skip to content

Commit 7e1d0cc

Browse files
authored
feat: use CommRing to normalize linarith expressions (#8682)
This PR uses the `CommRing` module to normalize linarith inequalities.
1 parent 2ae066f commit 7e1d0cc

File tree

7 files changed

+206
-42
lines changed

7 files changed

+206
-42
lines changed

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Reify.lean

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ private def reportAppIssue (e : Expr) : GoalM Unit := do
3131
/--
3232
Converts a Lean expression `e` in the `CommRing` with id `ringId` into
3333
a `CommRing.Expr` object.
34+
35+
If `skipVar` is `true`, then the result is `none` if `e` is not an interpreted `CommRing` term.
36+
We use `skipVar := false` when processing inequalities, and `skipVar := true` for equalities and disequalities
3437
-/
35-
partial def reify? (e : Expr) : RingM (Option RingExpr) := do
38+
partial def reify? (e : Expr) (skipVar := true) : RingM (Option RingExpr) := do
3639
let toVar (e : Expr) : RingM RingExpr := do
3740
return .var (← mkVar e)
3841
let asVar (e : Expr) : RingM RingExpr := do
@@ -67,36 +70,41 @@ partial def reify? (e : Expr) : RingM (Option RingExpr) := do
6770
let some k ← getNatValue? n | toVar e
6871
return .num k
6972
| _ => toVar e
70-
let asNone (e : Expr) : GoalM (Option RingExpr) := do
73+
let toTopVar (e : Expr) : RingM (Option RingExpr) := do
74+
if skipVar then
75+
return none
76+
else
77+
return some (← toVar e)
78+
let asTopVar (e : Expr) : RingM (Option RingExpr) := do
7179
reportAppIssue e
72-
return none
80+
toTopVar e
7381
match_expr e with
7482
| HAdd.hAdd _ _ _ i a b =>
75-
if isAddInst (← getRing) i then return some (.add (← go a) (← go b)) else asNone e
83+
if isAddInst (← getRing) i then return some (.add (← go a) (← go b)) else asTopVar e
7684
| HMul.hMul _ _ _ i a b =>
77-
if isMulInst (← getRing) i then return some (.mul (← go a) (← go b)) else asNone e
85+
if isMulInst (← getRing) i then return some (.mul (← go a) (← go b)) else asTopVar e
7886
| HSub.hSub _ _ _ i a b =>
79-
if isSubInst (← getRing) i then return some (.sub (← go a) (← go b)) else asNone e
87+
if isSubInst (← getRing) i then return some (.sub (← go a) (← go b)) else asTopVar e
8088
| HPow.hPow _ _ _ i a b =>
8189
let some k ← getNatValue? b | return none
82-
if isPowInst (← getRing) i then return some (.pow (← go a) k) else asNone e
90+
if isPowInst (← getRing) i then return some (.pow (← go a) k) else asTopVar e
8391
| Neg.neg _ i a =>
84-
if isNegInst (← getRing) i then return some (.neg (← go a)) else asNone e
85-
| IntCast.intCast _ i e =>
92+
if isNegInst (← getRing) i then return some (.neg (← go a)) else asTopVar e
93+
| IntCast.intCast _ i a =>
8694
if isIntCastInst (← getRing) i then
87-
let some k ← getIntValue? e | return none
95+
let some k ← getIntValue? a | asTopVar e
8896
return some (.num k)
8997
else
90-
asNone e
91-
| NatCast.natCast _ i e =>
98+
asTopVar e
99+
| NatCast.natCast _ i a =>
92100
if isNatCastInst (← getRing) i then
93-
let some k ← getNatValue? e | return none
101+
let some k ← getNatValue? a | asTopVar e
94102
return some (.num k)
95103
else
96-
asNone e
104+
asTopVar e
97105
| OfNat.ofNat _ n _ =>
98-
let some k ← getNatValue? n | return none
106+
let some k ← getNatValue? n | asTopVar e
99107
return some (.num k)
100-
| _ => return none
108+
| _ => toTopVar e
101109

102110
end Lean.Meta.Grind.Arith.CommRing

src/Lean/Meta/Tactic/Grind/Arith/Linear/DenoteExpr.lean

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,58 @@ private def mkEq (a b : Expr) : M Expr := do
4646
return mkApp3 (mkConst ``Eq [s.u.succ]) s.type a b
4747

4848
def EqCnstr.denoteExpr (c : EqCnstr) : M Expr := do
49-
mkEq (← c.p.denoteExpr) (← getStruct).zero
49+
mkEq (← c.p.denoteExpr) (← getStruct).ofNatZero
5050

5151
def DiseqCnstr.denoteExpr (c : DiseqCnstr) : M Expr := do
52-
return mkNot (← mkEq (← c.p.denoteExpr) (← getStruct).zero)
52+
return mkNot (← mkEq (← c.p.denoteExpr) (← getStruct).ofNatZero)
5353

5454
private def denoteIneq (p : Poly) (strict : Bool) : M Expr := do
5555
if strict then
56-
return mkApp2 (← getStruct).ltFn (← p.denoteExpr) (← getStruct).zero
56+
return mkApp2 (← getStruct).ltFn (← p.denoteExpr) (← getStruct).ofNatZero
5757
else
58-
return mkApp2 (← getStruct).leFn (← p.denoteExpr) (← getStruct).zero
58+
return mkApp2 (← getStruct).leFn (← p.denoteExpr) (← getStruct).ofNatZero
5959

6060
def IneqCnstr.denoteExpr (c : IneqCnstr) : M Expr := do
6161
denoteIneq c.p c.strict
6262

6363
def NotIneqCnstr.denoteExpr (c : NotIneqCnstr) : M Expr := do
6464
return mkNot (← denoteIneq c.p c.strict)
6565

66+
private def denoteNum (k : Int) : LinearM Expr := do
67+
return mkApp2 (← getStruct).hmulFn (mkIntLit k) (← getOne)
68+
69+
def _root_.Lean.Grind.CommRing.Power.denoteAsIntModuleExpr (pw : Grind.CommRing.Power) : LinearM Expr := do
70+
let x := (← getRing).vars[pw.x]!
71+
if pw.k == 1 then
72+
return x
73+
else
74+
return mkApp2 (← getRing).powFn x (toExpr pw.k)
75+
76+
def _root_.Lean.Grind.CommRing.Mon.denoteAsIntModuleExpr (m : Grind.CommRing.Mon) : LinearM Expr := do
77+
match m with
78+
| .unit => getOne
79+
| .mult pw m => go m (← pw.denoteAsIntModuleExpr)
80+
where
81+
go (m : Grind.CommRing.Mon) (acc : Expr) : LinearM Expr := do
82+
match m with
83+
| .unit => return acc
84+
| .mult pw m => go m (mkApp2 (← getRing).mulFn acc (← pw.denoteAsIntModuleExpr))
85+
86+
def _root_.Lean.Grind.CommRing.Poly.denoteAsIntModuleExpr (p : Grind.CommRing.Poly) : LinearM Expr := do
87+
match p with
88+
| .num k => denoteNum k
89+
| .add k m p => go p (← denoteTerm k m)
90+
where
91+
denoteTerm (k : Int) (m : Grind.CommRing.Mon) : LinearM Expr := do
92+
if k == 1 then
93+
m.denoteAsIntModuleExpr
94+
else
95+
return mkApp2 (← getStruct).hmulFn (mkIntLit k) (← m.denoteAsIntModuleExpr)
96+
97+
go (p : Grind.CommRing.Poly) (acc : Expr) : LinearM Expr := do
98+
match p with
99+
| .num 0 => return acc
100+
| .num k => return mkApp2 (← getStruct).addFn acc (← denoteNum k)
101+
| .add k m p => go p (mkApp2 (← getStruct).addFn acc (← denoteTerm k m))
102+
66103
end Lean.Meta.Grind.Arith.Linear

src/Lean/Meta/Tactic/Grind/Arith/Linear/IneqCnstr.lean

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
44
Authors: Leonardo de Moura
55
-/
66
prelude
7+
import Init.Grind.CommRing.Poly
8+
import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify
9+
import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
710
import Lean.Meta.Tactic.Grind.Arith.Linear.Var
811
import Lean.Meta.Tactic.Grind.Arith.Linear.StructId
912
import Lean.Meta.Tactic.Grind.Arith.Linear.Reify
@@ -16,6 +19,57 @@ def isLeInst (struct : Struct) (inst : Expr) : Bool :=
1619
def isLtInst (struct : Struct) (inst : Expr) : Bool :=
1720
isSameExpr struct.ltFn.appArg! inst
1821

22+
def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do
23+
trace[grind.linarith.assert] "{← c.denoteExpr}"
24+
-- TODO
25+
26+
def NotIneqCnstr.assert (c : NotIneqCnstr) : LinearM Unit := do
27+
trace[grind.linarith.assert] "{← c.denoteExpr}"
28+
-- TODO
29+
30+
def propagateCommRingIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do
31+
let some lhs ← withRingM <| CommRing.reify? lhs (skipVar := false) | return ()
32+
let some rhs ← withRingM <| CommRing.reify? rhs (skipVar := false) | return ()
33+
if eqTrue then
34+
let p' := (lhs.sub rhs).toPoly
35+
let lhs' ← p'.denoteAsIntModuleExpr
36+
let some lhs' ← reify? lhs' (skipVar := false) | return ()
37+
let p := lhs'.norm
38+
let c : IneqCnstr := { p, strict, h := .coreCommRing e lhs rhs lhs' }
39+
c.assert
40+
else if (← isLinearOrder) then
41+
let p' := (rhs.sub lhs).toPoly
42+
let strict := !strict
43+
let lhs' ← p'.denoteAsIntModuleExpr
44+
let some lhs' ← reify? lhs' (skipVar := false) | return ()
45+
let p := lhs'.norm
46+
let c : IneqCnstr := { p, strict, h := .notCoreCommRing e lhs rhs lhs' }
47+
c.assert
48+
else
49+
let p' := (lhs.sub rhs).toPoly
50+
let lhs' ← p'.denoteAsIntModuleExpr
51+
let some lhs' ← reify? lhs' (skipVar := false) | return ()
52+
let p := lhs'.norm
53+
let c : NotIneqCnstr := { p, strict, h := .coreCommRing e lhs rhs lhs' }
54+
c.assert
55+
56+
def propagateIntModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do
57+
let some lhs ← reify? lhs (skipVar := false) | return ()
58+
let some rhs ← reify? rhs (skipVar := false) | return ()
59+
if eqTrue then
60+
let p := (lhs.sub rhs).norm
61+
let c : IneqCnstr := { p, strict, h := .core e lhs rhs }
62+
c.assert
63+
else if (← isLinearOrder) then
64+
let p := (rhs.sub lhs).norm
65+
let strict := !strict
66+
let c : IneqCnstr := { p, strict, h := .notCore e lhs rhs }
67+
c.assert
68+
else
69+
let p := (lhs.sub rhs).norm
70+
let c : NotIneqCnstr := { p, strict, h := .core e lhs rhs }
71+
c.assert
72+
1973
def propagateIneq (e : Expr) (eqTrue : Bool) : GoalM Unit := do
2074
let numArgs := e.getAppNumArgs
2175
unless numArgs == 4 do return ()
@@ -29,15 +83,13 @@ def propagateIneq (e : Expr) (eqTrue : Bool) : GoalM Unit := do
2983
else if isLtInst struct inst then
3084
pure true
3185
else
32-
trace[grind.linarith] "invalid {e}, {(← getStruct).leFn}, {(← getStruct).ltFn}"
3386
return ()
34-
let some lhs ← reify? (e.getArg! 2 numArgs) (skipVar := false) | trace[grind.linarith] "lhs failed {e}"; return ()
35-
let some rhs ← reify? (e.getArg! 3 numArgs) (skipVar := false) | trace[grind.linarith] "rhs failed {e}"; return ()
36-
let p := (lhs.sub rhs).norm
37-
-- TODO
38-
trace[grind.linarith] "{e}, {eqTrue}, strict: {strict}, structId: {structId}"
39-
trace[grind.linarith] "{← p.denoteExpr}"
40-
trace[grind.linarith] "structId: {structId}"
41-
return ()
87+
let lhs := e.getArg! 2 numArgs
88+
let rhs := e.getArg! 3 numArgs
89+
if (← isCommRing) then
90+
propagateCommRingIneq e lhs rhs strict eqTrue
91+
-- TODO: non-commutative ring normalizer
92+
else
93+
propagateIntModuleIneq e lhs rhs strict eqTrue
4294

4395
end Lean.Meta.Grind.Arith.Linear

src/Lean/Meta/Tactic/Grind/Arith/Linear/Reify.lean

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,7 @@ partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do
4141
reportInstIssue e
4242
return .var (← mkVar e)
4343
let isOfNatZero (e : Expr) : LinearM Bool := do
44-
let_expr OfNat.ofNat _ n _ := e | return false
45-
let some k ← getNatValue? n | return false
46-
unless k == 0 do return false
47-
withDefault <| isDefEq e (← getStruct).zero
44+
withDefault <| isDefEq e (← getStruct).ofNatZero
4845
let rec go (e : Expr) : LinearM LinExpr := do
4946
match_expr e with
5047
| HAdd.hAdd _ _ _ i a b =>

src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@ prelude
77
import Init.Grind.Ordered.Module
88
import Lean.Meta.Tactic.Grind.Simp
99
import Lean.Meta.Tactic.Grind.Internalize
10+
import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
1011
import Lean.Meta.Tactic.Grind.Arith.Linear.Util
1112
import Lean.Meta.Tactic.Grind.Arith.Linear.Var
1213

1314
namespace Lean.Meta.Grind.Arith.Linear
1415

16+
private def preprocess (e : Expr) : GoalM Expr := do
17+
shareCommon (← canon e)
18+
1519
private def internalizeFn (fn : Expr) : GoalM Expr := do
16-
shareCommon (← canon fn)
20+
preprocess fn
21+
22+
private def preprocessConst (c : Expr) : GoalM Expr := do
23+
let c ← preprocess c
24+
internalize c none
25+
return c
1726

1827
private def internalizeConst (c : Expr) : GoalM Expr := do
1928
let c ← shareCommon (← canon c)
@@ -22,6 +31,13 @@ private def internalizeConst (c : Expr) : GoalM Expr := do
2231

2332
open Grind.Linarith (Poly)
2433

34+
private def mkExpectedDefEqMsg (a b : Expr) : MetaM MessageData :=
35+
return m!"`grind linarith` expected{indentExpr a}\nto be definitionally equal to{indentExpr b}"
36+
37+
private def ensureDefEq (a b : Expr) : MetaM Unit := do
38+
unless (← withDefault <| isDefEq a b) do
39+
throwError (← mkExpectedDefEqMsg a b)
40+
2541
def getStructId? (type : Expr) : GoalM (Option Nat) := do
2642
if let some id? := (← get').typeIdOf.find? { expr := type } then
2743
return id?
@@ -55,21 +71,24 @@ where
5571
let some inst := inst? | return none
5672
let toField := mkApp2 (mkConst toFieldName [u]) type inst
5773
unless (← withDefault <| isDefEq parentInst toField) do
58-
reportIssue! "`grind linarith` expected{indentExpr parentInst}\nto be definitionally equal to{indentExpr toField}"
74+
reportIssue! (← mkExpectedDefEqMsg parentInst toField)
5975
return none
6076
return some inst
6177
let ensureToFieldDefEq (parentInst : Expr) (inst : Expr) (toFieldName : Name) : GoalM Unit := do
6278
let toField := mkApp2 (mkConst toFieldName [u]) type inst
63-
unless (← withDefault <| isDefEq parentInst toField) do
64-
throwError "`grind linarith` expected{indentExpr parentInst}\nto be definitionally equal to{indentExpr toField}"
79+
ensureDefEq parentInst toField
6580
let ensureToHomoFieldDefEq (parentInst : Expr) (inst : Expr) (toFieldName : Name) (toHeteroName : Name) : GoalM Unit := do
6681
let toField := mkApp2 (mkConst toFieldName [u]) type inst
6782
let heteroToField := mkApp2 (mkConst toHeteroName [u]) type toField
68-
unless (← withDefault <| isDefEq parentInst heteroToField) do
69-
throwError "`grind linarith` expected{indentExpr parentInst}\nto be definitionally equal to{indentExpr heteroToField}"
83+
ensureDefEq parentInst heteroToField
7084
let some intModuleInst ← getInst? ``Grind.IntModule | return none
7185
let zeroInst ← getInst ``Zero
7286
let zero ← internalizeConst <| mkApp2 (mkConst ``Zero.zero [u]) type zeroInst
87+
let ofNatZeroType := mkApp2 (mkConst ``OfNat [u]) type (mkRawNatLit 0)
88+
let some ofNatZeroInst := LOption.toOption (← trySynthInstance ofNatZeroType) | return none
89+
-- `ofNatZero` is used internally, we don't need to internalize
90+
let ofNatZero ← preprocess <| mkApp3 (mkConst ``OfNat.ofNat [u]) type (mkRawNatLit 0) ofNatZeroInst
91+
ensureDefEq zero ofNatZero
7392
let addInst ← getBinHomoInst ``HAdd
7493
let addFn ← internalizeFn <| mkApp4 (mkConst ``HAdd.hAdd [u, u, u]) type type type addInst
7594
let subInst ← getBinHomoInst ``HSub
@@ -100,15 +119,16 @@ where
100119
let smulFn ← internalizeFn <| mkApp4 (mkConst ``HSMul.hSMul [0, u, u]) Int.mkType type smulInst smulInst
101120
if (← withDefault <| isDefEq hmulFn smulFn) then
102121
return smulFn
103-
reportIssue! "`grind linarith` expected{indentExpr hmulFn}\nto be definitionally equal to{indentExpr smulFn}"
122+
reportIssue! (← mkExpectedDefEqMsg hmulFn smulFn)
104123
return none
105124
let smulFn? ← getSMulFn?
125+
let ringId? ← CommRing.getRingId? type
106126
let ringInst? ← getInst? ``Grind.Ring
107127
let getOne? : GoalM (Option Expr) := do
108128
let some oneInst ← getInst? ``One | return none
109129
let one ← internalizeConst <| mkApp2 (mkConst ``One.one [u]) type oneInst
110130
let one' ← mkNumeral type 1
111-
unless (← withDefault <| isDefEq one one') do reportIssue! "`grind linarith` expected{indentExpr one}\nto be definitionally equal to{indentExpr one'}"
131+
unless (← withDefault <| isDefEq one one') do reportIssue! (← mkExpectedDefEqMsg one one')
112132
return some one
113133
let one? ← getOne?
114134
let commRingInst? ← getInst? ``Grind.CommRing
@@ -127,7 +147,7 @@ where
127147
let struct : Struct := {
128148
id, type, u, intModuleInst, preorderInst, isOrdInst, partialInst?, linearInst?, noNatDivInst?
129149
leFn, ltFn, addFn, subFn, negFn, hmulFn, smulFn?, zero, one?
130-
ringInst?, commRingInst?, ringIsOrdInst?
150+
ringInst?, commRingInst?, ringIsOrdInst?, ringId?, ofNatZero
131151
}
132152
modify' fun s => { s with structs := s.structs.push struct }
133153
if let some one := one? then

src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Authors: Leonardo de Moura
55
-/
66
prelude
77
import Std.Internal.Rat
8+
import Init.Grind.CommRing.Poly
89
import Init.Grind.Ordered.Linarith
910
import Lean.Data.PersistentArray
1011
import Lean.Meta.Tactic.Grind.ExprPtr
@@ -38,6 +39,8 @@ structure IneqCnstr where
3839
inductive IneqCnstrProof where
3940
| core (e : Expr) (lhs rhs : LinExpr)
4041
| notCore (e : Expr) (lhs rhs : LinExpr)
42+
| coreCommRing (e : Expr) (lhs rhs : Grind.CommRing.Expr) (lhs' : LinExpr)
43+
| notCoreCommRing (e : Expr) (lhs rhs : Grind.CommRing.Expr) (lhs' : LinExpr)
4144
| combine (c₁ : IneqCnstr) (c₂ : IneqCnstr)
4245
| combineEq (c₁ : IneqCnstr) (c₂ : EqCnstr)
4346
| norm (c₁ : IneqCnstr) (k : Nat)
@@ -62,6 +65,7 @@ structure NotIneqCnstr where
6265

6366
inductive NotIneqCnstrProof where
6467
| core (e : Expr) (lhs rhs : LinExpr)
68+
| coreCommRing (e : Expr) (lhs rhs : Grind.CommRing.Expr) (lhs' : LinExpr)
6569
-- TODO: norm, and combineEq
6670

6771
inductive UnsatProof where
@@ -77,6 +81,8 @@ Each type must be at least implement the instances `IntModule`, `Preorder`, and
7781
-/
7882
structure Struct where
7983
id : Nat
84+
/-- If the structure is a ring, we store its id in the `CommRing` module at `ringId?` -/
85+
ringId? : Option Nat
8086
type : Expr
8187
/-- Cached `getDecLevel type` -/
8288
u : Level
@@ -99,6 +105,7 @@ structure Struct where
99105
/-- `Ring.IsOrdered` instance with `Preorder` -/
100106
ringIsOrdInst? : Option Expr
101107
zero : Expr
108+
ofNatZero : Expr
102109
one? : Option Expr
103110
leFn : Expr
104111
ltFn : Expr

0 commit comments

Comments
 (0)