Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 21 additions & 5 deletions crates/lean_compiler/src/a_simplify_lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ pub enum SimpleLine {
LocationReport {
location: SourceLineNumber,
},
RangeCheck {
value: SimpleExpr,
max: ConstExpression,
},
}

pub fn simplify_program(mut program: Program) -> SimpleProgram {
Expand Down Expand Up @@ -748,6 +752,14 @@ fn simplify_lines(
location: *location,
});
}
Line::RangeCheck { value, max } => {
let simplified_value =
simplify_expr(value, &mut res, counters, array_manager, const_malloc);
res.push(SimpleLine::RangeCheck {
value: simplified_value,
max: max.clone(),
});
}
}
}

Expand Down Expand Up @@ -970,7 +982,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
on_new_expr(index, &internal_vars, &mut external_vars);
on_new_expr(value, &internal_vars, &mut external_vars);
}
Line::Panic | Line::Break | Line::LocationReport { .. } => {}
Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
}
}

Expand Down Expand Up @@ -1141,7 +1153,7 @@ pub fn inline_lines(
inline_expr(index, args, inlining_count);
inline_expr(value, args, inlining_count);
}
Line::Panic | Line::Break | Line::LocationReport { .. } => {}
Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
}
}
for (i, new_lines) in lines_to_replace.into_iter().rev() {
Expand Down Expand Up @@ -1635,7 +1647,7 @@ fn replace_vars_for_unroll(
Line::CounterHint { var } => {
*var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}");
}
Line::Break | Line::Panic | Line::LocationReport { .. } => {}
Line::Break | Line::Panic | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
}
}
}
Expand Down Expand Up @@ -2027,7 +2039,8 @@ fn get_function_called(lines: &[Line], function_called: &mut Vec<String>) {
| Line::MAlloc { .. }
| Line::Panic
| Line::Break
| Line::LocationReport { .. } => {}
| Line::LocationReport { .. }
| Line::RangeCheck { .. } => {}
}
}
}
Expand Down Expand Up @@ -2136,7 +2149,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
assert!(!map.contains_key(var), "Variable {var} is a constant");
replace_vars_by_const_in_expr(size, map);
}
Line::Panic | Line::Break | Line::LocationReport { .. } => {}
Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
}
}
}
Expand Down Expand Up @@ -2333,6 +2346,9 @@ impl SimpleLine {
}
Self::Panic => "panic".to_string(),
Self::LocationReport { .. } => Default::default(),
Self::RangeCheck { value, max } => {
format!("range_check({value}, {max})")
}
};
format!("{spaces}{line_str}")
}
Expand Down
74 changes: 73 additions & 1 deletion crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ fn compile_lines(
shift_0,
shift_1: shift.clone(),
res: res.to_mem_after_fp_or_constant(compiler),
for_range_check: false,
});
}

Expand Down Expand Up @@ -457,6 +458,7 @@ fn compile_lines(
res: IntermediaryMemOrFpOrConstant::MemoryAfterFp {
offset: compiler.get_offset(&ret_var.clone().into()),
},
for_range_check: false,
});
}

Expand Down Expand Up @@ -638,6 +640,72 @@ fn compile_lines(
location: *location,
});
}
SimpleLine::RangeCheck { value, max } => {
// x is the fp offset of the memory cell which contains the value
// i.e. m[fp + x] contains value
let x = match IntermediateValue::from_simple_expr(value, compiler) {
IntermediateValue::MemoryAfterFp { offset } => offset.naive_eval().unwrap(),
value::IntermediateValue::Fp => F::ZERO,
value::IntermediateValue::Constant(_) => unimplemented!(),
};

let t = max.naive_eval().unwrap();
let aux_i = compiler.stack_size;
let aux_j = compiler.stack_size + 1;
let aux_k = compiler.stack_size + 2;

// Step 1: DEREF: m[fp + i] == m[m[fp + x]]
// DEREF: m[fp + i] == m[value]

let step_1 = IntermediateInstruction::Deref {
shift_0: ConstExpression::scalar(x.to_usize()),
shift_1: ConstExpression::from(0),
res: IntermediaryMemOrFpOrConstant::MemoryAfterFp {
offset: aux_i.into(),
},
for_range_check: true,
};

// Step 2: ADD: m[fp + x] + m[fp + j] == (t-1)
// m[fp + j] == t - 1 - m[fp + x]
// Uses constraint solving to store t - 1 - m[fp + x] in m[fp + j]
let step_2 = IntermediateInstruction::Computation {
operation: Operation::Add,
arg_a: IntermediateValue::MemoryAfterFp {
offset: x.to_usize().into(),
},
arg_c: IntermediateValue::MemoryAfterFp {
offset: aux_j.into(),
},
res: IntermediateValue::Constant((t - F::ONE).to_usize().into()),
};

// Step 3: DEREF: m[fp + k] == m[m[fp + j]]
let step_3 = IntermediateInstruction::Deref {
shift_0: ConstExpression::scalar(aux_j),
shift_1: ConstExpression::from(0),
res: IntermediaryMemOrFpOrConstant::MemoryAfterFp {
offset: aux_k.into(),
},
for_range_check: true,
};

// Insert the instructions
instructions.extend_from_slice(&[
// This is just the RangeCheck hint which does nothing
IntermediateInstruction::RangeCheck {
value: IntermediateValue::from_simple_expr(value, compiler),
max: max.clone(),
},
// These are the steps that effectuate the range check
step_1,
step_2,
step_3,
]);

// Increase the stack size by 3 as we used 3 aux variables
compiler.stack_size += 3;
}
}
}

Expand Down Expand Up @@ -715,11 +783,13 @@ fn setup_function_call(
res: IntermediaryMemOrFpOrConstant::Constant(ConstExpression::label(
return_label.clone(),
)),
for_range_check: false,
},
IntermediateInstruction::Deref {
shift_0: new_fp_pos.into(),
shift_1: ConstExpression::one(),
res: IntermediaryMemOrFpOrConstant::Fp,
for_range_check: false,
},
];

Expand All @@ -729,6 +799,7 @@ fn setup_function_call(
shift_0: new_fp_pos.into(),
shift_1: (2 + i).into(),
res: arg.to_mem_after_fp_or_constant(compiler),
for_range_check: false,
});
}

Expand Down Expand Up @@ -837,7 +908,8 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet<Var> {
| SimpleLine::Print { .. }
| SimpleLine::FunctionRet { .. }
| SimpleLine::Precompile { .. }
| SimpleLine::LocationReport { .. } => {}
| SimpleLine::LocationReport { .. }
| SimpleLine::RangeCheck { .. } => {}
}
}
internal_vars
Expand Down
14 changes: 12 additions & 2 deletions crates/lean_compiler/src/c_compile_final.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ impl IntermediateInstruction {
| Self::DecomposeCustom { .. }
| Self::CounterHint { .. }
| Self::Inverse { .. }
| Self::LocationReport { .. } => true,
| Self::LocationReport { .. }
| Self::RangeCheck { .. } => true,
Self::Computation { .. }
| Self::Panic
| Self::Deref { .. }
Expand Down Expand Up @@ -264,6 +265,7 @@ fn compile_block(
shift_0,
shift_1,
res,
for_range_check,
} => {
low_level_bytecode.push(Instruction::Deref {
shift_0: eval_const_expression(&shift_0, compiler).to_usize(),
Expand All @@ -279,6 +281,7 @@ fn compile_block(
MemOrFpOrConstant::Constant(eval_const_expression(&c, compiler))
}
},
for_range_check,
});
}
IntermediateInstruction::JumpIfNotZero {
Expand Down Expand Up @@ -407,8 +410,15 @@ fn compile_block(
let hint = Hint::LocationReport { location };
hints.entry(pc).or_default().push(hint);
}
IntermediateInstruction::RangeCheck { value, max } => {
let hint = Hint::RangeCheck {
value: value.try_into_mem_or_fp(compiler).unwrap(),
// TODO: support max being an IntermediateValue
max: MemOrConstant::Constant(eval_const_expression(&max, compiler)),
};
hints.entry(pc).or_default().push(hint);
}
}

if !instruction.is_hint() {
pc += 1;
}
Expand Down
7 changes: 5 additions & 2 deletions crates/lean_compiler/src/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ statement = {
return_statement |
break_statement |
continue_statement |
function_call |
range_check_statement |
assert_eq_statement |
assert_not_eq_statement
assert_not_eq_statement |
function_call // Placed at the end so that it doesn't override other statements like range_check and assert_eq
}

return_statement = { "return" ~ (tuple_expression)? ~ ";" }
Expand Down Expand Up @@ -61,6 +62,8 @@ var_list = { identifier ~ ("," ~ identifier)* }
assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }

range_check_statement = { "range_check" ~ "(" ~ expression ~ "," ~ expression ~ ")" ~ ";" }

// Expressions
tuple_expression = { expression ~ ("," ~ expression)* }
expression = { neq_expr }
Expand Down
14 changes: 13 additions & 1 deletion crates/lean_compiler/src/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum IntermediateInstruction {
shift_0: ConstExpression,
shift_1: ConstExpression,
res: IntermediaryMemOrFpOrConstant,
for_range_check: bool,
}, // res = m[m[fp + shift_0]]
Panic,
Jump {
Expand Down Expand Up @@ -86,6 +87,10 @@ pub enum IntermediateInstruction {
LocationReport {
location: SourceLineNumber,
},
RangeCheck {
value: IntermediateValue,
max: ConstExpression,
},
}

impl IntermediateInstruction {
Expand Down Expand Up @@ -152,7 +157,11 @@ impl Display for IntermediateInstruction {
shift_0,
shift_1,
res,
} => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]"),
for_range_check,
} => write!(
f,
"{res} = m[m[fp + {shift_0}] + {shift_1}] for_range_check: {for_range_check}"
),
Self::Panic => write!(f, "panic"),
Self::Jump { dest, updated_fp } => {
if let Some(fp) = updated_fp {
Expand Down Expand Up @@ -256,6 +265,9 @@ impl Display for IntermediateInstruction {
Ok(())
}
Self::LocationReport { .. } => Ok(()),
Self::RangeCheck { value, max } => {
write!(f, "range_check({value}, {max})")
}
}
}
}
5 changes: 5 additions & 0 deletions crates/lean_compiler/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ pub enum Line {
LocationReport {
location: SourceLineNumber,
},
RangeCheck {
value: Expression,
max: ConstExpression,
},
}
impl Display for Expression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -591,6 +595,7 @@ impl Line {
}
Self::Break => "break".to_string(),
Self::Panic => "panic".to_string(),
Self::RangeCheck { value, max } => format!("range_check({value}, {max})"),
};
format!("{spaces}{line_str}")
}
Expand Down
22 changes: 21 additions & 1 deletion crates/lean_compiler/src/parser/parsers/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::literal::ConstExprParser;
use super::{Parse, ParseContext, next_inner_pair};
use crate::{
ir::HighLevelOperation,
lang::{AssumeBoolean, Boolean, Condition, Expression, Line},
lang::{AssumeBoolean, Boolean, Condition, ConstExpression, Expression, Line},
parser::{
error::{ParseResult, SemanticError},
grammar::{ParsePair, Rule},
Expand All @@ -28,6 +28,7 @@ impl Parse<Line> for StatementParser {
Rule::function_call => FunctionCallParser::parse(inner, ctx),
Rule::assert_eq_statement => AssertEqParser::parse(inner, ctx),
Rule::assert_not_eq_statement => AssertNotEqParser::parse(inner, ctx),
Rule::range_check_statement => RangeCheckParser::parse(inner, ctx),
Rule::break_statement => Ok(Line::Break),
Rule::continue_statement => {
Err(SemanticError::new("Continue statement not implemented yet").into())
Expand Down Expand Up @@ -327,3 +328,22 @@ impl Parse<Line> for AssertNotEqParser {
))
}
}

/// Parser for range check statements.
pub struct RangeCheckParser;

impl Parse<Line> for RangeCheckParser {
fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult<Line> {
let mut inner = pair.into_inner();
let value =
ExpressionParser::parse(next_inner_pair(&mut inner, "range check value")?, ctx)?;
let max_expr =
ExpressionParser::parse(next_inner_pair(&mut inner, "range check max")?, ctx)?;

// Convert the max expression to a const expression
let max = ConstExpression::try_from(max_expr)
.map_err(|_| SemanticError::new("Range check maximum must be a constant expression"))?;

Ok(Line::RangeCheck { value, max })
}
}
Loading
Loading