Skip to content

Commit 87a6604

Browse files
committed
chore: tests
1 parent 8cb00a0 commit 87a6604

File tree

3 files changed

+581
-0
lines changed

3 files changed

+581
-0
lines changed

generate_popcount_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
3+
output_file = open('test_popcount_correctness.lean', 'w')
4+
5+
output_file.write('import Std.Tactic.BVDecide\n')
6+
7+
width = 9
8+
9+
for n in range(0, pow(2, width)):
10+
popcount_golden = n.bit_count()
11+
output_file.write("example {x : BitVec "+str(width)+"} (h : x = "+str(n)+"#"+str(width)+") : x.popCount = "+str(popcount_golden)+" := by bv_decide\n")
12+
13+
output_file.close()

test_popcount_bv_correctness.lean

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import Init.System.IO
2+
import Init.Data.BitVec
3+
4+
5+
open BitVec
6+
7+
set_option diagnostics true
8+
9+
10+
def test : IO Unit := do
11+
let w := 5
12+
for xx in [0 : 2^w] do
13+
let x := BitVec.ofNat w xx
14+
let bbpop : BitVec w := x.popCountParSum
15+
let bvpop : BitVec w := x.popCount
16+
-- IO.print f!"\nNaive popCount returned {pop}, bitblaster circuit returned{bbpop}, bvPop returned {bvpop}"
17+
if bbpop.toNat ≠ bvpop.toNat then IO.print f!"\nFAIL"
18+
19+
IO.println ""
20+
21+
#eval! test
22+
23+
24+
-- x.popCountAuxRec (setWidth (w + 1) (extractLsb' 0 1 x)) 1
25+
26+
-- (setWidth (w) x).popCountAuxRec (setWidth (w) (extractLsb' 0 1 (setWidth (w) x))) 1 + (x.extractLsb' w 1).setWidth (w + 1)
27+
28+
def scatter (xs : BitVec (n * w)) : List (BitVec w) :=
29+
List.map (fun i => xs.extractLsb' (i * w) w) (List.range n)
30+
31+
def sumVecs (xs : List (BitVec w)) : BitVec w :=
32+
xs.foldl (fun acc x => acc + x) 0#w
33+
34+
/-- zero extend each of the bits `x[i]`, and produce a packed bitvector. -/
35+
def extractAndExtendPopulate (x : BitVec w) : BitVec (w * w) :=
36+
let res := BitVec.extractAndExtendPopulateAux 0 x 0#0 (by omega) (by intros; omega)
37+
res
38+
39+
-- setWidth (w + 1 + 1) (sumVecs (setWidth (w + 1) x).extractAndExtendPopulate.scatter) +
40+
-- setWidth (w + 1 + 1) (extractLsb' (w + 1) 1 x) =
41+
-- sumVecs x.extractAndExtendPopulate.scatter
42+
43+
def test1 : IO Unit := do
44+
let w := 5
45+
let wlt := 4
46+
for xx in [0 : 2^w] do
47+
let x := BitVec.ofNat w xx
48+
let lhs := setWidth w (sumVecs (scatter (extractAndExtendPopulate (setWidth wlt x)))) +
49+
setWidth w (extractLsb' wlt 1 x)
50+
let rhs := sumVecs (scatter (extractAndExtendPopulate x))
51+
if lhs ≠ rhs then IO.print f!"\nFAIL with x = {x}, where pop1 = {lhs.toNat} and pop2 = {rhs.toNat}"
52+
53+
IO.println ""
54+
55+
#eval! test1

0 commit comments

Comments
 (0)