diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 4c86348d..f985a3cd 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -160,6 +160,8 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { for (name, func) in &program.functions { let mut array_manager = ArrayManager::default(); let simplified_instructions = simplify_lines( + &program.functions, + func.n_returned_vars, &func.body, &mut counters, &mut new_functions, @@ -435,7 +437,10 @@ impl ArrayManager { } } +#[allow(clippy::too_many_arguments)] fn simplify_lines( + functions: &BTreeMap, + n_returned_vars: usize, lines: &[Line], counters: &mut Counters, new_functions: &mut BTreeMap, @@ -455,6 +460,8 @@ fn simplify_lines( for (i, (pattern, statements)) in arms.iter().enumerate() { assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0"); simple_arms.push(simplify_lines( + functions, + n_returned_vars, statements, counters, new_functions, @@ -606,6 +613,8 @@ fn simplify_lines( let mut array_manager_then = array_manager.clone(); let then_branch_simplified = simplify_lines( + functions, + n_returned_vars, then_branch, counters, new_functions, @@ -617,6 +626,8 @@ fn simplify_lines( array_manager_else.valid = array_manager.valid.clone(); // Crucial: remove the access added in the IF branch let else_branch_simplified = simplify_lines( + functions, + n_returned_vars, else_branch, counters, new_functions, @@ -666,6 +677,8 @@ fn simplify_lines( let mut body_copy = body.clone(); replace_vars_for_unroll(&mut body_copy, iterator, unroll_index, i, &internal_variables); unrolled_lines.extend(simplify_lines( + functions, + 0, &body_copy, counters, new_functions, @@ -689,6 +702,8 @@ fn simplify_lines( let valid_aux_vars_in_array_manager_before = array_manager.valid.clone(); array_manager.valid.clear(); let simplified_body = simplify_lines( + functions, + 0, body, counters, new_functions, @@ -763,6 +778,16 @@ fn simplify_lines( return_data, line_number, } => { + let function = functions + .get(function_name) + .expect("Function used but not defined: {function_name}"); + if return_data.len() != function.n_returned_vars { + panic!( + "Expected {} returned vars in call to {function_name}", + function.n_returned_vars + ); + } + let simplified_args = args .iter() .map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc)) @@ -776,6 +801,11 @@ fn simplify_lines( } Line::FunctionRet { return_data } => { assert!(!in_a_loop, "Function return inside a loop is not currently supported"); + assert!( + return_data.len() == n_returned_vars, + "Wrong number of return values in return statement; expected {n_returned_vars} but got {}", + return_data.len() + ); let simplified_return_data = return_data .iter() .map(|ret| simplify_expr(ret, &mut res, counters, array_manager, const_malloc)) diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 2267bbe9..3142c6d7 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -38,6 +38,36 @@ fn test_duplicate_constant_name() { compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } +#[test] +#[should_panic] +fn test_wrong_n_returned_vars_1() { + let program = r#" + fn main() { + a, b = f(); + } + + fn f() -> 1 { + return 0; + } + "#; + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); +} + +#[test] +#[should_panic] +fn test_wrong_n_returned_vars_2() { + let program = r#" + fn main() { + a = f(); + } + + fn f() -> 1 { + return 0, 1; + } + "#; + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); +} + #[test] fn test_fibonacci_program() { // a program to check the value of the 30th Fibonacci number (832040)