Skip to content

Commit a4f9a79

Browse files
authored
feat: new constraints in grind_pattern (#11391)
This PR implements new kinds of constraints for the `grind_pattern` command. These constraints allow users to control theorem instantiation in `grind`. It requires a manual `update-stage0` because the change affects the `.olean` format, and the PR fails without it.
1 parent 490d714 commit a4f9a79

File tree

9 files changed

+12503
-4667
lines changed

9 files changed

+12503
-4667
lines changed

src/Lean/Elab/Tactic/Grind/Main.lean

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ declare_config_elab elabCutsatConfig Grind.CutsatConfig
3030
declare_config_elab elabGrobnerConfig Grind.GrobnerConfig
3131

3232
open Command Term in
33+
open Lean.Parser.Command.GrindCnstr in
3334
@[builtin_command_elab Lean.Parser.Command.grindPattern]
3435
def elabGrindPattern : CommandElab := fun stx => do
3536
match stx with
@@ -38,41 +39,92 @@ def elabGrindPattern : CommandElab := fun stx => do
3839
| `(local grind_pattern $thmName:ident => $terms,* $[$cnstrs?:grindPatternCnstrs]?) => go thmName terms cnstrs? .local
3940
| _ => throwUnsupportedSyntax
4041
where
42+
findLHS (xs : Array Expr) (lhs : Syntax) : TermElabM (LocalDecl × Nat) := do
43+
let lhsId := lhs.getId
44+
let mut i := 0
45+
for x in xs do
46+
let xDecl ← x.fvarId!.getDecl
47+
if xDecl.userName == lhsId then
48+
return (xDecl, xs.size - i - 1)
49+
i := i + 1
50+
throwErrorAt lhs "invalid constraint, `{lhsId}` is not local variable of the theorem"
51+
52+
elabCnstrRHS (xs : Array Expr) (rhs : Syntax) (expectedType : Expr) : TermElabM Grind.CnstrRHS := do
53+
/-
54+
**Note**: We need better sanity checking here.
55+
We must check whether the type of `rhs` is type correct with respect to
56+
an arbitrary instantiation of `xs`. That is, we should use meta-variables
57+
in the check. It is incorrect to use `xDecl.type`. For example, suppose the
58+
type of `xDecl` is `α → β` where `α` and `β` are variables in `xs` occurring before
59+
`xDecl`, and `rhsExpr` is `some : ?m → ?m`. The types `α → β =?= ?m → ?m` are
60+
not definitionally equal, but `?α → ?β =?= ?m → ?m` are.
61+
-/
62+
let rhsExpr ← Term.elabTerm rhs expectedType
63+
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
64+
let rhsExpr ← instantiateMVars rhsExpr
65+
if rhsExpr.hasSyntheticSorry then
66+
throwErrorAt rhs "invalid constraint, rhs contains a synthetic `sorry`"
67+
let rhsExpr := rhsExpr.eta
68+
let { paramNames := levelNames, mvars, expr := rhs } ← abstractMVars rhsExpr
69+
let numMVars := mvars.size
70+
let rhs := rhs.abstract xs
71+
return { levelNames, numMVars, expr := rhs }
72+
73+
elabProp (xs : Array Expr) (term : Syntax) : TermElabM Expr := do
74+
let e ← Term.elabTermAndSynthesize term (Expr.sort 0)
75+
let e ← instantiateMVars e
76+
if e.hasSyntheticSorry then
77+
throwErrorAt term "invalid proposition, it contains a synthetic `sorry`"
78+
if e.hasMVar then
79+
throwErrorAt term "invalid proposition, it contains metavariables{indentExpr e}"
80+
return e.abstract xs
81+
82+
elabNotDefEq (xs : Array Expr) (lhs rhs : Syntax) : TermElabM Grind.EMatchTheoremConstraint := do
83+
let (localDecl, lhsBVarIdx) ← findLHS xs lhs
84+
let rhs ← elabCnstrRHS xs rhs localDecl.type
85+
return .notDefEq lhsBVarIdx rhs
86+
87+
elabDefEq (xs : Array Expr) (lhs rhs : Syntax) : TermElabM Grind.EMatchTheoremConstraint := do
88+
let (localDecl, lhsBVarIdx) ← findLHS xs lhs
89+
let rhs ← elabCnstrRHS xs rhs localDecl.type
90+
return .defEq lhsBVarIdx rhs
91+
4192
elabCnstrs (xs : Array Expr) (cnstrs? : Option (TSyntax ``Parser.Command.grindPatternCnstrs))
4293
: TermElabM (List (Grind.EMatchTheoremConstraint)) := do
4394
let some cnstrs := cnstrs? | return []
4495
let cnstrs := cnstrs.raw[1].getArgs
4596
cnstrs.toList.mapM fun cnstr => do
46-
-- **Note**: Hack because syntax matching is not working. Fix after another update stage0
47-
let lhs := cnstr[0]
48-
let rhs := cnstr[2]
49-
let lhsId := lhs.getId
50-
let mut i := 0
51-
for x in xs do
52-
let xDecl ← x.fvarId!.getDecl
53-
if xDecl.userName == lhsId then
54-
let bvarIdx := xs.size - i - 1
55-
/-
56-
**Note**: We need better sanity checking here.
57-
We must check whether the type of `rhs` is type correct with respect to
58-
an arbitrary instantiation of `xs`. That is, we should use meta-variables
59-
in the check. It is incorrect to use `xDecl.type`. For example, suppose the
60-
type of `xDecl` is `α → β` where `α` and `β` are variables in `xs` occurring before
61-
`xDecl`, and `rhsExpr` is `some : ?m → ?m`. The types `α → β =?= ?m → ?m` are
62-
not definitionally equal, but `?α → ?β =?= ?m → ?m` are.
63-
-/
64-
let rhsExpr ← Term.elabTerm rhs xDecl.type
65-
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
66-
let rhsExpr ← instantiateMVars rhsExpr
67-
if rhsExpr.hasSyntheticSorry then
68-
throwErrorAt rhs "invalid constraint, rhs contains a synthetic `sorry`"
69-
let rhsExpr := rhsExpr.eta
70-
let { paramNames := levelNames, mvars, expr := rhs } ← abstractMVars rhsExpr
71-
let numMVars := mvars.size
72-
let rhs := rhs.abstract xs
73-
return { bvarIdx, levelNames, numMVars, rhs }
74-
i := i + 1
75-
throwErrorAt lhs "invalid constraint, `{lhsId}` is not local variable of the theorem"
97+
let kind := cnstr.getKind
98+
if kind == ``notDefEq then
99+
elabNotDefEq xs cnstr[0] cnstr[2]
100+
else if kind == ``defEq then
101+
elabDefEq xs cnstr[0] cnstr[2]
102+
else if kind == ``genLt then
103+
let (_, lhs) ← findLHS xs cnstr[1]
104+
return .genLt lhs cnstr[3].toNat
105+
else if kind == ``sizeLt then
106+
let (_, lhs) ← findLHS xs cnstr[1]
107+
return .sizeLt lhs cnstr[3].toNat
108+
else if kind == ``depthLt then
109+
let (_, lhs) ← findLHS xs cnstr[1]
110+
return .depthLt lhs cnstr[3].toNat
111+
else if kind == ``maxInsts then
112+
return .maxInsts cnstr[1].toNat
113+
else if kind == ``isValue then
114+
let (_, lhs) ← findLHS xs cnstr[1]
115+
return .isValue lhs false
116+
else if kind == ``isStrictValue then
117+
let (_, lhs) ← findLHS xs cnstr[1]
118+
return .isValue lhs true
119+
else if kind == ``isGround then
120+
let (_, lhs) ← findLHS xs cnstr[1]
121+
return .isGround lhs
122+
else if kind == ``Parser.Command.GrindCnstr.check then
123+
return .check (← elabProp xs cnstr[1])
124+
else if kind == ``Parser.Command.GrindCnstr.guard then
125+
return .guard (← elabProp xs cnstr[1])
126+
else
127+
throwErrorAt cnstr "unexpected constraint"
76128

77129
go (thmName : TSyntax `ident) (terms : Syntax.TSepArray `term ",")
78130
(cnstrs? : Option (TSyntax ``Parser.Command.grindPatternCnstrs))

src/Lean/Meta/Tactic/Grind/EMatch.lean

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,28 @@ private abbrev withFreshNGen (x : M α) : M α := do
637637
finally
638638
setNGen ngen
639639

640+
/--
641+
Checks constraints of the form `lhs =/= rhs`.
642+
-/
643+
private def checkNotDefEq (levelParams : List Name) (us : List Level) (args : Array Expr) (lhs : Nat) (rhs : CnstrRHS) : GoalM Bool := do
644+
unless lhs < args.size do
645+
throwError "`grind` internal error, invalid variable in `grind_pattern` constraint"
646+
let lhs := args[args.size - lhs - 1]!
647+
/- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/
648+
let rhsExpr := rhs.expr.instantiateRev args
649+
let rhsExpr := rhsExpr.instantiateLevelParams levelParams us
650+
withNewMCtxDepth do
651+
/-
652+
**Note**: Recall that we have abstracted metavariables occurring in `rhs` after we elaborated it.
653+
So, we must "recreate" them.
654+
-/
655+
let us ← rhs.levelNames.mapM fun _ => mkFreshLevelMVar
656+
let rhsExpr := rhsExpr.instantiateLevelParamsArray rhs.levelNames us
657+
let (_, _, rhsExpr) ← lambdaMetaTelescope rhsExpr (some rhs.numMVars)
658+
/- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/
659+
let defEq ← isDefEqGuarded lhs rhsExpr
660+
return !defEq
661+
640662
/--
641663
Checks whether `vars` satisfies the `grind_pattern` constraints attached at `thm`.
642664
Example:
@@ -650,29 +672,15 @@ In the example above, a `map_map` instance should be added to the logical contex
650672
651673
Remark: `proof` is used to extract the universe parameters in the proof.
652674
-/
653-
private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : MetaM Bool := do
675+
private def checkConstraints (thm : EMatchTheorem) (proof : Expr) (args : Array Expr) : GoalM Bool := do
654676
if thm.cnstrs.isEmpty then return true
655677
/- **Note**: Only top-level theorems have constraints. -/
656678
let .const declName us := proof | return true
657679
let info ← getConstInfo declName
658680
thm.cnstrs.allM fun cnstr => do
659-
unless cnstr.bvarIdx < args.size do
660-
throwError "`grind` internal error, invalid variable in `grind_pattern` constraint"
661-
let lhs := args[args.size - cnstr.bvarIdx - 1]!
662-
/- **Note**: We first instantiate the theorem variables and universes occurring in `rhs`. -/
663-
let rhs := cnstr.rhs.instantiateRev args
664-
let rhs := rhs.instantiateLevelParams info.levelParams us
665-
withNewMCtxDepth do
666-
/-
667-
**Note**: Recall that we have abstracted metavariables occurring in `rhs` after we elaborated it.
668-
So, we must "recreate" them.
669-
-/
670-
let us ← cnstr.levelNames.mapM fun _ => mkFreshLevelMVar
671-
let rhs := rhs.instantiateLevelParamsArray cnstr.levelNames us
672-
let (_, _, rhs) ← lambdaMetaTelescope rhs (some cnstr.numMVars)
673-
/- **Note**: We used the guarded version to ensure type errors will not interrupt `grind`. -/
674-
let defEq ← isDefEqGuarded lhs rhs
675-
return !defEq
681+
match cnstr with
682+
| .notDefEq lhs rhs => checkNotDefEq info.levelParams us args lhs rhs
683+
| _ => throwError "NIY"
676684

677685
/--
678686
After processing a (multi-)pattern, use the choice assignment to instantiate the proof.

src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -344,20 +344,68 @@ private def EMatchTheoremKind.explainFailure : EMatchTheoremKind → String
344344
| .default _ => "failed to find patterns"
345345
| .user => unreachable!
346346

347-
/--
348-
Grind patterns may have constraints of the form `lhs =/= rhs` associated with them.
349-
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
350-
equal to a term `t` assigned to `lhs`.
351-
-/
352-
structure EMatchTheoremConstraint where
353-
/-- `lhs` -/
354-
bvarIdx : Nat
347+
structure CnstrRHS where
355348
/-- Abstracted universe level param names in the `rhs` -/
356349
levelNames : Array Name
357350
/-- Number of abstracted metavariable in the `rhs` -/
358351
numMVars : Nat
359352
/-- The actual `rhs`. -/
360-
rhs : Expr
353+
expr : Expr
354+
deriving Inhabited, BEq, Repr
355+
356+
/--
357+
Grind patterns may have constraints associated with them.
358+
-/
359+
inductive EMatchTheoremConstraint where
360+
| /--
361+
A constraint of the form `lhs =/= rhs`.
362+
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must not be definitionally
363+
equal to a term `t` assigned to `lhs`. -/
364+
notDefEq (lhs : Nat) (rhs : CnstrRHS)
365+
| /--
366+
A constraint of the form `lhs =?= rhs`.
367+
The `lhs` is one of the bound variables, and the `rhs` an abstract term that must be definitionally
368+
equal to a term `t` assigned to `lhs`. -/
369+
defEq (lhs : Nat) (rhs : CnstrRHS)
370+
| /--
371+
A constraint of the form `size lhs < n`. The `lhs` is one of the bound variables.
372+
The size is computed ignoring implicit terms, but sharing is not taken into account.
373+
-/
374+
sizeLt (lhs : Nat) (n : Nat)
375+
| /--
376+
A constraint of the form `depth lhs < n`. The `lhs` is one of the bound variables.
377+
The depth is computed in constant time using the `approxDepth` field attached to expressions.
378+
-/
379+
depthLt (lhs : Nat) (n : Nat)
380+
| /--
381+
Instantiates the theorem only if its generation is less than `n`
382+
-/
383+
genLt (lhs : Nat) (n : Nat)
384+
| /--
385+
Constraints of the form `is_ground x`. Instantiates the theorem only if
386+
`x` is ground term.
387+
-/
388+
isGround (bvarIdx : Nat)
389+
| /--
390+
Constraints of the form `is_value x` and `is_strict_value x`.
391+
A value is defined as
392+
- A constructor fully applied to value arguments.
393+
- A literal: numerals, strings, etc.
394+
- A lambda. In the strict case, lambdas are not considered.
395+
-/
396+
isValue (bvarIdx : Nat) (strict : Bool)
397+
| /--
398+
Instantiates the theorem only if less than `n` instances have been generated for this theorem.
399+
-/
400+
maxInsts (n : Nat)
401+
| /--
402+
It instructs `grind` to postpone the instantiation of the theorem until `e` is known to be `true`.
403+
-/
404+
guard (e : Expr)
405+
| /--
406+
Similar to `guard`, but checks whether `e` is implied by asserting `¬e`.
407+
-/
408+
check (e : Expr)
361409
deriving Inhabited, Repr, BEq
362410

363411
/-- A theorem for heuristic instantiation based on E-matching. -/

src/Lean/Meta/Tactic/Grind/Parser.lean

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,26 @@ public import Lean.Parser.Command
99
public section
1010
namespace Lean.Parser.Command
1111

12-
def grindPatternCnstr : Parser := leading_parser ident >> " =/= " >> checkColGe "irrelevant" >> termParser >> optional ";"
12+
namespace GrindCnstr
13+
14+
def isValue := leading_parser nonReservedSymbol "is_value " >> ident >> optional ";"
15+
def isStrictValue := leading_parser nonReservedSymbol "is_strict_value " >> ident >> optional ";"
16+
def isGround := leading_parser nonReservedSymbol "is_ground " >> ident >> optional ";"
17+
def sizeLt := leading_parser nonReservedSymbol "size " >> ident >> " < " >> numLit >> optional ";"
18+
def depthLt := leading_parser nonReservedSymbol "depth " >> ident >> " < " >> numLit >> optional ";"
19+
def genLt := leading_parser nonReservedSymbol "gen " >> ident >> " < " >> numLit >> optional ";"
20+
def maxInsts := leading_parser nonReservedSymbol "max_insts " >> numLit >> optional ";"
21+
def guard := leading_parser nonReservedSymbol "guard " >> checkColGe "irrelevant" >> termParser >> optional ";"
22+
def check := leading_parser nonReservedSymbol "check " >> checkColGe "irrelevant" >> termParser >> optional ";"
23+
def notDefEq := leading_parser atomic (ident >> " =/= ") >> checkColGe "irrelevant" >> termParser >> optional ";"
24+
def defEq := leading_parser atomic (ident >> " =?= ") >> checkColGe "irrelevant" >> termParser >> optional ";"
25+
26+
end GrindCnstr
27+
28+
open GrindCnstr in
29+
def grindPatternCnstr : Parser :=
30+
isValue <|> isStrictValue <|> isGround <|> sizeLt <|> depthLt <|> genLt <|> maxInsts
31+
<|> guard <|> check <|> notDefEq <|> defEq
1332

1433
def grindPatternCnstrs : Parser := leading_parser "where " >> many1Indent (ppLine >> grindPatternCnstr)
1534

0 commit comments

Comments
 (0)