Skip to content

Commit 9090f24

Browse files
committed
chore: popcount circuit
1 parent 9846bde commit 9090f24

File tree

2 files changed

+786
-0
lines changed

2 files changed

+786
-0
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Luisa Cicolini, Siddharth Bhat, Henrik Böving
5+
-/
6+
7+
prelude
8+
public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Const
9+
public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.Sub
10+
public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.Eq
11+
public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.Extract
12+
public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.ZeroExtend
13+
public import Std.Sat.AIG.If
14+
15+
/-!
16+
This module contains the implementation of a bitblaster for `BitVec.popCount`.
17+
-/
18+
19+
namespace Std.Tactic.BVDecide
20+
21+
open Std.Sat
22+
23+
variable [Hashable α] [DecidableEq α]
24+
25+
namespace BVExpr
26+
namespace bitblast
27+
28+
/-- We extract a single bit in position `start` and extend it to haev width `w`-/
29+
def blastExtractAndExtend (aig : AIG α) (x : AIG.RefVec aig w) (start : Nat) : AIG.RefVecEntry α w :=
30+
-- extract 1 bit starting from start
31+
let targetExtract : ExtractTarget aig 1 := {vec := x, start := start}
32+
let res := blastExtract aig targetExtract
33+
let aig := res.aig
34+
let extract := res.vec
35+
-- zero-extend the extracted portion to have
36+
let targetExtend : AIG.ExtendTarget aig w := {vec := extract, w := 1}
37+
let res := blastZeroExtend aig targetExtend
38+
let aig := res.aig
39+
let extend := res.vec
40+
⟨aig, extend⟩
41+
42+
theorem extractAndExtend_le_size (aig : AIG α) (x : AIG.RefVec aig w) (start : Nat) :
43+
aig.decls.size ≤ (blastExtractAndExtend aig x start).aig.decls.size := by
44+
unfold blastExtractAndExtend
45+
dsimp only
46+
apply AIG.LawfulVecOperator.le_size_of_le_aig_size (f := blastZeroExtend)
47+
apply AIG.LawfulVecOperator.le_size_of_le_aig_size (f := blastExtract)
48+
omega
49+
50+
theorem extractAndExtend_decl_eq (aig : AIG α) (x : AIG.RefVec aig w) (start : Nat):
51+
∀ (idx : Nat) (h1) (h2),
52+
(blastExtractAndExtend aig x start).aig.decls[idx]'h2 = aig.decls[idx]'h1 := by
53+
generalize hres : blastExtractAndExtend aig x start = res
54+
unfold blastExtractAndExtend at hres
55+
dsimp only at hres
56+
rw [← hres]
57+
intros
58+
rw [AIG.LawfulVecOperator.decl_eq (f := blastZeroExtend)]
59+
rw [AIG.LawfulVecOperator.decl_eq (f := blastExtract)]
60+
(expose_names; exact h1)
61+
62+
/-- We extract one bit at a time from the initial vector and zero-extend them to width `w`,
63+
appending the result to `acc` which eventually will have size `w * w`-/
64+
def blastExtractAndExtendPopulate (aig : AIG α) (idx : Nat) (x : AIG.RefVec aig w)
65+
(acc : AIG.RefVec aig (w * idx)) (hlt : idx ≤ w)
66+
: AIG.RefVecEntry α (w * w) :=
67+
if hidx : idx < w then
68+
let res := blastExtractAndExtend aig x idx
69+
let aigRes := res.aig
70+
let bv := res.vec
71+
have := extractAndExtend_le_size aig x idx
72+
let acc := acc.cast (aig2 := aigRes) this
73+
let x := x.cast (aig2 := aigRes) this
74+
let acc := acc.append bv
75+
have hcast : w * (idx + 1) = w * idx + w := by simp [Nat.mul_add]
76+
have acc := hcast▸acc
77+
blastExtractAndExtendPopulate (aigRes) (idx + 1) (x := x) (acc := acc) (by omega)
78+
else
79+
have : idx = w := by omega
80+
have hcast : w * idx = w * w := by rw [this]
81+
⟨aig, hcast▸acc⟩
82+
83+
theorem extractAndExtendPopulate_le_size (aig : AIG α) (idx : Nat) (x : AIG.RefVec aig w) (acc : AIG.RefVec aig (w * idx)) (hlt : idx ≤ w):
84+
aig.decls.size ≤ (blastExtractAndExtendPopulate aig idx x acc hlt).aig.decls.size := by
85+
unfold blastExtractAndExtendPopulate
86+
dsimp only
87+
split
88+
· apply Nat.le_trans ?_ (by apply extractAndExtendPopulate_le_size)
89+
apply extractAndExtend_le_size
90+
· simp
91+
92+
theorem extractAndExtendPopulate_decl_eq (aig : AIG α) (idx' : Nat) (x : AIG.RefVec aig w) (acc : AIG.RefVec aig (w * idx')) (hlt : idx' ≤ w):
93+
∀ (idx : Nat) (h1) (h2),
94+
(blastExtractAndExtendPopulate aig idx' x acc hlt).aig.decls[idx]'h2 = aig.decls[idx]'h1 := by
95+
generalize hres : blastExtractAndExtendPopulate aig idx' x acc hlt = res
96+
unfold blastExtractAndExtendPopulate at hres
97+
dsimp only at hres
98+
split at hres
99+
· rw [← hres]
100+
intros
101+
rw [extractAndExtendPopulate_decl_eq, extractAndExtend_decl_eq]
102+
apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastZeroExtend)
103+
apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastExtract)
104+
omega
105+
· simp [← hres]
106+
107+
/-- Given a vector of references belonging to the same AIG `oldParSum`, we create a node to add the `curr`-th couple of elements and push the add node to `newParSum` -/
108+
def blastAddVec (aig : AIG α) (usedNodes validNodes : Nat)
109+
(oldParSum : AIG.RefVec aig (validNodes * w)) (newParSum : AIG.RefVec aig ((usedNodes / 2) * w))
110+
(hval : validNodes ≤ w) (hused : usedNodes ≤ validNodes + 1) (hmod : usedNodes % 2 = 0) :
111+
AIG.RefVecEntry α (((validNodes+1)/2) * w) :=
112+
if hc1 : usedNodes < validNodes then
113+
-- rhs
114+
let rhs := if h : usedNodes + 1 < validNodes then
115+
let targetExtract : ExtractTarget aig w := {vec := oldParSum, start := (usedNodes + 1) * w}
116+
let res := blastExtract aig targetExtract
117+
let aig := res.aig
118+
have := AIG.LawfulVecOperator.le_size (f := blastExtract) (input := targetExtract)
119+
let oldParSum := oldParSum.cast this
120+
let newParSum := newParSum.cast this
121+
res.vec
122+
else blastConst aig (w := w) 0
123+
-- lhs
124+
let targetExtract : ExtractTarget aig w := {vec := oldParSum, start := usedNodes * w}
125+
let res := blastExtract aig targetExtract
126+
let aig := res.aig
127+
let lhs := res.vec
128+
have := AIG.LawfulVecOperator.le_size (f := blastExtract) ..
129+
let oldParSum := oldParSum.cast this
130+
let newParSum := newParSum.cast this
131+
let rhs := rhs.cast this
132+
-- add
133+
let res := blastAdd aig ⟨lhs, rhs⟩
134+
let aig := res.aig
135+
let add := res.vec
136+
have := AIG.LawfulVecOperator.le_size (f := blastAdd) ..
137+
let oldParSum := oldParSum.cast this
138+
let newParSum := newParSum.cast this
139+
let rhs := rhs.cast this
140+
let lhs := lhs.cast this
141+
let newVec := newParSum.append add
142+
have hcast : usedNodes / 2 * w + w = (usedNodes + 2) / 2 * w := by
143+
simp [show usedNodes / 2 * w + w = usedNodes / 2 * w + 1 * w by omega,
144+
show (usedNodes + 2) / 2 = usedNodes/2 + 1 by omega, Nat.add_mul]
145+
blastAddVec aig (usedNodes + 2) validNodes oldParSum (hcast▸newVec) hval (by omega) (by omega)
146+
else
147+
have hor : usedNodes = validNodes ∨ usedNodes = validNodes + 1 := by omega
148+
have hcast : usedNodes / 2 * w = (validNodes + 1) / 2 * w := by
149+
simp_all
150+
rcases hor with hor|hor
151+
· simp [hor] at *
152+
rw [show (validNodes + 1) / 2 = validNodes / 2 by omega]
153+
· simp [hor] at *
154+
let newParSum := hcast▸newParSum
155+
⟨aig, newParSum⟩
156+
157+
theorem addVec_le_size (aig : AIG α) (usedNodes validNodes: Nat)
158+
(oldParSum : AIG.RefVec aig (validNodes * w)) (newParSum : AIG.RefVec aig ((usedNodes / 2) * w))
159+
(hval : validNodes ≤ w) (hused : usedNodes ≤ validNodes + 1) (hmod : usedNodes % 2 = 0) :
160+
aig.decls.size ≤ (blastAddVec aig usedNodes validNodes oldParSum newParSum hval hused hmod).aig.decls.size := by
161+
unfold blastAddVec
162+
dsimp only
163+
split
164+
· simp
165+
<;> (refine Nat.le_trans ?_ (by apply addVec_le_size); apply AIG.LawfulVecOperator.le_size)
166+
· simp
167+
168+
theorem addVec_decl_eq (aig : AIG α) (usedNodes validNodes: Nat)
169+
(oldParSum : AIG.RefVec aig (validNodes * w)) (newParSum : AIG.RefVec aig ((usedNodes / 2) * w))
170+
(hval : validNodes ≤ w) (hused : usedNodes ≤ validNodes + 1) (hmod : usedNodes % 2 = 0) :
171+
∀ (idx : Nat) (h1) (h2),
172+
(blastAddVec aig usedNodes validNodes oldParSum newParSum hval hused hmod).aig.decls[idx]'h1 = aig.decls[idx]'h2 := by
173+
generalize hres : blastAddVec aig usedNodes validNodes oldParSum newParSum hval hused hmod = res
174+
unfold blastAddVec at hres
175+
dsimp only at hres
176+
split at hres
177+
· simp at hres
178+
· rw [← hres]
179+
intros
180+
rw [addVec_decl_eq]
181+
· apply AIG.LawfulVecOperator.decl_eq (f := blastAdd)
182+
· apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastAdd)
183+
assumption
184+
· simp [← hres]
185+
186+
187+
/-- We first extend all the single bits in the input BitVec w to have width `w`, then compute
188+
the parallel prefix sum given these bits.-/
189+
def blastPopCount (aig : AIG α) (x : AIG.RefVec aig w) : AIG.RefVecEntry α w :=
190+
if hw : 1 < w then
191+
-- init
192+
let initAcc := blastConst (aig := aig) (w := 0) (val := 0)
193+
let res := blastExtractAndExtendPopulate aig 0 x initAcc (by omega)
194+
let aig := res.aig
195+
let extendedBits := res.vec
196+
have := extractAndExtendPopulate_le_size
197+
let x := x.cast (aig2 := res.aig) (by apply this)
198+
go aig w extendedBits (by omega) (by omega) (by omega)
199+
else
200+
if hw' : 0 < w then
201+
⟨aig, x⟩
202+
else
203+
let zero := blastConst aig (w := w) 0
204+
⟨aig, zero⟩
205+
where
206+
go (aig : AIG α) (validNodes : Nat) (parSum : AIG.RefVec aig (validNodes * w))
207+
(hin : 1 < w) (hval : validNodes ≤ w) (hval' : 0 < validNodes) : AIG.RefVecEntry α w :=
208+
if hlt : 1 < validNodes then
209+
have hcastZero : 0 = 0 / 2 * w := by omega
210+
let initAcc := blastConst (aig := aig) (w := 0) (val := 0)
211+
let res := blastAddVec aig 0 validNodes parSum (hcastZero▸initAcc) hval (by omega) (by omega)
212+
let aig := res.aig
213+
let parSum := res.vec
214+
go (aig := aig) (validNodes := (validNodes + 1)/2) (parSum := parSum) (hin := hin) (hval := by omega) (by omega)
215+
else
216+
have hcast : validNodes * w = w := by
217+
simp [show validNodes = 1 by omega]
218+
⟨aig, hcast▸parSum⟩
219+
220+
221+
theorem blastPopCount.go_le_size (aig : AIG α) (validNodes : Nat) (parSum : AIG.RefVec aig (validNodes * w))
222+
(hin : 1 < w) (hval : validNodes ≤ w) (hval' : 0 < validNodes) :
223+
aig.decls.size ≤ (go aig validNodes parSum hin hval hval').aig.decls.size := by
224+
unfold go
225+
dsimp only
226+
split
227+
· refine Nat.le_trans ?_ (by apply go_le_size)
228+
apply addVec_le_size
229+
· simp
230+
231+
232+
theorem blastPopCount.go_le_size' (aig : AIG α) (input : aig.RefVec w) (h : 1 < w) :
233+
let initAcc := blastConst (aig := aig) (w := 0) (val := 0)
234+
aig.decls.size ≤
235+
(go (blastExtractAndExtendPopulate aig 0 input initAcc (by omega)).aig w
236+
(blastExtractAndExtendPopulate aig 0 input initAcc (by omega)).vec
237+
h (by omega) (by omega)).aig.decls.size:= by
238+
unfold go
239+
dsimp only
240+
split
241+
· refine Nat.le_trans ?_ (by apply go_le_size)
242+
refine Nat.le_trans ?_ (by apply addVec_le_size)
243+
apply extractAndExtendPopulate_le_size
244+
· omega
245+
246+
theorem blastPopCount.go_decl_eq {w : Nat} (validNodes : Nat) (aig : AIG α) (parSum : AIG.RefVec aig (validNodes * w))
247+
(hin : 1 < w) (hval : validNodes ≤ w) (hval' : 0 < validNodes) : ∀ (idx : Nat) h1 h2,
248+
(go aig validNodes parSum hin hval hval').aig.decls[idx]'h1 =
249+
aig.decls[idx]'h2 := by
250+
generalize hgo : go aig validNodes parSum hin hval hval' = res
251+
unfold go at hgo
252+
dsimp only at hgo
253+
split at hgo
254+
· rw [← hgo]
255+
intros idx hidx hidx'
256+
rw [go_decl_eq]
257+
· apply addVec_decl_eq
258+
· let initAcc := blastConst (aig := aig) (w := 0) (val := 0)
259+
have hcast : 0 = 0/2*w := by omega
260+
have := addVec_le_size aig 0 validNodes parSum (hcast▸initAcc) (by omega) (by omega) (by omega)
261+
exact Nat.lt_of_lt_of_le hidx' this
262+
· simp [← hgo]
263+
264+
instance : AIG.LawfulVecOperator α AIG.RefVec blastPopCount where
265+
le_size := by
266+
intros
267+
unfold blastPopCount
268+
split
269+
· apply blastPopCount.go_le_size'
270+
· split <;> simp
271+
decl_eq := by sorry
272+
-- intros
273+
-- unfold blastPopCount
274+
-- dsimp only
275+
-- expose_names
276+
-- split
277+
-- · let initAcc := blastConst (aig := aig) (w := 0) (val := 0)
278+
-- have := extractAndExtendPopulate_le_size (idx := 0) aig input initAcc (by omega)
279+
-- rw [blastPopCount.go_decl_eq]
280+
-- apply extractAndExtendPopulate_decl_eq (idx' := 0) aig input
281+
-- exact Nat.lt_of_lt_of_le h1 this
282+
-- · split <;> simp
283+
284+
end bitblast
285+
end BVExpr
286+
287+
end Std.Tactic.BVDecide

0 commit comments

Comments
 (0)