Skip to content

Commit 56f3ca6

Browse files
authored
fix: propagation in grind order (#10877)
This PR fixes theory propagation issue in `grind order`.
1 parent 94cb32b commit 56f3ca6

File tree

3 files changed

+158
-11
lines changed

3 files changed

+158
-11
lines changed

src/Lean/Meta/Tactic/Grind/Order/Assert.lean

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,23 @@ def propagatePending : OrderM Unit := do
111111
let h ← mkEqProofOfLeOfLe ue ve huv hvu
112112
pushEq ue ve h
113113

114+
/--
115+
Returns `true` if `e` is already `True` in the `grind` core.
116+
Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`).
117+
-/
118+
private def isAlreadyTrue (e : Expr) : OrderM Bool := do
119+
if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then
120+
alreadyInternalized e' <&&> isEqTrue e'
121+
else
122+
alreadyInternalized e <&&> isEqTrue e
123+
114124
/--
115125
Given `e` represented by constraint `c` (from `u` to `v`).
116126
Checks whether `e = True` can be propagated using the path `u --(k)--> v`.
117127
If it can, adds a new entry to propagation list.
118128
-/
119129
def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
120-
if (← alreadyInternalized e <&&> isEqTrue e) then return true
130+
if (← isAlreadyTrue e) then return true
121131
let k' := c.getWeight
122132
trace[grind.debug.order.check_eq_true] "{← getExpr u}, {← getExpr v}, {k}, {k'}, {← c.pp}"
123133
if k ≤ k' then
@@ -126,13 +136,23 @@ def checkEqTrue (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : Orde
126136
else
127137
return false
128138

139+
/--
140+
Returns `true` if `e` is already `False` in the `grind` core.
141+
Recall that `e` may be an auxiliary term created for a term `e'` (see `cnstrsMapInv`).
142+
-/
143+
private def isAlreadyFalse (e : Expr) : OrderM Bool := do
144+
if let some (e', _) := (← get').cnstrsMapInv.find? { expr := e } then
145+
alreadyInternalized e' <&&> isEqFalse e'
146+
else
147+
alreadyInternalized e <&&> isEqFalse e
148+
129149
/--
130150
Given `e` represented by constraint `c` (from `v` to `u`).
131151
Checks whether `e = False` can be propagated using the path `u --(k)--> v`.
132152
If it can, adds a new entry to propagation list.
133153
-/
134154
def checkEqFalse (u v : NodeId) (k : Weight) (c : Cnstr NodeId) (e : Expr) : OrderM Bool := do
135-
if (← alreadyInternalized e <&&> isEqFalse e) then return true
155+
if (← isAlreadyFalse e) then return true
136156
let k' := c.getWeight
137157
trace[grind.debug.order.check_eq_false] "{← getExpr u}, {← getExpr v}, {k}, {k'} {← c.pp}"
138158
if (k + k').isNeg then
@@ -168,8 +188,8 @@ def checkEq (u v : NodeId) (k : Weight) : OrderM Unit := do
168188

169189
/-- Finds constrains and equalities to be propagated. -/
170190
def checkToPropagate (u v : NodeId) (k : Weight) : OrderM Unit := do
171-
updateCnstrsOf u v fun c e => return !(← checkEqTrue u v k c e)
172-
updateCnstrsOf v u fun c e => return !(← checkEqFalse u v k c e)
191+
updateCnstrsOf u v fun c e => checkEqTrue u v k c e
192+
updateCnstrsOf v u fun c e => checkEqFalse u v k c e
173193
checkEq u v k
174194

175195
/--

tests/lean/run/grind_indexmap_trace.lean

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ example (m : IndexMap α β) (a a' : α) (b : β) :
239239
/--
240240
info: Try this:
241241
[apply]
242-
instantiate only [= getElem_def, insert]
242+
instantiate approx [= getElem_def, = mem_indices_of_mem, insert]
243243
instantiate only [= getElem?_neg, = getElem?_pos]
244244
cases #f590
245245
next =>
@@ -249,8 +249,7 @@ info: Try this:
249249
instantiate only [= Array.getElem_set]
250250
next =>
251251
instantiate only
252-
instantiate approx [= HashMap.getElem_insert, = Array.size_push, size, = Array.getElem_push,
253-
= HashMap.contains_insert, = HashMap.mem_insert, = Array.size_push]
252+
instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert, = HashMap.mem_insert]
254253
next =>
255254
instantiate only [= getElem_def, = mem_indices_of_mem]
256255
instantiate only [usr getElem_indices_lt]
@@ -272,7 +271,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
272271
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
273272
(m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by
274273
grind =>
275-
instantiate only [= getElem_def, insert]
274+
-- **TODO**: Check approx here
275+
instantiate approx [= getElem_def, = mem_indices_of_mem, insert]
276276
instantiate only [= getElem?_neg, = getElem?_pos]
277277
cases #f590
278278
next =>
@@ -282,9 +282,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
282282
instantiate only [= Array.getElem_set]
283283
next =>
284284
instantiate only
285-
-- **TODO**: Investigate why we need `approx` here
286-
instantiate approx [= HashMap.getElem_insert, = Array.size_push, size, = Array.getElem_push,
287-
= HashMap.contains_insert, = HashMap.mem_insert, = Array.size_push]
285+
instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert,
286+
= HashMap.mem_insert]
288287
next =>
289288
instantiate only [= getElem_def, = mem_indices_of_mem]
290289
instantiate only [usr getElem_indices_lt]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import Std.Data.HashMap
2+
set_option warn.sorry false
3+
macro_rules | `(tactic| get_elem_tactic_extensible) => `(tactic| grind)
4+
5+
open Std
6+
7+
structure IndexMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
8+
private indices : HashMap α Nat
9+
private keys : Array α
10+
private values : Array β
11+
private size_keys' : keys.size = values.size := by grind
12+
private WF : ∀ (i : Nat) (a : α), keys[i]? = some a ↔ indices[a]? = some i := by grind
13+
14+
namespace IndexMap
15+
16+
variable {α : Type u} {β : Type v} [BEq α] [Hashable α]
17+
variable {m : IndexMap α β} {a : α} {b : β} {i : Nat}
18+
19+
@[inline] def size (m : IndexMap α β) : Nat :=
20+
m.values.size
21+
22+
@[local grind =] private theorem size_keys : m.keys.size = m.size := m.size_keys'
23+
24+
def emptyWithCapacity (capacity := 8) : IndexMap α β where
25+
indices := HashMap.emptyWithCapacity capacity
26+
keys := Array.emptyWithCapacity capacity
27+
values := Array.emptyWithCapacity capacity
28+
29+
instance : EmptyCollection (IndexMap α β) where
30+
emptyCollection := emptyWithCapacity
31+
32+
instance : Inhabited (IndexMap α β) where
33+
default := ∅
34+
35+
@[inline] def contains (m : IndexMap α β)
36+
(a : α) : Bool :=
37+
m.indices.contains a
38+
39+
instance : Membership α (IndexMap α β) where
40+
mem m a := a ∈ m.indices
41+
42+
instance {m : IndexMap α β} {a : α} : Decidable (a ∈ m) :=
43+
inferInstanceAs (Decidable (a ∈ m.indices))
44+
45+
@[local grind =] private theorem mem_indices_of_mem {m : IndexMap α β} {a : α} :
46+
a ∈ m ↔ a ∈ m.indices := Iff.rfl
47+
48+
@[inline] def findIdx? (m : IndexMap α β) (a : α) : Option Nat := m.indices[a]?
49+
50+
@[inline] def findIdx (m : IndexMap α β) (a : α) (h : a ∈ m := by get_elem_tactic) : Nat := m.indices[a]
51+
52+
@[inline] def getIdx? (m : IndexMap α β) (i : Nat) : Option β := m.values[i]?
53+
54+
@[inline] def getIdx (m : IndexMap α β) (i : Nat) (h : i < m.size := by get_elem_tactic) : β :=
55+
m.values[i]
56+
57+
variable [LawfulBEq α] [LawfulHashable α]
58+
59+
attribute [local grind _=_] IndexMap.WF
60+
61+
private theorem getElem_indices_lt {h : a ∈ m} : m.indices[a] < m.size := by
62+
have : m.indices[a]? = some m.indices[a] := by grind
63+
grind
64+
65+
grind_pattern getElem_indices_lt => m.indices[a]
66+
67+
attribute [local grind] size
68+
69+
instance : GetElem? (IndexMap α β) α β (fun m a => a ∈ m) where
70+
getElem m a h := m.values[m.indices[a]'h]
71+
getElem? m a := m.indices[a]?.bind (fun i => (m.values[i]?))
72+
getElem! m a := m.indices[a]?.bind (fun i => (m.values[i]?)) |>.getD default
73+
74+
@[local grind =] private theorem getElem_def (m : IndexMap α β) (a : α) (h : a ∈ m) : m[a] = m.values[m.indices[a]'h] := rfl
75+
@[local grind =] private theorem getElem?_def (m : IndexMap α β) (a : α) :
76+
m[a]? = m.indices[a]?.bind (fun i => (m.values[i]?)) := rfl
77+
@[local grind =] private theorem getElem!_def [Inhabited β] (m : IndexMap α β) (a : α) :
78+
m[a]! = (m.indices[a]?.bind (fun i => (m.values[i]?))).getD default := rfl
79+
80+
instance : LawfulGetElem (IndexMap α β) α β (fun m a => a ∈ m) where
81+
getElem?_def := by grind
82+
getElem!_def := by grind
83+
84+
@[inline] def insert [LawfulBEq α] (m : IndexMap α β) (a : α) (b : β) :
85+
IndexMap α β :=
86+
match h : m.indices[a]? with
87+
| some i =>
88+
{ indices := m.indices
89+
keys := m.keys.set i a
90+
values := m.values.set i b }
91+
| none =>
92+
{ indices := m.indices.insert a m.size
93+
keys := m.keys.push a
94+
values := m.values.push b }
95+
96+
/-! ### Verification theorems -/
97+
98+
attribute [local grind] getIdx findIdx insert
99+
100+
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
101+
(m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by
102+
grind -offset -ring -linarith -cutsat =>
103+
instantiate only [= getElem_def, insert]
104+
cases #f590
105+
next =>
106+
cases #ffdf
107+
next => sorry
108+
next =>
109+
instantiate only
110+
instantiate only [= HashMap.getElem_insert]
111+
instantiate only [= size]
112+
instantiate only [= Array.getElem_push, = mem_indices_of_mem]
113+
next => sorry
114+
115+
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
116+
(m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by
117+
grind -offset -ring -linarith -cutsat =>
118+
instantiate only [= getElem_def, insert]
119+
cases #f590
120+
next =>
121+
cases #ffdf
122+
next => sorry
123+
next =>
124+
instantiate only
125+
instantiate only [= HashMap.getElem_insert]
126+
instantiate only [= size]
127+
instantiate only [= mem_indices_of_mem, = Array.getElem_push]
128+
next => sorry

0 commit comments

Comments
 (0)