@@ -155,6 +155,38 @@ namespace Lean.Core.CoreM
155155
156156open Std.Iterators
157157
158+ /--
159+ Internal state for an iterator over chunked tasks for CoreM.
160+ Yields individual results while internally managing chunk boundaries.
161+ -/
162+ private structure ChunkedTaskIterator (α : Type ) where
163+ chunkTasks : List (Task (CoreM (List (Except Exception α))))
164+ currentResults : List (Except Exception α)
165+
166+ private instance {α : Type} : Iterator (ChunkedTaskIterator α) CoreM (Except Exception α) where
167+ IsPlausibleStep _
168+ | .yield _ _ => True
169+ | .skip _ => True -- Allow skip for empty chunks
170+ | .done => True
171+ step it := do
172+ match it.internalState.currentResults with
173+ | r :: rest =>
174+ pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } CoreM (Except Exception α)) r, trivial⟩
175+ | [] =>
176+ match it.internalState.chunkTasks with
177+ | [] => pure <| .deflate ⟨.done, trivial⟩
178+ | task :: rest =>
179+ try
180+ let chunkResults ← task.get
181+ match chunkResults with
182+ | [] =>
183+ -- Empty chunk, skip to try next
184+ pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)), trivial⟩
185+ | r :: rs =>
186+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } CoreM (Except Exception α)) r, trivial⟩
187+ catch e =>
188+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } CoreM (Except Exception α)) (.error e), trivial⟩
189+
158190/--
159191Runs a list of CoreM computations in parallel and returns:
160192* a combined cancellation hook for all tasks, and
@@ -180,6 +212,29 @@ def parIterWithCancel {α : Type} (jobs : List (CoreM α)) := do
180212 pure (Except.error e)
181213 return (combinedCancel, iterWithErrors)
182214
215+ /--
216+ Runs a list of CoreM computations in parallel with chunking and returns:
217+ * a combined cancellation hook for all tasks, and
218+ * an iterator that yields results in original order.
219+
220+ Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
221+ Each chunk runs its jobs sequentially, but chunks run in parallel.
222+
223+ **Parameters:**
224+ - `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
225+ - `minChunkSize`: Minimum jobs per chunk. Default 1.
226+ -/
227+ def parIterWithCancelChunked {α : Type } (jobs : List (CoreM α))
228+ (maxTasks : Nat := 0 ) (minChunkSize : Nat := 1 ) := do
229+ let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
230+ let chunks := toChunks jobs chunkSize
231+ let chunkJobs : List (CoreM (List (Except Exception α))) :=
232+ chunks.map fun (chunk : List (CoreM α)) => chunk.mapM (observing ·)
233+ let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
234+ let combinedCancel := cancels.forM id
235+ let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) CoreM (Except Exception α)
236+ return (combinedCancel, flatIter)
237+
183238/--
184239Runs a list of CoreM computations in parallel (without cancellation hook).
185240
@@ -347,6 +402,41 @@ namespace Lean.Meta.MetaM
347402
348403open Std.Iterators
349404
405+ /--
406+ Internal state for an iterator over chunked tasks for MetaM.
407+ Yields individual results while internally managing chunk boundaries.
408+ -/
409+ structure ChunkedTaskIterator (α : Type ) where
410+ chunkTasks : List (Task (MetaM (List (Except Exception α))))
411+ currentResults : List (Except Exception α)
412+
413+ instance {α : Type} : Iterator (ChunkedTaskIterator α) MetaM (Except Exception α) where
414+ IsPlausibleStep _
415+ | .yield _ _ => True
416+ | .skip _ => True
417+ | .done => True
418+ step it := do
419+ match it.internalState.currentResults with
420+ | r :: rest =>
421+ pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } MetaM (Except Exception α)) r, trivial⟩
422+ | [] =>
423+ match it.internalState.chunkTasks with
424+ | [] => pure <| .deflate ⟨.done, trivial⟩
425+ | task :: rest =>
426+ try
427+ let chunkResults ← task.get
428+ match chunkResults with
429+ | [] =>
430+ pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)), trivial⟩
431+ | r :: rs =>
432+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } MetaM (Except Exception α)) r, trivial⟩
433+ catch e =>
434+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } MetaM (Except Exception α)) (.error e), trivial⟩
435+
436+ instance {α : Type} {n : Type → Type u} [Monad n] [MonadLiftT MetaM n] :
437+ IteratorLoopPartial (ChunkedTaskIterator α) MetaM n :=
438+ .defaultImplementation
439+
350440/-- Internal: run jobs in parallel without chunking, returning state. -/
351441private def parCore {α : Type } (jobs : List (MetaM α)) :
352442 MetaM (List (Except Exception (α × Meta.SavedState))) := do
@@ -467,7 +557,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
467557def parIterWithCancel {α : Type } (jobs : List (MetaM α)) := do
468558 let (cancels, tasks) := (← jobs.mapM asTask).unzip
469559 let combinedCancel := cancels.forM id
470- -- Create iterator that processes tasks sequentially
471560 let iterWithErrors := tasks.iter.mapM fun (task : Task (MetaM α)) => do
472561 try
473562 let result ← task.get
@@ -476,6 +565,29 @@ def parIterWithCancel {α : Type} (jobs : List (MetaM α)) := do
476565 pure (Except.error e)
477566 return (combinedCancel, iterWithErrors)
478567
568+ /--
569+ Runs a list of MetaM computations in parallel with chunking and returns:
570+ * a combined cancellation hook for all tasks, and
571+ * an iterator that yields results in original order.
572+
573+ Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
574+ Each chunk runs its jobs sequentially, but chunks run in parallel.
575+
576+ **Parameters:**
577+ - `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
578+ - `minChunkSize`: Minimum jobs per chunk. Default 1.
579+ -/
580+ def parIterWithCancelChunked {α : Type } (jobs : List (MetaM α))
581+ (maxTasks : Nat := 0 ) (minChunkSize : Nat := 1 ) := do
582+ let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
583+ let chunks := toChunks jobs chunkSize
584+ let chunkJobs : List (MetaM (List (Except Exception α))) :=
585+ chunks.map fun (chunk : List (MetaM α)) => chunk.mapM (observing ·)
586+ let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
587+ let combinedCancel := cancels.forM id
588+ let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) MetaM (Except Exception α)
589+ return (combinedCancel, flatIter)
590+
479591/--
480592Runs a list of MetaM computations in parallel (without cancellation hook).
481593
@@ -540,6 +652,37 @@ namespace Lean.Elab.Term.TermElabM
540652
541653open Std.Iterators
542654
655+ /--
656+ Internal state for an iterator over chunked tasks for TermElabM.
657+ Yields individual results while internally managing chunk boundaries.
658+ -/
659+ private structure ChunkedTaskIterator (α : Type ) where
660+ chunkTasks : List (Task (TermElabM (List (Except Exception α))))
661+ currentResults : List (Except Exception α)
662+
663+ private instance {α : Type} : Iterator (ChunkedTaskIterator α) TermElabM (Except Exception α) where
664+ IsPlausibleStep _
665+ | .yield _ _ => True
666+ | .skip _ => True
667+ | .done => True
668+ step it := do
669+ match it.internalState.currentResults with
670+ | r :: rest =>
671+ pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } TermElabM (Except Exception α)) r, trivial⟩
672+ | [] =>
673+ match it.internalState.chunkTasks with
674+ | [] => pure <| .deflate ⟨.done, trivial⟩
675+ | task :: rest =>
676+ try
677+ let chunkResults ← task.get
678+ match chunkResults with
679+ | [] =>
680+ pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)), trivial⟩
681+ | r :: rs =>
682+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } TermElabM (Except Exception α)) r, trivial⟩
683+ catch e =>
684+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } TermElabM (Except Exception α)) (.error e), trivial⟩
685+
543686/--
544687Runs a list of TermElabM computations in parallel and returns:
545688* a combined cancellation hook for all tasks, and
@@ -557,7 +700,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
557700def parIterWithCancel {α : Type } (jobs : List (TermElabM α)) := do
558701 let (cancels, tasks) := (← jobs.mapM asTask).unzip
559702 let combinedCancel := cancels.forM id
560- -- Create iterator that processes tasks sequentially
561703 let iterWithErrors := tasks.iter.mapM fun (task : Task (TermElabM α)) => do
562704 try
563705 let result ← task.get
@@ -566,6 +708,34 @@ def parIterWithCancel {α : Type} (jobs : List (TermElabM α)) := do
566708 pure (Except.error e)
567709 return (combinedCancel, iterWithErrors)
568710
711+ /--
712+ Runs a list of TermElabM computations in parallel with chunking and returns:
713+ * a combined cancellation hook for all tasks, and
714+ * an iterator that yields results in original order.
715+
716+ Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
717+ Each chunk runs its jobs sequentially, but chunks run in parallel.
718+
719+ **Parameters:**
720+ - `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
721+ - `minChunkSize`: Minimum jobs per chunk. Default 1.
722+ -/
723+ def parIterWithCancelChunked {α : Type } (jobs : List (TermElabM α))
724+ (maxTasks : Nat := 0 ) (minChunkSize : Nat := 1 ) := do
725+ let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
726+ let chunks := toChunks jobs chunkSize
727+ let chunkJobs : List (TermElabM (List (Except Exception α))) :=
728+ chunks.map fun (chunk : List (TermElabM α)) => chunk.mapM fun job => do
729+ try
730+ let a ← job
731+ pure (.ok a)
732+ catch e =>
733+ pure (.error e)
734+ let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
735+ let combinedCancel := cancels.forM id
736+ let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TermElabM (Except Exception α)
737+ return (combinedCancel, flatIter)
738+
569739/--
570740Runs a list of TermElabM computations in parallel (without cancellation hook).
571741
@@ -741,6 +911,37 @@ namespace Lean.Elab.Tactic.TacticM
741911
742912open Std.Iterators
743913
914+ /--
915+ Internal state for an iterator over chunked tasks for TacticM.
916+ Yields individual results while internally managing chunk boundaries.
917+ -/
918+ private structure ChunkedTaskIterator (α : Type ) where
919+ chunkTasks : List (Task (TacticM (List (Except Exception α))))
920+ currentResults : List (Except Exception α)
921+
922+ private instance {α : Type} : Iterator (ChunkedTaskIterator α) TacticM (Except Exception α) where
923+ IsPlausibleStep _
924+ | .yield _ _ => True
925+ | .skip _ => True
926+ | .done => True
927+ step it := do
928+ match it.internalState.currentResults with
929+ | r :: rest =>
930+ pure <| .deflate ⟨.yield (toIterM { it.internalState with currentResults := rest } TacticM (Except Exception α)) r, trivial⟩
931+ | [] =>
932+ match it.internalState.chunkTasks with
933+ | [] => pure <| .deflate ⟨.done, trivial⟩
934+ | task :: rest =>
935+ try
936+ let chunkResults ← task.get
937+ match chunkResults with
938+ | [] =>
939+ pure <| .deflate ⟨.skip (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)), trivial⟩
940+ | r :: rs =>
941+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := rs } TacticM (Except Exception α)) r, trivial⟩
942+ catch e =>
943+ pure <| .deflate ⟨.yield (toIterM { chunkTasks := rest, currentResults := [] } TacticM (Except Exception α)) (.error e), trivial⟩
944+
744945/--
745946Runs a list of TacticM computations in parallel and returns:
746947* a combined cancellation hook for all tasks, and
@@ -758,7 +959,6 @@ The iterator will terminate after all jobs complete (assuming they all do comple
758959def parIterWithCancel {α : Type } (jobs : List (TacticM α)) := do
759960 let (cancels, tasks) := (← jobs.mapM asTask).unzip
760961 let combinedCancel := cancels.forM id
761- -- Create iterator that processes tasks sequentially
762962 let iterWithErrors := tasks.iter.mapM fun (task : Task (TacticM α)) => do
763963 try
764964 let result ← task.get
@@ -767,6 +967,34 @@ def parIterWithCancel {α : Type} (jobs : List (TacticM α)) := do
767967 pure (Except.error e)
768968 return (combinedCancel, iterWithErrors)
769969
970+ /--
971+ Runs a list of TacticM computations in parallel with chunking and returns:
972+ * a combined cancellation hook for all tasks, and
973+ * an iterator that yields results in original order.
974+
975+ Unlike `parIterWithCancel`, this groups jobs into chunks to reduce task overhead.
976+ Each chunk runs its jobs sequentially, but chunks run in parallel.
977+
978+ **Parameters:**
979+ - `maxTasks`: Maximum number of parallel tasks (chunks). Default 0 means one task per job.
980+ - `minChunkSize`: Minimum jobs per chunk. Default 1.
981+ -/
982+ def parIterWithCancelChunked {α : Type } (jobs : List (TacticM α))
983+ (maxTasks : Nat := 0 ) (minChunkSize : Nat := 1 ) := do
984+ let chunkSize := computeChunkSize jobs.length maxTasks minChunkSize
985+ let chunks := toChunks jobs chunkSize
986+ let chunkJobs : List (TacticM (List (Except Exception α))) :=
987+ chunks.map fun (chunk : List (TacticM α)) => chunk.mapM fun job => do
988+ try
989+ let a ← job
990+ pure (.ok a)
991+ catch e =>
992+ pure (.error e)
993+ let (cancels, tasks) := (← chunkJobs.mapM asTask).unzip
994+ let combinedCancel := cancels.forM id
995+ let flatIter := toIterM (ChunkedTaskIterator.mk tasks []) TacticM (Except Exception α)
996+ return (combinedCancel, flatIter)
997+
770998/--
771999Runs a list of TacticM computations in parallel (without cancellation hook).
7721000
0 commit comments