Skip to content

Commit 16740a1

Browse files
authored
feat: some grind_pattern constraints (#11405)
This PR implements the following `grind_pattern` constraints: ```lean grind_pattern fax => f x where depth x < 2 grind_pattern fax => f x where is_ground x grind_pattern fax => f x where size x < 5 grind_pattern fax => f x where gen < 2 grind_pattern fax => f x where max_insts < 4 grind_pattern gax => g as where as =?= _ :: _ ```
1 parent 799d594 commit 16740a1

File tree

6 files changed

+173
-15
lines changed

6 files changed

+173
-15
lines changed

src/Lean/Elab/Tactic/Grind/Main.lean

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,15 @@ where
100100
else if kind == ``defEq then
101101
elabDefEq xs cnstr[0] cnstr[2]
102102
else if kind == ``genLt then
103-
let (_, lhs) ← findLHS xs cnstr[1]
104-
return .genLt lhs cnstr[3].toNat
103+
return .genLt cnstr[2].toNat
105104
else if kind == ``sizeLt then
106105
let (_, lhs) ← findLHS xs cnstr[1]
107106
return .sizeLt lhs cnstr[3].toNat
108107
else if kind == ``depthLt then
109108
let (_, lhs) ← findLHS xs cnstr[1]
110109
return .depthLt lhs cnstr[3].toNat
111110
else if kind == ``maxInsts then
112-
return .maxInsts cnstr[1].toNat
111+
return .maxInsts cnstr[2].toNat
113112
else if kind == ``isValue then
114113
let (_, lhs) ← findLHS xs cnstr[1]
115114
return .isValue lhs false

src/Lean/Meta/Tactic/Grind/EMatch.lean

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -637,13 +637,17 @@ private abbrev withFreshNGen (x : M α) : M α := do
637637
finally
638638
setNGen ngen
639639

640-
/--
641-
Checks constraints of the form `lhs =/= rhs`.
642-
-/
643-
private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
640+
private def getLHS (args : Array Expr) (lhs : Nat) : MetaM Expr := do
644641
unless lhs < args.size do
645642
throwError "`grind` internal error, invalid variable in `grind_pattern` constraint"
646-
let lhs := args[args.size - lhs - 1]!
643+
instantiateMVars args[args.size - lhs - 1]!
644+
645+
/--
646+
Checks constraints of the form `lhs =/= rhs` and `lhs =?= rhs`.
647+
`expectedResult` is `true` if `lhs` and `rhs` should be definitionally equal.
648+
-/
649+
private def checkDefEq (expectedResult : Bool) (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
650+
let lhs ← getLHS args lhs
647651
/- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/
648652
let rhsExpr := rhs.expr.instantiateRev args
649653
let rhsExpr := rhsExpr.instantiateLevelParams levelParams us
@@ -657,7 +661,38 @@ private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Ar
657661
let (_, _, rhsExpr) ← lambdaMetaTelescope rhsExpr (some rhs.numMVars)
658662
/- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/
659663
let defEq ← isDefEqGuarded lhs rhsExpr
660-
return !defEq
664+
return defEq == expectedResult
665+
666+
/--
667+
Helper function for checking grind pattern constraints of the form `size e < threshold`
668+
Implicit arguments and type information in lambdas and let-expressions are ignored.
669+
-/
670+
partial def checkSize (e : Expr) (threshold : Nat) : MetaM Bool :=
671+
return (← go e |>.run |>.run 0).1.isSome
672+
where
673+
go (e : Expr) : OptionT (StateT Nat MetaM) Unit := do
674+
guard ((← get) < threshold)
675+
modify (·+1)
676+
match e with
677+
| .forallE _ d b _ => go d; go b
678+
| .lam _ _ b _ => go b
679+
| .letE _ _ v b _ => go v; go b
680+
| .mdata _ e
681+
| .proj _ _ e => go e
682+
| .app .. => e.withApp fun f args => do
683+
if f.hasLooseBVars then
684+
go f; args.forM go
685+
else
686+
let paramInfo := (← getFunInfo f).paramInfo
687+
for h : i in *...args.size do
688+
let arg := args[i]
689+
if h : i < paramInfo.size then
690+
let pinfo := paramInfo[i]
691+
if pinfo.isExplicit && !pinfo.isProp then
692+
go arg
693+
else
694+
go arg
695+
| _ => return ()
661696

662697
/--
663698
Checks whether `vars` satisfies the `grind_pattern` constraints attached at `thm`.
@@ -672,14 +707,25 @@ In the example above, a `map_map` instance should be added to the logical contex
672707
673708
Remark: `proof` is used to extract the universe parameters in the proof.
674709
-/
675-
private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM Bool := do
710+
private def checkConstraints (thm : EMatchTheorem) (gen : Nat) (proof : Expr) (args : Array Expr) : GoalM Bool := do
676711
if thm.cnstrs.isEmpty then return true
677712
/- **Note**: Only top-level theorems have constraints. -/
678713
let .const declName us := proof | return true
679714
let info ← getConstInfo declName
680715
thm.cnstrs.allM fun cnstr => do
681716
match cnstr with
682-
| .notDefEq lhs rhs => checkNotDefEq info.levelParams us args lhs rhs
717+
| .notDefEq lhs rhs => checkDefEq (expectedResult := false) info.levelParams us args lhs rhs
718+
| .defEq lhs rhs => checkDefEq (expectedResult := true) info.levelParams us args lhs rhs
719+
| .depthLt lhs n => return (← getLHS args lhs).approxDepth.toNat < n
720+
| .isGround lhs => let lhs ← getLHS args lhs; return !lhs.hasFVar && !lhs.hasMVar
721+
| .sizeLt lhs n => checkSize (← getLHS args lhs) n
722+
| .genLt n => return gen < n
723+
| .maxInsts n =>
724+
/-
725+
**Note**: We are checking the number of instances produced in the whole proof.
726+
It may be useful to bound the number of instances in the current branch.
727+
-/
728+
return (← getEMatchTheoremNumInstances thm) + 1 < n
683729
| _ => throwError "NIY"
684730

685731
/--
@@ -700,7 +746,7 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
700746
return ()
701747
let (some _, c) ← applyAssignment mvars |>.run c | return ()
702748
let some _ ← synthesizeInsts mvars bis | return ()
703-
if (← checkConstraints thm proof mvars) then
749+
if (← checkConstraints thm c.gen proof mvars) then
704750
let proof := mkAppN proof mvars
705751
if (← mvars.allM (·.mvarId!.isAssigned)) then
706752
addNewInstance thm proof c.gen

src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ inductive EMatchTheoremConstraint where
380380
| /--
381381
Instantiates the theorem only if its generation is less than `n`
382382
-/
383-
genLt (lhs : Nat) (n : Nat)
383+
genLt (n : Nat)
384384
| /--
385385
Constraints of the form `is_ground x`. Instantiates the theorem only if
386386
`x` is ground term.

src/Lean/Meta/Tactic/Grind/Parser.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def isStrictValue := leading_parser nonReservedSymbol "is_strict_value " >> iden
1616
def isGround := leading_parser nonReservedSymbol "is_ground " >> ident >> optional ";"
1717
def sizeLt := leading_parser nonReservedSymbol "size " >> ident >> " < " >> numLit >> optional ";"
1818
def depthLt := leading_parser nonReservedSymbol "depth " >> ident >> " < " >> numLit >> optional ";"
19-
def genLt := leading_parser nonReservedSymbol "gen " >> ident >> " < " >> numLit >> optional ";"
20-
def maxInsts := leading_parser nonReservedSymbol "max_insts " >> numLit >> optional ";"
19+
def genLt := leading_parser nonReservedSymbol "gen" >> " < " >> numLit >> optional ";"
20+
def maxInsts := leading_parser nonReservedSymbol "max_insts" >> " < " >> numLit >> optional ";"
2121
def guard := leading_parser nonReservedSymbol "guard " >> checkColGe "irrelevant" >> termParser >> optional ";"
2222
def check := leading_parser nonReservedSymbol "check " >> checkColGe "irrelevant" >> termParser >> optional ";"
2323
def notDefEq := leading_parser atomic (ident >> " =/= ") >> checkColGe "irrelevant" >> termParser >> optional ";"

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ private def incCounter [Hashable α] [BEq α] (s : PHashMap α Nat) (k : α) : P
358358
private def saveEMatchTheorem (thm : EMatchTheorem) : GrindM Unit := do
359359
modify fun s => { s with counters.thm := incCounter s.counters.thm thm.origin }
360360

361+
def getEMatchTheoremNumInstances (thm : EMatchTheorem) : GrindM Nat := do
362+
return (← get).counters.thm.find? thm.origin |>.getD 0
363+
361364
def saveCases (declName : Name) : GrindM Unit := do
362365
modify fun s => { s with counters.case := incCounter s.counters.case declName }
363366

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
2+
namespace Ex1
3+
opaque f : Nat → Nat
4+
axiom fax : f x ≥ f (f x)
5+
6+
grind_pattern fax => f x where
7+
depth x < 2
8+
9+
/--
10+
trace: [grind.ematch.instance] fax: f a ≥ f (f a)
11+
[grind.ematch.instance] fax: f (f a) ≥ f (f (f a))
12+
-/
13+
#guard_msgs (drop error, trace) in
14+
set_option trace.grind.ematch.instance true in
15+
example (h : f a = 0) : False := by
16+
grind
17+
end Ex1
18+
19+
namespace Ex2
20+
opaque f : Nat → Nat
21+
axiom fax : f x ≥ f (f x)
22+
23+
grind_pattern fax => f x where
24+
is_ground x
25+
depth x < 3
26+
27+
opaque b : Nat
28+
29+
-- Theorems containing `a` should not be instantiate since it is a local variable
30+
/--
31+
trace: [grind.ematch.instance] fax: f b ≥ f (f b)
32+
[grind.ematch.instance] fax: f (f b) ≥ f (f (f b))
33+
[grind.ematch.instance] fax: f (f (f b)) ≥ f (f (f (f b)))
34+
-/
35+
#guard_msgs (drop error, trace) in
36+
set_option trace.grind.ematch.instance true in
37+
example : f a = 0 → f b = 0 → False := by
38+
grind
39+
end Ex2
40+
41+
namespace Ex3
42+
def f {α : Type} : α → α → α := fun x _ => x
43+
axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x)
44+
45+
grind_pattern fax => f x x where
46+
size x < 5
47+
48+
/--
49+
trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a)
50+
[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a))
51+
-/
52+
#guard_msgs (drop error, trace) in
53+
set_option trace.grind.ematch.instance true in
54+
example (a b : List (List Nat)) : f a a = b → False := by
55+
grind
56+
end Ex3
57+
58+
namespace Ex4
59+
def f {α : Type} : α → α → α := fun x _ => x
60+
axiom fax [LE α] (x : α) : f x x ≥ f (f x x) (f x x)
61+
62+
grind_pattern fax => f x x where
63+
gen < 2
64+
65+
/--
66+
trace: [grind.ematch.instance] fax: f a a ≥ f (f a a) (f a a)
67+
[grind.ematch.instance] fax: f (f a a) (f a a) ≥ f (f (f a a) (f a a)) (f (f a a) (f a a))
68+
-/
69+
#guard_msgs (drop error, trace) in
70+
set_option trace.grind.ematch.instance true in
71+
example (a b : List (List Nat)) : f a a = b → False := by
72+
grind
73+
end Ex4
74+
75+
76+
namespace Ex5
77+
opaque f : Nat → Nat
78+
axiom fax : f x ≥ f (f x)
79+
80+
grind_pattern fax => f x where
81+
max_insts < 4
82+
83+
/--
84+
trace: [grind.ematch.instance] fax: f c ≥ f (f c)
85+
[grind.ematch.instance] fax: f b ≥ f (f b)
86+
[grind.ematch.instance] fax: f a ≥ f (f a)
87+
-/
88+
#guard_msgs (drop error, trace) in
89+
set_option trace.grind.ematch.instance true in
90+
example : f a = 0 → f b = 0 → f c = 0 → False := by
91+
grind
92+
93+
end Ex5
94+
95+
namespace Ex6
96+
97+
opaque g : List Nat → Nat
98+
opaque f : List Nat → List Nat
99+
axiom gax (as : List Nat) : g as > g (f as)
100+
101+
grind_pattern gax => g as where
102+
as =?= _ :: _
103+
104+
/-- trace: [grind.ematch.instance] gax: g [1, 2, 3] > g (f [1, 2, 3]) -/
105+
#guard_msgs (drop error, trace) in
106+
set_option trace.grind.ematch.instance true in
107+
example (h : g [1, 2, 3] > 0) : False := by
108+
grind
109+
110+
end Ex6

0 commit comments

Comments
 (0)