@@ -254,36 +254,53 @@ pub fn verify<Value: IValue>(
254254/// where the sum is over all the components that use the relation.
255255///
256256/// To avoid overflows when computing the sum, we check
257- /// sum(uses_per_row * (floor(num_rows / DIV) + 1)) < floor(P / DIV)
257+ /// sum(uses_per_row * (floor(num_rows / DIV) + 1)) <= floor(P / DIV)
258258/// where DIV = 2 ** RELATION_USES_NUM_ROWS_SHIFT
259259fn check_relation_uses < Value : IValue > (
260260 context : & mut Context < impl IValue > ,
261261 statement : & impl Statement < Value > ,
262262 component_sizes_bits : & [ Simd ] ,
263- ) -> HashMap < & ' static str , Var > {
263+ ) -> HashMap < String , Var > {
264264 let components = statement. get_components ( ) ;
265265
266- // Check that sum(uses_per_row * (floor(num_rows / DIV) + 1)) cannot overflow even for the
267- // maximal num_rows (num_rows = P).
266+ let component_size_upper_bound = 1u64 << component_sizes_bits. len ( ) ;
267+ let shifted_component_size_upper_bound =
268+ ( component_size_upper_bound >> RELATION_USES_NUM_ROWS_SHIFT ) + 1 ;
269+ let shifted_use_count_upper_bound = P >> 1 ;
270+
271+ // Check that sum(uses_per_row * (floor(num_rows / DIV) + 1)) < shifted_use_count_upper_bound
272+ // even for the maximal num_rows (num_rows = component_size_upper_bound). This fact is used
273+ // later in this function when comparing the sum to floor(P / DIV).
268274 // This is a sanity check that `RELATION_USES_NUM_ROWS_SHIFT` is large enough for the given
269275 // statement, it does not depend on the specific assignment.
270276 let mut max_shifted_uses_per_relation = HashMap :: < & str , u64 > :: new ( ) ;
271277 for component in components. iter ( ) {
272278 for relation_use in component. relation_uses_per_row ( ) {
273279 let entry = max_shifted_uses_per_relation. entry ( relation_use. relation_id ) . or_insert ( 0 ) ;
274- * entry += relation_use. uses * ( ( ( P >> RELATION_USES_NUM_ROWS_SHIFT ) + 1 ) as u64 ) ;
280+ * entry = entry
281+ . checked_add ( relation_use. uses * shifted_component_size_upper_bound)
282+ . expect ( "Shifted num rows upper bound computation overflowed" ) ;
275283 }
276284 }
277- assert ! ( max_shifted_uses_per_relation. values( ) . all( |count| * count < ( P as u64 ) ) ) ;
285+ assert ! (
286+ max_shifted_uses_per_relation
287+ . values( )
288+ . all( |count| * count < shifted_use_count_upper_bound. into( ) )
289+ ) ;
278290
279- // Compute floor(num_rows / DIV) for all components
280- let shifted_component_sizes = match component_sizes_bits. get ( RELATION_USES_NUM_ROWS_SHIFT ..) {
281- Some ( high_bits) => Simd :: combine_bits ( context, high_bits) ,
282- None => Simd :: zero ( context, components. len ( ) ) ,
291+ // Compute floor(num_rows / DIV) + 1 for all components
292+ let shifted_component_sizes_p1 = match component_sizes_bits. get ( RELATION_USES_NUM_ROWS_SHIFT ..)
293+ {
294+ Some ( high_bits) => {
295+ let one = Simd :: one ( context, components. len ( ) ) ;
296+ let shifted_component_sizes = Simd :: combine_bits ( context, high_bits) ;
297+ Simd :: add ( context, & shifted_component_sizes, & one)
298+ }
299+ None => Simd :: one ( context, components. len ( ) ) ,
283300 } ;
284301 // A variable in the Simd vector might be unused in the case where all the corresponding
285302 // components don't use any relations.
286- Simd :: mark_partly_used ( context, & shifted_component_sizes ) ;
303+ Simd :: mark_partly_used ( context, & shifted_component_sizes_p1 ) ;
287304
288305 // Sum uses_per_row * (floor(num_rows / DIV) + 1) for all relations
289306 let mut shifted_relation_uses = HashMap :: new ( ) ;
@@ -292,25 +309,47 @@ fn check_relation_uses<Value: IValue>(
292309 if relation_uses. is_empty ( ) {
293310 continue ;
294311 }
295- let shifted_size = Simd :: unpack_idx ( context, & shifted_component_sizes , i) ;
312+ let shifted_size_p1 = Simd :: unpack_idx ( context, & shifted_component_sizes_p1 , i) ;
296313 for relation_use in relation_uses {
297- let entry =
298- shifted_relation_uses. entry ( relation_use. relation_id ) . or_insert ( context. zero ( ) ) ;
299- let uses_per_row =
300- context. constant ( TryInto :: < u32 > :: try_into ( relation_use. uses ) . unwrap ( ) . into ( ) ) ;
301- * entry = eval ! ( context, ( * entry) + ( ( ( shifted_size) + ( 1 ) ) * ( uses_per_row) ) ) ;
314+ let uses_per_row = context. constant ( u32:: try_from ( relation_use. uses ) . unwrap ( ) . into ( ) ) ;
315+
316+ let shifted_uses_upper_bound = eval ! ( context, ( shifted_size_p1) * ( uses_per_row) ) ;
317+
318+ shifted_relation_uses
319+ . entry ( relation_use. relation_id . to_string ( ) )
320+ . and_modify ( |entry| {
321+ * entry = eval ! ( context, ( * entry) + ( shifted_uses_upper_bound) ) ;
322+ } )
323+ . or_insert ( shifted_uses_upper_bound) ;
302324 }
303325 }
304326
305- // Verify that the sum is less than floor(P / DIV) by expressing it as a
306- // floor(log2(P / DIV))-bit number
307327 let shifted_use_counts = shifted_relation_uses
308328 . iter ( )
309329 . sorted_by_key ( |( k, _v) | * k)
310330 . map ( |( _k, v) | M31Wrapper :: new_unsafe ( * v) )
311331 . collect_vec ( ) ;
312332 let shifted_use_counts = Simd :: pack ( context, & shifted_use_counts) ;
313- extract_bits ( context, & shifted_use_counts, ( P >> RELATION_USES_NUM_ROWS_SHIFT ) . ilog2 ( ) ) ;
333+
334+ // Verify that the sum is at most floor(P / DIV) by checking that floor(P / DIV) - sum is
335+ // positive or zero.
336+ let shifted_max_allowed_use_counts = P >> RELATION_USES_NUM_ROWS_SHIFT ;
337+ let shifted_max_allowed_use_counts_simd =
338+ Simd :: repeat ( context, shifted_max_allowed_use_counts. into ( ) , shifted_use_counts. len ( ) ) ;
339+ let diff = Simd :: sub ( context, & shifted_max_allowed_use_counts_simd, & shifted_use_counts) ;
340+
341+ // If the difference is positive, it will fit in this many bits.
342+ let positive_diff_bits = shifted_max_allowed_use_counts. ilog2 ( ) + 1 ;
343+
344+ // Make sure that if the difference is negative, it won't fit in positive_diff_bits bits. Use
345+ // the check that sum < shifted_use_count_upper_bound from above.
346+ assert ! (
347+ P + shifted_max_allowed_use_counts - shifted_use_count_upper_bound
348+ > ( 1 << positive_diff_bits)
349+ ) ;
350+
351+ // Verify that the diff fits in positive_diff_bits bits.
352+ extract_bits ( context, & diff, positive_diff_bits) ;
314353 shifted_relation_uses
315354}
316355
0 commit comments