Skip to content

Commit bb013ec

Browse files
committed
range check implemented in compiler; tests WIP
1 parent 94ce7d5 commit bb013ec

File tree

15 files changed

+372
-19
lines changed

15 files changed

+372
-19
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ pub enum SimpleLine {
154154
LocationReport {
155155
location: SourceLineNumber,
156156
},
157+
RangeCheck {
158+
value: SimpleExpr,
159+
max: ConstExpression,
160+
},
157161
}
158162

159163
pub fn simplify_program(mut program: Program) -> SimpleProgram {
@@ -738,6 +742,14 @@ fn simplify_lines(
738742
location: *location,
739743
});
740744
}
745+
Line::RangeCheck { value, max } => {
746+
let simplified_value =
747+
simplify_expr(value, &mut res, counters, array_manager, const_malloc);
748+
res.push(SimpleLine::RangeCheck {
749+
value: simplified_value,
750+
max: max.clone(),
751+
});
752+
}
741753
}
742754
}
743755

@@ -958,7 +970,7 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
958970
on_new_expr(index, &internal_vars, &mut external_vars);
959971
on_new_expr(value, &internal_vars, &mut external_vars);
960972
}
961-
Line::Panic | Line::Break | Line::LocationReport { .. } => {}
973+
Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
962974
}
963975
}
964976

@@ -1127,7 +1139,7 @@ pub fn inline_lines(
11271139
inline_expr(index, args, inlining_count);
11281140
inline_expr(value, args, inlining_count);
11291141
}
1130-
Line::Panic | Line::Break | Line::LocationReport { .. } => {}
1142+
Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
11311143
}
11321144
}
11331145
for (i, new_lines) in lines_to_replace.into_iter().rev() {
@@ -1612,7 +1624,7 @@ fn replace_vars_for_unroll(
16121624
Line::CounterHint { var } => {
16131625
*var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}");
16141626
}
1615-
Line::Break | Line::Panic | Line::LocationReport { .. } => {}
1627+
Line::Break | Line::Panic | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
16161628
}
16171629
}
16181630
}
@@ -2002,7 +2014,8 @@ fn get_function_called(lines: &[Line], function_called: &mut Vec<String>) {
20022014
| Line::MAlloc { .. }
20032015
| Line::Panic
20042016
| Line::Break
2005-
| Line::LocationReport { .. } => {}
2017+
| Line::LocationReport { .. }
2018+
| Line::RangeCheck { .. } => {}
20062019
}
20072020
}
20082021
}
@@ -2110,7 +2123,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
21102123
assert!(!map.contains_key(var), "Variable {var} is a constant");
21112124
replace_vars_by_const_in_expr(size, map);
21122125
}
2113-
Line::Panic | Line::Break | Line::LocationReport { .. } => {}
2126+
Line::Panic | Line::Break | Line::LocationReport { .. } | Line::RangeCheck { .. } => {}
21142127
}
21152128
}
21162129
}
@@ -2305,6 +2318,9 @@ impl SimpleLine {
23052318
}
23062319
Self::Panic => "panic".to_string(),
23072320
Self::LocationReport { .. } => Default::default(),
2321+
Self::RangeCheck { value, max } => {
2322+
format!("range_check({value}, {max})")
2323+
}
23082324
};
23092325
format!("{spaces}{line_str}")
23102326
}

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{F, a_simplify_lang::*, ir::*, lang::*, precompiles::*};
22
use lean_vm::*;
3-
use p3_field::Field;
3+
use p3_field::{Field, PrimeCharacteristicRing};
44
use std::{
55
borrow::Borrow,
66
collections::{BTreeMap, BTreeSet},
@@ -395,6 +395,7 @@ fn compile_lines(
395395
shift_0,
396396
shift_1: shift.clone(),
397397
res: res.to_mem_after_fp_or_constant(compiler),
398+
for_range_check: false,
398399
});
399400
}
400401

@@ -432,6 +433,7 @@ fn compile_lines(
432433
res: IntermediaryMemOrFpOrConstant::MemoryAfterFp {
433434
offset: compiler.get_offset(&ret_var.clone().into()),
434435
},
436+
for_range_check: false,
435437
});
436438
}
437439

@@ -612,6 +614,65 @@ fn compile_lines(
612614
location: *location,
613615
});
614616
}
617+
SimpleLine::RangeCheck { value, max } => {
618+
let x = match IntermediateValue::from_simple_expr(value, compiler) {
619+
IntermediateValue::MemoryAfterFp { offset } => offset.naive_eval().unwrap(),
620+
value::IntermediateValue::Fp => F::ZERO,
621+
value::IntermediateValue::Constant(_) => unimplemented!(),
622+
};
623+
624+
let t = max.naive_eval().unwrap();
625+
let aux_i: usize = compiler.stack_size + 0;
626+
let aux_j: usize = compiler.stack_size + 1;
627+
let aux_k: usize = compiler.stack_size + 2;
628+
629+
// Step 1: DEREF: m[fp + i] == m[m[fp + x]]
630+
// DEREF: m[fp + i] == m[value]
631+
632+
let step_1 = IntermediateInstruction::Deref {
633+
shift_0: ConstExpression::scalar(x.to_usize()),
634+
shift_1: ConstExpression::from(0),
635+
res: IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset: aux_i.into() },
636+
for_range_check: true,
637+
};
638+
639+
640+
// Step 2: ADD: m[m[fp + x]] + m[fp + j] == (t-1)
641+
// m[fp + j] == t - 1 - value
642+
//
643+
// m[fp + j] == t - 1 - m[fp + x]
644+
let q = t - F::ONE;
645+
let step_2 = IntermediateInstruction::Computation {
646+
operation: Operation::Add,
647+
arg_a: IntermediateValue::MemoryAfterFp { offset: x.to_usize().into() },
648+
arg_c: IntermediateValue::MemoryAfterFp { offset: aux_j.into() }, // solve
649+
res: IntermediateValue::Constant(q.to_usize().into()), // t - 1
650+
};
651+
652+
// Step 3: DEREF: m[fp + k] == m[m[fp + j]]
653+
let step_3 = IntermediateInstruction::Deref {
654+
shift_0: ConstExpression::scalar(aux_j),
655+
shift_1: ConstExpression::from(0),
656+
res: IntermediaryMemOrFpOrConstant::MemoryAfterFp { offset: aux_k.into() },
657+
for_range_check: true,
658+
};
659+
660+
661+
// TODO: handle undefined memory access error
662+
663+
//println!("aux_i: {}; {:?}", aux_i, step_1);
664+
//println!("aux_j: {}; {:?}", aux_j, step_2);
665+
//println!("aux_k: {}; {:?}", aux_k, step_3);
666+
667+
instructions.push(IntermediateInstruction::RangeCheck {
668+
value: IntermediateValue::from_simple_expr(value, compiler),
669+
max: max.clone(),
670+
});
671+
instructions.push(step_1);
672+
instructions.push(step_2);
673+
instructions.push(step_3);
674+
compiler.stack_size += 3;
675+
}
615676
}
616677
}
617678

@@ -689,11 +750,13 @@ fn setup_function_call(
689750
res: IntermediaryMemOrFpOrConstant::Constant(ConstExpression::label(
690751
return_label.clone(),
691752
)),
753+
for_range_check: false,
692754
},
693755
IntermediateInstruction::Deref {
694756
shift_0: new_fp_pos.into(),
695757
shift_1: ConstExpression::one(),
696758
res: IntermediaryMemOrFpOrConstant::Fp,
759+
for_range_check: false,
697760
},
698761
];
699762

@@ -703,6 +766,7 @@ fn setup_function_call(
703766
shift_0: new_fp_pos.into(),
704767
shift_1: (2 + i).into(),
705768
res: arg.to_mem_after_fp_or_constant(compiler),
769+
for_range_check: false,
706770
});
707771
}
708772

@@ -811,7 +875,8 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet<Var> {
811875
| SimpleLine::Print { .. }
812876
| SimpleLine::FunctionRet { .. }
813877
| SimpleLine::Precompile { .. }
814-
| SimpleLine::LocationReport { .. } => {}
878+
| SimpleLine::LocationReport { .. }
879+
| SimpleLine::RangeCheck { .. } => {}
815880
}
816881
}
817882
internal_vars

crates/lean_compiler/src/c_compile_final.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ impl IntermediateInstruction {
1313
| Self::DecomposeCustom { .. }
1414
| Self::CounterHint { .. }
1515
| Self::Inverse { .. }
16-
| Self::LocationReport { .. } => true,
16+
| Self::LocationReport { .. }
17+
| Self::RangeCheck { .. } => true,
1718
Self::Computation { .. }
1819
| Self::Panic
1920
| Self::Deref { .. }
@@ -22,7 +23,7 @@ impl IntermediateInstruction {
2223
| Self::Poseidon2_16 { .. }
2324
| Self::Poseidon2_24 { .. }
2425
| Self::DotProduct { .. }
25-
| Self::MultilinearEval { .. } => false,
26+
| Self::MultilinearEval { .. } => false
2627
}
2728
}
2829
}
@@ -238,6 +239,7 @@ fn compile_block(
238239
shift_0,
239240
shift_1,
240241
res,
242+
for_range_check,
241243
} => {
242244
low_level_bytecode.push(Instruction::Deref {
243245
shift_0: eval_const_expression(&shift_0, compiler).to_usize(),
@@ -253,6 +255,7 @@ fn compile_block(
253255
MemOrFpOrConstant::Constant(eval_const_expression(&c, compiler))
254256
}
255257
},
258+
for_range_check,
256259
});
257260
}
258261
IntermediateInstruction::JumpIfNotZero {
@@ -380,8 +383,15 @@ fn compile_block(
380383
let hint = Hint::LocationReport { location };
381384
hints.entry(pc).or_default().push(hint);
382385
}
386+
IntermediateInstruction::RangeCheck { value, max } => {
387+
let hint = Hint::RangeCheck {
388+
value: value.try_into_mem_or_fp(compiler).unwrap(),
389+
// TODO: support max being an IntermediateValue
390+
max: MemOrConstant::Constant(eval_const_expression(&max, compiler)),
391+
};
392+
hints.entry(pc).or_default().push(hint);
393+
}
383394
}
384-
385395
if !instruction.is_hint() {
386396
pc += 1;
387397
}

crates/lean_compiler/src/grammar.pest

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ statement = {
2424
return_statement |
2525
break_statement |
2626
continue_statement |
27-
function_call |
27+
range_check_statement |
2828
assert_eq_statement |
29-
assert_not_eq_statement
29+
assert_not_eq_statement |
30+
function_call // Placed at the end so that it doesn't override other statements like range_check and assert_eq
3031
}
3132

3233
return_statement = { "return" ~ (tuple_expression)? ~ ";" }
@@ -61,6 +62,8 @@ var_list = { identifier ~ ("," ~ identifier)* }
6162
assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
6263
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }
6364

65+
range_check_statement = { "range_check" ~ "(" ~ expression ~ "," ~ expression ~ ")" ~ ";" }
66+
6467
// Expressions
6568
tuple_expression = { expression ~ ("," ~ expression)* }
6669
expression = { neq_expr }

crates/lean_compiler/src/ir/instruction.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub enum IntermediateInstruction {
1717
shift_0: ConstExpression,
1818
shift_1: ConstExpression,
1919
res: IntermediaryMemOrFpOrConstant,
20+
for_range_check: bool,
2021
}, // res = m[m[fp + shift_0]]
2122
Panic,
2223
Jump {
@@ -86,6 +87,10 @@ pub enum IntermediateInstruction {
8687
LocationReport {
8788
location: SourceLineNumber,
8889
},
90+
RangeCheck {
91+
value: IntermediateValue,
92+
max: ConstExpression,
93+
},
8994
}
9095

9196
impl IntermediateInstruction {
@@ -152,7 +157,8 @@ impl Display for IntermediateInstruction {
152157
shift_0,
153158
shift_1,
154159
res,
155-
} => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]"),
160+
for_range_check,
161+
} => write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}] for_range_check: {for_range_check}"),
156162
Self::Panic => write!(f, "panic"),
157163
Self::Jump { dest, updated_fp } => {
158164
if let Some(fp) = updated_fp {
@@ -256,6 +262,9 @@ impl Display for IntermediateInstruction {
256262
Ok(())
257263
}
258264
Self::LocationReport { .. } => Ok(()),
265+
Self::RangeCheck { value, max } => {
266+
write!(f, "range_check({value}, {max})")
267+
}
259268
}
260269
}
261270
}

crates/lean_compiler/src/lang.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,10 @@ pub enum Line {
390390
LocationReport {
391391
location: SourceLineNumber,
392392
},
393+
RangeCheck {
394+
value: Expression,
395+
max: ConstExpression,
396+
},
393397
}
394398
impl Display for Expression {
395399
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
@@ -585,6 +589,7 @@ impl Line {
585589
}
586590
Self::Break => "break".to_string(),
587591
Self::Panic => "panic".to_string(),
592+
Self::RangeCheck { value, max } => format!("range_check({value}, {max})"),
588593
};
589594
format!("{spaces}{line_str}")
590595
}

crates/lean_compiler/src/parser/parsers/statement.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use super::literal::ConstExprParser;
44
use super::{Parse, ParseContext, next_inner_pair};
55
use crate::{
66
ir::HighLevelOperation,
7-
lang::{AssumeBoolean, Boolean, Condition, Expression, Line},
7+
lang::{AssumeBoolean, Boolean, Condition, ConstExpression, Expression, Line},
88
parser::{
99
error::{ParseResult, SemanticError},
1010
grammar::{ParsePair, Rule},
@@ -28,6 +28,7 @@ impl Parse<Line> for StatementParser {
2828
Rule::function_call => FunctionCallParser::parse(inner, ctx),
2929
Rule::assert_eq_statement => AssertEqParser::parse(inner, ctx),
3030
Rule::assert_not_eq_statement => AssertNotEqParser::parse(inner, ctx),
31+
Rule::range_check_statement => RangeCheckParser::parse(inner, ctx),
3132
Rule::break_statement => Ok(Line::Break),
3233
Rule::continue_statement => {
3334
Err(SemanticError::new("Continue statement not implemented yet").into())
@@ -318,3 +319,22 @@ impl Parse<Line> for AssertNotEqParser {
318319
Ok(Line::Assert(Boolean::Different { left, right }))
319320
}
320321
}
322+
323+
/// Parser for range check statements.
324+
pub struct RangeCheckParser;
325+
326+
impl Parse<Line> for RangeCheckParser {
327+
fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult<Line> {
328+
let mut inner = pair.into_inner();
329+
let value =
330+
ExpressionParser::parse(next_inner_pair(&mut inner, "range check value")?, ctx)?;
331+
let max_expr =
332+
ExpressionParser::parse(next_inner_pair(&mut inner, "range check max")?, ctx)?;
333+
334+
// Convert the max expression to a const expression
335+
let max = ConstExpression::try_from(max_expr)
336+
.map_err(|_| SemanticError::new("Range check maximum must be a constant expression"))?;
337+
338+
Ok(Line::RangeCheck { value, max })
339+
}
340+
}

0 commit comments

Comments
 (0)