Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions SSA/Experimental/Bits/Fast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ import SSA.Experimental.Bits.Fast.MBA
import SSA.Experimental.Bits.Fast.Profile
import SSA.Experimental.Bits.Fast.Reflect
import SSA.Experimental.Bits.Fast.Tests
import SSA.Experimental.Bits.Fast.ZextSext

83 changes: 67 additions & 16 deletions SSA/Experimental/Bits/Fast/Defs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ given a value for the free variables in `t`.
Note that we don't keep track of how many free variable occur in `t`,
so eval requires us to give a value for each possible variable.
-/
def Term.eval (t : Term) (vars : List BitStream) : BitStream :=
def Term.eval (t : Term .bv) (vars : List BitStream) : BitStream :=
match t with
| var n => vars.getD n default
| zero => BitStream.zero
Expand All @@ -35,16 +35,17 @@ def Term.eval (t : Term) (vars : List BitStream) : BitStream :=
| shiftL t n => BitStream.shiftLeft (Term.eval t vars) n
-- | repeatBit t => BitStream.repeatBit (Term.eval t vars)


/--
Evaluate a term `t` to the BitStream it represents.

This differs from `Term.eval` in that `Term.evalFin` uses `Term.arity` to
determine the number of free variables that occur in the given term,
and only require that many bitstream values to be given in `vars`.
-/
@[simp] def Term.evalFin (t : Term) (vars : Fin (arity t) → BitStream) : BitStream :=
@[simp] def Term.evalFin (t : Term .bv) (vars : Fin (arity t) → BitStream) : BitStream :=
match t with
| var n => vars (Fin.last n)
| var n => vars (⟨n, by simp⟩) -- Fin.last n)
| zero => BitStream.zero
| one => BitStream.one
| negOne => BitStream.negOne
Expand All @@ -61,7 +62,7 @@ and only require that many bitstream values to be given in `vars`.
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
x₁ ^^^ x₂
| not t => ~~~(t.evalFin vars)
| not t => ~~~(t.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i)))
| add t₁ t₂ =>
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
Expand All @@ -70,24 +71,61 @@ and only require that many bitstream values to be given in `vars`.
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
x₁ - x₂
| neg t => -(Term.evalFin t vars)
| neg t => -(Term.evalFin t (fun i => vars (Fin.castLE (by simp [arity]) i)))
-- | incr t => BitStream.incr (Term.evalFin t vars)
-- | decr t => BitStream.decr (Term.evalFin t vars)
| shiftL t n => BitStream.shiftLeft (Term.evalFin t vars) n
| shiftL t n => BitStream.shiftLeft (Term.evalFin t (fun i => vars (Fin.castLE (by simp [arity]) i))) n
-- | repeatBit t => BitStream.repeatBit (Term.evalFin t vars)


/--
Evaluate a term `t` to the BitStream it represents,
given a value for the free variables in `t`.

Note that we don't keep track of how many free variable occur in `t`,
so eval requires us to give a value for each possible variable.
-/
def BTerm.eval (t : BTerm) (vars : List BitStream) : BitStream :=
match t with
| tru => BitStream.negOne
| fals => BitStream.zero
| xor a b => a.eval vars ^^^ b.eval vars
| msb x => x.eval vars
| var n => vars.getD n default

def BTerm.evalFin (t : BTerm) (vars : Fin (arity t) → BitStream) : BitStream :=
match t with
| tru => BitStream.negOne
| fals => BitStream.zero
| xor a b =>
let x₁ := a.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := b.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
x₁ ^^^ x₂
| msb x => x.evalFin vars
| var n => vars (Fin.last n)

/--
If they are equal so far, then `t1 ^^^ t2`.scanOr will be 0.
-/
def Predicate.evalEq (t₁ t₂ : BitStream) : BitStream := (t₁ ^^^ t₂).concat false |>.scanOr
def Predicate.evalBitstreamEq (t₁ t₂ : BitStream) : BitStream := (t₁ ^^^ t₂) |>.scanOr

/--
If they are equal so far, then `t1 ^^^ t2`.scanOr will be 0.
-/
def Predicate.evalBitstreamNeq (t₁ t₂ : BitStream) : BitStream := (t₁.nxor t₂) |>.scanAnd

/--
If they are equal so far, then `t1 ^^^ t2`.scanOr will be 0.
-/
def Predicate.evalBVEq (t₁ t₂ : BitStream) : BitStream := (t₁ ^^^ t₂).concat false |>.scanOr
/--
If they have been equal so far, then `BitStream.nxor t₁ t₂`.scanAnd will be 1.
Start by assuming that they are not (not equal) i.e. that they are equal, and the
initial value of the preciate is false / `1`.
If their values ever differ, then we know that we will have `a[i] == b[i]` to be `false`.
From this point onward, they will always disagree, and thus the predicate should become `0`.
-/
def Predicate.evalNeq (t₁ t₂ : BitStream) : BitStream := (t₁.nxor t₂).concat true |>.scanAnd
def Predicate.evalBVNeq (t₁ t₂ : BitStream) : BitStream := (t₁.nxor t₂).concat true |>.scanAnd

/-
If they have been `0` so far, then `t1 &&& t2 |>.scanOr` will be `1`.
Expand Down Expand Up @@ -152,21 +190,25 @@ def Predicate.eval (p : Predicate) (vars : List BitStream) : BitStream :=
| .width .ge n => BitStream.falseIffGe n
| lor p q => Predicate.evalLor (p.eval vars) (q.eval vars)
| land p q => Predicate.evalLand (p.eval vars) (q.eval vars)
| binary .eq t₁ t₂ => Predicate.evalEq (t₁.eval vars) (t₂.eval vars)
/- boolean operations. -/
| boolBinary .eq t₁ t₂ => Predicate.evalBitstreamEq (t₁.eval vars) (t₂.eval vars)
| boolBinary .neq t₁ t₂ => Predicate.evalBitstreamNeq (t₁.eval vars) (t₂.eval vars)
/- bitstream operations. -/
| binary .eq t₁ t₂ => Predicate.evalBVEq (t₁.eval vars) (t₂.eval vars)
/-
If it is ever not equal, then we want to stay not equals for ever.
So, if the 'a = b' returns 'false' at some index 'i', we will stay false
for all indexes '≥ i'.
-/
| binary .neq t1 t2 => Predicate.evalNeq (t1.eval vars) (t2.eval vars)
| binary .neq t1 t2 => Predicate.evalBVNeq (t1.eval vars) (t2.eval vars)
| binary .ult t₁ t₂ => Predicate.evalUlt (t₁.eval vars) (t₂.eval vars)
| binary .ule t₁ t₂ =>
Predicate.evalLor
(Predicate.evalEq (t₁.eval vars) (t₂.eval vars))
(Predicate.evalBVEq (t₁.eval vars) (t₂.eval vars))
(Predicate.evalUlt (t₁.eval vars) (t₂.eval vars))
| binary .slt t₁ t₂ => Predicate.evalSlt (t₁.eval vars) (t₂.eval vars)
| binary .sle t₁ t₂ => Predicate.evalLor
(Predicate.evalEq (t₁.eval vars) (t₂.eval vars))
(Predicate.evalBVEq (t₁.eval vars) (t₂.eval vars))
(Predicate.evalSlt (t₁.eval vars) (t₂.eval vars))

@[simp]
Expand Down Expand Up @@ -198,14 +240,23 @@ match p with
| .width .le n => BitStream.falseIffLe n
| .width .gt n => BitStream.falseIffGt n
| .width .ge n => BitStream.falseIffGe n
| .boolBinary .eq t₁ t₂ =>
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
Predicate.evalBitstreamEq x₁ x₂
| .boolBinary .neq t₁ t₂ =>
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
Predicate.evalBitstreamNeq x₁ x₂
-- ~~~ (x₁ ^^^ x₂)
| .binary .eq t₁ t₂ =>
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
Predicate.evalEq x₁ x₂
Predicate.evalBVEq x₁ x₂
| .binary .neq t₁ t₂ =>
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
Predicate.evalNeq x₁ x₂
Predicate.evalBVNeq x₁ x₂
| .land p q =>
-- if both `p` and `q` are logically true (i.e. the predicate is `false`),
-- only then should we return a `false`.
Expand All @@ -224,12 +275,12 @@ match p with
| .binary .sle p q =>
let x₁ := p.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := q.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
Predicate.evalLor (Predicate.evalSlt x₁ x₂) (Predicate.evalEq x₁ x₂)
Predicate.evalLor (Predicate.evalSlt x₁ x₂) (Predicate.evalBVEq x₁ x₂)
| .binary .ult p q =>
let x₁ := p.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := q.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
(Predicate.evalUlt x₁ x₂)
| .binary .ule p q =>
let x₁ := p.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := q.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
Predicate.evalLor (Predicate.evalUlt x₁ x₂) (Predicate.evalEq x₁ x₂)
Predicate.evalLor (Predicate.evalUlt x₁ x₂) (Predicate.evalBVEq x₁ x₂)
81 changes: 66 additions & 15 deletions SSA/Experimental/Bits/Fast/FiniteStateMachine.lean
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def repeatBit : FSM Unit where
end FSM

/-- An `FSMTermSolution `t` is an FSM with a witness that the FSM evaluates to the same value as `t` does -/
structure FSMTermSolution (t : Term) extends FSM (Fin t.arity) where
structure FSMTermSolution (t : Term .bv) extends FSM (Fin t.arity) where
( good : t.evalFin = toFSM.eval )


Expand All @@ -811,7 +811,7 @@ def composeUnaryAux
/-- Compose two automata together, where `q` is an FSMTermSolution -/
def composeUnary
(p : FSM Unit)
{t : Term}
{t : Term .bv}
(q : FSMTermSolution t) :
FSM (Fin t.arity) := composeUnaryAux p q.toFSM

Expand All @@ -830,7 +830,7 @@ def composeBinaryAux
/-- Compose two binary opeators -/
def composeBinary
(p : FSM Bool)
{t₁ t₂ : Term}
{t₁ t₂ : Term .bv}
(q₁ : FSMTermSolution t₁)
(q₂ : FSMTermSolution t₂) :
FSM (Fin (max t₁.arity t₂.arity)) := composeBinaryAux p q₁.toFSM q₂.toFSM
Expand All @@ -845,7 +845,7 @@ def composeBinary

@[simp] lemma composeUnary_eval
(p : FSM Unit)
{t : Term}
{t : Term .bv}
(q : FSMTermSolution t)
(x : Fin t.arity → BitStream) :
(composeUnary p q).eval x = p.eval (λ _ => t.evalFin x) := by
Expand All @@ -865,7 +865,7 @@ def composeBinary

@[simp] lemma composeBinary_eval
(p : FSM Bool)
{t₁ t₂ : Term}
{t₁ t₂ : Term .bv}
(q₁ : FSMTermSolution t₁)
(q₂ : FSMTermSolution t₂)
(x : Fin (max t₁.arity t₂.arity) → BitStream) :
Expand Down Expand Up @@ -1150,18 +1150,22 @@ private theorem falseUptoIncluding_eq_false_iff (n : Nat) (i : Nat) {env : Fin 0

end FSM


structure FSMBTermSolution (b : BTerm) extends FSM (Fin b.arity) where
( good : b.evalFin = toFSM.eval )

open Term

/--
Note that **this is the value that is run by decide**.
-/
def termEvalEqFSM : ∀ (t : Term), FSMTermSolution t
def termEvalEqFSM : ∀ (t : Term .bv), FSMTermSolution t
| ofNat n =>
{ toFSM := FSM.ofNat n,
good := by ext; simp [Term.evalFin] }
| var n =>
{ toFSM := FSM.var n,
good := by ext; simp [Term.evalFin] }
good := by ext; simp [Term.evalFin]; simp [Fin.last] }
| zero =>
{ toFSM := FSM.zero,
good := by ext; simp [Term.evalFin] }
Expand Down Expand Up @@ -1268,7 +1272,7 @@ def fsmUlt (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l)) :=
composeUnaryAux (FSM.ls true) <| (composeUnaryAux FSM.not <| composeBinaryAux FSM.borrow a b)

@[simp]
theorem eval_fsmUlt_eq_evalFin_Predicate_ult (t₁ t₂ : Term) :
theorem eval_fsmUlt_eq_evalFin_Predicate_ult (t₁ t₂ : Term .bv) :
(fsmUlt (termEvalEqFSM t₁).toFSM (termEvalEqFSM t₂).toFSM).eval = (Predicate.binary .ult t₁ t₂).evalFin := by
ext x i
generalize ha : termEvalEqFSM t₁ = a
Expand All @@ -1280,24 +1284,24 @@ def fsmEq (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l)) :=

/-- Evaluation FSM.eq is the same as evaluating Predicate.eq.evalFin. -/
@[simp]
theorem eval_fsmEq_eq_evalFin_Predicate_eq (t₁ t₂ : Term) :
theorem eval_fsmEq_eq_evalFin_Predicate_eq (t₁ t₂ : Term .bv) :
(fsmEq (termEvalEqFSM t₁).toFSM (termEvalEqFSM t₂).toFSM).eval = (Predicate.binary .eq t₁ t₂).evalFin := by
ext x i
generalize ha : termEvalEqFSM t₁ = a
generalize hb : termEvalEqFSM t₂ = b
simp [fsmEq, Predicate.evalEq, a.good, b.good]
simp [fsmEq, Predicate.evalBVEq, a.good, b.good]

def fsmNeq (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l)) :=
composeUnaryAux FSM.scanAnd <| composeUnaryAux (FSM.ls true) <| composeBinaryAux FSM.nxor a b

/-- Evaluation FSM.eq is the same as evaluating Predicate.eq.evalFin. -/
@[simp]
theorem eval_fsmNeq_eq_evalFin_Predicate_neq (t₁ t₂ : Term) :
theorem eval_fsmNeq_eq_evalFin_Predicate_neq (t₁ t₂ : Term .bv) :
(fsmNeq (termEvalEqFSM t₁).toFSM (termEvalEqFSM t₂).toFSM).eval = (Predicate.binary .neq t₁ t₂).evalFin := by
ext x i
generalize ha : termEvalEqFSM t₁ = a
generalize hb : termEvalEqFSM t₂ = b
simp [fsmNeq, Predicate.evalNeq, a.good, b.good]
simp [fsmNeq, Predicate.evalBVNeq, a.good, b.good]

def fsmLand (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l)) :=
composeBinaryAux FSM.or a b
Expand All @@ -1314,7 +1318,7 @@ def fsmUle (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l ⊔ (k ⊔ l)
def fsmMsbEq (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l)) :=
composeUnaryAux (FSM.ls false) <| composeBinaryAux FSM.xor a b

-- theorem fsmMsbNeq_eq_Predicate_MsbNeq (t₁ t₂ : Term) :
-- theorem fsmMsbNeq_eq_Predicate_MsbNeq (t₁ t₂ : Term .bv) :
-- (Predicate.msbNeq t₁ t₂).evalFin = (fsmMsbNeq (termEvalEqFSM t₁).toFSM (termEvalEqFSM t₂).toFSM).eval := sorry

def fsmSlt (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l ⊔ (k ⊔ l))) :=
Expand All @@ -1333,6 +1337,35 @@ def fsmSle (a : FSM (Fin k)) (b : FSM (Fin l)) : FSM (Fin (k ⊔ l ⊔ (k ⊔ l)
let out := fsmLor slt eq
out

def btermEvalEqFSM : ∀ (t : BTerm), FSMBTermSolution t
| .var n =>
{ toFSM := FSM.var n,
good := by ext; simp [BTerm.evalFin] }
| .msb t => {
toFSM := (termEvalEqFSM t).toFSM,
good := by
simp [BTerm.evalFin]
rw [← FSMTermSolution.good]
}
| .tru => {
toFSM := FSM.negOne
good := by
ext i
simp [BTerm.evalFin]
}
| .fals => {
toFSM := FSM.zero
good := by
ext i
simp [BTerm.evalFin]
}
| .xor a b => {
toFSM := composeBinaryAux FSM.xor (btermEvalEqFSM a).toFSM (btermEvalEqFSM b).toFSM
good := by
ext i
simp [BTerm.evalFin]
rw [(btermEvalEqFSM a).good, (btermEvalEqFSM b).good]
}

/--
Evaluating the eq predicate equals the FSM value.
Expand Down Expand Up @@ -1435,7 +1468,7 @@ def predicateEvalEqFSM : ∀ (p : Predicate), FSMPredicateSolution p
Predicate.evalLor, fsmLor,
Predicate.evalSlt, fsmSlt, Predicate.evalUlt,
fsmUlt, a.good, b.good, Predicate.evalMsbEq, fsmMsbEq,
Predicate.evalEq, fsmEq, a.good, b.good]
Predicate.evalBVEq, fsmEq, a.good, b.good]
}
| .binary .ult t₁ t₂ =>
let a := termEvalEqFSM t₁
Expand All @@ -1455,8 +1488,26 @@ def predicateEvalEqFSM : ∀ (p : Predicate), FSMPredicateSolution p
toFSM := fsmUle a.toFSM b.toFSM
good := by
ext x i
simp [fsmUle, fsmUlt, fsmEq, fsmLor, a.good, b.good, Predicate.evalLor, Predicate.evalUlt, Predicate.evalEq]
simp [fsmUle, fsmUlt, fsmEq, fsmLor, a.good, b.good, Predicate.evalLor, Predicate.evalUlt, Predicate.evalBVEq]
}
| .boolBinary .eq a b =>
let a := btermEvalEqFSM a
let b := btermEvalEqFSM b
{
toFSM := composeUnaryAux FSM.scanOr (composeBinaryAux FSM.xor a.toFSM b.toFSM)
good := by
ext x i
simp [Predicate.evalBitstreamEq, a.good, b.good]
}
| .boolBinary .neq a b =>
let a := btermEvalEqFSM a
let b := btermEvalEqFSM b
{
toFSM := composeUnaryAux FSM.scanAnd <| composeBinaryAux FSM.nxor a.toFSM b.toFSM
good := by
ext x i
simp [Predicate.evalBitstreamNeq, a.good, b.good]
}

/-- info: 'predicateEvalEqFSM' depends on axioms: [propext, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms predicateEvalEqFSM
Expand Down
Loading
Loading