|
| 1 | +import Init.System.IO |
| 2 | +import Init.Data.BitVec |
| 3 | + |
| 4 | + |
| 5 | +open BitVec |
| 6 | + |
| 7 | + set_option diagnostics true |
| 8 | + |
| 9 | + |
| 10 | +def test : IO Unit := do |
| 11 | + let w := 5 |
| 12 | + for xx in [0 : 2^w] do |
| 13 | + let x := BitVec.ofNat w xx |
| 14 | + let bbpop : BitVec w := x.popCountParSum |
| 15 | + let bvpop : BitVec w := x.popCount |
| 16 | + -- IO.print f!"\nNaive popCount returned {pop}, bitblaster circuit returned{bbpop}, bvPop returned {bvpop}" |
| 17 | + if bbpop.toNat ≠ bvpop.toNat then IO.print f!"\nFAIL" |
| 18 | + |
| 19 | + IO.println "" |
| 20 | + |
| 21 | +#eval! test |
| 22 | + |
| 23 | + |
| 24 | +-- x.popCountAuxRec (setWidth (w + 1) (extractLsb' 0 1 x)) 1 |
| 25 | + |
| 26 | +-- (setWidth (w) x).popCountAuxRec (setWidth (w) (extractLsb' 0 1 (setWidth (w) x))) 1 + (x.extractLsb' w 1).setWidth (w + 1) |
| 27 | + |
| 28 | +def scatter (xs : BitVec (n * w)) : List (BitVec w) := |
| 29 | + List.map (fun i => xs.extractLsb' (i * w) w) (List.range n) |
| 30 | + |
| 31 | +def sumVecs (xs : List (BitVec w)) : BitVec w := |
| 32 | + xs.foldl (fun acc x => acc + x) 0#w |
| 33 | + |
| 34 | +/-- zero extend each of the bits `x[i]`, and produce a packed bitvector. -/ |
| 35 | +def extractAndExtendPopulate (x : BitVec w) : BitVec (w * w) := |
| 36 | + let res := BitVec.extractAndExtendPopulateAux 0 x 0#0 (by omega) (by intros; omega) |
| 37 | + res |
| 38 | + |
| 39 | +-- setWidth (w + 1 + 1) (sumVecs (setWidth (w + 1) x).extractAndExtendPopulate.scatter) + |
| 40 | +-- setWidth (w + 1 + 1) (extractLsb' (w + 1) 1 x) = |
| 41 | +-- sumVecs x.extractAndExtendPopulate.scatter |
| 42 | + |
| 43 | +def test1 : IO Unit := do |
| 44 | + let w := 5 |
| 45 | + let wlt := 4 |
| 46 | + for xx in [0 : 2^w] do |
| 47 | + let x := BitVec.ofNat w xx |
| 48 | + let lhs := setWidth w (sumVecs (scatter (extractAndExtendPopulate (setWidth wlt x)))) + |
| 49 | + setWidth w (extractLsb' wlt 1 x) |
| 50 | + let rhs := sumVecs (scatter (extractAndExtendPopulate x)) |
| 51 | + if lhs ≠ rhs then IO.print f!"\nFAIL with x = {x}, where pop1 = {lhs.toNat} and pop2 = {rhs.toNat}" |
| 52 | + |
| 53 | + IO.println "" |
| 54 | + |
| 55 | +#eval! test1 |
0 commit comments