Skip to content

Commit faed852

Browse files
authored
feat: equality propagation in grind order (#11047)
This PR implements (nested term) equality propagation in `grind order`. That is, it propagates implied equalities from `grind order` back to the `grind` core. Examples: ```lean open Lean Grind Std example [LE α] [IsPartialOrder α] (a b : α) (f : α → Nat) : a ≤ b → b ≤ c → c ≤ a → f a = f b := by grind (splits := 0) example [CommRing α] [LE α] [LT α] [LawfulOrderLT α] [IsPartialOrder α] [OrderedRing α] (a b : α) (f : α → Int) : a ≤ b + 1 → b ≤ a - 1 → f a = f (2 + b - 1) := by grind -mbtc -lia -linarith (splits := 0) example (a b : Int) (f : Int → Int) : a ≤ b + 1 → b ≤ a - 1 → f a = f (2 + b - 1) := by grind -mbtc -lia -linarith (splits := 0) ```
1 parent 4939f44 commit faed852

File tree

8 files changed

+162
-35
lines changed

8 files changed

+162
-35
lines changed

src/Init/Grind/Order.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ theorem le_of_eq_2_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder
5555
rw [Ring.intCast_zero, Semiring.add_zero]
5656
intro h; subst a; apply Std.IsPreorder.le_refl
5757

58+
theorem le_of_offset_eq_1_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] [Ring α] [OrderedRing α]
59+
{a b : α} {k : Int} : a = b + k → a ≤ b + k := by
60+
intro h; subst a; apply Std.IsPreorder.le_refl
61+
62+
theorem le_of_offset_eq_2_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsPreorder α] [Ring α] [OrderedRing α]
63+
{a b : α} {k : Int} : a = b + k → b ≤ a + (-k : Int) := by
64+
intro h; subst a
65+
rw [Ring.intCast_neg, Semiring.add_assoc, Semiring.add_comm (α := α) k, Ring.neg_add_cancel, Semiring.add_zero]
66+
apply Std.IsPreorder.le_refl
67+
5868
theorem le_of_not_le {α} [LE α] [Std.IsLinearPreorder α]
5969
{a b : α} : ¬ a ≤ b → b ≤ a := by
6070
intro h

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Poly.lean

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def Poly.isZero : Poly → Bool
200200
| .num 0 => true
201201
| _ => false
202202

203+
def Poly.getConst : Poly → Int
204+
| .num k => k
205+
| .add _ _ p => p.getConst
206+
203207
def Poly.checkCoeffs : Poly → Bool
204208
| .num _ => true
205209
| .add k _ p => k != 0 && checkCoeffs p

src/Lean/Meta/Tactic/Grind/Arith/CommRing/Proof.lean

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,16 @@ def mkEqIffProof (lhs rhs lhs' rhs' : RingExpr) : RingM Expr := do
413413
let h := mkApp2 (mkConst ``Grind.CommRing.eq_norm_expr [ring.u]) ring.type ring.commRingInst
414414
return mkApp6 h ctx (toExpr lhs) (toExpr rhs) (toExpr lhs') (toExpr rhs') eagerReflBoolTrue
415415

416+
/--
417+
Given `e` and `e'` s.t. `e.toPoly == e'.toPoly`, returns a proof that `e.denote ctx = e'.denote ctx`
418+
-/
419+
def mkTermEqProof (e e' : RingExpr) : RingM Expr := do
420+
let ring ← getCommRing
421+
let { lhs, lhs', vars, .. } := norm ring.vars e (.num 0) e' (.num 0)
422+
let ctx ← toContextExpr vars
423+
let h := mkApp2 (mkConst ``Grind.CommRing.Expr.eq_of_toPoly_eq [ring.u]) ring.type ring.commRingInst
424+
return mkApp4 h ctx (toExpr lhs) (toExpr lhs') eagerReflBoolTrue
425+
416426
def mkNonCommLeIffProof (leInst ltInst isPreorderInst orderedRingInst : Expr) (lhs rhs lhs' rhs' : RingExpr) : NonCommRingM Expr := do
417427
let ring ← getRing
418428
let { lhs, rhs, lhs', rhs', vars } := norm ring.vars lhs rhs lhs' rhs'
@@ -434,4 +444,14 @@ def mkNonCommEqIffProof (lhs rhs lhs' rhs' : RingExpr) : NonCommRingM Expr := do
434444
let h := mkApp2 (mkConst ``Grind.CommRing.eq_norm_expr_nc [ring.u]) ring.type ring.ringInst
435445
return mkApp6 h ctx (toExpr lhs) (toExpr rhs) (toExpr lhs') (toExpr rhs') eagerReflBoolTrue
436446

447+
/--
448+
Given `e` and `e'` s.t. `e.toPoly_nc == e'.toPoly_nc`, returns a proof that `e.denote ctx = e'.denote ctx`
449+
-/
450+
def mkNonCommTermEqProof (e e' : RingExpr) : NonCommRingM Expr := do
451+
let ring ← getRing
452+
let { lhs, lhs', vars, .. } := norm ring.vars e (.num 0) e' (.num 0)
453+
let ctx ← toContextExpr vars
454+
let h := mkApp2 (mkConst ``Grind.CommRing.Expr.eq_of_toPoly_nc_eq [ring.u]) ring.type ring.ringInst
455+
return mkApp4 h ctx (toExpr lhs) (toExpr lhs') eagerReflBoolTrue
456+
437457
end Lean.Meta.Grind.Arith.CommRing

src/Lean/Meta/Tactic/Grind/Order/Assert.lean

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public def propagateEqTrue (c : Cnstr NodeId) (e : Expr) (u v : NodeId) (k k' :
7575
let mut h ← mkPropagateEqTrueProof u v k kuv k'
7676
if let some he := c.h? then
7777
h := mkApp4 (mkConst ``Grind.Order.eq_trans_true) e c.e he h
78-
if let some (e', he) := (← get').cnstrsMapInv.find? { expr := e } then
78+
if let some (e', he) := (← get').termMapInv.find? { expr := e } then
7979
h := mkApp4 (mkConst ``Grind.Order.eq_trans_true) e' e he h
8080
pushEqTrue e' h
8181
else
@@ -87,7 +87,7 @@ public def propagateSelfEqTrue (c : Cnstr NodeId) (e : Expr) : OrderM Unit := do
8787
let mut h ← mkPropagateSelfEqTrueProof u c.getWeight
8888
if let some he := c.h? then
8989
h := mkApp4 (mkConst ``Grind.Order.eq_trans_true) e c.e he h
90-
if let some (e', he) := (← get').cnstrsMapInv.find? { expr := e } then
90+
if let some (e', he) := (← get').termMapInv.find? { expr := e } then
9191
h := mkApp4 (mkConst ``Grind.Order.eq_trans_true) e' e he h
9292
pushEqTrue e' h
9393
else
@@ -100,7 +100,7 @@ public def propagateEqFalse (c : Cnstr NodeId) (e : Expr) (u v : NodeId) (k k' :
100100
let mut h ← mkPropagateEqFalseProof u v k kuv k'
101101
if let some he := c.h? then
102102
h := mkApp4 (mkConst ``Grind.Order.eq_trans_false) e c.e he h
103-
if let some (e', he) := (← get').cnstrsMapInv.find? { expr := e } then
103+
if let some (e', he) := (← get').termMapInv.find? { expr := e } then
104104
h := mkApp4 (mkConst ``Grind.Order.eq_trans_false) e' e he h
105105
pushEqFalse e' h
106106
else
@@ -112,7 +112,7 @@ public def propagateSelfEqFalse (c : Cnstr NodeId) (e : Expr) : OrderM Unit := d
112112
let mut h ← mkPropagateSelfEqFalseProof u c.getWeight
113113
if let some he := c.h? then
114114
h := mkApp4 (mkConst ``Grind.Order.eq_trans_false) e c.e he h
115-
if let some (e', he) := (← get').cnstrsMapInv.find? { expr := e } then
115+
if let some (e', he) := (← get').termMapInv.find? { expr := e } then
116116
h := mkApp4 (mkConst ``Grind.Order.eq_trans_false) e' e he h
117117
pushEqFalse e' h
118118
else
@@ -140,7 +140,7 @@ Returns `true` if `e` is already `True` in the `grind` core.
140140
Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`).
141141
-/
142142
private def isAlreadyTrue (e : Expr) : OrderM Bool := do
143-
if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then
143+
if let some (e', _) := (← get').termMapInv.find? { expr := e } then
144144
alreadyInternalized e' <&&> isEqTrue e'
145145
else
146146
alreadyInternalized e <&&> isEqTrue e
@@ -162,7 +162,7 @@ Returns `true` if `e` is already `False` in the `grind` core.
162162
Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`).
163163
-/
164164
private def isAlreadyFalse (e : Expr) : OrderM Bool := do
165-
if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then
165+
if let some (e', _) := (← get').termMapInv.find? { expr := e } then
166166
alreadyInternalized e' <&&> isEqFalse e'
167167
else
168168
alreadyInternalized e <&&> isEqFalse e
@@ -221,7 +221,7 @@ Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
221221
if no negative cycle was created, updates the shortest distance of affected
222222
node pairs.
223223
-/
224-
def addEdge (u : NodeId) (v : NodeId) (k : Weight) (h : Expr) : OrderM Unit := do
224+
public def addEdge (u : NodeId) (v : NodeId) (k : Weight) (h : Expr) : OrderM Unit := do
225225
if (← isInconsistent) then return ()
226226
if u == v then
227227
if k.isNeg then
@@ -303,7 +303,7 @@ def getStructIdOf? (e : Expr) : GoalM (Option Nat) := do
303303
return (← get').exprToStructId.find? { expr := e }
304304

305305
def propagateIneq (e : Expr) : GoalM Unit := do
306-
if let some (e', he) := (← get').cnstrsMap.find? { expr := e } then
306+
if let some (e', he) := (← get').termMap.find? { expr := e } then
307307
go e' (some he)
308308
else
309309
go e none

src/Lean/Meta/Tactic/Grind/Order/Internalize.lean

Lines changed: 105 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ module
77
prelude
88
public import Lean.Meta.Tactic.Grind.Order.OrderM
99
import Init.Data.Int.OfNat
10+
import Init.Grind.Module.Envelope
11+
import Init.Grind.Order
1012
import Lean.Meta.Tactic.Grind.Arith.CommRing.SafePoly
1113
import Lean.Meta.Tactic.Grind.Arith.CommRing.Reify
1214
import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
@@ -15,6 +17,7 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.Nat
1517
import Lean.Meta.Tactic.Grind.Order.StructId
1618
import Lean.Meta.Tactic.Grind.Order.Util
1719
import Lean.Meta.Tactic.Grind.Order.Assert
20+
import Lean.Meta.Tactic.Grind.Order.Proof
1821
namespace Lean.Meta.Grind.Order
1922

2023
open Arith CommRing
@@ -40,8 +43,24 @@ def getType? (e : Expr) : Option Expr :=
4043
| _ => none
4144

4245
def isForbiddenParent (parent? : Option Expr) : Bool :=
43-
if let some parent := parent? then
44-
getType? parent |>.isSome
46+
if isIntModuleVirtualParent parent? then
47+
/-
48+
**Note**: `linarith` uses a virtual parent to mark auxiliary declarations used to encode
49+
terms into an `IntModule`.
50+
-/
51+
true
52+
else if let some parent := parent? then
53+
(getType? parent |>.isSome)
54+
||
55+
/-
56+
**Note**: We currently ignore `•`. We may reconsider it in the future.
57+
-/
58+
match_expr parent with
59+
| HSMul.hSMul _ _ _ _ _ _ => true
60+
| Nat.cast _ _ _ => true
61+
| NatCast.natCast _ _ _ => true
62+
| Grind.IntModule.OfNatModule.toQ _ _ _ => true
63+
| _ => false
4564
else
4665
false
4766

@@ -206,27 +225,89 @@ def internalizeCnstr (e : Expr) (kind : CnstrKind) (lhs rhs : Expr) : OrderM Uni
206225
s.cnstrsOf.insert (u, v) cs
207226
}
208227

228+
/--
229+
Normalization result. A nested term `e` is normalized as `a + k` and
230+
`h` is a proof for `e = a + k`
231+
-/
232+
structure OffsetTermResult where
233+
a : Expr
234+
k : Int
235+
h : Expr
236+
237+
def toOffsetTermCommRing? (e : Expr) : RingM (Option OffsetTermResult) := do
238+
let some e ← reify? e (skipVar := false) | return none
239+
let p ← e.toPolyM
240+
let k := p.getConst
241+
let p := p.addConst (-k)
242+
let a ← shareCommon (← p.denoteExpr)
243+
let h ← mkTermEqProof e (.add (p.toExpr) (.intCast k))
244+
return some { a, k, h }
245+
246+
def toOffsetTermNonCommRing? (e : Expr) : NonCommRingM (Option OffsetTermResult) := do
247+
let some e ← ncreify? e (skipVar := false) | return none
248+
let p := e.toPoly_nc
249+
let k := p.getConst
250+
let p := p.addConst (-k)
251+
let a ← shareCommon (← p.denoteExpr)
252+
let h ← mkNonCommTermEqProof e (.add (p.toExpr) (.intCast k))
253+
return some { a, k, h }
254+
255+
def toOffsetTerm? (e : Expr) : OrderM (Option OffsetTermResult) := do
256+
let s ← getStruct
257+
/-
258+
**Note**: If it is not a partial order, then it is not worth internalizing term
259+
since we will not be able to propagate implied equalities back to core.
260+
-/
261+
unless s.isPartialInst?.isSome do return none
262+
let some ringId := s.ringId? | return none
263+
if s.isCommRing then
264+
RingM.run ringId <| toOffsetTermCommRing? e
265+
else
266+
NonCommRingM.run ringId <| toOffsetTermNonCommRing? e
267+
268+
def internalizeTerm (e : Expr) : OrderM Unit := do
269+
let some r ← toOffsetTerm? e | return ()
270+
let x ← mkNode e
271+
let y ← mkNode r.a
272+
let h₁ ← mkOrdRingPrefix ``Grind.Order.le_of_offset_eq_1_k
273+
let h₁ := mkApp4 h₁ e r.a (toExpr r.k) r.h
274+
addEdge x y { k := r.k } h₁
275+
let h₂ ← mkOrdRingPrefix ``Grind.Order.le_of_offset_eq_2_k
276+
let h₂ := mkApp4 h₂ e r.a (toExpr r.k) r.h
277+
addEdge y x { k := -r.k } h₂
278+
279+
def updateTermMap (e eNew h : Expr) : GoalM Unit := do
280+
modify' fun s => { s with
281+
termMap := s.termMap.insert { expr := e } (eNew, h)
282+
termMapInv := s.termMapInv.insert { expr := eNew } (e, h)
283+
}
284+
209285
open Arith.Cutsat in
210286
def adaptNat (e : Expr) : GoalM Expr := do
211-
let (eNew, h) ← match_expr e with
212-
| LE.le _ _ lhs rhs =>
213-
let (lhs', h₁) ← natToInt lhs
214-
let (rhs', h₂) ← natToInt rhs
215-
let eNew := mkIntLE lhs' rhs'
216-
let h := mkApp6 (mkConst ``Nat.ToInt.le_eq) lhs rhs lhs' rhs' h₁ h₂
217-
pure (eNew, h)
218-
| LT.lt _ _ lhs rhs =>
219-
let (lhs', h₁) ← natToInt lhs
220-
let (rhs', h₂) ← natToInt rhs
221-
let eNew := mkIntLT lhs' rhs'
222-
let h := mkApp6 (mkConst ``Nat.ToInt.lt_eq) lhs rhs lhs' rhs' h₁ h₂
223-
pure (eNew, h)
287+
if let some (eNew, _) := (← get').termMap.find? { expr := e } then
288+
return eNew
289+
else match_expr e with
290+
| LE.le _ _ lhs rhs => adaptCnstr lhs rhs (isLT := false)
291+
| LT.lt _ _ lhs rhs => adaptCnstr lhs rhs (isLT := true)
292+
| HAdd.hAdd _ _ _ _ _ _ => adaptTerm
293+
| HSub.hSub _ _ _ _ _ _ => adaptTerm
294+
| OfNat.ofNat _ _ _ => adaptTerm
224295
| _ => return e
225-
modify' fun s => { s with
226-
cnstrsMap := s.cnstrsMap.insert { expr := e } (eNew, h)
227-
cnstrsMapInv := s.cnstrsMapInv.insert { expr := eNew } (e, h)
228-
}
229-
return eNew
296+
where
297+
adaptCnstr (lhs rhs : Expr) (isLT : Bool) : GoalM Expr := do
298+
let (lhs', h₁) ← natToInt lhs
299+
let (rhs', h₂) ← natToInt rhs
300+
let eNew := if isLT then mkIntLT lhs' rhs' else mkIntLE lhs' rhs'
301+
let h := mkApp6
302+
(mkConst (if isLT then ``Nat.ToInt.lt_eq else ``Nat.ToInt.le_eq))
303+
lhs rhs lhs' rhs' h₁ h₂
304+
updateTermMap e eNew h
305+
return eNew
306+
307+
adaptTerm : GoalM Expr := do
308+
let (eNew, h) ← natToInt e
309+
updateTermMap e eNew h
310+
return eNew
230311

231312
def adapt (α : Expr) (e : Expr) : GoalM (Expr × Expr) := do
232313
-- **Note**: We currently only adapt `Nat` expressions
@@ -250,8 +331,9 @@ public def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
250331
match_expr e with
251332
| LE.le _ _ lhs rhs => internalizeCnstr e .le lhs rhs
252333
| LT.lt _ _ lhs rhs => if (← hasLt) then internalizeCnstr e .lt lhs rhs
253-
| _ =>
254-
-- **Note**: We currently do not internalize offset terms nested in other terms.
255-
return ()
334+
| HAdd.hAdd _ _ _ _ _ _ => internalizeTerm e
335+
| HSub.hSub _ _ _ _ _ _ => internalizeTerm e
336+
| OfNat.ofNat _ _ _ => internalizeTerm e
337+
| _ => return ()
256338

257339
end Lean.Meta.Grind.Order

src/Lean/Meta/Tactic/Grind/Order/Types.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ structure State where
144144
Example: given `x y : Nat`, `x ≤ y + 1` is mapped to `Int.ofNat x ≤ Int.ofNat y + 1`, and proof
145145
of equivalence.
146146
-/
147-
cnstrsMap : PHashMap ExprPtr (Expr × Expr) := {}
148-
/-- `cnstrsMap` inverse -/
149-
cnstrsMapInv : PHashMap ExprPtr (Expr × Expr) := {}
147+
termMap : PHashMap ExprPtr (Expr × Expr) := {}
148+
/-- `termMap` inverse -/
149+
termMapInv : PHashMap ExprPtr (Expr × Expr) := {}
150150
deriving Inhabited
151151

152152
builtin_initialize orderExt : SolverExtension State ← registerSolverExtension (return {})

tests/lean/run/grind_cutsat_toint_1.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ example (a b : Fin 3) : a > 0 → a ≠ b → a + b ≠ 0 → a + b ≠ 1 → Fa
9090
grind
9191

9292
-- We use `↑a` when pretty printing `ToInt.toInt a`
93-
/-- trace: [grind.debug.ring.basis] (↑a + ↑b) % 3 + -1 * ↑a + -1 * ↑b + 3 * ((↑a + ↑b) / 3) = 0 -/
93+
/-- trace: [grind.debug.ring.basis] ↑a + ↑b + -1 * ((↑a + ↑b) % 3) + -3 * ((↑a + ↑b) / 3) = 0 -/
9494
#guard_msgs (drop error, trace) in
9595
set_option trace.grind.debug.ring.basis true in
9696
example (a b : Fin 3) : a > 0 → a ≠ b → a + b ≠ 0 → a + b ≠ 1 → False := by

tests/lean/run/grind_order_eq.lean

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
open Lean Grind Std
2+
3+
example [LE α] [IsPartialOrder α] (a b : α) (f : α → Nat) : a ≤ b → b ≤ c → c ≤ a → f a = f b := by
4+
grind (splits := 0)
5+
6+
example [CommRing α] [LE α] [LT α] [LawfulOrderLT α] [IsPartialOrder α] [OrderedRing α]
7+
(a b : α) (f : α → Int) : a ≤ b + 1 → b ≤ a - 1 → f a = f (2 + b - 1) := by
8+
grind -mbtc -lia -linarith (splits := 0)
9+
10+
example (a b : Int) (f : Int → Int) : a ≤ b + 1 → b ≤ a - 1 → f a = f (2 + b - 1) := by
11+
grind -mbtc -lia -linarith (splits := 0)

0 commit comments

Comments
 (0)