@@ -568,6 +568,23 @@ theorem felt_ofNat_add (n m : Nat) :
568568 Felt.ofNat (n + m) = Felt.ofNat n + Felt.ofNat m := by
569569 simp [Felt.ofNat, Nat.cast_add]
570570
571+ /-- Reducing an addend modulo `m` before taking the outer modulo does not change the result. -/
572+ theorem add_mod_right_mod (x y m : Nat) : (x + y % m) % m = (x + y) % m := by
573+ rw [Nat.add_mod, Nat.mod_mod_of_dvd _ (dvd_refl m), ← Nat.add_mod]
574+
575+ /-- Nested modulo reduction under two additions can be flattened before the outer modulo. -/
576+ theorem add_add_mod_right_mod (a b c m : Nat) :
577+ (a + ((b + c % m) % m)) % m = (a + (b + c)) % m := by
578+ calc
579+ (a + ((b + c % m) % m)) % m = (a + (b + c % m)) % m := by
580+ rw [add_mod_right_mod]
581+ _ = ((a + b) + c % m) % m := by
582+ rw [Nat.add_assoc]
583+ _ = ((a + b) + c) % m := by
584+ rw [add_mod_right_mod]
585+ _ = (a + (b + c)) % m := by
586+ rw [Nat.add_assoc]
587+
571588-- ============================================================================
572589-- Carry chain bridging lemmas for addition
573590-- ============================================================================
@@ -617,6 +634,238 @@ theorem u128_add_carry_bridge (a b : U128) :
617634 simp only [HAdd.hAdd, Add.add, U128.ofNat, U128.toNat] at cc co ⊢
618635 exact ⟨congrArg _ cc.1 , congrArg _ cc.2 .1 , congrArg _ cc.2 .2 .1 , congrArg _ cc.2 .2 .2 , co⟩
619636
637+ -- ============================================================================
638+ -- Carry-chain bridging lemmas for q * b + r (used by u128 divmod)
639+ -- ============================================================================
640+
641+ /-- Column 0 of the `q * b + r` carry chain. -/
642+ def u128DivmodCol0 (q0 b0 r0 : Nat) : Nat :=
643+ q0 * b0 + r0
644+
645+ /-- Column 1 of the `q * b + r` carry chain. -/
646+ def u128DivmodCol1 (q0 q1 b0 b1 r0 r1 : Nat) : Nat :=
647+ q0 * b1 + q1 * b0 + r1 + u128DivmodCol0 q0 b0 r0 / 2 ^ 32
648+
649+ /-- Column 2 of the `q * b + r` carry chain. -/
650+ def u128DivmodCol2 (q0 q1 q2 b0 b1 b2 r0 r1 r2 : Nat) : Nat :=
651+ q0 * b2 + q1 * b1 + q2 * b0 + r2 + u128DivmodCol1 q0 q1 b0 b1 r0 r1 / 2 ^ 32
652+
653+ /-- The carry input accumulated before the final multiply-add in column 2. -/
654+ def u128DivmodCol2CarryIn (q0 q1 q2 b0 b1 r0 r1 : Nat) : Nat :=
655+ q2 * b0 + q1 * b1 + u128DivmodCol1 q0 q1 b0 b1 r0 r1 / 2 ^ 32
656+
657+ /-- Column 3 of the `q * b + r` carry chain. -/
658+ def u128DivmodCol3 (q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat) : Nat :=
659+ q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3 +
660+ u128DivmodCol2 q0 q1 q2 b0 b1 b2 r0 r1 r2 / 2 ^ 32
661+
662+ /-- Raw little-endian u128 value from four limbs. -/
663+ def u128RawValue (a0 a1 a2 a3 : Nat) : Nat :=
664+ a3 * 2 ^ 96 + a2 * 2 ^ 64 + a1 * 2 ^ 32 + a0
665+
666+ private theorem u128DivmodCol2CarryIn_mul_pow_le
667+ (q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat) :
668+ let base := 2 ^ 32
669+ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2 ≤
670+ u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 +
671+ u128RawValue r0 r1 r2 r3 := by
672+ let base := 2 ^ 32
673+ have hcarry0 : ((q0 * b0 + r0) / base) * base ≤ q0 * b0 + r0 := by
674+ exact Nat.div_mul_le_self _ _
675+ have hcarry1 :
676+ (u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ^ 2 ≤
677+ (q0 * b1 + q1 * b0 + r1) * base + (q0 * b0 + r0) := by
678+ have hdiv : (u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ≤
679+ u128DivmodCol1 q0 q1 b0 b1 r0 r1 := by
680+ exact Nat.div_mul_le_self _ _
681+ have hdiv' := Nat.mul_le_mul_right base hdiv
682+ have hbase :
683+ (u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ^ 2 ≤
684+ u128DivmodCol1 q0 q1 b0 b1 r0 r1 * base := by
685+ simpa [pow_two, Nat.mul_assoc, Nat.mul_left_comm, Nat.mul_comm] using hdiv'
686+ have hrepr :
687+ u128DivmodCol1 q0 q1 b0 b1 r0 r1 * base =
688+ (q0 * b1 + q1 * b0 + r1) * base + ((q0 * b0 + r0) / base) * base := by
689+ unfold u128DivmodCol1 u128DivmodCol0
690+ ring
691+ rw [hrepr] at hbase
692+ exact le_trans hbase (by omega)
693+ have hpre :
694+ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2 ≤
695+ (q0 * b0 + r0) + (q0 * b1 + q1 * b0 + r1) * base + (q2 * b0 + q1 * b1) * base ^ 2 := by
696+ have hmain := Nat.add_le_add_left hcarry1 ((q2 * b0 + q1 * b1) * base ^ 2 )
697+ have hsum :
698+ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2 =
699+ (q2 * b0 + q1 * b1) * base ^ 2 +
700+ (u128DivmodCol1 q0 q1 b0 b1 r0 r1 / base) * base ^ 2 := by
701+ unfold u128DivmodCol2CarryIn
702+ ring
703+ rw [hsum]
704+ simpa [Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hmain
705+ have hexpand :
706+ u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 + u128RawValue r0 r1 r2 r3 =
707+ ((q0 * b0 + r0) + (q0 * b1 + q1 * b0 + r1) * base + (q2 * b0 + q1 * b1) * base ^ 2 ) +
708+ ((q0 * b2 + r2) * base ^ 2 +
709+ (q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3) * base ^ 3 +
710+ (q1 * b3 + q2 * b2 + q3 * b1) * base ^ 4 +
711+ (q2 * b3 + q3 * b2) * base ^ 5 +
712+ (q3 * b3) * base ^ 6 ) := by
713+ unfold u128RawValue
714+ ring
715+ have hextra :
716+ 0 ≤ (q0 * b2 + r2) * base ^ 2 +
717+ (q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3) * base ^ 3 +
718+ (q1 * b3 + q2 * b2 + q3 * b1) * base ^ 4 +
719+ (q2 * b3 + q3 * b2) * base ^ 5 +
720+ (q3 * b3) * base ^ 6 := by
721+ have h0 : 0 ≤ (q0 * b2 + r2) * base ^ 2 := Nat.zero_le _
722+ have h1 : 0 ≤ (q0 * b3 + q1 * b2 + q2 * b1 + q3 * b0 + r3) * base ^ 3 := Nat.zero_le _
723+ have h2 : 0 ≤ (q1 * b3 + q2 * b2 + q3 * b1) * base ^ 4 := Nat.zero_le _
724+ have h3 : 0 ≤ (q2 * b3 + q3 * b2) * base ^ 5 := Nat.zero_le _
725+ have h4 : 0 ≤ (q3 * b3) * base ^ 6 := Nat.zero_le _
726+ omega
727+ have hbound :
728+ (q0 * b0 + r0) + (q0 * b1 + q1 * b0 + r1) * base + (q2 * b0 + q1 * b1) * base ^ 2 ≤
729+ u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 + u128RawValue r0 r1 r2 r3 := by
730+ nlinarith [hexpand, hextra]
731+ exact le_trans hpre hbound
732+
733+ /-- The carry input accumulated before the final multiply-add in column 2 fits below `2^64`
734+ whenever `q * b + r` fits in 128 bits. -/
735+ theorem u128DivmodCol2CarryIn_lt_2_64
736+ (q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat)
737+ (htotal :
738+ u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 +
739+ u128RawValue r0 r1 r2 r3 < 2 ^ 128 ) :
740+ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 < 2 ^ 64 := by
741+ let base := 2 ^ 32
742+ have hmul := u128DivmodCol2CarryIn_mul_pow_le q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3
743+ by_contra hnot
744+ have hge : base ^ 2 ≤ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 := by
745+ simpa [base, pow_two] using Nat.not_lt.mp hnot
746+ have hpow : 2 ^ 128 ≤ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 * base ^ 2 := by
747+ have hmul' := Nat.mul_le_mul_right (base ^ 2 ) hge
748+ simpa [base, pow_two, Nat.mul_assoc, Nat.mul_left_comm, Nat.mul_comm] using hmul'
749+ exact (not_lt_of_ge (le_trans hpow hmul)) htotal
750+
751+ /-- Dividing the column-2 carry input by `2^32` yields a u32 carry word. -/
752+ theorem u128DivmodCol2CarryIn_div_lt_2_32
753+ (q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 : Nat)
754+ (htotal :
755+ u128RawValue q0 q1 q2 q3 * u128RawValue b0 b1 b2 b3 +
756+ u128RawValue r0 r1 r2 r3 < 2 ^ 128 ) :
757+ u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 / 2 ^ 32 < 2 ^ 32 := by
758+ have hlt := u128DivmodCol2CarryIn_lt_2_64 q0 q1 q2 q3 b0 b1 b2 b3 r0 r1 r2 r3 htotal
759+ have hlt' : u128DivmodCol2CarryIn q0 q1 q2 b0 b1 r0 r1 < (2 ^ 32 ) * (2 ^ 32 ) := by
760+ simpa [pow_two] using hlt
761+ exact Nat.div_lt_of_lt_mul hlt'
762+
763+ /-- If `q * b` fits below `2^128`, every product term that contributes only above the
764+ low 128 bits must vanish. -/
765+ theorem u128HighTermsZeroOfMulLt
766+ (q0 q1 q2 q3 b0 b1 b2 b3 : Nat)
767+ (hq0 : q0 < 2 ^ 32 ) (hq1 : q1 < 2 ^ 32 ) (hq2 : q2 < 2 ^ 32 ) (hq3 : q3 < 2 ^ 32 )
768+ (hb0 : b0 < 2 ^ 32 ) (hb1 : b1 < 2 ^ 32 ) (hb2 : b2 < 2 ^ 32 ) (hb3 : b3 < 2 ^ 32 )
769+ (hmul :
770+ (q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
771+ (b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) < 2 ^ 128 ) :
772+ q1 * b3 = 0 ∧
773+ q2 * b2 = 0 ∧
774+ q3 * b1 = 0 ∧
775+ q2 * b3 = 0 ∧
776+ q3 * b2 = 0 ∧
777+ q3 * b3 = 0 := by
778+ refine ⟨?_, ?_, ?_, ?_, ?_, ?_⟩
779+ · by_contra hneq
780+ have hpos : 0 < q1 * b3 := by omega
781+ have hbound :
782+ 2 ^ 128 ≤
783+ (q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
784+ (b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
785+ have hq : q1 * 2 ^ 32 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
786+ have hb : b3 * 2 ^ 96 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
787+ have hmain : 2 ^ 128 ≤ (q1 * 2 ^ 32 ) * (b3 * 2 ^ 96 ) := by
788+ calc
789+ 2 ^ 128 = 1 * 1 * 2 ^ 128 := by ring
790+ _ ≤ q1 * b3 * 2 ^ 128 := by
791+ have : 1 ≤ q1 * b3 := by omega
792+ nlinarith
793+ _ = (q1 * 2 ^ 32 ) * (b3 * 2 ^ 96 ) := by ring
794+ exact le_trans hmain (Nat.mul_le_mul hq hb)
795+ omega
796+ · by_contra hneq
797+ have hpos : 0 < q2 * b2 := by omega
798+ have hq : q2 * 2 ^ 64 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
799+ have hb : b2 * 2 ^ 64 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
800+ have hmain : 2 ^ 128 ≤ (q2 * 2 ^ 64 ) * (b2 * 2 ^ 64 ) := by
801+ calc
802+ 2 ^ 128 = 1 * 1 * 2 ^ 128 := by ring
803+ _ ≤ q2 * b2 * 2 ^ 128 := by
804+ have : 1 ≤ q2 * b2 := by omega
805+ nlinarith
806+ _ = (q2 * 2 ^ 64 ) * (b2 * 2 ^ 64 ) := by ring
807+ exact (not_lt_of_ge (le_trans hmain (Nat.mul_le_mul hq hb))) hmul |> False.elim
808+ · by_contra hneq
809+ have hpos : 0 < q3 * b1 := by omega
810+ have hq : q3 * 2 ^ 96 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
811+ have hb : b1 * 2 ^ 32 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
812+ have hmain : 2 ^ 128 ≤ (q3 * 2 ^ 96 ) * (b1 * 2 ^ 32 ) := by
813+ calc
814+ 2 ^ 128 = 1 * 1 * 2 ^ 128 := by ring
815+ _ ≤ q3 * b1 * 2 ^ 128 := by
816+ have : 1 ≤ q3 * b1 := by omega
817+ nlinarith
818+ _ = (q3 * 2 ^ 96 ) * (b1 * 2 ^ 32 ) := by ring
819+ exact (not_lt_of_ge (le_trans hmain (Nat.mul_le_mul hq hb))) hmul |> False.elim
820+ · by_contra hneq
821+ have hpos : 0 < q2 * b3 := by omega
822+ have hq : q2 * 2 ^ 64 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
823+ have hb : b3 * 2 ^ 96 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
824+ have hmain : 2 ^ 160 ≤ (q2 * 2 ^ 64 ) * (b3 * 2 ^ 96 ) := by
825+ calc
826+ 2 ^ 160 = 1 * 1 * 2 ^ 160 := by ring
827+ _ ≤ q2 * b3 * 2 ^ 160 := by
828+ have : 1 ≤ q2 * b3 := by omega
829+ nlinarith
830+ _ = (q2 * 2 ^ 64 ) * (b3 * 2 ^ 96 ) := by ring
831+ have : 2 ^ 128 ≤
832+ (q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
833+ (b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
834+ exact le_trans (by omega) (le_trans hmain (Nat.mul_le_mul hq hb))
835+ omega
836+ · by_contra hneq
837+ have hpos : 0 < q3 * b2 := by omega
838+ have hq : q3 * 2 ^ 96 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
839+ have hb : b2 * 2 ^ 64 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
840+ have hmain : 2 ^ 160 ≤ (q3 * 2 ^ 96 ) * (b2 * 2 ^ 64 ) := by
841+ calc
842+ 2 ^ 160 = 1 * 1 * 2 ^ 160 := by ring
843+ _ ≤ q3 * b2 * 2 ^ 160 := by
844+ have : 1 ≤ q3 * b2 := by omega
845+ nlinarith
846+ _ = (q3 * 2 ^ 96 ) * (b2 * 2 ^ 64 ) := by ring
847+ have : 2 ^ 128 ≤
848+ (q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
849+ (b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
850+ exact le_trans (by omega) (le_trans hmain (Nat.mul_le_mul hq hb))
851+ omega
852+ · by_contra hneq
853+ have hpos : 0 < q3 * b3 := by omega
854+ have hq : q3 * 2 ^ 96 ≤ q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0 := by omega
855+ have hb : b3 * 2 ^ 96 ≤ b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0 := by omega
856+ have hmain : 2 ^ 192 ≤ (q3 * 2 ^ 96 ) * (b3 * 2 ^ 96 ) := by
857+ calc
858+ 2 ^ 192 = 1 * 1 * 2 ^ 192 := by ring
859+ _ ≤ q3 * b3 * 2 ^ 192 := by
860+ have : 1 ≤ q3 * b3 := by omega
861+ nlinarith
862+ _ = (q3 * 2 ^ 96 ) * (b3 * 2 ^ 96 ) := by ring
863+ have : 2 ^ 128 ≤
864+ (q3 * 2 ^ 96 + q2 * 2 ^ 64 + q1 * 2 ^ 32 + q0) *
865+ (b3 * 2 ^ 96 + b2 * 2 ^ 64 + b1 * 2 ^ 32 + b0) := by
866+ exact le_trans (by omega) (le_trans hmain (Nat.mul_le_mul hq hb))
867+ omega
868+
620869-- ============================================================================
621870-- Comparison bridging lemmas
622871-- ============================================================================
0 commit comments