Skip to content

Commit e8cf28a

Browse files
committed
Fix audit comments in check_relation_uses
1 parent 0f461f5 commit e8cf28a

3 files changed

Lines changed: 61 additions & 22 deletions

File tree

crates/cairo_air/src/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ impl<Value: IValue> Statement<Value> for CairoStatement<Value> {
391391
&self,
392392
context: &mut Context<Value>,
393393
component_sizes: &[Var],
394-
shifted_relation_uses: &HashMap<&'static str, Var>,
394+
shifted_relation_uses: &HashMap<String, Var>,
395395
) {
396396
let PublicData { initial_state, final_state, public_memory: _ } = &self.public_data;
397397

crates/stark_verifier/src/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pub trait Statement<Value: IValue> {
6060
&self,
6161
_context: &mut Context<Value>,
6262
_component_sizes: &[Var],
63-
_shifted_relation_uses: &HashMap<&'static str, Var>,
63+
_shifted_relation_uses: &HashMap<String, Var>,
6464
) {
6565
}
6666
}

crates/stark_verifier/src/verify.rs

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -254,36 +254,53 @@ pub fn verify<Value: IValue>(
254254
/// where the sum is over all the components that use the relation.
255255
///
256256
/// To avoid overflows when computing the sum, we check
257-
/// sum(uses_per_row * (floor(num_rows / DIV) + 1)) < floor(P / DIV)
257+
/// sum(uses_per_row * (floor(num_rows / DIV) + 1)) <= floor(P / DIV)
258258
/// where DIV = 2 ** RELATION_USES_NUM_ROWS_SHIFT
259259
fn check_relation_uses<Value: IValue>(
260260
context: &mut Context<impl IValue>,
261261
statement: &impl Statement<Value>,
262262
component_sizes_bits: &[Simd],
263-
) -> HashMap<&'static str, Var> {
263+
) -> HashMap<String, Var> {
264264
let components = statement.get_components();
265265

266-
// Check that sum(uses_per_row * (floor(num_rows / DIV) + 1)) cannot overflow even for the
267-
// maximal num_rows (num_rows = P).
266+
let component_size_upper_bound = 1u64 << component_sizes_bits.len();
267+
let shifted_component_size_upper_bound =
268+
(component_size_upper_bound >> RELATION_USES_NUM_ROWS_SHIFT) + 1;
269+
let shifted_use_count_upper_bound = P >> 1;
270+
271+
// Check that sum(uses_per_row * (floor(num_rows / DIV) + 1)) < shifted_use_count_upper_bound
272+
// even for the maximal num_rows (num_rows = component_size_upper_bound). This fact is used
273+
// later in this function when comparing the sum to floor(P / DIV).
268274
// This is a sanity check that `RELATION_USES_NUM_ROWS_SHIFT` is large enough for the given
269275
// statement, it does not depend on the specific assignment.
270276
let mut max_shifted_uses_per_relation = HashMap::<&str, u64>::new();
271277
for component in components.iter() {
272278
for relation_use in component.relation_uses_per_row() {
273279
let entry = max_shifted_uses_per_relation.entry(relation_use.relation_id).or_insert(0);
274-
*entry += relation_use.uses * (((P >> RELATION_USES_NUM_ROWS_SHIFT) + 1) as u64);
280+
*entry = entry
281+
.checked_add(relation_use.uses * shifted_component_size_upper_bound)
282+
.expect("Shifted num rows upper bound computation overflowed");
275283
}
276284
}
277-
assert!(max_shifted_uses_per_relation.values().all(|count| *count < (P as u64)));
285+
assert!(
286+
max_shifted_uses_per_relation
287+
.values()
288+
.all(|count| *count < shifted_use_count_upper_bound.into())
289+
);
278290

279-
// Compute floor(num_rows / DIV) for all components
280-
let shifted_component_sizes = match component_sizes_bits.get(RELATION_USES_NUM_ROWS_SHIFT..) {
281-
Some(high_bits) => Simd::combine_bits(context, high_bits),
282-
None => Simd::zero(context, components.len()),
291+
// Compute floor(num_rows / DIV) + 1 for all components
292+
let shifted_component_sizes_p1 = match component_sizes_bits.get(RELATION_USES_NUM_ROWS_SHIFT..)
293+
{
294+
Some(high_bits) => {
295+
let one = Simd::one(context, components.len());
296+
let shifted_component_sizes = Simd::combine_bits(context, high_bits);
297+
Simd::add(context, &shifted_component_sizes, &one)
298+
}
299+
None => Simd::one(context, components.len()),
283300
};
284301
// A variable in the Simd vector might be unused in the case where all the corresponding
285302
// components don't use any relations.
286-
Simd::mark_partly_used(context, &shifted_component_sizes);
303+
Simd::mark_partly_used(context, &shifted_component_sizes_p1);
287304

288305
// Sum uses_per_row * (floor(num_rows / DIV) + 1) for all relations
289306
let mut shifted_relation_uses = HashMap::new();
@@ -292,25 +309,47 @@ fn check_relation_uses<Value: IValue>(
292309
if relation_uses.is_empty() {
293310
continue;
294311
}
295-
let shifted_size = Simd::unpack_idx(context, &shifted_component_sizes, i);
312+
let shifted_size_p1 = Simd::unpack_idx(context, &shifted_component_sizes_p1, i);
296313
for relation_use in relation_uses {
297-
let entry =
298-
shifted_relation_uses.entry(relation_use.relation_id).or_insert(context.zero());
299-
let uses_per_row =
300-
context.constant(TryInto::<u32>::try_into(relation_use.uses).unwrap().into());
301-
*entry = eval!(context, (*entry) + (((shifted_size) + (1)) * (uses_per_row)));
314+
let uses_per_row = context.constant(u32::try_from(relation_use.uses).unwrap().into());
315+
316+
let shifted_uses_upper_bound = eval!(context, (shifted_size_p1) * (uses_per_row));
317+
318+
shifted_relation_uses
319+
.entry(relation_use.relation_id.to_string())
320+
.and_modify(|entry| {
321+
*entry = eval!(context, (*entry) + (shifted_uses_upper_bound));
322+
})
323+
.or_insert(shifted_uses_upper_bound);
302324
}
303325
}
304326

305-
// Verify that the sum is less than floor(P / DIV) by expressing it as a
306-
// floor(log2(P / DIV))-bit number
307327
let shifted_use_counts = shifted_relation_uses
308328
.iter()
309329
.sorted_by_key(|(k, _v)| *k)
310330
.map(|(_k, v)| M31Wrapper::new_unsafe(*v))
311331
.collect_vec();
312332
let shifted_use_counts = Simd::pack(context, &shifted_use_counts);
313-
extract_bits(context, &shifted_use_counts, (P >> RELATION_USES_NUM_ROWS_SHIFT).ilog2());
333+
334+
// Verify that the sum is at most floor(P / DIV) by checking that floor(P / DIV) - sum is
335+
// positive or zero.
336+
let shifted_max_allowed_use_counts = P >> RELATION_USES_NUM_ROWS_SHIFT;
337+
let shifted_max_allowed_use_counts_simd =
338+
Simd::repeat(context, shifted_max_allowed_use_counts.into(), shifted_use_counts.len());
339+
let diff = Simd::sub(context, &shifted_max_allowed_use_counts_simd, &shifted_use_counts);
340+
341+
// If the difference is positive, it will fit in this many bits.
342+
let positive_diff_bits = shifted_max_allowed_use_counts.ilog2() + 1;
343+
344+
// Make sure that if the difference is negative, it won't fit in positive_diff_bits bits. Use
345+
// the check that sum < shifted_use_count_upper_bound from above.
346+
assert!(
347+
P + shifted_max_allowed_use_counts - shifted_use_count_upper_bound
348+
> (1 << positive_diff_bits)
349+
);
350+
351+
// Verify that the diff fits in positive_diff_bits bits.
352+
extract_bits(context, &diff, positive_diff_bits);
314353
shifted_relation_uses
315354
}
316355

0 commit comments

Comments
 (0)