Skip to content

Commit a2b50ac

Browse files
authored
feat: generate instantiate only [...] at finish? (#10841)
This PR improves the `grind` tactic generated by the `instantiate` action in tracing mode. It also updates the syntax for the `instantiate` tactic, making it similar to `simp`. For example: * `instantiate only [thm1, thm2]` instantiates only theorems `thm1` and `thm2`. * `instantiate [thm1, thm2]` instantiates theorems marked with the `@[grind]` attribute **and** theorems `thm1` and `thm2`. The action produces `instantiate only [...]` tactics. Example: ```lean /-- info: Try this: [apply] ⏎ instantiate only [= Array.getElem_set] instantiate only [= Array.getElem_set] -/ #guard_msgs in example (as bs cs : Array α) (v₁ v₂ : α) (i₁ i₂ j : Nat) (h₁ : i₁ < as.size) (h₂ : bs = as.set i₁ v₁) (h₃ : i₂ < bs.size) (h₄ : cs = bs.set i₂ v₂) (h₅ : i₁ ≠ j ∧ i₂ ≠ j) (h₆ : j < cs.size) (h₇ : j < as.size) : cs[j] = as[j] := by grind => finish? ``` Recall that `finish?` replays generated tactics before suggesting them. The `instantiate` action inspects the generated proof term to decide which theorems to include as parameters in the `instantiate only [...]` tactic. However, in some cases, a theorem contributes only by adding a term to the state. In such cases, the generated tactic cannot be fully replayed, and the action uses `instantiate approx [<thms instantiated>]` to indicate which parts of the tactic script are approximate. The `approx` is just a hint for users.
1 parent 61ee3b2 commit a2b50ac

File tree

13 files changed

+330
-98
lines changed

13 files changed

+330
-98
lines changed

src/Init/Grind/Interactive.lean

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ syntax (name := «sorry») "sorry" : grind
6666
syntax anchor := "#" noWs hexnum
6767
syntax thm := anchor <|> grindLemma <|> grindLemmaMin
6868

69-
/-- Instantiates theorems using E-matching. -/
70-
syntax (name := instantiate) "instantiate" (colGt thm),* : grind
69+
/--
70+
Instantiates theorems using E-matching.
71+
The `approx` modifier is just a marker for users to easily identify automatically generated `instantiate` tactics
72+
that may have redundant arguments.
73+
-/
74+
syntax (name := instantiate) "instantiate" (&" only")? (&" approx")? (" [" withoutPosition(thm,*,?) "]")? : grind
7175

7276
-- **Note**: Should we rename the following tactics to `trace_`?
7377
/-- Shows asserted facts. -/

src/Lean/Elab/Tactic/Grind/Basic.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def mkEvalTactic' (elaborator : Name) (params : Params) : TermElabM (Goal → TS
354354
-- **Note**: we discard changes to `Term.State`
355355
let (subgoals, grindState') ← Term.TermElabM.run' (ctx := termCtx) (s := termState) do
356356
let (_, s) ← GrindTacticM.run
357-
(ctx := { methods, ctx := grindCtx, params, elaborator })
357+
(ctx := { recover := false, methods, ctx := grindCtx, params, elaborator })
358358
(s := { state := grindState, goals := [goal] }) do
359359
evalGrindTactic stx.raw
360360
pruneSolvedGoals

src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def logTheoremAnchor (proof : Expr) : TermElabM Unit := do
123123
let stx ← getRef
124124
Term.addTermInfo' stx proof
125125

126-
def ematchThms (thms : Array EMatchTheorem) : GrindTacticM Unit := do
127-
let progress ← liftGoalM <| if thms.isEmpty then ematch else ematchTheorems thms
126+
def ematchThms (only : Bool) (thms : Array EMatchTheorem) : GrindTacticM Unit := do
127+
let progress ← liftGoalM <| if only then ematchOnly thms else ematch thms
128128
unless progress do
129129
throwError "`instantiate` tactic failed to instantiate new facts, use `show_patterns` to see active theorems and their patterns."
130130
let goal ← getMainGoal
@@ -142,15 +142,19 @@ def elabAnchor (anchor : TSyntax `hexnum) : CoreM (Nat × UInt64) := do
142142
return (numDigits, val)
143143

144144
@[builtin_grind_tactic instantiate] def evalInstantiate : GrindTactic := fun stx => withMainContext do
145-
let `(grind| instantiate $[$thmRefs:thm],*) := stx | throwUnsupportedSyntax
146-
let mut thms := #[]
147-
for thmRef in thmRefs do
148-
match thmRef with
149-
| `(Parser.Tactic.Grind.thm| #$anchor:hexnum) => thms := thms ++ (← withRef thmRef <| elabLocalEMatchTheorem anchor)
150-
| `(Parser.Tactic.Grind.thm| $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id false)
151-
| `(Parser.Tactic.Grind.thm| ! $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id true)
152-
| _ => throwErrorAt thmRef "unexpected theorem reference"
153-
ematchThms thms
145+
let `(grind| instantiate $[ only%$only ]? $[ approx ]? $[ [ $[$thmRefs?:thm],* ] ]?) := stx | throwUnsupportedSyntax
146+
let goal ← getMainGoal
147+
let only := only.isSome
148+
let initThms ← if only then goal.getActiveMatchEqTheorems else pure #[]
149+
let mut thms := initThms
150+
if let some thmRefs := thmRefs? then
151+
for thmRef in thmRefs do
152+
match thmRef with
153+
| `(Parser.Tactic.Grind.thm| #$anchor:hexnum) => thms := thms ++ (← withRef thmRef <| elabLocalEMatchTheorem anchor)
154+
| `(Parser.Tactic.Grind.thm| $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id false)
155+
| `(Parser.Tactic.Grind.thm| ! $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id true)
156+
| _ => throwErrorAt thmRef "unexpected theorem reference"
157+
ematchThms only thms
154158
where
155159
collectThms (numDigits : Nat) (anchorPrefix : UInt64) (thms : PArray EMatchTheorem) : StateT (Array EMatchTheorem) GrindTacticM Unit := do
156160
let mut found : Std.HashSet Expr := {}

src/Lean/Elab/Tactic/Grind/Main.lean

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public import Lean.Meta.Tactic.TryThis
1010
public import Lean.Elab.Command
1111
public import Lean.Elab.Tactic.Config
1212
import Lean.Meta.Tactic.Grind.SimpUtil
13+
import Lean.Meta.Tactic.Grind.EMatchTheoremParam
1314
import Lean.Elab.Tactic.Grind.Basic
1415
import Lean.Elab.MutualDef
1516
meta import Lean.Meta.Tactic.Grind.Parser
@@ -292,37 +293,13 @@ def mkGrindOnly
292293
let mut foundFns : NameSet := {}
293294
for { origin, kind, minIndexable } in trace.thms.toList do
294295
if let .decl declName := origin then
295-
if let some declName ← isEqnThm? declName then
296-
unless foundFns.contains declName do
297-
foundFns := foundFns.insert declName
298-
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
299-
let param ← `(Parser.Tactic.grindParam| $decl:ident)
296+
if let some fnName ← isEqnThm? declName then
297+
unless foundFns.contains fnName do
298+
foundFns := foundFns.insert fnName
299+
let param ← Grind.globalDeclToGrindParamSyntax declName kind minIndexable
300300
params := params.push param
301301
else
302-
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
303-
let param ← match kind, minIndexable with
304-
| .eqLhs false, _ => `(Parser.Tactic.grindParam| = $decl:ident)
305-
| .eqLhs true, _ => `(Parser.Tactic.grindParam| = gen $decl:ident)
306-
| .eqRhs false, _ => `(Parser.Tactic.grindParam| =_ $decl:ident)
307-
| .eqRhs true, _ => `(Parser.Tactic.grindParam| =_ gen $decl:ident)
308-
| .eqBoth false, _ => `(Parser.Tactic.grindParam| _=_ $decl:ident)
309-
| .eqBoth true, _ => `(Parser.Tactic.grindParam| _=_ gen $decl:ident)
310-
| .eqBwd, _ => `(Parser.Tactic.grindParam| ←= $decl:ident)
311-
| .user, _ => `(Parser.Tactic.grindParam| usr $decl:ident)
312-
| .bwd false, false => `(Parser.Tactic.grindParam| ← $decl:ident)
313-
| .bwd true, false => `(Parser.Tactic.grindParam| ← gen $decl:ident)
314-
| .fwd, false => `(Parser.Tactic.grindParam| → $decl:ident)
315-
| .leftRight, false => `(Parser.Tactic.grindParam| => $decl:ident)
316-
| .rightLeft, false => `(Parser.Tactic.grindParam| <= $decl:ident)
317-
| .default false, false => `(Parser.Tactic.grindParam| $decl:ident)
318-
| .default true, false => `(Parser.Tactic.grindParam| gen $decl:ident)
319-
| .bwd false, true => `(Parser.Tactic.grindParam| ! ← $decl:ident)
320-
| .bwd true, true => `(Parser.Tactic.grindParam| ! ← gen $decl:ident)
321-
| .fwd, true => `(Parser.Tactic.grindParam| ! → $decl:ident)
322-
| .leftRight, true => `(Parser.Tactic.grindParam| ! => $decl:ident)
323-
| .rightLeft, true => `(Parser.Tactic.grindParam| ! <= $decl:ident)
324-
| .default false, true => `(Parser.Tactic.grindParam| ! $decl:ident)
325-
| .default true, true => `(Parser.Tactic.grindParam| ! gen $decl:ident)
302+
let param ← Grind.globalDeclToGrindParamSyntax declName kind minIndexable
326303
params := params.push param
327304
for declName in trace.eagerCases.toList do
328305
unless Grind.isBuiltinEagerCases declName do

src/Lean/Elab/Tactic/Grind/ShowState.lean

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ prelude
88
public import Lean.Elab.Tactic.Grind.Basic
99
public import Lean.Elab.Tactic.Grind.Filter
1010
import Lean.Meta.Tactic.Grind.PP
11+
import Lean.Meta.Tactic.Grind.EMatchTheoremParam
1112
import Lean.Meta.Tactic.Grind.Anchor
1213
import Lean.Meta.Tactic.Grind.Split
1314
namespace Lean.Elab.Tactic.Grind
@@ -118,30 +119,11 @@ public def showState (filter : Filter) (isSilent := false) : GrindTacticM Unit :
118119

119120
@[builtin_grind_tactic showLocalThms] def evalShowLocalThms : GrindTactic := fun _ => withMainContext do
120121
let goal ← getMainGoal
121-
let entries ← liftGrindM do
122-
let (found, entries) ← go {} {} goal.ematch.thms
123-
let (_, entries) ← go found entries goal.ematch.newThms
124-
pure entries
122+
let entries ← liftGrindM <| getLocalTheoremAnchors goal
125123
let numDigits := getNumDigitsForAnchors entries
126124
let msgs := entries.map fun { anchor, e } =>
127125
.trace { cls := `thm } m!"#{anchorToString numDigits anchor} := {e}" #[]
128126
let msg := MessageData.trace { cls := `thms, collapsed := false } "Local theorems" msgs
129127
logInfo msg
130-
where
131-
go (found : Std.HashSet Grind.Origin) (result : Array ExprWithAnchor) (thms : PArray EMatchTheorem)
132-
: GrindM (Std.HashSet Grind.Origin × Array ExprWithAnchor) := do
133-
let mut found := found
134-
let mut result := result
135-
for thm in thms do
136-
-- **Note**: We only display local theorems
137-
if thm.origin matches .local _ | .fvar _ then
138-
unless found.contains thm.origin do
139-
found := found.insert thm.origin
140-
let type ← inferType thm.proof
141-
-- **Note**: Evaluate how stable these anchors are.
142-
let anchor ← getAnchor type
143-
result := result.push { anchor, e := type }
144-
pure ()
145-
return (found, result)
146128

147129
end Lean.Elab.Tactic.Grind

src/Lean/Meta/Tactic/Grind.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public import Lean.Meta.Tactic.Grind.PropagateInj
4343
public import Lean.Meta.Tactic.Grind.Order
4444
public import Lean.Meta.Tactic.Grind.Anchor
4545
public import Lean.Meta.Tactic.Grind.Action
46+
public import Lean.Meta.Tactic.Grind.EMatchTheoremParam
4647
public import Lean.Meta.Tactic.Grind.EMatchAction
4748
public section
4849
namespace Lean

src/Lean/Meta/Tactic/Grind/Action.lean

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,20 @@ def mkGrindNext (s : List TGrind) : CoreM TGrind := do
163163
let s := mkGrindSeq s
164164
`(grind| next => $s:grindSeq)
165165

166+
/--
167+
Given `[t₁, ..., tₙ]`, returns
168+
```
169+
(t₁
170+
...
171+
tₙ)
172+
```
173+
If the list is empty, it returns `(skip)`.
174+
-/
175+
private def mkGrindParen (s : List TGrind) : CoreM TGrind := do
176+
let s ← if s == [] then pure [← `(grind| skip)] else pure s
177+
let s := mkGrindSeq s
178+
`(grind| ($s:grindSeq))
179+
166180
/--
167181
If tracing is enabled and continuation produced `.closed [t₁, ..., tₙ]`,
168182
returns the singleton sequence `[t]` where `t` is
@@ -251,21 +265,39 @@ def solverAction (check : GoalM CheckResult) (mkTac : GrindM (TSyntax `grind)) :
251265
concatTactic (← kp goal') mkTac
252266
| .closed => closeWith mkTac
253267

268+
def saveStateIfTracing : GrindM (Option SavedState) := do
269+
if (← getConfig).trace then
270+
return some (← saveState)
271+
else
272+
return none
273+
/--
274+
Returns `true` if the tactic sequence `seq` closes `goal` starting at saved state `s?`.
275+
If `s?` is `none` just returns `true`.
276+
-/
277+
def checkSeqAt (s? : Option SavedState) (goal : Goal) (seq : List TGrind) : GrindM Bool := do
278+
let some s := s? | return true
279+
let tac ← mkGrindParen seq
280+
Lean.withoutModifyingState do
281+
s.restore
282+
-- **Note**: Ensure tracing is disabled.
283+
withTheReader Grind.Context (fun ctx => { ctx with config.trace := false }) do
284+
try
285+
let subgoals ← evalTactic goal tac
286+
return subgoals.isEmpty
287+
catch _ =>
288+
return false
289+
254290
/--
255291
Helper action that checks whether the resulting tactic script produced by its continuation
256292
can close the original goal.
257293
-/
258294
def checkTactic : Action := fun goal _ kp => do
259-
let s ← saveState
295+
let s ← saveStateIfTracing
260296
let r ← kp goal
261297
match r with
262298
| .closed seq =>
263-
let tac ← mkGrindNext seq
264-
Lean.withoutModifyingState do
265-
s.restore
266-
let subgoals ← evalTactic goal tac
267-
unless subgoals.isEmpty do
268-
throwError "generated tactic cannot close the goal{indentD tac}\nInitial goal\n{goal.mvarId}\nPending subgoals\n{subgoals.map (·.mvarId)}"
299+
unless (← checkSeqAt s goal seq) do
300+
throwError "generated tactic cannot close the goal{indentD (← mkGrindNext seq)}\nInitial goal\n{goal.mvarId}"
269301
return r
270302
| _ => return r
271303

src/Lean/Meta/Tactic/Grind/EMatch.lean

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,44 @@ structure Context where
6767
initApp : Expr := default
6868
deriving Inhabited
6969

70+
/--
71+
A mapping `uniqueId ↦ thm`, where `uniqueId` is an auxiliary marker used to wrap a theorem instantiation proof of `thm`
72+
using a `Expr.mdata`. The `uniqueId`s are created using `mkFreshId`.
73+
-/
74+
abbrev InstanceMap := Std.HashMap Name EMatchTheorem
75+
76+
private def thmInstanceKey := `_grind_thm_instance
77+
78+
private def markTheoremInstanceProof (proof : Expr) (uniqueId : Name) : Expr :=
79+
Expr.mdata (KVMap.empty.insert thmInstanceKey uniqueId) proof
80+
81+
/-- Returns `some uniqueId` if `proof` was marked using `markTheoremInstanceProof` -/
82+
def isTheoremInstanceProof? (proof : Expr) : Option Name :=
83+
match proof with
84+
| .mdata d _ =>
85+
match d.find thmInstanceKey with
86+
| some (DataValue.ofName uniqueId) => some uniqueId
87+
| _ => none
88+
| _ => none
89+
7090
/-- State for the E-matching monad -/
7191
structure SearchState where
7292
/-- Choices that still have to be processed. -/
7393
choiceStack : List Choice := []
94+
/--
95+
When tracing is enabled track instances here. See comment at `InstanceMap`
96+
-/
97+
instanceMap : InstanceMap := {}
7498
deriving Inhabited
7599

76100
abbrev M := ReaderT Context $ StateRefT SearchState GoalM
77101

78102
def M.run' (x : M α) : GoalM α :=
79103
x {} |>.run' {}
80104

105+
def M.run (x : M α) : GoalM (α × SearchState) :=
106+
x {} |>.run {}
107+
81108
@[inline] private abbrev withInitApp (e : Expr) (x : M α) : M α :=
82109
withReader (fun ctx => { ctx with initApp := e }) x
83110

@@ -453,6 +480,10 @@ where
453480
go (proof prop : Expr) : M Unit := do
454481
let mut proof := proof
455482
let mut prop := prop
483+
if (← getConfig).trace then
484+
let uniqueId ← mkFreshId
485+
proof := markTheoremInstanceProof proof uniqueId
486+
modify fun s => { s with instanceMap := s.instanceMap.insert uniqueId thm }
456487
if (← isMatchEqLikeDeclName thm.origin.key) then
457488
prop ← annotateMatchEqnType prop (← read).initApp
458489
-- We must add a hint here because `annotateMatchEqnType` introduces `simpMatchDiscrsOnly` and
@@ -706,29 +737,41 @@ end EMatch
706737
open EMatch
707738

708739
/-- Performs one round of E-matching, and returns new instances. -/
709-
private def ematchCore : GoalM Unit := do profileitM Exception "grind ematch" (← getOptions) do
740+
private def ematchCore (extraThms : Array EMatchTheorem) : GoalM InstanceMap := do profileitM Exception "grind ematch" (← getOptions) do
710741
let go (thms newThms : PArray EMatchTheorem) : EMatch.M Unit := do
711742
withReader (fun ctx => { ctx with useMT := true }) <| ematchTheorems thms
712-
withReader (fun ctx => { ctx with useMT := false }) <| ematchTheorems newThms
743+
withReader (fun ctx => { ctx with useMT := false }) do
744+
ematchTheorems newThms
745+
extraThms.forM ematchTheorem
713746
if (← checkMaxInstancesExceeded <||> checkMaxEmatchExceeded) then
714-
return ()
747+
return {}
715748
else
716-
go (← get).ematch.thms (← get).ematch.newThms |>.run'
749+
let (_, s) ← go (← get).ematch.thms (← get).ematch.newThms |>.run
717750
modify fun s => { s with
718751
ematch.thms := s.ematch.thms ++ s.ematch.newThms
719752
ematch.newThms := {}
720753
ematch.gmt := s.ematch.gmt + 1
721754
ematch.num := s.ematch.num + 1
722755
}
756+
return s.instanceMap
723757

724-
/-- Performs one round of E-matching, and returns `true` if new instances were generated. -/
725-
def ematch : GoalM Bool := do
758+
/--
759+
Performs one round of E-matching, and returns `true` if new instances were generated.
760+
Recall that the mapping is nonempty only if tracing is enabled.
761+
-/
762+
def ematch' (extraThms : Array EMatchTheorem := #[]) : GoalM (Bool × InstanceMap) := do
726763
let numInstances := (← get).ematch.numInstances
727-
ematchCore
728-
return (← get).ematch.numInstances != numInstances
764+
let map ← ematchCore extraThms
765+
return ((← get).ematch.numInstances != numInstances, map)
766+
767+
/--
768+
Performs one round of E-matching, and returns `true` if new instances were generated.
769+
-/
770+
def ematch (extraThms : Array EMatchTheorem := #[]) : GoalM Bool :=
771+
return (← ematch' extraThms).1
729772

730773
/-- Performs one round of E-matching using the giving theorems, and returns `true` if new instances were generated. -/
731-
def ematchTheorems (thms : Array EMatchTheorem) : GoalM Bool := do
774+
def ematchOnly (thms : Array EMatchTheorem) : GoalM Bool := do
732775
let numInstances := (← get).ematch.numInstances
733776
go |>.run'
734777
return (← get).ematch.numInstances != numInstances

0 commit comments

Comments
 (0)