@@ -3221,7 +3221,7 @@ theorem recursive_addition_concat_of_lt_two {a : BitVec (a_length * w)} (h : 2
32213221 a.extractLsb' ((a_length - 1 ) * w) w + a.extractLsb' ((a_length - 1 - 1 ) * w) w + (a.extractLsb' 0 ((a_length - 1 - 1 ) * w)).recursive_addition (l := a_length - 2 ) (w := w) (a_length - 1 - 1 )
32223222 := by sorry
32233223
3224- theorem extractLsb'_extractLsb'_eq_of_lt (a : BitVec w) (hlt : i + k < len) :
3224+ theorem extractLsb'_extractLsb'_eq_of_lt (a : BitVec w) (hlt : i + k ≤ len) :
32253225 extractLsb' i k (extractLsb' 0 len a) =
32263226 extractLsb' i k a := by
32273227 ext j hj
@@ -3235,98 +3235,24 @@ theorem rec_add_eq_rec_add_iff
32353235 (b : BitVec (b_length * w))
32363236 (hadd : ∀ (i : Nat) (hi : i < (b_length + 1 ) / 2 ) (hi' : 2 * i < b_length),
32373237 extractLsb' (i * w) w a =
3238- extractLsb' (2 * i * w) w b + if h : 2 * i + 1 < b_length then extractLsb' ((2 * i + 1 ) * w) w b else 0 ) :
3239- recursive_addition a ((b_length + 1 ) / 2 ) = recursive_addition b b_length := by
3240- induction a_length generalizing b_length b
3241- · have hblen : b_length = 0 := by omega
3242- simp [hblen, recursive_addition]
3243- · case _ a_length' iha =>
3244- have ha_eq : a = ((a.extractLsb' ((a_length' + 1 - 1 ) * w) w) ++ (a.extractLsb' 0 ((a_length' + 1 - 1 ) * w))).cast (by simp [Nat.add_mul]; omega) := by
3245- ext j hj
3246- simp only [getElem_cast, getElem_append, getElem_extractLsb']
3247- split
3248- · simp [← getLsbD_eq_getElem]
3249- · rw [show (a_length' + 1 - 1 ) * w + (j - (a_length' + 1 - 1 ) * w) = j by omega, ← getLsbD_eq_getElem]
3250- rw [ha_eq]
3251- simp [← halen]
3252- rw [recursive_addition_concat (a_length := a_length')]
3253- rw [hadd (i := a_length') (by omega) (by omega)]
3254- split
3255- · case _ htrue =>
3256- have hblenle: 2 ≤ b_length := by omega
3257- -- b[2 * (a.length - 1)] + b[2 * (a.length - 1) + 1]
3258- have heq : 2 * a_length' + 1 = b_length - 1 := by omega
3259- have heq' : 2 * a_length' = b_length - 1 - 1 := by omega
3260- have h_cast : w + (w + (b_length - 1 - 1 ) * w) = b_length * w := by
3261- simp [show b_length = (a_length'+ 1 ) * 2 by omega]
3262- simp [Nat.add_mul]
3263- rw [Nat.mul_comm a_length' 2 ]
3264- omega
3265- have hb_eq : b = ((b.extractLsb' ((b_length - 1 ) * w) w) ++
3266- ((b.extractLsb' ((b_length - 1 - 1 ) * w) w) ++
3267- (b.extractLsb' 0 ((b_length - 1 - 1 ) * w)))).cast h_cast := by
3268- ext j hj
3269- simp [getElem_append]
3270- have : w + (b_length - 1 - 1 ) * w = (b_length - 1 ) * w := by
3271- rw [show w + (b_length - 1 - 1 ) * w = 1 * w + (b_length - 1 - 1 ) * w by omega]
3272- rw [← Nat.add_mul]
3273- rw [show 1 + (b_length - 1 - 1 ) = b_length - 1 by omega]
3274- rw [this ]
3275- split
3276- · split
3277- · rw [← getLsbD_eq_getElem]
3278- · rw [show (b_length - 1 - 1 ) * w + (j - (b_length - 1 - 1 ) * w) = j by omega,
3279- ← getLsbD_eq_getElem]
3280- · rw [show (b_length - 1 ) * w + (j - (b_length - 1 ) * w) = j by omega, ← getLsbD_eq_getElem]
3281- have := recursive_addition_concat_of_lt_two (a_length := b_length) (a := b) (by omega)
3282- conv =>
3283- rhs
3284- rw [hb_eq]
3285- rw [this ]
3286- rw [heq, heq']
3287- have : extractLsb' ((b_length - 1 - 1 ) * w) w b + extractLsb' ((b_length - 1 ) * w) w b =
3288- extractLsb' ((b_length - 1 ) * w) w b + extractLsb' ((b_length - 1 - 1 ) * w) w b := by
3289- exact
3290- BitVec.add_comm (extractLsb' ((b_length - 1 - 1 ) * w) w b)
3291- (extractLsb' ((b_length - 1 ) * w) w b)
3292- rw [this ]
3293- specialize iha (b_length := b_length - 1 - 1 )
3294- (extractLsb' 0 (a_length' * w) a)
3295- (by omega)
3296- (extractLsb' 0 ((b_length - 1 - 1 ) * w) b)
3297- rw [← iha]
3298- · congr
3299- omega
3300- · intros i hi hi'
3301- have hlt : i < a_length' := by omega
3302-
3303- have heq'' : extractLsb' (i * w) w (extractLsb' 0 (a_length' * w) a) =
3304- extractLsb' (i * w) w a := by
3305- rw [extractLsb'_extractLsb'_eq_of_lt]
3306- simp at *
3307- have h1 : i ≤ a_length' - 1 := by omega
3308- have h2 := Nat.mul_le_mul_right (n := i) (m := a_length' - 1 ) (k := w) (by omega)
3309- have h3 : i * w ≤ a_length' * w - w := by
3310- simp [Nat.sub_mul] at h2
3311- omega
3312- have : 0 < w := by sorry
3313- have : 0 < a_length' := by sorry
3314- have h4 : i * w + w ≤ a_length' * w := by
3315- refine add_le_of_le_sub (by
3316- refine Nat.le_mul_of_pos_left w ?_
3317- omega
3318- ) h3
3319- have : i * w + w ≤ a_length' * w := by omega
3320- omega
3321-
3322- sorry
3323- sorry
3324- · sorry
3238+ extractLsb' (2 * i * w) w b + if h : 2 * i + 1 < b_length then extractLsb' ((2 * i + 1 ) * w) w b else 0 )
3239+ (hw : 0 < w)
3240+ (hlen : 0 < a_length)
3241+ (n : Nat)
3242+ (hn : n = a_length)
3243+ :
3244+ recursive_addition a a_length = recursive_addition b b_length := by
3245+ induction n generalizing a b
3246+ · have hb : b_length = 0 := by omega
3247+ have ha : a_length = 0 := by omega
3248+ simp [ha, hb, recursive_addition]
3249+ ·
3250+ sorry
33253251
33263252/-- construct the parallel prefix sum circuit of the flattend bitvectors in `l` -/
33273253def pps (l : BitVec (l_length * w)) (k: BitVec w)
33283254 (proof : recursive_addition l l_length = k)
3329- (proof_length : 0 < l_length) :
3255+ (proof_length : 0 < l_length) (hw : 0 < w) :
33303256 {ls : BitVec (1 * w) // recursive_addition ls 1 = k} :=
33313257 if h : l_length = 1 then
33323258 ⟨l.cast (by simp [h]), by
@@ -3339,8 +3265,8 @@ def pps (l : BitVec (l_length * w)) (k: BitVec w)
33393265 let l_length' := (l_length + 1 ) / 2
33403266 let proof_sum_eq : recursive_addition new_layer ((l_length + 1 ) / 2 ) = k := by
33413267 rw [← proof]
3342- apply rec_add_eq_rec_add_iff (a := new_layer) (by omega) (b := l) (by omega)
3268+ apply rec_add_eq_rec_add_iff (a := new_layer) (by omega) (b := l) (by omega) ( by omega) ( by omega) (n := (l_length + 1 ) / 2 ) ( by omega)
33433269 let proof_new_layer_length : 0 < l_length' := by omega
3344- pps new_layer k proof_sum_eq proof_new_layer_length
3270+ pps new_layer k proof_sum_eq proof_new_layer_length hw
33453271
33463272end BitVec
0 commit comments