Skip to content

Commit 11f7d6d

Browse files
authored
feat: reuse simp cache in grind (#8483)
This PR ensures `grind` reuses the `simp` cache between different calls. Recall that `grind` uses `simp` to normalize terms during internalization.
1 parent e2fc9ba commit 11f7d6d

File tree

6 files changed

+44
-15
lines changed

6 files changed

+44
-15
lines changed

src/Lean/Meta/Tactic/Grind/Main.lean

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,29 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
5454
prop e
5555
}
5656

57+
-- A `simp` discharger that does not use assumptions.
58+
-- We use it to make sure we don't have to reset the `simp` cache used in `grind`.
59+
private def discharge? (e : Expr) : SimpM (Option Expr) := do
60+
let e := e.cleanupAnnotations
61+
let r ← Simp.simp e
62+
if let some p ← Simp.dischargeRfl r.expr then
63+
return some (mkApp4 (mkConst ``Eq.mpr [levelZero]) e r.expr (← r.getProof) p)
64+
else if r.expr.isTrue then
65+
return some (← mkOfEqTrue (← r.getProof))
66+
else
67+
return none
68+
5769
def GrindM.run (x : GrindM α) (params : Params) (fallback : Fallback) : MetaM α := do
5870
let (falseExpr, scState) := shareCommonAlpha (mkConst ``False) {}
5971
let (trueExpr, scState) := shareCommonAlpha (mkConst ``True) scState
6072
let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState
6173
let (btrueExpr, scState) := shareCommonAlpha (mkConst ``Bool.true) scState
6274
let (natZExpr, scState) := shareCommonAlpha (mkNatLit 0) scState
6375
let simprocs := params.normProcs
76+
let simpMethods := Simp.mkMethods simprocs discharge? (wellBehavedDischarge := true)
6477
let simp := params.norm
6578
let config := params.config
66-
x (← mkMethods fallback).toMethodsRef { config, simprocs, simp }
79+
x (← mkMethods fallback).toMethodsRef { config, simpMethods, simp }
6780
|>.run' { scState, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr }
6881

6982
private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do
@@ -167,7 +180,7 @@ def main (mvarId : MVarId) (params : Params) (fallback : Fallback) : MetaM Resul
167180
let issues := (← get).issues
168181
let trace := (← get).trace
169182
let counters := (← get).counters
170-
let simp := (← get).simpStats
183+
let simp := { (← get).simp with }
171184
if failure?.isNone then
172185
-- If there are no failures and diagnostics are enabled, we still report the performance counters.
173186
if (← isDiagnosticsEnabled) then

src/Lean/Meta/Tactic/Grind/MatchCond.lean

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ private partial def isStatisfied (e : Expr) : GoalM Bool := do
262262
e := b
263263
return false
264264

265+
-- TODO: we don't have support for offset equalities
266+
265267
/-- Constructs a proof for a satisfied `match`-expression condition. -/
266268
private partial def mkMatchCondProof? (e : Expr) : GoalM (Option Expr) := do
267269
let_expr Grind.MatchCond f ← e | return none
@@ -286,6 +288,8 @@ where
286288
| reportIssue! "found term that has not been internalized{indentExpr lhs}\nwhile trying to construct a proof for `MatchCond`{indentExpr e}"
287289
return none
288290
let isHEq := α?.isSome
291+
unless (← hasSameType root.self rhs) do
292+
return none
289293
let h ← if isHEq then
290294
mkEqOfHEq (← mkHEqTrans (← mkHEqProof root.self lhs) h) (check := false)
291295
else

src/Lean/Meta/Tactic/Grind/Simp.lean

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ import Lean.Meta.Tactic.Grind.MarkNestedProofs
1414
import Lean.Meta.Tactic.Grind.Canon
1515

1616
namespace Lean.Meta.Grind
17+
1718
/-- Simplifies the given expression using the `grind` simprocs and normalization theorems. -/
18-
private def simpCore (e : Expr) : GrindM Simp.Result := do
19-
let simpStats := (← get).simpStats
20-
let (r, simpStats) ← Meta.simp e (← readThe Context).simp (← readThe Context).simprocs (stats := simpStats)
21-
modify fun s => { s with simpStats }
19+
private def simpCore (e : Expr) : GrindM Simp.Result := do profileitM Exception "grind simp" (← getOptions) do
20+
let simp ← modifyGet fun s => (s.simp, { s with simp := {} })
21+
let ctx := (← readThe Context).simp
22+
let (r, simp) ← Simp.mainCore e ctx simp (methods := (← readThe Context).simpMethods)
23+
modify fun s => { s with simp }
2224
return r
2325

2426
/--

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ register_builtin_option grind.warning : Bool := {
5656
/-- Context for `GrindM` monad. -/
5757
structure Context where
5858
simp : Simp.Context
59-
simprocs : Array Simp.Simprocs
59+
simpMethods : Simp.Methods
6060
config : Grind.Config
6161
/--
6262
If `cheapCases` is `true`, `grind` only applies `cases` to types that contain
@@ -124,7 +124,7 @@ structure State where
124124
Remark: we currently do not reuse congruence theorems
125125
-/
126126
congrThms : PHashMap CongrTheoremCacheKey CongrTheorem := {}
127-
simpStats : Simp.Stats := {}
127+
simp : Simp.State := {}
128128
trueExpr : Expr
129129
falseExpr : Expr
130130
natZExpr : Expr

src/Lean/Meta/Tactic/Simp/Main.lean

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -822,17 +822,22 @@ private def updateUsedSimpsWithZetaDelta (ctx : Context) (stats : Stats) : MetaM
822822
else
823823
x
824824

825-
def main (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Result × Stats) := do
825+
def mainCore (e : Expr) (ctx : Context) (s : State := {}) (methods : Methods := {}) : MetaM (Result × State) := do
826826
let ctx ← ctx.setLctxInitIndices
827827
withSimpContext ctx do
828-
let (r, s) ← go e methods.toMethodsRef ctx |>.run { stats with }
828+
let (r, s) ← go e methods.toMethodsRef ctx |>.run s
829829
trace[Meta.Tactic.simp.numSteps] "{s.numSteps}"
830-
let s ← updateUsedSimpsWithZetaDelta ctx { s with }
830+
let stats ← updateUsedSimpsWithZetaDelta ctx { s with }
831+
let s := { s with diag := stats.diag, usedTheorems := stats.usedTheorems }
831832
return (r, s)
832833
where
833834
go (e : Expr) : SimpM Result :=
834835
withCatchingRuntimeEx (simp e)
835836

837+
def main (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Result × Stats) := do
838+
let (r, s) ← mainCore e ctx { stats with } methods
839+
return (r, { s with })
840+
836841
def dsimpMain (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Expr × Stats) := do
837842
withSimpContext ctx do
838843
let (r, s) ← go e methods.toMethodsRef ctx |>.run { stats with }
@@ -845,11 +850,16 @@ where
845850
end Simp
846851
open Simp (SimprocsArray Stats)
847852

853+
def simpCore (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (discharge? : Option Simp.Discharge := none)
854+
(s : Simp.State := {}) : MetaM (Simp.Result × Simp.State) := do profileitM Exception "simp" (← getOptions) do
855+
match discharge? with
856+
| none => Simp.mainCore e ctx s (methods := Simp.mkDefaultMethodsCore simprocs)
857+
| some d => Simp.mainCore e ctx s (methods := Simp.mkMethods simprocs d (wellBehavedDischarge := false))
858+
848859
def simp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (discharge? : Option Simp.Discharge := none)
849860
(stats : Stats := {}) : MetaM (Simp.Result × Stats) := do profileitM Exception "simp" (← getOptions) do
850-
match discharge? with
851-
| none => Simp.main e ctx stats (methods := Simp.mkDefaultMethodsCore simprocs)
852-
| some d => Simp.main e ctx stats (methods := Simp.mkMethods simprocs d (wellBehavedDischarge := false))
861+
let (r, s) ← simpCore e ctx simprocs discharge? { stats with }
862+
return (r, { s with })
853863

854864
def dsimp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[])
855865
(stats : Stats := {}) : MetaM (Expr × Stats) := do profileitM Exception "dsimp" (← getOptions) do

tests/lean/run/grind_heartbeats.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ macro_rules
1212
| `(gen! $n:num) => `(op (f $n) (gen! $(Lean.quote (n.getNat - 1))))
1313

1414
/--
15-
trace: [grind.issues] (deterministic) timeout at `isDefEq`, maximum number of heartbeats (5000) has been reached
15+
trace: [grind.issues] (deterministic) timeout at `grind`, maximum number of heartbeats (5000) has been reached
1616
Use `set_option maxHeartbeats <num>` to set the limit.
1717
1818
Additional diagnostic information may be available using the `set_option diagnostics true` command.

0 commit comments

Comments
 (0)