Skip to content

Commit 5c9f78b

Browse files
committed
chore: case zero ind
1 parent 000dd04 commit 5c9f78b

File tree

1 file changed

+17
-91
lines changed

1 file changed

+17
-91
lines changed

src/Init/Data/BitVec/Bitblast.lean

Lines changed: 17 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,7 +3221,7 @@ theorem recursive_addition_concat_of_lt_two {a : BitVec (a_length * w)} (h : 2
32213221
a.extractLsb' ((a_length - 1) * w) w + a.extractLsb' ((a_length - 1 - 1) * w) w + (a.extractLsb' 0 ((a_length - 1 - 1) * w)).recursive_addition (l := a_length - 2) (w := w) (a_length - 1 - 1)
32223222
:= by sorry
32233223

3224-
theorem extractLsb'_extractLsb'_eq_of_lt (a : BitVec w) (hlt : i + k < len) :
3224+
theorem extractLsb'_extractLsb'_eq_of_lt (a : BitVec w) (hlt : i + k len) :
32253225
extractLsb' i k (extractLsb' 0 len a) =
32263226
extractLsb' i k a := by
32273227
ext j hj
@@ -3235,98 +3235,24 @@ theorem rec_add_eq_rec_add_iff
32353235
(b : BitVec (b_length * w))
32363236
(hadd : ∀ (i : Nat) (hi : i < (b_length + 1) / 2) (hi' : 2 * i < b_length),
32373237
extractLsb' (i * w) w a =
3238-
extractLsb' (2 * i * w) w b + if h : 2 * i + 1 < b_length then extractLsb' ((2 * i + 1) * w) w b else 0) :
3239-
recursive_addition a ((b_length + 1) / 2) = recursive_addition b b_length := by
3240-
induction a_length generalizing b_length b
3241-
· have hblen : b_length = 0 := by omega
3242-
simp [hblen, recursive_addition]
3243-
· case _ a_length' iha =>
3244-
have ha_eq : a = ((a.extractLsb' ((a_length' + 1 - 1) * w) w) ++ (a.extractLsb' 0 ((a_length' + 1 - 1) * w))).cast (by simp [Nat.add_mul]; omega) := by
3245-
ext j hj
3246-
simp only [getElem_cast, getElem_append, getElem_extractLsb']
3247-
split
3248-
· simp [← getLsbD_eq_getElem]
3249-
· rw [show (a_length' + 1 - 1) * w + (j - (a_length' + 1 - 1) * w) = j by omega, ← getLsbD_eq_getElem]
3250-
rw [ha_eq]
3251-
simp [← halen]
3252-
rw [recursive_addition_concat (a_length := a_length')]
3253-
rw [hadd (i := a_length') (by omega) (by omega)]
3254-
split
3255-
· case _ htrue =>
3256-
have hblenle: 2 ≤ b_length := by omega
3257-
-- b[2 * (a.length - 1)] + b[2 * (a.length - 1) + 1]
3258-
have heq : 2 * a_length' + 1 = b_length - 1 := by omega
3259-
have heq' : 2 * a_length' = b_length - 1 - 1 := by omega
3260-
have h_cast : w + (w + (b_length - 1 - 1) * w) = b_length * w := by
3261-
simp [show b_length = (a_length'+ 1) * 2 by omega]
3262-
simp [Nat.add_mul]
3263-
rw [Nat.mul_comm a_length' 2]
3264-
omega
3265-
have hb_eq : b = ((b.extractLsb' ((b_length - 1) * w) w) ++
3266-
((b.extractLsb' ((b_length - 1 - 1) * w) w) ++
3267-
(b.extractLsb' 0 ((b_length - 1 - 1) * w)))).cast h_cast := by
3268-
ext j hj
3269-
simp [getElem_append]
3270-
have : w + (b_length - 1 - 1) * w = (b_length - 1) * w := by
3271-
rw [show w + (b_length - 1 - 1) * w = 1 * w + (b_length - 1 - 1) * w by omega]
3272-
rw [← Nat.add_mul]
3273-
rw [show 1 + (b_length - 1 - 1) = b_length - 1 by omega]
3274-
rw [this]
3275-
split
3276-
· split
3277-
· rw [← getLsbD_eq_getElem]
3278-
· rw [show (b_length - 1 - 1) * w + (j - (b_length - 1 - 1) * w) = j by omega,
3279-
← getLsbD_eq_getElem]
3280-
· rw [show (b_length - 1) * w + (j - (b_length - 1) * w) = j by omega, ← getLsbD_eq_getElem]
3281-
have := recursive_addition_concat_of_lt_two (a_length := b_length) (a := b) (by omega)
3282-
conv =>
3283-
rhs
3284-
rw [hb_eq]
3285-
rw [this]
3286-
rw [heq, heq']
3287-
have : extractLsb' ((b_length - 1 - 1) * w) w b + extractLsb' ((b_length - 1) * w) w b =
3288-
extractLsb' ((b_length - 1) * w) w b + extractLsb' ((b_length - 1 - 1) * w) w b := by
3289-
exact
3290-
BitVec.add_comm (extractLsb' ((b_length - 1 - 1) * w) w b)
3291-
(extractLsb' ((b_length - 1) * w) w b)
3292-
rw [this]
3293-
specialize iha (b_length := b_length - 1 - 1)
3294-
(extractLsb' 0 (a_length' * w) a)
3295-
(by omega)
3296-
(extractLsb' 0 ((b_length - 1 - 1) * w) b)
3297-
rw [← iha]
3298-
· congr
3299-
omega
3300-
· intros i hi hi'
3301-
have hlt : i < a_length' := by omega
3302-
3303-
have heq'' : extractLsb' (i * w) w (extractLsb' 0 (a_length' * w) a) =
3304-
extractLsb' (i * w) w a := by
3305-
rw [extractLsb'_extractLsb'_eq_of_lt]
3306-
simp at *
3307-
have h1 : i ≤ a_length' - 1 := by omega
3308-
have h2 := Nat.mul_le_mul_right (n := i) (m := a_length' - 1) (k := w) (by omega)
3309-
have h3 : i * w ≤ a_length' * w - w := by
3310-
simp [Nat.sub_mul] at h2
3311-
omega
3312-
have : 0 < w := by sorry
3313-
have : 0 < a_length' := by sorry
3314-
have h4 : i * w + w ≤ a_length' * w := by
3315-
refine add_le_of_le_sub (by
3316-
refine Nat.le_mul_of_pos_left w ?_
3317-
omega
3318-
) h3
3319-
have : i * w + w ≤ a_length' * w := by omega
3320-
omega
3321-
3322-
sorry
3323-
sorry
3324-
· sorry
3238+
extractLsb' (2 * i * w) w b + if h : 2 * i + 1 < b_length then extractLsb' ((2 * i + 1) * w) w b else 0)
3239+
(hw : 0 < w)
3240+
(hlen : 0 < a_length)
3241+
(n : Nat)
3242+
(hn : n = a_length)
3243+
:
3244+
recursive_addition a a_length = recursive_addition b b_length := by
3245+
induction n generalizing a b
3246+
· have hb : b_length = 0 := by omega
3247+
have ha : a_length = 0 := by omega
3248+
simp [ha, hb, recursive_addition]
3249+
·
3250+
sorry
33253251

33263252
/-- construct the parallel prefix sum circuit of the flattend bitvectors in `l` -/
33273253
def pps (l : BitVec (l_length * w)) (k: BitVec w)
33283254
(proof : recursive_addition l l_length = k)
3329-
(proof_length : 0 < l_length) :
3255+
(proof_length : 0 < l_length) (hw : 0 < w) :
33303256
{ls : BitVec (1 * w) // recursive_addition ls 1 = k} :=
33313257
if h : l_length = 1 then
33323258
⟨l.cast (by simp [h]), by
@@ -3339,8 +3265,8 @@ def pps (l : BitVec (l_length * w)) (k: BitVec w)
33393265
let l_length' := (l_length + 1) / 2
33403266
let proof_sum_eq : recursive_addition new_layer ((l_length + 1) / 2) = k := by
33413267
rw [← proof]
3342-
apply rec_add_eq_rec_add_iff (a := new_layer) (by omega) (b := l) (by omega)
3268+
apply rec_add_eq_rec_add_iff (a := new_layer) (by omega) (b := l) (by omega) (by omega) (by omega) (n := (l_length + 1) / 2) (by omega)
33433269
let proof_new_layer_length : 0 < l_length' := by omega
3344-
pps new_layer k proof_sum_eq proof_new_layer_length
3270+
pps new_layer k proof_sum_eq proof_new_layer_length hw
33453271

33463272
end BitVec

0 commit comments

Comments
 (0)