Skip to content

Commit d6fc6e6

Browse files
kim-emclaude
andcommitted
perf: parallelize rw? tactic
Use `MetaM.parIterWithCancel` to try all candidate rewrites in parallel while preserving deterministic result ordering. When an rfl-closeable result is found (and `stopAtRfl` is true), or the maximum number of results is reached, remaining tasks are cancelled. This removes the old sequential `takeListAux` implementation along with the heartbeat-based early termination and `RewriteResultConfig` structure. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 272f0f5 commit d6fc6e6

File tree

1 file changed

+33
-49
lines changed

1 file changed

+33
-49
lines changed

src/Lean/Meta/Tactic/Rewrites.lean

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public import Lean.Meta.Tactic.Refl
1313
public import Lean.Meta.Tactic.SolveByElim
1414
public import Lean.Meta.Tactic.TryThis
1515
public import Lean.Util.Heartbeats
16+
public import Lean.Elab.Parallel
1617

1718
public section
1819

@@ -286,61 +287,44 @@ def RewriteResult.addSuggestion (ref : Syntax) (r : RewriteResult)
286287
(type? := r.newGoal.toLOption) (origSpan? := ← getRef)
287288
(checkState? := checkState?.getD (← saveState))
288289

289-
structure RewriteResultConfig where
290-
stopAtRfl : Bool
291-
max : Nat
292-
minHeartbeats : Nat
293-
goal : MVarId
294-
target : Expr
295-
side : SideConditions := .solveByElim
296-
mctx : MetavarContext
290+
/--
291+
Find lemmas which can rewrite the goal.
297292
298-
def takeListAux (cfg : RewriteResultConfig) (seen : Std.HashMap String Unit) (acc : Array RewriteResult)
299-
(xs : List ((Expr ⊕ Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
300-
let mut seen := seen
301-
let mut acc := acc
302-
for (lem, symm, weight) in xs do
303-
if (← getRemainingHeartbeats) < cfg.minHeartbeats then
304-
return acc
305-
if acc.size ≥ cfg.max then
306-
return acc
307-
let res ←
308-
withoutModifyingState <| withMCtx cfg.mctx do
309-
rwLemma cfg.mctx cfg.goal cfg.target cfg.side lem symm weight
310-
match res with
311-
| none => continue
312-
| some r =>
313-
let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult
314-
if seen.contains s then
315-
continue
316-
let rfl? ← dischargableWithRfl? r.mctx r.result.eNew
317-
if cfg.stopAtRfl then
318-
if rfl? then
319-
return #[r]
320-
else
321-
seen := seen.insert s ()
322-
acc := acc.push r
323-
else
324-
seen := seen.insert s ()
325-
acc := acc.push r
326-
return acc
327-
328-
/-- Find lemmas which can rewrite the goal. -/
293+
Runs all candidates in parallel, iterates through results in order.
294+
Cancels remaining tasks and returns immediately if `stopAtRfl` is true and
295+
an rfl-closeable result is found. Collects up to `max` unique results.
296+
-/
329297
def findRewrites (hyps : Array (Expr × Bool × Nat))
330298
(moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection))
331299
(goal : MVarId) (target : Expr)
332300
(forbidden : NameSet := ∅) (side : SideConditions := .solveByElim)
333-
(stopAtRfl : Bool) (max : Nat := 20)
334-
(leavePercentHeartbeats : Nat := 10) : MetaM (List RewriteResult) := do
301+
(stopAtRfl : Bool) (max : Nat := 20) : MetaM (List RewriteResult) := do
335302
let mctx ← getMCtx
336303
let candidates ← rewriteCandidates hyps moduleRef target forbidden
337-
let minHeartbeats : Nat ←
338-
if (← getMaxHeartbeats) = 0 then
339-
pure 0
340-
else
341-
pure <| leavePercentHeartbeats * (← getRemainingHeartbeats) / 100
342-
let cfg : RewriteResultConfig :=
343-
{ stopAtRfl, minHeartbeats, max, mctx, goal, target, side }
344-
return (← takeListAux cfg {} (Array.mkEmpty max) candidates.toList).toList
304+
-- Create parallel jobs for each candidate
305+
let jobs := candidates.toList.map fun (lem, symm, weight) => do
306+
withoutModifyingState <| withMCtx mctx do
307+
let some r ← rwLemma mctx goal target side lem symm weight
308+
| return none
309+
let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult
310+
return some (r, s)
311+
let (cancel, iter) ← MetaM.parIterWithCancelChunked jobs (maxTasks := 128)
312+
let mut seen : Std.HashMap String Unit := {}
313+
let mut acc : Array RewriteResult := Array.mkEmpty max
314+
for result in iter.allowNontermination do
315+
if acc.size ≥ max then
316+
cancel
317+
break
318+
match result with
319+
| .error _ => continue
320+
| .ok none => continue
321+
| .ok (some (r, s)) =>
322+
if seen.contains s then continue
323+
seen := seen.insert s ()
324+
if stopAtRfl && r.rfl? then
325+
cancel
326+
return [r]
327+
acc := acc.push r
328+
return acc.toList
345329

346330
end Lean.Meta.Rewrites

0 commit comments

Comments
 (0)