Skip to content

Commit 8a1b6e0

Browse files
authored
feat: compress generated grind tactic sequences using <;> (#10808)
This PR implements support for compressing auto-generated `grind` tactic sequences.
1 parent 7087c4a commit 8a1b6e0

File tree

1 file changed

+61
-12
lines changed

1 file changed

+61
-12
lines changed

src/Lean/Meta/Tactic/Grind/Split.lean

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,58 @@ where
275275
if b.hasLooseBVars then return none
276276
go b
277277

278+
/-- Returns `true` if the given list can be compressed using `<;>` at `splitCore` -/
279+
private def isCompressibleSeq (seq : List (TSyntax `grind)) : Bool :=
280+
seq.all fun tac => match tac with
281+
| `(grind| next $_* => $_:grindSeq) => false
282+
| _ => true
283+
284+
/--
285+
Given `[t₁, ..., tₙ]`, returns `t₁ <;> ... <;> tₙ`
286+
-/
287+
private def mkAndThenSeq (seq : List (TSyntax `grind)) : CoreM (TSyntax `grind) := do
288+
match seq with
289+
| [] => `(grind| done)
290+
| [tac] => return tac
291+
| tac :: seq =>
292+
let seq ← mkAndThenSeq seq
293+
`(grind| $tac:grind <;> $seq:grind)
294+
295+
private def mkCasesAndThen (cases : TSyntax `grind) (seq : List (TSyntax `grind)) : CoreM (TSyntax `grind) := do
296+
match seq with
297+
| [] => return cases
298+
| seq =>
299+
let seq ← mkAndThenSeq seq
300+
`(grind| $cases:grind <;> $seq:grind)
301+
302+
private def isCompressibleAlts (alts : Array (List (TSyntax `grind))) : Bool :=
303+
if _ : alts.size > 0 then
304+
let alt := alts[0]
305+
isCompressibleSeq alt && alts.all (· == alt)
306+
else
307+
true
308+
309+
private def mkCasesResultSeq (cases : TSyntax `grind) (alts : Array (List (TSyntax `grind)))
310+
(compress : Bool) : CoreM (List (TSyntax `grind)) := do
311+
if compress && isCompressibleAlts alts then
312+
if h : alts.size > 0 then
313+
return [(← mkCasesAndThen cases alts[0]!)]
314+
else
315+
return [cases]
316+
else
317+
let seq ← if h : alts.size = 1 then
318+
pure alts[0]
319+
else
320+
alts.toList.mapM fun s => mkGrindNext s
321+
return cases :: seq
322+
278323
/--
279324
Performs a case-split using `c`.
280325
Remark: `numCases` and `isRec` are computed using `checkSplitStatus`.
281326
-/
282-
private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool) (stopAtFirstFailure : Bool) : Action := fun goal _ kp => do
283-
let mvarDecl ← goal.mvarId.getDecl
284-
let numIndices := mvarDecl.lctx.numIndices
327+
private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool)
328+
(stopAtFirstFailure : Bool)
329+
(compress : Bool) : Action := fun goal _ kp => do
285330
let mvarId ← goal.mkAuxMVar
286331
let cExpr := c.getExpr
287332
let (mvarIds, goal) ← GoalM.run goal do
@@ -325,19 +370,13 @@ private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool) (stopAtFir
325370
goal.mvarId.assign (← instantiateMVars (mkMVar mvarId))
326371
if stuckNew.isEmpty then
327372
if traceEnabled then
328-
let seqListNew ← if h : seqNew.size = 1 then
329-
pure seqNew[0]
330-
else
331-
seqNew.toList.mapM fun s => mkGrindNext s
332-
let mut seqListNew := seqListNew
333373
let anchor ← goal.withContext <| getAnchor cExpr
334374
-- **TODO**: compute the exact number of digits
335375
let numDigits := 4
336376
let anchorPrefix := anchor >>> (64 - 16)
337377
let hexnum := mkNode `hexnum #[mkAtom (anchorToString numDigits anchorPrefix)]
338378
let cases ← `(grind| cases #$hexnum)
339-
seqListNew := cases :: seqListNew
340-
return .closed seqListNew
379+
return .closed (← mkCasesResultSeq cases seqNew compress)
341380
else
342381
return .closed []
343382
else
@@ -348,15 +387,25 @@ Selects a case-split from the list of candidates, performs the split and applies
348387
continuation to all subgoals.
349388
If a subgoal is solved without using new hypotheses, closes the original goal using this proof. That is,
350389
it performs non-chronological backtracking.
390+
351391
If `stopsAtFirstFailure = true`, it stops the search as soon as the given continuation cannot solve a subgoal.
392+
393+
If `compress = true`, then it uses `<;>` to generate the resulting tactic sequence if all subgoal sequences are
394+
identical. For example, suppose that the following sequence is generated with `compress = false`
395+
```
396+
cases #50fc
397+
next => lia
398+
next => lia
399+
```
400+
Then with `compress = true` it generates `cases #50fc <;> lia`
352401
-/
353-
def splitNext (stopAtFirstFailure := true) : Action := fun goal kna kp => do
402+
def splitNext (stopAtFirstFailure := true) (compress := true) : Action := fun goal kna kp => do
354403
let (r, goal) ← GoalM.run goal selectNextSplit?
355404
let .some c numCases isRec _ := r
356405
| kna goal
357406
let cExpr := c.getExpr
358407
let gen := goal.getGeneration cExpr
359-
let x : Action := splitCore c numCases isRec stopAtFirstFailure >> intros gen
408+
let x : Action := splitCore c numCases isRec stopAtFirstFailure compress >> intros gen
360409
x goal kna kp
361410

362411
end Action

0 commit comments

Comments
 (0)