diff --git a/Cargo.lock b/Cargo.lock index 769fc395..6be023f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -384,6 +384,7 @@ dependencies = [ "poseidon_circuit", "rand", "sub_protocols", + "thiserror", "tracing", "utils", "vm_air", diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 5ea1e3be..f5e07e4f 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -156,6 +156,10 @@ pub enum SimpleLine { LocationReport { location: SourceLineNumber, }, + RangeCheck { + value: SimpleExpr, + max: ConstExpression, + }, } pub fn simplify_program(mut program: Program) -> SimpleProgram { @@ -748,6 +752,14 @@ fn simplify_lines( location: *location, }); } + Line::RangeCheck { value, max } => { + let simplified_value = + simplify_expr(value, &mut res, counters, array_manager, const_malloc); + res.push(SimpleLine::RangeCheck { + value: simplified_value, + max: max.clone(), + }); + } } } @@ -970,7 +982,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { on_new_expr(index, &internal_vars, &mut external_vars); on_new_expr(value, &internal_vars, &mut external_vars); } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} + Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {} } } @@ -1141,7 +1153,7 @@ pub fn inline_lines( inline_expr(index, args, inlining_count); inline_expr(value, args, inlining_count); } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} + Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {} } } for (i, new_lines) in lines_to_replace.into_iter().rev() { @@ -1635,7 +1647,7 @@ fn replace_vars_for_unroll( Line::CounterHint { var } => { *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); } - Line::Break | Line::Panic | Line::LocationReport { .. } => {} + Line::Break | Line::Panic | Line::LocationReport { .. } | Line::RangeCheck { .. } => {} } } } @@ -2027,7 +2039,8 @@ fn get_function_called(lines: &[Line], function_called: &mut Vec) { | Line::MAlloc { .. } | Line::Panic | Line::Break - | Line::LocationReport { .. } => {} + | Line::LocationReport { .. } + | Line::RangeCheck { .. } => {} } } } @@ -2136,7 +2149,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(size, map); } - Line::Panic | Line::Break | Line::LocationReport { .. } => {} + Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {} } } } @@ -2333,6 +2346,9 @@ impl SimpleLine { } Self::Panic => "panic".to_string(), Self::LocationReport { .. } => Default::default(), + Self::RangeCheck { value, max } => { + format!("range_check({value}, {max})") + } }; format!("{spaces}{line_str}") } diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index b172b538..d4e803fa 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -419,6 +419,7 @@ fn compile_lines( shift_0, shift_1: shift.clone(), res: res.to_mem_after_fp_or_constant(compiler), + for_range_check: false, }); } @@ -457,6 +458,7 @@ fn compile_lines( res: IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset: compiler.get_offset(&ret_var.clone().into()), }, + for_range_check: false, }); } @@ -638,6 +640,72 @@ fn compile_lines( location: *location, }); } + SimpleLine::RangeCheck { value, max } => { + // x is the fp offset of the memory cell which contains the value + // i.e. m[fp + x] contains value + let x = match IntermediateValue::from_simple_expr(value, compiler) { + IntermediateValue::MemoryAfterFp { offset } => offset.naive_eval().unwrap(), + value::IntermediateValue::Fp => F::ZERO, + value::IntermediateValue::Constant(_) => unimplemented!(), + }; + + let t = max.naive_eval().unwrap(); + let aux_i = compiler.stack_size; + let aux_j = compiler.stack_size + 1; + let aux_k = compiler.stack_size + 2; + + // Step 1: DEREF: m[fp + i] == m[m[fp + x]] + // DEREF: m[fp + i] == m[value] + + let step_1 = IntermediateInstruction::Deref { + shift_0: ConstExpression::scalar(x.to_usize()), + shift_1: ConstExpression::from(0), + res: IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: aux_i.into(), + }, + for_range_check: true, + }; + + // Step 2: ADD: m[fp + x] + m[fp + j] == (t-1) + // m[fp + j] == t - 1 - m[fp + x] + // Uses constraint solving to store t - 1 - m[fp + x] in m[fp + j] + let step_2 = IntermediateInstruction::Computation { + operation: Operation::Add, + arg_a: IntermediateValue::MemoryAfterFp { + offset: x.to_usize().into(), + }, + arg_c: IntermediateValue::MemoryAfterFp { + offset: aux_j.into(), + }, + res: IntermediateValue::Constant((t - F::ONE).to_usize().into()), + }; + + // Step 3: DEREF: m[fp + k] == m[m[fp + j]] + let step_3 = IntermediateInstruction::Deref { + shift_0: ConstExpression::scalar(aux_j), + shift_1: ConstExpression::from(0), + res: IntermediaryMemOrFpOrConstant::MemoryAfterFp { + offset: aux_k.into(), + }, + for_range_check: true, + }; + + // Insert the instructions + instructions.extend_from_slice(&[ + // This is just the RangeCheck hint which does nothing + IntermediateInstruction::RangeCheck { + value: IntermediateValue::from_simple_expr(value, compiler), + max: max.clone(), + }, + // These are the steps that effectuate the range check + step_1, + step_2, + step_3, + ]); + + // Increase the stack size by 3 as we used 3 aux variables + compiler.stack_size += 3; + } } } @@ -715,11 +783,13 @@ fn setup_function_call( res: IntermediaryMemOrFpOrConstant::Constant(ConstExpression::label( return_label.clone(), )), + for_range_check: false, }, IntermediateInstruction::Deref { shift_0: new_fp_pos.into(), shift_1: ConstExpression::one(), res: IntermediaryMemOrFpOrConstant::Fp, + for_range_check: false, }, ]; @@ -729,6 +799,7 @@ fn setup_function_call( shift_0: new_fp_pos.into(), shift_1: (2 + i).into(), res: arg.to_mem_after_fp_or_constant(compiler), + for_range_check: false, }); } @@ -837,7 +908,8 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { | SimpleLine::Print { .. } | SimpleLine::FunctionRet { .. } | SimpleLine::Precompile { .. } - | SimpleLine::LocationReport { .. } => {} + | SimpleLine::LocationReport { .. } + | SimpleLine::RangeCheck { .. } => {} } } internal_vars diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 1dcc450c..ba5117c3 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -13,7 +13,8 @@ impl IntermediateInstruction { | Self::DecomposeCustom { .. } | Self::CounterHint { .. } | Self::Inverse { .. } - | Self::LocationReport { .. } => true, + | Self::LocationReport { .. } + | Self::RangeCheck { .. } => true, Self::Computation { .. } | Self::Panic | Self::Deref { .. } @@ -264,6 +265,7 @@ fn compile_block( shift_0, shift_1, res, + for_range_check, } => { low_level_bytecode.push(Instruction::Deref { shift_0: eval_const_expression(&shift_0, compiler).to_usize(), @@ -279,6 +281,7 @@ fn compile_block( MemOrFpOrConstant::Constant(eval_const_expression(&c, compiler)) } }, + for_range_check, }); } IntermediateInstruction::JumpIfNotZero { @@ -407,8 +410,15 @@ fn compile_block( let hint = Hint::LocationReport { location }; hints.entry(pc).or_default().push(hint); } + IntermediateInstruction::RangeCheck { value, max } => { + let hint = Hint::RangeCheck { + value: value.try_into_mem_or_fp(compiler).unwrap(), + // TODO: support max being an IntermediateValue + max: MemOrConstant::Constant(eval_const_expression(&max, compiler)), + }; + hints.entry(pc).or_default().push(hint); + } } - if !instruction.is_hint() { pc += 1; } diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index d6fc7d1a..61282407 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -24,9 +24,10 @@ statement = { return_statement | break_statement | continue_statement | - function_call | + range_check_statement | assert_eq_statement | - assert_not_eq_statement + assert_not_eq_statement | + function_call // Placed at the end so that it doesn't override other statements like range_check and assert_eq } return_statement = { "return" ~ (tuple_expression)? ~ ";" } @@ -61,6 +62,8 @@ var_list = { identifier ~ ("," ~ identifier)* } assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" } assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" } +range_check_statement = { "range_check" ~ "(" ~ expression ~ "," ~ expression ~ ")" ~ ";" } + // Expressions tuple_expression = { expression ~ ("," ~ expression)* } expression = { neq_expr } diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index 40e73770..a1922112 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -17,6 +17,7 @@ pub enum IntermediateInstruction { shift_0: ConstExpression, shift_1: ConstExpression, res: IntermediaryMemOrFpOrConstant, + for_range_check: bool, }, // res = m[m[fp + shift_0]] Panic, Jump { @@ -86,6 +87,10 @@ pub enum IntermediateInstruction { LocationReport { location: SourceLineNumber, }, + RangeCheck { + value: IntermediateValue, + max: ConstExpression, + }, } impl IntermediateInstruction { @@ -152,7 +157,11 @@ impl Display for IntermediateInstruction { shift_0, shift_1, res, - } => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]"), + for_range_check, + } => write!( + f, + "{res} = m[m[fp + {shift_0}] + {shift_1}] for_range_check: {for_range_check}" + ), Self::Panic => write!(f, "panic"), Self::Jump { dest, updated_fp } => { if let Some(fp) = updated_fp { @@ -256,6 +265,9 @@ impl Display for IntermediateInstruction { Ok(()) } Self::LocationReport { .. } => Ok(()), + Self::RangeCheck { value, max } => { + write!(f, "range_check({value}, {max})") + } } } } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 5690b8db..f5bac024 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -393,6 +393,10 @@ pub enum Line { LocationReport { location: SourceLineNumber, }, + RangeCheck { + value: Expression, + max: ConstExpression, + }, } impl Display for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -591,6 +595,7 @@ impl Line { } Self::Break => "break".to_string(), Self::Panic => "panic".to_string(), + Self::RangeCheck { value, max } => format!("range_check({value}, {max})"), }; format!("{spaces}{line_str}") } diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index de055693..fdf1d2d1 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -4,7 +4,7 @@ use super::literal::ConstExprParser; use super::{Parse, ParseContext, next_inner_pair}; use crate::{ ir::HighLevelOperation, - lang::{AssumeBoolean, Boolean, Condition, Expression, Line}, + lang::{AssumeBoolean, Boolean, Condition, ConstExpression, Expression, Line}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, @@ -28,6 +28,7 @@ impl Parse for StatementParser { Rule::function_call => FunctionCallParser::parse(inner, ctx), Rule::assert_eq_statement => AssertEqParser::parse(inner, ctx), Rule::assert_not_eq_statement => AssertNotEqParser::parse(inner, ctx), + Rule::range_check_statement => RangeCheckParser::parse(inner, ctx), Rule::break_statement => Ok(Line::Break), Rule::continue_statement => { Err(SemanticError::new("Continue statement not implemented yet").into()) @@ -327,3 +328,22 @@ impl Parse for AssertNotEqParser { )) } } + +/// Parser for range check statements. +pub struct RangeCheckParser; + +impl Parse for RangeCheckParser { + fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let mut inner = pair.into_inner(); + let value = + ExpressionParser::parse(next_inner_pair(&mut inner, "range check value")?, ctx)?; + let max_expr = + ExpressionParser::parse(next_inner_pair(&mut inner, "range check max")?, ctx)?; + + // Convert the max expression to a const expression + let max = ConstExpression::try_from(max_expr) + .map_err(|_| SemanticError::new("Range check maximum must be a constant expression"))?; + + Ok(Line::RangeCheck { value, max }) + } +} diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 86e56670..bec66b9d 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -4,6 +4,45 @@ use utils::{poseidon16_permute, poseidon24_permute}; const DEFAULT_NO_VEC_RUNTIME_MEMORY: usize = 1 << 15; +// TODO: create more test programs +fn range_check_program(value: usize, max: usize) -> String { + let program = format!( + r#" + //fn func(val) {{ + //if 0 == 0 {{ + //range_check(val, {max}); + //}} + //abc = 0; + //range_check(abc, {max}); + //return; + //}} + + fn main() {{ + val = {value}; + //func(val); + range_check(val, {max}); + //range_check(val, {max}); + return; + }} + "# + ); + program.to_string() +} + +#[test] +fn test_compile_range_check() { + let program = range_check_program(1000, 100000); + let bytecode = compile_program(program.clone()); + println!("{}", bytecode); + compile_and_run( + program.to_string(), + (&[], &[]), + DEFAULT_NO_VEC_RUNTIME_MEMORY, + false, + ); + //TODO test more +} + #[test] #[should_panic] fn test_duplicate_function_name() { diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index d8189e2a..a5e35d44 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -29,6 +29,7 @@ witness_generation.workspace = true vm_air.workspace = true multilinear-toolkit.workspace = true poseidon_circuit.workspace = true +thiserror.workspace = true [dev-dependencies] -xmss.workspace = true \ No newline at end of file +xmss.workspace = true diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 3b1b7acf..6c0e8281 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -52,6 +52,15 @@ pub fn prove_execution( (poseidons_16_precomputed, poseidons_24_precomputed), ) }); + + // Fill undefined memory cells used for range checks + for (dest_ptr, src_ptr) in &execution_result.range_check_cells_to_fill { + if execution_result.memory.get(*dest_ptr).is_err() { + let value = execution_result.memory.get(*src_ptr).unwrap_or(F::ZERO); + execution_result.memory.set(*dest_ptr, value).unwrap(); + } + } + exec_summary = std::mem::take(&mut execution_result.summary); info_span!("Building execution trace") .in_scope(|| get_execution_trace(bytecode, execution_result)) diff --git a/crates/lean_prover/tests/test_range_check.rs b/crates/lean_prover/tests/test_range_check.rs new file mode 100644 index 00000000..8716e6f4 --- /dev/null +++ b/crates/lean_prover/tests/test_range_check.rs @@ -0,0 +1,160 @@ +use lean_compiler::compile_program; +use lean_prover::verify_execution::verify_execution; +use lean_prover::{prove_execution::prove_execution, whir_config_builder}; +use lean_vm::{DIMENSION, F, NONRESERVED_PROGRAM_INPUT_START}; +use multilinear_toolkit::prelude::*; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::collections::BTreeSet; +use whir_p3::WhirConfigBuilder; + +const NO_VEC_RUNTIME_MEMORY: usize = 1 << 20; + +fn range_check_program(value: usize, max: usize) -> String { + let program = format!( + r#" + fn func() {{ + x = 1; + y = {value}; + value = x * y; + range_check(value, {max}); + return; + }} + + fn main() {{ + x = 1; + y = {value}; + value = x * y; + range_check(value, {max}); + + func(); + + if 0 == 0 {{ + a = 1; + b = {value}; + c = a * b; + range_check(c, {max}); + }} + return; + }} + "# + ); + program.to_string() +} + +fn random_test_cases(num_test_cases: usize, valid: bool) -> BTreeSet<(usize, usize)> { + let t_max = 1 << 16; + let mut rng = StdRng::seed_from_u64(0); + + let mut test_cases = BTreeSet::new(); + + while test_cases.len() < num_test_cases { + let t = rng.random_range(1..t_max); + let v = if valid { + rng.random_range(0..t) + } else { + rng.random_range(t..1 << 31) + }; + + test_cases.insert((v, t)); + } + + test_cases +} + +fn prepare_inputs() -> (Vec, Vec) { + const SECOND_POINT: usize = 2; + const SECOND_N_VARS: usize = 7; + + let mut public_input = (0..(1 << 13) - NONRESERVED_PROGRAM_INPUT_START) + .map(F::from_usize) + .collect::>(); + + public_input[SECOND_POINT * (SECOND_N_VARS * DIMENSION).next_power_of_two() + + SECOND_N_VARS * DIMENSION + - NONRESERVED_PROGRAM_INPUT_START + ..(SECOND_POINT + 1) * (SECOND_N_VARS * DIMENSION).next_power_of_two() + - NONRESERVED_PROGRAM_INPUT_START] + .iter_mut() + .for_each(|x| *x = F::ZERO); + + let private_input = (0..1 << 13) + .map(|i| F::from_usize(i).square()) + .collect::>(); + + (public_input, private_input) +} + +fn do_test_range_check( + v: usize, + t: usize, + whir_config_builder: &WhirConfigBuilder, + public_input: &[F], + private_input: &[F], +) { + let program_str = range_check_program(v, t); + + let bytecode = compile_program(program_str); + + let (proof_data, _, _) = prove_execution( + &bytecode, + (public_input, private_input), + whir_config_builder.clone(), + NO_VEC_RUNTIME_MEMORY, + false, + (&vec![], &vec![]), + ); + verify_execution( + &bytecode, + public_input, + proof_data, + whir_config_builder.clone(), + ) + .unwrap(); +} + +#[test] +fn test_range_check_valid() { + test_range_check_random(100, true); +} + +#[test] +#[should_panic] +fn test_range_check_invalid() { + test_range_check_random(1, false); +} + +fn test_range_check_random(num_test_cases: usize, valid: bool) { + let (public_input, private_input) = prepare_inputs(); + let whir_config_builder = whir_config_builder(); + + let test_cases = random_test_cases(num_test_cases, valid); + + println!("Running {} random test cases", test_cases.len()); + + for (v, t) in test_cases { + println!("v: {v}; t: {t}"); + do_test_range_check(v, t, &whir_config_builder, &public_input, &private_input); + } +} + +#[test] +fn test_range_check_valid_1() { + let (public_input, private_input) = prepare_inputs(); + let whir_config_builder = whir_config_builder(); + do_test_range_check( + 3716, + 20122, + &whir_config_builder, + &public_input, + &private_input, + ); +} + +#[test] +#[should_panic] +fn test_range_check_invalid_1() { + let (public_input, private_input) = prepare_inputs(); + let whir_config_builder = whir_config_builder(); + do_test_range_check(1, 0, &whir_config_builder, &public_input, &private_input); +} diff --git a/crates/lean_prover/witness_generation/src/execution_trace.rs b/crates/lean_prover/witness_generation/src/execution_trace.rs index f130bf55..5c7b1fa7 100644 --- a/crates/lean_prover/witness_generation/src/execution_trace.rs +++ b/crates/lean_prover/witness_generation/src/execution_trace.rs @@ -49,6 +49,7 @@ pub fn get_execution_trace( .zip(execution_result.fps.par_iter()) .for_each(|((trace_row, &pc), &fp)| { let instruction = &bytecode.instructions[pc]; + //println!("instruction: {}", instruction); let field_repr = field_representation(instruction); let mut addr_a = F::ZERO; @@ -56,13 +57,13 @@ pub fn get_execution_trace( // flag_a == 0 addr_a = F::from_usize(fp) + field_repr[0]; // fp + operand_a } - let value_a = memory.0[addr_a.to_usize()].unwrap(); + let value_a = memory.get(addr_a.to_usize()).unwrap_or(F::ZERO); let mut addr_b = F::ZERO; if field_repr[4].is_zero() { // flag_b == 0 addr_b = F::from_usize(fp) + field_repr[1]; // fp + operand_b } - let value_b = memory.0[addr_b.to_usize()].unwrap(); + let value_b = memory.get(addr_b.to_usize()).unwrap_or(F::ZERO); let mut addr_c = F::ZERO; if field_repr[5].is_zero() { @@ -73,7 +74,7 @@ pub fn get_execution_trace( assert_eq!(field_repr[2], operand_c); // debug purpose addr_c = value_a + operand_c; } - let value_c = memory.0[addr_c.to_usize()].unwrap(); + let value_c = memory.get(addr_c.to_usize()).unwrap_or(F::ZERO); for (j, field) in field_repr.iter().enumerate() { *trace_row[j] = *field; diff --git a/crates/lean_prover/witness_generation/src/instruction_encoder.rs b/crates/lean_prover/witness_generation/src/instruction_encoder.rs index dca85d9c..de8588c5 100644 --- a/crates/lean_prover/witness_generation/src/instruction_encoder.rs +++ b/crates/lean_prover/witness_generation/src/instruction_encoder.rs @@ -28,6 +28,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { shift_0, shift_1, res, + for_range_check: _, } => { fields[COL_INDEX_DEREF] = F::ONE; fields[COL_INDEX_FLAG_A] = F::ZERO; diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index cc8ed616..5519d53b 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -4,6 +4,7 @@ use crate::execution::Memory; use crate::witness::{ WitnessDotProduct, WitnessMultilinearEval, WitnessPoseidon16, WitnessPoseidon24, }; +use std::collections::BTreeSet; use thiserror::Error; #[derive(Debug, Clone, Error)] @@ -48,4 +49,5 @@ pub struct ExecutionResult { pub multilinear_evals: Vec, pub summary: String, pub memory_profile: Option, + pub range_check_cells_to_fill: Vec<(usize, usize)>, } diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index 41898992..d0848c63 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -214,6 +214,7 @@ fn execute_bytecode_helper( let mut counter_hint = 0; let mut cpu_cycles_before_new_line = 0; + let mut range_check_cells_to_fill = Vec::new(); while pc != bytecode.ending_pc { if pc >= bytecode.instructions.len() { @@ -264,6 +265,7 @@ fn execute_bytecode_helper( poseidon24_precomputed: poseidons_24_precomputed, n_poseidon16_precomputed_used: &mut n_poseidon16_precomputed_used, n_poseidon24_precomputed_used: &mut n_poseidon24_precomputed_used, + range_check_cells_to_fill: &mut range_check_cells_to_fill, }; instruction.execute_instruction(&mut instruction_ctx)?; } @@ -426,5 +428,6 @@ fn execute_bytecode_helper( multilinear_evals, summary, memory_profile: if profiling { Some(mem_profile) } else { None }, + range_check_cells_to_fill, }) } diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 68cccfc0..96d186ba 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -1,7 +1,7 @@ use crate::core::{F, LOG_VECTOR_LEN, Label, SourceLineNumber, VECTOR_LEN}; use crate::diagnostics::{MemoryObject, MemoryObjectType, MemoryProfile, RunnerError}; use crate::execution::{ExecutionHistory, Memory}; -use crate::isa::operands::MemOrConstant; +use crate::isa::operands::{MemOrConstant, MemOrFp}; use multilinear_toolkit::prelude::*; use std::fmt::{Display, Formatter}; use utils::{ToUsize, pretty_integer}; @@ -67,6 +67,10 @@ pub enum Hint { }, /// Jump destination label (for debugging purposes) Label { label: Label }, + + /// Range check + RangeCheck { value: MemOrFp, max: MemOrConstant }, + /// Stack frame size (for memory profiling) StackFrame { label: Label, size: usize }, } @@ -227,6 +231,7 @@ impl Hint { *ctx.cpu_cycles_before_new_line = 0; } Self::Label { .. } => {} + Self::RangeCheck { .. } => {} Self::StackFrame { label, size } => { if ctx.profiling { ctx.memory_profile.objects.insert( @@ -313,6 +318,9 @@ impl Display for Hint { Self::Label { label } => { write!(f, "label: {label}") } + Self::RangeCheck { value, max } => { + write!(f, "range_check({value}, {max})") + } Self::StackFrame { label, size } => { write!(f, "stack frame for {label} size {size}") } diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index 548a17df..b0352b0a 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -38,6 +38,7 @@ pub enum Instruction { shift_1: usize, /// Result destination (can be memory, frame pointer, or constant) res: MemOrFpOrConstant, + for_range_check: bool, }, /// Conditional jump instruction for control flow @@ -118,6 +119,7 @@ pub struct InstructionContext<'a> { pub poseidon24_precomputed: &'a [([F; 24], [F; 8])], pub n_poseidon16_precomputed_used: &'a mut usize, pub n_poseidon24_precomputed_used: &'a mut usize, + pub range_check_cells_to_fill: &'a mut Vec<(usize, usize)>, } impl Instruction { @@ -175,12 +177,29 @@ impl Instruction { shift_0, shift_1, res, + for_range_check, } => { if res.is_value_unknown(ctx.memory, *ctx.fp) { let memory_address_res = res.memory_address(*ctx.fp)?; let ptr = ctx.memory.get(*ctx.fp + shift_0)?; - let value = ctx.memory.get(ptr.to_usize() + shift_1)?; - ctx.memory.set(memory_address_res, value)?; + + if *for_range_check { + let ptr_usize = ptr.to_usize(); + let value = ctx.memory.get(ptr_usize + shift_1); + if let Ok(value) = value { + ctx.memory.set(memory_address_res, value)?; + } else { + // Ignore the UndefinedMemory error from get(). Also, indicate/"hint" + // to the prover that it needs to be filled later on with either 0 or + // some other value which a later instruction will set + ctx.range_check_cells_to_fill + .push((memory_address_res, ptr_usize)); + } + } else { + // For non-range check derefs, allow the error to bubble up + let value = ctx.memory.get(ptr.to_usize() + shift_1)?; + ctx.memory.set(memory_address_res, value)?; + } } else { let value = res.read_value(ctx.memory, *ctx.fp)?; let ptr = ctx.memory.get(*ctx.fp + shift_0)?; @@ -199,6 +218,7 @@ impl Instruction { } => { let condition_value = condition.read_value(ctx.memory, *ctx.fp)?; assert!([F::ZERO, F::ONE].contains(&condition_value),); + if condition_value == F::ZERO { *ctx.pc += 1; } else { @@ -423,8 +443,12 @@ impl Display for Instruction { shift_0, shift_1, res, + for_range_check, } => { - write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]") + write!( + f, + "{res} = m[m[fp + {shift_0}] + {shift_1}] for_range_check: {for_range_check}" + ) } Self::DotProduct { arg0, diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index ec443243..db3042d6 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -96,7 +96,7 @@ pub fn transpose( width: usize, column_extra_capacity: usize, ) -> Vec> { - assert!((matrix.len().is_multiple_of(width))); + assert!(matrix.len().is_multiple_of(width)); let height = matrix.len() / width; let res = vec![ {