Skip to content

Commit c28b052

Browse files
authored
feat: [grind?] attribute (#8426)
This PR adds the attribute `[grind?]`. It is like `[grind]` but displays inferred E-matching patterns. It is a more convinient than writing. Thanks @kim-em for suggesting this feature. ```lean set_option trace.grind.ematch.pattern true ``` This PR also improves some tests, and adds helper function `ENode.isRoot`.
1 parent a541b8e commit c28b052

File tree

10 files changed

+102
-48
lines changed

10 files changed

+102
-48
lines changed

src/Init/Grind/Tactics.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ syntax grindIntro := &"intro "
3030
syntax grindExt := &"ext "
3131
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt
3232
syntax (name := grind) "grind" (grindMod)? : attr
33+
syntax (name := grind?) "grind?" (grindMod)? : attr
3334
end Attr
3435
end Lean.Parser
3536

src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do
9595
-- Assign on expressions associated with cutsat terms or interpreted terms
9696
for e in goal.exprs do
9797
let node ← goal.getENode e
98-
if isSameExpr node.root node.self then
98+
if node.isRoot then
9999
if (← isIntNatENode node) then
100100
if let some v ← getAssignment? goal node.self then
101101
if v.den == 1 then used := used.insert v.num
@@ -111,7 +111,7 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do
111111
-- Assign the remaining ones with values not used by cutsat
112112
for e in goal.exprs do
113113
let node ← goal.getENode e
114-
if isSameExpr node.root node.self then
114+
if node.isRoot then
115115
if (← isIntNatENode node) then
116116
if model[node.self]?.isNone then
117117
let v := pickUnusedValue goal model node.self nextVal used

src/Lean/Meta/Tactic/Grind/Attr.lean

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,20 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
4949
def throwInvalidUsrModifier : CoreM α :=
5050
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
5151

52-
builtin_initialize
52+
/--
53+
Auxiliary function for registering `grind` and `grind?` attributes.
54+
The `grind?` is an alias for `grind` which displays patterns using `logInfo`.
55+
It is just a convenience for users.
56+
-/
57+
private def registerGrindAttr (showInfo : Bool) : IO Unit :=
5358
registerBuiltinAttribute {
54-
name := `grind
59+
name := if showInfo then `grind? else `grind
5560
descr :=
56-
"The `[grind]` attribute is used to annotate declarations.\
61+
let header := if showInfo then
62+
"The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information."
63+
else
64+
"The `[grind]` attribute is used to annotate declarations."
65+
header ++ "\
5766
\
5867
When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\
5968
will mark the theorem for use in heuristic instantiations by the `grind` tactic,
@@ -73,12 +82,12 @@ builtin_initialize
7382
add := fun declName stx attrKind => MetaM.run' do
7483
match (← getAttrKindFromOpt stx) with
7584
| .ematch .user => throwInvalidUsrModifier
76-
| .ematch k => addEMatchAttr declName attrKind k
85+
| .ematch k => addEMatchAttr declName attrKind k (showInfo := showInfo)
7786
| .cases eager => addCasesAttr declName eager attrKind
7887
| .intro =>
7988
if let some info ← isCasesAttrPredicateCandidate? declName false then
8089
for ctor in info.ctors do
81-
addEMatchAttr ctor attrKind .default
90+
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
8291
else
8392
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
8493
| .ext => addExtAttr declName attrKind
@@ -89,10 +98,12 @@ builtin_initialize
8998
-- If it is an inductive predicate,
9099
-- we also add the constructors (intro rules) as E-matching rules
91100
for ctor in info.ctors do
92-
addEMatchAttr ctor attrKind .default
101+
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
93102
else
94-
addEMatchAttr declName attrKind .default
103+
addEMatchAttr declName attrKind .default (showInfo := showInfo)
95104
erase := fun declName => MetaM.run' do
105+
if showInfo then
106+
throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead"
96107
if (← isCasesAttrCandidate declName false) then
97108
eraseCasesAttr declName
98109
else if (← isExtTheorem declName) then
@@ -101,4 +112,8 @@ builtin_initialize
101112
eraseEMatchAttr declName
102113
}
103114

115+
builtin_initialize
116+
registerGrindAttr true
117+
registerGrindAttr false
118+
104119
end Lean.Meta.Grind

src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -589,18 +589,24 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
589589
msg := msg ++ m!"{x} : {← inferType x}"
590590
addMessageContextFull msg
591591

592+
private def logPatternWhen (showInfo : Bool) (origin : Origin) (patterns : List Expr) : MetaM Unit := do
593+
if showInfo then
594+
logInfo m!"{← origin.pp}: {patterns.map ppPattern}"
595+
592596
/--
593597
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
594598
Pattern variables are represented using de Bruijn indices.
595599
-/
596-
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : EMatchTheoremKind) : MetaM EMatchTheorem := do
600+
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr)
601+
(patterns : List Expr) (kind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
597602
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
598603
if symbols.isEmpty then
599604
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
600-
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
605+
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
601606
if let .missing pos ← checkCoverage proof numParams bvarFound then
602607
let pats : MessageData := m!"{patterns.map ppPattern}"
603608
throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
609+
logPatternWhen showInfo origin patterns
604610
return {
605611
proof, patterns, numParams, symbols
606612
levelParams, origin, kind
@@ -627,7 +633,7 @@ Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs
627633
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
628634
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern.
629635
-/
630-
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) : MetaM EMatchTheorem := do
636+
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) (showInfo := false) : MetaM EMatchTheorem := do
631637
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
632638
let (lhs, rhs) ← match_expr type with
633639
| Eq _ lhs rhs => pure (lhs, rhs)
@@ -640,15 +646,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
640646
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat normConfig}"
641647
let pats := splitWhileForbidden (pat.abstract xs)
642648
return (xs.size, pats)
643-
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs)
649+
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs) (showInfo := showInfo)
644650

645-
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
651+
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (showInfo := false) : MetaM EMatchTheorem := do
646652
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
647653
let_expr f@Eq α lhs rhs := type
648654
| throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}"
649655
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
650656
return (xs.size, [pat.abstract xs])
651-
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd
657+
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd (showInfo := showInfo)
652658

653659
/--
654660
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
@@ -657,8 +663,8 @@ creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
657663
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the
658664
pattern.
659665
-/
660-
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) : MetaM EMatchTheorem := do
661-
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs
666+
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) (showInfo := false) : MetaM EMatchTheorem := do
667+
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs (showInfo := showInfo)
662668

663669
/--
664670
Adds an E-matching theorem to the environment.
@@ -844,13 +850,13 @@ since the theorem is already in the `grind` state and there is nothing to be ins
844850
-/
845851
def mkEMatchTheoremWithKind?
846852
(origin : Origin) (levelParams : Array Name) (proof : Expr) (kind : EMatchTheoremKind)
847-
(groundPatterns := true) : MetaM (Option EMatchTheorem) := do
853+
(groundPatterns := true) (showInfo := false) : MetaM (Option EMatchTheorem) := do
848854
if kind == .eqLhs then
849-
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true))
855+
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true) (showInfo := showInfo))
850856
else if kind == .eqRhs then
851-
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false))
857+
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false) (showInfo := showInfo))
852858
else if kind == .eqBwd then
853-
return (← mkEMatchEqBwdTheoremCore origin levelParams proof)
859+
return (← mkEMatchEqBwdTheoremCore origin levelParams proof (showInfo := showInfo))
854860
let type ← inferType proof
855861
/-
856862
Remark: we should not use `forallTelescopeReducing` (with default reducibility) here
@@ -894,25 +900,26 @@ where
894900
return none
895901
let numParams := xs.size
896902
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
903+
logPatternWhen showInfo origin patterns
897904
return some {
898905
proof, patterns, numParams, symbols
899906
levelParams, origin, kind
900907
}
901908

902-
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) : MetaM EMatchTheorem := do
903-
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
909+
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
910+
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind (showInfo := showInfo)
904911
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
905912
return thm
906913

907-
def mkEMatchEqTheoremsForDef? (declName : Name) : MetaM (Option (Array EMatchTheorem)) := do
914+
def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Option (Array EMatchTheorem)) := do
908915
let some eqns ← getEqnsFor? declName | return none
909916
eqns.mapM fun eqn => do
910-
mkEMatchEqTheorem eqn (normalizePattern := true)
917+
mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo)
911918

912-
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) : MetaM Unit := do
919+
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
913920
if wasOriginallyTheorem (← getEnv) declName then
914-
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs)) attrKind
915-
else if let some thms ← mkEMatchEqTheoremsForDef? declName then
921+
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (showInfo := showInfo)) attrKind
922+
else if let some thms ← mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then
916923
unless useLhs do
917924
throwError "`{declName}` is a definition, you must only use the left-hand side for extracting patterns"
918925
thms.forM (ematchTheoremsExt.add · attrKind)
@@ -935,20 +942,20 @@ def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMat
935942
throwErr
936943
return s.erase <| .decl declName
937944

938-
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) : MetaM Unit := do
945+
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM Unit := do
939946
if thmKind == .eqLhs then
940-
addGrindEqAttr declName attrKind thmKind (useLhs := true)
947+
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
941948
else if thmKind == .eqRhs then
942-
addGrindEqAttr declName attrKind thmKind (useLhs := false)
949+
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
943950
else if thmKind == .eqBoth then
944-
addGrindEqAttr declName attrKind thmKind (useLhs := true)
945-
addGrindEqAttr declName attrKind thmKind (useLhs := false)
951+
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
952+
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
946953
else
947954
let info ← getConstInfo declName
948955
if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then
949-
addGrindEqAttr declName attrKind thmKind
956+
addGrindEqAttr declName attrKind thmKind (showInfo := showInfo)
950957
else
951-
let thm ← mkEMatchTheoremForDecl declName thmKind
958+
let thm ← mkEMatchTheoremForDecl declName thmKind (showInfo := showInfo)
952959
ematchTheoremsExt.add thm attrKind
953960

954961
def eraseEMatchAttr (declName : Name) : MetaM Unit := do

src/Lean/Meta/Tactic/Grind/Inv.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def checkInvariants (expensive := false) : GoalM Unit := do
123123
for e in (← getExprs) do
124124
let node ← getENode e
125125
checkParents node.self
126-
if isSameExpr node.self node.root then
126+
if node.isRoot then
127127
checkEqc node
128128
if expensive then
129129
checkPtrEqImpliesStructEq

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,9 @@ structure ENode where
343343
-- If the number of satellite solvers increases, we may add support for an arbitrary solvers like done in Z3.
344344
deriving Inhabited, Repr
345345

346+
def ENode.isRoot (n : ENode) :=
347+
isSameExpr n.self n.root
348+
346349
def ENode.isCongrRoot (n : ENode) :=
347350
isSameExpr n.self n.congr
348351

@@ -1250,7 +1253,7 @@ def filterENodes (p : ENode → GoalM Bool) : GoalM (Array ENode) := do
12501253
def forEachEqcRoot (f : ENode → GoalM Unit) : GoalM Unit := do
12511254
for e in (← getExprs) do
12521255
let n ← getENode e
1253-
if isSameExpr n.self n.root then
1256+
if n.isRoot then
12541257
f n
12551258

12561259
abbrev Propagator := Expr → GoalM Unit
@@ -1302,7 +1305,7 @@ partial def Goal.getEqcs (goal : Goal) : List (List Expr) := Id.run do
13021305
let mut r : List (List Expr) := []
13031306
for e in goal.exprs do
13041307
let some node := goal.getENode? e | pure ()
1305-
if isSameExpr node.root node.self then
1308+
if node.isRoot then
13061309
r := goal.getEqc node.self :: r
13071310
return r
13081311

tests/lean/run/grind_attrs.lean

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,18 @@ set_option trace.grind.ematch.pattern true in
4444
set_option trace.grind.ematch.pattern true in
4545
@[grind =>] theorem State.update_le_update (h : State.le σ' σ) : State.le (σ'.update x v) (σ.update x v) :=
4646
sorry
47+
48+
49+
namespace Foo
50+
51+
/-- info: Rtrans: [R #4 #3, R #3 #2] -/
52+
#guard_msgs (info) in
53+
@[grind? ->]
54+
axiom Rtrans {x y z : Nat} : R x y → R y z → R x z
55+
56+
/-- info: Rtrans': [R #4 #3, R #3 #2] -/
57+
#guard_msgs (info) in
58+
@[grind? →]
59+
axiom Rtrans' {x y z : Nat} : R x y → R y z → R x z
60+
61+
end Foo

tests/lean/run/grind_countP.lean

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@ attribute [grind] List.countP_nil List.countP_cons
55

66
theorem List.countP_le_countP (hpq : ∀ x ∈ l, P x → Q x) :
77
l.countP P ≤ l.countP Q := by
8-
induction l with
9-
| nil => grind
10-
| cons x xs ih =>
11-
grind
8+
induction l <;> grind
9+
10+
-- TODO: how to explain to the user that `l.countP P ≤ l.countP Q` is a bad pattern
11+
grind_pattern List.countP_le_countP => l.countP P, l.countP Q
1212

1313
theorem List.countP_lt_countP (hpq : ∀ x ∈ l, P x → Q x) (y:α) (hx: y ∈ l) (hxP : P y = false) (hxQ : Q y) :
1414
l.countP P < l.countP Q := by
15-
induction l with
16-
| nil => grind
17-
| cons x xs ih =>
18-
have : xs.countP P ≤ xs.countP Q := countP_le_countP (by grind)
19-
grind
15+
induction l <;> grind
16+
17+
/--
18+
info: List.countP_nil: [@List.countP #1 #0 (@List.nil _)]
19+
---
20+
info: List.countP_cons: [@List.countP #3 #2 (@List.cons _ #1 #0)]
21+
-/
22+
#guard_msgs (info) in
23+
attribute [grind?] List.countP_nil List.countP_cons

tests/lean/run/grind_eq.lean

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,12 @@ trace: [grind.assert] x1 = appV a_2 b
7676
#guard_msgs (trace) in
7777
example : x1 = appV a b → x2 = appV x1 c → x3 = appV b c → x4 = appV a x3 → HEq x2 x4 := by
7878
grind
79+
80+
81+
/--
82+
info: appV_assoc': [@appV #6 #5 (@HAdd.hAdd `[Nat] `[Nat] `[Nat] `[instHAdd] #4 #3) #2 (@appV _ #4 #3 #1 #0)]
83+
-/
84+
#guard_msgs (info) in
85+
@[grind? =]
86+
theorem appV_assoc' (a : Vector α n) (b : Vector α m) (c : Vector α n') :
87+
HEq (appV a (appV b c)) (appV (appV a b) c) := sorry

tests/lean/run/grind_getLast_dropLast renamed to tests/lean/run/grind_getLast_dropLast.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ theorem length_pos_of_ne_nil {l : List α} (h : l ≠ []) : 0 < l.length := by
88

99
theorem getLast?_dropLast {xs : List α} :
1010
xs.dropLast.getLast? = if xs.length ≤ 1 then none else xs[xs.length - 2]? := by
11-
grind (splits := 9) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?,
11+
grind (splits := 15) only [List.getElem?_eq_none, List.getElem?_reverse, getLast?_eq_getElem?,
1212
List.head?_eq_getLast?_reverse, getElem?_dropLast, List.getLast?_reverse, List.length_dropLast,
1313
List.length_reverse, length_nil, List.reverse_reverse, head?_nil, List.getElem?_eq_none,
1414
length_pos_of_ne_nil, getLast?_nil, List.head?_reverse, List.getLast?_eq_head?_reverse,

0 commit comments

Comments
 (0)