Skip to content

Commit 7330d74

Browse files
authored
issue/76: allow arbitrary expressions as conditions in if (#78)
1 parent 05a8f7b commit 7330d74

File tree

12 files changed

+270
-85
lines changed

12 files changed

+270
-85
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 139 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use crate::{
22
Counter, F,
33
ir::HighLevelOperation,
44
lang::{
5-
Boolean, ConstExpression, ConstMallocLabel, Expression, Function, Line, Program,
6-
SimpleExpr, Var,
5+
AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue,
6+
Expression, Function, Line, Program, SimpleExpr, Var,
77
},
88
precompiles::Precompile,
99
};
@@ -97,6 +97,12 @@ pub enum SimpleLine {
9797
then_branch: Vec<Self>,
9898
else_branch: Vec<Self>,
9999
},
100+
TestZero {
101+
// Test that the result of the given operation is zero
102+
operation: HighLevelOperation,
103+
arg0: SimpleExpr,
104+
arg1: SimpleExpr,
105+
},
100106
FunctionCall {
101107
function_name: String,
102108
args: Vec<SimpleExpr>,
@@ -353,26 +359,68 @@ fn simplify_lines(
353359
then_branch,
354360
else_branch,
355361
} => {
356-
// Transform if a == b then X else Y into if a != b then Y else X
362+
let (condition_simplified, then_branch, else_branch) = match condition {
363+
Condition::Comparison(condition) => {
364+
// Transform if a == b then X else Y into if a != b then Y else X
365+
366+
let (left, right, then_branch, else_branch) = match condition {
367+
Boolean::Equal { left, right } => {
368+
(left, right, else_branch, then_branch)
369+
} // switched
370+
Boolean::Different { left, right } => {
371+
(left, right, then_branch, else_branch)
372+
}
373+
};
374+
375+
let left_simplified =
376+
simplify_expr(left, &mut res, counters, array_manager, const_malloc);
377+
let right_simplified =
378+
simplify_expr(right, &mut res, counters, array_manager, const_malloc);
379+
380+
let diff_var = format!("@diff_{}", counters.aux_vars);
381+
counters.aux_vars += 1;
382+
res.push(SimpleLine::Assignment {
383+
var: diff_var.clone().into(),
384+
operation: HighLevelOperation::Sub,
385+
arg0: left_simplified,
386+
arg1: right_simplified,
387+
});
388+
(diff_var.into(), then_branch, else_branch)
389+
}
390+
Condition::Expression(condition, assume_boolean) => {
391+
let condition_simplified = simplify_expr(
392+
condition,
393+
&mut res,
394+
counters,
395+
array_manager,
396+
const_malloc,
397+
);
357398

358-
let (left, right, then_branch, else_branch) = match condition {
359-
Boolean::Equal { left, right } => (left, right, else_branch, then_branch), // switched
360-
Boolean::Different { left, right } => (left, right, then_branch, else_branch),
361-
};
399+
match assume_boolean {
400+
AssumeBoolean::AssumeBoolean => {}
401+
AssumeBoolean::DoNotAssumeBoolean => {
402+
// Check condition_simplified is boolean
403+
let one_minus_condition_var = format!("@aux_{}", counters.aux_vars);
404+
counters.aux_vars += 1;
405+
res.push(SimpleLine::Assignment {
406+
var: one_minus_condition_var.clone().into(),
407+
operation: HighLevelOperation::Sub,
408+
arg0: SimpleExpr::Constant(ConstExpression::Value(
409+
ConstantValue::Scalar(1),
410+
)),
411+
arg1: condition_simplified.clone(),
412+
});
413+
res.push(SimpleLine::TestZero {
414+
operation: HighLevelOperation::Mul,
415+
arg0: condition_simplified.clone(),
416+
arg1: one_minus_condition_var.into(),
417+
});
418+
}
419+
}
362420

363-
let left_simplified =
364-
simplify_expr(left, &mut res, counters, array_manager, const_malloc);
365-
let right_simplified =
366-
simplify_expr(right, &mut res, counters, array_manager, const_malloc);
367-
368-
let diff_var = format!("@diff_{}", counters.aux_vars);
369-
counters.aux_vars += 1;
370-
res.push(SimpleLine::Assignment {
371-
var: diff_var.clone().into(),
372-
operation: HighLevelOperation::Sub,
373-
arg0: left_simplified,
374-
arg1: right_simplified,
375-
});
421+
(condition_simplified, then_branch, else_branch)
422+
}
423+
};
376424

377425
let forbidden_vars_before = const_malloc.forbidden_vars.clone();
378426

@@ -417,7 +465,7 @@ fn simplify_lines(
417465
.collect();
418466

419467
res.push(SimpleLine::IfNotZero {
420-
condition: diff_var.into(),
468+
condition: condition_simplified,
421469
then_branch: then_branch_simplified,
422470
else_branch: else_branch_simplified,
423471
});
@@ -787,12 +835,20 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
787835
}
788836
};
789837

790-
let on_new_condition =
791-
|condition: &Boolean, internal_vars: &BTreeSet<Var>, external_vars: &mut BTreeSet<Var>| {
792-
let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition;
793-
on_new_expr(left, internal_vars, external_vars);
794-
on_new_expr(right, internal_vars, external_vars);
795-
};
838+
let on_new_condition = |condition: &Condition,
839+
internal_vars: &BTreeSet<Var>,
840+
external_vars: &mut BTreeSet<Var>| {
841+
match condition {
842+
Condition::Comparison(Boolean::Equal { left, right })
843+
| Condition::Comparison(Boolean::Different { left, right }) => {
844+
on_new_expr(left, internal_vars, external_vars);
845+
on_new_expr(right, internal_vars, external_vars);
846+
}
847+
Condition::Expression(expr, _assume_boolean) => {
848+
on_new_expr(expr, internal_vars, external_vars);
849+
}
850+
}
851+
};
796852

797853
for line in lines {
798854
match line {
@@ -839,7 +895,11 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
839895
internal_vars.extend(return_data.iter().cloned());
840896
}
841897
Line::Assert(condition) => {
842-
on_new_condition(condition, &internal_vars, &mut external_vars);
898+
on_new_condition(
899+
&Condition::Comparison(condition.clone()),
900+
&internal_vars,
901+
&mut external_vars,
902+
);
843903
}
844904
Line::FunctionRet { return_data } => {
845905
for ret in return_data {
@@ -944,12 +1004,17 @@ pub fn inline_lines(
9441004
res: &[Var],
9451005
inlining_count: usize,
9461006
) {
947-
let inline_condition = |condition: &mut Boolean| {
948-
let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition;
1007+
let inline_comparison = |comparison: &mut Boolean| {
1008+
let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = comparison;
9491009
inline_expr(left, args, inlining_count);
9501010
inline_expr(right, args, inlining_count);
9511011
};
9521012

1013+
let inline_condition = |condition: &mut Condition| match condition {
1014+
Condition::Comparison(comparison) => inline_comparison(comparison),
1015+
Condition::Expression(expr, _assume_boolean) => inline_expr(expr, args, inlining_count),
1016+
};
1017+
9531018
let inline_internal_var = |var: &mut Var| {
9541019
assert!(
9551020
!args.contains_key(var),
@@ -994,7 +1059,7 @@ pub fn inline_lines(
9941059
}
9951060
}
9961061
Line::Assert(condition) => {
997-
inline_condition(condition);
1062+
inline_comparison(condition);
9981063
}
9991064
Line::FunctionRet { return_data } => {
10001065
assert_eq!(return_data.len(), res.len());
@@ -1368,24 +1433,40 @@ fn replace_vars_for_unroll(
13681433
);
13691434
}
13701435
Line::IfCondition {
1371-
condition: Boolean::Equal { left, right } | Boolean::Different { left, right },
1436+
condition,
13721437
then_branch,
13731438
else_branch,
13741439
} => {
1375-
replace_vars_for_unroll_in_expr(
1376-
left,
1377-
iterator,
1378-
unroll_index,
1379-
iterator_value,
1380-
internal_vars,
1381-
);
1382-
replace_vars_for_unroll_in_expr(
1383-
right,
1384-
iterator,
1385-
unroll_index,
1386-
iterator_value,
1387-
internal_vars,
1388-
);
1440+
match condition {
1441+
Condition::Comparison(
1442+
Boolean::Equal { left, right } | Boolean::Different { left, right },
1443+
) => {
1444+
replace_vars_for_unroll_in_expr(
1445+
left,
1446+
iterator,
1447+
unroll_index,
1448+
iterator_value,
1449+
internal_vars,
1450+
);
1451+
replace_vars_for_unroll_in_expr(
1452+
right,
1453+
iterator,
1454+
unroll_index,
1455+
iterator_value,
1456+
internal_vars,
1457+
);
1458+
}
1459+
Condition::Expression(expr, _assume_bool) => {
1460+
replace_vars_for_unroll_in_expr(
1461+
expr,
1462+
iterator,
1463+
unroll_index,
1464+
iterator_value,
1465+
internal_vars,
1466+
);
1467+
}
1468+
}
1469+
13891470
replace_vars_for_unroll(
13901471
then_branch,
13911472
iterator,
@@ -1972,10 +2053,14 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
19722053
else_branch,
19732054
} => {
19742055
match condition {
1975-
Boolean::Equal { left, right } | Boolean::Different { left, right } => {
2056+
Condition::Comparison(Boolean::Equal { left, right })
2057+
| Condition::Comparison(Boolean::Different { left, right }) => {
19762058
replace_vars_by_const_in_expr(left, map);
19772059
replace_vars_by_const_in_expr(right, map);
19782060
}
2061+
Condition::Expression(expr, _assume_boolean) => {
2062+
replace_vars_by_const_in_expr(expr, map);
2063+
}
19792064
}
19802065
replace_vars_by_const_in_lines(then_branch, map);
19812066
replace_vars_by_const_in_lines(else_branch, map);
@@ -2116,6 +2201,13 @@ impl SimpleLine {
21162201
Self::RawAccess { res, index, shift } => {
21172202
format!("memory[{index} + {shift}] = {res}")
21182203
}
2204+
Self::TestZero {
2205+
operation,
2206+
arg0,
2207+
arg1,
2208+
} => {
2209+
format!("0 = {arg0} {operation} {arg1}")
2210+
}
21192211
Self::IfNotZero {
21202212
condition,
21212213
then_branch,

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,21 @@ fn compile_lines(
184184
}
185185
}
186186

187+
SimpleLine::TestZero {
188+
operation,
189+
arg0,
190+
arg1,
191+
} => {
192+
instructions.push(IntermediateInstruction::computation(
193+
*operation,
194+
IntermediateValue::from_simple_expr(arg0, compiler),
195+
IntermediateValue::from_simple_expr(arg1, compiler),
196+
IntermediateValue::Constant(0.into()),
197+
));
198+
199+
mark_vars_as_declared(&[arg0, arg1], declared_vars);
200+
}
201+
187202
SimpleLine::Match { value, arms } => {
188203
let match_index = compiler.match_blocks.len();
189204
let end_label = Label::match_end(match_index);
@@ -768,6 +783,7 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet<Var> {
768783
internal_vars.insert(var.clone());
769784
}
770785
}
786+
SimpleLine::TestZero { .. } => {}
771787
SimpleLine::HintMAlloc { var, .. }
772788
| SimpleLine::ConstMalloc { var, .. }
773789
| SimpleLine::DecomposeBits { var, .. }

crates/lean_compiler/src/grammar.pest

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ array_assign = { identifier ~ "[" ~ expression ~ "]" ~ "=" ~ expression ~ ";" }
4040

4141
if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_clause? }
4242

43-
condition = {condition_eq | condition_diff}
44-
condition_eq = { expression ~ "==" ~ expression }
45-
condition_diff = { expression ~ "!=" ~ expression }
43+
condition = { expression | assumed_bool_expr }
44+
45+
assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" }
4646

4747
else_clause = { "else" ~ "{" ~ statement* ~ "}" }
4848

@@ -58,12 +58,14 @@ function_call = { function_res? ~ identifier ~ "(" ~ tuple_expression? ~ ")" ~ "
5858
function_res = { var_list ~ "=" }
5959
var_list = { identifier ~ ("," ~ identifier)* }
6060

61-
assert_eq_statement = { "assert" ~ expression ~ "==" ~ expression ~ ";" }
62-
assert_not_eq_statement = { "assert" ~ expression ~ "!=" ~ expression ~ ";" }
61+
assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
62+
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }
6363

6464
// Expressions
6565
tuple_expression = { expression ~ ("," ~ expression)* }
66-
expression = { add_expr }
66+
expression = { neq_expr }
67+
neq_expr = { eq_expr ~ ("!=" ~ eq_expr)* }
68+
eq_expr = { add_expr ~ ("==" ~ add_expr)* }
6769
add_expr = { sub_expr ~ ("+" ~ sub_expr)* }
6870
sub_expr = { mul_expr ~ ("-" ~ mul_expr)* }
6971
mul_expr = { mod_expr ~ ("*" ~ mod_expr)* }
@@ -85,4 +87,4 @@ constant_value = { number | "public_input_start" }
8587

8688
// Lexical elements
8789
identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
88-
number = @{ ASCII_DIGIT+ }
90+
number = @{ ASCII_DIGIT+ }

crates/lean_compiler/src/ir/instruction.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ impl IntermediateInstruction {
120120
arg_c,
121121
res: arg_a,
122122
},
123-
HighLevelOperation::Exp | HighLevelOperation::Mod => unreachable!(),
123+
HighLevelOperation::Exp
124+
| HighLevelOperation::Mod
125+
| HighLevelOperation::Equal
126+
| HighLevelOperation::NotEqual => unreachable!(),
124127
}
125128
}
126129

crates/lean_compiler/src/ir/operation.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,29 @@ pub enum HighLevelOperation {
2323
Exp,
2424
/// Modulo operation (only for constant expressions).
2525
Mod,
26+
/// Equality comparison
27+
Equal,
28+
/// Non-equality comparison
29+
NotEqual,
2630
}
2731

2832
impl HighLevelOperation {
2933
pub fn eval(&self, a: F, b: F) -> F {
3034
match self {
35+
Self::Equal => {
36+
if a == b {
37+
F::ONE
38+
} else {
39+
F::ZERO
40+
}
41+
}
42+
Self::NotEqual => {
43+
if a != b {
44+
F::ONE
45+
} else {
46+
F::ZERO
47+
}
48+
}
3149
Self::Add => a + b,
3250
Self::Mul => a * b,
3351
Self::Sub => a - b,
@@ -41,6 +59,8 @@ impl HighLevelOperation {
4159
impl Display for HighLevelOperation {
4260
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
4361
match self {
62+
Self::Equal => write!(f, "=="),
63+
Self::NotEqual => write!(f, "!="),
4464
Self::Add => write!(f, "+"),
4565
Self::Mul => write!(f, "*"),
4666
Self::Sub => write!(f, "-"),

0 commit comments

Comments
 (0)