@@ -13,8 +13,10 @@ public import Init.Data.BitVec.Basic
1313public import Init.Data.BitVec.Folds
1414public import Init.Data.BitVec.Lemmas
1515public import Init.Data.ByteArray.Basic
16+ public import Init.Data.ByteArray.Lemmas
1617public import Init.Data.Function
1718public import Init.Data.List.Basic
19+ public import Init.Data.UInt.Lemmas
1820public import Init.Data.Vector.Basic
1921public import Init.Data.Vector.Lemmas
2022
@@ -53,6 +55,7 @@ theorem ofFnLE_nil (f : Fin 0 → Bool) : ofFnLE f = nil := by
5355 contradiction
5456
5557
58+ @[simp]
5659theorem getElem_ofFnLE (f : Fin w → Bool) (i : Nat) (h : i < w) :
5760 (ofFnLE f)[i] = f ⟨i, h⟩ := by
5861 unfold ofFnLE
@@ -86,6 +89,7 @@ theorem ofFnBE_nil (f : Fin 0 → Bool) : ofFnBE f = nil := by
8689 ext
8790 contradiction
8891
92+ @[simp]
8993theorem getElem_ofFnBE (f : Fin w → Bool) (i : Nat) (h : i < w) :
9094 (ofFnBE f)[i] = f ⟨w - 1 - i, by omega⟩ := by
9195 simp [ofFnBE, getElem_ofFnLE]
@@ -527,14 +531,6 @@ theorem toList_ofFn (f : Fin w → Bool) : (ofFnLE f).toListLE = List.ofFn f :=
527531theorem toArray_ofFn (f : Fin w → Bool) : (ofFnLE f).toArrayLE = Array.ofFn f := by
528532 ext <;> simp [getElem_ofFnLE]
529533
530- /--
531- Building a bitvector from its own indexing function is the identity.
532- -/
533- @[simp]
534- theorem ofFnLE_getElem (x : BitVec w) : ofFnLE (fun i => x[i]) = x := by
535- ext i
536- simp [getElem_ofFnLE]
537-
538534/--
539535Convert a bitvector to a vector of bools (LSB first).
540536
@@ -660,7 +656,6 @@ theorem toVectorLE_nil : toVectorLE nil = #v[] := by
660656theorem toVector_concat (x : BitVec w) (b : Bool) :
661657 toVectorLE (concat x b) = Vector.mk (#[b] ++ x.toArrayLE) (by simp +arith) := by
662658 apply Vector.ext
663- intro i
664659 simp [toVectorLE, Vector.getElem_mk]
665660
666661theorem toList_toVectorLE (x : BitVec w) : x.toVectorLE.toList = x.toListLE := by
@@ -682,7 +677,7 @@ Examples:
682677def toBytesLE (x : BitVec w) : ByteArray :=
683678 let numBytes := (w + 7 ) / 8
684679 ByteArray.mk <| Array.ofFn fun (i : Fin numBytes) =>
685- ((x >>> (i.val * 8 )).toNat &&& 0xFF ).toUInt8
680+ UInt8.ofBitVec (ofFnLE fun (j : Fin 8 ) => x.getLsbD (i.val * 8 + j.val))
686681
687682/--
688683Convert a bitvector to a byte array (big-endian).
@@ -697,10 +692,8 @@ Examples:
697692 -/
698693def toBytesBE (x : BitVec w) : ByteArray :=
699694 let numBytes := (w + 7 ) / 8
700- let xn := x.toNat
701695 ByteArray.mk <| Array.ofFn fun (i : Fin numBytes) =>
702- let shift := (numBytes - 1 - i) * 8
703- UInt8.ofNat ((xn >>> shift) &&& 0xFF )
696+ UInt8.ofBitVec (BitVec.ofFnLE fun (j : Fin 8 ) => x.getLsbD ((numBytes - 1 - i.val) * 8 + j.val))
704697
705698@[local simp]
706699theorem size_toBytesLE (x : BitVec w) : x.toBytesLE.size = (w + 7 ) / 8 := by
@@ -721,9 +714,7 @@ Examples:
721714* `ofBytesLE (ByteArray.mk #[0xFF, 0x0F]) = 0xFFF#16`
722715 -/
723716def ofBytesLE (bytes : ByteArray) : BitVec (bytes.size * 8 ) :=
724- let w := bytes.size * 8
725- (List.range bytes.size).foldl (init := (0 : BitVec w)) fun acc i =>
726- acc ||| (BitVec.ofNat w bytes[i]!.toNat <<< (i * 8 ))
717+ ofFnLE fun i => bytes[i.val / 8 ].toBitVec[i.val % 8 ]
727718
728719/--
729720Build a bitvector from a byte array (big-endian).
@@ -736,8 +727,127 @@ Examples:
736727* `ofBytesBE (ByteArray.mk #[0x0F, 0xFF]) = 0xFFF#16`
737728 -/
738729def ofBytesBE (bytes : ByteArray) : BitVec (bytes.size * 8 ) :=
739- let w := bytes.size * 8
740- (List.range bytes.size).foldl (init := (0 : BitVec w)) fun acc i =>
741- acc ||| (BitVec.ofNat w bytes[i]!.toNat <<< (w - i * 8 - 8 ))
742-
730+ ofFnLE fun i => bytes[bytes.size - 1 - i.val / 8 ].toBitVec[i.val % 8 ]
731+
732+ theorem getElem_toBytesLE (x : BitVec w) (i : Nat) (h : i < (w + 7 ) / 8 ) :
733+ x.toBytesLE[i]'(size_toBytesLE _ ▸ h) =
734+ UInt8.ofBitVec (ofFnLE fun (j : Fin 8 ) => x.getLsbD (i * 8 + j)) := by
735+ simp [toBytesLE, ByteArray.getElem_eq_getElem_data]
736+
737+ theorem getElem_toBytesBE (x : BitVec w) (i : Nat) (h : i < (w + 7 ) / 8 ) :
738+ x.toBytesBE[i]'(size_toBytesBE _ ▸ h) =
739+ UInt8.ofBitVec (ofFnLE fun (j : Fin 8 ) => x.getLsbD (((w + 7 ) / 8 - 1 - i) * 8 + j)) := by
740+ simp [toBytesBE, ByteArray.getElem_eq_getElem_data]
741+
742+ /-! ### Bit access lemmas -/
743+
744+ theorem getElem_ofBytesLE (bytes : ByteArray) (j : Nat) (h : j < bytes.size * 8 ) :
745+ (ofBytesLE bytes)[j] = bytes[j / 8 ].toBitVec[j % 8 ] := by
746+ simp [ofBytesLE]
747+
748+ theorem getElem_ofBytesBE (bytes : ByteArray) (j : Nat) (h : j < bytes.size * 8 ) :
749+ (ofBytesBE bytes)[j] = bytes[bytes.size - 1 - j / 8 ].toBitVec[j % 8 ] := by
750+ simp [ofBytesBE]
751+
752+ /-! ### Round-trip theorems -/
753+
754+ @[simp]
755+ theorem toBytesLE_ofBytesLE (bytes : ByteArray) :
756+ (ofBytesLE bytes).toBytesLE = bytes := by
757+ ext1
758+ apply Array.ext
759+ · simp only [ByteArray.size_data, size_toBytesLE]
760+ omega
761+ · intro i hi hi'
762+ rw [
763+ ←ByteArray.getElem_eq_getElem_data,
764+ getElem_toBytesLE (h := by simp_all),
765+ UInt8.eq_iff_toBitVec_eq]
766+ ext j
767+ simp only [getElem_ofFnLE]
768+ have : i * 8 + j < bytes.size * 8 := by
769+ simp_all only [ByteArray.size_data, size_toBytesLE]
770+ omega
771+ simp [getLsbD_eq_getElem this, getElem_ofBytesLE]
772+ congr <;> omega
773+
774+
775+ @[simp]
776+ theorem toBytesBE_ofBytesBE (bytes : ByteArray) :
777+ (ofBytesBE bytes).toBytesBE = bytes := by
778+ ext1
779+ apply Array.ext
780+ · simp [size_toBytesBE]
781+ omega
782+ · intro i hi hi'
783+ rw [
784+ ←ByteArray.getElem_eq_getElem_data,
785+ getElem_toBytesBE (h := by simp_all),
786+ UInt8.eq_iff_toBitVec_eq]
787+ apply eq_of_getLsbD_eq
788+ intro j hj
789+ have sz_eq : (bytes.size * 8 + 7 ) / 8 = bytes.size :=
790+ Nat.div_eq_of_lt_le (by omega) (by omega)
791+ have : ((bytes.size * 8 + 7 ) / 8 - 1 - i) * 8 + j < bytes.size * 8 := by
792+ rw [sz_eq]
793+ have : i < bytes.size := hi'
794+ omega
795+ simp [getLsbD_eq_getElem this, getElem_ofBytesBE]
796+ simp only [sz_eq]
797+ have idx_eq : bytes.size - 1 - ((bytes.size - 1 - i) * 8 + j) / 8 = i := calc
798+ bytes.size - 1 - ((bytes.size - 1 - i) * 8 + j) / 8
799+ _ = bytes.size - 1 - (bytes.size - 1 - i) := by
800+ congr 1
801+ exact Nat.div_eq_of_lt_le (by omega) (by omega)
802+ _ = i := by
803+ have : i < bytes.size := hi'
804+ omega
805+ simp [idx_eq, hj, ByteArray.getElem_eq_getElem_data, Nat.mod_eq_of_lt hj]
806+
807+ @[simp]
808+ theorem ofBytesLE_toBytesLE (x : BitVec w) :
809+ ofBytesLE x.toBytesLE = x.zeroExtend (x.toBytesLE.size * 8 ) := by
810+ apply eq_of_getLsbD_eq
811+ intro i hi
812+ rw [size_toBytesLE] at hi
813+ simp only [ofBytesLE, getLsbD_ofFnLE, size_toBytesLE]
814+ rw [dif_pos hi]
815+ have h_idx : i / 8 < (w + 7 ) / 8 := by
816+ apply Nat.div_lt_of_lt_mul
817+ rw [Nat.mul_comm]
818+ exact hi
819+ rw [getElem_toBytesLE (h := h_idx)]
820+ simp only [getElem_ofFnLE]
821+ have : i / 8 * 8 + i % 8 = i := by omega
822+ rw [this]
823+ simp [zeroExtend, hi]
824+
825+ @[simp]
826+ theorem ofBytesBE_toBytesBE (x : BitVec w) :
827+ ofBytesBE x.toBytesBE = x.zeroExtend (x.toBytesBE.size * 8 ) := by
828+ apply eq_of_getLsbD_eq
829+ intro i hi
830+ simp only [ofBytesBE, getLsbD_ofFnLE]
831+ rw [dif_pos hi]
832+ have h_idx : x.toBytesBE.size - 1 - i / 8 < (w + 7 ) / 8 := by
833+ rw [size_toBytesBE]
834+ have : i / 8 < (w + 7 ) / 8 := by
835+ apply Nat.div_lt_of_lt_mul
836+ rw [Nat.mul_comm]
837+ rw [size_toBytesBE] at hi
838+ exact hi
839+ omega
840+ rw [getElem_toBytesBE (h := h_idx)]
841+ simp only [getElem_ofFnLE]
842+ have : ((w + 7 ) / 8 - 1 - ((w + 7 ) / 8 - 1 - i / 8 )) * 8 + i % 8 = i := by
843+ rw [size_toBytesBE] at hi
844+ have : i / 8 < (w + 7 ) / 8 := by
845+ apply Nat.div_lt_of_lt_mul
846+ rw [Nat.mul_comm]
847+ exact hi
848+ omega
849+ rw [size_toBytesBE]
850+ rw [this]
851+ rw [size_toBytesBE] at hi
852+ simp [zeroExtend, hi]
743853end BitVec
0 commit comments