diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index b172b538..84a1d307 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -126,13 +126,14 @@ fn compile_function( function: &SimpleFunction, compiler: &mut Compiler, ) -> Result, String> { - let mut internal_vars = find_internal_vars(&function.instructions); - - internal_vars.retain(|var| !function.arguments.contains(var)); + let internal_vars: InternalVars = + find_internal_vars(&function.instructions, &|var: &Var| -> bool { + function.arguments.contains(var) + }); // memory layout: pc, fp, args, return_vars, internal_vars let mut stack_pos = 2; // Reserve space for pc and fp - let mut var_positions = BTreeMap::new(); + let mut var_positions: BTreeMap = BTreeMap::new(); for (i, var) in function.arguments.iter().enumerate() { var_positions.insert(var.clone(), stack_pos + i); @@ -141,10 +142,7 @@ fn compile_function( stack_pos += function.n_returned_vars; - for (i, var) in internal_vars.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); - } - stack_pos += internal_vars.len(); + stack_pos = layout_internal_vars(&internal_vars, &mut var_positions, stack_pos); compiler.func_name = function.name.clone(); compiler.var_positions = var_positions; @@ -161,6 +159,37 @@ fn compile_function( ) } +fn layout_internal_vars( + internal_vars: &InternalVars, + var_positions: &mut BTreeMap, + initial_stack_pos: usize, +) -> usize { + let mut stack_pos = initial_stack_pos; + match internal_vars { + InternalVars::One(var) => { + if !var_positions.contains_key(var) { + var_positions.insert(var.clone(), stack_pos); + stack_pos += 1; + } + } + InternalVars::AllOf(children) => { + for child in children { + stack_pos = layout_internal_vars(child, var_positions, stack_pos); + } + } + InternalVars::OneOf(children) => { + // TODO: this is wrong b/c it can result in the same stack pos + // being doubly assigned when a name is shared between branches? + let mut new_stack_poss: Vec = Vec::new(); + for child in children { + new_stack_poss.push(layout_internal_vars(child, var_positions, stack_pos)); + } + stack_pos = *new_stack_poss.iter().max().unwrap_or(&stack_pos); + } + } + stack_pos +} + fn compile_lines( function_name: &Label, lines: &[SimpleLine], @@ -795,50 +824,166 @@ fn compile_function_ret( }); } -fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { - let mut internal_vars = BTreeSet::new(); +enum InternalVars { + One(Var), + AllOf(Vec), + OneOf(Vec), +} + +fn find_internal_vars(lines: &[SimpleLine], exclude: &F) -> InternalVars +where + F: Fn(&Var) -> bool, +{ + let mut internal_vars: Vec = Vec::new(); + + // Scan outside of conditional statements first, so that any variables shared + // between branches of conditional statements and also outside of the conditional + // statement will be assigned a consistent stack location. for line in lines { match line { - SimpleLine::Match { arms, .. } => { - for arm in arms { - internal_vars.extend(find_internal_vars(arm)); + SimpleLine::Match { value, .. } => { + if let SimpleExpr::Var(var) = value && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); } } - SimpleLine::Assignment { var, .. } => { - if let VarOrConstMallocAccess::Var(var) = var { - internal_vars.insert(var.clone()); + SimpleLine::Assignment { var, arg0, arg1, .. } => { + if let VarOrConstMallocAccess::Var(var) = var + && !exclude(var) + { + internal_vars.push(InternalVars::One(var.clone())); + } + if let SimpleExpr::Var(var) = arg0 && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + if let SimpleExpr::Var(var) = arg1 && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + } + SimpleLine::TestZero { arg0, arg1, .. } => { + if let SimpleExpr::Var(var) = arg0 && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + if let SimpleExpr::Var(var) = arg1 && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + } + SimpleLine::HintMAlloc { var, size, vectorized_len, .. } => { + if !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + if let SimpleExpr::Var(var) = size && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + if let SimpleExpr::Var(var) = vectorized_len && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + } + SimpleLine::DecomposeBits { var, to_decompose, .. } + | SimpleLine::DecomposeCustom { var, to_decompose, .. } => { + if !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + for expr in to_decompose { + if let SimpleExpr::Var(var) = expr && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } } } - SimpleLine::TestZero { .. } => {} - SimpleLine::HintMAlloc { var, .. } - | SimpleLine::ConstMalloc { var, .. } - | SimpleLine::DecomposeBits { var, .. } - | SimpleLine::DecomposeCustom { var, .. } + SimpleLine::ConstMalloc { var, .. } | SimpleLine::CounterHint { var } => { - internal_vars.insert(var.clone()); + if !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } } - SimpleLine::RawAccess { res, .. } => { - if let SimpleExpr::Var(var) = res { - internal_vars.insert(var.clone()); + SimpleLine::RawAccess { res, index, .. } => { + if let SimpleExpr::Var(var) = res && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + if let SimpleExpr::Var(var) = index && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + } + SimpleLine::FunctionCall { return_data, args, .. } => { + internal_vars.extend( + return_data + .iter() + .filter(|&var| !exclude(var)) + .cloned() + .map(InternalVars::One), + ); + for arg in args { + if let SimpleExpr::Var(var) = arg && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + } + } + SimpleLine::FunctionRet { return_data } => { + for expr in return_data { + if let SimpleExpr::Var(var) = expr && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + } + } + SimpleLine::Precompile { args, .. } => { + for arg in args { + if let SimpleExpr::Var(var) = arg && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } + } + } + SimpleLine::Print { content, .. } => { + for expr in content { + if let SimpleExpr::Var(var) = expr && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } } } - SimpleLine::FunctionCall { return_data, .. } => { - internal_vars.extend(return_data.iter().cloned()); + SimpleLine::IfNotZero { condition, .. } => { + if let SimpleExpr::Var(var) = condition && !exclude(var) { + internal_vars.push(InternalVars::One(var.clone())); + } } + SimpleLine::Panic + | SimpleLine::LocationReport { .. } => {} + } + } + + // Having scanned outside of conditional statements, scan inside conditional statements. + for line in lines { + match line { SimpleLine::IfNotZero { then_branch, else_branch, .. } => { - internal_vars.extend(find_internal_vars(then_branch)); - internal_vars.extend(find_internal_vars(else_branch)); + internal_vars.push(InternalVars::OneOf(vec![ + find_internal_vars(then_branch, exclude), + find_internal_vars(else_branch, exclude), + ])); } - SimpleLine::Panic + SimpleLine::Match { arms, .. } => { + let mut branch_vars: Vec = Vec::new(); + for arm in arms { + branch_vars.push(find_internal_vars(arm, exclude)); + } + internal_vars.push(InternalVars::OneOf(branch_vars)); + } + SimpleLine::Assignment { .. } + | SimpleLine::TestZero { .. } + | SimpleLine::HintMAlloc { .. } + | SimpleLine::ConstMalloc { .. } + | SimpleLine::DecomposeBits { .. } + | SimpleLine::DecomposeCustom { .. } + | SimpleLine::CounterHint { .. } + | SimpleLine::RawAccess { .. } + | SimpleLine::FunctionCall { .. } + | SimpleLine::Panic | SimpleLine::Print { .. } | SimpleLine::FunctionRet { .. } | SimpleLine::Precompile { .. } | SimpleLine::LocationReport { .. } => {} } } - internal_vars + + InternalVars::AllOf(internal_vars) } diff --git a/crates/lean_prover/tests/test_zkvm.rs b/crates/lean_prover/tests/test_zkvm.rs index 9733e38a..9b2e9f53 100644 --- a/crates/lean_prover/tests/test_zkvm.rs +++ b/crates/lean_prover/tests/test_zkvm.rs @@ -135,6 +135,6 @@ fn test_zk_vm_helper( ); let proof_time = time.elapsed(); verify_execution(&bytecode, public_input, proof_data, whir_config_builder()).unwrap(); - println!("{}", summary); + println!("{summary}"); println!("Proof time: {:.3} s", proof_time.as_secs_f32()); }