diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index f985a3cd..c4934d5c 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -177,15 +177,16 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { v.clone() }) .collect::>(); - new_functions.insert( - name.clone(), - SimpleFunction { - name: name.clone(), - arguments, - n_returned_vars: func.n_returned_vars, - instructions: simplified_instructions, - }, - ); + let simplified_function = SimpleFunction { + name: name.clone(), + arguments, + n_returned_vars: func.n_returned_vars, + instructions: simplified_instructions, + }; + if !func.assume_always_returns { + check_function_always_returns(&simplified_function); + } + new_functions.insert(name.clone(), simplified_function); const_malloc.map.clear(); } SimpleProgram { @@ -193,6 +194,39 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram { } } +/// Analyzes a simplified function to verify that it returns on each code path. +fn check_function_always_returns(func: &SimpleFunction) { + check_block_always_returns(&func.name, &func.instructions); +} + +fn check_block_always_returns(function_name: &String, instructions: &[SimpleLine]) { + match instructions.last() { + Some(SimpleLine::Match { value: _, arms }) => { + for arm in arms { + check_block_always_returns(function_name, arm); + } + } + Some(SimpleLine::IfNotZero { + condition: _, + then_branch, + else_branch, + line_number: _, + }) => { + check_block_always_returns(function_name, then_branch); + check_block_always_returns(function_name, else_branch); + } + Some(SimpleLine::FunctionRet { return_data: _ }) => { + // good + } + Some(SimpleLine::Panic) => { + // good + } + _ => { + panic!("Cannot prove that function always returns: {function_name}"); + } + } +} + /// Analyzes the program to verify that each variable is defined in each context where it is used. fn check_program_scoping(program: &Program) { for (_, function) in program.functions.iter() { @@ -1864,6 +1898,7 @@ fn handle_const_arguments_helper( inlined: false, body: new_body, n_returned_vars: func.n_returned_vars, + assume_always_returns: func.assume_always_returns, }, ); } diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 8de5a975..332b78fa 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -7,7 +7,8 @@ program = { SOI ~ constant_declaration* ~ function+ ~ EOI } constant_declaration = { "const" ~ identifier ~ "=" ~ expression ~ ";" } // Functions -function = { "fn" ~ identifier ~ "(" ~ parameter_list? ~ ")" ~ inlined_statement? ~ return_count? ~ "{" ~ statement* ~ "}" } +function = { pragma? ~ "fn" ~ identifier ~ "(" ~ parameter_list? ~ ")" ~ inlined_statement? ~ return_count? ~ "{" ~ statement* ~ "}" } +pragma = { "#![assume_always_returns]" } parameter_list = { parameter ~ ("," ~ parameter)* } parameter = { (const_keyword)? ~ identifier } const_keyword = { "const" } diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 252fd680..227ae09e 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -19,6 +19,7 @@ pub struct Function { pub inlined: bool, pub n_returned_vars: usize, pub body: Vec, + pub assume_always_returns: bool, } impl Function { diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 77abe17a..484a1bc2 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -17,7 +17,14 @@ pub struct FunctionParser; impl Parse for FunctionParser { fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let mut inner = pair.into_inner(); + let mut inner = pair.into_inner().peekable(); + let assume_always_returns = match inner.peek().map(|x| x.as_rule()) { + Some(Rule::pragma) => { + inner.next(); + true + } + _ => false, + }; let name = next_inner_pair(&mut inner, "function name")?.as_str().to_string(); let mut arguments = Vec::new(); @@ -53,6 +60,7 @@ impl Parse for FunctionParser { inlined, n_returned_vars, body, + assume_always_returns, }) } } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 3142c6d7..9df40be6 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -44,6 +44,7 @@ fn test_wrong_n_returned_vars_1() { let program = r#" fn main() { a, b = f(); + return; } fn f() -> 1 { @@ -59,6 +60,7 @@ fn test_wrong_n_returned_vars_2() { let program = r#" fn main() { a = f(); + return; } fn f() -> 1 { @@ -68,6 +70,45 @@ fn test_wrong_n_returned_vars_2() { compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); } +#[test] +#[should_panic] +fn test_no_return() { + let program = r#" + fn main() { + a = f(); + return; + } + + fn f() -> 1 { + } + + fn g() -> 1 { + return 0; + } + "#; + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); +} + +#[test] +fn test_assumed_return() { + let program = r#" + fn main() { + a = f(); + return; + } + + #![assume_always_returns] + fn f() -> 1 { + if 1 == 1 { + return 0; + } else { + print(1); + } + } + "#; + compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false); +} + #[test] fn test_fibonacci_program() { // a program to check the value of the 30th Fibonacci number (832040)