Skip to content

Commit b68ac99

Browse files
authored
feat: try? uses parallelism (#11365)
This PR enables parallelism in `try?`. Currently, we replace the `attempt_all` stages (there are two, one for builtin tactics including `grind` and `simp_all`, and a second one for all user extensions) with parallel versions. We do not (yet?) change the behaviour of `first` based stages.
1 parent e1f8c14 commit b68ac99

File tree

4 files changed

+122
-25
lines changed

4 files changed

+122
-25
lines changed

src/Init/Try.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ syntax (name := tryTrace) "try?" optConfig : tactic
5555
/-- Helper internal tactic for implementing the tactic `try?`. -/
5656
syntax (name := attemptAll) "attempt_all " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic
5757

58+
/-- Helper internal tactic for implementing the tactic `try?` with parallel execution. -/
59+
syntax (name := attemptAllPar) "attempt_all_par " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic
60+
5861
/-- Helper internal tactic used to implement `evalSuggest` in `try?` -/
5962
syntax (name := tryResult) "try_suggestions " tactic* : tactic
6063

src/Lean/Elab/Parallel.lean

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -194,24 +194,24 @@ def parIterGreedy {α : Type} (jobs : List (CoreM α)) :=
194194

195195
/--
196196
Runs a list of CoreM computations in parallel and collects results in the original order,
197-
including the state after each task completes.
197+
including the saved state after each task completes.
198198
199199
Unlike `parIter`, this waits for all tasks to complete and returns results
200200
in the same order as the input list, not in completion order.
201201
202-
Results are wrapped in `Except Exception (α × Core.State)` so that errors in individual
202+
Results are wrapped in `Except Exception (α × Core.SavedState)` so that errors in individual
203203
tasks don't stop the collection - you can observe all results including which tasks failed.
204204
205205
The final CoreM state is restored to the initial state (before tasks ran).
206206
-/
207-
def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α × Core.State))) := do
207+
def par {α : Type} (jobs : List (CoreM α)) : CoreM (List (Except Exception (α × Core.SavedState))) := do
208208
let initialState ← get
209209
let tasks ← jobs.mapM asTask'
210210
let mut results := []
211211
for task in tasks do
212212
let resultWithState ← observing do
213213
let result ← task.get
214-
pure (result, (← get))
214+
pure (result, (← saveState))
215215
results := resultWithState :: results
216216
set initialState
217217
return results.reverse
@@ -261,25 +261,24 @@ open Std.Iterators
261261

262262
/--
263263
Runs a list of MetaM computations in parallel and collects results in the original order,
264-
including the state after each task completes.
264+
including the saved state after each task completes.
265265
266266
Unlike `parIter`, this waits for all tasks to complete and returns results
267267
in the same order as the input list, not in completion order.
268268
269-
Results are wrapped in `Except Exception (α × Meta.State)` so that errors in individual
269+
Results are wrapped in `Except Exception (α × Meta.SavedState)` so that errors in individual
270270
tasks don't stop the collection - you can observe all results including which tasks failed.
271271
272272
The final MetaM state is restored to the initial state (before tasks ran).
273-
Note: Only Meta.State is captured/reverted, not Core.State or IO effects.
274273
-/
275-
def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α × Meta.State))) := do
274+
def par {α : Type} (jobs : List (MetaM α)) : MetaM (List (Except Exception (α × Meta.SavedState))) := do
276275
let initialState ← get
277276
let tasks ← jobs.mapM asTask'
278277
let mut results := []
279278
for task in tasks do
280279
let resultWithState ← observing do
281280
let result ← task.get
282-
pure (result, (← get))
281+
pure (result, (← saveState))
283282
results := resultWithState :: results
284283
set initialState
285284
return results.reverse
@@ -465,27 +464,24 @@ def parIterGreedy {α : Type} (jobs : List (TermElabM α)) :=
465464

466465
/--
467466
Runs a list of TermElabM computations in parallel and collects results in the original order,
468-
including the state after each task completes.
467+
including the saved state after each task completes.
469468
470469
Unlike `parIter`, this waits for all tasks to complete and returns results
471470
in the same order as the input list, not in completion order.
472471
473-
Results are wrapped in `Except Exception (α × Term.State)` so that errors in individual
472+
Results are wrapped in `Except Exception (α × Term.SavedState)` so that errors in individual
474473
tasks don't stop the collection - you can observe all results including which tasks failed.
475474
476475
The final TermElabM state is restored to the initial state (before tasks ran).
477-
Note: Only Term.State is captured/reverted, not Meta.State, Core.State or IO effects.
478476
-/
479-
def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception (α × Term.State))) := do
477+
def par {α : Type} (jobs : List (TermElabM α)) : TermElabM (List (Except Exception (α × Term.SavedState))) := do
480478
let initialState ← get
481479
let tasks ← jobs.mapM asTask'
482480
let mut results := []
483481
for task in tasks do
484-
-- Note: We use try/catch instead of `observing` here because TermElabM's `observing`
485-
-- returns `TermElabResult` (not `Except`), which includes SavedState that we don't need.
486482
try
487483
let result ← task.get
488-
let taskState ← get
484+
let taskState ← saveState
489485
results := .ok (result, taskState) :: results
490486
catch e =>
491487
results := .error e :: results
@@ -605,27 +601,24 @@ def parIterGreedy {α : Type} (jobs : List (TacticM α)) :=
605601

606602
/--
607603
Runs a list of TacticM computations in parallel and collects results in the original order,
608-
including the state after each task completes.
604+
including the saved state after each task completes.
609605
610606
Unlike `parIter`, this waits for all tasks to complete and returns results
611607
in the same order as the input list, not in completion order.
612608
613-
Results are wrapped in `Except Exception (α × Tactic.State)` so that errors in individual
609+
Results are wrapped in `Except Exception (α × Tactic.SavedState)` so that errors in individual
614610
tasks don't stop the collection - you can observe all results including which tasks failed.
615611
616612
The final TacticM state is restored to the initial state (before tasks ran).
617-
Note: Only Tactic.State is captured/reverted, not Term.State, Meta.State, Core.State or IO effects.
618613
-/
619-
def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception (α × Tactic.State))) := do
614+
def par {α : Type} (jobs : List (TacticM α)) : TacticM (List (Except Exception (α × Tactic.SavedState))) := do
620615
let initialState ← get
621616
let tasks ← jobs.mapM asTask'
622617
let mut results := []
623618
for task in tasks do
624-
-- Note: We use try/catch instead of `observing` here because TacticM's `observing`
625-
-- (inherited from TermElabM) returns `TermElabResult`, not `Except`.
626619
try
627620
let result ← task.get
628-
let taskState ← get
621+
let taskState ← Tactic.saveState
629622
results := .ok (result, taskState) :: results
630623
catch e =>
631624
results := .error e :: results

src/Lean/Elab/Tactic/Try.lean

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public import Lean.Meta.Tactic.Try
1010
public import Lean.Elab.Tactic.SimpTrace
1111
public import Lean.Elab.Tactic.LibrarySearch
1212
public import Lean.Elab.Tactic.Grind.Main
13+
public import Lean.Elab.Parallel
1314
meta import Lean.Elab.Command
1415
public section
1516
namespace Lean.Elab.Tactic
@@ -697,6 +698,39 @@ where
697698
else
698699
throwError "`attempt_all` failed"
699700

701+
/-- `evalSuggest` for `attempt_all_par` tactic (parallel version). -/
702+
private partial def evalSuggestAttemptAllPar (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do
703+
unless (← read).terminal do
704+
throwError "invalid occurrence of `attempt_all_par` in non-terminal position for `try?` script{indentD (← read).root}"
705+
706+
let ctx ← read
707+
708+
-- Create jobs that each try one tactic and return the suggestion
709+
let jobs : List (TacticM (TSyntax `tactic)) := tacs.toList.map fun tacSeq =>
710+
withOriginalHeartbeats (evalSuggestTacticSeq tacSeq) ctx
711+
712+
-- Run all jobs in parallel - par returns (result, SavedState) for each
713+
let results ← TacticM.par jobs
714+
715+
-- Collect successful results (maintaining order)
716+
let mut acc : Array (TSyntax `tactic) := #[]
717+
let mut firstSaved? : Option SavedState := none
718+
for result in results do
719+
match result with
720+
| .ok (tac, s) =>
721+
trace[try.debug] "`attempt_all_par` argument succeeded{indentD tac}"
722+
acc := appendSuggestion acc tac
723+
if firstSaved?.isNone then
724+
firstSaved? := some s
725+
| .error _ => pure ()
726+
727+
-- Restore first successful state and return suggestions
728+
if let some saved := firstSaved? then
729+
saved.restore
730+
mkTrySuggestions acc
731+
else
732+
throwError "`attempt_all_par` failed"
733+
700734
private partial def evalSuggestDefault (tac : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := do
701735
let kind := tac.raw.getKind
702736
match (← getEvalFns kind) with
@@ -743,6 +777,7 @@ private partial def evalSuggestImpl : TryTactic := fun tac => do
743777
| `(tactic| ($tac:tacticSeq)) => evalSuggestTacticSeq tac
744778
| `(tactic| try $tac:tacticSeq) => evalSuggestTry tac
745779
| `(tactic| attempt_all $[| $tacs]*) => evalSuggestAttemptAll tacs
780+
| `(tactic| attempt_all_par $[| $tacs]*) => evalSuggestAttemptAllPar tacs
746781
| _ =>
747782
let k := tac.raw.getKind
748783
if k == ``Parser.Tactic.seq1 then
@@ -910,7 +945,7 @@ private unsafe def mkTryEvalSuggestStxUnsafe (goal : MVarId) (info : Try.Info) :
910945
let simp ← mkSimpStx
911946
let grind ← mkGrindStx info
912947

913-
let atomic ← `(tactic| attempt_all | $simple:tactic | $simp:tactic | $grind:tactic | simp_all)
948+
let atomic ← `(tactic| attempt_all_par | $simple:tactic | $simp:tactic | $grind:tactic | simp_all)
914949
let atomicSuggestions ← mkAtomicWithSuggestionsStx
915950
let funInds ← mkAllFunIndStx info atomic
916951
let inds ← mkAllIndStx info atomic
@@ -934,7 +969,7 @@ private unsafe def mkTryEvalSuggestStxUnsafe (goal : MVarId) (info : Try.Info) :
934969
if userTactics.isEmpty then
935970
`(tactic| first | $atomic:tactic | $atomicSuggestions:tactic | $funInds:tactic | $inds:tactic | $extra:tactic)
936971
else
937-
let userAttemptAll ← `(tactic| attempt_all $[| $userTactics:tactic]*)
972+
let userAttemptAll ← `(tactic| attempt_all_par $[| $userTactics:tactic]*)
938973
`(tactic| first | $atomic:tactic | $atomicSuggestions:tactic | $funInds:tactic | $inds:tactic | $extra:tactic | $userAttemptAll:tactic)
939974

940975
@[implemented_by mkTryEvalSuggestStxUnsafe]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/-
2+
Test that try? runs user suggestion tactics in parallel via attempt_all_par.
3+
4+
This test uses IO.stdGenRef (a builtin_initialize ref) to demonstrate parallelism:
5+
- Tactic 1 (high prio): waits 1000ms, then checks if the random seed was changed
6+
- Tactic 2 (low prio): immediately sets the seed to a magic value, then succeeds
7+
8+
If sequential: Tactic 1 executes first, waits, seed unchanged, fails.
9+
Then tactic 2 executes, sets seed, succeeds. Only one suggestion.
10+
If parallel: Both tactics start together. Tactic 2 sets seed immediately.
11+
Tactic 1 waits 1000ms, sees changed seed, succeeds. Two suggestions.
12+
-/
13+
module
14+
public import Lean
15+
public meta import Lean.Elab.Tactic.Try
16+
17+
open Lean Meta Elab Tactic Try
18+
19+
-- A goal that built-in tactics won't solve
20+
inductive ParallelTestGoal : Prop where
21+
| mk : ParallelTestGoal
22+
23+
-- Magic seed value to signal parallelism
24+
meta def magicSeed : Nat := 314159265
25+
26+
-- Tactic that waits, then checks if seed was changed
27+
elab "wait_and_check_seed" : tactic => do
28+
IO.sleep 1000
29+
let gen ← IO.stdGenRef.get
30+
let expected := mkStdGen magicSeed
31+
if gen.s1 == expected.s1 && gen.s2 == expected.s2 then
32+
evalTactic (← `(tactic| exact ParallelTestGoal.mk))
33+
else
34+
throwError "seed not changed (sequential execution detected)"
35+
36+
-- Tactic that immediately sets seed and succeeds
37+
elab "set_seed_and_succeed" : tactic => do
38+
IO.setRandSeed magicSeed
39+
evalTactic (← `(tactic| exact ParallelTestGoal.mk))
40+
41+
-- Register both tactics as user suggestions
42+
-- High priority tactic: reset seed first (to ensure clean state), then return waiting tactic
43+
@[local try_suggestion 900]
44+
meta def waitAndCheckSolver (_goal : MVarId) (_info : Try.Info) : MetaM (Array (TSyntax `tactic)) := do
45+
-- Reset to a different seed to ensure we're testing actual communication
46+
IO.setRandSeed 0
47+
return #[← `(tactic| wait_and_check_seed)]
48+
49+
-- Low priority tactic returns the seed-setting tactic
50+
@[local try_suggestion 800]
51+
meta def setFlagSolver (_goal : MVarId) (_info : Try.Info) : MetaM (Array (TSyntax `tactic)) := do
52+
return #[← `(tactic| set_seed_and_succeed)]
53+
54+
-- If parallel: both tactics succeed (tactic 1 sees seed change after waiting)
55+
-- If sequential: only tactic 2 succeeds (tactic 1 sees unchanged seed)
56+
--
57+
-- EXPECTED ON MASTER (sequential): Only one suggestion
58+
-- EXPECTED ON try_par (parallel): Two suggestions
59+
/--
60+
info: Try these:
61+
[apply] wait_and_check_seed
62+
[apply] set_seed_and_succeed
63+
-/
64+
#guard_msgs in
65+
example : ParallelTestGoal := by
66+
try?

0 commit comments

Comments
 (0)