Skip to content

Commit 3d30792

Browse files
authored
refactor: make constant folding more robust for future bugs (#11044)
This PR enforces users of the constant folder API to provide proofs of their algebraic properties, thus hopefully avoiding bugs such as #11042 and #11043 in the future.
1 parent 1fa67d0 commit 3d30792

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

src/Lean/Compiler/LCNF/Simp/ConstantFold.lean

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def Folder.mkBinaryDecisionProcedure [Literal α] [Literal β] {r : α → β
224224
/--
225225
Provide a folder for an operation with a left neutral element.
226226
-/
227-
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
227+
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) (op : α → α → α)
228+
(_h : ∀ x, op neutral x = x := by simp) : Folder := fun args => do
228229
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
229230
let some arg₁ ← getLit fvarId₁ | return none
230231
unless arg₁ == neutral do return none
@@ -233,7 +234,8 @@ def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args
233234
/--
234235
Provide a folder for an operation with a right neutral element.
235236
-/
236-
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
237+
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) (op : α → α → α)
238+
(_h : ∀ x, op x neutral = x := by simp) : Folder := fun args => do
237239
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
238240
let some arg₂ ← getLit fvarId₂ | return none
239241
unless arg₂ == neutral do return none
@@ -242,7 +244,8 @@ def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun arg
242244
/--
243245
Provide a folder for an operation with a left annihilator.
244246
-/
245-
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
247+
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) (op : α → α → α)
248+
(_h : ∀ x, op annihilator x = zero := by simp) : Folder := fun args => do
246249
let #[.fvar fvarId, _] := args | return none
247250
let some arg ← getLit fvarId | return none
248251
unless arg == annihilator do return none
@@ -251,7 +254,8 @@ def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α)
251254
/--
252255
Provide a folder for an operation with a right annihilator.
253256
-/
254-
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
257+
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) (op : α → α → α)
258+
(_h : ∀ x, op x annihilator = zero := by simp) : Folder := fun args => do
255259
let #[_, .fvar fvarId] := args | return none
256260
let some arg ← getLit fvarId | return none
257261
unless arg == annihilator do return none
@@ -299,14 +303,20 @@ def Folder.first (folders : Array Folder) : Folder := fun exprs => do
299303
/--
300304
Provide a folder for an operation that has the same left and right neutral element.
301305
-/
302-
def Folder.leftRightNeutral [Literal α] [BEq α] (neutral : α) : Folder :=
303-
Folder.first #[Folder.leftNeutral neutral, Folder.rightNeutral neutral]
306+
def Folder.leftRightNeutral [Literal α] [BEq α] (neutral : α) (op : α → α → α)
307+
(_h1 : ∀ x, op neutral x = x := by simp) (_h2 : ∀ x, op x neutral = x := by simp) : Folder :=
308+
Folder.first #[Folder.leftNeutral neutral op _h1, Folder.rightNeutral neutral op _h2]
304309

305310
/--
306311
Provide a folder for an operation that has the same left and right annihilator.
307312
-/
308-
def Folder.leftRightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder :=
309-
Folder.first #[Folder.leftAnnihilator annihilator zero, Folder.rightAnnihilator annihilator zero]
313+
def Folder.leftRightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α)
314+
(op : α → α → α) (_h1 : ∀ x, op annihilator x = zero := by simp)
315+
(_h2 : ∀ x, op x annihilator = zero := by simp) : Folder :=
316+
Folder.first #[
317+
Folder.leftAnnihilator annihilator zero op _h1,
318+
Folder.rightAnnihilator annihilator zero op _h2
319+
]
310320

311321
/--
312322
Literal folders for higher order datastructures.
@@ -350,27 +360,27 @@ All arithmetic folders.
350360
-/
351361
def arithmeticFolders : List (Name × Folder) := [
352362
(``Nat.succ, Folder.mkUnary Nat.succ),
353-
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]),
354-
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]),
355-
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]),
356-
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]),
357-
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]),
358-
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftAnnihilator 0 0, Folder.rightNeutral 0]),
359-
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.rightNeutral (0 : UInt8)]),
360-
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.rightNeutral (0 : UInt16)]),
361-
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.rightNeutral (0 : UInt32)]),
362-
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.rightNeutral (0 : UInt64)]),
363+
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0 (· + ·)]),
364+
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8) (· + ·)]),
365+
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16) (· + ·)]),
366+
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32) (· + ·)]),
367+
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64) (· + ·)]),
368+
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftAnnihilator 0 0 (· - ·), Folder.rightNeutral 0 (· - ·)]),
369+
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.rightNeutral (0 : UInt8) (· - ·)]),
370+
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.rightNeutral (0 : UInt16) (· - ·)]),
371+
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.rightNeutral (0 : UInt32) (· - ·)]),
372+
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.rightNeutral (0 : UInt64) (· - ·)]),
363373
-- We don't convert Nat multiplication by a power of 2 into a left shift, because the fast path
364374
-- for multiplication isn't any slower than a fast path for left shift that checks for overflow.
365-
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0, Folder.mulShift ``UInt8.shiftLeft (UInt8.shiftLeft 1 ·) UInt8.log2]),
366-
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0, Folder.mulShift ``UInt16.shiftLeft (UInt16.shiftLeft 1 ·) UInt16.log2]),
367-
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0, Folder.mulShift ``UInt32.shiftLeft (UInt32.shiftLeft 1 ·) UInt32.log2]),
368-
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0, Folder.mulShift ``UInt64.shiftLeft (UInt64.shiftLeft 1 ·) UInt64.log2]),
369-
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1, Folder.divShift ``Nat.shiftRight (Nat.pow 2) Nat.log2]),
370-
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8), Folder.divShift ``UInt8.shiftRight (UInt8.shiftLeft 1 ·) UInt8.log2]),
371-
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16), Folder.divShift ``UInt16.shiftRight (UInt16.shiftLeft 1 ·) UInt16.log2]),
372-
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32), Folder.divShift ``UInt32.shiftRight (UInt32.shiftLeft 1 ·) UInt32.log2]),
373-
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64), Folder.divShift ``UInt64.shiftRight (UInt64.shiftLeft 1 ·) UInt64.log2]),
375+
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8) (· * ·), Folder.leftRightAnnihilator (0 : UInt8) 0 (· * ·), Folder.mulShift ``UInt8.shiftLeft (UInt8.shiftLeft 1 ·) UInt8.log2]),
376+
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16) (· * ·), Folder.leftRightAnnihilator (0 : UInt16) 0 (· * ·), Folder.mulShift ``UInt16.shiftLeft (UInt16.shiftLeft 1 ·) UInt16.log2]),
377+
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32) (· * ·), Folder.leftRightAnnihilator (0 : UInt32) 0 (· * ·), Folder.mulShift ``UInt32.shiftLeft (UInt32.shiftLeft 1 ·) UInt32.log2]),
378+
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64) (· * ·), Folder.leftRightAnnihilator (0 : UInt64) 0 (· * ·), Folder.mulShift ``UInt64.shiftLeft (UInt64.shiftLeft 1 ·) UInt64.log2]),
379+
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1 (· / ·), Folder.divShift ``Nat.shiftRight (Nat.pow 2) Nat.log2]),
380+
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8) (· / ·), Folder.divShift ``UInt8.shiftRight (UInt8.shiftLeft 1 ·) UInt8.log2]),
381+
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16) (· / ·), Folder.divShift ``UInt16.shiftRight (UInt16.shiftLeft 1 ·) UInt16.log2]),
382+
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32) (· / ·), Folder.divShift ``UInt32.shiftRight (UInt32.shiftLeft 1 ·) UInt32.log2]),
383+
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64) (· / ·), Folder.divShift ``UInt64.shiftRight (UInt64.shiftLeft 1 ·) UInt64.log2]),
374384
(``Nat.pow, foldNatPow),
375385
(``Nat.nextPowerOfTwo, Folder.mkUnary Nat.nextPowerOfTwo),
376386
]
@@ -413,7 +423,7 @@ def conversionFolders : List (Name × Folder) := [
413423
All string folders.
414424
-/
415425
def stringFolders : List (Name × Folder) := [
416-
(``String.append, Folder.first #[Folder.mkBinary String.append, Folder.leftRightNeutral ""]),
426+
(``String.append, Folder.first #[Folder.mkBinary String.append, Folder.leftRightNeutral "" (· ++ ·)]),
417427
(``String.length, Folder.mkUnary String.length),
418428
(``String.push, Folder.mkBinary String.push)
419429
]

0 commit comments

Comments
 (0)