1- use ethereum_types :: U256 ;
1+ use alloy :: primitives :: { U256 , U512 , ruint :: UintTryFrom } ;
22
33/// Computes `x * q / d` rounding down.
44///
@@ -13,7 +13,9 @@ pub fn mul_ratio(x: U256, q: U256, d: U256) -> Option<U256> {
1313 return Some ( res / d) ;
1414 }
1515
16- x. full_mul ( q) . checked_div ( d. into ( ) ) ?. try_into ( ) . ok ( )
16+ // SAFETY: at this point !d.is_zero() upholds
17+ let div = ( x. widening_mul ( q) ) / U512 :: from ( d) ;
18+ U256 :: uint_try_from ( div) . ok ( )
1719}
1820
1921/// Computes `x * q / d` rounding up.
@@ -26,12 +28,85 @@ pub fn mul_ratio_ceil(x: U256, q: U256, d: U256) -> Option<U256> {
2628
2729 // fast path when math in U256 doesn't overflow
2830 if let Some ( p) = x. checked_mul ( q) {
29- let ( div, rem) = p . div_mod ( d) ;
30- return div. checked_add ( u8 :: from ( !rem. is_zero ( ) ) . into ( ) ) ;
31+ let ( div, rem) = ( p / d , p % d) ;
32+ return div. checked_add ( U256 :: from ( !rem. is_zero ( ) ) ) ;
3133 }
3234
33- let p = x. full_mul ( q) ;
34- let ( div, rem) = p. div_mod ( d. into ( ) ) ;
35- let result = U256 :: try_from ( div) . ok ( ) ?;
36- result. checked_add ( u8:: from ( !rem. is_zero ( ) ) . into ( ) )
35+ let p = x. widening_mul ( q) ;
36+ let d = U512 :: from ( d) ;
37+ // SAFETY: at this point !d.is_zero() upholds
38+ let ( div, rem) = ( p / d, p % d) ;
39+
40+ let result = U256 :: uint_try_from ( div) . ok ( ) ?;
41+ result. checked_add ( U256 :: from ( !rem. is_zero ( ) ) )
42+ }
43+
44+ #[ cfg( test) ]
45+ mod test {
46+ use {
47+ crate :: util:: math:: { mul_ratio, mul_ratio_ceil} ,
48+ alloy:: primitives:: U256 ,
49+ } ;
50+
51+ #[ test]
52+ fn mul_ratio_zero ( ) {
53+ assert ! ( mul_ratio( U256 :: from( 10 ) , U256 :: from( 10 ) , U256 :: ZERO ) . is_none( ) ) ;
54+ }
55+
56+ #[ test]
57+ fn mul_ratio_overflow ( ) {
58+ assert ! ( mul_ratio( U256 :: MAX , U256 :: from( 2 ) , U256 :: ONE ) . is_none( ) ) ;
59+ }
60+
61+ #[ test]
62+ fn mul_ratio_ceil_zero ( ) {
63+ assert ! ( mul_ratio_ceil( U256 :: from( 10 ) , U256 :: from( 10 ) , U256 :: ZERO ) . is_none( ) ) ;
64+ }
65+
66+ #[ test]
67+ fn mul_ratio_ceil_overflow ( ) {
68+ assert ! ( mul_ratio_ceil( U256 :: MAX , U256 :: from( 2 ) , U256 :: ONE ) . is_none( ) ) ;
69+ }
70+
71+ #[ test]
72+ fn mul_ratio_normal ( ) {
73+ // Exact division
74+ assert_eq ! (
75+ mul_ratio( U256 :: from( 100 ) , U256 :: from( 5 ) , U256 :: from( 10 ) ) ,
76+ Some ( U256 :: from( 50 ) )
77+ ) ;
78+
79+ // Division with remainder (rounds down)
80+ assert_eq ! (
81+ mul_ratio( U256 :: from( 100 ) , U256 :: from( 3 ) , U256 :: from( 10 ) ) ,
82+ Some ( U256 :: from( 30 ) )
83+ ) ;
84+
85+ // Large values that don't overflow
86+ assert_eq ! (
87+ mul_ratio( U256 :: from( u128 :: MAX ) , U256 :: from( 2 ) , U256 :: from( 4 ) ) ,
88+ Some ( U256 :: from( u128 :: MAX / 2 ) )
89+ ) ;
90+ }
91+
92+ #[ test]
93+ fn mul_ratio_ceil_normal ( ) {
94+ // Exact division (no rounding needed)
95+ assert_eq ! (
96+ mul_ratio_ceil( U256 :: from( 100 ) , U256 :: from( 5 ) , U256 :: from( 10 ) ) ,
97+ Some ( U256 :: from( 50 ) )
98+ ) ;
99+
100+ // Division with remainder (rounds up)
101+ assert_eq ! (
102+ mul_ratio_ceil( U256 :: from( 10 ) , U256 :: from( 3 ) , U256 :: from( 4 ) ) ,
103+ Some ( U256 :: from( 8 ) )
104+ ) ;
105+
106+ // Large values that don't overflow
107+ assert_eq ! (
108+ mul_ratio_ceil( U256 :: from( u128 :: MAX ) , U256 :: from( 2 ) , U256 :: from( 4 ) ) ,
109+ Some ( U256 :: from( u128 :: MAX / 2 + 1 ) )
110+ ) ;
111+ }
37112}
0 commit comments