Skip to content

Commit 3352477

Browse files
committed
Add u128 divmod proof and clean Lean warnings
1 parent 0bde06f commit 3352477

File tree

4 files changed

+3192
-2
lines changed

4 files changed

+3192
-2
lines changed

MidenLean/Proofs/U128/Common.lean

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,23 @@ theorem felt_ofNat_add (n m : Nat) :
568568
Felt.ofNat (n + m) = Felt.ofNat n + Felt.ofNat m := by
569569
simp [Felt.ofNat, Nat.cast_add]
570570

571+
/-- Reducing an addend modulo `m` before taking the outer modulo does not change the result. -/
572+
theorem add_mod_right_mod (x y m : Nat) : (x + y % m) % m = (x + y) % m := by
573+
rw [Nat.add_mod, Nat.mod_mod_of_dvd _ (dvd_refl m), ← Nat.add_mod]
574+
575+
/-- Nested modulo reduction under two additions can be flattened before the outer modulo. -/
576+
theorem add_add_mod_right_mod (a b c m : Nat) :
577+
(a + ((b + c % m) % m)) % m = (a + (b + c)) % m := by
578+
calc
579+
(a + ((b + c % m) % m)) % m = (a + (b + c % m)) % m := by
580+
rw [add_mod_right_mod]
581+
_ = ((a + b) + c % m) % m := by
582+
rw [Nat.add_assoc]
583+
_ = ((a + b) + c) % m := by
584+
rw [add_mod_right_mod]
585+
_ = (a + (b + c)) % m := by
586+
rw [Nat.add_assoc]
587+
571588
-- ============================================================================
572589
-- Carry chain bridging lemmas for addition
573590
-- ============================================================================
@@ -617,6 +634,238 @@ theorem u128_add_carry_bridge (a b : U128) :
617634
simp only [HAdd.hAdd, Add.add, U128.ofNat, U128.toNat] at cc co ⊢
618635
exact ⟨congrArg _ cc.1, congrArg _ cc.2.1, congrArg _ cc.2.2.1, congrArg _ cc.2.2.2, co⟩
619636

637+
-- ============================================================================
638+
-- Carry-chain bridging lemmas for q * b + r (used by u128 divmod)
639+
-- ============================================================================
640+
641+
/-- Column 0 of the `q * b + r` carry chain. -/
642+
def u128DivmodCol0 (q0 b0 r0 : Nat) : Nat :=
643+
q0 * b0 + r0
644+
645+
/-- Column 1 of the `q * b + r` carry chain. -/
646+
def u128DivmodCol1 (q0 q1 b0 b1 r0 r1 : Nat) : Nat :=
647+
q0 * b1 + q1 * b0 + r1 + u128DivmodCol0 q0 b0 r0 / 2 ^ 32
648+
649+
/-- Column 2 of the `q * b + r` carry chain. -/
650+
def u128DivmodCol2 (q0 q1 q2 b0 b1 b2 r0 r1 r2 : Nat) : Nat :=
651+
q0 * b2 + q1 * b1 + q2 * b0 + r2 + u128DivmodCol1 q0 q1 b0 b1 r0 r1 / 2 ^ 32
652+
653+
/-- The carry input accumulated before the final multiply-add in column 2. -/
654+
def u128DivmodCol2CarryIn (q0 q1 q2 b0 b1 r0 r1 : Nat) : Nat :=
655+
q2 * b0 + q1 * b1 + u128DivmodCol1 q0 q1 b0 b1 r0 r1 / 2 ^ 32
656+
657+
/-- Column 3 of the `q * b + r` carry chain. -/
658+
def u128DivmodCol3 (q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat) : Nat :=
659+
q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3 +
660+
u128DivmodCol2 q0 q1 q2 b0 b1 b2 r0 r1 r2 / 2 ^ 32
661+
662+
/-- Raw little-endian u128 value from four limbs. -/
663+
def u128RawValue (a0 a1 a2 a3 : Nat) : Nat :=
664+
a3 * 2 ^ 96 + a2 * 2 ^ 64 + a1 * 2 ^ 32 + a0
665+
666+
private theorem u128DivmodCol2CarryIn_mul_pow_le
667+
(q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat) :
668+
let base := 2 ^ 32
669+
u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2
670+
u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 +
671+
u128RawValue r0 r1 r2 r3 := by
672+
let base := 2 ^ 32
673+
have hcarry0 : ((q0 * b0 + r0) / base) * base ≤ q0 * b0 + r0 := by
674+
exact Nat.div_mul_le_self _ _
675+
have hcarry1 :
676+
(u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ^ 2
677+
(q0 * b1 + q1 * b0 + r1) * base + (q0 * b0 + r0) := by
678+
have hdiv : (u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ≤
679+
u128DivmodCol1 q0 q1 b0 b1 r0 r1 := by
680+
exact Nat.div_mul_le_self _ _
681+
have hdiv' := Nat.mul_le_mul_right base hdiv
682+
have hbase :
683+
(u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ^ 2
684+
u128DivmodCol1 q0 q1 b0 b1 r0 r1 * base := by
685+
simpa [pow_two, Nat.mul_assoc, Nat.mul_left_comm, Nat.mul_comm] using hdiv'
686+
have hrepr :
687+
u128DivmodCol1 q0 q1 b0 b1 r0 r1 * base =
688+
(q0 * b1 + q1 * b0 + r1) * base + ((q0 * b0 + r0) / base) * base := by
689+
unfold u128DivmodCol1 u128DivmodCol0
690+
ring
691+
rw [hrepr] at hbase
692+
exact le_trans hbase (by omega)
693+
have hpre :
694+
u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2
695+
(q0 * b0 + r0) + (q0 * b1 + q1 * b0 + r1) * base + (q2 * b0 + q1 * b1) * base ^ 2 := by
696+
have hmain := Nat.add_le_add_left hcarry1 ((q2 * b0 + q1 * b1) * base ^ 2)
697+
have hsum :
698+
u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2 =
699+
(q2 * b0 + q1 * b1) * base ^ 2 +
700+
(u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ^ 2 := by
701+
unfold u128DivmodCol2CarryIn
702+
ring
703+
rw [hsum]
704+
simpa [Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hmain
705+
have hexpand :
706+
u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 + u128RawValue r0 r1 r2 r3 =
707+
((q0 * b0 + r0) + (q0 * b1 + q1 * b0 + r1) * base + (q2 * b0 + q1 * b1) * base ^ 2) +
708+
((q0 * b2 + r2) * base ^ 2 +
709+
(q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3) * base ^ 3 +
710+
(q1 * b3 + q2 * b2 + q3 * b1) * base ^ 4 +
711+
(q2 * b3 + q3 * b2) * base ^ 5 +
712+
(q3 * b3) * base ^ 6) := by
713+
unfold u128RawValue
714+
ring
715+
have hextra :
716+
0 ≤ (q0 * b2 + r2) * base ^ 2 +
717+
(q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3) * base ^ 3 +
718+
(q1 * b3 + q2 * b2 + q3 * b1) * base ^ 4 +
719+
(q2 * b3 + q3 * b2) * base ^ 5 +
720+
(q3 * b3) * base ^ 6 := by
721+
have h0 : 0 ≤ (q0 * b2 + r2) * base ^ 2 := Nat.zero_le _
722+
have h1 : 0 ≤ (q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3) * base ^ 3 := Nat.zero_le _
723+
have h2 : 0 ≤ (q1 * b3 + q2 * b2 + q3 * b1) * base ^ 4 := Nat.zero_le _
724+
have h3 : 0 ≤ (q2 * b3 + q3 * b2) * base ^ 5 := Nat.zero_le _
725+
have h4 : 0 ≤ (q3 * b3) * base ^ 6 := Nat.zero_le _
726+
omega
727+
have hbound :
728+
(q0 * b0 + r0) + (q0 * b1 + q1 * b0 + r1) * base + (q2 * b0 + q1 * b1) * base ^ 2
729+
u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 + u128RawValue r0 r1 r2 r3 := by
730+
nlinarith [hexpand, hextra]
731+
exact le_trans hpre hbound
732+
733+
/-- The carry input accumulated before the final multiply-add in column 2 fits below `2^64`
734+
whenever `q * b + r` fits in 128 bits. -/
735+
theorem u128DivmodCol2CarryIn_lt_2_64
736+
(q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat)
737+
(htotal :
738+
u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 +
739+
u128RawValue r0 r1 r2 r3 < 2 ^ 128) :
740+
u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 < 2 ^ 64 := by
741+
let base := 2 ^ 32
742+
have hmul := u128DivmodCol2CarryIn_mul_pow_le q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3
743+
by_contra hnot
744+
have hge : base ^ 2 ≤ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 := by
745+
simpa [base, pow_two] using Nat.not_lt.mp hnot
746+
have hpow : 2 ^ 128 ≤ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2 := by
747+
have hmul' := Nat.mul_le_mul_right (base ^ 2) hge
748+
simpa [base, pow_two, Nat.mul_assoc, Nat.mul_left_comm, Nat.mul_comm] using hmul'
749+
exact (not_lt_of_ge (le_trans hpow hmul)) htotal
750+
751+
/-- Dividing the column-2 carry input by `2^32` yields a u32 carry word. -/
752+
theorem u128DivmodCol2CarryIn_div_lt_2_32
753+
(q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat)
754+
(htotal :
755+
u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 +
756+
u128RawValue r0 r1 r2 r3 < 2 ^ 128) :
757+
u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 / 2 ^ 32 < 2 ^ 32 := by
758+
have hlt := u128DivmodCol2CarryIn_lt_2_64 q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 htotal
759+
have hlt' : u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 < (2 ^ 32) * (2 ^ 32) := by
760+
simpa [pow_two] using hlt
761+
exact Nat.div_lt_of_lt_mul hlt'
762+
763+
/-- If `q * b` fits below `2^128`, every product term that contributes only above the
764+
low 128 bits must vanish. -/
765+
theorem u128HighTermsZeroOfMulLt
766+
(q0 q1 q2 q3 b0 b1 b2 b3 : Nat)
767+
(hq0 : q0 < 2 ^ 32) (hq1 : q1 < 2 ^ 32) (hq2 : q2 < 2 ^ 32) (hq3 : q3 < 2 ^ 32)
768+
(hb0 : b0 < 2 ^ 32) (hb1 : b1 < 2 ^ 32) (hb2 : b2 < 2 ^ 32) (hb3 : b3 < 2 ^ 32)
769+
(hmul :
770+
(q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
771+
(b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) < 2 ^ 128) :
772+
q1 * b3 = 0
773+
q2 * b2 = 0
774+
q3 * b1 = 0
775+
q2 * b3 = 0
776+
q3 * b2 = 0
777+
q3 * b3 = 0 := by
778+
refine ⟨?_, ?_, ?_, ?_, ?_, ?_⟩
779+
· by_contra hneq
780+
have hpos : 0 < q1 * b3 := by omega
781+
have hbound :
782+
2 ^ 128
783+
(q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
784+
(b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
785+
have hq : q1 * 2 ^ 32 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
786+
have hb : b3 * 2 ^ 96 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
787+
have hmain : 2 ^ 128 ≤ (q1 * 2 ^ 32) * (b3 * 2 ^ 96) := by
788+
calc
789+
2 ^ 128 = 1 * 1 * 2 ^ 128 := by ring
790+
_ ≤ q1 * b3 * 2 ^ 128 := by
791+
have : 1 ≤ q1 * b3 := by omega
792+
nlinarith
793+
_ = (q1 * 2 ^ 32) * (b3 * 2 ^ 96) := by ring
794+
exact le_trans hmain (Nat.mul_le_mul hq hb)
795+
omega
796+
· by_contra hneq
797+
have hpos : 0 < q2 * b2 := by omega
798+
have hq : q2 * 2 ^ 64 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
799+
have hb : b2 * 2 ^ 64 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
800+
have hmain : 2 ^ 128 ≤ (q2 * 2 ^ 64) * (b2 * 2 ^ 64) := by
801+
calc
802+
2 ^ 128 = 1 * 1 * 2 ^ 128 := by ring
803+
_ ≤ q2 * b2 * 2 ^ 128 := by
804+
have : 1 ≤ q2 * b2 := by omega
805+
nlinarith
806+
_ = (q2 * 2 ^ 64) * (b2 * 2 ^ 64) := by ring
807+
exact (not_lt_of_ge (le_trans hmain (Nat.mul_le_mul hq hb))) hmul |> False.elim
808+
· by_contra hneq
809+
have hpos : 0 < q3 * b1 := by omega
810+
have hq : q3 * 2 ^ 96 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
811+
have hb : b1 * 2 ^ 32 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
812+
have hmain : 2 ^ 128 ≤ (q3 * 2 ^ 96) * (b1 * 2 ^ 32) := by
813+
calc
814+
2 ^ 128 = 1 * 1 * 2 ^ 128 := by ring
815+
_ ≤ q3 * b1 * 2 ^ 128 := by
816+
have : 1 ≤ q3 * b1 := by omega
817+
nlinarith
818+
_ = (q3 * 2 ^ 96) * (b1 * 2 ^ 32) := by ring
819+
exact (not_lt_of_ge (le_trans hmain (Nat.mul_le_mul hq hb))) hmul |> False.elim
820+
· by_contra hneq
821+
have hpos : 0 < q2 * b3 := by omega
822+
have hq : q2 * 2 ^ 64 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
823+
have hb : b3 * 2 ^ 96 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
824+
have hmain : 2 ^ 160 ≤ (q2 * 2 ^ 64) * (b3 * 2 ^ 96) := by
825+
calc
826+
2 ^ 160 = 1 * 1 * 2 ^ 160 := by ring
827+
_ ≤ q2 * b3 * 2 ^ 160 := by
828+
have : 1 ≤ q2 * b3 := by omega
829+
nlinarith
830+
_ = (q2 * 2 ^ 64) * (b3 * 2 ^ 96) := by ring
831+
have : 2 ^ 128
832+
(q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
833+
(b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
834+
exact le_trans (by omega) (le_trans hmain (Nat.mul_le_mul hq hb))
835+
omega
836+
· by_contra hneq
837+
have hpos : 0 < q3 * b2 := by omega
838+
have hq : q3 * 2 ^ 96 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
839+
have hb : b2 * 2 ^ 64 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
840+
have hmain : 2 ^ 160 ≤ (q3 * 2 ^ 96) * (b2 * 2 ^ 64) := by
841+
calc
842+
2 ^ 160 = 1 * 1 * 2 ^ 160 := by ring
843+
_ ≤ q3 * b2 * 2 ^ 160 := by
844+
have : 1 ≤ q3 * b2 := by omega
845+
nlinarith
846+
_ = (q3 * 2 ^ 96) * (b2 * 2 ^ 64) := by ring
847+
have : 2 ^ 128
848+
(q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
849+
(b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
850+
exact le_trans (by omega) (le_trans hmain (Nat.mul_le_mul hq hb))
851+
omega
852+
· by_contra hneq
853+
have hpos : 0 < q3 * b3 := by omega
854+
have hq : q3 * 2 ^ 96 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
855+
have hb : b3 * 2 ^ 96 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
856+
have hmain : 2 ^ 192 ≤ (q3 * 2 ^ 96) * (b3 * 2 ^ 96) := by
857+
calc
858+
2 ^ 192 = 1 * 1 * 2 ^ 192 := by ring
859+
_ ≤ q3 * b3 * 2 ^ 192 := by
860+
have : 1 ≤ q3 * b3 := by omega
861+
nlinarith
862+
_ = (q3 * 2 ^ 96) * (b3 * 2 ^ 96) := by ring
863+
have : 2 ^ 128
864+
(q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
865+
(b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
866+
exact le_trans (by omega) (le_trans hmain (Nat.mul_le_mul hq hb))
867+
omega
868+
620869
-- ============================================================================
621870
-- Comparison bridging lemmas
622871
-- ============================================================================

0 commit comments

Comments
 (0)