Skip to content

Commit 075f1d6

Browse files
authored
feat: guard and check in grind_pattern (#11428)
This PR implements support for **guards** in `grind_pattern`. The new feature provides additional control over theorem instantiation. For example, consider the following monotonicity theorem: ```lean opaque f : Nat → Nat theorem fMono : x ≤ y → f x ≤ f y := ... ``` We can use `grind_pattern` to instruct `grind` to instantiate the theorem for every pair `f x` and `f y` occurring in the goal: ```lean grind_pattern fMono => f x, f y ``` Then we can automatically prove the following simple example using `grind`: ```lean /-- trace: [grind.ematch.instance] fMono: f a ≤ b → f (f a) ≤ f b [grind.ematch.instance] fMono: f a ≤ c → f (f a) ≤ f c [grind.ematch.instance] fMono: f a ≤ a → f (f a) ≤ f a [grind.ematch.instance] fMono: f a ≤ f (f a) → f (f a) ≤ f (f (f a)) [grind.ematch.instance] fMono: f a ≤ f a → f (f a) ≤ f (f a) [grind.ematch.instance] fMono: f (f a) ≤ b → f (f (f a)) ≤ f b [grind.ematch.instance] fMono: f (f a) ≤ c → f (f (f a)) ≤ f c [grind.ematch.instance] fMono: f (f a) ≤ a → f (f (f a)) ≤ f a [grind.ematch.instance] fMono: f (f a) ≤ f (f a) → f (f (f a)) ≤ f (f (f a)) [grind.ematch.instance] fMono: f (f a) ≤ f a → f (f (f a)) ≤ f (f a) [grind.ematch.instance] fMono: a ≤ b → f a ≤ f b [grind.ematch.instance] fMono: a ≤ c → f a ≤ f c [grind.ematch.instance] fMono: a ≤ a → f a ≤ f a [grind.ematch.instance] fMono: a ≤ f (f a) → f a ≤ f (f (f a)) [grind.ematch.instance] fMono: a ≤ f a → f a ≤ f (f a) [grind.ematch.instance] fMono: c ≤ b → f c ≤ f b [grind.ematch.instance] fMono: c ≤ c → f c ≤ f c [grind.ematch.instance] fMono: c ≤ a → f c ≤ f a [grind.ematch.instance] fMono: c ≤ f (f a) → f c ≤ f (f (f a)) [grind.ematch.instance] fMono: c ≤ f a → f c ≤ f (f a) [grind.ematch.instance] fMono: b ≤ b → f b ≤ f b [grind.ematch.instance] fMono: b ≤ c → f b ≤ f c [grind.ematch.instance] fMono: b ≤ a → f b ≤ f a [grind.ematch.instance] fMono: b ≤ f (f a) → f b ≤ f (f (f a)) [grind.ematch.instance] fMono: b ≤ f a → f b ≤ f (f a) -/ #guard_msgs in example : f b = f c → a ≤ f a → f (f a) ≤ f (f (f a)) := by set_option trace.grind.ematch.instance true in grind ``` However, many unnecessary theorem instantiations are generated. With the new `guard` feature, we can instruct `grind` to instantiate the theorem **only if** `x ≤ y` is already known to be true in the current `grind` state: ```lean grind_pattern fMono => f x, f y where guard x ≤ y x =/= y ``` If we run the example again, only three instances are generated: ```lean /-- trace: [grind.ematch.instance] fMono: a ≤ f a → f a ≤ f (f a) [grind.ematch.instance] fMono: f a ≤ f (f a) → f (f a) ≤ f (f (f a)) [grind.ematch.instance] fMono: a ≤ f (f a) → f a ≤ f (f (f a)) -/ #guard_msgs in example : f b = f c → a ≤ f a → f (f a) ≤ f (f (f a)) := by set_option trace.grind.ematch.instance true in grind ``` Note that `guard` does **not** check whether the expression is *implied*. It only checks whether the expression is *already known* to be true in the current `grind` state. If this fact is eventually learned, the theorem will be instantiated. If you want `grind` to check whether the expression is implied, you should use: ```lean grind_pattern fMono => f x, f y where check x ≤ y x =/= y ``` Remark: we can use multiple `guard`/`check`s in a `grind_pattern` command.
1 parent 3f05179 commit 075f1d6

File tree

7 files changed

+279
-21
lines changed

7 files changed

+279
-21
lines changed

src/Lean/Meta/Tactic/Grind.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ builtin_initialize registerTraceClass `grind.ematch
6161
builtin_initialize registerTraceClass `grind.ematch.pattern
6262
builtin_initialize registerTraceClass `grind.ematch.instance
6363
builtin_initialize registerTraceClass `grind.ematch.instance.assignment
64+
builtin_initialize registerTraceClass `grind.ematch.instance.delayed
6465
builtin_initialize registerTraceClass `grind.eqResolution
6566
builtin_initialize registerTraceClass `grind.issues
6667
builtin_initialize registerTraceClass `grind.simp

src/Lean/Meta/Tactic/Grind/Core.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ where
240240
propagateDown e
241241
propagateUnitConstFuns lams₁ lams₂
242242
toPropagateSolvers.propagate
243+
if rhsNode.root.isTrue then
244+
checkDelayedThmInsts toPropagateDown
245+
checkDelayedThmInsts (toPropagateDown : List Expr) : GoalM Unit := do
246+
if (← isInconsistent) then return ()
247+
if (← get).delayedThmInsts.isEmpty then return ()
248+
for e in toPropagateDown do
249+
let some delayedThms := (← get).delayedThmInsts.find? { expr := e } | pure ()
250+
modify fun s => { s with delayedThmInsts := s.delayedThmInsts.erase { expr := e } }
251+
delayedThms.forM (·.check)
252+
243253
updateRoots (lhs : Expr) (rootNew : Expr) : GoalM Unit := do
244254
let isFalseRoot ← isFalseExpr rootNew
245255
traverseEqc lhs fun n => do

src/Lean/Meta/Tactic/Grind/EMatch.lean

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ macro "reportEMatchIssue!" s:(interpolatedStr(term) <|> term) : doElem => do
461461
Stores new theorem instance in the state.
462462
Recall that new instances are internalized later, after a full round of ematching.
463463
-/
464-
private def addNewInstance (thm : EMatchTheorem) (proof : Expr) (generation : Nat) : M Unit := do
464+
private def addNewInstance (thm : EMatchTheorem) (proof : Expr) (generation : Nat) (guards : List TheoremGuard) : M Unit := do
465465
let proof ← instantiateMVars proof
466466
if grind.debug.proofs.get (← getOptions) then
467467
check proof
@@ -499,8 +499,7 @@ where
499499
-- We must add a hint because `annotateEqnTypeConds` introduces `Grind.PreMatchCond`
500500
-- which is not reducible.
501501
proof := mkExpectedPropHint proof prop
502-
trace_goal[grind.ematch.instance] "{thm.origin.pp}: {prop}"
503-
addTheoremInstance thm proof prop (generation+1)
502+
addTheoremInstance thm proof prop (generation+1) guards
504503

505504
private def synthesizeInsts (mvars : Array Expr) (bis : Array BinderInfo) : OptionT M Unit := do
506505
let thm := (← read).thm
@@ -741,7 +740,33 @@ private def checkConstraints (thm : EMatchTheorem) (gen : Nat) (proof : Expr) (a
741740
It may be useful to bound the number of instances in the current branch.
742741
-/
743742
return (← getEMatchTheoremNumInstances thm) + 1 < n
744-
| _ => throwError "NIY"
743+
| .check _ | .guard _ => return true
744+
745+
private def collectGuards (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM (List TheoremGuard) := do
746+
if thm.cnstrs.isEmpty then return []
747+
/- **Note**: Only top-level theorems have constraints. -/
748+
let .const declName us := proof | return []
749+
unless thm.cnstrs.any fun c => c matches .check _ | .guard _ do return []
750+
let info ← getConstInfo declName
751+
let mut result := #[]
752+
let applySubst (e : Expr) : GoalM (Option Expr) := do
753+
let e := e.instantiateRev args
754+
let e := e.instantiateLevelParams info.levelParams us
755+
let e ← instantiateMVars e
756+
if e.hasMVar then
757+
reportIssue! "guard for `{thm.origin.pp}` was skipped because it contains metavariables after theorem instantiation{indentExpr e}"
758+
return none
759+
return some e
760+
for cnstr in thm.cnstrs do
761+
match cnstr with
762+
| .check e =>
763+
let some e ← applySubst e | pure ()
764+
result := result.push <| { e, check := true }
765+
| .guard e =>
766+
let some e ← applySubst e | pure ()
767+
result := result.push <| { e, check := false }
768+
| _ => pure ()
769+
return result.toList
745770

746771
/--
747772
After processing a (multi-)pattern, use the choice assignment to instantiate the proof.
@@ -762,15 +787,16 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
762787
let (some _, c) ← applyAssignment mvars |>.run c | return ()
763788
let some _ ← synthesizeInsts mvars bis | return ()
764789
if (← checkConstraints thm c.gen proof mvars) then
790+
let guards ← collectGuards thm proof mvars
765791
let proof := mkAppN proof mvars
766792
if (← mvars.allM (·.mvarId!.isAssigned)) then
767-
addNewInstance thm proof c.gen
793+
addNewInstance thm proof c.gen guards
768794
else
769795
let mvars ← mvars.filterM fun mvar => return !(← mvar.mvarId!.isAssigned)
770796
if let some mvarBad ← mvars.findM? fun mvar => return !(← isProof mvar) then
771797
reportEMatchIssue! "failed to instantiate {thm.origin.pp}, failed to instantiate non propositional argument with type{indentExpr (← inferType mvarBad)}"
772798
let proof ← mkLambdaFVars (binderInfoForMVars := .default) mvars (← instantiateMVars proof)
773-
addNewInstance thm proof c.gen
799+
addNewInstance thm proof c.gen guards
774800

775801
/-- Process choice stack until we don't have more choices to be processed. -/
776802
private def processChoices : M Unit := do
@@ -891,8 +917,19 @@ Recall that the mapping is nonempty only if tracing is enabled.
891917
-/
892918
def ematch' (extraThms : Array EMatchTheorem := #[]) : GoalM (Bool × InstanceMap) := do
893919
let numInstances := (← get).ematch.numInstances
920+
let numDelayedInstances := (← get).ematch.numDelayedInstances
894921
let map ← ematchCore extraThms
895-
return ((← get).ematch.numInstances != numInstances, map)
922+
let progress :=
923+
(← get).ematch.numInstances != numInstances
924+
||
925+
(← get).ematch.numDelayedInstances != numDelayedInstances
926+
if (← get).ematch.numDelayedInstances != numDelayedInstances then
927+
/-
928+
**Note**: If delayed instances were produced, new guards may have been internalized,
929+
and we may have pending facts to process.
930+
-/
931+
processNewFacts
932+
return (progress, map)
896933

897934
/--
898935
Performs one round of E-matching, and returns `true` if new instances were generated.

src/Lean/Meta/Tactic/Grind/Parser.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end GrindCnstr
2828
open GrindCnstr in
2929
def grindPatternCnstr : Parser :=
3030
isValue <|> isStrictValue <|> isGround <|> sizeLt <|> depthLt <|> genLt <|> maxInsts
31-
<|> guard <|> check <|> notDefEq <|> defEq
31+
<|> guard <|> GrindCnstr.check <|> notDefEq <|> defEq
3232

3333
def grindPatternCnstrs : Parser := leading_parser "where " >> many1Indent (ppLine >> grindPatternCnstr)
3434

src/Lean/Meta/Tactic/Grind/Simp.lean

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def dsimpCore (e : Expr) : GrindM Expr := do profileitM Exception "grind dsimp"
4545
Preprocesses `e` using `grind` normalization theorems and simprocs,
4646
and then applies several other preprocessing steps.
4747
-/
48-
def preprocess (e : Expr) : GoalM Simp.Result := do
48+
@[export lean_grind_preprocess]
49+
def preprocessImpl (e : Expr) : GoalM Simp.Result := do
4950
let e ← instantiateMVars e
5051
let r ← simpCore e
5152
/-

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,20 @@ inductive SplitSource where
121121
input
122122
| /-- Injectivity theorem. -/
123123
inj (origin : Origin)
124+
| /-- `grind_pattern` guard -/
125+
guard (origin : Origin)
124126
deriving Inhabited
125127

126128
def SplitSource.toMessageData : SplitSource → MessageData
127-
| .ematch origin => m!"E-matching {origin.pp}"
128-
| .ext declName => m!"Extensionality {declName}"
129+
| .ematch origin => m!"E-matching `{origin.pp}`"
130+
| .guard origin => m!"Theorem instantiation guard for `{origin.pp}`"
131+
| .ext declName => m!"Extensionality `{declName}`"
129132
| .mbtc a b i => m!"Model-based theory combination at argument #{i} of{indentExpr a}\nand{indentExpr b}"
130133
| .beta e => m!"Beta-reduction of{indentExpr e}"
131134
| .forallProp e => m!"Forall propagation at{indentExpr e}"
132135
| .existsProp e => m!"Exists propagation at{indentExpr e}"
133136
| .input => "Initial goal"
134-
| .inj origin => m!"Injectivity {origin.pp}"
137+
| .inj origin => m!"Injectivity `{origin.pp}`"
135138

136139
/-- Context for `GrindM` monad. -/
137140
structure Context where
@@ -762,8 +765,10 @@ structure EMatch.State where
762765
thms : PArray EMatchTheorem := {}
763766
/-- Active theorems that we have not performed any round of ematching yet. -/
764767
newThms : PArray EMatchTheorem := {}
765-
/-- Number of theorem instances generated so far -/
768+
/-- Number of theorem instances generated so far. -/
766769
numInstances : Nat := 0
770+
/-- Number of delayed theorem instances generated so far. We track them to decide whether E-match made progress or not. -/
771+
numDelayedInstances : Nat := 0
767772
/-- Number of E-matching rounds performed in this goal since the last case-split. -/
768773
num : Nat := 0
769774
/-- (pre-)instances found so far. It includes instances that failed to be instantiated. -/
@@ -900,6 +905,30 @@ structure Injective.State where
900905
fns : PHashMap ExprPtr InjectiveInfo := {}
901906
deriving Inhabited
902907

908+
/--
909+
Users can attach guards to `grind_pattern`s. A guard ensures that a theorem is instantiated
910+
only when the guard expression becomes provably true.
911+
912+
If `check` is `true`, then `grind` attempts to prove `e` by asserting its negation and
913+
checking whether this leads to a contradiction.
914+
-/
915+
structure TheoremGuard where
916+
e : Expr
917+
check : Bool
918+
deriving Inhabited
919+
920+
/--
921+
A delayed theorem instantiation is an instantiation that includes one or more guards.
922+
See `TheoremGuard`.
923+
-/
924+
structure DelayedTheoremInstance where
925+
thm : EMatchTheorem
926+
proof : Expr
927+
prop : Expr
928+
generation : Nat
929+
guards : List TheoremGuard
930+
deriving Inhabited
931+
903932
/-- The `grind` goal. -/
904933
structure Goal where
905934
mvarId : MVarId
@@ -936,6 +965,11 @@ structure Goal where
936965
clean : Clean.State := {}
937966
/-- Solver states. -/
938967
sstates : Array SolverExtensionState := #[]
968+
/--
969+
Delayed instantiations is a mapping from guards to theorems that are waiting them
970+
to become `True`.
971+
-/
972+
delayedThmInsts : PHashMap ExprPtr (List DelayedTheoremInstance) := {}
939973
deriving Inhabited
940974

941975
def Goal.hasSameRoot (g : Goal) (a b : Expr) : Bool :=
@@ -1001,12 +1035,6 @@ def addNewRawFact (proof : Expr) (prop : Expr) (generation : Nat) (splitSource :
10011035
def getNumTheoremInstances : GoalM Nat := do
10021036
return (← get).ematch.numInstances
10031037

1004-
/-- Adds a new theorem instance produced using E-matching. -/
1005-
def addTheoremInstance (thm : EMatchTheorem) (proof : Expr) (prop : Expr) (generation : Nat) : GoalM Unit := do
1006-
saveEMatchTheorem thm
1007-
addNewRawFact proof prop generation (.ematch thm.origin)
1008-
modify fun s => { s with ematch.numInstances := s.ematch.numInstances + 1 }
1009-
10101038
/-- Returns `true` if the maximum number of instances has been reached. -/
10111039
def checkMaxInstancesExceeded : GoalM Bool := do
10121040
return (← get).ematch.numInstances >= (← getConfig).instances
@@ -1316,13 +1344,17 @@ It assumes `a` and `b` are in the same equivalence class.
13161344
@[extern "lean_grind_mk_heq_proof"]
13171345
opaque mkHEqProof (a b : Expr) : GoalM Expr
13181346

1347+
-- Forward definition
1348+
@[extern "lean_grind_process_new_facts"]
1349+
opaque processNewFacts : GoalM Unit
1350+
13191351
-- Forward definition
13201352
@[extern "lean_grind_internalize"]
13211353
opaque internalize (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit
13221354

13231355
-- Forward definition
1324-
@[extern "lean_grind_process_new_facts"]
1325-
opaque processNewFacts : GoalM Unit
1356+
@[extern "lean_grind_preprocess"]
1357+
opaque preprocess : Expr → GoalM Simp.Result
13261358

13271359
/--
13281360
Internalizes a local declaration which is not a proposition.
@@ -1589,6 +1621,45 @@ def addSplitCandidate (sinfo : SplitInfo) : GoalM Unit := do
15891621
}
15901622
updateSplitArgPosMap sinfo
15911623

1624+
inductive ActivateNextGuardResult where
1625+
| ready
1626+
| next (guard : Expr) (pending : List TheoremGuard)
1627+
1628+
def activateNextGuard (thm : EMatchTheorem) (guards : List TheoremGuard) (generation : Nat) : GoalM ActivateNextGuardResult := do
1629+
go guards
1630+
where
1631+
go : List TheoremGuard → GoalM ActivateNextGuardResult
1632+
| [] => return .ready
1633+
| guard :: guards => do
1634+
let { expr := e, .. } ← preprocess guard.e
1635+
internalize e generation
1636+
if (← isEqTrue e) then
1637+
go guards
1638+
else
1639+
if guard.check then
1640+
addSplitCandidate <| .default e (.guard thm.origin)
1641+
return .next e guards
1642+
1643+
/-- Adds a new theorem instance produced using E-matching. -/
1644+
def addTheoremInstance (thm : EMatchTheorem) (proof : Expr) (prop : Expr) (generation : Nat) (guards : List TheoremGuard) : GoalM Unit := do
1645+
match (← activateNextGuard thm guards generation) with
1646+
| .ready =>
1647+
trace_goal[grind.ematch.instance] "{thm.origin.pp}: {prop}"
1648+
saveEMatchTheorem thm
1649+
addNewRawFact proof prop generation (.ematch thm.origin)
1650+
modify fun s => { s with ematch.numInstances := s.ematch.numInstances + 1 }
1651+
| .next guard guards =>
1652+
let thms := (← get).delayedThmInsts.find? { expr := guard } |>.getD []
1653+
let thms := { thm, proof, prop, generation, guards } :: thms
1654+
trace_goal[grind.ematch.instance.delayed] "`{thm.origin.pp}` waiting{indentExpr guard}"
1655+
modify fun s => { s with
1656+
delayedThmInsts := s.delayedThmInsts.insert { expr := guard } thms
1657+
ematch.numDelayedInstances := s.ematch.numDelayedInstances + 1
1658+
}
1659+
1660+
def DelayedTheoremInstance.check (delayed : DelayedTheoremInstance) : GoalM Unit := do
1661+
addTheoremInstance delayed.thm delayed.proof delayed.prop delayed.generation delayed.guards
1662+
15921663
/--
15931664
Returns extensionality theorems for the given type if available.
15941665
If `Config.ext` is `false`, the result is `#[]`.

0 commit comments

Comments
 (0)