Skip to content

Commit 6d665f3

Browse files
authored
fix: bugs at grind => finish? (#10936)
This PR fixes issues in `grind => finish?` that were preventing generated `grind` tactic scripts from being successfully replayed.
1 parent 74fd468 commit 6d665f3

File tree

5 files changed

+117
-4
lines changed

5 files changed

+117
-4
lines changed

src/Lean/Elab/Tactic/Grind/Basic.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def GrindTacticM.runAtGoal (mvarId : MVarId) (params : Params) (k : GrindTacticM
379379
-- **Note**: We use `withCheapCasesOnly` to ensure multiple goals are not created.
380380
-- We will add support for this case in the future.
381381
let (goal, _) ← withCheapCasesOnly <| SearchM.run goal do
382-
intros 0
382+
intros 0; discard <| assertAll
383383
getGoal
384384
let goals := if goal.inconsistent then [] else [goal]
385385
let ctx ← readThe Meta.Grind.Context

src/Lean/Elab/Tactic/Grind/BuiltinTactic.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def logAnchor (e : Expr) : TermElabM Unit := do
281281
goal.withContext <| withRef anchor <| logAnchor e
282282
let goals ← goals.filterMapM fun goal => do
283283
let (goal, _) ← liftGrindM <| SearchM.run goal do
284-
intros genNew
284+
intros genNew; discard <| assertAll
285285
getGoal
286286
if goal.inconsistent then
287287
return none

src/Lean/Meta/Tactic/Grind/EMatchAction.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def mkInstantiateTactic (goal : Goal) (usedThms : Array EMatchTheorem) (approx :
108108
| false, true => `(grind| instantiate only approx [$params,*])
109109

110110
def mkNewSeq (goal : Goal) (thms : Array EMatchTheorem) (seq : List TGrind) (approx : Bool) : GrindM (List TGrind) := do
111-
if thms.isEmpty then
111+
if thms.isEmpty && !approx then
112112
return seq
113113
else
114114
return ((← mkInstantiateTactic goal thms approx) :: seq)

src/Lean/Meta/Tactic/Grind/Split.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def splitNext (stopAtFirstFailure := true) (compress := true) : Action := fun go
444444
| kna goal
445445
let cExpr := c.getExpr
446446
let gen := goal.getGeneration cExpr
447-
let x : Action := splitCore c numCases isRec stopAtFirstFailure compress >> intros gen
447+
let x : Action := splitCore c numCases isRec stopAtFirstFailure compress >> intros gen >> assertAll
448448
x goal kna kp
449449

450450
end Action
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
module
2+
public import Std.Data.HashMap
3+
public import Std.Data.TreeMap
4+
5+
inductive IfExpr
6+
| lit : Bool → IfExpr
7+
| var : Nat → IfExpr
8+
| ite : IfExpr → IfExpr → IfExpr → IfExpr
9+
deriving DecidableEq
10+
11+
namespace IfExpr
12+
13+
def hasNestedIf : IfExpr → Bool
14+
| lit _ => false
15+
| var _ => false
16+
| ite (ite _ _ _) _ _ => true
17+
| ite _ t e => t.hasNestedIf || e.hasNestedIf
18+
19+
def hasConstantIf : IfExpr → Bool
20+
| lit _ => false
21+
| var _ => false
22+
| ite (lit _) _ _ => true
23+
| ite i t e => i.hasConstantIf || t.hasConstantIf || e.hasConstantIf
24+
25+
def hasRedundantIf : IfExpr → Bool
26+
| lit _ => false
27+
| var _ => false
28+
| ite i t e => t == e || i.hasRedundantIf || t.hasRedundantIf || e.hasRedundantIf
29+
30+
def vars : IfExpr → List Nat
31+
| lit _ => []
32+
| var i => [i]
33+
| ite i t e => i.vars ++ t.vars ++ e.vars
34+
35+
def _root_.List.disjoint {α} [DecidableEq α] : List α → List α → Bool
36+
| [], _ => true
37+
| x::xs, ys => x ∉ ys && xs.disjoint ys
38+
39+
def disjoint : IfExpr → Bool
40+
| lit _ => true
41+
| var _ => true
42+
| ite i t e =>
43+
i.vars.disjoint t.vars && i.vars.disjoint e.vars && i.disjoint && t.disjoint && e.disjoint
44+
45+
def normalized (e : IfExpr) : Bool :=
46+
!e.hasNestedIf && !e.hasConstantIf && !e.hasRedundantIf && e.disjoint
47+
48+
def eval (f : Nat → Bool) : IfExpr → Bool
49+
| lit b => b
50+
| var i => f i
51+
| ite i t e => bif i.eval f then t.eval f else e.eval f
52+
53+
end IfExpr
54+
55+
def IfNormalization : Type := { Z : IfExpr → IfExpr // ∀ e, (Z e).normalized ∧ (Z e).eval = e.eval }
56+
57+
namespace IfExpr
58+
59+
@[simp] def normSize : IfExpr → Nat
60+
| lit _ => 0
61+
| var _ => 1
62+
| .ite i t e => 2 * normSize i + max (normSize t) (normSize e) + 1
63+
64+
def normalize (assign : Std.HashMap Nat Bool) : IfExpr → IfExpr
65+
| lit b => lit b
66+
| var v =>
67+
match assign[v]? with
68+
| none => var v
69+
| some b => lit b
70+
| ite (lit true) t _ => normalize assign t
71+
| ite (lit false) _ e => normalize assign e
72+
| ite (ite a b c) t e => normalize assign (ite a (ite b t e) (ite c t e))
73+
| ite (var v) t e =>
74+
match assign[v]? with
75+
| none =>
76+
let t' := normalize (assign.insert v true) t
77+
let e' := normalize (assign.insert v false) e
78+
if t' = e' then t' else ite (var v) t' e'
79+
| some b => normalize assign (ite (lit b) t e)
80+
termination_by e => e.normSize
81+
82+
-- We tell `grind` to unfold our definitions above.
83+
attribute [local grind] normalized hasNestedIf hasConstantIf hasRedundantIf disjoint vars eval List.disjoint
84+
85+
theorem normalize_spec (assign : Std.HashMap Nat Bool) (e : IfExpr) :
86+
(normalize assign e).normalized
87+
∧ (∀ f, (normalize assign e).eval f = e.eval fun w => assign[w]?.getD (f w))
88+
∧ ∀ (v : Nat), v ∈ vars (normalize assign e) → ¬ v ∈ assign := by
89+
fun_induction normalize
90+
next => grind => finish?
91+
next => grind => finish?
92+
next => grind => finish?
93+
next => grind => finish?
94+
next => grind => finish?
95+
next => grind => finish?
96+
next => grind => finish -- TODO: ensure `finish?` works here
97+
next => grind => finish -- TODO: ensure `finish?` works here
98+
next => grind => finish?
99+
100+
example (assign : Std.HashMap Nat Bool) (e : IfExpr) :
101+
(normalize assign e).normalized
102+
∧ (∀ f, (normalize assign e).eval f = e.eval fun w => assign[w]?.getD (f w))
103+
∧ ∀ (v : Nat), v ∈ vars (normalize assign e) → assign.contains v = false := by
104+
fun_induction normalize
105+
next => grind => finish?
106+
next => grind => finish?
107+
next => grind => finish?
108+
next => grind => finish?
109+
next => grind => finish?
110+
next => grind => finish?
111+
next => grind => finish -- TODO: ensure `finish?` works here
112+
next => grind => finish -- TODO: ensure `finish?` works here
113+
next => grind => finish?

0 commit comments

Comments
 (0)