Skip to content

Commit e5a6901

Browse files
authored
feat: Nat equality propagation in grind order (#11049)
This PR implements equality propagation for `Nat` in `grind order`. `grind order` supports offset equalities for rings, but it has an adapter for `Nat`. Example: ```lean example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f (1 + a) = f (1 + b + 1) := by grind -offset -mbtc -lia -linarith (splits := 0) ```
1 parent cf871a8 commit e5a6901

File tree

4 files changed

+66
-11
lines changed

4 files changed

+66
-11
lines changed

src/Init/Grind/Order.lean

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ theorem le_of_offset_eq_2_k {α} [LE α] [LT α] [Std.LawfulOrderLT α] [Std.IsP
6565
rw [Ring.intCast_neg, Semiring.add_assoc, Semiring.add_comm (α := α) k, Ring.neg_add_cancel, Semiring.add_zero]
6666
apply Std.IsPreorder.le_refl
6767

68+
theorem nat_eq (a b : Nat) (x y : Int) : NatCast.natCast a = x → NatCast.natCast b = y → x = y → a = b := by
69+
intro _ _; subst x y; intro h
70+
exact Int.natCast_inj.mp h
71+
6872
theorem le_of_not_le {α} [LE α] [Std.IsLinearPreorder α]
6973
{a b : α} : ¬ a ≤ b → b ≤ a := by
7074
intro h

src/Lean/Meta/Tactic/Grind/Order/Assert.lean

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,42 @@ def propagatePending : OrderM Unit := do
129129
| .eq u v =>
130130
let ue ← getExpr u
131131
let ve ← getExpr v
132-
unless (← isEqv ue ve) do
132+
if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
133+
unless (← isEqv ue ve) do
134+
let huv ← mkProofForPath u v
135+
let hvu ← mkProofForPath v u
136+
let h ← mkEqProofOfLeOfLe ue ve huv hvu
137+
pushEq ue ve h
138+
-- Checks whether `ue` and `ve` are auxiliary terms
139+
let some (ue', h₁) ← getOriginal? ue | continue
140+
let some (ve', h₂) ← getOriginal? ve | continue
141+
if (← alreadyInternalized ue' <&&> alreadyInternalized ve') then
142+
unless (← isEqv ue' ve') do
133143
let huv ← mkProofForPath u v
134144
let hvu ← mkProofForPath v u
135145
let h ← mkEqProofOfLeOfLe ue ve huv hvu
136-
pushEq ue ve h
146+
/-
147+
We have
148+
- `h₁ : ↑ue' = ue`
149+
- `h₂ : ↑ve' = ve`
150+
- `h : ue = ve`
151+
-/
152+
pushEq ue' ve' <| mkApp7 (mkConst ``Grind.Order.nat_eq) ue' ve' ue ve h₁ h₂ h
153+
where
154+
/--
155+
If `e` is an auxiliary term used to represent some term `a`, returns
156+
`some (a, h)` s.t. `h : ↑a = e`
157+
**Note**: We currently only support `Nat`. Thus `↑a` is actually
158+
`NatCast.natCast a`. If we decide to support arbitrary semirings
159+
in this module, we must adjust this code.
160+
-/
161+
getOriginal? (e : Expr) : GoalM (Option (Expr × Expr)) := do
162+
if let some r := (← get').termMapInv.find? { expr := e } then
163+
return some r
164+
else
165+
let_expr NatCast.natCast _ _ a := e | return none
166+
let h ← mkEqRefl e
167+
return some (a, h)
137168

138169
/--
139170
Returns `true` if `e` is already `True` in the `grind` core.
@@ -190,6 +221,7 @@ Traverses the constraints `c` (representing an expression `e`) s.t.
190221

191222
/-- Equality propagation. -/
192223
def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
224+
if u == v then return ()
193225
if (← isPartialOrder) then
194226
if !k.isZero then return ()
195227
let some k' ← getDist? v u | return ()
@@ -199,6 +231,24 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
199231
if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
200232
if (← isEqv ue ve) then return ()
201233
pushToPropagate <| .eq u v
234+
else
235+
/-
236+
Check whether `ue` and `ve` are auxiliary terms used to encode `Nat` terms.
237+
**Note**: `getOriginal?` is currently hard coded to the `Nat` case since
238+
it is the only type we map to rings. If in the future, we want to support
239+
arbitrary `Semiring`s, we must adjust this code.
240+
-/
241+
let some ue ← getOriginal? ue | return ()
242+
let some ve ← getOriginal? ve | return ()
243+
if (← alreadyInternalized ue <&&> alreadyInternalized ve) then
244+
if (← isEqv ue ve) then return ()
245+
pushToPropagate <| .eq u v
246+
where
247+
getOriginal? (e : Expr) : GoalM (Option Expr) := do
248+
let_expr NatCast.natCast _ _ a := e
249+
| let some (a, _) := (← get').termMapInv.find? { expr := e } | return none
250+
return some a
251+
return some a
202252

203253
/-- Finds constrains and equalities to be propagated. -/
204254
def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do

tests/lean/run/grind_10885.lean

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
11
example {a b : Nat} (ha : 1 ≤ a) : (a - 1 + 1) * b = a * b := by grind
22

3-
/--
4-
info: Try these:
5-
[apply]
6-
mbtc
7-
cases #9501
8-
[apply] finish only [#9501]
9-
-/
10-
#guard_msgs in
113
example {a b : Nat} (ha : 1 ≤ a) : (a - 1 + 1) * b = a * b := by
12-
grind => finish? -- mbtc was applied consider nonlinear `*`
4+
grind => done

tests/lean/run/grind_order_eq.lean

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,12 @@ example [CommRing α] [LE α] [LT α] [LawfulOrderLT α] [IsPartialOrder α] [Or
99

1010
example (a b : Int) (f : Int → Int) : a ≤ b + 1 → b ≤ a - 1 → f a = f (2 + b - 1) := by
1111
grind -mbtc -lia -linarith (splits := 0)
12+
13+
example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f a = f (1 + b + 0) := by
14+
grind -offset -mbtc -lia -linarith (splits := 0)
15+
16+
example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ c → c ≤ a → f a = f c := by
17+
grind -offset -mbtc -lia -linarith (splits := 0)
18+
19+
example (a b : Nat) (f : Nat → Int) : a ≤ b + 1 → b + 1 ≤ a → f (1 + a) = f (1 + b + 1) := by
20+
grind -offset -mbtc -lia -linarith (splits := 0)

0 commit comments

Comments
 (0)