diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index 332b78fa..ca2532ee 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -42,12 +42,14 @@ single_assignment = { identifier ~ "=" ~ expression ~ ";" } array_assign = { identifier ~ "[" ~ expression ~ "]" ~ "=" ~ expression ~ ";" } -if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_clause? } +if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_if_clause* ~ else_clause? } condition = { expression | assumed_bool_expr } assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" } +else_if_clause = { "else" ~ "if" ~ condition ~ "{" ~ statement* ~ "}" } + else_clause = { "else" ~ "{" ~ statement* ~ "}" } for_statement = { "for" ~ identifier ~ "in" ~ rev_clause? ~ expression ~ ".." ~ expression ~ unroll_clause? ~ "{" ~ statement* ~ "}" } diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 094d963f..6464d417 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -3,6 +3,7 @@ use super::function::{FunctionCallParser, TupleExpressionParser}; use super::literal::ConstExprParser; use super::{Parse, ParseContext, next_inner_pair}; use crate::{ + SourceLineNumber, ir::HighLevelOperation, lang::{AssumeBoolean, Boolean, Condition, Expression, Line}, parser::{ @@ -88,14 +89,26 @@ impl Parse for IfStatementParser { let mut inner = pair.into_inner(); let condition = ConditionParser::parse(next_inner_pair(&mut inner, "if condition")?, ctx)?; - let mut then_branch = Vec::new(); - let mut else_branch = Vec::new(); + let mut then_branch: Vec = Vec::new(); + let mut else_if_branches: Vec<(Condition, Vec, SourceLineNumber)> = Vec::new(); + let mut else_branch: Vec = Vec::new(); for item in inner { match item.as_rule() { Rule::statement => { Self::add_statement_with_location(&mut then_branch, item, ctx)?; } + Rule::else_if_clause => { + let line_number = item.line_col().0; + let mut inner = item.into_inner(); + let else_if_condition = + ConditionParser::parse(next_inner_pair(&mut inner, "else if condition")?, ctx)?; + let mut else_if_branch = Vec::new(); + for else_if_item in inner { + Self::add_statement_with_location(&mut else_if_branch, else_if_item, ctx)?; + } + else_if_branches.push((else_if_condition, else_if_branch, line_number)); + } Rule::else_clause => { for else_item in item.into_inner() { if else_item.as_rule() == Rule::statement { @@ -107,10 +120,28 @@ impl Parse for IfStatementParser { } } + let mut outer_else_branch = Vec::new(); + let mut inner_else_branch = &mut outer_else_branch; + + for (else_if_condition, else_if_branch, line_number) in else_if_branches.into_iter() { + inner_else_branch.push(Line::IfCondition { + condition: else_if_condition, + then_branch: else_if_branch, + else_branch: Vec::new(), + line_number, + }); + inner_else_branch = match &mut inner_else_branch[0] { + Line::IfCondition { else_branch, .. } => else_branch, + _ => unreachable!("Expected Line::IfCondition"), + }; + } + + inner_else_branch.extend(else_branch); + Ok(Line::IfCondition { condition, then_branch, - else_branch, + else_branch: outer_else_branch, line_number, }) } diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 9df40be6..55039e89 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -259,17 +259,13 @@ fn test_mini_program_1() { for i in 0..N { if i == 0 { arr[i] = 10; + } else if i == 1 { + arr[i] = 20; + } else if i == 2 { + arr[i] = 30; } else { - if i == 1 { - arr[i] = 20; - } else { - if i == 2 { - arr[i] = 30; - } else { - i_plus_one = i + 1; - arr[i] = i_plus_one; - } - } + i_plus_one = i + 1; + arr[i] = i_plus_one; } } return; diff --git a/crates/rec_aggregation/recursion_program.lean_lang b/crates/rec_aggregation/recursion_program.lean_lang index 29f246cd..2d16ad25 100644 --- a/crates/rec_aggregation/recursion_program.lean_lang +++ b/crates/rec_aggregation/recursion_program.lean_lang @@ -377,21 +377,15 @@ fn merkle_verif_batch_dynamic(n_paths, leaves_digests, leave_positions, root, he if height == MERKLE_HEIGHT_0 { merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_0); return; - } else { - if height == MERKLE_HEIGHT_1 { - merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_1); - return; - } else { - if height == MERKLE_HEIGHT_2 { - merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_2); - return; - } else { - if height == MERKLE_HEIGHT_3 { - merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_3); - return; - } - } - } + } else if height == MERKLE_HEIGHT_1 { + merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_1); + return; + } else if height == MERKLE_HEIGHT_2 { + merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_2); + return; + } else if height == MERKLE_HEIGHT_3 { + merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_3); + return; } print(12345555); @@ -725,21 +719,15 @@ fn sample_bits_dynamic(fs_state, n_samples, K) -> 2 { if n_samples == NUM_QUERIES_0 { new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_0, K); return new_fs_state, sampled_bits; - } else { - if n_samples == NUM_QUERIES_1 { - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_1, K); - return new_fs_state, sampled_bits; - } else { - if n_samples == NUM_QUERIES_2 { - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_2, K); - return new_fs_state, sampled_bits; - } else { - if n_samples == NUM_QUERIES_3 { - new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_3, K); - return new_fs_state, sampled_bits; - } - } - } + } else if n_samples == NUM_QUERIES_1 { + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_1, K); + return new_fs_state, sampled_bits; + } else if n_samples == NUM_QUERIES_2 { + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_2, K); + return new_fs_state, sampled_bits; + } else if n_samples == NUM_QUERIES_3 { + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_3, K); + return new_fs_state, sampled_bits; } print(n_samples); print(999333);