Skip to content

Commit e2b5747

Browse files
authored
feat: evalTactic in GrindM (#10833)
This PR implements infrastructure for evaluating `grind` tactics in the `GrindM` monad. We are going to use it to check whether auto-generated tactics can effectively close the original goal.
1 parent dad5412 commit e2b5747

File tree

4 files changed

+82
-7
lines changed

4 files changed

+82
-7
lines changed

src/Lean/Elab/Tactic/Grind/Basic.lean

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,34 @@ def liftSearchM (k : SearchM α) : GrindTacticM α := do
341341
replaceMainGoal [state.goal]
342342
return a
343343

344+
def GrindTacticM.run (x : GrindTacticM α) (ctx : Context) (s : State) : TermElabM (α × State) :=
345+
x ctx |>.run s
346+
347+
def mkEvalTactic' (elaborator : Name) (params : Params) : TermElabM (Goal → TSyntax `grind → GrindM (List Goal)) := do
348+
let termState ← getThe Term.State
349+
let termCtx ← readThe Term.Context
350+
let eval (goal : Goal) (stx : TSyntax `grind) : GrindM (List Goal) := do
351+
let methods ← getMethods
352+
let grindCtx ← readThe Meta.Grind.Context
353+
let grindState ← get
354+
-- **Note**: we discard changes to `Term.State`
355+
let (subgoals, grindState') ← Term.TermElabM.run' (ctx := termCtx) (s := termState) do
356+
let (_, s) ← GrindTacticM.run
357+
(ctx := { methods, ctx := grindCtx, params, elaborator })
358+
(s := { state := grindState, goals := [goal] }) do
359+
evalGrindTactic stx.raw
360+
pruneSolvedGoals
361+
return (s.goals, s.state)
362+
set grindState'
363+
return subgoals
364+
return eval
365+
366+
def mkEvalTactic (params : Params) : TacticM (Goal → TSyntax `grind → GrindM (List Goal)) := do
367+
mkEvalTactic' (← read).elaborator params
368+
344369
def GrindTacticM.runAtGoal (mvarId : MVarId) (params : Params) (k : GrindTacticM α) : TacticM (α × State) := do
345-
let (methods, ctx, state) ← liftMetaM <| GrindM.runAtGoal mvarId params fun goal => do
370+
let evalTactic ← mkEvalTactic params
371+
let (methods, ctx, state) ← liftMetaM <| GrindM.runAtGoal mvarId params (evalTactic? := some evalTactic) fun goal => do
346372
let methods ← getMethods
347373
-- **Note**: We use `withCheapCasesOnly` to ensure multiple goals are not created.
348374
-- We will add support for this case in the future.

src/Lean/Meta/Tactic/Grind/Action.lean

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ A terminal action which closes the goal or not.
247247
This kind of action may make progress, but we only include `mkTac` into the resulting tactic sequence
248248
if it closed the goal.
249249
-/
250-
public def terminalAction (check : GoalM Bool) (mkTac : GrindM (TSyntax `grind)) : Action := fun goal kna kp => do
250+
def terminalAction (check : GoalM Bool) (mkTac : GrindM (TSyntax `grind)) : Action := fun goal kna kp => do
251251
let (progress, goal') ← GoalM.run goal check
252252
if progress then
253253
if goal'.inconsistent then
@@ -257,6 +257,24 @@ public def terminalAction (check : GoalM Bool) (mkTac : GrindM (TSyntax `grind))
257257
else
258258
kna goal'
259259

260+
/--
261+
Helper action that checks whether the resulting tactic script produced by its continuation
262+
can close the original goal.
263+
-/
264+
def checkTactic : Action := fun goal _ kp => do
265+
let s ← saveState
266+
let r ← kp goal
267+
match r with
268+
| .closed seq =>
269+
let tac ← mkGrindNext seq
270+
Lean.withoutModifyingState do
271+
s.restore
272+
let subgoals ← evalTactic goal tac
273+
unless subgoals.isEmpty do
274+
throwError "generated tactic cannot close the goal{indentD tac}\nInitial goal\n{goal.mvarId}\nPending subgoals\n{subgoals.map (·.mvarId)}"
275+
return r
276+
| _ => return r
277+
260278
section
261279
/-!
262280
Some sanity check properties.

src/Lean/Meta/Tactic/Grind/Main.lean

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ def mkParams (config : Grind.Config) : MetaM Params := do
4545
let symPrios ← getGlobalSymbolPriorities
4646
return { config, norm, normProcs, symPrios }
4747

48-
def mkMethods : CoreM Methods := do
48+
def mkMethods (evalTactic? : Option EvalTactic := none) : CoreM Methods := do
4949
let builtinPropagators ← builtinPropagatorsRef.get
50+
let evalTactic : EvalTactic := evalTactic?.getD EvalTactic.skip
5051
return {
52+
evalTactic
5153
propagateUp := fun e => do
5254
propagateForallPropUp e
5355
propagateReflCmp e
@@ -75,7 +77,7 @@ private def discharge? (e : Expr) : SimpM (Option Expr) := do
7577
else
7678
return none
7779

78-
def GrindM.run (x : GrindM α) (params : Params) : MetaM α := do
80+
def GrindM.run (x : GrindM α) (params : Params) (evalTactic? : Option EvalTactic := none) : MetaM α := do
7981
let (falseExpr, scState) := shareCommonAlpha (mkConst ``False) {}
8082
let (trueExpr, scState) := shareCommonAlpha (mkConst ``True) scState
8183
let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState
@@ -88,7 +90,7 @@ def GrindM.run (x : GrindM α) (params : Params) : MetaM α := do
8890
let simp := params.norm
8991
let config := params.config
9092
let symPrios := params.symPrios
91-
x (← mkMethods).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios }
93+
x (← mkMethods evalTactic?).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios }
9294
|>.run' { scState }
9395

9496
private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do
@@ -217,11 +219,11 @@ def mkResult (params : Params) (failure? : Option Goal) : GrindM Result := do
217219
logInfo msg
218220
return { failure?, issues, config := params.config, trace, counters, simp, splitDiags }
219221

220-
def GrindM.runAtGoal (mvarId : MVarId) (params : Params) (k : Goal → GrindM α) : MetaM α := do
222+
def GrindM.runAtGoal (mvarId : MVarId) (params : Params) (k : Goal → GrindM α) (evalTactic? : Option EvalTactic := none) : MetaM α := do
221223
let go : GrindM α := withReducible do
222224
let goal ← initCore mvarId params
223225
k goal
224-
go.run params
226+
go.run params (evalTactic? := evalTactic?)
225227

226228
def main (mvarId : MVarId) (params : Params) : MetaM Result := do profileitM Exception "grind" (← getOptions) do
227229
GrindM.runAtGoal mvarId params fun goal => do

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ structure State where
219219
-/
220220
anchors : PHashMap ExprPtr UInt64 := {}
221221

222+
instance : Nonempty State :=
223+
.intro {}
224+
222225
private opaque MethodsRefPointed : NonemptyType.{0}
223226
def MethodsRef : Type := MethodsRefPointed.type
224227
instance : Nonempty MethodsRef := by exact MethodsRefPointed.property
@@ -228,6 +231,26 @@ abbrev GrindM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM
228231
@[inline] def mapGrindM [MonadControlT GrindM m] [Monad m] (f : {α : Type} → GrindM α → GrindM α) {α} (x : m α) : m α :=
229232
controlAt GrindM fun runInBase => f <| runInBase x
230233

234+
/--
235+
Backtrackable state for the `GrindM` monad.
236+
-/
237+
structure SavedState where
238+
«meta» : Meta.SavedState
239+
grind : State
240+
deriving Nonempty
241+
242+
protected def saveState : GrindM SavedState :=
243+
return { «meta» := (← Meta.saveState), grind := (← get) }
244+
245+
/-- Restore backtrackable parts of the state. -/
246+
def SavedState.restore (b : SavedState) : GrindM Unit := do
247+
b.meta.restore
248+
set b.grind
249+
250+
instance : MonadBacktrack SavedState GrindM where
251+
saveState := Grind.saveState
252+
restoreState s := s.restore
253+
231254
/--
232255
`withoutReportingMVarIssues x` executes `x` without reporting metavariables found during internalization.
233256
See comment at `Grind.Context.reportMVarIssue` for additional details.
@@ -1325,10 +1348,13 @@ def forEachEqcRoot (f : ENode → GoalM Unit) : GoalM Unit := do
13251348
f n
13261349

13271350
abbrev Propagator := Expr → GoalM Unit
1351+
abbrev EvalTactic := Goal → TSyntax `grind → GrindM (List Goal)
1352+
def EvalTactic.skip : EvalTactic := fun goal _ => return [goal]
13281353

13291354
structure Methods where
13301355
propagateUp : Propagator := fun _ => return ()
13311356
propagateDown : Propagator := fun _ => return ()
1357+
evalTactic : EvalTactic := EvalTactic.skip
13321358
deriving Inhabited
13331359

13341360
def Methods.toMethodsRef (m : Methods) : MethodsRef :=
@@ -1346,6 +1372,9 @@ def propagateUp (e : Expr) : GoalM Unit := do
13461372
def propagateDown (e : Expr) : GoalM Unit := do
13471373
(← getMethods).propagateDown e
13481374

1375+
def evalTactic (goal : Goal) (stx : TSyntax `grind) : GrindM (List Goal) := do
1376+
(← getMethods).evalTactic goal stx
1377+
13491378
/-- Returns expressions in the given expression equivalence class. -/
13501379
partial def Goal.getEqc (goal : Goal) (e : Expr) (sort := false) : List Expr :=
13511380
let eqc := go e e #[]

0 commit comments

Comments
 (0)