Skip to content

Commit 4e6b301

Browse files
authored
check for wrong number of returned vars (#107)
* check for wrong number of returned vars in function calls * check for wrong number of returned vars in return statements
1 parent be12a66 commit 4e6b301

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram {
160160
for (name, func) in &program.functions {
161161
let mut array_manager = ArrayManager::default();
162162
let simplified_instructions = simplify_lines(
163+
&program.functions,
164+
func.n_returned_vars,
163165
&func.body,
164166
&mut counters,
165167
&mut new_functions,
@@ -435,7 +437,10 @@ impl ArrayManager {
435437
}
436438
}
437439

440+
#[allow(clippy::too_many_arguments)]
438441
fn simplify_lines(
442+
functions: &BTreeMap<String, Function>,
443+
n_returned_vars: usize,
439444
lines: &[Line],
440445
counters: &mut Counters,
441446
new_functions: &mut BTreeMap<String, SimpleFunction>,
@@ -455,6 +460,8 @@ fn simplify_lines(
455460
for (i, (pattern, statements)) in arms.iter().enumerate() {
456461
assert_eq!(*pattern, i, "match patterns should be consecutive, starting from 0");
457462
simple_arms.push(simplify_lines(
463+
functions,
464+
n_returned_vars,
458465
statements,
459466
counters,
460467
new_functions,
@@ -606,6 +613,8 @@ fn simplify_lines(
606613

607614
let mut array_manager_then = array_manager.clone();
608615
let then_branch_simplified = simplify_lines(
616+
functions,
617+
n_returned_vars,
609618
then_branch,
610619
counters,
611620
new_functions,
@@ -617,6 +626,8 @@ fn simplify_lines(
617626
array_manager_else.valid = array_manager.valid.clone(); // Crucial: remove the access added in the IF branch
618627

619628
let else_branch_simplified = simplify_lines(
629+
functions,
630+
n_returned_vars,
620631
else_branch,
621632
counters,
622633
new_functions,
@@ -666,6 +677,8 @@ fn simplify_lines(
666677
let mut body_copy = body.clone();
667678
replace_vars_for_unroll(&mut body_copy, iterator, unroll_index, i, &internal_variables);
668679
unrolled_lines.extend(simplify_lines(
680+
functions,
681+
0,
669682
&body_copy,
670683
counters,
671684
new_functions,
@@ -689,6 +702,8 @@ fn simplify_lines(
689702
let valid_aux_vars_in_array_manager_before = array_manager.valid.clone();
690703
array_manager.valid.clear();
691704
let simplified_body = simplify_lines(
705+
functions,
706+
0,
692707
body,
693708
counters,
694709
new_functions,
@@ -763,6 +778,16 @@ fn simplify_lines(
763778
return_data,
764779
line_number,
765780
} => {
781+
let function = functions
782+
.get(function_name)
783+
.expect("Function used but not defined: {function_name}");
784+
if return_data.len() != function.n_returned_vars {
785+
panic!(
786+
"Expected {} returned vars in call to {function_name}",
787+
function.n_returned_vars
788+
);
789+
}
790+
766791
let simplified_args = args
767792
.iter()
768793
.map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc))
@@ -776,6 +801,11 @@ fn simplify_lines(
776801
}
777802
Line::FunctionRet { return_data } => {
778803
assert!(!in_a_loop, "Function return inside a loop is not currently supported");
804+
assert!(
805+
return_data.len() == n_returned_vars,
806+
"Wrong number of return values in return statement; expected {n_returned_vars} but got {}",
807+
return_data.len()
808+
);
779809
let simplified_return_data = return_data
780810
.iter()
781811
.map(|ret| simplify_expr(ret, &mut res, counters, array_manager, const_malloc))

crates/lean_compiler/tests/test_compiler.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,36 @@ fn test_duplicate_constant_name() {
3838
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
3939
}
4040

41+
#[test]
42+
#[should_panic]
43+
fn test_wrong_n_returned_vars_1() {
44+
let program = r#"
45+
fn main() {
46+
a, b = f();
47+
}
48+
49+
fn f() -> 1 {
50+
return 0;
51+
}
52+
"#;
53+
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
54+
}
55+
56+
#[test]
57+
#[should_panic]
58+
fn test_wrong_n_returned_vars_2() {
59+
let program = r#"
60+
fn main() {
61+
a = f();
62+
}
63+
64+
fn f() -> 1 {
65+
return 0, 1;
66+
}
67+
"#;
68+
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
69+
}
70+
4171
#[test]
4272
fn test_fibonacci_program() {
4373
// a program to check the value of the 30th Fibonacci number (832040)

0 commit comments

Comments
 (0)