Skip to content

Commit 63f2735

Browse files
committed
feat: grind? attribute
1 parent 897cd85 commit 63f2735

File tree

6 files changed

+88
-33
lines changed

6 files changed

+88
-33
lines changed

src/Init/Grind/Tactics.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ syntax grindIntro := &"intro "
3030
syntax grindExt := &"ext "
3131
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindRL <|> grindLR <|> grindUsr <|> grindCasesEager <|> grindCases <|> grindIntro <|> grindExt
3232
syntax (name := grind) "grind" (grindMod)? : attr
33+
syntax (name := grind?) "grind?" (grindMod)? : attr
3334
end Attr
3435
end Lean.Parser
3536

src/Lean/Meta/Tactic/Grind/Attr.lean

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,20 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
4949
def throwInvalidUsrModifier : CoreM α :=
5050
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
5151

52-
builtin_initialize
52+
/--
53+
Auxiliary function for registering `grind` and `grind?` attributes.
54+
The `grind?` is an alias for `grind` which displays patterns using `logInfo`.
55+
It is just a convenience for users.
56+
-/
57+
private def registerGrindAttr (showInfo : Bool) : IO Unit :=
5358
registerBuiltinAttribute {
54-
name := `grind
59+
name := if showInfo then `grind? else `grind
5560
descr :=
56-
"The `[grind]` attribute is used to annotate declarations.\
61+
let header := if showInfo then
62+
"The `[grind?]` attribute is identical to the `[grind]` attribute, but displays inferred pattern information."
63+
else
64+
"The `[grind]` attribute is used to annotate declarations."
65+
header ++ "\
5766
\
5867
When applied to an equational theorem, `[grind =]`, `[grind =_]`, or `[grind _=_]`\
5968
will mark the theorem for use in heuristic instantiations by the `grind` tactic,
@@ -73,12 +82,12 @@ builtin_initialize
7382
add := fun declName stx attrKind => MetaM.run' do
7483
match (← getAttrKindFromOpt stx) with
7584
| .ematch .user => throwInvalidUsrModifier
76-
| .ematch k => addEMatchAttr declName attrKind k
85+
| .ematch k => addEMatchAttr declName attrKind k (showInfo := showInfo)
7786
| .cases eager => addCasesAttr declName eager attrKind
7887
| .intro =>
7988
if let some info ← isCasesAttrPredicateCandidate? declName false then
8089
for ctor in info.ctors do
81-
addEMatchAttr ctor attrKind .default
90+
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
8291
else
8392
throwError "invalid `[grind intro]`, `{declName}` is not an inductive predicate"
8493
| .ext => addExtAttr declName attrKind
@@ -89,10 +98,12 @@ builtin_initialize
8998
-- If it is an inductive predicate,
9099
-- we also add the constructors (intro rules) as E-matching rules
91100
for ctor in info.ctors do
92-
addEMatchAttr ctor attrKind .default
101+
addEMatchAttr ctor attrKind .default (showInfo := showInfo)
93102
else
94-
addEMatchAttr declName attrKind .default
103+
addEMatchAttr declName attrKind .default (showInfo := showInfo)
95104
erase := fun declName => MetaM.run' do
105+
if showInfo then
106+
throwError "`[grind?]` is a helper attribute for displaying inferred patterns, if you want to remove the attribute, consider using `[grind]` instead"
96107
if (← isCasesAttrCandidate declName false) then
97108
eraseCasesAttr declName
98109
else if (← isExtTheorem declName) then
@@ -101,4 +112,8 @@ builtin_initialize
101112
eraseEMatchAttr declName
102113
}
103114

115+
builtin_initialize
116+
registerGrindAttr true
117+
registerGrindAttr false
118+
104119
end Lean.Meta.Grind

src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -589,18 +589,24 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
589589
msg := msg ++ m!"{x} : {← inferType x}"
590590
addMessageContextFull msg
591591

592+
private def logPatternWhen (showInfo : Bool) (origin : Origin) (patterns : List Expr) : MetaM Unit := do
593+
if showInfo then
594+
logInfo m!"{← origin.pp}: {patterns.map ppPattern}"
595+
592596
/--
593597
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
594598
Pattern variables are represented using de Bruijn indices.
595599
-/
596-
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : EMatchTheoremKind) : MetaM EMatchTheorem := do
600+
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr)
601+
(patterns : List Expr) (kind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
597602
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
598603
if symbols.isEmpty then
599604
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
600-
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
605+
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
601606
if let .missing pos ← checkCoverage proof numParams bvarFound then
602607
let pats : MessageData := m!"{patterns.map ppPattern}"
603608
throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
609+
logPatternWhen showInfo origin patterns
604610
return {
605611
proof, patterns, numParams, symbols
606612
levelParams, origin, kind
@@ -627,7 +633,7 @@ Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs
627633
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
628634
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the pattern.
629635
-/
630-
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) : MetaM EMatchTheorem := do
636+
def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (normalizePattern : Bool) (useLhs : Bool) (showInfo := false) : MetaM EMatchTheorem := do
631637
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
632638
let (lhs, rhs) ← match_expr type with
633639
| Eq _ lhs rhs => pure (lhs, rhs)
@@ -640,15 +646,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
640646
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat normConfig}"
641647
let pats := splitWhileForbidden (pat.abstract xs)
642648
return (xs.size, pats)
643-
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs)
649+
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs) (showInfo := showInfo)
644650

645-
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
651+
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) (showInfo := false) : MetaM EMatchTheorem := do
646652
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
647653
let_expr f@Eq α lhs rhs := type
648654
| throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}"
649655
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
650656
return (xs.size, [pat.abstract xs])
651-
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd
657+
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd (showInfo := showInfo)
652658

653659
/--
654660
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
@@ -657,8 +663,8 @@ creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
657663
If `normalizePattern` is true, it applies the `grind` simplification theorems and simprocs to the
658664
pattern.
659665
-/
660-
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) : MetaM EMatchTheorem := do
661-
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs
666+
def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Bool := true) (showInfo := false) : MetaM EMatchTheorem := do
667+
mkEMatchEqTheoremCore (.decl declName) #[] (← getProofFor declName) normalizePattern useLhs (showInfo := showInfo)
662668

663669
/--
664670
Adds an E-matching theorem to the environment.
@@ -844,13 +850,13 @@ since the theorem is already in the `grind` state and there is nothing to be ins
844850
-/
845851
def mkEMatchTheoremWithKind?
846852
(origin : Origin) (levelParams : Array Name) (proof : Expr) (kind : EMatchTheoremKind)
847-
(groundPatterns := true) : MetaM (Option EMatchTheorem) := do
853+
(groundPatterns := true) (showInfo := false) : MetaM (Option EMatchTheorem) := do
848854
if kind == .eqLhs then
849-
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true))
855+
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true) (showInfo := showInfo))
850856
else if kind == .eqRhs then
851-
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false))
857+
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false) (showInfo := showInfo))
852858
else if kind == .eqBwd then
853-
return (← mkEMatchEqBwdTheoremCore origin levelParams proof)
859+
return (← mkEMatchEqBwdTheoremCore origin levelParams proof (showInfo := showInfo))
854860
let type ← inferType proof
855861
/-
856862
Remark: we should not use `forallTelescopeReducing` (with default reducibility) here
@@ -894,25 +900,26 @@ where
894900
return none
895901
let numParams := xs.size
896902
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
903+
logPatternWhen showInfo origin patterns
897904
return some {
898905
proof, patterns, numParams, symbols
899906
levelParams, origin, kind
900907
}
901908

902-
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) : MetaM EMatchTheorem := do
903-
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
909+
def mkEMatchTheoremForDecl (declName : Name) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM EMatchTheorem := do
910+
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind (showInfo := showInfo)
904911
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
905912
return thm
906913

907-
def mkEMatchEqTheoremsForDef? (declName : Name) : MetaM (Option (Array EMatchTheorem)) := do
914+
def mkEMatchEqTheoremsForDef? (declName : Name) (showInfo := false) : MetaM (Option (Array EMatchTheorem)) := do
908915
let some eqns ← getEqnsFor? declName | return none
909916
eqns.mapM fun eqn => do
910-
mkEMatchEqTheorem eqn (normalizePattern := true)
917+
mkEMatchEqTheorem eqn (normalizePattern := true) (showInfo := showInfo)
911918

912-
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) : MetaM Unit := do
919+
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (useLhs := true) (showInfo := false) : MetaM Unit := do
913920
if wasOriginallyTheorem (← getEnv) declName then
914-
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs)) attrKind
915-
else if let some thms ← mkEMatchEqTheoremsForDef? declName then
921+
ematchTheoremsExt.add (← mkEMatchEqTheorem declName (normalizePattern := true) (useLhs := useLhs) (showInfo := showInfo)) attrKind
922+
else if let some thms ← mkEMatchEqTheoremsForDef? declName (showInfo := showInfo) then
916923
unless useLhs do
917924
throwError "`{declName}` is a definition, you must only use the left-hand side for extracting patterns"
918925
thms.forM (ematchTheoremsExt.add · attrKind)
@@ -935,20 +942,20 @@ def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMat
935942
throwErr
936943
return s.erase <| .decl declName
937944

938-
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) : MetaM Unit := do
945+
def addEMatchAttr (declName : Name) (attrKind : AttributeKind) (thmKind : EMatchTheoremKind) (showInfo := false) : MetaM Unit := do
939946
if thmKind == .eqLhs then
940-
addGrindEqAttr declName attrKind thmKind (useLhs := true)
947+
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
941948
else if thmKind == .eqRhs then
942-
addGrindEqAttr declName attrKind thmKind (useLhs := false)
949+
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
943950
else if thmKind == .eqBoth then
944-
addGrindEqAttr declName attrKind thmKind (useLhs := true)
945-
addGrindEqAttr declName attrKind thmKind (useLhs := false)
951+
addGrindEqAttr declName attrKind thmKind (useLhs := true) (showInfo := showInfo)
952+
addGrindEqAttr declName attrKind thmKind (useLhs := false) (showInfo := showInfo)
946953
else
947954
let info ← getConstInfo declName
948955
if !wasOriginallyTheorem (← getEnv) declName && !info.isCtor && !info.isAxiom then
949-
addGrindEqAttr declName attrKind thmKind
956+
addGrindEqAttr declName attrKind thmKind (showInfo := showInfo)
950957
else
951-
let thm ← mkEMatchTheoremForDecl declName thmKind
958+
let thm ← mkEMatchTheoremForDecl declName thmKind (showInfo := showInfo)
952959
ematchTheoremsExt.add thm attrKind
953960

954961
def eraseEMatchAttr (declName : Name) : MetaM Unit := do

tests/lean/run/grind_attrs.lean

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,18 @@ set_option trace.grind.ematch.pattern true in
4444
set_option trace.grind.ematch.pattern true in
4545
@[grind =>] theorem State.update_le_update (h : State.le σ' σ) : State.le (σ'.update x v) (σ.update x v) :=
4646
sorry
47+
48+
49+
namespace Foo
50+
51+
/-- info: Rtrans: [R #4 #3, R #3 #2] -/
52+
#guard_msgs (info) in
53+
@[grind? ->]
54+
axiom Rtrans {x y z : Nat} : R x y → R y z → R x z
55+
56+
/-- info: Rtrans': [R #4 #3, R #3 #2] -/
57+
#guard_msgs (info) in
58+
@[grind? →]
59+
axiom Rtrans' {x y z : Nat} : R x y → R y z → R x z
60+
61+
end Foo

tests/lean/run/grind_countP.lean

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,11 @@ grind_pattern List.countP_le_countP => l.countP P, l.countP Q
1313
theorem List.countP_lt_countP (hpq : ∀ x ∈ l, P x → Q x) (y:α) (hx: y ∈ l) (hxP : P y = false) (hxQ : Q y) :
1414
l.countP P < l.countP Q := by
1515
induction l <;> grind
16+
17+
/--
18+
info: List.countP_nil: [@List.countP #1 #0 (@List.nil _)]
19+
---
20+
info: List.countP_cons: [@List.countP #3 #2 (@List.cons _ #1 #0)]
21+
-/
22+
#guard_msgs (info) in
23+
attribute [grind?] List.countP_nil List.countP_cons

tests/lean/run/grind_eq.lean

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,12 @@ trace: [grind.assert] x1 = appV a_2 b
7676
#guard_msgs (trace) in
7777
example : x1 = appV a b → x2 = appV x1 c → x3 = appV b c → x4 = appV a x3 → HEq x2 x4 := by
7878
grind
79+
80+
81+
/--
82+
info: appV_assoc': [@appV #6 #5 (@HAdd.hAdd `[Nat] `[Nat] `[Nat] `[instHAdd] #4 #3) #2 (@appV _ #4 #3 #1 #0)]
83+
-/
84+
#guard_msgs (info) in
85+
@[grind? =]
86+
theorem appV_assoc' (a : Vector α n) (b : Vector α m) (c : Vector α n') :
87+
HEq (appV a (appV b c)) (appV (appV a b) c) := sorry

0 commit comments

Comments
 (0)