Skip to content

Commit 595d87b

Browse files
authored
feat: include symbols in ground grind patterns (#11589)
This PR improves indexing for `grind` patterns. We now include symbols occurring in nested ground patterns. This important to minimize the number of activated E-match theorems.
1 parent 361bfdb commit 595d87b

File tree

8 files changed

+69
-27
lines changed

8 files changed

+69
-27
lines changed

src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module
77
prelude
88
public import Lean.Meta.Tactic.Grind.Theorems
99
import Init.Grind.Util
10+
import Lean.Util.ForEachExpr
1011
import Lean.Meta.Tactic.Grind.Util
1112
import Lean.Meta.Match.Basic
1213
import Lean.Meta.Tactic.TryThis
@@ -580,6 +581,22 @@ private def saveSymbol (h : HeadIndex) : M Unit := do
580581
unless (← get).symbolSet.contains h do
581582
modify fun s => { s with symbols := s.symbols.push h, symbolSet := s.symbolSet.insert h }
582583

584+
private def saveSymbolsAt (e : Expr) : M Unit := do
585+
e.forEach' fun e => do
586+
if e.isApp || e.isConst then
587+
/- **Note**: We ignore function symbols that have special handling in the internalizer. -/
588+
if let .const declName _ := e.getAppFn then
589+
if declName == ``OfNat.ofNat || declName == ``Grind.nestedProof
590+
|| declName == ``Grind.eqBwdPattern
591+
|| declName == ``Grind.nestedDecidable || declName == ``ite then
592+
return false
593+
match e with
594+
| .const .. =>
595+
saveSymbol e.toHeadIndex
596+
return false
597+
| _ =>
598+
return true
599+
583600
private def foundBVar (idx : Nat) : M Bool :=
584601
return (← get).bvarsFound.contains idx
585602

@@ -672,30 +689,48 @@ private def getPatternFn? (pattern : Expr) (inSupport : Bool) (root : Bool) (arg
672689

673690
private partial def go (pattern : Expr) (inSupport : Bool) (root : Bool) : M Expr := do
674691
if let some (e, k) := isOffsetPattern? pattern then
675-
let e ← goArg e inSupport .relevant
692+
let e ← goArg e inSupport .relevant (isEqBwdParent := false)
676693
if e == dontCare then
677694
return dontCare
678695
else
679696
return mkOffsetPattern e k
680697
let some f ← getPatternFn? pattern inSupport root .relevant
681698
| throwError "invalid pattern, (non-forbidden) application expected{indentD (ppPattern pattern)}"
682699
assert! f.isConst || f.isFVar
683-
unless f.isConstOf ``Grind.eqBwdPattern do
700+
let isEqBwd := f.isConstOf ``Grind.eqBwdPattern
701+
unless isEqBwd do
684702
saveSymbol f.toHeadIndex
685703
let mut args := pattern.getAppArgs.toVector
686704
let patternArgKinds ← getPatternArgKinds f args.size
687705
for h : i in *...args.size do
688706
let arg := args[i]
689707
let argKind := patternArgKinds[i]?.getD .relevant
690-
args := args.set i (← goArg arg (inSupport || argKind.isSupport) argKind)
708+
args := args.set i (← goArg arg (inSupport || argKind.isSupport) argKind isEqBwd)
691709
return mkAppN f args.toArray
692710
where
693-
goArg (arg : Expr) (inSupport : Bool) (argKind : PatternArgKind) : M Expr := do
711+
goArg (arg : Expr) (inSupport : Bool) (argKind : PatternArgKind) (isEqBwdParent : Bool) : M Expr := do
694712
if !arg.hasLooseBVars then
695713
if arg.hasMVar then
696714
pure dontCare
715+
else if (← isProof arg) then
716+
pure dontCare
697717
else
698-
return mkGroundPattern (← expandOffsetPatterns arg)
718+
let arg ← expandOffsetPatterns arg
719+
unless isEqBwdParent do
720+
/-
721+
**Note**: We ignore symbols in ground patterns if the parent is the auxiliary ``Grind.eqBwdPattern
722+
We do that because we want to sign an error in examples such as:
723+
```
724+
theorem dummy (x : Nat) : x = x :=
725+
rfl
726+
-- error: invalid pattern for `dummy`
727+
-- [@Lean.Grind.eqBwdPattern `[Nat] #0 #0]
728+
-- the pattern does not contain constant symbols for indexing
729+
attribute [grind ←=] dummy
730+
```
731+
-/
732+
saveSymbolsAt arg
733+
return mkGroundPattern arg
699734
else match arg with
700735
| .bvar idx =>
701736
if inSupport && (← foundBVar idx) then

src/Lean/Meta/Tactic/Grind/Internalize.lean

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,17 @@ where
8989
unless (← isEqFalse e) do return false
9090
return !(← isEqFalse e')
9191

92+
def updateIndicesFound (k : HeadIndex) : GoalM Unit := do
93+
if (← get).indicesFound.contains k then return ()
94+
modify fun s => { s with indicesFound := s.indicesFound.insert k }
95+
9296
/--
9397
Given an application `e` of the form `f a_1 ... a_n`,
9498
adds entry `f ↦ e` to `appMap`. Recall that `appMap` is a multi-map.
9599
-/
96100
private def updateAppMap (e : Expr) : GoalM Unit := do
97101
let key := e.toHeadIndex
102+
updateIndicesFound key
98103
trace_goal[grind.debug.appMap] "{e} => {repr key}"
99104
modify fun s => { s with
100105
appMap := if let some es := s.appMap.find? key then
@@ -280,10 +285,10 @@ private def activateTheoremsCore [TheoremLike α] (declName : Name)
280285
let origin := TheoremLike.getOrigin thm
281286
trace_goal[grind.debug.theorem.activate] "`{declName}` => `{origin.key}`"
282287
unless s.isErased origin do
283-
let appMap := (← get).appMap
284-
let symbols := TheoremLike.getSymbols thm
285-
let symbols := symbols.filter fun sym => !appMap.contains sym
286-
let thm := TheoremLike.setSymbols thm symbols
288+
let indicesFound := (← get).indicesFound
289+
let symbols := TheoremLike.getSymbols thm
290+
let symbols := symbols.filter fun sym => !indicesFound.contains sym
291+
let thm := TheoremLike.setSymbols thm symbols
287292
match symbols with
288293
| [] =>
289294
trace_goal[grind.debug.theorem.activate] "`{origin.key}`"
@@ -515,6 +520,7 @@ where
515520
| .lit .. =>
516521
mkENode e generation
517522
| .const declName _ =>
523+
updateIndicesFound (.const declName)
518524
mkENode e generation
519525
activateTheorems declName generation
520526
| .mvar .. =>

src/Lean/Meta/Tactic/Grind/Main.lean

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ def assertExtra (params : Params) : GoalM Unit := do
121121
for thm in params.extraInj do
122122
activateInjectiveTheorem thm 0
123123

124+
private def initENodeCore (e : Expr) (interpreted ctor : Bool) : GoalM Unit := do
125+
if let .const declName _ := e then
126+
updateIndicesFound (.const declName)
127+
mkENodeCore e interpreted ctor (generation := 0) (funCC := false)
128+
124129
private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
125130
let mvarId ← if params.config.clean then mvarId.exposeNames else pure mvarId
126131
let trueExpr ← getTrueExpr
@@ -134,12 +139,12 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
134139
let clean ← mkCleanState mvarId params
135140
let sstates ← Solvers.mkInitialStates
136141
GoalM.run' { mvarId, ematch.thmMap := thmMap, inj.thms := params.inj, split.casesTypes := casesTypes, clean, sstates } do
137-
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0) (funCC := false)
138-
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0) (funCC := false)
139-
mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0) (funCC := false)
140-
mkENodeCore bfalseExpr (interpreted := false) (ctor := true) (generation := 0) (funCC := false)
141-
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0) (funCC := false)
142-
mkENodeCore ordEqExpr (interpreted := false) (ctor := true) (generation := 0) (funCC := false)
142+
initENodeCore falseExpr (interpreted := true) (ctor := false)
143+
initENodeCore trueExpr (interpreted := true) (ctor := false)
144+
initENodeCore btrueExpr (interpreted := false) (ctor := true)
145+
initENodeCore bfalseExpr (interpreted := false) (ctor := true)
146+
initENodeCore natZeroExpr (interpreted := true) (ctor := false)
147+
initENodeCore ordEqExpr (interpreted := false) (ctor := true)
143148
assertExtra params
144149

145150
structure Result where

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,11 @@ structure Goal where
951951
it is its unique id.
952952
-/
953953
appMap : PHashMap HeadIndex (List Expr) := {}
954+
/--
955+
All constants (*not* in `appMap`) that have been internalized, *and*
956+
`appMap`'s domain. We use this collection during theorem activation.
957+
-/
958+
indicesFound : PHashSet HeadIndex := {}
954959
/-- Equations and propositions to be processed. -/
955960
newFacts : Array NewFact := #[]
956961
/-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/

stage0/src/stdlib_flags.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// update me!
12
#include "util/options.h"
23

34
namespace lean {

tests/lean/run/grind_const_pattern.lean

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,8 @@ h : ¬f x = 11
7373
[prop] ¬f x = 11
7474
[eqc] False propositions
7575
[prop] f x = 11
76-
[ematch] E-matching patterns
77-
[thm] fa: [f `[a]]
7876
[cutsat] Assignment satisfying linear constraints
7977
[assign] x := 1
80-
[assign] a := 3
8178
[assign] f x := 2
8279
-/
8380
#guard_msgs (error) in

tests/lean/run/grind_pattern2.lean

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,6 @@ trace: [grind.internalize] [0] x
4545
[grind.internalize] [0] y
4646
[grind.internalize] [0] z
4747
[grind.internalize] [0] foo x y
48-
[grind.internalize] [0] [a, b]
49-
[grind.internalize] [0] Nat
50-
[grind.internalize] [0] a
51-
[grind.internalize] [0] [b]
52-
[grind.internalize] [0] b
53-
[grind.internalize] [0] []
54-
[grind.ematch] activated `fooThm`, [foo #0 `[[a, b]]]
5548
-/
5649
#guard_msgs (trace) in
5750
set_option trace.grind.internalize true in

tests/lean/run/try_induction.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,14 @@ theorem hyperoperation_recursion (n m k : ℕ) :
130130
/--
131131
info: Try these:
132132
[apply] (induction k) <;> grind
133-
[apply] (induction k) <;> grind only [hyperoperation, = add_zero, = add_succ, = hyperoperation_zero]
133+
[apply] (induction k) <;> grind only [hyperoperation, = add_zero, = add_succ]
134134
[apply] ·
135135
induction k
136136
· grind => instantiate only [hyperoperation, = add_zero]
137137
·
138138
grind =>
139139
instantiate only [hyperoperation, = add_succ]
140-
instantiate only [= hyperoperation_zero]
140+
instantiate only [hyperoperation]
141141
-/
142142
#guard_msgs in
143143
@[grind =]

0 commit comments

Comments
 (0)