diff --git a/crates/cairo_air/src/lib.rs b/crates/cairo_air/src/lib.rs index a032a1cb..b8243960 100644 --- a/crates/cairo_air/src/lib.rs +++ b/crates/cairo_air/src/lib.rs @@ -3,6 +3,7 @@ pub mod component_utils; pub mod components; pub mod preprocessed_columns; pub mod privacy; +pub mod relations; pub mod sample_evaluations; pub mod statement; pub mod utils; diff --git a/crates/cairo_air/src/relations.rs b/crates/cairo_air/src/relations.rs new file mode 100644 index 00000000..a21fd584 --- /dev/null +++ b/crates/cairo_air/src/relations.rs @@ -0,0 +1,3 @@ +pub const MEMORY_ADDRESS_TO_ID_RELATION_NAME: &str = "MemoryAddressToId"; +pub const MEMORY_ID_TO_BIG_RELATION_NAME: &str = "MemoryIdToBig"; +pub const OPCODES_RELATION_NAME: &str = "Opcodes"; diff --git a/crates/cairo_air/src/statement.rs b/crates/cairo_air/src/statement.rs index 20a5c8cb..6013495c 100644 --- a/crates/cairo_air/src/statement.rs +++ b/crates/cairo_air/src/statement.rs @@ -16,6 +16,7 @@ use circuits_stark_verifier::logup::logup_use_term; use circuits_stark_verifier::proof_from_stark_proof::pack_into_qm31s; use circuits_stark_verifier::verify::RELATION_USES_NUM_ROWS_SHIFT; use itertools::{Itertools, chain, izip, zip_eq}; +use num_traits::Zero; use stwo::core::fields::qm31::QM31; use stwo_cairo_common::builtins::{ BITWISE_BUILTIN_MEMORY_CELLS, PEDERSEN_BUILTIN_MEMORY_CELLS, POSEIDON_BUILTIN_MEMORY_CELLS, @@ -25,6 +26,9 @@ use stwo_constraint_framework::preprocessed_columns::PreProcessedColumnId; use crate::all_components::all_components; use crate::preprocessed_columns::MAX_SEQUENCE_LOG_SIZE; +use crate::relations::{ + MEMORY_ADDRESS_TO_ID_RELATION_NAME, MEMORY_ID_TO_BIG_RELATION_NAME, OPCODES_RELATION_NAME, +}; use circuits::context::{Context, Var}; use circuits::ivalue::{IValue, qm31_from_u32s}; use circuits::simd::Simd; @@ -50,6 +54,44 @@ pub const PUBLIC_DATA_LEN: usize = const LIMB_BITS: usize = 9; const SMALL_VALUE_BITS: u32 = 27; +#[derive(Default)] +struct LogupSum { + var: Option, + use_counts: HashMap, +} + +impl LogupSum { + pub fn add_use_term( + &mut self, + context: &mut Context, + relation_name: &str, + term: Var, + ) { + *self.use_counts.entry(relation_name.to_string()).or_insert(0) += 1; + self.var = self.var.map(|var| eval!(context, (var) + (term))).or(Some(term)); + } + + pub fn add_yield_term(&mut self, context: &mut Context, term: Var) { + self.var = self.var.map(|var| eval!(context, (var) - (term))).or(Some(term)); + } + + pub fn add_sum(&mut self, context: &mut Context, sum: &LogupSum) { + for (relation_name, count) in sum.use_counts.iter() { + *self.use_counts.entry(relation_name.clone()).or_insert(0) += count; + } + let sum_var = sum.sum(context); + self.var = self.var.map(|var| eval!(context, (var) + (sum_var))).or(Some(sum_var)); + } + + pub fn sum(&self, context: &mut Context) -> Var { + self.var.unwrap_or(context.constant(QM31::zero())) + } + + pub fn use_counts(&self) -> &HashMap { + &self.use_counts + } +} + pub struct CasmState { pub pc: T, pub ap: T, @@ -337,7 +379,7 @@ impl Statement for CairoStatement { &self, context: &mut Context, interaction_elements: [Var; 2], - ) -> Var { + ) -> (Var, HashMap) { let program_as_constants = self .program .iter() @@ -391,7 +433,7 @@ impl Statement for CairoStatement { &self, context: &mut Context, component_sizes: &[Var], - shifted_relation_uses: &HashMap<&'static str, Var>, + shifted_relation_uses: &HashMap, ) { let PublicData { initial_state, final_state, public_memory: _ } = &self.public_data; @@ -448,15 +490,15 @@ impl Statement for CairoStatement { } } -pub fn segment_ranges_logup_sum( +fn segment_ranges_logup_sum( context: &mut Context, interaction_elements: [Var; 2], segment_ranges: &[SegmentRange; N_SEGMENTS], mut argument_address: Var, mut return_value_address: Var, -) -> Var { +) -> LogupSum { let one = context.one(); - let mut sum = context.zero(); + let mut sum = LogupSum::default(); for (i, segment_range) in segment_ranges.iter().enumerate() { if i != 0 { argument_address = eval!(context, (argument_address) + (one)); @@ -471,7 +513,8 @@ pub fn segment_ranges_logup_sum( segment_range.start.id, &start_value_limbs, ); - sum = eval!(context, (sum) + (segment_start_logup_term)); + sum.add_sum(context, &segment_start_logup_term); + let end_value_limbs = split_27bit_to_9bit_limbs(context, segment_range.end.value); let segment_end_logup_term = public_memory_logup_terms( context, @@ -480,7 +523,7 @@ pub fn segment_ranges_logup_sum( segment_range.end.id, &end_value_limbs, ); - sum = eval!(context, (sum) + (segment_end_logup_term)); + sum.add_sum(context, &segment_end_logup_term); } sum @@ -494,7 +537,8 @@ fn public_memory_logup_terms<'a>( address: Var, id: Var, value_limbs: impl IntoIterator, -) -> Var { +) -> LogupSum { + let mut sum = LogupSum::default(); let memory_address_to_id_relation_id = context.constant(MEMORY_ADDRESS_TO_ID_RELATION_ID.into()); let address_to_id_logup_term = logup_use_term( @@ -502,23 +546,25 @@ fn public_memory_logup_terms<'a>( &[memory_address_to_id_relation_id, address, id], interaction_elements, ); + sum.add_use_term(context, MEMORY_ADDRESS_TO_ID_RELATION_NAME, address_to_id_logup_term); let memory_id_to_big_relation_id = context.constant(MEMORY_ID_TO_BIG_RELATION_ID.into()); let elements = chain!([memory_id_to_big_relation_id, id], value_limbs.into_iter().cloned()).collect_vec(); let id_to_value_logup_term = logup_use_term(context, &elements, interaction_elements); - eval!(context, (address_to_id_logup_term) + (id_to_value_logup_term)) + sum.add_use_term(context, MEMORY_ID_TO_BIG_RELATION_NAME, id_to_value_logup_term); + sum } -pub fn memory_segment_logup_sum( +fn memory_segment_logup_sum( context: &mut Context, interaction_elements: [Var; 2], start_address: Var, ids: &[Var], memory_values: &[[M31Wrapper; MEMORY_VALUES_LIMBS]], -) -> Var { +) -> LogupSum { let one = context.one(); - let mut sum = context.zero(); + let mut sum = LogupSum::default(); let mut address = start_address; for (i, (&id, value_limbs)) in zip_eq(ids, memory_values).enumerate() { @@ -533,7 +579,7 @@ pub fn memory_segment_logup_sum( id, value_limbs.iter().map(|limb| limb.get()), ); - sum = eval!(context, (sum) + (logup_term)); + sum.add_sum(context, &logup_term); } sum @@ -545,7 +591,7 @@ pub fn public_logup_sum( program: &[[M31Wrapper; MEMORY_VALUES_LIMBS]], outputs: &[[M31Wrapper; MEMORY_VALUES_LIMBS]], interaction_elements: [Var; 2], -) -> Var { +) -> (Var, HashMap) { let PublicData { initial_state, final_state, @@ -556,7 +602,9 @@ pub fn public_logup_sum( let final_state_logup_term = public_data.final_state.logup_term(context, interaction_elements); let initial_state_logup_term = public_data.initial_state.logup_term(context, interaction_elements); - let mut sum = eval!(context, (final_state_logup_term) - (initial_state_logup_term)); + let mut sum = LogupSum::default(); + sum.add_use_term(context, OPCODES_RELATION_NAME, final_state_logup_term); + sum.add_yield_term(context, initial_state_logup_term); let one = context.one(); let safe_call_addresses = vec![ @@ -575,7 +623,7 @@ pub fn public_logup_sum( for (address, id, value_limbs) in izip!(safe_call_addresses, safe_call_ids, safe_call_values) { let logup_term = public_memory_logup_terms(context, interaction_elements, address, *id, value_limbs); - sum = eval!(context, (sum) + (logup_term)); + sum.add_sum(context, &logup_term); } let argument_address = initial_ap; @@ -588,7 +636,7 @@ pub fn public_logup_sum( argument_address, return_value_address, ); - sum = eval!(context, (sum) + (segment_ranges_logup_sum)); + sum.add_sum(context, &segment_ranges_logup_sum); let output_logup_sum = memory_segment_logup_sum( context, @@ -597,7 +645,7 @@ pub fn public_logup_sum( output_ids, outputs, ); - sum = eval!(context, (sum) + (output_logup_sum)); + sum.add_sum(context, &output_logup_sum); let program_logup_sum = memory_segment_logup_sum( context, @@ -606,7 +654,7 @@ pub fn public_logup_sum( program_ids, program, ); - sum = eval!(context, (sum) + (program_logup_sum)); + sum.add_sum(context, &program_logup_sum); - sum + (sum.sum(context), sum.use_counts().clone()) } diff --git a/crates/cairo_air/src/statement_test.rs b/crates/cairo_air/src/statement_test.rs index 87cce950..cc370202 100644 --- a/crates/cairo_air/src/statement_test.rs +++ b/crates/cairo_air/src/statement_test.rs @@ -244,7 +244,7 @@ fn test_public_logup_sum() { let public_data = PublicData { initial_state, final_state, public_memory }; // Call public_logup_sum - let result = public_logup_sum( + let (result, _use_counts) = public_logup_sum( &mut context, &public_data, &program[..], diff --git a/crates/circuit_air/src/relations.rs b/crates/circuit_air/src/relations.rs index 95ffdec1..4ed78283 100644 --- a/crates/circuit_air/src/relations.rs +++ b/crates/circuit_air/src/relations.rs @@ -22,3 +22,5 @@ pub const VERIFY_BITWISE_XOR_8_RELATION_ID: u32 = 112558620; pub const VERIFY_BITWISE_XOR_8_B_RELATION_ID: u32 = 521092554; pub const VERIFY_BITWISE_XOR_9_RELATION_ID: u32 = 95781001; pub const VERIFY_BITWISE_XOR_12_RELATION_ID: u32 = 648362599; + +pub const GATE_RELATION_NAME: &str = "gate"; diff --git a/crates/circuit_air/src/statement.rs b/crates/circuit_air/src/statement.rs index 071d91c4..79c2f7bc 100644 --- a/crates/circuit_air/src/statement.rs +++ b/crates/circuit_air/src/statement.rs @@ -1,11 +1,13 @@ use crate::blake2s_consts::blake2s_initial_state; +use std::collections::HashMap; + use crate::circuit_eval_components::{ blake_g, blake_gate, blake_output, blake_round, blake_round_sigma, range_check_15, range_check_16, triple_xor_32, verify_bitwise_xor_4, verify_bitwise_xor_7, verify_bitwise_xor_8, verify_bitwise_xor_9, verify_bitwise_xor_12, }; use crate::components::{eq::CircuitEqComponent, qm31_ops::CircuitQm31OpsComponent}; -use crate::relations::{BLAKE_STATE_RELATION_ID, GATE_RELATION_ID}; +use crate::relations::{BLAKE_STATE_RELATION_ID, GATE_RELATION_ID, GATE_RELATION_NAME}; use circuits::blake::HashValue; use circuits::context::{Context, Var}; use circuits::eval; @@ -74,7 +76,8 @@ impl Statement for CircuitStatement { &self, context: &mut Context, interaction_elements: [Var; 2], - ) -> Var { + ) -> (Var, HashMap) { + let mut use_counts = HashMap::new(); let mut sum = context.zero(); // Output gates public logup sum contribution. @@ -98,6 +101,8 @@ impl Statement for CircuitStatement { ); sum = eval!(context, (sum) + (term)); } + *use_counts.entry(GATE_RELATION_NAME.to_string()).or_insert(0) += + self.output_addresses.len() as u64; // Blake IV public logup sum contribution. if self.n_blake_gates > 0 { @@ -123,7 +128,7 @@ impl Statement for CircuitStatement { sum = eval!(context, (sum) - (blake_iv_yield)); } - sum + (sum, use_counts) } fn get_preprocessed_column_ids(&self) -> Vec { diff --git a/crates/stark_verifier/src/statement.rs b/crates/stark_verifier/src/statement.rs index 83f93647..70bd4ee0 100644 --- a/crates/stark_verifier/src/statement.rs +++ b/crates/stark_verifier/src/statement.rs @@ -46,9 +46,13 @@ pub trait Statement { /// Returns the expected preprocessed trace root as circuit variables. fn get_preprocessed_root(&self, context: &mut Context) -> HashValue; - /// Returns the part of the logup sum determined by the public statement. - fn public_logup_sum(&self, context: &mut Context, interaction_elements: [Var; 2]) - -> Var; + /// Returns the part of the logup sum determined by the public statement, and + /// the number of uses it contains for each relation. + fn public_logup_sum( + &self, + context: &mut Context, + interaction_elements: [Var; 2], + ) -> (Var, HashMap); /// Returns statement-specific named parameters passed to component constraint evaluators. fn public_params(&self, _context: &mut Context) -> HashMap { @@ -60,7 +64,7 @@ pub trait Statement { &self, _context: &mut Context, _component_sizes: &[Var], - _shifted_relation_uses: &HashMap<&'static str, Var>, + _shifted_relation_uses: &HashMap, ) { } } diff --git a/crates/stark_verifier/src/verify.rs b/crates/stark_verifier/src/verify.rs index b16257c8..0d22109a 100644 --- a/crates/stark_verifier/src/verify.rs +++ b/crates/stark_verifier/src/verify.rs @@ -113,7 +113,8 @@ pub fn verify( // Pick the interaction elements. let [interaction_z, interaction_alpha] = channel.draw_two_qm31s(context); - let public_logup_sum = statement.public_logup_sum(context, [interaction_z, interaction_alpha]); + let (public_logup_sum, public_logup_uses) = + statement.public_logup_sum(context, [interaction_z, interaction_alpha]); validate_logup_sum(context, &enable_bits, public_logup_sum, &proof.claim.claimed_sums); channel.mix_qm31s(context, proof.claim.claimed_sums.iter().cloned()); @@ -127,7 +128,8 @@ pub fn verify( // Draw a random point for the OODS. let oods_point = channel.draw_point(context); - let shifted_relation_uses = check_relation_uses(context, statement, &component_sizes_bits); + let shifted_relation_uses = + check_relation_uses(context, statement, &public_logup_uses, &component_sizes_bits); let unpacked_component_sizes = Simd::unpack(context, &component_sizes); statement.verify_claim(context, &unpacked_component_sizes, &shifted_relation_uses); @@ -259,22 +261,38 @@ pub fn verify( fn check_relation_uses( context: &mut Context, statement: &impl Statement, + extra_uses: &HashMap, component_sizes_bits: &[Simd], -) -> HashMap<&'static str, Var> { +) -> HashMap { let components = statement.get_components(); - // Check that sum(uses_per_row * (floor(num_rows / DIV) + 1)) cannot overflow even for the - // maximal num_rows (num_rows = P). + let component_size_upper_bound = 1u64 << component_sizes_bits.len(); + let shifted_component_size_upper_bound = + (component_size_upper_bound >> RELATION_USES_NUM_ROWS_SHIFT) + 1; + let shifted_use_count_upper_bound = P >> 1; + + // Check that sum(uses_per_row * (floor(num_rows / DIV) + 1)) < shifted_use_count_upper_bound + // even for the maximal num_rows (num_rows = component_size_upper_bound). This fact is used + // later in this function when comparing the sum to floor(P / DIV). // This is a sanity check that `RELATION_USES_NUM_ROWS_SHIFT` is large enough for the given // statement, it does not depend on the specific assignment. let mut max_shifted_uses_per_relation = HashMap::<&str, u64>::new(); for component in components.iter() { for relation_use in component.relation_uses_per_row() { let entry = max_shifted_uses_per_relation.entry(relation_use.relation_id).or_insert(0); - *entry += relation_use.uses * (((P >> RELATION_USES_NUM_ROWS_SHIFT) + 1) as u64); + *entry = + entry.checked_add(relation_use.uses * shifted_component_size_upper_bound).unwrap(); } } - assert!(max_shifted_uses_per_relation.values().all(|count| *count < (P as u64))); + for (relation, num_uses) in extra_uses { + let entry = max_shifted_uses_per_relation.entry(relation).or_insert(0); + *entry = entry.checked_add((num_uses >> RELATION_USES_NUM_ROWS_SHIFT) + 1).unwrap(); + } + assert!( + max_shifted_uses_per_relation + .values() + .all(|count| *count < shifted_use_count_upper_bound.into()) + ); // Compute floor(num_rows / DIV) for all components let shifted_component_sizes = match component_sizes_bits.get(RELATION_USES_NUM_ROWS_SHIFT..) { @@ -285,8 +303,16 @@ fn check_relation_uses( // components don't use any relations. Simd::mark_partly_used(context, &shifted_component_sizes); - // Sum uses_per_row * (floor(num_rows / DIV) + 1) for all relations + // Sum uses_per_row * (floor(num_rows / DIV) + 1) + extra_uses for all relations let mut shifted_relation_uses = HashMap::new(); + for (relation, num_uses) in extra_uses { + shifted_relation_uses.insert( + relation.to_string(), + context.constant( + ((u32::try_from(*num_uses).unwrap() >> RELATION_USES_NUM_ROWS_SHIFT) + 1).into(), + ), + ); + } for (i, component) in components.iter().enumerate() { let relation_uses = component.relation_uses_per_row(); if relation_uses.is_empty() { @@ -294,23 +320,44 @@ fn check_relation_uses( } let shifted_size = Simd::unpack_idx(context, &shifted_component_sizes, i); for relation_use in relation_uses { - let entry = - shifted_relation_uses.entry(relation_use.relation_id).or_insert(context.zero()); - let uses_per_row = - context.constant(TryInto::::try_into(relation_use.uses).unwrap().into()); - *entry = eval!(context, (*entry) + (((shifted_size) + (1)) * (uses_per_row))); + let uses_per_row = context.constant(u32::try_from(relation_use.uses).unwrap().into()); + + let shifted_uses_upper_bound = eval!(context, ((shifted_size) + (1)) * (uses_per_row)); + + shifted_relation_uses + .entry(relation_use.relation_id.to_string()) + .and_modify(|entry| { + *entry = eval!(context, (*entry) + (shifted_uses_upper_bound)); + }) + .or_insert(shifted_uses_upper_bound); } } - // Verify that the sum is less than floor(P / DIV) by expressing it as a - // floor(log2(P / DIV))-bit number let shifted_use_counts = shifted_relation_uses .iter() .sorted_by_key(|(k, _v)| *k) .map(|(_k, v)| M31Wrapper::new_unsafe(*v)) .collect_vec(); let shifted_use_counts = Simd::pack(context, &shifted_use_counts); - extract_bits(context, &shifted_use_counts, (P >> RELATION_USES_NUM_ROWS_SHIFT).ilog2()); + + // Verify that the sum is less than floor(P / DIV) by checking that floor(P / DIV) - 1 - sum is + // positive or zero. + let shifted_max_allowed_use_counts = Simd::repeat( + context, + ((P >> RELATION_USES_NUM_ROWS_SHIFT) - 1).into(), + shifted_use_counts.len(), + ); + let diff = Simd::sub(context, &shifted_max_allowed_use_counts, &shifted_use_counts); + + // If the difference is positive, it will fit in this many bits. + let positive_diff_bits = (P >> RELATION_USES_NUM_ROWS_SHIFT).ilog2() + 1; + + // Make sure that if the difference is negative, it won't fit in positive_diff_bits bits. Use + // the check that sum < shifted_use_count_upper_bound from above. + assert!(P - shifted_use_count_upper_bound > (1 << positive_diff_bits)); + + // Verify that the diff fits in positive_diff_bits bits. + extract_bits(context, &diff, positive_diff_bits); shifted_relation_uses } diff --git a/crates/stark_verifier_examples/src/simple_statement.rs b/crates/stark_verifier_examples/src/simple_statement.rs index 01372e7c..7b662633 100644 --- a/crates/stark_verifier_examples/src/simple_statement.rs +++ b/crates/stark_verifier_examples/src/simple_statement.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use circuits::blake::HashValue; use itertools::zip_eq; use num_traits::One; @@ -134,7 +136,7 @@ impl Statement for SimpleStatement { &self, context: &mut Context, interaction_elements: [Var; 2], - ) -> Var { + ) -> (Var, HashMap) { let mut sum = context.zero(); for (component, log_n_instances) in zip_eq(&self.components, &COMPONENT_LOG_SIZES) { @@ -145,7 +147,8 @@ impl Statement for SimpleStatement { squared_fibonacci_public_logup_sum(context, interaction_elements, *log_n_instances); sum = eval!(context, (sum) + (fib_logup_sum)); } - sum + // The sum contains only yields - no uses + (sum, HashMap::new()) } fn get_preprocessed_column_ids(&self) -> Vec {