Skip to content
8 changes: 6 additions & 2 deletions src/Init/Grind/Interactive.lean
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ syntax (name := «sorry») "sorry" : grind
syntax anchor := "#" noWs hexnum
syntax thm := anchor <|> grindLemma <|> grindLemmaMin

/-- Instantiates theorems using E-matching. -/
syntax (name := instantiate) "instantiate" (colGt thm),* : grind
/--
Instantiates theorems using E-matching.
The `approx` modifier is just a marker for users to easily identify automatically generated `instantiate` tactics
that may have redundant arguments.
-/
syntax (name := instantiate) "instantiate" (&" only")? (&" approx")? (" [" withoutPosition(thm,*,?) "]")? : grind

-- **Note**: Should we rename the following tactics to `trace_`?
/-- Shows asserted facts. -/
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/Grind/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def mkEvalTactic' (elaborator : Name) (params : Params) : TermElabM (Goal → TS
-- **Note**: we discard changes to `Term.State`
let (subgoals, grindState') ← Term.TermElabM.run' (ctx := termCtx) (s := termState) do
let (_, s) ← GrindTacticM.run
(ctx := { methods, ctx := grindCtx, params, elaborator })
(ctx := { recover := false, methods, ctx := grindCtx, params, elaborator })
(s := { state := grindState, goals := [goal] }) do
evalGrindTactic stx.raw
pruneSolvedGoals
Expand Down
26 changes: 15 additions & 11 deletions src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def logTheoremAnchor (proof : Expr) : TermElabM Unit := do
let stx ← getRef
Term.addTermInfo' stx proof

def ematchThms (thms : Array EMatchTheorem) : GrindTacticM Unit := do
let progress ← liftGoalM <| if thms.isEmpty then ematch else ematchTheorems thms
def ematchThms (only : Bool) (thms : Array EMatchTheorem) : GrindTacticM Unit := do
let progress ← liftGoalM <| if only then ematchOnly thms else ematch thms
unless progress do
throwError "`instantiate` tactic failed to instantiate new facts, use `show_patterns` to see active theorems and their patterns."
let goal ← getMainGoal
Expand All @@ -142,15 +142,19 @@ def elabAnchor (anchor : TSyntax `hexnum) : CoreM (Nat × UInt64) := do
return (numDigits, val)

@[builtin_grind_tactic instantiate] def evalInstantiate : GrindTactic := fun stx => withMainContext do
let `(grind| instantiate $[$thmRefs:thm],*) := stx | throwUnsupportedSyntax
let mut thms := #[]
for thmRef in thmRefs do
match thmRef with
| `(Parser.Tactic.Grind.thm| #$anchor:hexnum) => thms := thms ++ (← withRef thmRef <| elabLocalEMatchTheorem anchor)
| `(Parser.Tactic.Grind.thm| $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id false)
| `(Parser.Tactic.Grind.thm| ! $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id true)
| _ => throwErrorAt thmRef "unexpected theorem reference"
ematchThms thms
let `(grind| instantiate $[ only%$only ]? $[ approx ]? $[ [ $[$thmRefs?:thm],* ] ]?) := stx | throwUnsupportedSyntax
let goal ← getMainGoal
let only := only.isSome
let initThms ← if only then goal.getActiveMatchEqTheorems else pure #[]
let mut thms := initThms
if let some thmRefs := thmRefs? then
for thmRef in thmRefs do
match thmRef with
| `(Parser.Tactic.Grind.thm| #$anchor:hexnum) => thms := thms ++ (← withRef thmRef <| elabLocalEMatchTheorem anchor)
| `(Parser.Tactic.Grind.thm| $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id false)
| `(Parser.Tactic.Grind.thm| ! $[$mod?:grindMod]? $id:ident) => thms := thms ++ (← withRef thmRef <| elabThm mod? id true)
| _ => throwErrorAt thmRef "unexpected theorem reference"
ematchThms only thms
where
collectThms (numDigits : Nat) (anchorPrefix : UInt64) (thms : PArray EMatchTheorem) : StateT (Array EMatchTheorem) GrindTacticM Unit := do
let mut found : Std.HashSet Expr := {}
Expand Down
35 changes: 6 additions & 29 deletions src/Lean/Elab/Tactic/Grind/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public import Lean.Meta.Tactic.TryThis
public import Lean.Elab.Command
public import Lean.Elab.Tactic.Config
import Lean.Meta.Tactic.Grind.SimpUtil
import Lean.Meta.Tactic.Grind.EMatchTheoremParam
import Lean.Elab.Tactic.Grind.Basic
import Lean.Elab.MutualDef
meta import Lean.Meta.Tactic.Grind.Parser
Expand Down Expand Up @@ -292,37 +293,13 @@ def mkGrindOnly
let mut foundFns : NameSet := {}
for { origin, kind, minIndexable } in trace.thms.toList do
if let .decl declName := origin then
if let some declName ← isEqnThm? declName then
unless foundFns.contains declName do
foundFns := foundFns.insert declName
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
let param ← `(Parser.Tactic.grindParam| $decl:ident)
if let some fnName ← isEqnThm? declName then
unless foundFns.contains fnName do
foundFns := foundFns.insert fnName
let param ← Grind.globalDeclToGrindParamSyntax declName kind minIndexable
params := params.push param
else
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
let param ← match kind, minIndexable with
| .eqLhs false, _ => `(Parser.Tactic.grindParam| = $decl:ident)
| .eqLhs true, _ => `(Parser.Tactic.grindParam| = gen $decl:ident)
| .eqRhs false, _ => `(Parser.Tactic.grindParam| =_ $decl:ident)
| .eqRhs true, _ => `(Parser.Tactic.grindParam| =_ gen $decl:ident)
| .eqBoth false, _ => `(Parser.Tactic.grindParam| _=_ $decl:ident)
| .eqBoth true, _ => `(Parser.Tactic.grindParam| _=_ gen $decl:ident)
| .eqBwd, _ => `(Parser.Tactic.grindParam| ←= $decl:ident)
| .user, _ => `(Parser.Tactic.grindParam| usr $decl:ident)
| .bwd false, false => `(Parser.Tactic.grindParam| ← $decl:ident)
| .bwd true, false => `(Parser.Tactic.grindParam| ← gen $decl:ident)
| .fwd, false => `(Parser.Tactic.grindParam| → $decl:ident)
| .leftRight, false => `(Parser.Tactic.grindParam| => $decl:ident)
| .rightLeft, false => `(Parser.Tactic.grindParam| <= $decl:ident)
| .default false, false => `(Parser.Tactic.grindParam| $decl:ident)
| .default true, false => `(Parser.Tactic.grindParam| gen $decl:ident)
| .bwd false, true => `(Parser.Tactic.grindParam| ! ← $decl:ident)
| .bwd true, true => `(Parser.Tactic.grindParam| ! ← gen $decl:ident)
| .fwd, true => `(Parser.Tactic.grindParam| ! → $decl:ident)
| .leftRight, true => `(Parser.Tactic.grindParam| ! => $decl:ident)
| .rightLeft, true => `(Parser.Tactic.grindParam| ! <= $decl:ident)
| .default false, true => `(Parser.Tactic.grindParam| ! $decl:ident)
| .default true, true => `(Parser.Tactic.grindParam| ! gen $decl:ident)
let param ← Grind.globalDeclToGrindParamSyntax declName kind minIndexable
params := params.push param
for declName in trace.eagerCases.toList do
unless Grind.isBuiltinEagerCases declName do
Expand Down
22 changes: 2 additions & 20 deletions src/Lean/Elab/Tactic/Grind/ShowState.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ prelude
public import Lean.Elab.Tactic.Grind.Basic
public import Lean.Elab.Tactic.Grind.Filter
import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.EMatchTheoremParam
import Lean.Meta.Tactic.Grind.Anchor
import Lean.Meta.Tactic.Grind.Split
namespace Lean.Elab.Tactic.Grind
Expand Down Expand Up @@ -118,30 +119,11 @@ public def showState (filter : Filter) (isSilent := false) : GrindTacticM Unit :

@[builtin_grind_tactic showLocalThms] def evalShowLocalThms : GrindTactic := fun _ => withMainContext do
let goal ← getMainGoal
let entries ← liftGrindM do
let (found, entries) ← go {} {} goal.ematch.thms
let (_, entries) ← go found entries goal.ematch.newThms
pure entries
let entries ← liftGrindM <| getLocalTheoremAnchors goal
let numDigits := getNumDigitsForAnchors entries
let msgs := entries.map fun { anchor, e } =>
.trace { cls := `thm } m!"#{anchorToString numDigits anchor} := {e}" #[]
let msg := MessageData.trace { cls := `thms, collapsed := false } "Local theorems" msgs
logInfo msg
where
go (found : Std.HashSet Grind.Origin) (result : Array ExprWithAnchor) (thms : PArray EMatchTheorem)
: GrindM (Std.HashSet Grind.Origin × Array ExprWithAnchor) := do
let mut found := found
let mut result := result
for thm in thms do
-- **Note**: We only display local theorems
if thm.origin matches .local _ | .fvar _ then
unless found.contains thm.origin do
found := found.insert thm.origin
let type ← inferType thm.proof
-- **Note**: Evaluate how stable these anchors are.
let anchor ← getAnchor type
result := result.push { anchor, e := type }
pure ()
return (found, result)

end Lean.Elab.Tactic.Grind
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public import Lean.Meta.Tactic.Grind.PropagateInj
public import Lean.Meta.Tactic.Grind.Order
public import Lean.Meta.Tactic.Grind.Anchor
public import Lean.Meta.Tactic.Grind.Action
public import Lean.Meta.Tactic.Grind.EMatchTheoremParam
public import Lean.Meta.Tactic.Grind.EMatchAction
public section
namespace Lean
Expand Down
46 changes: 39 additions & 7 deletions src/Lean/Meta/Tactic/Grind/Action.lean
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ def mkGrindNext (s : List TGrind) : CoreM TGrind := do
let s := mkGrindSeq s
`(grind| next => $s:grindSeq)

/--
Given `[t₁, ..., tₙ]`, returns
```
(t₁
...
tₙ)
```
If the list is empty, it returns `(skip)`.
-/
private def mkGrindParen (s : List TGrind) : CoreM TGrind := do
let s ← if s == [] then pure [← `(grind| skip)] else pure s
let s := mkGrindSeq s
`(grind| ($s:grindSeq))

/--
If tracing is enabled and continuation produced `.closed [t₁, ..., tₙ]`,
returns the singleton sequence `[t]` where `t` is
Expand Down Expand Up @@ -251,21 +265,39 @@ def solverAction (check : GoalM CheckResult) (mkTac : GrindM (TSyntax `grind)) :
concatTactic (← kp goal') mkTac
| .closed => closeWith mkTac

def saveStateIfTracing : GrindM (Option SavedState) := do
if (← getConfig).trace then
return some (← saveState)
else
return none
/--
Returns `true` if the tactic sequence `seq` closes `goal` starting at saved state `s?`.
If `s?` is `none` just returns `true`.
-/
def checkSeqAt (s? : Option SavedState) (goal : Goal) (seq : List TGrind) : GrindM Bool := do
let some s := s? | return true
let tac ← mkGrindParen seq
Lean.withoutModifyingState do
s.restore
-- **Note**: Ensure tracing is disabled.
withTheReader Grind.Context (fun ctx => { ctx with config.trace := false }) do
try
let subgoals ← evalTactic goal tac
return subgoals.isEmpty
catch _ =>
return false

/--
Helper action that checks whether the resulting tactic script produced by its continuation
can close the original goal.
-/
def checkTactic : Action := fun goal _ kp => do
let s ← saveState
let s ← saveStateIfTracing
let r ← kp goal
match r with
| .closed seq =>
let tac ← mkGrindNext seq
Lean.withoutModifyingState do
s.restore
let subgoals ← evalTactic goal tac
unless subgoals.isEmpty do
throwError "generated tactic cannot close the goal{indentD tac}\nInitial goal\n{goal.mvarId}\nPending subgoals\n{subgoals.map (·.mvarId)}"
unless (← checkSeqAt s goal seq) do
throwError "generated tactic cannot close the goal{indentD (← mkGrindNext seq)}\nInitial goal\n{goal.mvarId}"
return r
| _ => return r

Expand Down
61 changes: 52 additions & 9 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,44 @@ structure Context where
initApp : Expr := default
deriving Inhabited

/--
A mapping `uniqueId ↦ thm`, where `uniqueId` is an auxiliary marker used to wrap a theorem instantiation proof of `thm`
using a `Expr.mdata`. The `uniqueId`s are created using `mkFreshId`.
-/
abbrev InstanceMap := Std.HashMap Name EMatchTheorem

private def thmInstanceKey := `_grind_thm_instance

private def markTheoremInstanceProof (proof : Expr) (uniqueId : Name) : Expr :=
Expr.mdata (KVMap.empty.insert thmInstanceKey uniqueId) proof

/-- Returns `some uniqueId` if `proof` was marked using `markTheoremInstanceProof` -/
def isTheoremInstanceProof? (proof : Expr) : Option Name :=
match proof with
| .mdata d _ =>
match d.find thmInstanceKey with
| some (DataValue.ofName uniqueId) => some uniqueId
| _ => none
| _ => none

/-- State for the E-matching monad -/
structure SearchState where
/-- Choices that still have to be processed. -/
choiceStack : List Choice := []
/--
When tracing is enabled track instances here. See comment at `InstanceMap`
-/
instanceMap : InstanceMap := {}
deriving Inhabited

abbrev M := ReaderT Context $ StateRefT SearchState GoalM

def M.run' (x : M α) : GoalM α :=
x {} |>.run' {}

def M.run (x : M α) : GoalM (α × SearchState) :=
x {} |>.run {}

@[inline] private abbrev withInitApp (e : Expr) (x : M α) : M α :=
withReader (fun ctx => { ctx with initApp := e }) x

Expand Down Expand Up @@ -453,6 +480,10 @@ where
go (proof prop : Expr) : M Unit := do
let mut proof := proof
let mut prop := prop
if (← getConfig).trace then
let uniqueId ← mkFreshId
proof := markTheoremInstanceProof proof uniqueId
modify fun s => { s with instanceMap := s.instanceMap.insert uniqueId thm }
if (← isMatchEqLikeDeclName thm.origin.key) then
prop ← annotateMatchEqnType prop (← read).initApp
-- We must add a hint here because `annotateMatchEqnType` introduces `simpMatchDiscrsOnly` and
Expand Down Expand Up @@ -706,29 +737,41 @@ end EMatch
open EMatch

/-- Performs one round of E-matching, and returns new instances. -/
private def ematchCore : GoalM Unit := do profileitM Exception "grind ematch" (← getOptions) do
private def ematchCore (extraThms : Array EMatchTheorem) : GoalM InstanceMap := do profileitM Exception "grind ematch" (← getOptions) do
let go (thms newThms : PArray EMatchTheorem) : EMatch.M Unit := do
withReader (fun ctx => { ctx with useMT := true }) <| ematchTheorems thms
withReader (fun ctx => { ctx with useMT := false }) <| ematchTheorems newThms
withReader (fun ctx => { ctx with useMT := false }) do
ematchTheorems newThms
extraThms.forM ematchTheorem
if (← checkMaxInstancesExceeded <||> checkMaxEmatchExceeded) then
return ()
return {}
else
go (← get).ematch.thms (← get).ematch.newThms |>.run'
let (_, s) ← go (← get).ematch.thms (← get).ematch.newThms |>.run
modify fun s => { s with
ematch.thms := s.ematch.thms ++ s.ematch.newThms
ematch.newThms := {}
ematch.gmt := s.ematch.gmt + 1
ematch.num := s.ematch.num + 1
}
return s.instanceMap

/-- Performs one round of E-matching, and returns `true` if new instances were generated. -/
def ematch : GoalM Bool := do
/--
Performs one round of E-matching, and returns `true` if new instances were generated.
Recall that the mapping is nonempty only if tracing is enabled.
-/
def ematch' (extraThms : Array EMatchTheorem := #[]) : GoalM (Bool × InstanceMap) := do
let numInstances := (← get).ematch.numInstances
ematchCore
return (← get).ematch.numInstances != numInstances
let map ← ematchCore extraThms
return ((← get).ematch.numInstances != numInstances, map)

/--
Performs one round of E-matching, and returns `true` if new instances were generated.
-/
def ematch (extraThms : Array EMatchTheorem := #[]) : GoalM Bool :=
return (← ematch' extraThms).1

/-- Performs one round of E-matching using the giving theorems, and returns `true` if new instances were generated. -/
def ematchTheorems (thms : Array EMatchTheorem) : GoalM Bool := do
def ematchOnly (thms : Array EMatchTheorem) : GoalM Bool := do
let numInstances := (← get).ematch.numInstances
go |>.run'
return (← get).ematch.numInstances != numInstances
Expand Down
Loading
Loading