Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions src/Lean/Compiler/LCNF/Simp/ConstantFold.lean
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def Folder.mkBinaryDecisionProcedure [Literal α] [Literal β] {r : α → β
/--
Provide a folder for an operation with a left neutral element.
-/
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) (op : α → α → α)
(_h : ∀ x, op neutral x = x := by simp) : Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₁ ← getLit fvarId₁ | return none
unless arg₁ == neutral do return none
Expand All @@ -233,7 +234,8 @@ def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args
/--
Provide a folder for an operation with a right neutral element.
-/
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) (op : α → α → α)
(_h : ∀ x, op x neutral = x := by simp) : Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₂ ← getLit fvarId₂ | return none
unless arg₂ == neutral do return none
Expand All @@ -242,7 +244,8 @@ def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun arg
/--
Provide a folder for an operation with a left annihilator.
-/
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) (op : α → α → α)
(_h : ∀ x, op annihilator x = zero := by simp) : Folder := fun args => do
let #[.fvar fvarId, _] := args | return none
let some arg ← getLit fvarId | return none
unless arg == annihilator do return none
Expand All @@ -251,7 +254,8 @@ def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α)
/--
Provide a folder for an operation with a right annihilator.
-/
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) (op : α → α → α)
(_h : ∀ x, op x annihilator = zero := by simp) : Folder := fun args => do
let #[_, .fvar fvarId] := args | return none
let some arg ← getLit fvarId | return none
unless arg == annihilator do return none
Expand Down Expand Up @@ -299,14 +303,20 @@ def Folder.first (folders : Array Folder) : Folder := fun exprs => do
/--
Provide a folder for an operation that has the same left and right neutral element.
-/
def Folder.leftRightNeutral [Literal α] [BEq α] (neutral : α) : Folder :=
Folder.first #[Folder.leftNeutral neutral, Folder.rightNeutral neutral]
def Folder.leftRightNeutral [Literal α] [BEq α] (neutral : α) (op : α → α → α)
(_h1 : ∀ x, op neutral x = x := by simp) (_h2 : ∀ x, op x neutral = x := by simp) : Folder :=
Folder.first #[Folder.leftNeutral neutral op _h1, Folder.rightNeutral neutral op _h2]

/--
Provide a folder for an operation that has the same left and right annihilator.
-/
def Folder.leftRightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder :=
Folder.first #[Folder.leftAnnihilator annihilator zero, Folder.rightAnnihilator annihilator zero]
def Folder.leftRightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α)
(op : α → α → α) (_h1 : ∀ x, op annihilator x = zero := by simp)
(_h2 : ∀ x, op x annihilator = zero := by simp) : Folder :=
Folder.first #[
Folder.leftAnnihilator annihilator zero op _h1,
Folder.rightAnnihilator annihilator zero op _h2
]

/--
Literal folders for higher order datastructures.
Expand Down Expand Up @@ -350,27 +360,27 @@ All arithmetic folders.
-/
def arithmeticFolders : List (Name × Folder) := [
(``Nat.succ, Folder.mkUnary Nat.succ),
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]),
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]),
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]),
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]),
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]),
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftAnnihilator 0 0, Folder.rightNeutral 0]),
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.rightNeutral (0 : UInt8)]),
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.rightNeutral (0 : UInt16)]),
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.rightNeutral (0 : UInt32)]),
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.rightNeutral (0 : UInt64)]),
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0 (· + ·)]),
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8) (· + ·)]),
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16) (· + ·)]),
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32) (· + ·)]),
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64) (· + ·)]),
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftAnnihilator 0 0 (· - ·), Folder.rightNeutral 0 (· - ·)]),
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.rightNeutral (0 : UInt8) (· - ·)]),
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.rightNeutral (0 : UInt16) (· - ·)]),
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.rightNeutral (0 : UInt32) (· - ·)]),
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.rightNeutral (0 : UInt64) (· - ·)]),
-- We don't convert Nat multiplication by a power of 2 into a left shift, because the fast path
-- for multiplication isn't any slower than a fast path for left shift that checks for overflow.
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0, Folder.mulShift ``UInt8.shiftLeft (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0, Folder.mulShift ``UInt16.shiftLeft (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0, Folder.mulShift ``UInt32.shiftLeft (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0, Folder.mulShift ``UInt64.shiftLeft (UInt64.shiftLeft 1 ·) UInt64.log2]),
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1, Folder.divShift ``Nat.shiftRight (Nat.pow 2) Nat.log2]),
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8), Folder.divShift ``UInt8.shiftRight (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16), Folder.divShift ``UInt16.shiftRight (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32), Folder.divShift ``UInt32.shiftRight (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64), Folder.divShift ``UInt64.shiftRight (UInt64.shiftLeft 1 ·) UInt64.log2]),
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8) (· * ·), Folder.leftRightAnnihilator (0 : UInt8) 0 (· * ·), Folder.mulShift ``UInt8.shiftLeft (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16) (· * ·), Folder.leftRightAnnihilator (0 : UInt16) 0 (· * ·), Folder.mulShift ``UInt16.shiftLeft (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32) (· * ·), Folder.leftRightAnnihilator (0 : UInt32) 0 (· * ·), Folder.mulShift ``UInt32.shiftLeft (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64) (· * ·), Folder.leftRightAnnihilator (0 : UInt64) 0 (· * ·), Folder.mulShift ``UInt64.shiftLeft (UInt64.shiftLeft 1 ·) UInt64.log2]),
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1 (· / ·), Folder.divShift ``Nat.shiftRight (Nat.pow 2) Nat.log2]),
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8) (· / ·), Folder.divShift ``UInt8.shiftRight (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16) (· / ·), Folder.divShift ``UInt16.shiftRight (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32) (· / ·), Folder.divShift ``UInt32.shiftRight (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64) (· / ·), Folder.divShift ``UInt64.shiftRight (UInt64.shiftLeft 1 ·) UInt64.log2]),
(``Nat.pow, foldNatPow),
(``Nat.nextPowerOfTwo, Folder.mkUnary Nat.nextPowerOfTwo),
]
Expand Down Expand Up @@ -413,7 +423,7 @@ def conversionFolders : List (Name × Folder) := [
All string folders.
-/
def stringFolders : List (Name × Folder) := [
(``String.append, Folder.first #[Folder.mkBinary String.append, Folder.leftRightNeutral ""]),
(``String.append, Folder.first #[Folder.mkBinary String.append, Folder.leftRightNeutral "" (· ++ ·)]),
(``String.length, Folder.mkUnary String.length),
(``String.push, Folder.mkBinary String.push)
]
Expand Down
Loading