Skip to content

Commit 7087c4a

Browse files
authored
feat: add splitNext grind action (#10801)
This PR implements the `splitNext` action for `grind`.
1 parent b7ea66d commit 7087c4a

File tree

10 files changed

+225
-181
lines changed

10 files changed

+225
-181
lines changed

src/Lean/Meta/Tactic/Grind/Action.lean

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -151,33 +151,6 @@ Executes `x`, but behaves like a `skip` if it is not applicable.
151151
def skipIfNA (x : Action) : Action := fun goal _ kp =>
152152
x goal kp kp
153153

154-
private def isTargetFalse (mvarId : MVarId) : MetaM Bool := do
155-
return (← mvarId.getType).isFalse
156-
157-
private def getFalseProof? (mvarId : MVarId) : MetaM (Option Expr) := mvarId.withContext do
158-
let proof ← instantiateMVars (mkMVar mvarId)
159-
if (← isTargetFalse mvarId) then
160-
return some proof
161-
else if proof.isAppOfArity ``False.elim 2 || proof.isAppOfArity ``False.casesOn 2 then
162-
return some proof.appArg!
163-
else
164-
return none
165-
166-
/--
167-
Returns the maximum free variable id occurring in `e`
168-
-/
169-
private def findMaxFVarIdx? (e : Expr) : MetaM (Option Nat) := do
170-
let go (e : Expr) : StateT (Option Nat) MetaM Bool := do
171-
unless e.hasFVar do return false
172-
let .fvar fvarId := e | return true
173-
let localDecl ← fvarId.getDecl
174-
modify fun
175-
| none => some localDecl.index
176-
| some index => some (max index localDecl.index)
177-
return false
178-
let (_, s?) ← e.forEach' go |>.run none
179-
return s?
180-
181154
private def mkGrindSeq (s : List (TSyntax `grind)) : TSyntax ``Parser.Tactic.Grind.grindSeq :=
182155
let s := s.map (·.raw)
183156
let s := s.intersperse (mkNullNode #[])
@@ -236,73 +209,6 @@ def ungroup : Action := fun goal _ kp => do
236209
else
237210
return r
238211

239-
/--
240-
Returns `some falseProof` if we can use non-chronological backtracking with `subgoal`.
241-
That is, `subgoal` was closed using `falseProof`, but its proof does not use any of the
242-
new hypotheses. A hypothesis is new if its `index >= oldNumIndices`.
243-
-/
244-
def useNCB? (oldNumIndices : Nat) (subgoal : Goal) : MetaM (Option Expr) := do
245-
let some falseProof ← getFalseProof? subgoal.mvarId
246-
| return none
247-
let some max ← subgoal.mvarId.withContext <| findMaxFVarIdx? falseProof
248-
| return some falseProof -- Proof actually closes any pending split
249-
if max < oldNumIndices then
250-
return some falseProof
251-
else
252-
return none
253-
254-
/--
255-
Helper function for implementing tactics that perform case-splits
256-
**Note**: We will probably delete this function.
257-
-/
258-
def splitCore
259-
(goal : Goal)
260-
(anchor? : Option (Nat × UInt64))
261-
(s : MVarId → GrindM (List MVarId))
262-
(kp : ActionCont) : GrindM ActionResult := do
263-
let mvarDecl ← goal.mvarId.getDecl
264-
let numIndices := mvarDecl.lctx.numIndices
265-
let mvarId ← goal.mkAuxMVar
266-
let mvarIds ← s mvarId
267-
let subgoals := mvarIds.map fun mvarId => { goal with mvarId }
268-
let traceEnabled := (← getConfig).trace
269-
let mut seqNew : Array (List (TSyntax `grind)) := #[]
270-
let mut stuckNew : Array Goal := #[]
271-
for subgoal in subgoals do
272-
match (← kp subgoal) with
273-
| .stuck gs =>
274-
-- *TODO*: Add support for saving multiple failures
275-
return .stuck gs
276-
| .closed seq =>
277-
if let some falseProof ← useNCB? numIndices subgoal then
278-
goal.mvarId.assignFalseProof falseProof
279-
return .closed seq
280-
else
281-
seqNew := seqNew.push seq
282-
if stuckNew.isEmpty then
283-
goal.mvarId.assign (← instantiateMVars (mkMVar mvarId))
284-
if traceEnabled then
285-
let seqListNew ← if h : seqNew.size = 1 then
286-
pure seqNew[0]
287-
else
288-
seqNew.toList.mapM fun s => mkGrindNext s
289-
let mut seqListNew := seqListNew
290-
if let some anchor := anchor? then
291-
let hexnum := mkNode `hexnum #[mkAtom (anchorToString anchor.1 anchor.2)]
292-
/-
293-
*TODO*: We need to distinguish between user-facing `cases` which `intros` new hypotheses
294-
automatically, and auto-generated `cases` produced by `grind?` and `finish?` which does not
295-
`intros` automatically. Each branch provides includes its own `intros`.
296-
*Current strategy*: Use only one `cases` (`intros`) automatically and add `rename_i`.
297-
-/
298-
let cases ← `(grind| cases #$hexnum)
299-
seqListNew := cases :: seqListNew
300-
return .closed seqListNew
301-
else
302-
return .closed []
303-
else
304-
return .stuck stuckNew.toList
305-
306212
section
307213
/-!
308214
Some sanity check properties.

src/Lean/Meta/Tactic/Grind/Split.lean

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ module
77
prelude
88
public import Lean.Meta.Tactic.Grind.Types
99
public import Lean.Meta.Tactic.Grind.SearchM
10+
public import Lean.Meta.Tactic.Grind.Action
1011
import Lean.Meta.Tactic.Grind.Intro
1112
import Lean.Meta.Tactic.Grind.Cases
1213
import Lean.Meta.Tactic.Grind.Util
1314
import Lean.Meta.Tactic.Grind.CasesMatch
1415
import Lean.Meta.Tactic.Grind.Internalize
16+
import Lean.Meta.Tactic.Grind.Anchor
1517
public section
1618
namespace Lean.Meta.Grind
1719

@@ -242,6 +244,131 @@ private def casesWithTrace (mvarId : MVarId) (major : Expr) : GoalM (List MVarId
242244
saveCases declName false
243245
cases mvarId major
244246

247+
namespace Action
248+
249+
/--
250+
Given a `mvarId` associated with a subgoal created by `splitCore`, inspects the
251+
proof term assigned to `mvarId` and tries to extract the proof of `False` that does not
252+
depend on hypotheses introduced in the subgoal.
253+
For example: suppose the subgoal is of the form `p → q → False` where `p` and `q` are new
254+
hypotheses introduced during case analysis. If the proof is of the form `fun _ _ => h`, returns
255+
`some h`.
256+
-/
257+
private def getFalseProof? (mvarId : MVarId) : MetaM (Option Expr) := mvarId.withContext do
258+
let proof ← instantiateMVars (mkMVar mvarId)
259+
go proof
260+
where
261+
go (proof : Expr) : MetaM (Option Expr) := do
262+
match_expr proof with
263+
| False.elim _ p => return some p
264+
| False.casesOn _ p => return some p
265+
| id α p => if α.isFalse then return some p else return none
266+
| _ =>
267+
/-
268+
**Note**: `intros` tactics may hide the `False` proof behind a `casesOn`
269+
For example: suppose the subgoal has a type of the form `p₁ → q₁ ∧ q₂ → p₂ → False`
270+
The proof will be of the form `fun _ h => h.casesOn (fun _ _ => hf)` where `hf` is the proof
271+
of `False` we are looking for.
272+
Non-chronological backtracking currently fails in this kind of example.
273+
-/
274+
let .lam _ _ b _ := proof | return none
275+
if b.hasLooseBVars then return none
276+
go b
277+
278+
/--
279+
Performs a case-split using `c`.
280+
Remark: `numCases` and `isRec` are computed using `checkSplitStatus`.
281+
-/
282+
private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool) (stopAtFirstFailure : Bool) : Action := fun goal _ kp => do
283+
let mvarDecl ← goal.mvarId.getDecl
284+
let numIndices := mvarDecl.lctx.numIndices
285+
let mvarId ← goal.mkAuxMVar
286+
let cExpr := c.getExpr
287+
let (mvarIds, goal) ← GoalM.run goal do
288+
let gen ← getGeneration cExpr
289+
let genNew := if numCases > 1 || isRec then gen+1 else gen
290+
saveSplitDiagInfo cExpr genNew numCases c.source
291+
markCaseSplitAsResolved cExpr
292+
trace_goal[grind.split] "{cExpr}, generation: {gen}"
293+
let mvarIds ← if let .imp e h _ := c then
294+
casesWithTrace mvarId (mkGrindEM (e.forallDomain h))
295+
else if (← isMatcherApp cExpr) then
296+
casesMatch mvarId cExpr
297+
else
298+
casesWithTrace mvarId (← mkCasesMajor cExpr)
299+
let subgoals := mvarIds.map fun mvarId => { goal with mvarId }
300+
let traceEnabled := (← getConfig).trace
301+
let mut seqNew : Array (List (TSyntax `grind)) := #[]
302+
let mut stuckNew : Array Goal := #[]
303+
for subgoal in subgoals do
304+
match (← kp subgoal) with
305+
| .stuck gs =>
306+
if stopAtFirstFailure then
307+
/-
308+
**Note**: We don't need to assign `goal.mvarId` when `stopAtFirstFailure = true`
309+
because the caller will not be able to process the all failure/stuck goals anyway.
310+
-/
311+
return .stuck gs
312+
else
313+
stuckNew := stuckNew ++ gs
314+
| .closed seq =>
315+
if let some falseProof ← getFalseProof? subgoal.mvarId then
316+
goal.mvarId.assignFalseProof falseProof
317+
return .closed seq
318+
else if !seq.isEmpty then
319+
/- **Note**: if the sequence is empty, it means the user will never see this goal. -/
320+
seqNew := seqNew.push seq
321+
if (← goal.mvarId.getType).isFalse then
322+
/- **Note**: We add the marker to assist `getFalseExpr?` -/
323+
goal.mvarId.assign (mkExpectedPropHint (← instantiateMVars (mkMVar mvarId)) (mkConst ``False))
324+
else
325+
goal.mvarId.assign (← instantiateMVars (mkMVar mvarId))
326+
if stuckNew.isEmpty then
327+
if traceEnabled then
328+
let seqListNew ← if h : seqNew.size = 1 then
329+
pure seqNew[0]
330+
else
331+
seqNew.toList.mapM fun s => mkGrindNext s
332+
let mut seqListNew := seqListNew
333+
let anchor ← goal.withContext <| getAnchor cExpr
334+
-- **TODO**: compute the exact number of digits
335+
let numDigits := 4
336+
let anchorPrefix := anchor >>> (64 - 16)
337+
let hexnum := mkNode `hexnum #[mkAtom (anchorToString numDigits anchorPrefix)]
338+
let cases ← `(grind| cases #$hexnum)
339+
seqListNew := cases :: seqListNew
340+
return .closed seqListNew
341+
else
342+
return .closed []
343+
else
344+
return .stuck stuckNew.toList
345+
346+
/--
347+
Selects a case-split from the list of candidates, performs the split and applies
348+
continuation to all subgoals.
349+
If a subgoal is solved without using new hypotheses, closes the original goal using this proof. That is,
350+
it performs non-chronological backtracking.
351+
If `stopsAtFirstFailure = true`, it stops the search as soon as the given continuation cannot solve a subgoal.
352+
-/
353+
def splitNext (stopAtFirstFailure := true) : Action := fun goal kna kp => do
354+
let (r, goal) ← GoalM.run goal selectNextSplit?
355+
let .some c numCases isRec _ := r
356+
| kna goal
357+
let cExpr := c.getExpr
358+
let gen := goal.getGeneration cExpr
359+
let x : Action := splitCore c numCases isRec stopAtFirstFailure >> intros gen
360+
x goal kna kp
361+
362+
end Action
363+
364+
/-!
365+
**------------------------------------------**
366+
**------------------------------------------**
367+
**TODO** Delete rest of the file
368+
**------------------------------------------**
369+
**------------------------------------------**
370+
-/
371+
245372
/--
246373
Performs a case-split using `c`.
247374
Remarks:

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -923,10 +923,15 @@ def Goal.getENode (goal : Goal) (e : Expr) : CoreM ENode := do
923923
def getENode (e : Expr) : GoalM ENode := do
924924
(← get).getENode e
925925

926+
def Goal.getGeneration (goal : Goal) (e : Expr) : Nat :=
927+
if let some n := goal.getENode? e then
928+
n.generation
929+
else
930+
0
931+
926932
/-- Returns the generation of the given term. Is assumes it has been internalized -/
927-
def getGeneration (e : Expr) : GoalM Nat := do
928-
let some n ← getENode? e | return 0
929-
return n.generation
933+
def getGeneration (e : Expr) : GoalM Nat :=
934+
return (← get).getGeneration e
930935

931936
/-- Returns `true` if `e` is in the equivalence class of `True`. -/
932937
def isEqTrue (e : Expr) : GoalM Bool := do
@@ -1244,7 +1249,10 @@ If type of `mvarId` is not `False`, then use `False.elim`.
12441249
def _root_.Lean.MVarId.assignFalseProof (mvarId : MVarId) (falseProof : Expr) : MetaM Unit := mvarId.withContext do
12451250
let target ← mvarId.getType
12461251
if target.isFalse then
1247-
mvarId.assign falseProof
1252+
/-
1253+
**Note**: We add the marker to assist `getFalseExpr?` used to implement
1254+
non-chronological backtracking. -/
1255+
mvarId.assign (mkExpectedPropHint falseProof (mkConst ``False))
12481256
else
12491257
mvarId.assign (← mkFalseElim target falseProof)
12501258

@@ -1341,12 +1349,6 @@ def propagateUp (e : Expr) : GoalM Unit := do
13411349
def propagateDown (e : Expr) : GoalM Unit := do
13421350
(← getMethods).propagateDown e
13431351

1344-
def Goal.getGeneration (goal : Goal) (e : Expr) : Nat :=
1345-
if let some n := goal.getENode? e then
1346-
n.generation
1347-
else
1348-
0
1349-
13501352
/-- Returns expressions in the given expression equivalence class. -/
13511353
partial def Goal.getEqc (goal : Goal) (e : Expr) (sort := false) : List Expr :=
13521354
let eqc := go e e #[]

tests/lean/run/grind_cutsat_trim_context.lean

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module
22
/--
33
trace: [grind.debug.proof] fun h h_1 h_2 h_3 h_4 h_5 h_6 h_7 h_8 =>
4-
let ctx := RArray.leaf (f 2);
5-
let p_1 := Poly.add 1 0 (Poly.num 0);
6-
let p_2 := Poly.add (-1) 0 (Poly.num 1);
7-
let p_3 := Poly.num 1;
8-
le_unsat ctx p_3 (eagerReduce (Eq.refl true)) (le_combine ctx p_2 p_1 p_3 (eagerReduce (Eq.refl true)) h_8 h_1)
4+
id
5+
(let ctx := RArray.leaf (f 2);
6+
let p_1 := Poly.add 1 0 (Poly.num 0);
7+
let p_2 := Poly.add (-1) 0 (Poly.num 1);
8+
let p_3 := Poly.num 1;
9+
le_unsat ctx p_3 (eagerReduce (Eq.refl true)) (le_combine ctx p_2 p_1 p_3 (eagerReduce (Eq.refl true)) h_8 h_1))
910
-/
1011
#guard_msgs in -- `cutsat` context should contain only `f 2`
1112
open Lean Int Linear in

tests/lean/run/grind_linarith_2.lean

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@ example [IntModule α] [LE α] [LT α] [IsPreorder α] [OrderedAdd α] (a b : α
88

99
/--
1010
trace: [grind.debug.proof] Classical.byContradiction fun h =>
11-
let ctx := RArray.leaf One.one;
12-
let p_1 := Poly.nil;
13-
let e_1 := Expr.zero;
14-
let e_2 := Expr.intMul 0 (Expr.var 0);
15-
let rctx := RArray.branch 1 (RArray.leaf a) (RArray.leaf b);
16-
let rp_1 := CommRing.Poly.num 0;
17-
let re_1 := (CommRing.Expr.var 0).add (CommRing.Expr.var 1);
18-
let re_2 := (CommRing.Expr.var 1).add (CommRing.Expr.var 0);
19-
diseq_unsat ctx
20-
(diseq_norm ctx e_2 e_1 p_1 (eagerReduce (Eq.refl true))
21-
(CommRing.diseq_norm rctx re_1 re_2 rp_1 (eagerReduce (Eq.refl true)) h))
11+
id
12+
(let ctx := RArray.leaf One.one;
13+
let p_1 := Poly.nil;
14+
let e_1 := Expr.zero;
15+
let e_2 := Expr.intMul 0 (Expr.var 0);
16+
let rctx := RArray.branch 1 (RArray.leaf a) (RArray.leaf b);
17+
let rp_1 := CommRing.Poly.num 0;
18+
let re_1 := (CommRing.Expr.var 0).add (CommRing.Expr.var 1);
19+
let re_2 := (CommRing.Expr.var 1).add (CommRing.Expr.var 0);
20+
diseq_unsat ctx
21+
(diseq_norm ctx e_2 e_1 p_1 (eagerReduce (Eq.refl true))
22+
(CommRing.diseq_norm rctx re_1 re_2 rp_1 (eagerReduce (Eq.refl true)) h)))
2223
-/
2324
#guard_msgs in
2425
open Linarith in

0 commit comments

Comments
 (0)