@@ -2,8 +2,8 @@ use crate::{
22 Counter , F ,
33 ir:: HighLevelOperation ,
44 lang:: {
5- Boolean , ConstExpression , ConstMallocLabel , Expression , Function , Line , Program ,
6- SimpleExpr , Var ,
5+ AssumeBoolean , Boolean , Condition , ConstExpression , ConstMallocLabel , ConstantValue ,
6+ Expression , Function , Line , Program , SimpleExpr , Var ,
77 } ,
88 precompiles:: Precompile ,
99} ;
@@ -97,6 +97,12 @@ pub enum SimpleLine {
9797 then_branch : Vec < Self > ,
9898 else_branch : Vec < Self > ,
9999 } ,
100+ TestZero {
101+ // Test that the result of the given operation is zero
102+ operation : HighLevelOperation ,
103+ arg0 : SimpleExpr ,
104+ arg1 : SimpleExpr ,
105+ } ,
100106 FunctionCall {
101107 function_name : String ,
102108 args : Vec < SimpleExpr > ,
@@ -353,26 +359,68 @@ fn simplify_lines(
353359 then_branch,
354360 else_branch,
355361 } => {
356- // Transform if a == b then X else Y into if a != b then Y else X
362+ let ( condition_simplified, then_branch, else_branch) = match condition {
363+ Condition :: Comparison ( condition) => {
364+ // Transform if a == b then X else Y into if a != b then Y else X
365+
366+ let ( left, right, then_branch, else_branch) = match condition {
367+ Boolean :: Equal { left, right } => {
368+ ( left, right, else_branch, then_branch)
369+ } // switched
370+ Boolean :: Different { left, right } => {
371+ ( left, right, then_branch, else_branch)
372+ }
373+ } ;
374+
375+ let left_simplified =
376+ simplify_expr ( left, & mut res, counters, array_manager, const_malloc) ;
377+ let right_simplified =
378+ simplify_expr ( right, & mut res, counters, array_manager, const_malloc) ;
379+
380+ let diff_var = format ! ( "@diff_{}" , counters. aux_vars) ;
381+ counters. aux_vars += 1 ;
382+ res. push ( SimpleLine :: Assignment {
383+ var : diff_var. clone ( ) . into ( ) ,
384+ operation : HighLevelOperation :: Sub ,
385+ arg0 : left_simplified,
386+ arg1 : right_simplified,
387+ } ) ;
388+ ( diff_var. into ( ) , then_branch, else_branch)
389+ }
390+ Condition :: Expression ( condition, assume_boolean) => {
391+ let condition_simplified = simplify_expr (
392+ condition,
393+ & mut res,
394+ counters,
395+ array_manager,
396+ const_malloc,
397+ ) ;
357398
358- let ( left, right, then_branch, else_branch) = match condition {
359- Boolean :: Equal { left, right } => ( left, right, else_branch, then_branch) , // switched
360- Boolean :: Different { left, right } => ( left, right, then_branch, else_branch) ,
361- } ;
399+ match assume_boolean {
400+ AssumeBoolean :: AssumeBoolean => { }
401+ AssumeBoolean :: DoNotAssumeBoolean => {
402+ // Check condition_simplified is boolean
403+ let one_minus_condition_var = format ! ( "@aux_{}" , counters. aux_vars) ;
404+ counters. aux_vars += 1 ;
405+ res. push ( SimpleLine :: Assignment {
406+ var : one_minus_condition_var. clone ( ) . into ( ) ,
407+ operation : HighLevelOperation :: Sub ,
408+ arg0 : SimpleExpr :: Constant ( ConstExpression :: Value (
409+ ConstantValue :: Scalar ( 1 ) ,
410+ ) ) ,
411+ arg1 : condition_simplified. clone ( ) ,
412+ } ) ;
413+ res. push ( SimpleLine :: TestZero {
414+ operation : HighLevelOperation :: Mul ,
415+ arg0 : condition_simplified. clone ( ) ,
416+ arg1 : one_minus_condition_var. into ( ) ,
417+ } ) ;
418+ }
419+ }
362420
363- let left_simplified =
364- simplify_expr ( left, & mut res, counters, array_manager, const_malloc) ;
365- let right_simplified =
366- simplify_expr ( right, & mut res, counters, array_manager, const_malloc) ;
367-
368- let diff_var = format ! ( "@diff_{}" , counters. aux_vars) ;
369- counters. aux_vars += 1 ;
370- res. push ( SimpleLine :: Assignment {
371- var : diff_var. clone ( ) . into ( ) ,
372- operation : HighLevelOperation :: Sub ,
373- arg0 : left_simplified,
374- arg1 : right_simplified,
375- } ) ;
421+ ( condition_simplified, then_branch, else_branch)
422+ }
423+ } ;
376424
377425 let forbidden_vars_before = const_malloc. forbidden_vars . clone ( ) ;
378426
@@ -417,7 +465,7 @@ fn simplify_lines(
417465 . collect ( ) ;
418466
419467 res. push ( SimpleLine :: IfNotZero {
420- condition : diff_var . into ( ) ,
468+ condition : condition_simplified ,
421469 then_branch : then_branch_simplified,
422470 else_branch : else_branch_simplified,
423471 } ) ;
@@ -787,12 +835,20 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
787835 }
788836 } ;
789837
790- let on_new_condition =
791- |condition : & Boolean , internal_vars : & BTreeSet < Var > , external_vars : & mut BTreeSet < Var > | {
792- let ( Boolean :: Equal { left, right } | Boolean :: Different { left, right } ) = condition;
793- on_new_expr ( left, internal_vars, external_vars) ;
794- on_new_expr ( right, internal_vars, external_vars) ;
795- } ;
838+ let on_new_condition = |condition : & Condition ,
839+ internal_vars : & BTreeSet < Var > ,
840+ external_vars : & mut BTreeSet < Var > | {
841+ match condition {
842+ Condition :: Comparison ( Boolean :: Equal { left, right } )
843+ | Condition :: Comparison ( Boolean :: Different { left, right } ) => {
844+ on_new_expr ( left, internal_vars, external_vars) ;
845+ on_new_expr ( right, internal_vars, external_vars) ;
846+ }
847+ Condition :: Expression ( expr, _assume_boolean) => {
848+ on_new_expr ( expr, internal_vars, external_vars) ;
849+ }
850+ }
851+ } ;
796852
797853 for line in lines {
798854 match line {
@@ -839,7 +895,11 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
839895 internal_vars. extend ( return_data. iter ( ) . cloned ( ) ) ;
840896 }
841897 Line :: Assert ( condition) => {
842- on_new_condition ( condition, & internal_vars, & mut external_vars) ;
898+ on_new_condition (
899+ & Condition :: Comparison ( condition. clone ( ) ) ,
900+ & internal_vars,
901+ & mut external_vars,
902+ ) ;
843903 }
844904 Line :: FunctionRet { return_data } => {
845905 for ret in return_data {
@@ -944,12 +1004,17 @@ pub fn inline_lines(
9441004 res : & [ Var ] ,
9451005 inlining_count : usize ,
9461006) {
947- let inline_condition = |condition : & mut Boolean | {
948- let ( Boolean :: Equal { left, right } | Boolean :: Different { left, right } ) = condition ;
1007+ let inline_comparison = |comparison : & mut Boolean | {
1008+ let ( Boolean :: Equal { left, right } | Boolean :: Different { left, right } ) = comparison ;
9491009 inline_expr ( left, args, inlining_count) ;
9501010 inline_expr ( right, args, inlining_count) ;
9511011 } ;
9521012
1013+ let inline_condition = |condition : & mut Condition | match condition {
1014+ Condition :: Comparison ( comparison) => inline_comparison ( comparison) ,
1015+ Condition :: Expression ( expr, _assume_boolean) => inline_expr ( expr, args, inlining_count) ,
1016+ } ;
1017+
9531018 let inline_internal_var = |var : & mut Var | {
9541019 assert ! (
9551020 !args. contains_key( var) ,
@@ -994,7 +1059,7 @@ pub fn inline_lines(
9941059 }
9951060 }
9961061 Line :: Assert ( condition) => {
997- inline_condition ( condition) ;
1062+ inline_comparison ( condition) ;
9981063 }
9991064 Line :: FunctionRet { return_data } => {
10001065 assert_eq ! ( return_data. len( ) , res. len( ) ) ;
@@ -1368,24 +1433,40 @@ fn replace_vars_for_unroll(
13681433 ) ;
13691434 }
13701435 Line :: IfCondition {
1371- condition : Boolean :: Equal { left , right } | Boolean :: Different { left , right } ,
1436+ condition,
13721437 then_branch,
13731438 else_branch,
13741439 } => {
1375- replace_vars_for_unroll_in_expr (
1376- left,
1377- iterator,
1378- unroll_index,
1379- iterator_value,
1380- internal_vars,
1381- ) ;
1382- replace_vars_for_unroll_in_expr (
1383- right,
1384- iterator,
1385- unroll_index,
1386- iterator_value,
1387- internal_vars,
1388- ) ;
1440+ match condition {
1441+ Condition :: Comparison (
1442+ Boolean :: Equal { left, right } | Boolean :: Different { left, right } ,
1443+ ) => {
1444+ replace_vars_for_unroll_in_expr (
1445+ left,
1446+ iterator,
1447+ unroll_index,
1448+ iterator_value,
1449+ internal_vars,
1450+ ) ;
1451+ replace_vars_for_unroll_in_expr (
1452+ right,
1453+ iterator,
1454+ unroll_index,
1455+ iterator_value,
1456+ internal_vars,
1457+ ) ;
1458+ }
1459+ Condition :: Expression ( expr, _assume_bool) => {
1460+ replace_vars_for_unroll_in_expr (
1461+ expr,
1462+ iterator,
1463+ unroll_index,
1464+ iterator_value,
1465+ internal_vars,
1466+ ) ;
1467+ }
1468+ }
1469+
13891470 replace_vars_for_unroll (
13901471 then_branch,
13911472 iterator,
@@ -1972,10 +2053,14 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
19722053 else_branch,
19732054 } => {
19742055 match condition {
1975- Boolean :: Equal { left, right } | Boolean :: Different { left, right } => {
2056+ Condition :: Comparison ( Boolean :: Equal { left, right } )
2057+ | Condition :: Comparison ( Boolean :: Different { left, right } ) => {
19762058 replace_vars_by_const_in_expr ( left, map) ;
19772059 replace_vars_by_const_in_expr ( right, map) ;
19782060 }
2061+ Condition :: Expression ( expr, _assume_boolean) => {
2062+ replace_vars_by_const_in_expr ( expr, map) ;
2063+ }
19792064 }
19802065 replace_vars_by_const_in_lines ( then_branch, map) ;
19812066 replace_vars_by_const_in_lines ( else_branch, map) ;
@@ -2116,6 +2201,13 @@ impl SimpleLine {
21162201 Self :: RawAccess { res, index, shift } => {
21172202 format ! ( "memory[{index} + {shift}] = {res}" )
21182203 }
2204+ Self :: TestZero {
2205+ operation,
2206+ arg0,
2207+ arg1,
2208+ } => {
2209+ format ! ( "0 = {arg0} {operation} {arg1}" )
2210+ }
21192211 Self :: IfNotZero {
21202212 condition,
21212213 then_branch,
0 commit comments