@@ -13,6 +13,7 @@ public import Lean.Meta.Tactic.Refl
1313public import Lean.Meta.Tactic.SolveByElim
1414public import Lean.Meta.Tactic.TryThis
1515public import Lean.Util.Heartbeats
16+ public import Lean.Elab.Parallel
1617
1718public section
1819
@@ -286,61 +287,44 @@ def RewriteResult.addSuggestion (ref : Syntax) (r : RewriteResult)
286287 (type? := r.newGoal.toLOption) (origSpan? := ← getRef)
287288 (checkState? := checkState?.getD (← saveState))
288289
289- structure RewriteResultConfig where
290- stopAtRfl : Bool
291- max : Nat
292- minHeartbeats : Nat
293- goal : MVarId
294- target : Expr
295- side : SideConditions := .solveByElim
296- mctx : MetavarContext
290+ /--
291+ Find lemmas which can rewrite the goal.
297292
298- def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
299- (xs : List ((Expr ⊕ Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
300- let mut seen := seen
301- let mut acc := acc
302- for (lem, symm, weight) in xs do
303- if (← getRemainingHeartbeats) < cfg.minHeartbeats then
304- return acc
305- if acc.size ≥ cfg.max then
306- return acc
307- let res ←
308- withoutModifyingState <| withMCtx cfg.mctx do
309- rwLemma cfg.mctx cfg.goal cfg.target cfg.side lem symm weight
310- match res with
311- | none => continue
312- | some r =>
313- let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult
314- if seen.contains s then
315- continue
316- let rfl? ← dischargableWithRfl? r.mctx r.result.eNew
317- if cfg.stopAtRfl then
318- if rfl? then
319- return #[r]
320- else
321- seen := seen.insert s ()
322- acc := acc.push r
323- else
324- seen := seen.insert s ()
325- acc := acc.push r
326- return acc
327-
328- /-- Find lemmas which can rewrite the goal. -/
293+ Runs all candidates in parallel, iterates through results in order.
294+ Cancels remaining tasks and returns immediately if `stopAtRfl` is true and
295+ an rfl-closeable result is found. Collects up to `max` unique results.
296+ -/
329297def findRewrites (hyps : Array (Expr × Bool × Nat))
330298 (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection))
331299 (goal : MVarId) (target : Expr)
332300 (forbidden : NameSet := ∅) (side : SideConditions := .solveByElim)
333- (stopAtRfl : Bool) (max : Nat := 20 )
334- (leavePercentHeartbeats : Nat := 10 ) : MetaM (List RewriteResult) := do
301+ (stopAtRfl : Bool) (max : Nat := 20 ) : MetaM (List RewriteResult) := do
335302 let mctx ← getMCtx
336303 let candidates ← rewriteCandidates hyps moduleRef target forbidden
337- let minHeartbeats : Nat ←
338- if (← getMaxHeartbeats) = 0 then
339- pure 0
340- else
341- pure <| leavePercentHeartbeats * (← getRemainingHeartbeats) / 100
342- let cfg : RewriteResultConfig :=
343- { stopAtRfl, minHeartbeats, max, mctx, goal, target, side }
344- return (← takeListAux cfg {} (Array.mkEmpty max) candidates.toList).toList
304+ -- Create parallel jobs for each candidate
305+ let jobs := candidates.toList.map fun (lem, symm, weight) => do
306+ withoutModifyingState <| withMCtx mctx do
307+ let some r ← rwLemma mctx goal target side lem symm weight
308+ | return none
309+ let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult
310+ return some (r, s)
311+ let (cancel, iter) ← MetaM.parIterWithCancelChunked jobs (maxTasks := 128 )
312+ let mut seen : Std.HashMap String Unit := {}
313+ let mut acc : Array RewriteResult := Array.mkEmpty max
314+ for result in iter.allowNontermination do
315+ if acc.size ≥ max then
316+ cancel
317+ break
318+ match result with
319+ | .error _ => continue
320+ | .ok none => continue
321+ | .ok (some (r, s)) =>
322+ if seen.contains s then continue
323+ seen := seen.insert s ()
324+ if stopAtRfl && r.rfl? then
325+ cancel
326+ return [r]
327+ acc := acc.push r
328+ return acc.toList
345329
346330end Lean.Meta.Rewrites
0 commit comments