Skip to content

Commit 272f0f5

Browse files
kim-emclaude
andcommitted
feat: add chunked variants of parIterWithCancel
Add `parIterWithCancelChunked` functions for CoreM, MetaM, TermElabM, and TacticM that support chunking jobs into groups to reduce task creation overhead. The original `parIterWithCancel` functions remain unchanged for backward compatibility. The new chunked variants accept `maxTasks` and `minChunkSize` parameters to control parallelism. This enables PRs that use `parIterWithCancel` (like parallel library search and rewrites) to benefit from chunking by switching to the new `parIterWithCancelChunked` function with `maxTasks := 128`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 1d3fda4 commit 272f0f5

File tree

1 file changed

+231
-3
lines changed

1 file changed

+231
-3
lines changed

src/Lean/Elab/Parallel.lean

Lines changed: 231 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,38 @@ namespace Lean.Core.CoreM
155155

156156
open Std.Iterators
157157

158+
/--
159+
Internal state for an iterator over chunked tasks for CoreM.
160+
Yields individual results while internally managing chunk boundaries.
161+
-/
162+
private structure ChunkedTaskIterator (α : Type) where
163+
chunkTasks : List (Task (CoreM (List (Except Exception α))))
164+
currentResults : List (Except Exception α)
165+
166+
private instance {α : Type} : Iterator (ChunkedTaskIterator α) CoreM (Except Exception α) where
167+
IsPlausibleStep _
168+
| .yield _ _ => True
169+
| .skip _ => True -- Allow skip for empty chunks
170+
| .done => True
171+
step it := do
172+
match it.internalState.currentResults with
173+
| r :: rest =>
174+
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } CoreM (Except Exception α)) r, trivial⟩
175+
| [] =>
176+
match it.internalState.chunkTasks with
177+
| [] => pure <| .deflate ⟨.done, trivial⟩
178+
| task :: rest =>
179+
try
180+
let chunkResults ← task.get
181+
match chunkResults with
182+
| [] =>
183+
-- Empty chunk, skip to try next
184+
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)), trivial⟩
185+
| r :: rs =>
186+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } CoreM (Except Exception α)) r, trivial⟩
187+
catch e =>
188+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)) (.error e), trivial⟩
189+
158190
/--
159191
Runs a list of CoreM computations in parallel and returns:
160192
* a combined cancellation hook for all tasks, and
@@ -180,6 +212,29 @@ def parIterWithCancel {α : Type} (jobs : List (CoreM α)) := do
180212
pure (Except.error e)
181213
return (combinedCancel, iterWithErrors)
182214

215+
/--
216+
Runs a list of CoreM computations in parallel with chunking and returns:
217+
* a combined cancellation hook for all tasks, and
218+
* an iterator that yields results in original order.
219+
220+
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
221+
Each chunk runs its jobs sequentially, but chunks run in parallel.
222+
223+
**Parameters:**
224+
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
225+
- `minChunkSize`: Minimum jobs per chunk. Default 1.
226+
-/
227+
def parIterWithCancelChunked {α : Type} (jobs : List (CoreM α))
228+
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
229+
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
230+
let chunks := toChunks jobs chunkSize
231+
let chunkJobs : List (CoreM (List (Except Exception α))) :=
232+
chunks.map fun (chunk : List (CoreM α)) => chunk.mapM (observing ·)
233+
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
234+
let combinedCancel := cancels.forM id
235+
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) CoreM (Except Exception α)
236+
return (combinedCancel, flatIter)
237+
183238
/--
184239
Runs a list of CoreM computations in parallel (without cancellation hook).
185240
@@ -347,6 +402,41 @@ namespace Lean.Meta.MetaM
347402

348403
open Std.Iterators
349404

405+
/--
406+
Internal state for an iterator over chunked tasks for MetaM.
407+
Yields individual results while internally managing chunk boundaries.
408+
-/
409+
structure ChunkedTaskIterator (α : Type) where
410+
chunkTasks : List (Task (MetaM (List (Except Exception α))))
411+
currentResults : List (Except Exception α)
412+
413+
instance {α : Type} : Iterator (ChunkedTaskIterator α) MetaM (Except Exception α) where
414+
IsPlausibleStep _
415+
| .yield _ _ => True
416+
| .skip _ => True
417+
| .done => True
418+
step it := do
419+
match it.internalState.currentResults with
420+
| r :: rest =>
421+
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } MetaM (Except Exception α)) r, trivial⟩
422+
| [] =>
423+
match it.internalState.chunkTasks with
424+
| [] => pure <| .deflate ⟨.done, trivial⟩
425+
| task :: rest =>
426+
try
427+
let chunkResults ← task.get
428+
match chunkResults with
429+
| [] =>
430+
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)), trivial⟩
431+
| r :: rs =>
432+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } MetaM (Except Exception α)) r, trivial⟩
433+
catch e =>
434+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)) (.error e), trivial⟩
435+
436+
instance {α : Type} {n : TypeType u} [Monad n] [MonadLiftT MetaM n] :
437+
IteratorLoopPartial (ChunkedTaskIterator α) MetaM n :=
438+
.defaultImplementation
439+
350440
/-- Internal: run jobs in parallel without chunking, returning state. -/
351441
private def parCore {α : Type} (jobs : List (MetaM α)) :
352442
MetaM (List (Except Exception (α × Meta.SavedState))) := do
@@ -467,7 +557,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
467557
def parIterWithCancel {α : Type} (jobs : List (MetaM α)) := do
468558
let (cancels, tasks) := (← jobs.mapM asTask).unzip
469559
let combinedCancel := cancels.forM id
470-
-- Create iterator that processes tasks sequentially
471560
let iterWithErrors := tasks.iter.mapM fun (task : Task (MetaM α)) => do
472561
try
473562
let result ← task.get
@@ -476,6 +565,29 @@ def parIterWithCancel {α : Type} (jobs : List (MetaM α)) := do
476565
pure (Except.error e)
477566
return (combinedCancel, iterWithErrors)
478567

568+
/--
569+
Runs a list of MetaM computations in parallel with chunking and returns:
570+
* a combined cancellation hook for all tasks, and
571+
* an iterator that yields results in original order.
572+
573+
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
574+
Each chunk runs its jobs sequentially, but chunks run in parallel.
575+
576+
**Parameters:**
577+
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
578+
- `minChunkSize`: Minimum jobs per chunk. Default 1.
579+
-/
580+
def parIterWithCancelChunked {α : Type} (jobs : List (MetaM α))
581+
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
582+
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
583+
let chunks := toChunks jobs chunkSize
584+
let chunkJobs : List (MetaM (List (Except Exception α))) :=
585+
chunks.map fun (chunk : List (MetaM α)) => chunk.mapM (observing ·)
586+
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
587+
let combinedCancel := cancels.forM id
588+
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) MetaM (Except Exception α)
589+
return (combinedCancel, flatIter)
590+
479591
/--
480592
Runs a list of MetaM computations in parallel (without cancellation hook).
481593
@@ -540,6 +652,37 @@ namespace Lean.Elab.Term.TermElabM
540652

541653
open Std.Iterators
542654

655+
/--
656+
Internal state for an iterator over chunked tasks for TermElabM.
657+
Yields individual results while internally managing chunk boundaries.
658+
-/
659+
private structure ChunkedTaskIterator (α : Type) where
660+
chunkTasks : List (Task (TermElabM (List (Except Exception α))))
661+
currentResults : List (Except Exception α)
662+
663+
private instance {α : Type} : Iterator (ChunkedTaskIterator α) TermElabM (Except Exception α) where
664+
IsPlausibleStep _
665+
| .yield _ _ => True
666+
| .skip _ => True
667+
| .done => True
668+
step it := do
669+
match it.internalState.currentResults with
670+
| r :: rest =>
671+
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } TermElabM (Except Exception α)) r, trivial⟩
672+
| [] =>
673+
match it.internalState.chunkTasks with
674+
| [] => pure <| .deflate ⟨.done, trivial⟩
675+
| task :: rest =>
676+
try
677+
let chunkResults ← task.get
678+
match chunkResults with
679+
| [] =>
680+
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)), trivial⟩
681+
| r :: rs =>
682+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } TermElabM (Except Exception α)) r, trivial⟩
683+
catch e =>
684+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)) (.error e), trivial⟩
685+
543686
/--
544687
Runs a list of TermElabM computations in parallel and returns:
545688
* a combined cancellation hook for all tasks, and
@@ -557,7 +700,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
557700
def parIterWithCancel {α : Type} (jobs : List (TermElabM α)) := do
558701
let (cancels, tasks) := (← jobs.mapM asTask).unzip
559702
let combinedCancel := cancels.forM id
560-
-- Create iterator that processes tasks sequentially
561703
let iterWithErrors := tasks.iter.mapM fun (task : Task (TermElabM α)) => do
562704
try
563705
let result ← task.get
@@ -566,6 +708,34 @@ def parIterWithCancel {α : Type} (jobs : List (TermElabM α)) := do
566708
pure (Except.error e)
567709
return (combinedCancel, iterWithErrors)
568710

711+
/--
712+
Runs a list of TermElabM computations in parallel with chunking and returns:
713+
* a combined cancellation hook for all tasks, and
714+
* an iterator that yields results in original order.
715+
716+
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
717+
Each chunk runs its jobs sequentially, but chunks run in parallel.
718+
719+
**Parameters:**
720+
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
721+
- `minChunkSize`: Minimum jobs per chunk. Default 1.
722+
-/
723+
def parIterWithCancelChunked {α : Type} (jobs : List (TermElabM α))
724+
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
725+
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
726+
let chunks := toChunks jobs chunkSize
727+
let chunkJobs : List (TermElabM (List (Except Exception α))) :=
728+
chunks.map fun (chunk : List (TermElabM α)) => chunk.mapM fun job => do
729+
try
730+
let a ← job
731+
pure (.ok a)
732+
catch e =>
733+
pure (.error e)
734+
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
735+
let combinedCancel := cancels.forM id
736+
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TermElabM (Except Exception α)
737+
return (combinedCancel, flatIter)
738+
569739
/--
570740
Runs a list of TermElabM computations in parallel (without cancellation hook).
571741
@@ -741,6 +911,37 @@ namespace Lean.Elab.Tactic.TacticM
741911

742912
open Std.Iterators
743913

914+
/--
915+
Internal state for an iterator over chunked tasks for TacticM.
916+
Yields individual results while internally managing chunk boundaries.
917+
-/
918+
private structure ChunkedTaskIterator (α : Type) where
919+
chunkTasks : List (Task (TacticM (List (Except Exception α))))
920+
currentResults : List (Except Exception α)
921+
922+
private instance {α : Type} : Iterator (ChunkedTaskIterator α) TacticM (Except Exception α) where
923+
IsPlausibleStep _
924+
| .yield _ _ => True
925+
| .skip _ => True
926+
| .done => True
927+
step it := do
928+
match it.internalState.currentResults with
929+
| r :: rest =>
930+
pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } TacticM (Except Exception α)) r, trivial⟩
931+
| [] =>
932+
match it.internalState.chunkTasks with
933+
| [] => pure <| .deflate ⟨.done, trivial⟩
934+
| task :: rest =>
935+
try
936+
let chunkResults ← task.get
937+
match chunkResults with
938+
| [] =>
939+
pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)), trivial⟩
940+
| r :: rs =>
941+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } TacticM (Except Exception α)) r, trivial⟩
942+
catch e =>
943+
pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)) (.error e), trivial⟩
944+
744945
/--
745946
Runs a list of TacticM computations in parallel and returns:
746947
* a combined cancellation hook for all tasks, and
@@ -758,7 +959,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
758959
def parIterWithCancel {α : Type} (jobs : List (TacticM α)) := do
759960
let (cancels, tasks) := (← jobs.mapM asTask).unzip
760961
let combinedCancel := cancels.forM id
761-
-- Create iterator that processes tasks sequentially
762962
let iterWithErrors := tasks.iter.mapM fun (task : Task (TacticM α)) => do
763963
try
764964
let result ← task.get
@@ -767,6 +967,34 @@ def parIterWithCancel {α : Type} (jobs : List (TacticM α)) := do
767967
pure (Except.error e)
768968
return (combinedCancel, iterWithErrors)
769969

970+
/--
971+
Runs a list of TacticM computations in parallel with chunking and returns:
972+
* a combined cancellation hook for all tasks, and
973+
* an iterator that yields results in original order.
974+
975+
Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
976+
Each chunk runs its jobs sequentially, but chunks run in parallel.
977+
978+
**Parameters:**
979+
- `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
980+
- `minChunkSize`: Minimum jobs per chunk. Default 1.
981+
-/
982+
def parIterWithCancelChunked {α : Type} (jobs : List (TacticM α))
983+
(maxTasks : Nat := 0) (minChunkSize : Nat := 1) := do
984+
let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
985+
let chunks := toChunks jobs chunkSize
986+
let chunkJobs : List (TacticM (List (Except Exception α))) :=
987+
chunks.map fun (chunk : List (TacticM α)) => chunk.mapM fun job => do
988+
try
989+
let a ← job
990+
pure (.ok a)
991+
catch e =>
992+
pure (.error e)
993+
let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
994+
let combinedCancel := cancels.forM id
995+
let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TacticM (Except Exception α)
996+
return (combinedCancel, flatIter)
997+
770998
/--
771999
Runs a list of TacticM computations in parallel (without cancellation hook).
7721000

0 commit comments

Comments
 (0)