11mod mut_sparse_array ;
22use dep::sort::sort_advanced ;
33
4- unconstrained fn __sort_field_as_u32 (lhs : Field , rhs : Field ) -> bool {
4+ unconstrained fn __sort (lhs : u32 , rhs : u32 ) -> bool {
55 // lhs.lt(rhs)
6- lhs as u32 < rhs as u32
6+ lhs < rhs
77}
88
9- fn assert_sorted (lhs : Field , rhs : Field ) {
10- let result = (rhs - lhs - 1 );
11- result .assert_max_bit_size ::<32 >();
9+ fn assert_sorted (lhs : u32 , rhs : u32 ) {
10+ assert (lhs < rhs );
1211}
1312
1413/**
@@ -24,10 +23,10 @@ fn assert_sorted(lhs: Field, rhs: Field) {
2423 **/
2524struct MutSparseArrayBase <let N : u32 , T , ComparisonFuncs > {
2625 values : [T ; N + 3 ],
27- keys : [Field ; N + 2 ],
28- linked_keys : [Field ; N + 2 ],
29- tail_ptr : Field ,
30- maximum : Field ,
26+ keys : [u32 ; N + 2 ],
27+ linked_keys : [u32 ; N + 2 ],
28+ tail_ptr : u32 ,
29+ maximum : u32 ,
3130}
3231
3332struct U32RangeTraits {}
@@ -47,9 +46,9 @@ pub struct MutSparseArray<let N: u32, T> {
4746 * 2. values[0] is an empty object. when calling `get(idx)`, if `idx` is not in `keys` we will return `values[0]`
4847 **/
4948pub struct SparseArray <let N : u32 , T > {
50- keys : [Field ; N + 2 ],
49+ keys : [u32 ; N + 2 ],
5150 values : [T ; N + 3 ],
52- maximum : Field , // can be up to 2^32
51+ maximum : u32 , // can be up to 2^32 - 1
5352}
5453impl <let N : u32 , T > SparseArray <N , T >
5554where
@@ -59,15 +58,16 @@ where
5958 /**
6059 * @brief construct a SparseArray
6160 **/
62- pub (crate ) fn create (_keys : [Field ; N ], _values : [T ; N ], size : Field ) -> Self {
61+ pub (crate ) fn create (_keys : [u32 ; N ], _values : [T ; N ], size : u32 ) -> Self {
62+ assert (size >= 1 );
6363 let _maximum = size - 1 ;
6464 let mut r : Self =
6565 SparseArray { keys : [0 ; N + 2 ], values : [T ::default (); N + 3 ], maximum : _maximum };
6666
6767 // for any valid index, we want to ensure the following is satified:
6868 // self.keys[X] <= index <= self.keys[X+1]
6969 // this requires us to sort hte keys, and insert a startpoint and endpoint
70- let sorted_keys = sort_advanced (_keys , __sort_field_as_u32 , assert_sorted );
70+ let sorted_keys = sort_advanced (_keys , __sort , assert_sorted );
7171
7272 // insert start and endpoints
7373 r .keys [0 ] = 0 ;
@@ -103,45 +103,41 @@ where
103103 // because `self.keys` is sorted, we can simply validate that
104104 // sorted_keys.sorted[0] < 2^32
105105 // sorted_keys.sorted[N-1] < maximum
106- sorted_keys .sorted [0 ].assert_max_bit_size ::<32 >();
107- _maximum .assert_max_bit_size ::<32 >();
108- (_maximum - sorted_keys .sorted [N - 1 ]).assert_max_bit_size ::<32 >();
106+ assert (_maximum >= sorted_keys .sorted [N - 1 ]);
109107 r
110108 }
111109
112110 /**
113111 * @brief determine whether `target` is present in `self.keys`
114112 * @details if `found == false`, `self.keys[found_index] < target < self.keys[found_index + 1]`
115113 **/
116- unconstrained fn search_for_key (self , target : Field ) -> (Field , Field ) {
114+ unconstrained fn search_for_key (self , target : u32 ) -> (bool , u32 ) {
117115 let mut found = false ;
118- let mut found_index = 0 ;
116+ let mut found_index : u32 = 0 ;
119117 let mut previous_less_than_or_equal_to_target = false ;
120118 for i in 0 ..N + 2 {
121119 // if target = 0xffffffff we need to be able to add 1 here, so use u64
122120 let current_less_than_or_equal_to_target = self .keys [i ] as u64 <= target as u64 ;
123121 if (self .keys [i ] == target ) {
124122 found = true ;
125- found_index = i as Field ;
123+ found_index = i ;
126124 break ;
127125 }
128126 if (previous_less_than_or_equal_to_target & !current_less_than_or_equal_to_target ) {
129- found_index = i as Field - 1 ;
127+ found_index = i - 1 ;
130128 break ;
131129 }
132130 previous_less_than_or_equal_to_target = current_less_than_or_equal_to_target ;
133131 }
134- (found as Field , found_index )
132+ (found , found_index )
135133 }
136134
137135 /**
138136 * @brief return element `idx` from the sparse array
139137 * @details cost is 14.5 gates per lookup
140138 **/
141- fn get (self , idx : Field ) -> T {
139+ fn get (self , idx : u32 ) -> T {
142140 let (found , found_index ) = unsafe { self .search_for_key (idx ) };
143- // bool check. 0.25 gates cheaper than a raw `bool` type. need to fix at some point
144- assert (found * found == found );
145141
146142 // OK! So we have the following cases to check
147143 // 1. if `found` then `self.keys[found_index] == idx`
@@ -152,15 +148,13 @@ where
152148 // combine the two into the following single statement:
153149 // `self.keys[found_index] + 1 - found <= idx <= self.keys[found_index + 1 - found] - 1 + found
154150 let lhs = self .keys [found_index ];
155- let rhs = self .keys [found_index + 1 - found ];
156- let lhs_condition = idx - lhs - 1 + found ;
157- let rhs_condition = rhs - 1 + found - idx ;
158- lhs_condition .assert_max_bit_size ::<32 >();
159- rhs_condition .assert_max_bit_size ::<32 >();
151+ let rhs = self .keys [found_index + 1 - found as u32 ];
152+ assert (lhs + 1 - found as u32 <= idx );
153+ assert (idx <= rhs + found as u32 - 1 );
160154
161155 // self.keys[i] maps to self.values[i+1]
162156 // however...if we did not find a non-sparse entry, we want to return self.values[0] (the default value)
163- let value_index = (found_index + 1 ) * found ;
157+ let value_index = (found_index + 1 ) * found as u32 ;
164158 self .values [value_index ]
165159 }
166160}
@@ -179,7 +173,7 @@ mod test {
179173
180174 for i in 0 ..100 {
181175 if ((i != 1 ) & (i != 5 ) & (i != 7 ) & (i != 99 )) {
182- assert (example .get (i as Field ) == 0 );
176+ assert (example .get (i ) == 0 );
183177 }
184178 }
185179 }
@@ -188,34 +182,35 @@ mod test {
188182 fn test_sparse_lookup_boundary_cases () {
189183 // what about when keys[0] = 0 and keys[N-1] = 2^32 - 1?
190184 let example = SparseArray ::create (
191- [0 , 99999 , 7 , 0xffffffff ],
185+ [0 , 99999 , 7 , 0xfffffffe ],
192186 [123 , 101112 , 789 , 456 ],
193- 0x100000000 ,
187+ 0xffffffff ,
194188 );
195189
196190 assert (example .get (0 ) == 123 );
197191 assert (example .get (99999 ) == 101112 );
198192 assert (example .get (7 ) == 789 );
199- assert (example .get (0xffffffff ) == 456 );
200- assert (example .get (0xfffffffe ) == 0 );
193+ assert (example .get (0xfffffffe ) == 456 );
194+ assert (example .get (0xfffffffd ) == 0 );
201195 }
202196
203- #[test(should_fail_with = "call to assert_max_bit_size" )]
197+ #[test(should_fail )]
204198 fn test_sparse_lookup_overflow () {
205199 let example = SparseArray ::create ([1 , 5 , 7 , 99999 ], [123 , 456 , 789 , 101112 ], 100000 );
206200
207201 assert (example .get (100000 ) == 0 );
208202 }
209203
204+ /**
210205 #[test(should_fail_with = "call to assert_max_bit_size")]
211206 fn test_sparse_lookup_boundary_case_overflow() {
212207 let example =
213208 SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0x100000000);
214209
215210 assert(example.get(0x100000000) == 0);
216211 }
217-
218- #[test(should_fail_with = "call to assert_max_bit_size" )]
212+ **/
213+ #[test(should_fail )]
219214 fn test_sparse_lookup_key_exceeds_maximum () {
220215 let example =
221216 SparseArray ::create ([0 , 5 , 7 , 0xffffffff ], [123 , 456 , 789 , 101112 ], 0xffffffff );
@@ -236,7 +231,7 @@ mod test {
236231
237232 for i in 0 ..100 {
238233 if ((i != 1 ) & (i != 5 ) & (i != 7 ) & (i != 99 )) {
239- assert (example .get (i as Field ) == 0 );
234+ assert (example .get (i ) == 0 );
240235 }
241236 }
242237 }
@@ -272,7 +267,7 @@ mod test {
272267 assert (example .get (99 ) == values [1 ]);
273268 for i in 0 ..100 {
274269 if ((i != 1 ) & (i != 5 ) & (i != 7 ) & (i != 99 )) {
275- assert (example .get (i as Field ) == F ::default ());
270+ assert (example .get (i ) == F ::default ());
276271 }
277272 }
278273 }
0 commit comments