Skip to content

Commit bb30b64

Browse files
committed
Cleaner code structure, use state not exceptions
1 parent 715c469 commit bb30b64

File tree

3 files changed

+105
-81
lines changed

3 files changed

+105
-81
lines changed

src/Lean/Meta/Tactic/Simp/Rewrite.lean

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -21,88 +21,87 @@ namespace Lean.Meta.Simp
2121
def currentlyLoopChecking : SimpM Bool := do
2222
return !(← getContext).loopCheckStack.isEmpty
2323

24-
def setLoopCache (thm : SimpTheorem) (r : Bool) : SimpM Unit := do
24+
def getLoopCache (thm : SimpTheorem) : SimpM (Option LoopProtectionResult) := do
25+
return (← get).loopProtectionCache.lookup? thm
26+
27+
def setLoopCache (thm : SimpTheorem) (r : LoopProtectionResult) : SimpM Unit := do
2528
modifyThe State fun s => { s with loopProtectionCache := s.loopProtectionCache.insert thm r }
2629

27-
@[inline] def withPreservedLoopCache (x : SimpM α) : SimpM α := do
28-
-- Recall that `cache.map₁` should be used linearly but `cache.map₂` is great for copies.
29-
let saved := (← get).loopProtectionCache
30-
try x
31-
finally modify fun s => { s with loopProtectionCache := saved }
30+
def setLoopCacheOkIfUnset (thm : SimpTheorem) : SimpM Unit := do
31+
unless (← getLoopCache thm).isSome do
32+
setLoopCache thm .ok
33+
34+
def setLoopCacheLoop (loop : Array SimpTheorem): SimpM Unit := do
35+
let thm := loop[0]!
36+
assert! (← getLoopCache thm).isNone
37+
assert! loop.size > 0
38+
setLoopCache thm (.loop loop)
39+
40+
def unlessWarnedBefore (thm : SimpTheorem) (k : SimpM Unit) : SimpM Unit := do
41+
unless (← get).loopProtectionCache.warned thm do
42+
modifyThe State fun s => { s with loopProtectionCache := s.loopProtectionCache.setWarned thm }
43+
k
44+
45+
def mkLoopWarningMsg (thms : Array SimpTheorem) : SimpM MessageData := do
46+
if thms.size = 1 then
47+
return m!"Ignoring looping simp theorem: {← ppOrigin thms[0]!.origin}"
48+
else
49+
return m! "Ignoring jointly looping simp theorems: \
50+
{.andList (← thms.mapM (ppOrigin ·.origin)).toList}"
51+
52+
private def rotations (a : Array α) : Array (Array α) := Id.run do
53+
let mut r : Array (Array α) := #[]
54+
for i in [:a.size] do
55+
r := r.push (a[i:] ++ a[:i])
56+
return r
3257

3358
def checkLoops (thm : SimpTheorem) : SimpM Bool := do
3459
let cfg ← getConfig
3560
-- No loop checking when disabled or in single pass mode
3661
if !cfg.loopProtection || cfg.singlePass then return true
3762

38-
-- Check cache
39-
if let some r := (← get).loopProtectionCache.lookup? thm then
40-
return r
41-
42-
withTraceNode `Meta.Tactic.simp.loopProtection (return m!"{exceptEmoji ·} loop-checking {← ppSimpTheorem thm}") do
43-
44-
let checkRhs : SimpM Unit := do
45-
withPushingLoopCheck thm do
46-
withFreshCache do
47-
let type ← inferType (← thm.getValue)
48-
forallTelescopeReducing type fun _xs type => do
49-
let rhs := (← whnf type).appArg!
50-
-- We ignore the result for now. We could return it to `tryTheoremCore` to avoid
51-
-- re-simplifying the right-hand side, but that would require some more refactoring
52-
let _ ← simp rhs
53-
54-
let seenThms := (← getContext).loopCheckStack
55-
if seenThms.isEmpty then
56-
-- This is the main entry into loop checking
57-
58-
-- Accept permutating and local theorems without checking
59-
if thm.perm then return true
60-
if thm.proof.hasFVar then return true
61-
62-
-- Check the right-hand side, turn thrown errors into logged warnigns.
63-
try
64-
withPreservedLoopCache do
65-
checkRhs
66-
setLoopCache thm true
67-
pure true
68-
catch e =>
69-
-- This catches all errors, but ideally we only catch the error thrown above.
70-
-- Can we achieve that without hacks?
71-
logWarning e.toMessageData
72-
setLoopCache thm false
73-
pure false
74-
else
75-
let checkingThmId := seenThms.getLast!
76-
-- We are in the process of checking `checkingThmId` for loops
77-
78-
-- Disable all local theorems and all permutating theorems
79-
if thm.perm then return false
80-
if thm.proof.hasFVar then return false
8163

82-
if thm == checkingThmId then
83-
-- We found a loop starting with `checkingThmId`!
84-
if seenThms matches [_] then
85-
throwError "Ignoring looping simp theorem: {← ppOrigin thm.origin}"
86-
else
87-
throwError "Ignoring jointly looping simp theorems: \
88-
{.andList (← seenThms.reverse.mapM (ppOrigin ·.origin))}"
89-
90-
if seenThms.contains thm then
91-
-- Starting with `checkingThmId`, we run into a loop, but the loop does
92-
-- not actually involve `checkingThmId`. Stop rewriting, but do not complain.
93-
-- We update the cache to avoid looping during checking.
94-
-- Since this is not reportd, we throw away the cache updates
95-
-- at the end of the loop checking.
96-
setLoopCache thm false
97-
return false
64+
-- Permutating and local theorems are never checked, so accept when starting
65+
-- a loop check, and ignore when inside a loop check
66+
if thm.perm || thm.proof.hasFVar then
67+
return !(← currentlyLoopChecking)
9868

99-
checkRhs
100-
-- Check cache again, we may have found a loop for this one
101-
if let some r := (← get).loopProtectionCache.lookup? thm then
102-
return r
69+
-- Check cache
70+
if (← getLoopCache thm).isNone then
71+
withTraceNode `Meta.Tactic.simp.loopProtection (fun _ => return m!"loop-checking {← ppSimpTheorem thm}") do
72+
73+
-- Checking for a loop
74+
let seenThms := (← getContext).loopCheckStack
75+
if let some idx := seenThms.idxOf? thm then
76+
let loopThms := (seenThms.take (idx + 1)).toArray.reverse
77+
assert! loopThms[0]! == thm
78+
trace[Meta.Tactic.simp.loopProtection] "loop detected: {.andList (← loopThms.mapM (ppOrigin ·.origin)).toList}"
79+
(rotations loopThms).forM setLoopCacheLoop
10380
else
104-
setLoopCache thm true
105-
return true
81+
-- Check the right-hand side
82+
withPushingLoopCheck thm do
83+
withFreshCache do
84+
let type ← inferType (← thm.getValue)
85+
forallTelescopeReducing type fun _xs type => do
86+
let rhs := (← whnf type).appArg!
87+
-- We ignore the result for now. We could return it to `tryTheoremCore` to avoid
88+
-- re-simplifying the right-hand side, but that would require some more refactoring
89+
let _ ← simp rhs
90+
-- If we made it this far without finding a loop, this theorem is fine
91+
setLoopCacheOkIfUnset thm
92+
93+
-- Now the cache tells us if this was looping
94+
if let some (.loop thms) ← getLoopCache thm then
95+
-- Only when this is the starting point and we have not warned before: report the loop
96+
unless (← currentlyLoopChecking) do
97+
unlessWarnedBefore thm do
98+
if let .stx _ ref := thm.origin then
99+
logWarningAt ref (← mkLoopWarningMsg thms)
100+
else
101+
logWarning (← mkLoopWarningMsg thms)
102+
return false
103+
else
104+
return true
106105

107106
/--
108107
Helper type for implementing `discharge?'`

src/Lean/Meta/Tactic/Simp/Types.lean

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,26 @@ def UsedSimps.insert (s : UsedSimps) (thmId : Origin) : UsedSimps :=
225225
def UsedSimps.toArray (s : UsedSimps) : Array Origin :=
226226
s.map.toArray.qsort (·.2 < ·.2) |>.map (·.1)
227227

228+
inductive LoopProtectionResult where
229+
| ok
230+
| loop (loop : Array SimpTheorem)
231+
228232
structure LoopProtectionCache where
229-
map : PHashMap Expr Bool := {}
233+
map : PHashMap Expr LoopProtectionResult := {}
234+
warnedSet : PHashSet Expr := {}
230235
deriving Inhabited
231236

232-
def LoopProtectionCache.lookup? (c : LoopProtectionCache) (thm : SimpTheorem) : Option Bool :=
237+
def LoopProtectionCache.lookup? (c : LoopProtectionCache) (thm : SimpTheorem) : Option LoopProtectionResult :=
233238
c.map.find? thm.proof
234239

235-
def LoopProtectionCache.insert (c : LoopProtectionCache) (thm : SimpTheorem) (b : Bool) : LoopProtectionCache :=
236-
{ c with map := c.map.insert thm.proof b }
240+
def LoopProtectionCache.insert (c : LoopProtectionCache) (thm : SimpTheorem) (r : LoopProtectionResult) : LoopProtectionCache :=
241+
{ c with map := c.map.insert thm.proof r }
242+
243+
def LoopProtectionCache.warned (c : LoopProtectionCache) (thm : SimpTheorem) : Bool :=
244+
c.warnedSet.contains thm.proof
245+
246+
def LoopProtectionCache.setWarned (c : LoopProtectionCache) (thm : SimpTheorem) : LoopProtectionCache :=
247+
{ c with warnedSet := c.warnedSet.insert thm.proof }
237248

238249
structure Diagnostics where
239250
/-- Number of times each simp theorem has been used/applied. -/

tests/lean/run/simpLoopProtection.lean

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ error: unsolved goals
1414
⊢ a = 23
1515
-/
1616
#guard_msgs in
17-
example : id a = 23 := by simp +loopProtection -failIfUnchanged [aa]
17+
example : id a = 23 := by
18+
simp +loopProtection -failIfUnchanged [aa]
1819

1920
/--
2021
warning: Ignoring jointly looping simp theorems: ab and ba
@@ -25,6 +26,17 @@ error: unsolved goals
2526
#guard_msgs in
2627
example : a = 23 := by simp +loopProtection -failIfUnchanged [ab, ba]
2728

29+
/--
30+
warning: Ignoring jointly looping simp theorems: ab and ba
31+
---
32+
warning: Ignoring jointly looping simp theorems: ba and ab
33+
---
34+
error: unsolved goals
35+
⊢ a = 2 * b
36+
-/
37+
#guard_msgs in
38+
example : a = 2*b := by simp +loopProtection -failIfUnchanged [ab, ba]
39+
2840
/--
2941
warning: Ignoring jointly looping simp theorems: ← ba and ← ab
3042
---
@@ -55,8 +67,8 @@ theorem id'_eq (n : Nat) : id' n = n := testSorry
5567
theorem id'_eq_bad (n : Nat) : id' n = id' (id' n) := testSorry
5668

5769
/--
58-
trace: [Meta.Tactic.simp.loopProtection] ✅️ loop-checking id'_eq:1000
59-
[Meta.Tactic.simp.loopProtection] ✅️ loop-checking eq_self:1000
70+
trace: [Meta.Tactic.simp.loopProtection] loop-checking id'_eq:1000
71+
[Meta.Tactic.simp.loopProtection] loop-checking eq_self:1000
6072
-/
6173
#guard_msgs in
6274
set_option trace.Meta.Tactic.simp.loopProtection true in
@@ -68,8 +80,9 @@ warning: Ignoring looping simp theorem: id'_eq_bad
6880
error: unsolved goals
6981
⊢ id' 1 + id' 2 = id' 3
7082
---
71-
trace: [Meta.Tactic.simp.loopProtection] ✅️ loop-checking id'_eq_bad:1000
72-
[Meta.Tactic.simp.loopProtection] ❌️ loop-checking id'_eq_bad:1000
83+
trace: [Meta.Tactic.simp.loopProtection] loop-checking id'_eq_bad:1000
84+
[Meta.Tactic.simp.loopProtection] loop-checking id'_eq_bad:1000
85+
[Meta.Tactic.simp.loopProtection] loop detected: id'_eq_bad
7386
-/
7487
#guard_msgs in
7588
set_option trace.Meta.Tactic.simp.loopProtection true in
@@ -164,6 +177,7 @@ P : Nat → Prop
164177
-/
165178
#guard_msgs in
166179
example : c > 0 := by simp only [c, ac]
180+
167181
/--
168182
warning: Ignoring looping simp theorem: ac
169183
---
@@ -172,7 +186,7 @@ P : Nat → Prop
172186
⊢ a > 0
173187
-/
174188
#guard_msgs in
175-
example : d > 0 := by simp only [c, ac, dc]
189+
example : d > 0 := by simp only [dc, c, ac]
176190
/--
177191
warning: Ignoring looping simp theorem: ac
178192
---

0 commit comments

Comments
 (0)