Skip to content

Commit fd97add

Browse files
committed
feat: improve case-split heuristic in grind
This PR improves the case-split heuristics in `grind`. In this PR, we do not increment the number of case splits in the first case. The idea is to leverage non-chronological backtracking: if the first case is solved using a proof that doesn't depend on the case hypothesis, we backtrack and close the original goal directly. In this scenario, the case-split was "free", it didn't contribute to the proof. By not counting it, we allow deeper exploration when case-splits turn out to be irrelevant. The new heuristic addresses the second example in #11545
1 parent eee58f4 commit fd97add

File tree

3 files changed

+64
-10
lines changed

3 files changed

+64
-10
lines changed

src/Lean/Meta/Tactic/Grind/Split.lean

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,10 @@ where
185185
match cs with
186186
| [] =>
187187
modify fun s => { s with split.candidates := cs'.reverse }
188-
if let .some _ numCases isRec _ := c? then
189-
let numSplits := (← get).split.num
190-
-- We only increase the number of splits if there is more than one case or it is recursive.
191-
let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits
188+
if let .some .. := c? then
192189
-- Remark: we reset `numEmatch` after each case split.
193190
-- We should consider other strategies in the future.
194-
modify fun s => { s with split.num := numSplits, ematch.num := 0 }
191+
modify fun s => { s with ematch.num := 0 }
195192
return c?
196193
| c::cs =>
197194
if !(← checkAnchorRefs c) then
@@ -422,10 +419,24 @@ def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool)
422419
pure 0
423420
return (mvarIds, numDigits)
424421
let numSubgoals := mvarIds.length
425-
let subgoals := mvarIds.mapIdx fun i mvarId => { goal with
426-
mvarId
427-
split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace
428-
}
422+
/-
423+
**Split counter heuristic**: We do not increment `numSplits` for the first case (`i = 0`)
424+
of a non-recursive split. This leverages non-chronological backtracking: if the first case
425+
is solved using a proof that doesn't depend on the case hypothesis, we backtrack and close
426+
the original goal directly. In this scenario, the case-split was "free", it didn't contribute
427+
to the proof. By not counting it, we allow deeper exploration when case-splits turn out to be
428+
irrelevant.
429+
430+
For recursive types or subsequent cases (`i > 0`), we always increment the counter since
431+
these represent genuine branches in the proof search.
432+
-/
433+
let subgoals := mvarIds.mapIdx fun i mvarId =>
434+
let numSplits := goal.split.num
435+
let numSplits := if i > 0 || isRec then numSplits + 1 else numSplits
436+
{ goal with
437+
mvarId
438+
split.num := numSplits
439+
split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace }
429440
let mut seqNew : Array (List (TSyntax `grind)) := #[]
430441
let mut stuckNew : Array Goal := #[]
431442
for subgoal in subgoals do

tests/lean/run/grind_11081.lean

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ open List
2525

2626
/--
2727
error: `grind` failed
28-
case grind.1.1.1.1.1.1.1.1.1
28+
case grind.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1
2929
α : Type
3030
inst : DecidableEq α
3131
l₁ l₂ : List α
@@ -66,6 +66,44 @@ left_8 : l₁ ~ l₁.diff l₂
6666
right_8 : ∀ (a : α), count a l₁ = count a (l₁.diff l₂)
6767
left_9 : l₁ ~ l₂
6868
right_9 : ∀ (a : α), count a l₁ = count a l₂
69+
left_10 : filter p l₁ ~ filter p (l₁.diff l₂ ++ l₂)
70+
right_10 : ∀ (a : α), count a (filter p l₁) = count a (filter p (l₁.diff l₂ ++ l₂))
71+
left_11 : filter p (l₁.diff l₂ ++ l₂) ~ filter p l₁
72+
right_11 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p l₁)
73+
left_12 : l₁.diff l₂ ++ l₂ ~ l₂ ++ (l₁.diff l₂ ++ l₂)
74+
right_12 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₂ ++ (l₁.diff l₂ ++ l₂))
75+
left_13 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁
76+
right_13 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁)
77+
left_14 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)
78+
right_14 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂))
79+
left_15 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₂
80+
right_15 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₂)
81+
left_16 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁.diff l₂
82+
right_16 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂)
83+
left_17 : l₁.diff l₂ ++ l₂ ~ l₁ ++ (l₁.diff l₂ ++ l₂)
84+
right_17 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁ ++ (l₁.diff l₂ ++ l₂))
85+
left_18 : filter p (l₁.diff l₂ ++ l₂) ~ filter p (l₁.diff l₂)
86+
right_18 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p (l₁.diff l₂))
87+
left_19 : filter p (l₁.diff l₂) ~ filter p (l₁.diff l₂ ++ l₂)
88+
right_19 : ∀ (a : α), count a (filter p (l₁.diff l₂)) = count a (filter p (l₁.diff l₂ ++ l₂))
89+
left_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p l₁)
90+
right_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p (l₁.diff l₂ ++ l₂))
91+
left_21 : l₁.diff l₂ ++ l₂ ++ l₁.diff l₂ ~ l₁.diff l₂ ++ l₂
92+
right_21 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂) = count a (l₁.diff l₂ ++ l₂)
93+
left_22 : l₁.diff l₂ ++ l₂ ++ l₁ ~ l₁.diff l₂ ++ l₂
94+
right_22 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁) = count a (l₁.diff l₂ ++ l₂)
95+
left_23 : l₁.diff l₂ ++ l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂
96+
right_23 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂)
97+
left_24 : l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
98+
right_24 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
99+
left_25 : l₁ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
100+
right_25 : ∀ (a : α), count a (l₁ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
101+
left_26 : l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
102+
right_26 : ∀ (a : α), count a (l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
103+
left_27 : l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
104+
right_27 : ∀ (a : α), count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
105+
left_28 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)
106+
right_28 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂))
69107
⊢ False
70108
-/
71109
#guard_msgs in

tests/lean/run/grind_11539_2.lean

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
example (a b : Nat) (f g : Nat → Nat)
2+
(hf : (∀ i ≤ a, f i ≤ f (i + 1)) ∧ f 0 = 0)
3+
(hg : (∀ i ≤ b, g i ≤ g (i + 1)) ∧ g 0 = 0 ∧ g b = 0) :
4+
g (a + b - a) = 0 := by
5+
grind

0 commit comments

Comments
 (0)