Skip to content

Commit 63c0672

Browse files
authored
feat: preserve instantiation order at finish? (#10899)
This PR ensures the generated `instantiate` tactic instantiates the theorems using the same order used by `finish?`
1 parent b5dc11e commit 63c0672

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

src/Lean/Meta/Tactic/Grind/EMatchAction.lean

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,31 @@ import Lean.Meta.Tactic.Grind.EMatchTheoremParam
1212
import Lean.Meta.Tactic.Grind.MarkNestedSubsingletons
1313
namespace Lean.Meta.Grind.Action
1414

15+
/-
16+
**Note**: The unique IDs created to instantiate theorems have the form `<prefix>.<num>`,
17+
where `<num>` corresponds to the instantiation order within a particular proof branch.
18+
Thus, by sorting the collected theorems using their corresponding unique IDs,
19+
we can construct an `instantiate` tactic that performs the instantiations using
20+
the original order.
21+
22+
**Note**: It is unclear at this point whether this is a good strategy or not.
23+
The order in which things are asserted affects the proof found by `grind`.
24+
Thus, preserving the original order should intuitively help ensure that the generated
25+
tactic script for the continuation still closes the goal when combined with the
26+
generated `instantiate` tactic. However, it does not guarantee that the
27+
script can be successfully replayed, since we are filtering out instantiations that do
28+
not appear in the final proof term. Recall that a theorem instance may
29+
contribute to the proof search even if it does not appear in the final proof term.
30+
-/
31+
1532
structure CollectState where
1633
visited : Std.HashSet ExprPtr := {}
1734
collectedThms : Std.HashSet (Origin × EMatchTheoremKind) := {}
18-
thms : Array EMatchTheorem := #[]
35+
idAndThms : Array (Name × EMatchTheorem) := #[]
1936

20-
def collect (e : Expr) (map : EMatch.InstanceMap) : Array EMatchTheorem :=
37+
def collect (e : Expr) (map : EMatch.InstanceMap) : Array (Name × EMatchTheorem) :=
2138
let (_, s) := go e |>.run {}
22-
s.thms
39+
s.idAndThms
2340
where
2441
go (e : Expr) : StateM CollectState Unit := do
2542
if isMarkedSubsingletonApp e then
@@ -35,7 +52,7 @@ where
3552
if let some thm := map[uniqueId]? then
3653
let key := (thm.origin, thm.kind)
3754
unless (← get).collectedThms.contains key do
38-
modify fun s => { s with collectedThms := s.collectedThms.insert key, thms := s.thms.push thm }
55+
modify fun s => { s with collectedThms := s.collectedThms.insert key, idAndThms := s.idAndThms.push (uniqueId, thm) }
3956
match e with
4057
| .lam _ d b _
4158
| .forallE _ d b _ => go d; go b
@@ -93,7 +110,10 @@ public def instantiate' : Action := fun goal kna kp => do
93110
| .closed seq =>
94111
if (← getConfig).trace then
95112
let proof ← instantiateMVars (mkMVar goal.mvarId)
96-
let usedThms := collect proof map
113+
let usedIdAndThms := collect proof map
114+
-- **Note**: See note above. We want to sort here to reproduce the original instantiation order.
115+
let usedIdAndThms := usedIdAndThms.qsort fun (id₁, _) (id₂, _) => id₁.lt id₂
116+
let usedThms := usedIdAndThms.map (·.2)
97117
let newSeq ← mkNewSeq goal usedThms seq (approx := false)
98118
if (← checkSeqAt saved? goal newSeq) then
99119
return .closed newSeq

tests/lean/run/grind_indexmap_trace.lean

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ example (m : IndexMap α β) (a : α) (h : a ∈ m) :
147147
info: Try this:
148148
[apply]
149149
instantiate only [= mem_indices_of_mem, insert]
150-
instantiate only [= getElem?_neg, = getElem?_pos, =_ HashMap.contains_iff_mem]
150+
instantiate only [=_ HashMap.contains_iff_mem, = getElem?_neg, = getElem?_pos]
151151
cases #4ed2
152152
next =>
153153
cases #ffdf
@@ -179,7 +179,7 @@ example (m : IndexMap α β) (a a' : α) (b : β) :
179179
info: Try this:
180180
[apply]
181181
instantiate only [= mem_indices_of_mem, insert]
182-
instantiate only [= getElem?_neg, = getElem?_pos, =_ HashMap.contains_iff_mem]
182+
instantiate only [=_ HashMap.contains_iff_mem, = getElem?_neg, = getElem?_pos]
183183
cases #4ed2
184184
next =>
185185
cases #ffdf
@@ -247,19 +247,19 @@ info: Try this:
247247
instantiate only [= Array.getElem_set]
248248
next =>
249249
instantiate only
250-
instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert, = HashMap.mem_insert]
250+
instantiate only [size, = HashMap.mem_insert, = HashMap.getElem_insert, = Array.getElem_push]
251251
next =>
252-
instantiate only [= getElem_def, = mem_indices_of_mem]
252+
instantiate only [= mem_indices_of_mem, = getElem_def]
253253
instantiate only [usr getElem_indices_lt]
254254
instantiate only [size]
255255
cases #ffdf
256256
next =>
257257
instantiate only [=_ WF]
258-
instantiate only [= Array.getElem_set, = getElem?_neg, = getElem?_pos]
258+
instantiate only [= getElem?_neg, = getElem?_pos, = Array.getElem_set]
259259
instantiate only [WF']
260260
next =>
261261
instantiate only
262-
instantiate only [= Array.getElem_push, = HashMap.mem_insert, = HashMap.getElem_insert]
262+
instantiate only [= HashMap.mem_insert, = HashMap.getElem_insert, = Array.getElem_push]
263263
-/
264264
#guard_msgs in
265265
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
@@ -298,8 +298,8 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
298298
/--
299299
info: Try this:
300300
[apply]
301-
instantiate only [insert, = mem_indices_of_mem, findIdx]
302-
instantiate only [= getElem?_pos, = getElem?_neg]
301+
instantiate only [findIdx, insert, = mem_indices_of_mem]
302+
instantiate only [= getElem?_neg, = getElem?_pos]
303303
cases #1bba
304304
next => instantiate only [findIdx]
305305
next =>

0 commit comments

Comments
 (0)