Skip to content

Commit c448cd3

Browse files
committed
completed range check impl
1 parent 08fea21 commit c448cd3

File tree

9 files changed

+108
-135
lines changed

9 files changed

+108
-135
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,8 @@ fn compile_lines(
615615
});
616616
}
617617
SimpleLine::RangeCheck { value, max } => {
618+
// x is the fp offset of the memory cell which contains the value
619+
// i.e. m[fp + x] contains value
618620
let x = match IntermediateValue::from_simple_expr(value, compiler) {
619621
IntermediateValue::MemoryAfterFp { offset } => offset.naive_eval().unwrap(),
620622
value::IntermediateValue::Fp => F::ZERO,
@@ -636,17 +638,14 @@ fn compile_lines(
636638
for_range_check: true,
637639
};
638640

639-
640-
// Step 2: ADD: m[m[fp + x]] + m[fp + j] == (t-1)
641-
// m[fp + j] == t - 1 - value
642-
//
641+
// Step 2: ADD: m[fp + x] + m[fp + j] == (t-1)
643642
// m[fp + j] == t - 1 - m[fp + x]
644-
let q = t - F::ONE;
643+
// Uses constraint solving to store t - 1 - m[fp + x] in m[fp + j]
645644
let step_2 = IntermediateInstruction::Computation {
646645
operation: Operation::Add,
647646
arg_a: IntermediateValue::MemoryAfterFp { offset: x.to_usize().into() },
648-
arg_c: IntermediateValue::MemoryAfterFp { offset: aux_j.into() }, // solve
649-
res: IntermediateValue::Constant(q.to_usize().into()), // t - 1
647+
arg_c: IntermediateValue::MemoryAfterFp { offset: aux_j.into() },
648+
res: IntermediateValue::Constant((t - F::ONE).to_usize().into()),
650649
};
651650

652651
// Step 3: DEREF: m[fp + k] == m[m[fp + j]]
@@ -657,20 +656,22 @@ fn compile_lines(
657656
for_range_check: true,
658657
};
659658

659+
// Insert the instructions
660+
instructions.extend_from_slice(
661+
&[
662+
// This is just the RangeCheck hint which does nothing
663+
IntermediateInstruction::RangeCheck {
664+
value: IntermediateValue::from_simple_expr(value, compiler),
665+
max: max.clone(),
666+
},
667+
// These are the steps that effectuate the range check
668+
step_1,
669+
step_2,
670+
step_3,
671+
]
672+
);
660673

661-
// TODO: handle undefined memory access error
662-
663-
//println!("aux_i: {}; {:?}", aux_i, step_1);
664-
//println!("aux_j: {}; {:?}", aux_j, step_2);
665-
//println!("aux_k: {}; {:?}", aux_k, step_3);
666-
667-
instructions.push(IntermediateInstruction::RangeCheck {
668-
value: IntermediateValue::from_simple_expr(value, compiler),
669-
max: max.clone(),
670-
});
671-
instructions.push(step_1);
672-
instructions.push(step_2);
673-
instructions.push(step_3);
674+
// Increase the stack size by 3 as we used 3 aux variables
674675
compiler.stack_size += 3;
675676
}
676677
}

crates/lean_prover/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ witness_generation.workspace = true
3333
vm_air.workspace = true
3434
multilinear-toolkit.workspace = true
3535
poseidon_circuit.workspace = true
36+
thiserror.workspace = true
3637

3738
[dev-dependencies]
3839
xmss.workspace = true

crates/lean_prover/src/prove_execution.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,22 @@ use whir_p3::{
2222
};
2323
use xmss::{Poseidon16History, Poseidon24History};
2424

25+
use thiserror::Error;
26+
27+
#[derive(Error, Debug)]
28+
pub enum ProveExecutionError {
29+
#[error("MemoryError")]
30+
MemoryError,
31+
}
32+
2533
pub fn prove_execution(
2634
bytecode: &Bytecode,
2735
(public_input, private_input): (&[F], &[F]),
2836
whir_config_builder: WhirConfigBuilder,
2937
no_vec_runtime_memory: usize, // size of the "non-vectorized" runtime memory
3038
vm_profiler: bool,
3139
(poseidons_16_precomputed, poseidons_24_precomputed): (&Poseidon16History, &Poseidon24History),
32-
) -> (Vec<PF<EF>>, usize, String) {
40+
) -> Result<(Vec<PF<EF>>, usize, String), ProveExecutionError> {
3341
let mut exec_summary = String::new();
3442
let ExecutionTrace {
3543
full_trace,
@@ -843,9 +851,14 @@ pub fn prove_execution(
843851
&mut base_memory_poly_eq_point,
844852
memory_poly_eq_point_alpha.square(),
845853
);
854+
855+
let memory_len = memory.len();
856+
if base_memory_indexes.iter().max().unwrap().to_usize() >= memory_len {
857+
return Err(ProveExecutionError::MemoryError);
858+
}
846859
let base_memory_pushforward = compute_pushforward(
847860
&base_memory_indexes,
848-
memory.len(),
861+
memory_len,
849862
&base_memory_poly_eq_point,
850863
);
851864

@@ -1188,9 +1201,9 @@ pub fn prove_execution(
11881201
&packed_pcs_witness_extension.packed_polynomial.by_ref(),
11891202
);
11901203

1191-
(
1204+
Ok((
11921205
prover_state.proof_data().to_vec(),
11931206
prover_state.proof_size(),
11941207
exec_summary,
1195-
)
1208+
))
11961209
}

crates/lean_prover/tests/hash_chain.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ fn benchmark_poseidon_chain() {
8181
false,
8282
(&vec![], &vec![]), // TODO poseidons precomputed
8383
)
84+
.unwrap()
8485
.0;
8586
let vm_time = time.elapsed();
8687
verify_execution(&bytecode, &public_input, proof_data, whir_config_builder()).unwrap();
Lines changed: 62 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,43 @@
1-
use lean_compiler::*;
2-
use lean_prover::{
3-
prove_execution::prove_execution, whir_config_builder,
4-
};
5-
use lean_vm::*;
1+
use lean_compiler::compile_program;
2+
use lean_vm::{F, DIMENSION, PUBLIC_INPUT_START};
3+
use lean_prover::{whir_config_builder, prove_execution::prove_execution};
4+
use whir_p3::WhirConfigBuilder;
65
use p3_field::PrimeCharacteristicRing;
76
use std::collections::BTreeSet;
87
use rand::Rng;
98
use rand_chacha::ChaCha20Rng;
109
use rand_chacha::rand_core::SeedableRng;
11-
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
1210

1311
const NO_VEC_RUNTIME_MEMORY: usize = 1 << 20;
1412

15-
fn critical_test_cases() -> (BTreeSet<(usize, usize)>, BTreeSet<(usize, usize)>) {
16-
let mut happy_test_cases = BTreeSet::<(usize, usize)>::new();
17-
let mut sad_test_cases = BTreeSet::<(usize, usize)>::new();
18-
19-
for t in 69..70 {
20-
for v in 0..t {
21-
if v < t {
22-
happy_test_cases.insert((v, t));
23-
} else {
24-
sad_test_cases.insert((v, t));
25-
}
26-
}
27-
//for v in 16777215..16777300 {
28-
//if v < t {
29-
//happy_test_cases.insert((v, t));
30-
//} else {
31-
//sad_test_cases.insert((v, t));
32-
//}
33-
//}
34-
}
13+
fn range_check_program(value: usize, max: usize) -> String {
14+
let program = format!(r#"
15+
fn func() {{
16+
x = 1;
17+
y = {value};
18+
value = x * y;
19+
range_check(value, {max});
20+
return;
21+
}}
22+
23+
fn main() {{
24+
x = 1;
25+
y = {value};
26+
value = x * y;
27+
range_check(value, {max});
3528
36-
(happy_test_cases, sad_test_cases)
29+
func();
30+
31+
if 0 == 0 {{
32+
a = 1;
33+
b = {value};
34+
c = a * b;
35+
range_check(c, {max});
36+
}}
37+
return;
38+
}}
39+
"#);
40+
program.to_string()
3741
}
3842

3943
fn random_test_cases(num_test_cases: usize) -> BTreeSet<(usize, usize)> {
@@ -62,43 +66,7 @@ fn random_test_cases(num_test_cases: usize) -> BTreeSet<(usize, usize)> {
6266
test_cases
6367
}
6468

65-
fn range_check_program(value: usize, max: usize) -> String {
66-
let program = format!(r#"
67-
const DIM = 5;
68-
const SECOND_POINT = 2;
69-
const SECOND_N_VARS = 7;
70-
const COMPRESSION = 1;
71-
const PERMUTATION = 0;
72-
73-
fn main() {{
74-
x = 1;
75-
y = {value};
76-
value = x * y;
77-
range_check(value, {max});
78-
79-
// Need to add the following to avoid a "TODO small GKR, no packing" error
80-
81-
for i in 10..50 {{
82-
x = malloc_vec(6);
83-
poseidon16(i + 3, i, x, PERMUTATION);
84-
poseidon24(i + 3, i, x + 2);
85-
dot_product(i*2, i, (x + 3) * 8, 1);
86-
dot_product(i*3, i + 7, (x + 4) * 8, 2);
87-
}}
88-
89-
for i in 0..1000 {{
90-
assert i != 1000;
91-
}}
92-
93-
return;
94-
}}
95-
"#);
96-
program.to_string()
97-
}
98-
99-
fn do_test_range_check(v: usize, t: usize) {
100-
let program_str = range_check_program(v, t);
101-
69+
fn prepare_inputs() -> (Vec<F>, Vec<F>) {
10270
const SECOND_POINT: usize = 2;
10371
const SECOND_N_VARS: usize = 7;
10472

@@ -117,61 +85,48 @@ fn do_test_range_check(v: usize, t: usize) {
11785
let private_input = (0..1 << 13)
11886
.map(|i| F::from_usize(i).square())
11987
.collect::<Vec<_>>();
88+
89+
(public_input, private_input)
90+
}
91+
92+
fn do_test_range_check(
93+
v: usize,
94+
t: usize,
95+
whir_config_builder: &WhirConfigBuilder,
96+
public_input: &Vec<F>,
97+
private_input: &Vec<F>
98+
) {
99+
let program_str = range_check_program(v, t);
120100

121101
let bytecode = compile_program(program_str);
122-
let _proof_data = prove_execution(
102+
103+
let r = prove_execution(
123104
&bytecode,
124-
(&public_input, &private_input),
125-
whir_config_builder(),
105+
(public_input, private_input),
106+
whir_config_builder.clone(),
126107
NO_VEC_RUNTIME_MEMORY,
127108
false,
128109
(&vec![], &vec![]),
129110
);
130-
}
131-
132-
#[test]
133-
fn test_prove_range_check_happy() {
134-
let (happy_test_cases, _sad_test_cases) = critical_test_cases();
135-
println!("Running {} test cases:", happy_test_cases.len());
136-
//happy_test_cases.par_iter().for_each(|(v, t)| {
137-
//do_test_range_check(*v, *t);
138-
//});
139-
for (v, t) in happy_test_cases {
140-
do_test_range_check(v, t);
111+
112+
if v < t {
113+
assert!(r.is_ok(), "Proof generation should work for v < t");
114+
} else {
115+
assert!(r.is_err(), "Proof generation should fail for v >= t");
141116
}
142-
}
143-
144-
//#[test]
145-
//fn test_prove_range_check_sad() {
146-
//let (_happy_test_cases, sad_test_cases) = critical_test_cases();
147-
//for (v, t) in sad_test_cases {
148-
//let result = std::panic::catch_unwind(|| {
149-
//do_test_range_check(v, t);
150-
//});
151-
//assert!(result.is_err(), "Expected panic for test case v={}, t={}", v, t);
152-
//}
153-
//}
154117

155-
#[test]
156-
#[should_panic]
157-
fn test_prove_range_check_sad_1() {
158-
do_test_range_check(0, 0);
159118
}
160119

161120
#[test]
162-
#[should_panic]
163-
fn test_prove_range_check_sad_2() {
164-
do_test_range_check(1, 0);
165-
}
121+
fn test_range_check() {
122+
let (public_input, private_input) = prepare_inputs();
123+
let whir_config_builder = whir_config_builder();
166124

167-
#[test]
168-
#[should_panic]
169-
fn test_prove_range_check_sad_3() {
170-
do_test_range_check(2, 1);
171-
}
125+
let test_cases = random_test_cases(500);
172126

173-
#[test]
174-
#[should_panic]
175-
fn test_prove_range_check_sad_4() {
176-
do_test_range_check(69, 65);
127+
println!("Running {} random test cases", test_cases.len());
128+
129+
for (v, t) in test_cases {
130+
do_test_range_check(v, t, &whir_config_builder, &public_input, &private_input);
131+
}
177132
}

crates/lean_prover/tests/test_zkvm.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ fn test_zk_vm_helper(program_str: &str) {
109109
false,
110110
(&vec![], &vec![]),
111111
)
112+
.unwrap()
112113
.0;
113114
verify_execution(&bytecode, &public_input, proof_data, whir_config_builder()).unwrap();
114115
}

crates/lean_prover/witness_generation/src/execution_trace.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ pub fn get_execution_trace(
6363
// flag_a == 0
6464
addr_a = F::from_usize(fp) + field_repr[0]; // fp + operand_a
6565
}
66-
let value_a = memory.0[addr_a.to_usize()].unwrap();
66+
let value_a = memory.get(addr_a.to_usize()).unwrap_or(F::ZERO);
6767
let mut addr_b = F::ZERO;
6868
if field_repr[4].is_zero() {
6969
// flag_b == 0
7070
addr_b = F::from_usize(fp) + field_repr[1]; // fp + operand_b
7171
}
72-
let value_b = memory.0[addr_b.to_usize()].unwrap();
72+
let value_b = memory.get(addr_b.to_usize()).unwrap_or(F::ZERO);
7373

7474
let mut addr_c = F::ZERO;
7575
if field_repr[5].is_zero() {
@@ -80,7 +80,7 @@ pub fn get_execution_trace(
8080
assert_eq!(field_repr[2], operand_c); // debug purpose
8181
addr_c = value_a + operand_c;
8282
}
83-
let value_c = memory.0[addr_c.to_usize()].unwrap();
83+
let value_c = memory.get(addr_c.to_usize()).unwrap_or(F::ZERO);
8484

8585
for (j, field) in field_repr.iter().enumerate() {
8686
*trace_row[j] = *field;

crates/utils/src/misc.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ pub fn transpose<F: Copy + Send + Sync>(
111111
width: usize,
112112
column_extra_capacity: usize,
113113
) -> Vec<Vec<F>> {
114-
assert!((matrix.len().is_multiple_of(width)));
114+
assert!(matrix.len().is_multiple_of(width));
115115
let height = matrix.len() / width;
116116
let res = vec![
117117
{

0 commit comments

Comments
 (0)