diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 50f95b59..4c86348d 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -2,8 +2,8 @@ use crate::{ Counter, F, ir::HighLevelOperation, lang::{ - AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Expression, Function, - Line, Program, SimpleExpr, Var, + AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Context, Expression, + Function, Line, Program, Scope, SimpleExpr, Var, }, }; use lean_vm::{SourceLineNumber, Table, TableT}; @@ -72,6 +72,9 @@ pub enum SimpleLine { value: SimpleExpr, arms: Vec>, // patterns = 0, 1, ... }, + ForwardDeclaration { + var: Var, + }, Assignment { var: VarOrConstMallocAccess, operation: HighLevelOperation, @@ -148,6 +151,7 @@ pub enum SimpleLine { } pub fn simplify_program(mut program: Program) -> SimpleProgram { + check_program_scoping(&program); handle_inlined_functions(&mut program); handle_const_arguments(&mut program); let mut new_functions = BTreeMap::new(); @@ -187,6 +191,218 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { } } +/// Analyzes the program to verify that each variable is defined in each context where it is used. +fn check_program_scoping(program: &Program) { + for (_, function) in program.functions.iter() { + let mut scope = Scope { vars: BTreeSet::new() }; + for (arg, _) in function.arguments.iter() { + scope.vars.insert(arg.clone()); + } + let mut ctx = Context { scopes: vec![scope] }; + + check_block_scoping(&function.body, &mut ctx); + } +} + +/// Analyzes the block to verify that each variable is defined in each context where it is used. +fn check_block_scoping(block: &[Line], ctx: &mut Context) { + for line in block.iter() { + match line { + Line::ForwardDeclaration { var } => { + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!( + !last_scope.vars.contains(var), + "Variable declared multiple times in the same scope: {var}", + ); + last_scope.vars.insert(var.clone()); + } + Line::Match { value, arms } => { + check_expr_scoping(value, ctx); + for (_, arm) in arms { + ctx.scopes.push(Scope { vars: BTreeSet::new() }); + check_block_scoping(arm, ctx); + ctx.scopes.pop(); + } + } + Line::Assignment { var, value } => { + check_expr_scoping(value, ctx); + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!( + !last_scope.vars.contains(var), + "Variable declared multiple times in the same scope: {var}", + ); + last_scope.vars.insert(var.clone()); + } + Line::ArrayAssign { array, index, value } => { + check_simple_expr_scoping(array, ctx); + check_expr_scoping(index, ctx); + check_expr_scoping(value, ctx); + } + Line::Assert(boolean, _) => { + check_boolean_scoping(boolean, ctx); + } + Line::IfCondition { + condition, + then_branch, + else_branch, + line_number: _, + } => { + check_condition_scoping(condition, ctx); + for branch in [then_branch, else_branch] { + ctx.scopes.push(Scope { vars: BTreeSet::new() }); + check_block_scoping(branch, ctx); + ctx.scopes.pop(); + } + } + Line::ForLoop { + iterator, + start, + end, + body, + rev: _, + unroll: _, + line_number: _, + } => { + check_expr_scoping(start, ctx); + check_expr_scoping(end, ctx); + let mut new_scope_vars = BTreeSet::new(); + new_scope_vars.insert(iterator.clone()); + ctx.scopes.push(Scope { vars: new_scope_vars }); + check_block_scoping(body, ctx); + ctx.scopes.pop(); + } + Line::FunctionCall { + function_name: _, + args, + return_data, + line_number: _, + } => { + for arg in args { + check_expr_scoping(arg, ctx); + } + let last_scope = ctx.scopes.last_mut().unwrap(); + for var in return_data { + assert!( + !last_scope.vars.contains(var), + "Variable declared multiple times in the same scope: {var}", + ); + last_scope.vars.insert(var.clone()); + } + } + Line::FunctionRet { return_data } => { + for expr in return_data { + check_expr_scoping(expr, ctx); + } + } + Line::Precompile { table: _, args } => { + for arg in args { + check_expr_scoping(arg, ctx); + } + } + Line::Break | Line::Panic | Line::LocationReport { .. } => {} + Line::Print { line_info: _, content } => { + for expr in content { + check_expr_scoping(expr, ctx); + } + } + Line::MAlloc { + var, + size, + vectorized: _, + vectorized_len, + } => { + check_expr_scoping(size, ctx); + check_expr_scoping(vectorized_len, ctx); + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!( + !last_scope.vars.contains(var), + "Variable declared multiple times in the same scope: {var}", + ); + last_scope.vars.insert(var.clone()); + } + Line::DecomposeBits { var, to_decompose } => { + for expr in to_decompose { + check_expr_scoping(expr, ctx); + } + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!( + !last_scope.vars.contains(var), + "Variable declared multiple times in the same scope: {var}", + ); + last_scope.vars.insert(var.clone()); + } + Line::DecomposeCustom { args } => { + for arg in args { + check_expr_scoping(arg, ctx); + } + } + Line::PrivateInputStart { result } => { + let last_scope = ctx.scopes.last_mut().unwrap(); + assert!( + !last_scope.vars.contains(result), + "Variable declared multiple times in the same scope: {result}" + ); + last_scope.vars.insert(result.clone()); + } + } + } +} + +/// Analyzes the expression to verify that each variable is defined in the given context. +fn check_expr_scoping(expr: &Expression, ctx: &Context) { + match expr { + Expression::Value(simple_expr) => { + check_simple_expr_scoping(simple_expr, ctx); + } + Expression::ArrayAccess { array, index } => { + check_simple_expr_scoping(array, ctx); + check_expr_scoping(index, ctx); + } + Expression::Binary { + left, + operation: _, + right, + } => { + check_expr_scoping(left, ctx); + check_expr_scoping(right, ctx); + } + Expression::Log2Ceil { value } => { + check_expr_scoping(value, ctx); + } + } +} + +/// Analyzes the simple expression to verify that each variable is defined in the given context. +fn check_simple_expr_scoping(expr: &SimpleExpr, ctx: &Context) { + match expr { + SimpleExpr::Var(v) => { + assert!(ctx.defines(v), "Variable used but not defined: {v}"); + } + SimpleExpr::Constant(_) => {} + SimpleExpr::ConstMallocAccess { .. } => {} + } +} + +fn check_boolean_scoping(boolean: &Boolean, ctx: &Context) { + match boolean { + Boolean::Equal { left, right } | Boolean::Different { left, right } => { + check_expr_scoping(left, ctx); + check_expr_scoping(right, ctx); + } + } +} + +fn check_condition_scoping(condition: &Condition, ctx: &Context) { + match condition { + Condition::Expression(expr, _) => { + check_expr_scoping(expr, ctx); + } + Condition::Comparison(boolean) => { + check_boolean_scoping(boolean, ctx); + } + } +} + #[derive(Debug, Clone, Default)] struct Counters { aux_vars: usize, @@ -205,7 +421,6 @@ struct ArrayManager { pub struct ConstMalloc { counter: usize, map: BTreeMap, - forbidden_vars: BTreeSet, // vars shared between branches of an if/else } impl ArrayManager { @@ -231,6 +446,9 @@ fn simplify_lines( let mut res = Vec::new(); for line in lines { match line { + Line::ForwardDeclaration { var } => { + res.push(SimpleLine::ForwardDeclaration { var: var.clone() }); + } Line::Match { value, arms } => { let simple_value = simplify_expr(value, &mut res, counters, array_manager, const_malloc); let mut simple_arms = vec![]; @@ -320,7 +538,7 @@ fn simplify_lines( } else if let Ok(right) = right.clone().try_into() { (right, left) } else { - unreachable!("Weird: {:?}, {:?}", left, right) + panic!("Unsupported equality assertion: {left:?}, {right:?}") }; res.push(SimpleLine::Assignment { var, @@ -386,17 +604,6 @@ fn simplify_lines( } }; - let forbidden_vars_before = const_malloc.forbidden_vars.clone(); - - let then_internal_vars = find_variable_usage(then_branch).0; - let else_internal_vars = find_variable_usage(else_branch).0; - let new_forbidden_vars = then_internal_vars - .intersection(&else_internal_vars) - .cloned() - .collect::>(); - - const_malloc.forbidden_vars.extend(new_forbidden_vars); - let mut array_manager_then = array_manager.clone(); let then_branch_simplified = simplify_lines( then_branch, @@ -418,8 +625,6 @@ fn simplify_lines( const_malloc, ); - const_malloc.forbidden_vars = forbidden_vars_before; - *array_manager = array_manager_else.clone(); // keep the intersection both branches array_manager.valid = array_manager @@ -612,12 +817,8 @@ fn simplify_lines( let simplified_size = simplify_expr(size, &mut res, counters, array_manager, const_malloc); let simplified_vectorized_len = simplify_expr(vectorized_len, &mut res, counters, array_manager, const_malloc); - if simplified_size.is_constant() && !*vectorized && const_malloc.forbidden_vars.contains(var) { - println!("TODO: Optimization missed: Requires to align const malloc in if/else branches"); - } match simplified_size { - SimpleExpr::Constant(const_size) if !*vectorized && !const_malloc.forbidden_vars.contains(var) => { - // TODO do this optimization even if we are in an if/else branch + SimpleExpr::Constant(const_size) if !*vectorized => { let label = const_malloc.counter; const_malloc.counter += 1; const_malloc.map.insert(var.clone(), label); @@ -638,7 +839,6 @@ fn simplify_lines( } } Line::DecomposeBits { var, to_decompose } => { - assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); let simplified_to_decompose = to_decompose .iter() .map(|expr| simplify_expr(expr, &mut res, counters, array_manager, const_malloc)) @@ -772,6 +972,9 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { for line in lines { match line { + Line::ForwardDeclaration { var } => { + internal_vars.insert(var.clone()); + } Line::Match { value, arms } => { on_new_expr(value, &internal_vars, &mut external_vars); for (_, statements) in arms { @@ -921,7 +1124,7 @@ pub fn inline_lines(lines: &mut Vec, args: &BTreeMap, res let inline_internal_var = |var: &mut Var| { assert!( !args.contains_key(var), - "Variable {var} is both an argument and assigned in the inlined function" + "Variable {var} is both an argument and declared in the inlined function" ); *var = format!("@inlined_var_{inlining_count}_{var}"); }; @@ -929,6 +1132,9 @@ pub fn inline_lines(lines: &mut Vec, args: &BTreeMap, res let mut lines_to_replace = vec![]; for (i, line) in lines.iter_mut().enumerate() { match line { + Line::ForwardDeclaration { var } => { + inline_internal_var(var); + } Line::Match { value, arms } => { inline_expr(value, args, inlining_count); for (_, statements) in arms { @@ -1239,6 +1445,9 @@ fn replace_vars_for_unroll( replace_vars_for_unroll(statements, iterator, unroll_index, iterator_value, internal_vars); } } + Line::ForwardDeclaration { var } => { + *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + } Line::Assignment { var, value } => { assert!(var != iterator, "Weird"); *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); @@ -1444,6 +1653,10 @@ fn handle_inlined_functions_helper( if let Some(func) = inlined_functions.get(&*function_name) { let mut inlined_lines = vec![]; + for var in return_data.iter() { + inlined_lines.push(Line::ForwardDeclaration { var: var.clone() }); + } + let mut simplified_args = vec![]; for arg in args { if let Expression::Value(simple_expr) = arg { @@ -1699,6 +1912,7 @@ fn get_function_called(lines: &[Line], function_called: &mut Vec) { get_function_called(body, function_called); } Line::Assignment { .. } + | Line::ForwardDeclaration { .. } | Line::ArrayAssign { .. } | Line::Assert { .. } | Line::FunctionRet { .. } @@ -1724,6 +1938,9 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { replace_vars_by_const_in_lines(statements, map); } } + Line::ForwardDeclaration { var } => { + assert!(!map.contains_key(var), "Variable {var} is a constant"); + } Line::Assignment { var, value } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); replace_vars_by_const_in_expr(value, map); @@ -1831,6 +2048,9 @@ impl SimpleLine { fn to_string_with_indent(&self, indent: usize) -> String { let spaces = " ".repeat(indent); let line_str = match self { + Self::ForwardDeclaration { var } => { + format!("var {var}") + } Self::Match { value, arms } => { let arms_str = arms .iter() @@ -1880,7 +2100,7 @@ impl SimpleLine { ) } Self::RawAccess { res, index, shift } => { - format!("memory[{index} + {shift}] = {res}") + format!("{res} = memory[{index} + {shift}]") } Self::TestZero { operation, arg0, arg1 } => { format!("0 = {arg0} {operation} {arg1}") diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index 18cebdf8..b9ff27ec 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -1,10 +1,7 @@ use crate::{F, a_simplify_lang::*, ir::*, lang::*}; use lean_vm::*; use multilinear_toolkit::prelude::*; -use std::{ - borrow::Borrow, - collections::{BTreeMap, BTreeSet}, -}; +use std::collections::BTreeMap; use utils::ToUsize; #[derive(Default)] @@ -14,31 +11,56 @@ struct Compiler { if_counter: usize, call_counter: usize, func_name: String, - var_positions: BTreeMap, // var -> memory offset from fp + stack_frame_layout: StackFrameLayout, args_count: usize, stack_size: usize, + stack_pos: usize, +} + +#[derive(Default)] +struct StackFrameLayout { + // Innermost lexical scope last + scopes: Vec, +} + +#[derive(Default)] +struct ScopeLayout { + var_positions: BTreeMap, // var -> memory offset from fp const_mallocs: BTreeMap, // const_malloc_label -> start = memory offset from fp } impl Compiler { + fn is_in_scope(&self, var: &Var) -> bool { + for scope in self.stack_frame_layout.scopes.iter() { + if let Some(_offset) = scope.var_positions.get(var) { + return true; + } + } + false + } + fn get_offset(&self, var: &VarOrConstMallocAccess) -> ConstExpression { match var { - VarOrConstMallocAccess::Var(var) => (*self - .var_positions - .get(var) - .unwrap_or_else(|| panic!("Variable {var} not in scope"))) - .into(), - VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => ConstExpression::Binary { - left: Box::new( - self.const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, + VarOrConstMallocAccess::Var(var) => { + for scope in self.stack_frame_layout.scopes.iter().rev() { + if let Some(offset) = scope.var_positions.get(var) { + return (*offset).into(); + } + } + panic!("Variable {var} not in scope"); + } + VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => { + for scope in self.stack_frame_layout.scopes.iter().rev() { + if let Some(base) = scope.const_mallocs.get(malloc_label) { + return ConstExpression::Binary { + left: Box::new((*base).into()), + operation: HighLevelOperation::Add, + right: Box::new((*offset).clone()), + }; + } + } + panic!("Const malloc {malloc_label} not in scope"); + } } } } @@ -68,18 +90,10 @@ impl IntermediateValue { }, SimpleExpr::Constant(c) => Self::Constant(c.clone()), SimpleExpr::ConstMallocAccess { malloc_label, offset } => Self::MemoryAfterFp { - offset: ConstExpression::Binary { - left: Box::new( - compiler - .const_mallocs - .get(malloc_label) - .copied() - .unwrap_or_else(|| panic!("Const malloc {malloc_label} not in scope")) - .into(), - ), - operation: HighLevelOperation::Add, - right: Box::new(offset.clone()), - }, + offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *malloc_label, + offset: offset.clone(), + }), }, } } @@ -110,38 +124,31 @@ fn compile_function( function: &SimpleFunction, compiler: &mut Compiler, ) -> Result, String> { - let mut internal_vars = find_internal_vars(&function.instructions); - - internal_vars.retain(|var| !function.arguments.contains(var)); - // memory layout: pc, fp, args, return_vars, internal_vars let mut stack_pos = 2; // Reserve space for pc and fp - let mut var_positions = BTreeMap::new(); + let function_scope_layout = ScopeLayout::default(); + compiler.stack_frame_layout = StackFrameLayout { + scopes: vec![function_scope_layout], + }; + let function_scope_layout = &mut compiler.stack_frame_layout.scopes[0]; for (i, var) in function.arguments.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); + function_scope_layout.var_positions.insert(var.clone(), stack_pos + i); } stack_pos += function.arguments.len(); stack_pos += function.n_returned_vars; - for (i, var) in internal_vars.iter().enumerate() { - var_positions.insert(var.clone(), stack_pos + i); - } - stack_pos += internal_vars.len(); - compiler.func_name = function.name.clone(); - compiler.var_positions = var_positions; + compiler.stack_pos = stack_pos; compiler.stack_size = stack_pos; compiler.args_count = function.arguments.len(); - let mut declared_vars: BTreeSet = function.arguments.iter().cloned().collect(); compile_lines( &Label::function(function.name.clone()), &function.instructions, compiler, None, - &mut declared_vars, ) } @@ -150,29 +157,44 @@ fn compile_lines( lines: &[SimpleLine], compiler: &mut Compiler, final_jump: Option