|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Luisa Cicolini, Siddharth Bhat, Henrik Böving |
| 5 | +-/ |
| 6 | + |
| 7 | +prelude |
| 8 | +public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Const |
| 9 | +public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.Sub |
| 10 | +public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.Eq |
| 11 | +public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.Extract |
| 12 | +public import Std.Tactic.BVDecide.Bitblast.BVExpr.Circuit.Impl.Operations.ZeroExtend |
| 13 | +public import Std.Sat.AIG.If |
| 14 | + |
| 15 | +/-! |
| 16 | +This module contains the implementation of a bitblaster for `BitVec.popCount`. |
| 17 | +-/ |
| 18 | + |
| 19 | +namespace Std.Tactic.BVDecide |
| 20 | + |
| 21 | +open Std.Sat |
| 22 | + |
| 23 | +variable [Hashable α] [DecidableEq α] |
| 24 | + |
| 25 | +namespace BVExpr |
| 26 | +namespace bitblast |
| 27 | + |
| 28 | +/-- We extract a single bit in position `start` and extend it to haev width `w`-/ |
| 29 | +def blastExtractAndExtend (aig : AIG α) (x : AIG.RefVec aig w) (start : Nat) : AIG.RefVecEntry α w := |
| 30 | + -- extract 1 bit starting from start |
| 31 | + let targetExtract : ExtractTarget aig 1 := {vec := x, start := start} |
| 32 | + let res := blastExtract aig targetExtract |
| 33 | + let aig := res.aig |
| 34 | + let extract := res.vec |
| 35 | + -- zero-extend the extracted portion to have |
| 36 | + let targetExtend : AIG.ExtendTarget aig w := {vec := extract, w := 1} |
| 37 | + let res := blastZeroExtend aig targetExtend |
| 38 | + let aig := res.aig |
| 39 | + let extend := res.vec |
| 40 | + ⟨aig, extend⟩ |
| 41 | + |
| 42 | +theorem extractAndExtend_le_size (aig : AIG α) (x : AIG.RefVec aig w) (start : Nat) : |
| 43 | + aig.decls.size ≤ (blastExtractAndExtend aig x start).aig.decls.size := by |
| 44 | + unfold blastExtractAndExtend |
| 45 | + dsimp only |
| 46 | + apply AIG.LawfulVecOperator.le_size_of_le_aig_size (f := blastZeroExtend) |
| 47 | + apply AIG.LawfulVecOperator.le_size_of_le_aig_size (f := blastExtract) |
| 48 | + omega |
| 49 | + |
| 50 | +theorem extractAndExtend_decl_eq (aig : AIG α) (x : AIG.RefVec aig w) (start : Nat): |
| 51 | + ∀ (idx : Nat) (h1) (h2), |
| 52 | + (blastExtractAndExtend aig x start).aig.decls[idx]'h2 = aig.decls[idx]'h1 := by |
| 53 | + generalize hres : blastExtractAndExtend aig x start = res |
| 54 | + unfold blastExtractAndExtend at hres |
| 55 | + dsimp only at hres |
| 56 | + rw [← hres] |
| 57 | + intros |
| 58 | + rw [AIG.LawfulVecOperator.decl_eq (f := blastZeroExtend)] |
| 59 | + rw [AIG.LawfulVecOperator.decl_eq (f := blastExtract)] |
| 60 | + (expose_names; exact h1) |
| 61 | + |
| 62 | +/-- We extract one bit at a time from the initial vector and zero-extend them to width `w`, |
| 63 | + appending the result to `acc` which eventually will have size `w * w`-/ |
| 64 | +def blastExtractAndExtendPopulate (aig : AIG α) (idx : Nat) (x : AIG.RefVec aig w) |
| 65 | + (acc : AIG.RefVec aig (w * idx)) (hlt : idx ≤ w) |
| 66 | + : AIG.RefVecEntry α (w * w) := |
| 67 | + if hidx : idx < w then |
| 68 | + let res := blastExtractAndExtend aig x idx |
| 69 | + let aigRes := res.aig |
| 70 | + let bv := res.vec |
| 71 | + have := extractAndExtend_le_size aig x idx |
| 72 | + let acc := acc.cast (aig2 := aigRes) this |
| 73 | + let x := x.cast (aig2 := aigRes) this |
| 74 | + let acc := acc.append bv |
| 75 | + have hcast : w * (idx + 1) = w * idx + w := by simp [Nat.mul_add] |
| 76 | + have acc := hcast▸acc |
| 77 | + blastExtractAndExtendPopulate (aigRes) (idx + 1) (x := x) (acc := acc) (by omega) |
| 78 | + else |
| 79 | + have : idx = w := by omega |
| 80 | + have hcast : w * idx = w * w := by rw [this] |
| 81 | + ⟨aig, hcast▸acc⟩ |
| 82 | + |
| 83 | +theorem extractAndExtendPopulate_le_size (aig : AIG α) (idx : Nat) (x : AIG.RefVec aig w) (acc : AIG.RefVec aig (w * idx)) (hlt : idx ≤ w): |
| 84 | + aig.decls.size ≤ (blastExtractAndExtendPopulate aig idx x acc hlt).aig.decls.size := by |
| 85 | + unfold blastExtractAndExtendPopulate |
| 86 | + dsimp only |
| 87 | + split |
| 88 | + · apply Nat.le_trans ?_ (by apply extractAndExtendPopulate_le_size) |
| 89 | + apply extractAndExtend_le_size |
| 90 | + · simp |
| 91 | + |
| 92 | +theorem extractAndExtendPopulate_decl_eq (aig : AIG α) (idx' : Nat) (x : AIG.RefVec aig w) (acc : AIG.RefVec aig (w * idx')) (hlt : idx' ≤ w): |
| 93 | + ∀ (idx : Nat) (h1) (h2), |
| 94 | + (blastExtractAndExtendPopulate aig idx' x acc hlt).aig.decls[idx]'h2 = aig.decls[idx]'h1 := by |
| 95 | + generalize hres : blastExtractAndExtendPopulate aig idx' x acc hlt = res |
| 96 | + unfold blastExtractAndExtendPopulate at hres |
| 97 | + dsimp only at hres |
| 98 | + split at hres |
| 99 | + · rw [← hres] |
| 100 | + intros |
| 101 | + rw [extractAndExtendPopulate_decl_eq, extractAndExtend_decl_eq] |
| 102 | + apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastZeroExtend) |
| 103 | + apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastExtract) |
| 104 | + omega |
| 105 | + · simp [← hres] |
| 106 | + |
| 107 | +/-- Given a vector of references belonging to the same AIG `oldParSum`, we create a node to add the `curr`-th couple of elements and push the add node to `newParSum` -/ |
| 108 | +def blastAddVec (aig : AIG α) (usedNodes validNodes : Nat) |
| 109 | + (oldParSum : AIG.RefVec aig (validNodes * w)) (newParSum : AIG.RefVec aig ((usedNodes / 2) * w)) |
| 110 | + (hval : validNodes ≤ w) (hused : usedNodes ≤ validNodes + 1) (hmod : usedNodes % 2 = 0) : |
| 111 | + AIG.RefVecEntry α (((validNodes+1)/2) * w) := |
| 112 | + if hc1 : usedNodes < validNodes then |
| 113 | + -- rhs |
| 114 | + let rhs := if h : usedNodes + 1 < validNodes then |
| 115 | + let targetExtract : ExtractTarget aig w := {vec := oldParSum, start := (usedNodes + 1) * w} |
| 116 | + let res := blastExtract aig targetExtract |
| 117 | + let aig := res.aig |
| 118 | + have := AIG.LawfulVecOperator.le_size (f := blastExtract) (input := targetExtract) |
| 119 | + let oldParSum := oldParSum.cast this |
| 120 | + let newParSum := newParSum.cast this |
| 121 | + res.vec |
| 122 | + else blastConst aig (w := w) 0 |
| 123 | + -- lhs |
| 124 | + let targetExtract : ExtractTarget aig w := {vec := oldParSum, start := usedNodes * w} |
| 125 | + let res := blastExtract aig targetExtract |
| 126 | + let aig := res.aig |
| 127 | + let lhs := res.vec |
| 128 | + have := AIG.LawfulVecOperator.le_size (f := blastExtract) .. |
| 129 | + let oldParSum := oldParSum.cast this |
| 130 | + let newParSum := newParSum.cast this |
| 131 | + let rhs := rhs.cast this |
| 132 | + -- add |
| 133 | + let res := blastAdd aig ⟨lhs, rhs⟩ |
| 134 | + let aig := res.aig |
| 135 | + let add := res.vec |
| 136 | + have := AIG.LawfulVecOperator.le_size (f := blastAdd) .. |
| 137 | + let oldParSum := oldParSum.cast this |
| 138 | + let newParSum := newParSum.cast this |
| 139 | + let rhs := rhs.cast this |
| 140 | + let lhs := lhs.cast this |
| 141 | + let newVec := newParSum.append add |
| 142 | + have hcast : usedNodes / 2 * w + w = (usedNodes + 2) / 2 * w := by |
| 143 | + simp [show usedNodes / 2 * w + w = usedNodes / 2 * w + 1 * w by omega, |
| 144 | + show (usedNodes + 2) / 2 = usedNodes/2 + 1 by omega, Nat.add_mul] |
| 145 | + blastAddVec aig (usedNodes + 2) validNodes oldParSum (hcast▸newVec) hval (by omega) (by omega) |
| 146 | + else |
| 147 | + have hor : usedNodes = validNodes ∨ usedNodes = validNodes + 1 := by omega |
| 148 | + have hcast : usedNodes / 2 * w = (validNodes + 1) / 2 * w := by |
| 149 | + simp_all |
| 150 | + rcases hor with hor|hor |
| 151 | + · simp [hor] at * |
| 152 | + rw [show (validNodes + 1) / 2 = validNodes / 2 by omega] |
| 153 | + · simp [hor] at * |
| 154 | + let newParSum := hcast▸newParSum |
| 155 | + ⟨aig, newParSum⟩ |
| 156 | + |
| 157 | +theorem addVec_le_size (aig : AIG α) (usedNodes validNodes: Nat) |
| 158 | + (oldParSum : AIG.RefVec aig (validNodes * w)) (newParSum : AIG.RefVec aig ((usedNodes / 2) * w)) |
| 159 | + (hval : validNodes ≤ w) (hused : usedNodes ≤ validNodes + 1) (hmod : usedNodes % 2 = 0) : |
| 160 | + aig.decls.size ≤ (blastAddVec aig usedNodes validNodes oldParSum newParSum hval hused hmod).aig.decls.size := by |
| 161 | + unfold blastAddVec |
| 162 | + dsimp only |
| 163 | + split |
| 164 | + · simp |
| 165 | + <;> (refine Nat.le_trans ?_ (by apply addVec_le_size); apply AIG.LawfulVecOperator.le_size) |
| 166 | + · simp |
| 167 | + |
| 168 | +theorem addVec_decl_eq (aig : AIG α) (usedNodes validNodes: Nat) |
| 169 | + (oldParSum : AIG.RefVec aig (validNodes * w)) (newParSum : AIG.RefVec aig ((usedNodes / 2) * w)) |
| 170 | + (hval : validNodes ≤ w) (hused : usedNodes ≤ validNodes + 1) (hmod : usedNodes % 2 = 0) : |
| 171 | + ∀ (idx : Nat) (h1) (h2), |
| 172 | + (blastAddVec aig usedNodes validNodes oldParSum newParSum hval hused hmod).aig.decls[idx]'h1 = aig.decls[idx]'h2 := by |
| 173 | + generalize hres : blastAddVec aig usedNodes validNodes oldParSum newParSum hval hused hmod = res |
| 174 | + unfold blastAddVec at hres |
| 175 | + dsimp only at hres |
| 176 | + split at hres |
| 177 | + · simp at hres |
| 178 | + · rw [← hres] |
| 179 | + intros |
| 180 | + rw [addVec_decl_eq] |
| 181 | + · apply AIG.LawfulVecOperator.decl_eq (f := blastAdd) |
| 182 | + · apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastAdd) |
| 183 | + assumption |
| 184 | + · simp [← hres] |
| 185 | + |
| 186 | + |
| 187 | +/-- We first extend all the single bits in the input BitVec w to have width `w`, then compute |
| 188 | +the parallel prefix sum given these bits.-/ |
| 189 | +def blastPopCount (aig : AIG α) (x : AIG.RefVec aig w) : AIG.RefVecEntry α w := |
| 190 | + if hw : 1 < w then |
| 191 | + -- init |
| 192 | + let initAcc := blastConst (aig := aig) (w := 0) (val := 0) |
| 193 | + let res := blastExtractAndExtendPopulate aig 0 x initAcc (by omega) |
| 194 | + let aig := res.aig |
| 195 | + let extendedBits := res.vec |
| 196 | + have := extractAndExtendPopulate_le_size |
| 197 | + let x := x.cast (aig2 := res.aig) (by apply this) |
| 198 | + go aig w extendedBits (by omega) (by omega) (by omega) |
| 199 | + else |
| 200 | + if hw' : 0 < w then |
| 201 | + ⟨aig, x⟩ |
| 202 | + else |
| 203 | + let zero := blastConst aig (w := w) 0 |
| 204 | + ⟨aig, zero⟩ |
| 205 | +where |
| 206 | + go (aig : AIG α) (validNodes : Nat) (parSum : AIG.RefVec aig (validNodes * w)) |
| 207 | + (hin : 1 < w) (hval : validNodes ≤ w) (hval' : 0 < validNodes) : AIG.RefVecEntry α w := |
| 208 | + if hlt : 1 < validNodes then |
| 209 | + have hcastZero : 0 = 0 / 2 * w := by omega |
| 210 | + let initAcc := blastConst (aig := aig) (w := 0) (val := 0) |
| 211 | + let res := blastAddVec aig 0 validNodes parSum (hcastZero▸initAcc) hval (by omega) (by omega) |
| 212 | + let aig := res.aig |
| 213 | + let parSum := res.vec |
| 214 | + go (aig := aig) (validNodes := (validNodes + 1)/2) (parSum := parSum) (hin := hin) (hval := by omega) (by omega) |
| 215 | + else |
| 216 | + have hcast : validNodes * w = w := by |
| 217 | + simp [show validNodes = 1 by omega] |
| 218 | + ⟨aig, hcast▸parSum⟩ |
| 219 | + |
| 220 | + |
| 221 | +theorem blastPopCount.go_le_size (aig : AIG α) (validNodes : Nat) (parSum : AIG.RefVec aig (validNodes * w)) |
| 222 | + (hin : 1 < w) (hval : validNodes ≤ w) (hval' : 0 < validNodes) : |
| 223 | + aig.decls.size ≤ (go aig validNodes parSum hin hval hval').aig.decls.size := by |
| 224 | + unfold go |
| 225 | + dsimp only |
| 226 | + split |
| 227 | + · refine Nat.le_trans ?_ (by apply go_le_size) |
| 228 | + apply addVec_le_size |
| 229 | + · simp |
| 230 | + |
| 231 | + |
| 232 | +theorem blastPopCount.go_le_size' (aig : AIG α) (input : aig.RefVec w) (h : 1 < w) : |
| 233 | + let initAcc := blastConst (aig := aig) (w := 0) (val := 0) |
| 234 | + aig.decls.size ≤ |
| 235 | + (go (blastExtractAndExtendPopulate aig 0 input initAcc (by omega)).aig w |
| 236 | + (blastExtractAndExtendPopulate aig 0 input initAcc (by omega)).vec |
| 237 | + h (by omega) (by omega)).aig.decls.size:= by |
| 238 | + unfold go |
| 239 | + dsimp only |
| 240 | + split |
| 241 | + · refine Nat.le_trans ?_ (by apply go_le_size) |
| 242 | + refine Nat.le_trans ?_ (by apply addVec_le_size) |
| 243 | + apply extractAndExtendPopulate_le_size |
| 244 | + · omega |
| 245 | + |
| 246 | +theorem blastPopCount.go_decl_eq {w : Nat} (validNodes : Nat) (aig : AIG α) (parSum : AIG.RefVec aig (validNodes * w)) |
| 247 | + (hin : 1 < w) (hval : validNodes ≤ w) (hval' : 0 < validNodes) : ∀ (idx : Nat) h1 h2, |
| 248 | + (go aig validNodes parSum hin hval hval').aig.decls[idx]'h1 = |
| 249 | + aig.decls[idx]'h2 := by |
| 250 | + generalize hgo : go aig validNodes parSum hin hval hval' = res |
| 251 | + unfold go at hgo |
| 252 | + dsimp only at hgo |
| 253 | + split at hgo |
| 254 | + · rw [← hgo] |
| 255 | + intros idx hidx hidx' |
| 256 | + rw [go_decl_eq] |
| 257 | + · apply addVec_decl_eq |
| 258 | + · let initAcc := blastConst (aig := aig) (w := 0) (val := 0) |
| 259 | + have hcast : 0 = 0/2*w := by omega |
| 260 | + have := addVec_le_size aig 0 validNodes parSum (hcast▸initAcc) (by omega) (by omega) (by omega) |
| 261 | + exact Nat.lt_of_lt_of_le hidx' this |
| 262 | + · simp [← hgo] |
| 263 | + |
| 264 | +instance : AIG.LawfulVecOperator α AIG.RefVec blastPopCount where |
| 265 | + le_size := by |
| 266 | + intros |
| 267 | + unfold blastPopCount |
| 268 | + split |
| 269 | + · apply blastPopCount.go_le_size' |
| 270 | + · split <;> simp |
| 271 | + decl_eq := by sorry |
| 272 | + -- intros |
| 273 | + -- unfold blastPopCount |
| 274 | + -- dsimp only |
| 275 | + -- expose_names |
| 276 | + -- split |
| 277 | + -- · let initAcc := blastConst (aig := aig) (w := 0) (val := 0) |
| 278 | + -- have := extractAndExtendPopulate_le_size (idx := 0) aig input initAcc (by omega) |
| 279 | + -- rw [blastPopCount.go_decl_eq] |
| 280 | + -- apply extractAndExtendPopulate_decl_eq (idx' := 0) aig input |
| 281 | + -- exact Nat.lt_of_lt_of_le h1 this |
| 282 | + -- · split <;> simp |
| 283 | + |
| 284 | +end bitblast |
| 285 | +end BVExpr |
| 286 | + |
| 287 | +end Std.Tactic.BVDecide |
0 commit comments