Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cairo_air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod component_utils;
pub mod components;
pub mod preprocessed_columns;
pub mod privacy;
pub mod relations;
pub mod sample_evaluations;
pub mod statement;
pub mod utils;
Expand Down
3 changes: 3 additions & 0 deletions crates/cairo_air/src/relations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub const MEMORY_ADDRESS_TO_ID_RELATION_NAME: &str = "MemoryAddressToId";
pub const MEMORY_ID_TO_BIG_RELATION_NAME: &str = "MemoryIdToBig";
pub const OPCODES_RELATION_NAME: &str = "Opcodes";
88 changes: 68 additions & 20 deletions crates/cairo_air/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use circuits_stark_verifier::logup::logup_use_term;
use circuits_stark_verifier::proof_from_stark_proof::pack_into_qm31s;
use circuits_stark_verifier::verify::RELATION_USES_NUM_ROWS_SHIFT;
use itertools::{Itertools, chain, izip, zip_eq};
use num_traits::Zero;
use stwo::core::fields::qm31::QM31;
use stwo_cairo_common::builtins::{
BITWISE_BUILTIN_MEMORY_CELLS, PEDERSEN_BUILTIN_MEMORY_CELLS, POSEIDON_BUILTIN_MEMORY_CELLS,
Expand All @@ -25,6 +26,9 @@ use stwo_constraint_framework::preprocessed_columns::PreProcessedColumnId;

use crate::all_components::all_components;
use crate::preprocessed_columns::MAX_SEQUENCE_LOG_SIZE;
use crate::relations::{
MEMORY_ADDRESS_TO_ID_RELATION_NAME, MEMORY_ID_TO_BIG_RELATION_NAME, OPCODES_RELATION_NAME,
};
use circuits::context::{Context, Var};
use circuits::ivalue::{IValue, qm31_from_u32s};
use circuits::simd::Simd;
Expand All @@ -50,6 +54,44 @@ pub const PUBLIC_DATA_LEN: usize =
const LIMB_BITS: usize = 9;
const SMALL_VALUE_BITS: u32 = 27;

#[derive(Default)]
struct LogupSum {
var: Option<Var>,
use_counts: HashMap<String, u64>,
}

impl LogupSum {
pub fn add_use_term(
&mut self,
context: &mut Context<impl IValue>,
relation_name: &str,
term: Var,
) {
*self.use_counts.entry(relation_name.to_string()).or_insert(0) += 1;
self.var = self.var.map(|var| eval!(context, (var) + (term))).or(Some(term));
}

pub fn add_yield_term(&mut self, context: &mut Context<impl IValue>, term: Var) {
self.var = self.var.map(|var| eval!(context, (var) - (term))).or(Some(term));
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_yield_term wrong sign when var is None

Medium Severity

add_yield_term uses .or(Some(term)) as the fallback when self.var is None, which sets the sum to +term instead of -term. A yield term represents a subtraction, so an empty LogupSum getting a yield term first would produce the wrong sign. Currently the only call site (line 607) always calls add_use_term first, so the None path isn't hit — but the method's contract is broken for that case, making it a latent correctness issue.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e82d50f. Configure here.


pub fn add_sum(&mut self, context: &mut Context<impl IValue>, sum: &LogupSum) {
for (relation_name, count) in sum.use_counts.iter() {
*self.use_counts.entry(relation_name.clone()).or_insert(0) += count;
}
let sum_var = sum.sum(context);
self.var = self.var.map(|var| eval!(context, (var) + (sum_var))).or(Some(sum_var));
}

pub fn sum(&self, context: &mut Context<impl IValue>) -> Var {
self.var.unwrap_or(context.constant(QM31::zero()))
}

pub fn use_counts(&self) -> &HashMap<String, u64> {
&self.use_counts
}
}

pub struct CasmState<T> {
pub pc: T,
pub ap: T,
Expand Down Expand Up @@ -337,7 +379,7 @@ impl<Value: IValue> Statement<Value> for CairoStatement<Value> {
&self,
context: &mut Context<Value>,
interaction_elements: [Var; 2],
) -> Var {
) -> (Var, HashMap<String, u64>) {
let program_as_constants = self
.program
.iter()
Expand Down Expand Up @@ -391,7 +433,7 @@ impl<Value: IValue> Statement<Value> for CairoStatement<Value> {
&self,
context: &mut Context<Value>,
component_sizes: &[Var],
shifted_relation_uses: &HashMap<&'static str, Var>,
shifted_relation_uses: &HashMap<String, Var>,
) {
let PublicData { initial_state, final_state, public_memory: _ } = &self.public_data;

Expand Down Expand Up @@ -448,15 +490,15 @@ impl<Value: IValue> Statement<Value> for CairoStatement<Value> {
}
}

pub fn segment_ranges_logup_sum(
fn segment_ranges_logup_sum(
context: &mut Context<impl IValue>,
interaction_elements: [Var; 2],
segment_ranges: &[SegmentRange<Var>; N_SEGMENTS],
mut argument_address: Var,
mut return_value_address: Var,
) -> Var {
) -> LogupSum {
let one = context.one();
let mut sum = context.zero();
let mut sum = LogupSum::default();
for (i, segment_range) in segment_ranges.iter().enumerate() {
if i != 0 {
argument_address = eval!(context, (argument_address) + (one));
Expand All @@ -471,7 +513,8 @@ pub fn segment_ranges_logup_sum(
segment_range.start.id,
&start_value_limbs,
);
sum = eval!(context, (sum) + (segment_start_logup_term));
sum.add_sum(context, &segment_start_logup_term);

let end_value_limbs = split_27bit_to_9bit_limbs(context, segment_range.end.value);
let segment_end_logup_term = public_memory_logup_terms(
context,
Expand All @@ -480,7 +523,7 @@ pub fn segment_ranges_logup_sum(
segment_range.end.id,
&end_value_limbs,
);
sum = eval!(context, (sum) + (segment_end_logup_term));
sum.add_sum(context, &segment_end_logup_term);
}

sum
Expand All @@ -494,31 +537,34 @@ fn public_memory_logup_terms<'a>(
address: Var,
id: Var,
value_limbs: impl IntoIterator<Item = &'a Var>,
) -> Var {
) -> LogupSum {
let mut sum = LogupSum::default();
let memory_address_to_id_relation_id =
context.constant(MEMORY_ADDRESS_TO_ID_RELATION_ID.into());
let address_to_id_logup_term = logup_use_term(
context,
&[memory_address_to_id_relation_id, address, id],
interaction_elements,
);
sum.add_use_term(context, MEMORY_ADDRESS_TO_ID_RELATION_NAME, address_to_id_logup_term);

let memory_id_to_big_relation_id = context.constant(MEMORY_ID_TO_BIG_RELATION_ID.into());
let elements =
chain!([memory_id_to_big_relation_id, id], value_limbs.into_iter().cloned()).collect_vec();
let id_to_value_logup_term = logup_use_term(context, &elements, interaction_elements);
eval!(context, (address_to_id_logup_term) + (id_to_value_logup_term))
sum.add_use_term(context, MEMORY_ID_TO_BIG_RELATION_NAME, id_to_value_logup_term);
sum
}

pub fn memory_segment_logup_sum(
fn memory_segment_logup_sum(
context: &mut Context<impl IValue>,
interaction_elements: [Var; 2],
start_address: Var,
ids: &[Var],
memory_values: &[[M31Wrapper<Var>; MEMORY_VALUES_LIMBS]],
) -> Var {
) -> LogupSum {
let one = context.one();
let mut sum = context.zero();
let mut sum = LogupSum::default();

let mut address = start_address;
for (i, (&id, value_limbs)) in zip_eq(ids, memory_values).enumerate() {
Expand All @@ -533,7 +579,7 @@ pub fn memory_segment_logup_sum(
id,
value_limbs.iter().map(|limb| limb.get()),
);
sum = eval!(context, (sum) + (logup_term));
sum.add_sum(context, &logup_term);
}

sum
Expand All @@ -545,7 +591,7 @@ pub fn public_logup_sum(
program: &[[M31Wrapper<Var>; MEMORY_VALUES_LIMBS]],
outputs: &[[M31Wrapper<Var>; MEMORY_VALUES_LIMBS]],
interaction_elements: [Var; 2],
) -> Var {
) -> (Var, HashMap<String, u64>) {
let PublicData {
initial_state,
final_state,
Expand All @@ -556,7 +602,9 @@ pub fn public_logup_sum(
let final_state_logup_term = public_data.final_state.logup_term(context, interaction_elements);
let initial_state_logup_term =
public_data.initial_state.logup_term(context, interaction_elements);
let mut sum = eval!(context, (final_state_logup_term) - (initial_state_logup_term));
let mut sum = LogupSum::default();
sum.add_use_term(context, OPCODES_RELATION_NAME, final_state_logup_term);
sum.add_yield_term(context, initial_state_logup_term);

let one = context.one();
let safe_call_addresses = vec![
Expand All @@ -575,7 +623,7 @@ pub fn public_logup_sum(
for (address, id, value_limbs) in izip!(safe_call_addresses, safe_call_ids, safe_call_values) {
let logup_term =
public_memory_logup_terms(context, interaction_elements, address, *id, value_limbs);
sum = eval!(context, (sum) + (logup_term));
sum.add_sum(context, &logup_term);
}

let argument_address = initial_ap;
Expand All @@ -588,7 +636,7 @@ pub fn public_logup_sum(
argument_address,
return_value_address,
);
sum = eval!(context, (sum) + (segment_ranges_logup_sum));
sum.add_sum(context, &segment_ranges_logup_sum);

let output_logup_sum = memory_segment_logup_sum(
context,
Expand All @@ -597,7 +645,7 @@ pub fn public_logup_sum(
output_ids,
outputs,
);
sum = eval!(context, (sum) + (output_logup_sum));
sum.add_sum(context, &output_logup_sum);

let program_logup_sum = memory_segment_logup_sum(
context,
Expand All @@ -606,7 +654,7 @@ pub fn public_logup_sum(
program_ids,
program,
);
sum = eval!(context, (sum) + (program_logup_sum));
sum.add_sum(context, &program_logup_sum);

sum
(sum.sum(context), sum.use_counts().clone())
}
2 changes: 1 addition & 1 deletion crates/cairo_air/src/statement_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ fn test_public_logup_sum() {
let public_data = PublicData { initial_state, final_state, public_memory };

// Call public_logup_sum
let result = public_logup_sum(
let (result, _use_counts) = public_logup_sum(
&mut context,
&public_data,
&program[..],
Expand Down
2 changes: 2 additions & 0 deletions crates/circuit_air/src/relations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ pub const VERIFY_BITWISE_XOR_8_RELATION_ID: u32 = 112558620;
pub const VERIFY_BITWISE_XOR_8_B_RELATION_ID: u32 = 521092554;
pub const VERIFY_BITWISE_XOR_9_RELATION_ID: u32 = 95781001;
pub const VERIFY_BITWISE_XOR_12_RELATION_ID: u32 = 648362599;

pub const GATE_RELATION_NAME: &str = "gate";
11 changes: 8 additions & 3 deletions crates/circuit_air/src/statement.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::blake2s_consts::blake2s_initial_state;
use std::collections::HashMap;

use crate::circuit_eval_components::{
blake_g, blake_gate, blake_output, blake_round, blake_round_sigma, range_check_15,
range_check_16, triple_xor_32, verify_bitwise_xor_4, verify_bitwise_xor_7,
verify_bitwise_xor_8, verify_bitwise_xor_9, verify_bitwise_xor_12,
};
use crate::components::{eq::CircuitEqComponent, qm31_ops::CircuitQm31OpsComponent};
use crate::relations::{BLAKE_STATE_RELATION_ID, GATE_RELATION_ID};
use crate::relations::{BLAKE_STATE_RELATION_ID, GATE_RELATION_ID, GATE_RELATION_NAME};
use circuits::blake::HashValue;
use circuits::context::{Context, Var};
use circuits::eval;
Expand Down Expand Up @@ -74,7 +76,8 @@ impl<Value: IValue> Statement<Value> for CircuitStatement<Value> {
&self,
context: &mut Context<Value>,
interaction_elements: [Var; 2],
) -> Var {
) -> (Var, HashMap<String, u64>) {
let mut use_counts = HashMap::new();
let mut sum = context.zero();

// Output gates public logup sum contribution.
Expand All @@ -98,6 +101,8 @@ impl<Value: IValue> Statement<Value> for CircuitStatement<Value> {
);
sum = eval!(context, (sum) + (term));
}
*use_counts.entry(GATE_RELATION_NAME.to_string()).or_insert(0) +=
self.output_addresses.len() as u64;

// Blake IV public logup sum contribution.
if self.n_blake_gates > 0 {
Expand All @@ -123,7 +128,7 @@ impl<Value: IValue> Statement<Value> for CircuitStatement<Value> {
sum = eval!(context, (sum) - (blake_iv_yield));
}

sum
(sum, use_counts)
}

fn get_preprocessed_column_ids(&self) -> Vec<PreProcessedColumnId> {
Expand Down
12 changes: 8 additions & 4 deletions crates/stark_verifier/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,13 @@ pub trait Statement<Value: IValue> {
/// Returns the expected preprocessed trace root as circuit variables.
fn get_preprocessed_root(&self, context: &mut Context<Value>) -> HashValue<Var>;

/// Returns the part of the logup sum determined by the public statement.
fn public_logup_sum(&self, context: &mut Context<Value>, interaction_elements: [Var; 2])
-> Var;
/// Returns the part of the logup sum determined by the public statement, and
/// the number of uses it contains for each relation.
fn public_logup_sum(
&self,
context: &mut Context<Value>,
interaction_elements: [Var; 2],
) -> (Var, HashMap<String, u64>);

/// Returns statement-specific named parameters passed to component constraint evaluators.
fn public_params(&self, _context: &mut Context<Value>) -> HashMap<String, Var> {
Expand All @@ -60,7 +64,7 @@ pub trait Statement<Value: IValue> {
&self,
_context: &mut Context<Value>,
_component_sizes: &[Var],
_shifted_relation_uses: &HashMap<&'static str, Var>,
_shifted_relation_uses: &HashMap<String, Var>,
) {
}
}
Loading
Loading