Skip to content

Commit ef0e9c6

Browse files
committed
chore: wip
1 parent 9df3765 commit ef0e9c6

File tree

1 file changed

+107
-20
lines changed

1 file changed

+107
-20
lines changed

src/Init/Data/BitVec/Bitblast.lean

Lines changed: 107 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,15 +3024,7 @@ theorem extractAndExtendPopulate_sumPackedVec_eq_add (x : BitVec (w + 1)):
30243024
-- need lemma to push extractandextend populate inside the appned afer whichb sumpackedvec simp lemmas should just rewrite, we finish the goal by usign ih
30253025
-- the popcount equals the result of summing the packed vector.
30263026
theorem popCount_eq_sumPackedVec (x : BitVec w) :
3027-
x.popCount = sumPackedVec (extractAndExtendPopulate w x) := by
3028-
induction w
3029-
· case zero => simp
3030-
· case succ w' ihw' =>
3031-
rw [popCount_append]
3032-
simp
3033-
rw [ihw']
3034-
apply extractAndExtendPopulate_sumPackedVec_eq_add
3035-
3027+
x.popCount = sumPackedVec (extractAndExtendPopulate w x) := by sorry
30363028
/-!
30373029
we should keep the proof strategy intact and not delete written proofs, we should fix the defs correctly
30383030
so thatproofs that are not about EEP gothrough while proofs about EEP we can sorry and this is bc the new def
@@ -3140,7 +3132,7 @@ theorem popCount_eq_popCountParSum {x : BitVec w} :
31403132
def recursive_addition (x : BitVec (l * w)) (remaining_elements : Nat) : BitVec w :=
31413133
match remaining_elements with
31423134
| 0 => 0
3143-
| n + 1 => x.extractLsb' n w + x.recursive_addition n
3135+
| n + 1 => x.extractLsb' (n * w) w + x.recursive_addition n
31443136

31453137
/-- given a flattened list of bitvectors `old_layer`, produce a `new_layer` adding
31463138
the elements of `old_layer` two-by-two. -/
@@ -3201,12 +3193,6 @@ def pps_layer {w : Nat} (iter_num : Nat) (old_layer : BitVec (old_length * w))
32013193
proof_new_layer_elements_eq_old_layer_add
32023194
termination_by old_length - (iter_num * 2)
32033195

3204-
theorem recursive_addition_concat {a : BitVec ((a_length + 1) * w)} :
3205-
((a.extractLsb' (a_length * w) w ++ a.extractLsb' 0 (a_length * w)).cast (by simp [Nat.add_mul]; omega)).recursive_addition
3206-
(l := a_length + 1) (w := w) (a_length + 1)
3207-
=
3208-
a.extractLsb' (a_length * w) w + (a.extractLsb' 0 (a_length * w)).recursive_addition (l := a_length) (w := w) a_length
3209-
:= by sorry
32103196

32113197
theorem recursive_addition_concat_of_lt_two {a : BitVec (a_length * w)} (h : 2 ≤ a_length):
32123198
((a.extractLsb' ((a_length - 1) * w) w ++ (a.extractLsb' ((a_length - 1 - 1) * w) w ++ a.extractLsb' 0 ((a_length - 1 - 1) * w))).cast
@@ -3240,6 +3226,89 @@ theorem extractLsb'_append_extractLsb'_eq_of_lt (a : BitVec (a_length * w)) (ha
32403226
· simp [show (a_length - 1) * w + (i - (a_length - 1) * w) = i by omega]
32413227
rw [← getLsbD_eq_getElem]
32423228

3229+
theorem recursive_addition_succ (x : BitVec (l * w)) (n : Nat) :
3230+
x.recursive_addition (n + 1) = x.extractLsb' (n * w) w + x.recursive_addition n := by rfl
3231+
3232+
theorem recursive_addition_eq_of_le (a : BitVec (length * w)) (h : r ≤ length):
3233+
a.recursive_addition r =
3234+
(extractLsb' 0 (r * w) a).recursive_addition r := by sorry
3235+
3236+
theorem recursive_addition_eq_of_le' (a : BitVec (length * w)) (h : r ≤ length) (hk : r ≤ k ):
3237+
a.recursive_addition r =
3238+
(extractLsb' 0 (k * w) a).recursive_addition r := by sorry
3239+
3240+
3241+
theorem cast_recursive_addition_eq_of_le' {a_length' : Nat} {newEl : BitVec w}:
3242+
let hcast : w + (a_length'.succ + 1 - 1) * w = (a_length'.succ + 1) * w := by
3243+
simp [Nat.add_mul]
3244+
omega
3245+
(BitVec.cast hcast (newEl ++ extractLsb' 0 (a_length'.succ * w) a)).recursive_addition a_length' =
3246+
(extractLsb' 0 (a_length' * w) a).recursive_addition a_length' := by
3247+
induction a_length'
3248+
· simp [recursive_addition]
3249+
· case _ a_length' iha =>
3250+
simp [recursive_addition_succ]
3251+
3252+
sorry
3253+
3254+
3255+
3256+
theorem recursive_addition_concat {a : BitVec (a_length * w)} (ha : 0 < a_length) :
3257+
let hc : w + (a_length - 1) * w = a_length * w := by
3258+
simp [Nat.sub_mul]; rw [← Nat.add_sub_assoc (by exact Nat.le_mul_of_pos_left w ha)]; omega
3259+
((a.extractLsb' ((a_length - 1) * w) w ++ a.extractLsb' 0 ((a_length - 1) * w)).cast hc).recursive_addition
3260+
(l := a_length) (w := w) (a_length)
3261+
=
3262+
a.extractLsb' ((a_length - 1) * w) w + (a.extractLsb' 0 ((a_length - 1) * w)).recursive_addition
3263+
(l := (a_length - 1)) (w := w) (a_length - 1)
3264+
:= by
3265+
let newEl := a.extractLsb' ((a_length - 1) * w) w
3266+
rw [show a.extractLsb' ((a_length - 1) * w) w = newEl by rfl]
3267+
induction a_length
3268+
· omega
3269+
· case _ a_length' iha' =>
3270+
simp
3271+
conv =>
3272+
rhs
3273+
unfold recursive_addition
3274+
split
3275+
· ext k hk
3276+
simp [newEl, getLsbD_append, recursive_addition]
3277+
omega
3278+
· case _ a_length' =>
3279+
simp only [recursive_addition_succ]
3280+
have hc1 : w + (a_length'.succ + 1 - 1) * w = (a_length'.succ + 1) * w := by simp [Nat.add_mul, Nat.add_assoc]; omega
3281+
have hadd1 :
3282+
extractLsb' ((a_length' + 1) * w) w (BitVec.cast hc1 (newEl ++ extractLsb' 0 ((a_length' + 1) * w) a)) =
3283+
newEl := by
3284+
ext k hk
3285+
simp [getLsbD_append]
3286+
have : ¬ ((a_length' + 1) * w + k < (a_length' + 1) * w) := by omega
3287+
simp [this, ← getLsbD_eq_getElem]
3288+
rw [hadd1]
3289+
have hc2 : w + (a_length'.succ + 1 - 1) * w = (a_length'.succ + 1) * w := by simp [Nat.add_mul, Nat.add_assoc]; omega
3290+
have hadd2 :
3291+
extractLsb' (a_length' * w) w (BitVec.cast hc2 (newEl ++ extractLsb' 0 ((a_length' + 1) * w) a)) =
3292+
extractLsb' (a_length' * w) w a := by
3293+
ext k hk
3294+
simp [getLsbD_append, show a_length' * w + k < (a_length' + 1) * w by simp [Nat.add_mul]; omega]
3295+
rw [hadd2]
3296+
have hadd3 : extractLsb' (a_length' * w) w (extractLsb' 0 ((a_length' + 1) * w) a) =
3297+
extractLsb' (a_length' * w) w a := by
3298+
ext k hk
3299+
simp [Nat.add_mul]
3300+
intros
3301+
omega
3302+
simp at iha'
3303+
specialize iha' (a := extractLsb' 0 ((a_length' + 1) * w) a)
3304+
rw [hadd3]
3305+
congr 2
3306+
rw [cast_recursive_addition_eq_of_le' ]
3307+
rw [← recursive_addition_eq_of_le]
3308+
rw [← recursive_addition_eq_of_le']
3309+
<;> omega
3310+
omega
3311+
32433312
theorem rec_add_eq_rec_add_iff
32443313
(a : BitVec (a_length * w))
32453314
(halen : a_length = (b_length + 1) / 2)
@@ -3259,10 +3328,28 @@ theorem rec_add_eq_rec_add_iff
32593328
simp [ha, hb, recursive_addition]
32603329
· case _ n' ihn =>
32613330
rw [extractLsb'_append_extractLsb'_eq_of_lt (a := a) (by omega)]
3262-
have : a_length - 1 + 1 = a_length := by omega
3263-
rw [this] at *
3264-
rw [recursive_addition_concat (a_length := a_length - 1) (w := w) (a := a)]
3265-
sorry
3331+
rw [recursive_addition_concat (by omega)]
3332+
have hadd1 := hadd (i := a_length - 1) (by omega) (by omega)
3333+
rw [hadd1]
3334+
split
3335+
· case _ ht =>
3336+
let op1 := extractLsb' (2 * (a_length - 1) * w) w b
3337+
let op2 := extractLsb' ((2 * (a_length - 1) + 1) * w) w b
3338+
let taila := extractLsb' 0 ((a_length - 1) * w) a
3339+
conv =>
3340+
rhs
3341+
rw [extractLsb'_append_extractLsb'_eq_of_lt (a := b) (by omega)]
3342+
rw [extractLsb'_append_extractLsb'_eq_of_lt (a := b) (by omega)]
3343+
rw [show extractLsb' (2 * (a_length - 1) * w) w b = op1 by rfl]
3344+
rw [show extractLsb' ((2 * (a_length - 1) + 1) * w) w b = op2 by rfl]
3345+
rw [show extractLsb' 0 ((a_length - 1) * w) a = taila by rfl]
3346+
3347+
have hext1 : (extractLsb' ((b_length - 1) * w) w b ++ extractLsb' 0 ((b_length - 1) * w) b) =
3348+
(extractLsb' ((b_length - 1) * w) w b ++ extractLsb' 0 ((b_length - 1) * w) b) := by
3349+
sorry
3350+
sorry
3351+
· case _ hf =>
3352+
sorry
32663353

32673354
/-- construct the parallel prefix sum circuit of the flattend bitvectors in `l` -/
32683355
def pps (l : BitVec (l_length * w)) (k: BitVec w)

0 commit comments

Comments
 (0)