Skip to content

Commit 1d00dee

Browse files
authored
fix: grind order equality propagation for Nat (#11050)
This PR fixes equality propagation for `Nat` in `grind order`.
1 parent e5a6901 commit 1d00dee

File tree

9 files changed

+91
-37
lines changed

9 files changed

+91
-37
lines changed

src/Init/Grind/Order.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ theorem nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCa
6969
intro _ _; subst x y; intro h
7070
exact Int.natCast_inj.mp h
7171

72+
theorem of_nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCast b = y → a = b → x = y := by
73+
intro _ _; subst x y; intro; simp [*]
74+
7275
theorem le_of_not_le {α} [LE α] [Std.IsLinearPreorder α]
7376
{a b : α} : ¬ a ≤ b → b ≤ a := by
7477
intro h

src/Lean/Meta/Tactic/Grind/Arith/CommRing/DenoteExpr.lean

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ variable [Monad M] [MonadError M] [MonadLiftT MetaM M] [MonadCanon M] [MonadRing
1717
def denoteNum (k : Int) : M Expr := do
1818
let ring ← getRing
1919
let n := mkRawNatLit k.natAbs
20-
let ofNatInst := mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
20+
let ofNatInst ← if let some inst ← synthInstance? (mkApp2 (mkConst ``OfNat [ring.u]) ring.type n) then
21+
pure inst
22+
else
23+
pure <| mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
2124
let n := mkApp3 (mkConst ``OfNat.ofNat [ring.u]) ring.type n ofNatInst
2225
if k < 0 then
2326
return mkApp (← getNegFn) n

src/Lean/Meta/Tactic/Grind/Order/Assert.lean

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,23 +148,18 @@ def propagatePending : OrderM Unit := do
148148
- `h₁ : ↑ue' = ue`
149149
- `h₂ : ↑ve' = ve`
150150
- `h : ue = ve`
151+
**Note**: We currently only support `Nat`. Thus `↑a` is actually
152+
`NatCast.natCast a`. If we decide to support arbitrary semirings
153+
in this module, we must adjust this code.
151154
-/
152155
pushEq ue' ve' <| mkApp7 (mkConst ``Grind.Order.nat_eq) ue' ve' ue ve h₁ h₂ h
153156
where
154157
/--
155158
If `e` is an auxiliary term used to represent some term `a`, returns
156159
`some (a, h)` s.t. `h : ↑a = e`
157-
**Note**: We currently only support `Nat`. Thus `↑a` is actually
158-
`NatCast.natCast a`. If we decide to support arbitrary semirings
159-
in this module, we must adjust this code.
160160
-/
161161
getOriginal? (e : Expr) : GoalM (Option (Expr × Expr)) := do
162-
if let some r := (← get').termMapInv.find? { expr := e } then
163-
return some r
164-
else
165-
let_expr NatCast.natCast _ _ a := e | return none
166-
let h ← mkEqRefl e
167-
return some (a, h)
162+
return (← get').termMapInv.find? { expr := e }
168163

169164
/--
170165
Returns `true` if `e` is already `True` in the `grind` core.
@@ -233,10 +228,7 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
233228
pushToPropagate <| .eq u v
234229
else
235230
/-
236-
Check whether `ue` and `ve` are auxiliary terms used to encode `Nat` terms.
237-
**Note**: `getOriginal?` is currently hard coded to the `Nat` case since
238-
it is the only type we map to rings. If in the future, we want to support
239-
arbitrary `Semiring`s, we must adjust this code.
231+
Check whether `ue` and `ve` are auxiliary terms.
240232
-/
241233
let some ue ← getOriginal? ue | return ()
242234
let some ve ← getOriginal? ve | return ()
@@ -245,9 +237,7 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
245237
pushToPropagate <| .eq u v
246238
where
247239
getOriginal? (e : Expr) : GoalM (Option Expr) := do
248-
let_expr NatCast.natCast _ _ a := e
249-
| let some (a, _) := (← get').termMapInv.find? { expr := e } | return none
250-
return some a
240+
let some (a, _) := (← get').termMapInv.find? { expr := e } | return none
251241
return some a
252242

253243
/-- Finds constrains and equalities to be propagated. -/
@@ -378,14 +368,31 @@ builtin_grind_propagator propagateLT ↓LT.lt := propagateIneq
378368

379369
public def processNewEq (a b : Expr) : GoalM Unit := do
380370
unless isSameExpr a b do
371+
let h ← mkEqProof a b
372+
if let some (a', h₁) ← getAuxTerm? a then
373+
let some (b', h₂) ← getAuxTerm? b | return ()
374+
/-
375+
We have
376+
- `h : a = b`
377+
- `h₁ : ↑a = a'`
378+
- `h₂ : ↑b = b'`
379+
-/
380+
let h := mkApp7 (mkConst ``Grind.Order.of_nat_eq) a b a' b' h₁ h₂ h
381+
go a' b' h
382+
else
383+
go a b h
384+
where
385+
getAuxTerm? (e : Expr) : GoalM (Option (Expr × Expr)) := do
386+
return (← get').termMap.find? { expr := e }
387+
388+
go (a b h : Expr) : GoalM Unit := do
381389
let some id₁ ← getStructIdOf? a | return ()
382390
let some id₂ ← getStructIdOf? b | return ()
383391
unless id₁ == id₂ do return ()
384392
OrderM.run id₁ do
385393
trace[grind.order.assert] "{a} = {b}"
386394
let u ← getNodeId a
387395
let v ← getNodeId b
388-
let h ← mkEqProof a b
389396
if (← isRing) then
390397
let h₁ := mkApp3 (← mkOrdRingPrefix ``Grind.Order.le_of_eq_1_k) a b h
391398
let h₂ := mkApp3 (← mkOrdRingPrefix ``Grind.Order.le_of_eq_2_k) a b h

src/Lean/Meta/Tactic/Grind/Order/Internalize.lean

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,20 @@ def isForbiddenParent (parent? : Option Expr) : Bool :=
5050
-/
5151
true
5252
else if let some parent := parent? then
53-
(getType? parent |>.isSome)
54-
||
55-
/-
56-
**Note**: We currently ignore `•`. We may reconsider it in the future.
57-
-/
58-
match_expr parent with
59-
| HSMul.hSMul _ _ _ _ _ _ => true
60-
| Nat.cast _ _ _ => true
61-
| NatCast.natCast _ _ _ => true
62-
| Grind.IntModule.OfNatModule.toQ _ _ _ => true
63-
| _ => false
53+
if parent.isEq then
54+
false
55+
else
56+
(getType? parent |>.isSome)
57+
||
58+
/-
59+
**Note**: We currently ignore `•`. We may reconsider it in the future.
60+
-/
61+
match_expr parent with
62+
| HSMul.hSMul _ _ _ _ _ _ => true
63+
| Nat.cast _ _ _ => true
64+
| NatCast.natCast _ _ _ => true
65+
| Grind.IntModule.OfNatModule.toQ _ _ _ => true
66+
| _ => false
6467
else
6568
false
6669

@@ -173,6 +176,12 @@ def setStructId (e : Expr) : OrderM Unit := do
173176
exprToStructId := s.exprToStructId.insert { expr := e } structId
174177
}
175178

179+
def updateTermMap (e eNew h : Expr) : GoalM Unit := do
180+
modify' fun s => { s with
181+
termMap := s.termMap.insert { expr := e } (eNew, h)
182+
termMapInv := s.termMapInv.insert { expr := eNew } (e, h)
183+
}
184+
176185
def mkNode (e : Expr) : OrderM NodeId := do
177186
if let some nodeId := (← getStruct).nodeMap.find? { expr := e } then
178187
return nodeId
@@ -192,7 +201,19 @@ def mkNode (e : Expr) : OrderM NodeId := do
192201
-/
193202
if (← alreadyInternalized e) then
194203
orderExt.markTerm e
204+
if let some e' ← getOriginal? e then
205+
orderExt.markTerm e'
195206
return nodeId
207+
where
208+
getOriginal? (e : Expr) : GoalM (Option Expr) := do
209+
if let some (e', _) := (← get').termMapInv.find? { expr := e } then
210+
return some e'
211+
let_expr NatCast.natCast _ _ a := e | return none
212+
if (← alreadyInternalized a) then
213+
updateTermMap a e (← mkEqRefl e)
214+
return some a
215+
else
216+
return none
196217

197218
def internalizeCnstr (e : Expr) (kind : CnstrKind) (lhs rhs : Expr) : OrderM Unit := do
198219
let some c ← mkCnstr? e kind lhs rhs | return ()
@@ -267,6 +288,7 @@ def toOffsetTerm? (e : Expr) : OrderM (Option OffsetTermResult) := do
267288

268289
def internalizeTerm (e : Expr) : OrderM Unit := do
269290
let some r ← toOffsetTerm? e | return ()
291+
if e == r.a && r.k == 0 then return ()
270292
let x ← mkNode e
271293
let y ← mkNode r.a
272294
let h₁ ← mkOrdRingPrefix ``Grind.Order.le_of_offset_eq_1_k
@@ -276,12 +298,6 @@ def internalizeTerm (e : Expr) : OrderM Unit := do
276298
let h₂ := mkApp4 h₂ e r.a (toExpr r.k) r.h
277299
addEdge y x { k := -r.k } h₂
278300

279-
def updateTermMap (e eNew h : Expr) : GoalM Unit := do
280-
modify' fun s => { s with
281-
termMap := s.termMap.insert { expr := e } (eNew, h)
282-
termMapInv := s.termMapInv.insert { expr := eNew } (e, h)
283-
}
284-
285301
open Arith.Cutsat in
286302
def adaptNat (e : Expr) : GoalM Expr := do
287303
if let some (eNew, _) := (← get').termMap.find? { expr := e } then
@@ -317,7 +333,7 @@ def adapt (α : Expr) (e : Expr) : GoalM (Expr × Expr) := do
317333
else
318334
return (α, e)
319335

320-
def alreadyInternalized (e : Expr) : OrderM Bool := do
336+
def alreadyInternalizedHere (e : Expr) : OrderM Bool := do
321337
let s ← getStruct
322338
return s.cnstrs.contains { expr := e } || s.nodeMap.contains { expr := e }
323339

@@ -327,7 +343,7 @@ public def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
327343
let (α, e) ← adapt α e
328344
if isForbiddenParent parent? then return ()
329345
if let some structId ← getStructId? α then OrderM.run structId do
330-
if (← alreadyInternalized e) then return ()
346+
if (← alreadyInternalizedHere e) then return ()
331347
match_expr e with
332348
| LE.le _ _ lhs rhs => internalizeCnstr e .le lhs rhs
333349
| LT.lt _ _ lhs rhs => if (← hasLt) then internalizeCnstr e .lt lhs rhs

tests/lean/run/grind_dep_match_overlap.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,9 @@ example (a b : Vec α 2) : h a b = 20 := by
1818

1919
example (a b : Vec α 2) : h a b = 20 := by
2020
grind (splits := 4) [h.eq_def, Vec]
21+
22+
example (a b : Vec α 2) : h a b = 20 := by
23+
grind -offset [h.eq_def, Vec]
24+
25+
example (a b : Vec α 2) : h a b = 20 := by
26+
grind -offset (splits := 4) [h.eq_def, Vec]

tests/lean/run/grind_fun_singleton.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ example [Inhabited α] : ((fun (_ : α) => x = a + 1) = fun (_ : α) => True)
3939

4040
example : c = 5 → ((fun (_ : Nat × Nat) => { down := a + c = b + 5 : ULift Prop }) = fun (_ : Nat × Nat) => { down := c < 10 : ULift Prop }) → a = b := by
4141
grind
42+
43+
example : c = 5 → ((fun (_ : Nat × Nat) => { down := a + c = b + 5 : ULift Prop }) = fun (_ : Nat × Nat) => { down := c < 10 : ULift Prop }) → a = b := by
44+
grind -offset

tests/lean/run/grind_match_eq_propagation.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def h (v w : Vec α n) : Nat :=
7373
example : h a b > 0 := by
7474
grind [h.eq_def]
7575

76+
example : h a b > 0 := by
77+
grind -offset [h.eq_def]
78+
7679
-- TODO: introduce casts while instantiating equation theorems for `h.match_1`
7780
-- example (a b : Vec α 2) : h a b = 20 := by
7881
-- unfold h

tests/lean/run/grind_order_eq.lean

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,15 @@ example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ c → c ≤ a
1818

1919
example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f (1 + a) = f (1 + b + 1) := by
2020
grind -offset -mbtc -lia -linarith (splits := 0)
21+
22+
example
23+
: 2*n_1 + a = 12*n_1 + a = n_2 + 1 → n = 1 → n = n_3 + 1 → n_2 ≠ n_3 → False := by
24+
grind -lia -linarith -offset -ring (splits := 0)
25+
26+
example
27+
: a = b → a ≤ b + 1 := by
28+
grind -lia -linarith -offset -ring (splits := 0) only
29+
30+
example
31+
: a = b + 1 → a ≤ b + 2 := by
32+
grind -lia -linarith -offset -ring (splits := 0) only

tests/lean/run/grind_sort_eqc.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ h_2 : ¬f (f x) = g x x
7272
[eqc] {z, g y z}
7373
[eqc] {0, f x + -1 * f (f x) + -1, f (f x) + -1 * f (f (f x)) + -1, f (f (f x)) + -1 * f (f (f (f x))) + -1}
7474
[eqc] {g z y, g y x}
75+
[eqc] {f x + -1 * f (f x), f (f x) + -1 * f (f (f x)), f (f (f x)) + -1 * f (f (f (f x)))}
7576
[ematch] E-matching patterns
7677
[thm] feq: [@f #4 #3 #0]
7778
[thm] geq: [@g #2 #1 #0 #0]

0 commit comments

Comments
 (0)