Skip to content

Commit 2d67524

Browse files
authored
feat: equality in grind linarith (#8697)
This PR implements support for inequalities in the `grind` linear arithmetic procedure and simplifies its design. Some examples that can already be solved: ```lean open Lean.Grind example [IntModule α] [Preorder α] [IntModule.IsOrdered α] (a b c d : α) : a + d < c → b = a + (2:Int)*d → b - d > c → False := by grind example [CommRing α] [LinearOrder α] [Ring.IsOrdered α] (a b : α) : a = 0 → b = 1 → a + b ≤ 2 := by grind example [CommRing α] [Preorder α] [Ring.IsOrdered α] (a b c d e : α) : 2*a + b ≥ 1 → b ≥ 0 → c ≥ 0 → d ≥ 0 → e ≥ 0 → a ≥ 3*c → c ≥ 6*e → d - e*5 ≥ 0 → a + b + 3*c + d + 2*e < 0 → False := by grind ```
1 parent 41c41e4 commit 2d67524

File tree

16 files changed

+297
-98
lines changed

16 files changed

+297
-98
lines changed

src/Init/Grind/Ordered/Linarith.lean

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ theorem eq_norm {α} [IntModule α] (ctx : Context α) (lhs rhs : Expr) (p : Pol
370370
: norm_cert lhs rhs p → lhs.denote ctx = rhs.denote ctx → p.denote' ctx = 0 := by
371371
simp [norm_cert]; intro _ h₁; subst p; simp [Expr.denote, h₁, sub_self]
372372

373+
theorem le_of_eq {α} [IntModule α] [Preorder α] [IntModule.IsOrdered α] (ctx : Context α) (lhs rhs : Expr) (p : Poly)
374+
: norm_cert lhs rhs p → lhs.denote ctx = rhs.denote ctx → p.denote' ctx ≤ 0 := by
375+
simp [norm_cert]; intro _ h₁; subst p; simp [Expr.denote, h₁, sub_self]
376+
apply Preorder.le_refl
377+
373378
theorem diseq_norm {α} [IntModule α] (ctx : Context α) (lhs rhs : Expr) (p : Poly)
374379
: norm_cert lhs rhs p → lhs.denote ctx ≠ rhs.denote ctx → p.denote' ctx ≠ 0 := by
375380
simp [norm_cert]; intro _ h₁; subst p; simp [Expr.denote, h₁, sub_self]

src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ prelude
77
import Lean.Meta.Tactic.Grind.Arith.Offset
88
import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr
99
import Lean.Meta.Tactic.Grind.Arith.CommRing.Internalize
10+
import Lean.Meta.Tactic.Grind.Arith.Linear.Internalize
1011

1112
namespace Lean.Meta.Grind.Arith
1213

1314
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
1415
Offset.internalize e parent?
1516
Cutsat.internalize e parent?
1617
CommRing.internalize e parent?
18+
Linear.internalize e parent?
1719

1820
end Lean.Meta.Grind.Arith

src/Lean/Meta/Tactic/Grind/Arith/Linear.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import Lean.Meta.Tactic.Grind.Arith.Linear.ToExpr
1515
import Lean.Meta.Tactic.Grind.Arith.Linear.Proof
1616
import Lean.Meta.Tactic.Grind.Arith.Linear.SearchM
1717
import Lean.Meta.Tactic.Grind.Arith.Linear.Search
18+
import Lean.Meta.Tactic.Grind.Arith.Linear.PropagateEq
19+
import Lean.Meta.Tactic.Grind.Arith.Linear.Internalize
1820

1921
namespace Lean
2022

src/Lean/Meta/Tactic/Grind/Arith/Linear/DenoteExpr.lean

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ private def mkEq (a b : Expr) : M Expr := do
4646
let s ← getStruct
4747
return mkApp3 (mkConst ``Eq [s.u.succ]) s.type a b
4848

49-
def EqCnstr.denoteExpr (c : EqCnstr) : M Expr := do
50-
mkEq (← c.p.denoteExpr) (← getStruct).ofNatZero
51-
5249
def DiseqCnstr.denoteExpr (c : DiseqCnstr) : M Expr := do
5350
return mkNot (← mkEq (← c.p.denoteExpr) (← getStruct).ofNatZero)
5451

@@ -61,9 +58,6 @@ private def denoteIneq (p : Poly) (strict : Bool) : M Expr := do
6158
def IneqCnstr.denoteExpr (c : IneqCnstr) : M Expr := do
6259
denoteIneq c.p c.strict
6360

64-
def NotIneqCnstr.denoteExpr (c : NotIneqCnstr) : M Expr := do
65-
return mkNot (← denoteIneq c.p c.strict)
66-
6761
private def denoteNum (k : Int) : LinearM Expr := do
6862
return mkApp2 (← getStruct).hmulFn (mkIntLit k) (← getOne)
6963

src/Lean/Meta/Tactic/Grind/Arith/Linear/IneqCnstr.lean

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do
3838
if (← c.satisfied) == .false then
3939
resetAssignmentFrom x
4040

41-
def NotIneqCnstr.assert (c : NotIneqCnstr) : LinearM Unit := do
42-
trace[grind.linarith.assert] "{← c.denoteExpr}"
43-
-- TODO
44-
4541
def propagateCommRingIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do
4642
let some lhs ← withRingM <| CommRing.reify? lhs (skipVar := false) | return ()
4743
let some rhs ← withRingM <| CommRing.reify? rhs (skipVar := false) | return ()
@@ -61,12 +57,8 @@ def propagateCommRingIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue :
6157
let c : IneqCnstr := { p, strict, h := .notCoreCommRing e lhs rhs p' lhs' }
6258
c.assert
6359
else
64-
let p' := (lhs.sub rhs).toPoly
65-
let lhs' ← p'.denoteAsIntModuleExpr
66-
let some lhs' ← reify? lhs' (skipVar := false) | return ()
67-
let p := lhs'.norm
68-
let c : NotIneqCnstr := { p, strict, h := .coreCommRing e lhs rhs lhs' }
69-
c.assert
60+
-- Negation for preorders is not supported
61+
return ()
7062

7163
def propagateIntModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do
7264
let some lhs ← reify? lhs (skipVar := false) | return ()
@@ -81,9 +73,8 @@ def propagateIntModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue :
8173
let c : IneqCnstr := { p, strict, h := .notCore e lhs rhs }
8274
c.assert
8375
else
84-
let p := (lhs.sub rhs).norm
85-
let c : NotIneqCnstr := { p, strict, h := .core e lhs rhs }
86-
c.assert
76+
-- Negation for preorders is not supported
77+
return ()
8778

8879
def propagateIneq (e : Expr) (eqTrue : Bool) : GoalM Unit := do
8980
unless (← getConfig).linarith do return ()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/-
2+
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
prelude
7+
import Lean.Meta.Tactic.Grind.Simp
8+
import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify
9+
import Lean.Meta.Tactic.Grind.Arith.Linear.StructId
10+
import Lean.Meta.Tactic.Grind.Arith.Linear.Reify
11+
12+
namespace Lean.Meta.Grind.Arith
13+
14+
15+
namespace Linear
16+
17+
/-- If `e` is a function application supported by the linarith module, return its type. -/
18+
private def getType? (e : Expr) : Option Expr :=
19+
match_expr e with
20+
| HAdd.hAdd _ _ α _ _ _ => some α
21+
| HSub.hSub _ _ α _ _ _ => some α
22+
| HMul.hMul _ _ α _ _ _ => some α
23+
| HSMul.hSMul _ _ α _ _ _ => some α
24+
| Neg.neg α _ _ => some α
25+
| Zero.zero α _ => some α
26+
| OfNat.ofNat α _ _ => some α
27+
| NatCast.natCast α _ _ => some α
28+
| IntCast.intCast α _ _ => some α
29+
| _ => none
30+
31+
private def isForbiddenParent (parent? : Option Expr) : Bool :=
32+
if let some parent := parent? then
33+
if getType? parent |>.isSome then
34+
true
35+
else
36+
-- We also ignore the following parents.
37+
-- Remark: `HDiv` should appear in `getType?` as soon as we add support for `Field`
38+
match_expr parent with
39+
| LE.le _ _ _ _ => true
40+
| HDiv.hDiv _ _ _ _ _ _ => true
41+
| HMod.hMod _ _ _ _ _ _ => true
42+
| _ => false
43+
else
44+
true
45+
46+
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
47+
unless (← getConfig).linarith do return ()
48+
let some type := getType? e | return ()
49+
if isForbiddenParent parent? then return ()
50+
let some structId ← getStructId? type | return ()
51+
LinearM.run structId do
52+
setTermStructId e
53+
markAsLinarithTerm e
54+
55+
end Lean.Meta.Grind.Arith.Linear

src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@ private def mkIntModLinOrdThmPrefix (declName : Name) : ProofM Expr := do
149149
let s ← getStruct
150150
return mkApp5 (mkConst declName [s.u]) s.type s.intModuleInst (← getLinearOrderInst) s.isOrdInst (← getContext)
151151

152+
/--
153+
Returns the prefix of a theorem with name `declName` where the first three arguments are
154+
`{α} [CommRing α] (rctx : Context α)`
155+
-/
156+
private def mkCommRingThmPrefix (declName : Name) : ProofM Expr := do
157+
let s ← getStruct
158+
return mkApp3 (mkConst declName [s.u]) s.type (← getCommRingInst) (← getRingContext)
159+
152160
/--
153161
Returns the prefix of a theorem with name `declName` where the first five arguments are
154162
`{α} [CommRing α] [Preorder α] [Ring.IsOrdered α] (rctx : Context α)`
@@ -166,9 +174,6 @@ private def mkCommRingLinOrdThmPrefix (declName : Name) : ProofM Expr := do
166174
return mkApp5 (mkConst declName [s.u]) s.type (← getCommRingInst) (← getLinearOrderInst) (← getRingIsOrdInst) (← getRingContext)
167175

168176
mutual
169-
partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
170-
throwError "NIY"
171-
172177
partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' do
173178
match c'.h with
174179
| .core e lhs rhs =>
@@ -200,14 +205,19 @@ partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' d
200205
let s ← getStruct
201206
let h := mkApp5 (mkConst ``Grind.Linarith.zero_lt_one [s.u]) s.type (← getRingInst) s.preorderInst (← getRingIsOrdInst) (← getContext)
202207
return mkApp3 h (← mkPolyDecl c'.p) reflBoolTrue (← mkEqRefl (← getOne))
208+
| .ofEq a b la lb =>
209+
let h ← mkIntModPreOrdThmPrefix ``Grind.Linarith.le_of_eq
210+
return mkApp5 h (← mkExprDecl la) (← mkExprDecl lb) (← mkPolyDecl c'.p) reflBoolTrue (← mkEqProof a b)
211+
| .ofCommRingEq a b la lb p' lhs' =>
212+
let h' ← mkCommRingThmPrefix ``Grind.CommRing.eq_norm
213+
let h' := mkApp5 h' (← mkRingExprDecl la) (← mkRingExprDecl lb) (← mkRingPolyDecl p') reflBoolTrue (← mkEqProof a b)
214+
let h ← mkIntModPreOrdThmPrefix ``Grind.Linarith.le_of_eq
215+
return mkApp5 h (← mkExprDecl lhs') (← mkExprDecl .zero) (← mkPolyDecl c'.p) reflBoolTrue h'
203216
| _ => throwError "NIY"
204217

205218
partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do
206219
throwError "NIY"
207220

208-
partial def NotIneqCnstr.toExprProof (c' : NotIneqCnstr) : ProofM Expr := caching c' do
209-
throwError "NIY"
210-
211221
partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
212222
match h with
213223
| .lt c => return mkApp (← mkIntModPreThmPrefix ``Grind.Linarith.lt_unsat) (← c.toExprProof)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/-
2+
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
prelude
7+
import Init.Grind.CommRing.Poly
8+
import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify
9+
import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
10+
import Lean.Meta.Tactic.Grind.Arith.Linear.Var
11+
import Lean.Meta.Tactic.Grind.Arith.Linear.StructId
12+
import Lean.Meta.Tactic.Grind.Arith.Linear.Reify
13+
import Lean.Meta.Tactic.Grind.Arith.Linear.IneqCnstr
14+
import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr
15+
import Lean.Meta.Tactic.Grind.Arith.Linear.Proof
16+
17+
namespace Lean.Meta.Grind.Arith.Linear
18+
/-- Returns `some structId` if `a` and `b` are elements of the same structure. -/
19+
private def inSameStruct? (a b : Expr) : GoalM (Option Nat) := do
20+
let some structId ← getTermStructId? a | return none
21+
let some structId' ← getTermStructId? b | return none
22+
unless structId == structId' do return none -- This can happen when we have heterogeneous equalities
23+
return structId
24+
25+
private def processNewCommRingEq (a b : Expr) : LinearM Unit := do
26+
let some lhs ← withRingM <| CommRing.reify? a (skipVar := false) | return ()
27+
let some rhs ← withRingM <| CommRing.reify? b (skipVar := false) | return ()
28+
let p' := (lhs.sub rhs).toPoly
29+
let lhs' ← p'.denoteAsIntModuleExpr
30+
let some lhs' ← reify? lhs' (skipVar := false) | return ()
31+
let p := lhs'.norm
32+
if p == .nil then return ()
33+
let c₁ : IneqCnstr := { p, strict := false, h := .ofCommRingEq a b lhs rhs p' lhs' }
34+
c₁.assert
35+
let p := p.mul (-1)
36+
let p' := p'.mulConst (-1)
37+
let lhs' ← p'.denoteAsIntModuleExpr
38+
let some lhs' ← reify? lhs' (skipVar := false) | return ()
39+
let c₂ : IneqCnstr := { p, strict := false, h := .ofCommRingEq b a rhs lhs p' lhs' }
40+
c₂.assert
41+
42+
private def processNewIntModuleEq (a b : Expr) : LinearM Unit := do
43+
let some lhs ← reify? a (skipVar := false) | return ()
44+
let some rhs ← reify? b (skipVar := false) | return ()
45+
let p := (lhs.sub rhs).norm
46+
if p == .nil then return ()
47+
let c₁ : IneqCnstr := { p, strict := false, h := .ofEq a b lhs rhs }
48+
c₁.assert
49+
let p := p.mul (-1)
50+
let c₂ : IneqCnstr := { p, strict := false, h := .ofEq b a rhs lhs }
51+
c₂.assert
52+
53+
@[export lean_process_linarith_eq]
54+
def processNewEqImpl (a b : Expr) : GoalM Unit := do
55+
if isSameExpr a b then return () -- TODO: check why this is needed
56+
let some structId ← inSameStruct? a b | return ()
57+
LinearM.run structId do
58+
trace_goal[grind.linarith.assert] "{← mkEq a b}"
59+
if (← isCommRing) then
60+
processNewCommRingEq a b
61+
else
62+
processNewIntModuleEq a b
63+
64+
@[export lean_process_linarith_diseq]
65+
def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
66+
trace[grind.linarith.assert] "{a} ≠ {b}"
67+
-- TODO
68+
69+
end Lean.Meta.Grind.Arith.Linear

src/Lean/Meta/Tactic/Grind/Arith/Linear/Search.lean

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ namespace Lean.Meta.Grind.Arith.Linear
1313
def IneqCnstr.throwUnexpected (c : IneqCnstr) : LinearM α := do
1414
throwError "`grind linarith` internal error, unexpected{indentD (← c.denoteExpr)}"
1515

16-
def EqCnstr.throwUnexpected (c : EqCnstr) : LinearM α := do
17-
throwError "`grind linarith` internal error, unexpected{indentD (← c.denoteExpr)}"
18-
1916
def DiseqCnstr.throwUnexpected (c : DiseqCnstr) : LinearM α := do
2017
throwError "`grind linarith` internal error, unexpected{indentD (← c.denoteExpr)}"
2118

@@ -79,14 +76,6 @@ def getDiseqValues (x : Var) : LinearM (Array (Rat × DiseqCnstr)) := do
7976
r := r.push (((-v)/k), c)
8077
return r
8178

82-
def getEqValue? (x : Var) : LinearM (Option (Rat × EqCnstr)) := do
83-
let s ← getStruct
84-
let some c := s.eqs[x]! | return none
85-
let .add k _ p := c.p | c.throwUnexpected
86-
let some v ← p.eval? | c.throwUnexpected
87-
let val := (-v) / k
88-
return some (val, c)
89-
9079
def findDiseq? (v : Rat) (dvals : Array (Rat × DiseqCnstr)) : Option DiseqCnstr :=
9180
(·.2) <$> dvals.find? fun (v', _) => v' == v
9281

@@ -140,18 +129,17 @@ def processVar (x : Var) : SearchM Unit := do
140129
let lower? ← getBestLower? x
141130
let upper? ← getBestUpper? x
142131
let diseqVals ← getDiseqValues x
143-
let val? ← getEqValue? x
144132
-- TODO: handle special variable One.one
145-
match lower?, upper?, val? with
146-
| none, none, none =>
133+
match lower?, upper? with
134+
| none, none =>
147135
setAssignment x <| geAvoiding 0 diseqVals
148-
| some (lower, _), none, none =>
136+
| some (lower, _), none =>
149137
let v := geAvoiding (lower.ceil + 1) diseqVals
150138
setAssignment x v
151-
| none, some (upper, _), none =>
139+
| none, some (upper, _) =>
152140
let v := geAvoiding (upper.floor - 1) diseqVals
153141
setAssignment x v
154-
| some (lower, c₁), some (upper, c₂), none =>
142+
| some (lower, c₁), some (upper, c₂) =>
155143
if lower > upper || (lower == upper && (c₁.strict || c₂.strict)) then
156144
resolveLowerUpperConflict c₁ c₂
157145
else if lower == upper then
@@ -165,9 +153,6 @@ def processVar (x : Var) : SearchM Unit := do
165153
else
166154
findRat lower upper diseqVals
167155
setAssignment x v
168-
| _, _, _ =>
169-
-- Handle equalities
170-
throwError "NIY"
171156

172157
/-- Returns `true` if we already have a complete assignment / model. -/
173158
def hasAssignment : LinearM Bool := do

src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ Authors: Leonardo de Moura
66
prelude
77
import Init.Grind.Ordered.Module
88
import Lean.Meta.Tactic.Grind.Simp
9-
import Lean.Meta.Tactic.Grind.Internalize
109
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
1110
import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
1211
import Lean.Meta.Tactic.Grind.Arith.Linear.Util
@@ -22,12 +21,12 @@ private def internalizeFn (fn : Expr) : GoalM Expr := do
2221

2322
private def preprocessConst (c : Expr) : GoalM Expr := do
2423
let c ← preprocess c
25-
internalize c none
24+
internalize c 0 none
2625
return c
2726

2827
private def internalizeConst (c : Expr) : GoalM Expr := do
2928
let c ← shareCommon (← canon c)
30-
internalize c none
29+
internalize c 0 none
3130
return c
3231

3332
open Grind.Linarith (Poly)
@@ -160,7 +159,7 @@ where
160159
if let some one := one? then
161160
if ringInst?.isSome then LinearM.run id do
162161
-- Create `1` variable, and assert strict lower bound `0 < 1`
163-
let x ← mkVar one
162+
let x ← mkVar one (mark := false)
164163
let p := Poly.add (-1) x .nil
165164
modifyStruct fun s => { s with
166165
lowers := s.lowers.modify x fun cs => cs.push { p, h := .oneGtZero, strict := true }

0 commit comments

Comments
 (0)