Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 245 additions & 25 deletions crates/lean_compiler/src/a_simplify_lang.rs

Large diffs are not rendered by default.

399 changes: 183 additions & 216 deletions crates/lean_compiler/src/b_compile_intermediate.rs

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions crates/lean_compiler/src/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ return_count = { "->" ~ number }

// Statements
statement = {
forward_declaration |
single_assignment |
array_assign |
if_statement |
Expand All @@ -34,6 +35,8 @@ return_statement = { "return" ~ (tuple_expression)? ~ ";" }
break_statement = { "break" ~ ";" }
continue_statement = { "continue" ~ ";" }

forward_declaration = { "var" ~ identifier ~ ";" }

single_assignment = { identifier ~ "=" ~ expression ~ ";" }

array_assign = { identifier ~ "[" ~ expression ~ "]" ~ "=" ~ expression ~ ";" }
Expand Down
32 changes: 31 additions & 1 deletion crates/lean_compiler/src/lang.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use lean_vm::*;
use multilinear_toolkit::prelude::*;
use p3_util::log2_ceil_usize;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Display, Formatter};
use utils::ToUsize;

Expand Down Expand Up @@ -307,6 +307,9 @@ pub enum Line {
value: Expression,
arms: Vec<(usize, Vec<Self>)>,
},
ForwardDeclaration {
var: Var,
},
Assignment {
var: Var,
value: Expression,
Expand Down Expand Up @@ -378,6 +381,30 @@ pub enum Line {
location: SourceLineNumber,
},
}

/// A context specifying which variables are in scope.
pub struct Context {
/// A list of lexical scopes, innermost scope last.
pub scopes: Vec<Scope>,
}

impl Context {
pub fn defines(&self, var: &Var) -> bool {
for scope in self.scopes.iter() {
if scope.vars.contains(var) {
return true;
}
}
false
}
}

#[derive(Default)]
pub struct Scope {
/// A set of declared variables.
pub vars: BTreeSet<Var>,
}

impl Display for Expression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down Expand Up @@ -418,6 +445,9 @@ impl Line {
.join("\n");
format!("match {value} {{\n{arms_str}\n{spaces}}}")
}
Self::ForwardDeclaration { var } => {
format!("var {var}")
}
Self::Assignment { var, value } => {
format!("{var} = {value}")
}
Expand Down
2 changes: 1 addition & 1 deletion crates/lean_compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn compile_program(program: String) -> Bytecode {
let (parsed_program, function_locations) = parse_program(&program).unwrap();
// println!("Parsed program: {}", parsed_program.to_string());
let simple_program = simplify_program(parsed_program);
// println!("Simplified program: {}", simple_program.to_string());
// println!("Simplified program: {}", simple_program);
let intermediate_bytecode = compile_to_intermediate_bytecode(simple_program).unwrap();
// println!("Intermediate Bytecode:\n\n{}", intermediate_bytecode.to_string());

Expand Down
12 changes: 12 additions & 0 deletions crates/lean_compiler/src/parser/parsers/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ impl Parse<Line> for StatementParser {
let inner = next_inner_pair(&mut pair.into_inner(), "statement body")?;

match inner.as_rule() {
Rule::forward_declaration => ForwardDeclarationParser::parse(inner, ctx),
Rule::single_assignment => AssignmentParser::parse(inner, ctx),
Rule::array_assign => ArrayAssignParser::parse(inner, ctx),
Rule::if_statement => IfStatementParser::parse(inner, ctx),
Expand All @@ -35,6 +36,17 @@ impl Parse<Line> for StatementParser {
}
}

/// Parser for forward declarations of variables.
pub struct ForwardDeclarationParser;

impl Parse<Line> for ForwardDeclarationParser {
fn parse(pair: ParsePair<'_>, _ctx: &mut ParseContext) -> ParseResult<Line> {
let mut inner = pair.into_inner();
let var = next_inner_pair(&mut inner, "variable name")?.as_str().to_string();
Ok(Line::ForwardDeclaration { var })
}
}

/// Parser for variable assignments.
pub struct AssignmentParser;

Expand Down
66 changes: 66 additions & 0 deletions crates/lean_compiler/tests/test_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,26 @@ fn test_inlined() {
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
}

#[test]
fn test_inlined_2() {
let program = r#"
fn main() {
b = is_one();
c = b;
return;
}

fn is_one() inline -> 1 {
if 1 {
return 1;
} else {
return 0;
}
}
"#;
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
}

#[test]
fn test_match() {
let program = r#"
Expand Down Expand Up @@ -433,6 +453,29 @@ fn test_match() {
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
}

#[test]
fn test_match_shrink() {
let program = r#"
fn main() {
match 1 {
0 => {
y = 90;
}
1 => {
y = 10;
z = func_2(y);
}
}
return;
}

fn func_2(x) inline -> 1 {
return x * x;
}
"#;
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
}

// #[test]
// fn inline_bug_mre() {
// let program = r#"
Expand Down Expand Up @@ -523,3 +566,26 @@ fn test_nested_inline_functions() {

compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
}

#[test]
fn test_const_and_nonconst_malloc_sharing_name() {
let program = r#"
fn main() {
f(1);
return;
}

fn f(n) {
if 0 == 0 {
res = malloc(2);
res[1] = 0;
return;
} else {
res = malloc(n * 1);
return;
}
}
"#;

compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
}
4 changes: 2 additions & 2 deletions crates/lean_vm/src/diagnostics/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ pub enum RunnerError {
#[error("Computation invalid: {0} != {1}")]
NotEqual(F, F),

#[error("Undefined memory access")]
UndefinedMemory,
#[error("Undefined memory access: {0}")]
UndefinedMemory(usize),

#[error("Program counter out of bounds")]
PCOutOfBounds,
Expand Down
6 changes: 5 additions & 1 deletion crates/lean_vm/src/execution/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ impl Memory {
///
/// Returns an error if the address is uninitialized
pub fn get(&self, index: usize) -> Result<F, RunnerError> {
self.0.get(index).copied().flatten().ok_or(RunnerError::UndefinedMemory)
self.0
.get(index)
.copied()
.flatten()
.ok_or(RunnerError::UndefinedMemory(index))
}

/// Sets a value at a memory address
Expand Down
2 changes: 1 addition & 1 deletion crates/lean_vm/src/execution/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn test_basic_memory_operations() {
assert_eq!(memory.get(5).unwrap(), F::from_usize(42));

// Test undefined memory access
assert!(matches!(memory.get(1), Err(RunnerError::UndefinedMemory)));
assert!(matches!(memory.get(1), Err(RunnerError::UndefinedMemory(1))));
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion crates/lookup/src/quotient_gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ fn sum_quotients_helper<F: PrimeCharacteristicRing + Sync + Send + Copy>(
) -> Vec<Vec<F>> {
assert_eq!(numerators_and_denominators.len(), n_groups);
let n = numerators_and_denominators[0].len();
assert!(n.is_power_of_two() && n >= 2, "n = {}", n);
assert!(n.is_power_of_two() && n >= 2, "n = {n}");
let mut new_numerators = Vec::new();
let mut new_denominators = Vec::new();
let (prev_numerators, prev_denominators) = numerators_and_denominators.split_at(n_groups / 2);
Expand Down
34 changes: 20 additions & 14 deletions crates/rec_aggregation/recursion_program.lean_lang
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ fn eq_mle_extension_base_dynamic(a, b, n) -> 1 {
}

fn expand_from_univariate_dynamic(alpha, n) -> 1 {
var res;
match n {
0 => { } // unreachable
1 => { res = expand_from_univariate_const(alpha, 1); }
Expand Down Expand Up @@ -302,22 +303,23 @@ fn sample_stir_indexes_and_fold(fs_state, num_queries, merkle_leaves_in_basefiel
fs_states_b = malloc(num_queries + 1);
fs_states_b[0] = fs_state_9;

var n_chunks_per_answer;
// the number of chunk of 8 field elements per merkle leaf opened
if merkle_leaves_in_basefield == 1 {
n_chuncks_per_answer = two_pow_folding_factor / 8; // "/ 8" because initial merkle leaves are in the basefield
n_chunks_per_answer = two_pow_folding_factor / 8; // "/ 8" because initial merkle leaves are in the basefield
} else {
n_chuncks_per_answer = two_pow_folding_factor * DIM / 8;
n_chunks_per_answer = two_pow_folding_factor * DIM / 8;
}

for i in 0..num_queries {
new_fs_state, answer = fs_hint(fs_states_b[i], n_chuncks_per_answer);
new_fs_state, answer = fs_hint(fs_states_b[i], n_chunks_per_answer);
fs_states_b[i + 1] = new_fs_state;
answers[i] = answer;
}
fs_state_10 = fs_states_b[num_queries];

leaf_hashes = malloc(num_queries); // a vector of vectorized pointers, each pointing to 1 chunk of 8 field elements
batch_hash_slice_dynamic(num_queries, answers, leaf_hashes, n_chuncks_per_answer);
batch_hash_slice_dynamic(num_queries, answers, leaf_hashes, n_chunks_per_answer);

// Merkle verification
merkle_verif_batch_dynamic(num_queries, leaf_hashes, stir_challenges_indexes + num_queries, prev_root, folded_domain_size);
Expand Down Expand Up @@ -457,6 +459,7 @@ fn powers(alpha, n) -> 1 {

fn unit_root_pow_dynamic(domain_size, index_bits) -> 1 {
// index_bits is a pointer to domain_size bits
var res;
match domain_size {
0 => { } // unreachable
1 => { res = unit_root_pow_const(1, index_bits); }
Expand Down Expand Up @@ -544,23 +547,24 @@ fn poly_eq_extension(point, n, two_pow_n) -> 1 {
res = malloc(DIM);
set_to_one(res);
return res;
}

res = malloc(two_pow_n * DIM);
} else {
res = malloc(two_pow_n * DIM);

inner_res = poly_eq_extension(point + DIM, n - 1, two_pow_n / 2);
inner_res = poly_eq_extension(point + DIM, n - 1, two_pow_n / 2);

two_pow_n_minus_1 = two_pow_n / 2;
two_pow_n_minus_1 = two_pow_n / 2;

for i in 0..two_pow_n_minus_1 {
mul_extension(point, inner_res + i*DIM, res + (two_pow_n_minus_1 + i) * DIM);
sub_extension(inner_res + i*DIM, res + (two_pow_n_minus_1 + i) * DIM, res + i*DIM);
for i in 0..two_pow_n_minus_1 {
mul_extension(point, inner_res + i*DIM, res + (two_pow_n_minus_1 + i) * DIM);
sub_extension(inner_res + i*DIM, res + (two_pow_n_minus_1 + i) * DIM, res + i*DIM);
}

return res;
}

return res;
}

fn poly_eq_base(point, n) -> 1 {
var res;
match n {
0 => { } // unreachable
1 => { res = poly_eq_base_1(point); }
Expand Down Expand Up @@ -892,6 +896,8 @@ fn fs_hint(fs_state, n) -> 2 {
}

fn fs_receive_ef(fs_state, n) -> 2 {
var final_fs_state;
var res;
match n {
0 => { } // unreachable
1 => { final_fs_state, res = fs_receive_ef_const(fs_state, 1); }
Expand Down
3 changes: 1 addition & 2 deletions crates/rec_aggregation/src/xmss_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ fn compile_xmss_aggregation_program() -> XmssAggregationProgram {
assert_eq!(
result.no_vec_runtime_memory,
res.compute_non_vec_memory(&log_lifetimes),
"inconsistent no-vec memory for log_lifetimes : {:?}: non linear formula, TODO",
log_lifetimes
"inconsistent no-vec memory for log_lifetimes : {log_lifetimes:?}: non linear formula, TODO",
);
}
res
Expand Down
2 changes: 1 addition & 1 deletion crates/rec_aggregation/xmss_aggregate.lean_lang
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,4 @@ fn assert_eq_vec(x, y) inline {
dot_product_ee(ptr_x, pointer_to_one_vector * 8, ptr_y, 1);
dot_product_ee(ptr_x + 3, pointer_to_one_vector * 8, ptr_y + 3, 1);
return;
}
}