diff --git a/Cargo.lock b/Cargo.lock index 2861a2c4..5517d523 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,7 +92,7 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "backend" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#99573608916262f1643de9ce5d67f54a15d05524" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" dependencies = [ "fiat-shamir", "itertools", @@ -103,6 +103,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -176,7 +185,7 @@ dependencies = [ [[package]] name = "constraints-folder" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#99573608916262f1643de9ce5d67f54a15d05524" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" dependencies = [ "fiat-shamir", "p3-air", @@ -185,9 +194,9 @@ dependencies = [ [[package]] name = "convert_case" -version = "0.7.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb402b8d4c85569410425650ce3eddc7d698ed96d39a73f941b08fb63082f1e7" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" dependencies = [ "unicode-segmentation", ] @@ -238,22 +247,23 @@ dependencies = [ [[package]] name = "derive_more" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618" dependencies = [ "derive_more-impl", ] [[package]] name = "derive_more-impl" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b" dependencies = [ "convert_case", "proc-macro2", "quote", + "rustc_version", "syn", "unicode-xid", ] @@ -283,11 +293,12 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "fiat-shamir" version = "0.1.0" -source = "git+https://github.com/leanEthereum/fiat-shamir.git#cd05e57a124fa29bb788cec49569c835dd60053a" +source = "git+https://github.com/leanEthereum/fiat-shamir.git#bcf23c766f2e930acf11e68777449483a55af077" dependencies = [ "p3-challenger", "p3-field", "p3-koala-bear", + "serde", ] [[package]] @@ -442,6 +453,7 @@ dependencies = [ "air", "colored", "derive_more", + "itertools", "lookup", "multilinear-toolkit", "num_enum", @@ -464,15 +476,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lookup" @@ -506,7 +518,7 @@ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] name = "multilinear-toolkit" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#99573608916262f1643de9ce5d67f54a15d05524" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" dependencies = [ "backend", "constraints-folder", @@ -962,6 +974,7 @@ name = "rec_aggregation" version = "0.1.0" dependencies = [ "air", + "bincode", "lean_compiler", "lean_prover", "lean_vm", @@ -974,6 +987,7 @@ dependencies = [ "p3-symmetric", "p3-util", "rand", + "serde", "serde_json", "sub_protocols", "tracing", @@ -999,6 +1013,15 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -1011,6 +1034,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -1120,7 +1149,7 @@ dependencies = [ [[package]] name = "sumcheck" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#99573608916262f1643de9ce5d67f54a15d05524" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#62766141561550c3540f9f644085fec53d721f16" dependencies = [ "backend", "constraints-folder", @@ -1203,9 +1232,9 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -1214,9 +1243,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -1225,9 +1254,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", "valuable", @@ -1259,9 +1288,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -1362,7 +1391,7 @@ dependencies = [ [[package]] name = "whir-p3" version = "0.1.0" -source = "git+https://github.com/TomWambsgans/whir-p3?branch=lean-multisig#e30463252c59e18c7cf09168299539b1e5b2b442" +source = "git+https://github.com/TomWambsgans/whir-p3?branch=lean-multisig#979e607d4c06725519ed3a4ef903fa1c7254d734" dependencies = [ "itertools", "multilinear-toolkit", @@ -1498,9 +1527,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.7.13" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" dependencies = [ "memchr", ] @@ -1553,18 +1582,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea879c944afe8a2b25fef16bb4ba234f47c694565e97383b36f3a878219065c" +checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf955aa904d6040f70dc8e9384444cb1030aed272ba3cb09bbc4ab9e7c1f34f5" +checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index eb064dd6..00cb8933 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,10 +67,13 @@ itertools = "0.14.0" colored = "3.0.0" tracing = "0.1.26" serde_json = "1.0.145" +serde = { version = "1.0.228", features = ["derive"] } +bincode = "1.3.3" num_enum = "0.7.5" tracing-subscriber = { version = "0.3.19", features = ["std", "env-filter"] } tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] } p3-koala-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } +p3-baby-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-poseidon2 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-symmetric = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } p3-air = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } @@ -92,25 +95,25 @@ multilinear-toolkit.workspace = true whir-p3.workspace = true # [patch."https://github.com/TomWambsgans/Plonky3.git"] -# p3-koala-bear = { path = "../zk/Plonky3/koala-bear" } -# p3-field = { path = "../zk/Plonky3/field" } -# p3-poseidon2 = { path = "../zk/Plonky3/poseidon2" } -# p3-symmetric = { path = "../zk/Plonky3/symmetric" } -# p3-air = { path = "../zk/Plonky3/air" } -# p3-merkle-tree = { path = "../zk/Plonky3/merkle-tree" } -# p3-commit = { path = "../zk/Plonky3/commit" } -# p3-matrix = { path = "../zk/Plonky3/matrix" } -# p3-dft = { path = "../zk/Plonky3/dft" } -# p3-challenger = { path = "../zk/Plonky3/challenger" } -# p3-monty-31 = { path = "../zk/Plonky3/monty-31" } -# p3-maybe-rayon = { path = "../zk/Plonky3/maybe-rayon" } -# p3-util = { path = "../zk/Plonky3/util" } +# p3-koala-bear = { path = "../Plonky3/koala-bear" } +# p3-field = { path = "../Plonky3/field" } +# p3-poseidon2 = { path = "../Plonky3/poseidon2" } +# p3-symmetric = { path = "../Plonky3/symmetric" } +# p3-air = { path = "../Plonky3/air" } +# p3-merkle-tree = { path = "../Plonky3/merkle-tree" } +# p3-commit = { path = "../Plonky3/commit" } +# p3-matrix = { path = "../Plonky3/matrix" } +# p3-dft = { path = "../Plonky3/dft" } +# p3-challenger = { path = "../Plonky3/challenger" } +# p3-monty-31 = { path = "../Plonky3/monty-31" } +# p3-maybe-rayon = { path = "../Plonky3/maybe-rayon" } +# p3-util = { path = "../Plonky3/util" } # [patch."https://github.com/TomWambsgans/whir-p3.git"] -# whir-p3 = { path = "../zk/whir/fork-whir-p3" } +# whir-p3 = { path = "../whir-p3" } # [patch."https://github.com/leanEthereum/multilinear-toolkit.git"] -# multilinear-toolkit = { path = "../zk/multilinear-toolkit" } +# multilinear-toolkit = { path = "../multilinear-toolkit" } # [profile.release] # opt-level = 1 diff --git a/README.md b/README.md index f5a14852..55b7c3e9 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,12 @@ XMSS + minimal [zkVM](minimal_zkVM.pdf) = lightweight PQ signatures, with unbounded aggregation +## Status + +- branch [main](https://github.com/leanEthereum/leanMultisig): optimized for **prover efficiency** +- branch [lean-vm-simple](https://github.com/leanEthereum/leanMultisig/tree/lean-vm-simple): optimized for **simplicity** + +Both versions will eventually merge into one. ## Proving System @@ -38,15 +44,17 @@ RUSTFLAGS='-C target-cpu=native' cargo run --release -- poseidon --log-n-perms 2 The full recursion program is not finished yet. Instead, we prove validity of a WHIR opening, with 25 variables, and rate = 1/4. ```console -RUSTFLAGS='-C target-cpu=native' cargo run --release -- recursion +RUSTFLAGS='-C target-cpu=native' cargo run --release -- recursion --count 8 ``` ![Alt text](docs/benchmark_graphs/graphs/recursive_whir_opening.svg) +Detail: before 4 December 2025, only 1 WHIR opening was benchmarked. Starting from now, we prove a dozen of openings together (to be closer to the n-to-1 aggregation scenario) and we report the proving time / WHIR. + ### XMSS aggregation ```console -RUSTFLAGS='-C target-cpu=native' cargo run --release -- xmss --n-signatures 1000 +RUSTFLAGS='-C target-cpu=native' cargo run --release -- xmss --n-signatures 1775 ``` [Trivial encoding](docs/XMSS_trivial_encoding.pdf) (for now). diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index 58ff6494..ad6b8409 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -40,6 +40,16 @@ where "TODO handle the case UNIVARIATE_SKIPS >= log_length" ); + // crate::check_air_validity( + // air, + // &extra_data, + // &columns_f, + // &columns_ef, + // last_row_shifted_f, + // last_row_shifted_ef, + // ) + // .unwrap(); + let alpha = prover_state.sample(); // random challenge for batching constraints *extra_data.alpha_powers_mut() = alpha diff --git a/crates/air/tests/complex_air.rs b/crates/air/tests/complex_air.rs index 32b857c7..d2cc7c76 100644 --- a/crates/air/tests/complex_air.rs +++ b/crates/air/tests/complex_air.rs @@ -170,7 +170,7 @@ fn test_air_helper() { virtual_column_statement_prover, true, ); - let mut verifier_state = build_verifier_state(&prover_state); + let mut verifier_state = build_verifier_state(prover_state); let virtual_column_statement_verifier = if VIRTUAL_COLUMN { let virtual_column_evaluation_point = diff --git a/crates/air/tests/fib_air.rs b/crates/air/tests/fib_air.rs index 5b2f2977..968a3fe4 100644 --- a/crates/air/tests/fib_air.rs +++ b/crates/air/tests/fib_air.rs @@ -92,7 +92,7 @@ fn test_air_fibonacci() { None, true, ); - let mut verifier_state = build_verifier_state(&prover_state); + let mut verifier_state = build_verifier_state(prover_state); let (point_verifier, evaluations_remaining_to_verify_f, evaluations_remaining_to_verify_ef) = verify_air( &mut verifier_state, diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 7fdd4cf1..50f95b59 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -120,12 +120,10 @@ pub enum SimpleLine { /// and ai < 4, b < 2^7 - 1 /// The decomposition is unique, and always exists (except for x = -1) DecomposeCustom { - var: Var, // a pointer to 13 * len(to_decompose) field elements - to_decompose: Vec, - label: ConstMallocLabel, + args: Vec, }, - CounterHint { - var: Var, + PrivateInputStart { + result: Var, }, Print { line_info: String, @@ -654,23 +652,15 @@ fn simplify_lines( label, }); } - Line::DecomposeCustom { var, to_decompose } => { - assert!(!const_malloc.forbidden_vars.contains(var), "TODO"); - let simplified_to_decompose = to_decompose + Line::PrivateInputStart { result } => { + res.push(SimpleLine::PrivateInputStart { result: result.clone() }); + } + Line::DecomposeCustom { args } => { + let simplified_args = args .iter() .map(|expr| simplify_expr(expr, &mut res, counters, array_manager, const_malloc)) .collect::>(); - let label = const_malloc.counter; - const_malloc.counter += 1; - const_malloc.map.insert(var.clone(), label); - res.push(SimpleLine::DecomposeCustom { - var: var.clone(), - to_decompose: simplified_to_decompose, - label, - }); - } - Line::CounterHint { var } => { - res.push(SimpleLine::CounterHint { var: var.clone() }); + res.push(SimpleLine::DecomposeCustom { args: simplified_args }); } Line::Panic => { res.push(SimpleLine::Panic); @@ -845,15 +835,19 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { on_new_expr(var, &internal_vars, &mut external_vars); } } - Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } => { for expr in to_decompose { on_new_expr(expr, &internal_vars, &mut external_vars); } internal_vars.insert(var.clone()); } - - Line::CounterHint { var } => { - internal_vars.insert(var.clone()); + Line::PrivateInputStart { result } => { + internal_vars.insert(result.clone()); + } + Line::DecomposeCustom { args } => { + for expr in args { + on_new_expr(expr, &internal_vars, &mut external_vars); + } } Line::ForLoop { iterator, @@ -1005,14 +999,19 @@ pub fn inline_lines(lines: &mut Vec, args: &BTreeMap, res inline_expr(var, args, inlining_count); } } - Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } => { for expr in to_decompose { inline_expr(expr, args, inlining_count); } inline_internal_var(var); } - Line::CounterHint { var } => { - inline_internal_var(var); + Line::DecomposeCustom { args: decompose_args } => { + for expr in decompose_args { + inline_expr(expr, args, inlining_count); + } + } + Line::PrivateInputStart { result } => { + inline_internal_var(result); } Line::ForLoop { iterator, @@ -1340,15 +1339,21 @@ fn replace_vars_for_unroll( replace_vars_for_unroll_in_expr(size, iterator, unroll_index, iterator_value, internal_vars); replace_vars_for_unroll_in_expr(vectorized_len, iterator, unroll_index, iterator_value, internal_vars); } - Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } => { assert!(var != iterator, "Weird"); *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); for expr in to_decompose { replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars); } } - Line::CounterHint { var } => { - *var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}"); + Line::PrivateInputStart { result } => { + assert!(result != iterator, "Weird"); + *result = format!("@unrolled_{unroll_index}_{iterator_value}_{result}"); + } + Line::DecomposeCustom { args } => { + for expr in args { + replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars); + } } Line::Break | Line::Panic | Line::LocationReport { .. } => {} } @@ -1698,10 +1703,10 @@ fn get_function_called(lines: &[Line], function_called: &mut Vec) { | Line::Assert { .. } | Line::FunctionRet { .. } | Line::Precompile { .. } + | Line::PrivateInputStart { .. } | Line::Print { .. } | Line::DecomposeBits { .. } | Line::DecomposeCustom { .. } - | Line::CounterHint { .. } | Line::MAlloc { .. } | Line::Panic | Line::Break @@ -1783,14 +1788,19 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { replace_vars_by_const_in_expr(var, map); } } - Line::DecomposeBits { var, to_decompose } | Line::DecomposeCustom { var, to_decompose } => { + Line::DecomposeBits { var, to_decompose } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); for expr in to_decompose { replace_vars_by_const_in_expr(expr, map); } } - Line::CounterHint { var } => { - assert!(!map.contains_key(var), "Variable {var} is a constant"); + Line::DecomposeCustom { args } => { + for expr in args { + replace_vars_by_const_in_expr(expr, map); + } + } + Line::PrivateInputStart { result } => { + assert!(!map.contains_key(result), "Variable {result} is a constant"); } Line::MAlloc { var, size, .. } => { assert!(!map.contains_key(var), "Variable {var} is a constant"); @@ -1863,24 +1873,12 @@ impl SimpleLine { .join(", ") ) } - Self::DecomposeCustom { - var: result, - to_decompose, - label: _, - } => { + Self::DecomposeCustom { args } => { format!( - "{} = decompose_custom({})", - result, - to_decompose - .iter() - .map(|expr| format!("{expr}")) - .collect::>() - .join(", ") + "decompose_custom({})", + args.iter().map(|expr| format!("{expr}")).collect::>().join(", ") ) } - Self::CounterHint { var: result } => { - format!("{result} = counter_hint()") - } Self::RawAccess { res, index, shift } => { format!("memory[{index} + {shift}] = {res}") } @@ -1967,6 +1965,9 @@ impl SimpleLine { Self::ConstMalloc { var, size, label: _ } => { format!("{var} = malloc({size})") } + Self::PrivateInputStart { result } => { + format!("private_input_start({result})") + } Self::Panic => "panic".to_string(), Self::LocationReport { .. } => Default::default(), }; diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index e77ac8ae..18cebdf8 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -435,7 +435,7 @@ fn compile_lines( } SimpleLine::Precompile { table, args, .. } => { - if *table == Table::poseidon24() { + if *table == Table::poseidon24_mem() { assert_eq!(args.len(), 3); } else { assert_eq!(args.len(), 4); @@ -511,36 +511,23 @@ fn compile_lines( label, ); } - SimpleLine::DecomposeCustom { - var, - to_decompose, - label, - } => { + SimpleLine::DecomposeCustom { args } => { + assert!(args.len() >= 3); + let decomposed = IntermediateValue::from_simple_expr(&args[0], compiler); + let remaining = IntermediateValue::from_simple_expr(&args[1], compiler); instructions.push(IntermediateInstruction::DecomposeCustom { - res_offset: compiler.stack_size, - to_decompose: to_decompose + decomposed, + remaining, + to_decompose: args[2..] .iter() .map(|expr| IntermediateValue::from_simple_expr(expr, compiler)) .collect(), }); - - handle_const_malloc( - declared_vars, - &mut instructions, - compiler, - var, - F::bits() * to_decompose.len(), - label, - ); } - SimpleLine::CounterHint { var } => { - declared_vars.insert(var.clone()); - instructions.push(IntermediateInstruction::CounterHint { - res_offset: compiler - .get_offset(&var.clone().into()) - .naive_eval() - .unwrap() - .to_usize(), + SimpleLine::PrivateInputStart { result } => { + declared_vars.insert(result.clone()); + instructions.push(IntermediateInstruction::PrivateInputStart { + res_offset: compiler.get_offset(&result.clone().into()), }); } SimpleLine::Print { line_info, content } => { @@ -690,11 +677,12 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { SimpleLine::TestZero { .. } => {} SimpleLine::HintMAlloc { var, .. } | SimpleLine::ConstMalloc { var, .. } - | SimpleLine::DecomposeBits { var, .. } - | SimpleLine::DecomposeCustom { var, .. } - | SimpleLine::CounterHint { var } => { + | SimpleLine::DecomposeBits { var, .. } => { internal_vars.insert(var.clone()); } + SimpleLine::PrivateInputStart { result } => { + internal_vars.insert(result.clone()); + } SimpleLine::RawAccess { res, .. } => { if let SimpleExpr::Var(var) = res { internal_vars.insert(var.clone()); @@ -712,6 +700,7 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { internal_vars.extend(find_internal_vars(else_branch)); } SimpleLine::Panic + | SimpleLine::DecomposeCustom { .. } | SimpleLine::Print { .. } | SimpleLine::FunctionRet { .. } | SimpleLine::Precompile { .. } diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 1b8934b3..0c2a3bd8 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -11,7 +11,7 @@ impl IntermediateInstruction { | Self::Print { .. } | Self::DecomposeBits { .. } | Self::DecomposeCustom { .. } - | Self::CounterHint { .. } + | Self::PrivateInputStart { .. } | Self::Inverse { .. } | Self::LocationReport { .. } => true, Self::Computation { .. } @@ -116,8 +116,6 @@ pub fn compile_to_low_level_bytecode( } } - let mut low_level_bytecode = Vec::new(); - for (label, pc) in label_to_pc.clone() { hints.entry(pc).or_insert_with(Vec::new).push(Hint::Label { label }); } @@ -129,19 +127,21 @@ pub fn compile_to_low_level_bytecode( match_first_block_starts, }; + let mut instructions = Vec::new(); + for (function_name, pc_start, block) in code_blocks { compile_block( &compiler, &function_name, &block, pc_start, - &mut low_level_bytecode, + &mut instructions, &mut hints, ); } Ok(Bytecode { - instructions: low_level_bytecode, + instructions, hints, starting_frame_memory, program, @@ -307,12 +307,19 @@ fn compile_block( }; hints.entry(pc).or_default().push(hint); } + IntermediateInstruction::PrivateInputStart { res_offset } => { + hints.entry(pc).or_default().push(Hint::PrivateInputStart { + res_offset: eval_const_expression_usize(&res_offset, compiler), + }); + } IntermediateInstruction::DecomposeCustom { - res_offset, + decomposed, + remaining, to_decompose, } => { let hint = Hint::DecomposeCustom { - res_offset, + decomposed: try_as_mem_or_constant(&decomposed).unwrap(), + remaining: try_as_mem_or_constant(&remaining).unwrap(), to_decompose: to_decompose .iter() .map(|expr| try_as_mem_or_constant(expr).unwrap()) @@ -320,10 +327,6 @@ fn compile_block( }; hints.entry(pc).or_default().push(hint); } - IntermediateInstruction::CounterHint { res_offset } => { - let hint = Hint::CounterHint { res_offset }; - hints.entry(pc).or_default().push(hint); - } IntermediateInstruction::Inverse { arg, res_offset } => { let hint = Hint::Inverse { arg: try_as_mem_or_constant(&arg).unwrap(), diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index 5b193356..f2e0829c 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -56,11 +56,12 @@ pub enum IntermediateInstruction { /// and ai < 4, b < 2^7 - 1 /// The decomposition is unique, and always exists (except for x = -1) DecomposeCustom { - res_offset: usize, // m[fp + res_offset..fp + res_offset + 13 * len(to_decompose)] will store the decomposed values + decomposed: IntermediateValue, + remaining: IntermediateValue, to_decompose: Vec, }, - CounterHint { - res_offset: usize, + PrivateInputStart { + res_offset: ConstExpression, }, Print { line_info: String, // information about the line where the print occurs @@ -141,6 +142,9 @@ impl Display for IntermediateInstruction { write!(f, "jump {dest}") } } + Self::PrivateInputStart { res_offset } => { + write!(f, "m[fp + {res_offset}] = private_input_start()") + } Self::JumpIfNotZero { condition, dest, @@ -190,10 +194,11 @@ impl Display for IntermediateInstruction { write!(f, ")") } Self::DecomposeCustom { - res_offset, + decomposed, + remaining, to_decompose, } => { - write!(f, "m[fp + {res_offset}..] = decompose_custom(")?; + write!(f, "decompose_custom(m[fp + {decomposed}], m[fp + {remaining}], ")?; for (i, expr) in to_decompose.iter().enumerate() { if i > 0 { write!(f, ", ")?; @@ -202,9 +207,6 @@ impl Display for IntermediateInstruction { } write!(f, ")") } - Self::CounterHint { res_offset } => { - write!(f, "m[fp + {res_offset}] = counter_hint()") - } Self::Print { line_info, content } => { write!(f, "print {line_info}: ")?; for (i, c) in content.iter().enumerate() { diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index ba1c001a..a9db3c76 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -368,11 +368,10 @@ pub enum Line { /// and ai < 4, b < 2^7 - 1 /// The decomposition is unique, and always exists (except for x = -1) DecomposeCustom { - var: Var, // a pointer to 13 * len(to_decompose) field elements - to_decompose: Vec, + args: Vec, }, - CounterHint { - var: Var, + PrivateInputStart { + result: Var, }, // noop, debug purpose only LocationReport { @@ -425,6 +424,9 @@ impl Line { Self::ArrayAssign { array, index, value } => { format!("{array}[{index}] = {value}") } + Self::PrivateInputStart { result } => { + format!("{result} = private_input_start()") + } Self::Assert(condition, _line_number) => format!("assert {condition}"), Self::IfCondition { condition, @@ -450,9 +452,6 @@ impl Line { format!("if {condition} {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}") } } - Self::CounterHint { var } => { - format!("{var} = counter_hint({var})") - } Self::ForLoop { iterator, start, @@ -542,15 +541,10 @@ impl Line { .join(", ") ) } - Self::DecomposeCustom { var, to_decompose } => { + Self::DecomposeCustom { args } => { format!( - "{} = decompose_custom({})", - var, - to_decompose - .iter() - .map(|expr| expr.to_string()) - .collect::>() - .join(", ") + "decompose_custom({})", + args.iter().map(|expr| expr.to_string()).collect::>().join(", ") ) } Self::Break => "break".to_string(), diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index 7d3bd3d5..750d072b 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -43,6 +43,7 @@ pub fn compile_and_run( no_vec_runtime_memory, profiler, (&vec![], &vec![]), + Default::default(), ) .summary; println!("{summary}"); diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 4e43477e..77abe17a 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -199,20 +199,17 @@ impl FunctionCallParser { }) } "decompose_custom" => { - if args.is_empty() || return_data.len() != 1 { + if args.len() < 3 { return Err(SemanticError::new("Invalid decompose_custom call").into()); } - Ok(Line::DecomposeCustom { - var: return_data[0].clone(), - to_decompose: args, - }) + Ok(Line::DecomposeCustom { args }) } - "counter_hint" => { + "private_input_start" => { if !args.is_empty() || return_data.len() != 1 { - return Err(SemanticError::new("Invalid counter_hint call").into()); + return Err(SemanticError::new("Invalid private_input_start call").into()); } - Ok(Line::CounterHint { - var: return_data[0].clone(), + Ok(Line::PrivateInputStart { + result: return_data[0].clone(), }) } "panic" => { diff --git a/crates/lean_prover/src/common.rs b/crates/lean_prover/src/common.rs index 5adeb3fd..e5cef782 100644 --- a/crates/lean_prover/src/common.rs +++ b/crates/lean_prover/src/common.rs @@ -1,6 +1,8 @@ +use std::collections::BTreeMap; + use multilinear_toolkit::prelude::*; use p3_koala_bear::{KOALABEAR_RC16_INTERNAL, KOALABEAR_RC24_INTERNAL}; -use poseidon_circuit::{GKRPoseidonResult, PoseidonGKRLayers, default_cube_layers}; +use poseidon_circuit::{PoseidonGKRLayers, default_cube_layers}; use sub_protocols::ColDims; use crate::*; @@ -16,7 +18,7 @@ pub(crate) fn get_base_dims( &PoseidonGKRLayers<16, N_COMMITED_CUBES_P16>, &PoseidonGKRLayers<24, N_COMMITED_CUBES_P24>, ), - table_heights: [TableHeight; N_TABLES], + table_heights: &BTreeMap, ) -> Vec> { let p16_default_cubes = default_cube_layers::(p16_gkr_layers); let p24_default_cubes = default_cube_layers::(p24_gkr_layers); @@ -27,16 +29,16 @@ pub(crate) fn get_base_dims( ], p16_default_cubes .iter() - .map(|&c| ColDims::padded(table_heights[Table::poseidon16().index()].n_rows_non_padded_maxed(), c)) + .map(|&c| ColDims::padded(table_heights[&Table::poseidon16_core()].n_rows_non_padded_maxed(), c)) .collect::>(), // commited cubes for poseidon16 p24_default_cubes .iter() - .map(|&c| ColDims::padded(table_heights[Table::poseidon24().index()].n_rows_non_padded_maxed(), c)) + .map(|&c| ColDims::padded(table_heights[&Table::poseidon24_core()].n_rows_non_padded_maxed(), c)) .collect::>(), ] .concat(); - for i in 0..N_TABLES { - dims.extend(ALL_TABLES[i].committed_dims(table_heights[i].n_rows_non_padded_maxed())); + for (table, height) in table_heights { + dims.extend(table.committed_dims(height.n_rows_non_padded_maxed())); } dims } @@ -64,20 +66,3 @@ fn split_at(stmt: &MultiEvaluation, start: usize, end: usize) -> Vec Vec>> { - vec![ - split_at(&p16_gkr.input_statements, 0, VECTOR_LEN), - split_at(&p16_gkr.input_statements, VECTOR_LEN, VECTOR_LEN * 2), - split_at(&p16_gkr.output_statements, 0, VECTOR_LEN), - split_at(&p16_gkr.output_statements, VECTOR_LEN, VECTOR_LEN * 2), - ] -} - -pub(crate) fn poseidon_24_vectorized_lookup_statements(p24_gkr: &GKRPoseidonResult) -> Vec>> { - vec![ - split_at(&p24_gkr.input_statements, 0, VECTOR_LEN), - split_at(&p24_gkr.input_statements, VECTOR_LEN, VECTOR_LEN * 2), - split_at(&p24_gkr.input_statements, VECTOR_LEN * 2, VECTOR_LEN * 3), - split_at(&p24_gkr.output_statements, VECTOR_LEN * 2, VECTOR_LEN * 3), - ] -} diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 2d9bae9a..2483dfc6 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -1,4 +1,4 @@ -use std::array; +use std::collections::BTreeMap; use crate::common::*; use crate::*; @@ -10,6 +10,7 @@ use multilinear_toolkit::prelude::*; use p3_air::Air; use p3_util::{log2_ceil_usize, log2_strict_usize}; use poseidon_circuit::{PoseidonGKRLayers, prove_poseidon_gkr}; +use std::collections::VecDeque; use sub_protocols::*; use tracing::info_span; use utils::{build_prover_state, padd_with_zero_to_next_power_of_two}; @@ -23,7 +24,8 @@ pub fn prove_execution( no_vec_runtime_memory: usize, // size of the "non-vectorized" runtime memory vm_profiler: bool, (poseidons_16_precomputed, poseidons_24_precomputed): (&Poseidon16History, &Poseidon24History), -) -> (Vec>, String) { + merkle_path_hints: VecDeque>, +) -> (Proof, String) { let mut exec_summary = String::new(); let ExecutionTrace { traces, @@ -38,6 +40,7 @@ pub fn prove_execution( no_vec_runtime_memory, vm_profiler, (poseidons_16_precomputed, poseidons_24_precomputed), + merkle_path_hints, ) }); exec_summary = std::mem::take(&mut execution_result.summary); @@ -52,53 +55,62 @@ pub fn prove_execution( let private_memory = &memory[public_memory_size..non_zero_memory_size]; let log_public_memory = log2_strict_usize(public_memory.len()); - let _validity_proof_span = info_span!("Validity proof generation").entered(); + let mut prover_state = build_prover_state::(false); + prover_state.add_base_scalars( + &[ + vec![private_memory.len()], + traces.values().map(|t| t.n_rows_non_padded()).collect::>(), + ] + .concat() + .into_iter() + .map(F::from_usize) + .collect::>(), + ); + + // only keep tables with non-zero rows + let traces: BTreeMap<_, _> = traces + .into_iter() + .filter(|(table, trace)| trace.n_rows_non_padded() > 0 || table == &Table::execution() || table.is_poseidon()) + .collect(); let p16_gkr_layers = PoseidonGKRLayers::<16, N_COMMITED_CUBES_P16>::build(Some(VECTOR_LEN)); let p24_gkr_layers = PoseidonGKRLayers::<24, N_COMMITED_CUBES_P24>::build(None); let p16_witness = generate_poseidon_witness_helper( &p16_gkr_layers, - &traces[Table::poseidon16().index()], - POSEIDON_16_COL_INDEX_INPUT_START, - Some(&traces[Table::poseidon16().index()].base[POSEIDON_16_COL_COMPRESSION].clone()), + &traces[&Table::poseidon16_core()], + POSEIDON_16_CORE_COL_INPUT_START, + Some(&traces[&Table::poseidon16_core()].base[POSEIDON_16_CORE_COL_COMPRESSION].clone()), ); let p24_witness = generate_poseidon_witness_helper( &p24_gkr_layers, - &traces[Table::poseidon24().index()], - POSEIDON_24_COL_INDEX_INPUT_START, + &traces[&Table::poseidon24_core()], + POSEIDON_24_CORE_COL_INPUT_START, None, ); - let commitmenent_extension_helper: [_; N_TABLES] = array::from_fn(|i| { - (ALL_TABLES[i].n_commited_columns_ef() > 0).then(|| { - ExtensionCommitmentFromBaseProver::before_commitment( - ALL_TABLES[i] - .commited_columns_ef() - .iter() - .map(|&c| &traces[i].ext[c][..]) - .collect::>(), + let commitmenent_extension_helper = traces + .iter() + .filter(|(table, _)| table.n_commited_columns_ef() > 0) + .map(|(table, trace)| { + ( + *table, + ExtensionCommitmentFromBaseProver::before_commitment( + table + .commited_columns_ef() + .iter() + .map(|&c| &trace.ext[c][..]) + .collect::>(), + ), ) }) - }); - - let mut prover_state = build_prover_state::(false); - prover_state.add_base_scalars( - &[ - vec![private_memory.len()], - traces.iter().map(|t| t.n_rows_non_padded()).collect::>(), - ] - .concat() - .into_iter() - .map(F::from_usize) - .collect::>(), - ); + .collect::>(); let base_dims = get_base_dims( log_public_memory, private_memory.len(), (&p16_gkr_layers, &p24_gkr_layers), - array::from_fn(|i| traces[i].height), + &traces.iter().map(|(table, trace)| (*table, trace.height)).collect(), ); let mut base_pols = [ @@ -115,8 +127,8 @@ pub fn prove_execution( .collect::>(), ] .concat(); - for i in 0..N_TABLES { - base_pols.extend(ALL_TABLES[i].committed_columns(&traces[i], commitmenent_extension_helper[i].as_ref())); + for (table, trace) in &traces { + base_pols.extend(table.committed_columns(trace, commitmenent_extension_helper.get(table))); } // 1st Commitment @@ -128,44 +140,24 @@ pub fn prove_execution( LOG_SMALLEST_DECOMPOSITION_CHUNK, ); - let random_point_p16 = MultilinearPoint(prover_state.sample_vec(traces[Table::poseidon16().index()].log_padded())); - let p16_gkr = prove_poseidon_gkr( - &mut prover_state, - &p16_witness, - random_point_p16.0.clone(), - UNIVARIATE_SKIPS, - &p16_gkr_layers, - ); - - let random_point_p24 = MultilinearPoint(prover_state.sample_vec(traces[Table::poseidon24().index()].log_padded())); - let p24_gkr = prove_poseidon_gkr( - &mut prover_state, - &p24_witness, - random_point_p24.0.clone(), - UNIVARIATE_SKIPS, - &p24_gkr_layers, - ); - let bus_challenge = prover_state.sample(); let fingerprint_challenge = prover_state.sample(); - let mut bus_quotients: [EF; N_TABLES] = Default::default(); - let mut air_points: [MultilinearPoint; N_TABLES] = Default::default(); - let mut evals_f: [Vec; N_TABLES] = Default::default(); - let mut evals_ef: [Vec; N_TABLES] = Default::default(); - - for table in ALL_TABLES { - let i = table.index(); - (bus_quotients[i], air_points[i], evals_f[i], evals_ef[i]) = prove_bus_and_air( - &mut prover_state, - &table, - &traces[i], - bus_challenge, - fingerprint_challenge, - ); + let mut bus_quotients: BTreeMap = Default::default(); + let mut air_points: BTreeMap> = Default::default(); + let mut evals_f: BTreeMap> = Default::default(); + let mut evals_ef: BTreeMap> = Default::default(); + + for (table, trace) in &traces { + let (this_bus_quotient, this_air_point, this_evals_f, this_evals_ef) = + prove_bus_and_air(&mut prover_state, table, trace, bus_challenge, fingerprint_challenge); + bus_quotients.insert(*table, this_bus_quotient); + air_points.insert(*table, this_air_point); + evals_f.insert(*table, this_evals_f); + evals_ef.insert(*table, this_evals_ef); } - assert_eq!(bus_quotients.iter().copied().sum::(), EF::ZERO); + assert_eq!(bus_quotients.values().copied().sum::(), EF::ZERO); let bytecode_compression_challenges = MultilinearPoint(prover_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); @@ -173,13 +165,13 @@ pub fn prove_execution( let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges); let bytecode_lookup_claim_1 = Evaluation::new( - air_points[Table::execution().index()].clone(), - padd_with_zero_to_next_power_of_two(&evals_f[Table::execution().index()][..N_INSTRUCTION_COLUMNS]) + air_points[&Table::execution()].clone(), + padd_with_zero_to_next_power_of_two(&evals_f[&Table::execution()][..N_INSTRUCTION_COLUMNS]) .evaluate(&bytecode_compression_challenges), ); - let bytecode_poly_eq_point = eval_eq(&air_points[Table::execution().index()]); + let bytecode_poly_eq_point = eval_eq(&air_points[&Table::execution()]); let bytecode_pushforward = compute_pushforward( - &traces[Table::execution().index()].base[COL_INDEX_PC], + &traces[&Table::execution()].base[COL_INDEX_PC], folded_bytecode.len(), &bytecode_poly_eq_point, ); @@ -187,35 +179,45 @@ pub fn prove_execution( let normal_lookup_into_memory = NormalPackedLookupProver::step_1( &mut prover_state, &memory, - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_index_columns_f(&traces[i])) + traces + .iter() + .flat_map(|(table, trace)| table.normal_lookup_index_columns_f(trace)) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_index_columns_ef(&traces[i])) + traces + .iter() + .flat_map(|(table, trace)| table.normal_lookup_index_columns_ef(trace)) .collect(), - (0..N_TABLES) - .flat_map(|i| vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_f()]) + traces + .iter() + .flat_map(|(table, trace)| vec![trace.n_rows_non_padded_maxed(); table.num_normal_lookups_f()]) .collect(), - (0..N_TABLES) - .flat_map(|i| vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_ef()]) + traces + .iter() + .flat_map(|(table, trace)| vec![trace.n_rows_non_padded_maxed(); table.num_normal_lookups_ef()]) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_default_indexes_f()) + traces + .keys() + .flat_map(|table| table.normal_lookup_default_indexes_f()) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_default_indexes_ef()) + traces + .keys() + .flat_map(|table| table.normal_lookup_default_indexes_ef()) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_f_value_columns(&traces[i])) + traces + .iter() + .flat_map(|(table, trace)| table.normal_lookup_f_value_columns(trace)) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_ef_value_columns(&traces[i])) + traces + .iter() + .flat_map(|(table, trace)| table.normal_lookup_ef_value_columns(trace)) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookups_statements_f(&air_points[i], &evals_f[i])) + traces + .keys() + .flat_map(|table| table.normal_lookups_statements_f(&air_points[table], &evals_f[table])) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookups_statements_ef(&air_points[i], &evals_ef[i])) + traces + .keys() + .flat_map(|table| table.normal_lookups_statements_ef(&air_points[table], &evals_ef[table])) .collect(), LOG_SMALLEST_DECOMPOSITION_CHUNK, ); @@ -223,34 +225,26 @@ pub fn prove_execution( let vectorized_lookup_into_memory = VectorizedPackedLookupProver::<_, VECTOR_LEN>::step_1( &mut prover_state, &memory, - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].vector_lookup_index_columns(&traces[i])) + traces + .iter() + .flat_map(|(table, trace)| table.vector_lookup_index_columns(trace)) .collect(), - (0..N_TABLES) - .flat_map(|i| vec![traces[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_vector_lookups()]) + traces + .iter() + .flat_map(|(table, trace)| vec![trace.n_rows_non_padded_maxed(); table.num_vector_lookups()]) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].vector_lookup_default_indexes()) + traces + .keys() + .flat_map(|table| table.vector_lookup_default_indexes()) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].vector_lookup_values_columns(&traces[i])) + traces + .iter() + .flat_map(|(table, trace)| table.vector_lookup_values_columns(trace)) + .collect(), + traces + .keys() + .flat_map(|table| table.vectorized_lookups_statements(&air_points[table], &evals_f[table])) .collect(), - { - let mut statements = vec![]; - for table in ALL_TABLES { - if table.identifier() == Table::poseidon16() { - statements.extend(poseidon_16_vectorized_lookup_statements(&p16_gkr)); // special case - continue; - } - if table.identifier() == Table::poseidon24() { - statements.extend(poseidon_24_vectorized_lookup_statements(&p24_gkr)); // special case - continue; - } - statements - .extend(table.vectorized_lookups_statements(&air_points[table.index()], &evals_f[table.index()])); - } - statements - }, LOG_SMALLEST_DECOMPOSITION_CHUNK, ); @@ -286,7 +280,7 @@ pub fn prove_execution( let bytecode_logup_star_statements = prove_logup_star( &mut prover_state, &MleRef::Extension(&folded_bytecode), - &traces[Table::execution().index()].base[COL_INDEX_PC], + &traces[&Table::execution()].base[COL_INDEX_PC], bytecode_lookup_claim_1.value, &bytecode_poly_eq_point, &bytecode_pushforward, @@ -298,43 +292,87 @@ pub fn prove_execution( vectorized_lookup_statements.on_table.clone(), ]; - let mut final_statements: [Vec<_>; N_TABLES] = Default::default(); - for i in 0..N_TABLES { - final_statements[i] = ALL_TABLES[i].committed_statements_prover( - &mut prover_state, - &air_points[i], - &evals_f[i], - commitmenent_extension_helper[i].as_ref(), - &mut normal_lookup_statements.on_indexes_f, - &mut normal_lookup_statements.on_indexes_ef, + let mut final_statements: BTreeMap>>> = Default::default(); + for table in traces.keys() { + final_statements.insert( + *table, + table.committed_statements_prover( + &mut prover_state, + &air_points[table], + &evals_f[table], + commitmenent_extension_helper.get(table), + &mut normal_lookup_statements.on_indexes_f, + &mut normal_lookup_statements.on_indexes_ef, + ), ); } assert!(normal_lookup_statements.on_indexes_f.is_empty()); assert!(normal_lookup_statements.on_indexes_ef.is_empty()); + let p16_gkr = prove_poseidon_gkr( + &mut prover_state, + &p16_witness, + air_points[&Table::poseidon16_core()].0.clone(), + UNIVARIATE_SKIPS, + &p16_gkr_layers, + ); + assert_eq!(&p16_gkr.output_statements.point, &air_points[&Table::poseidon16_core()]); + assert_eq!( + &p16_gkr.output_statements.values, + &evals_f[&Table::poseidon16_core()][POSEIDON_16_CORE_COL_OUTPUT_START..][..16] + ); + + let p24_gkr = prove_poseidon_gkr( + &mut prover_state, + &p24_witness, + air_points[&Table::poseidon24_core()].0.clone(), + UNIVARIATE_SKIPS, + &p24_gkr_layers, + ); + assert_eq!(&p24_gkr.output_statements.point, &air_points[&Table::poseidon24_core()]); + assert_eq!( + &p24_gkr.output_statements.values[16..], + &evals_f[&Table::poseidon24_core()][POSEIDON_24_CORE_COL_OUTPUT_START..][..8] + ); + { let mut cursor = 0; - for t in 0..N_TABLES { + for table in traces.keys() { for (statement, lookup) in vectorized_lookup_statements.on_indexes[cursor..] .iter() - .zip(ALL_TABLES[t].vector_lookups()) + .zip(table.vector_lookups()) { - final_statements[t][lookup.index].extend(statement.clone()); + final_statements.get_mut(table).unwrap()[lookup.index].extend(statement.clone()); } - cursor += ALL_TABLES[t].num_vector_lookups(); + cursor += table.num_vector_lookups(); } } let (initial_pc_statement, final_pc_statement) = - initial_and_final_pc_conditions(traces[Table::execution().index()].log_padded()); + initial_and_final_pc_conditions(traces[&Table::execution()].log_padded()); - final_statements[Table::execution().index()][ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)].extend( - vec![ + final_statements.get_mut(&Table::execution()).unwrap()[ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)] + .extend(vec![ bytecode_logup_star_statements.on_indexes.clone(), initial_pc_statement, final_pc_statement, - ], - ); + ]); + let statements_p16_core = final_statements.get_mut(&Table::poseidon16_core()).unwrap(); + for (stmts, gkr_value) in statements_p16_core[POSEIDON_16_CORE_COL_INPUT_START..][..16] + .iter_mut() + .zip(&p16_gkr.input_statements.values) + { + stmts.push(Evaluation::new(p16_gkr.input_statements.point.clone(), *gkr_value)); + } + statements_p16_core[POSEIDON_16_CORE_COL_COMPRESSION].push(p16_gkr.on_compression_selector.unwrap()); + + let statements_p24_core = final_statements.get_mut(&Table::poseidon24_core()).unwrap(); + for (stmts, gkr_value) in statements_p24_core[POSEIDON_24_CORE_COL_INPUT_START..][..24] + .iter_mut() + .zip(&p24_gkr.input_statements.values) + { + stmts.push(Evaluation::new(p24_gkr.input_statements.point.clone(), *gkr_value)); + } // First Opening let mut all_base_statements = [ @@ -343,7 +381,7 @@ pub fn prove_execution( encapsulate_vec(p24_gkr.cubes_statements.split()), ] .concat(); - all_base_statements.extend(final_statements.into_iter().flatten()); + all_base_statements.extend(final_statements.into_values().flatten()); let global_statements_base = packed_pcs_global_statements_for_prover( &base_pols, @@ -380,7 +418,7 @@ pub fn prove_execution( &packed_pcs_witness_extension.packed_polynomial.by_ref(), ); - (prover_state.proof_data().to_vec(), exec_summary) + (prover_state.into_proof(), exec_summary) } fn prove_bus_and_air( @@ -400,16 +438,28 @@ fn prove_bus_and_air( let mut numerators = F::zero_vec(n_buses_padded * n_rows); for (bus, numerators_chunk) in t.buses().iter().zip(numerators.chunks_mut(n_rows)) { - assert!(bus.selector < trace.base.len()); - trace.base[bus.selector] - .par_iter() - .zip(numerators_chunk) - .for_each(|(&selector, v)| { - *v = match bus.direction { - BusDirection::Pull => -selector, - BusDirection::Push => selector, - } - }); + match bus.selector { + BusSelector::Column(selector_col) => { + assert!(selector_col < trace.base.len()); + trace.base[selector_col] + .par_iter() + .zip(numerators_chunk) + .for_each(|(&selector, v)| { + *v = match bus.direction { + BusDirection::Pull => -selector, + BusDirection::Push => selector, + } + }); + } + BusSelector::ConstantOne => { + numerators_chunk.par_iter_mut().for_each(|v| { + *v = match bus.direction { + BusDirection::Pull => F::NEG_ONE, + BusDirection::Push => F::ONE, + } + }); + } + } } let mut denominators = unsafe { uninitialized_vec(n_buses_padded * n_rows) }; @@ -528,10 +578,15 @@ fn prove_bus_and_air( } let extra_data = ExtraDataForBuses { - fingerprint_challenge_powers: powers_const(fingerprint_challenge), + fingerprint_challenge_powers: fingerprint_challenge.powers().collect_n(max_bus_width()), + fingerprint_challenge_powers_packed: EFPacking::::from(fingerprint_challenge) + .powers() + .collect_n(max_bus_width()), bus_beta, + bus_beta_packed: EFPacking::::from(bus_beta), alpha_powers: vec![], // filled later }; + let (air_point, evals_f, evals_ef) = info_span!("Table AIR proof", table = t.name()).in_scope(|| { macro_rules! prove_air_for_table { ($t:expr) => { diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 76ace533..a5be0708 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -1,4 +1,4 @@ -use std::array; +use std::collections::BTreeMap; use crate::common::*; use crate::*; @@ -21,10 +21,10 @@ use whir_p3::second_batched_whir_config_builder; pub fn verify_execution( bytecode: &Bytecode, public_input: &[F], - proof_data: Vec>, + proof: Proof, whir_config_builder: WhirConfigBuilder, ) -> Result<(), ProofError> { - let mut verifier_state = VerifierState::new(proof_data, build_challenger(), false); + let mut verifier_state = VerifierState::new(proof, build_challenger()); let p16_gkr_layers = PoseidonGKRLayers::<16, N_COMMITED_CUBES_P16>::build(Some(VECTOR_LEN)); let p24_gkr_layers = PoseidonGKRLayers::<24, N_COMMITED_CUBES_P24>::build(None); @@ -35,7 +35,15 @@ pub fn verify_execution( .map(|x| x.to_usize()) .collect::>(); let private_memory_len = dims[0]; - let table_heights: [TableHeight; N_TABLES] = array::from_fn(|i| TableHeight(dims[i + 1])); + let table_heights: BTreeMap = (0..N_TABLES) + .map(|i| (ALL_TABLES[i], TableHeight(dims[i + 1]))) + .collect(); + + // only keep tables with non-zero rows + let table_heights: BTreeMap<_, _> = table_heights + .into_iter() + .filter(|(table, height)| height.n_rows_non_padded() > 0 || table == &Table::execution() || table.is_poseidon()) + .collect(); let public_memory = build_public_memory(public_input); @@ -50,7 +58,7 @@ pub fn verify_execution( log_public_memory, private_memory_len, (&p16_gkr_layers, &p24_gkr_layers), - table_heights, + &table_heights, ); let parsed_commitment_base = packed_pcs_parse_commitment( &whir_config_builder, @@ -59,47 +67,29 @@ pub fn verify_execution( LOG_SMALLEST_DECOMPOSITION_CHUNK, )?; - let random_point_p16 = - MultilinearPoint(verifier_state.sample_vec(table_heights[Table::poseidon16().index()].log_padded())); - let p16_gkr = verify_poseidon_gkr( - &mut verifier_state, - table_heights[Table::poseidon16().index()].log_padded(), - &random_point_p16, - &p16_gkr_layers, - UNIVARIATE_SKIPS, - true, - ); - - let random_point_p24 = - MultilinearPoint(verifier_state.sample_vec(table_heights[Table::poseidon24().index()].log_padded())); - let p24_gkr = verify_poseidon_gkr( - &mut verifier_state, - table_heights[Table::poseidon24().index()].log_padded(), - &random_point_p24, - &p24_gkr_layers, - UNIVARIATE_SKIPS, - false, - ); - let bus_challenge = verifier_state.sample(); let fingerprint_challenge = verifier_state.sample(); - let mut bus_quotients: [EF; N_TABLES] = Default::default(); - let mut air_points: [MultilinearPoint; N_TABLES] = Default::default(); - let mut evals_f: [Vec; N_TABLES] = Default::default(); - let mut evals_ef: [Vec; N_TABLES] = Default::default(); + let mut bus_quotients: BTreeMap = Default::default(); + let mut air_points: BTreeMap> = Default::default(); + let mut evals_f: BTreeMap> = Default::default(); + let mut evals_ef: BTreeMap> = Default::default(); - for i in 0..N_TABLES { - (bus_quotients[i], air_points[i], evals_f[i], evals_ef[i]) = verify_bus_and_air( + for (table, height) in &table_heights { + let (this_bus_quotient, this_air_point, this_evals_f, this_evals_ef) = verify_bus_and_air( &mut verifier_state, - &ALL_TABLES[i], - table_heights[i], + table, + *height, bus_challenge, fingerprint_challenge, )?; + bus_quotients.insert(*table, this_bus_quotient); + air_points.insert(*table, this_air_point); + evals_f.insert(*table, this_evals_f); + evals_ef.insert(*table, this_evals_ef); } - if bus_quotients.iter().copied().sum::() != EF::ZERO { + if bus_quotients.values().copied().sum::() != EF::ZERO { return Err(ProofError::InvalidProof); } @@ -107,30 +97,36 @@ pub fn verify_execution( MultilinearPoint(verifier_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); let bytecode_lookup_claim_1 = Evaluation::new( - air_points[Table::execution().index()].clone(), - padd_with_zero_to_next_power_of_two(&evals_f[Table::execution().index()][..N_INSTRUCTION_COLUMNS]) + air_points[&Table::execution()].clone(), + padd_with_zero_to_next_power_of_two(&evals_f[&Table::execution()][..N_INSTRUCTION_COLUMNS]) .evaluate(&bytecode_compression_challenges), ); let normal_lookup_into_memory = NormalPackedLookupVerifier::step_1( &mut verifier_state, - (0..N_TABLES) - .flat_map(|i| vec![table_heights[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_f()]) + table_heights + .iter() + .flat_map(|(table, height)| vec![height.n_rows_non_padded_maxed(); table.num_normal_lookups_f()]) .collect(), - (0..N_TABLES) - .flat_map(|i| vec![table_heights[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_normal_lookups_ef()]) + table_heights + .iter() + .flat_map(|(table, height)| vec![height.n_rows_non_padded_maxed(); table.num_normal_lookups_ef()]) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_default_indexes_f()) + table_heights + .keys() + .flat_map(|table| table.normal_lookup_default_indexes_f()) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookup_default_indexes_ef()) + table_heights + .keys() + .flat_map(|table| table.normal_lookup_default_indexes_ef()) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookups_statements_f(&air_points[i], &evals_f[i])) + table_heights + .keys() + .flat_map(|table| table.normal_lookups_statements_f(&air_points[table], &evals_f[table])) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].normal_lookups_statements_ef(&air_points[i], &evals_ef[i])) + table_heights + .keys() + .flat_map(|table| table.normal_lookups_statements_ef(&air_points[table], &evals_ef[table])) .collect(), LOG_SMALLEST_DECOMPOSITION_CHUNK, &public_memory, // we need to pass the first few values of memory, public memory is enough @@ -138,28 +134,18 @@ pub fn verify_execution( let vectorized_lookup_into_memory = VectorizedPackedLookupVerifier::<_, VECTOR_LEN>::step_1( &mut verifier_state, - (0..N_TABLES) - .flat_map(|i| vec![table_heights[i].n_rows_non_padded_maxed(); ALL_TABLES[i].num_vector_lookups()]) + table_heights + .iter() + .flat_map(|(table, height)| vec![height.n_rows_non_padded_maxed(); table.num_vector_lookups()]) .collect(), - (0..N_TABLES) - .flat_map(|i| ALL_TABLES[i].vector_lookup_default_indexes()) + table_heights + .keys() + .flat_map(|table| table.vector_lookup_default_indexes()) + .collect(), + table_heights + .keys() + .flat_map(|table| table.vectorized_lookups_statements(&air_points[table], &evals_f[table])) .collect(), - { - let mut statements = vec![]; - for table in ALL_TABLES { - if table.identifier() == Table::poseidon16() { - statements.extend(poseidon_16_vectorized_lookup_statements(&p16_gkr)); // special case - continue; - } - if table.identifier() == Table::poseidon24() { - statements.extend(poseidon_24_vectorized_lookup_statements(&p24_gkr)); // special case - continue; - } - statements - .extend(table.vectorized_lookups_statements(&air_points[table.index()], &evals_f[table.index()])); - } - statements - }, LOG_SMALLEST_DECOMPOSITION_CHUNK, &public_memory, // we need to pass the first few values of memory, public memory is enough )?; @@ -191,7 +177,7 @@ pub fn verify_execution( let bytecode_logup_star_statements = verify_logup_star( &mut verifier_state, log2_ceil_usize(bytecode.instructions.len()), - table_heights[Table::execution().index()].log_padded(), + table_heights[&Table::execution()].log_padded(), &[bytecode_lookup_claim_1], EF::ONE, )?; @@ -206,43 +192,89 @@ pub fn verify_execution( vectorized_lookup_statements.on_table.clone(), ]; - let mut final_statements: [Vec<_>; N_TABLES] = Default::default(); - for i in 0..N_TABLES { - final_statements[i] = ALL_TABLES[i].committed_statements_verifier( - &mut verifier_state, - &air_points[i], - &evals_f[i], - &evals_ef[i], - &mut normal_lookup_statements.on_indexes_f, - &mut normal_lookup_statements.on_indexes_ef, - )?; + let mut final_statements: BTreeMap> = Default::default(); + for table in table_heights.keys() { + final_statements.insert( + *table, + table.committed_statements_verifier( + &mut verifier_state, + &air_points[table], + &evals_f[table], + &evals_ef[table], + &mut normal_lookup_statements.on_indexes_f, + &mut normal_lookup_statements.on_indexes_ef, + )?, + ); } assert!(normal_lookup_statements.on_indexes_f.is_empty()); assert!(normal_lookup_statements.on_indexes_ef.is_empty()); + let p16_gkr = verify_poseidon_gkr( + &mut verifier_state, + table_heights[&Table::poseidon16_core()].log_padded(), + &air_points[&Table::poseidon16_core()].0, + &p16_gkr_layers, + UNIVARIATE_SKIPS, + true, + ); + assert_eq!(&p16_gkr.output_statements.point, &air_points[&Table::poseidon16_core()]); + assert_eq!( + &p16_gkr.output_statements.values, + &evals_f[&Table::poseidon16_core()][POSEIDON_16_CORE_COL_OUTPUT_START..][..16] + ); + + let p24_gkr = verify_poseidon_gkr( + &mut verifier_state, + table_heights[&Table::poseidon24_core()].log_padded(), + &air_points[&Table::poseidon24_core()].0, + &p24_gkr_layers, + UNIVARIATE_SKIPS, + false, + ); + assert_eq!(&p24_gkr.output_statements.point, &air_points[&Table::poseidon24_core()]); + assert_eq!( + &p24_gkr.output_statements.values[16..], + &evals_f[&Table::poseidon24_core()][POSEIDON_24_CORE_COL_OUTPUT_START..][..8] + ); + { let mut cursor = 0; - for t in 0..N_TABLES { + for table in table_heights.keys() { for (statement, lookup) in vectorized_lookup_statements.on_indexes[cursor..] .iter() - .zip(ALL_TABLES[t].vector_lookups()) + .zip(table.vector_lookups()) { - final_statements[t][lookup.index].extend(statement.clone()); + final_statements.get_mut(table).unwrap()[lookup.index].extend(statement.clone()); } - cursor += ALL_TABLES[t].num_vector_lookups(); + cursor += table.num_vector_lookups(); } } let (initial_pc_statement, final_pc_statement) = - initial_and_final_pc_conditions(table_heights[Table::execution().index()].log_padded()); + initial_and_final_pc_conditions(table_heights[&Table::execution()].log_padded()); - final_statements[Table::execution().index()][ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)].extend( - vec![ + final_statements.get_mut(&Table::execution()).unwrap()[ExecutionTable.find_committed_column_index_f(COL_INDEX_PC)] + .extend(vec![ bytecode_logup_star_statements.on_indexes.clone(), initial_pc_statement, final_pc_statement, - ], - ); + ]); + let statements_p16_core = final_statements.get_mut(&Table::poseidon16_core()).unwrap(); + for (stmts, gkr_value) in statements_p16_core[POSEIDON_16_CORE_COL_INPUT_START..][..16] + .iter_mut() + .zip(&p16_gkr.input_statements.values) + { + stmts.push(Evaluation::new(p16_gkr.input_statements.point.clone(), *gkr_value)); + } + statements_p16_core[POSEIDON_16_CORE_COL_COMPRESSION].push(p16_gkr.on_compression_selector.unwrap()); + + let statements_p24_core = final_statements.get_mut(&Table::poseidon24_core()).unwrap(); + for (stmts, gkr_value) in statements_p24_core[POSEIDON_24_CORE_COL_INPUT_START..][..24] + .iter_mut() + .zip(&p24_gkr.input_statements.values) + { + stmts.push(Evaluation::new(p24_gkr.input_statements.point.clone(), *gkr_value)); + } let mut all_base_statements = [ vec![memory_statements], @@ -250,7 +282,7 @@ pub fn verify_execution( encapsulate_vec(p24_gkr.cubes_statements.split()), ] .concat(); - all_base_statements.extend(final_statements.into_iter().flatten()); + all_base_statements.extend(final_statements.into_values().flatten()); let global_statements_base = packed_pcs_global_statements_for_verifier( &base_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK, @@ -368,8 +400,12 @@ fn verify_bus_and_air( } let extra_data = ExtraDataForBuses { - fingerprint_challenge_powers: powers_const(fingerprint_challenge), + fingerprint_challenge_powers: fingerprint_challenge.powers().collect_n(max_bus_width()), + fingerprint_challenge_powers_packed: EFPacking::::from(fingerprint_challenge) + .powers() + .collect_n(max_bus_width()), bus_beta, + bus_beta_packed: EFPacking::::from(bus_beta), alpha_powers: vec![], // filled later }; diff --git a/crates/lean_prover/tests/hash_chain.rs b/crates/lean_prover/tests/hash_chain.rs index 04794701..bd8821e0 100644 --- a/crates/lean_prover/tests/hash_chain.rs +++ b/crates/lean_prover/tests/hash_chain.rs @@ -63,6 +63,7 @@ fn benchmark_poseidon_chain() { 1 << (3 + LOG_CHAIN_LENGTH), false, (&vec![], &vec![]), + Default::default(), ) .no_vec_runtime_memory; @@ -74,6 +75,7 @@ fn benchmark_poseidon_chain() { no_vec_runtime_memory, false, (&vec![], &vec![]), // TODO poseidons precomputed + Default::default(), // TODO merkle path hints ) .0; let vm_time = time.elapsed(); diff --git a/crates/lean_prover/tests/test_zkvm.rs b/crates/lean_prover/tests/test_zkvm.rs index 899c2fa3..0c1239e1 100644 --- a/crates/lean_prover/tests/test_zkvm.rs +++ b/crates/lean_prover/tests/test_zkvm.rs @@ -1,3 +1,5 @@ +use std::collections::VecDeque; + use lean_compiler::*; use lean_prover::{prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder}; use lean_vm::*; @@ -12,6 +14,10 @@ fn test_zk_vm_all_precompiles() { const COMPRESSION = 1; const PERMUTATION = 0; const N = 11; + const MERKLE_HEIGHT_1 = 10; + const LEAF_POS_1 = 781; + const MERKLE_HEIGHT_2 = 15; + const LEAF_POS_2 = 178; fn main() { pub_start = public_input_start; pub_start_vec = pub_start / 8; @@ -21,6 +27,11 @@ fn test_zk_vm_all_precompiles() { poseidon24(pub_start_vec + 7, pub_start_vec + 9, pub_start_vec + 10); dot_product_be(pub_start + 88, pub_start + 88 + N, pub_start + 1000, N); dot_product_ee(pub_start + 88 + N, pub_start + 88 + N * (DIM + 1), pub_start + 1000 + DIM, N); + merkle_verify((pub_start + 2000) / 8, LEAF_POS_1, (pub_start + 2000 + 8) / 8, MERKLE_HEIGHT_1); + merkle_verify((pub_start + 2000 + 16) / 8, LEAF_POS_2, (pub_start + 2000 + 24) / 8, MERKLE_HEIGHT_2); + index_res_slice_hash = 10000; + slice_hash(5, 6, index_res_slice_hash, 3); + eq_poly_base_ext(pub_start + 1100, pub_start +1100 + 3, pub_start + 1100 + (DIM + 1) * 3, 3); return; } @@ -66,7 +77,56 @@ fn test_zk_vm_all_precompiles() { public_input[1000..][..DIMENSION].copy_from_slice(dot_product_base_ext.as_basis_coefficients_slice()); public_input[1000 + DIMENSION..][..DIMENSION].copy_from_slice(dot_product_ext_ext.as_basis_coefficients_slice()); - test_zk_vm_helper(program_str, (&public_input, &[]), 0); + let slice_a: [F; 3] = rng.random(); + let slice_b: [EF; 3] = rng.random(); + let poly_eq = MultilinearPoint(slice_b.to_vec()) + .eq_poly_outside(&MultilinearPoint(slice_a.iter().map(|&x| EF::from(x)).collect())); + public_input[1100..][..3].copy_from_slice(&slice_a); + public_input[1100 + 3..][..3 * DIMENSION].copy_from_slice( + slice_b + .iter() + .flat_map(|&x| x.as_basis_coefficients_slice().to_vec()) + .collect::>() + .as_slice(), + ); + public_input[1100 + 3 + 3 * DIMENSION..][..DIMENSION].copy_from_slice(poly_eq.as_basis_coefficients_slice()); + + fn add_merkle_path( + rng: &mut StdRng, + public_input: &mut [F], + merkle_height: usize, + leaf_position: usize, + ) -> Vec<[F; 8]> { + let leaf: [F; VECTOR_LEN] = rng.random(); + public_input[..VECTOR_LEN].copy_from_slice(&leaf); + let mut merkle_path = Vec::new(); + let mut current_digest = leaf; + for i in 0..merkle_height { + let sibling: [F; VECTOR_LEN] = rng.random(); + merkle_path.push(sibling); + let (left, right) = if (leaf_position >> i) & 1 == 0 { + (current_digest, sibling) + } else { + (sibling, current_digest) + }; + current_digest = poseidon16_permute([left.to_vec(), right.to_vec()].concat().try_into().unwrap()) + [..VECTOR_LEN] + .try_into() + .unwrap(); + } + let root = current_digest; + public_input[VECTOR_LEN..][..VECTOR_LEN].copy_from_slice(&root); + merkle_path + } + + let merkle_path_1 = add_merkle_path(&mut rng, &mut public_input[2000..], 10, 781); + let merkle_path_2 = add_merkle_path(&mut rng, &mut public_input[2000 + 16..], 15, 178); + + let mut merkle_path_hints = VecDeque::new(); + merkle_path_hints.push_back(merkle_path_1); + merkle_path_hints.push_back(merkle_path_2); + + test_zk_vm_helper(program_str, (&public_input, &[]), 0, merkle_path_hints); } #[test] @@ -103,28 +163,34 @@ fn test_prove_fibonacci() { "#; let n = std::env::var("FIB_N") - .unwrap_or("100000".to_string()) + .unwrap_or("10000".to_string()) .parse::() .unwrap(); let program_str = program_str.replace("FIB_N_PLACEHOLDER", &n.to_string()); - test_zk_vm_helper(&program_str, (&[F::ZERO; 1 << 14], &[]), 0); + test_zk_vm_helper(&program_str, (&[F::ZERO; 1 << 14], &[]), 0, Default::default()); } -fn test_zk_vm_helper(program_str: &str, (public_input, private_input): (&[F], &[F]), no_vec_runtime_memory: usize) { +fn test_zk_vm_helper( + program_str: &str, + (public_input, private_input): (&[F], &[F]), + no_vec_runtime_memory: usize, + merkle_path_hints: VecDeque>, +) { utils::init_tracing(); let bytecode = compile_program(program_str.to_string()); let time = std::time::Instant::now(); - let (proof_data, summary) = prove_execution( + let (proof, summary) = prove_execution( &bytecode, (public_input, private_input), whir_config_builder(), no_vec_runtime_memory, false, (&vec![], &vec![]), + merkle_path_hints, ); let proof_time = time.elapsed(); - verify_execution(&bytecode, public_input, proof_data, whir_config_builder()).unwrap(); + verify_execution(&bytecode, public_input, proof, whir_config_builder()).unwrap(); println!("{summary}"); println!("Proof time: {:.3} s", proof_time.as_secs_f32()); } diff --git a/crates/lean_prover/witness_generation/src/execution_trace.rs b/crates/lean_prover/witness_generation/src/execution_trace.rs index 83b9b048..eaf55515 100644 --- a/crates/lean_prover/witness_generation/src/execution_trace.rs +++ b/crates/lean_prover/witness_generation/src/execution_trace.rs @@ -1,12 +1,12 @@ use crate::instruction_encoder::field_representation; use lean_vm::*; use multilinear_toolkit::prelude::*; -use std::{array, iter::repeat_n}; +use std::{array, collections::BTreeMap, iter::repeat_n}; use utils::{ToUsize, transposed_par_iter_mut}; #[derive(Debug)] pub struct ExecutionTrace { - pub traces: [TableTrace; N_TABLES], + pub traces: BTreeMap, pub public_memory_size: usize, pub non_zero_memory_size: usize, pub memory: Vec, // of length a multiple of public_memory_size @@ -96,13 +96,16 @@ pub fn get_execution_trace(bytecode: &Bytecode, mut execution_result: ExecutionR let ExecutionResult { mut traces, .. } = execution_result; - traces[Table::execution().index()] = TableTrace { - base: Vec::from(main_trace), - ext: vec![], - height: Default::default(), - }; - for (trace, table) in traces.iter_mut().zip(ALL_TABLES) { - padd_table(&table, trace); + traces.insert( + Table::execution(), + TableTrace { + base: Vec::from(main_trace), + ext: vec![], + height: Default::default(), + }, + ); + for table in traces.keys().copied().collect::>() { + padd_table(&table, &mut traces); } ExecutionTrace { @@ -113,21 +116,24 @@ pub fn get_execution_trace(bytecode: &Bytecode, mut execution_result: ExecutionR } } -fn padd_table(t: &T, trace: &mut TableTrace) { +fn padd_table(table: &Table, traces: &mut BTreeMap) { + let trace = traces.get_mut(table).unwrap(); let h = trace.base[0].len(); trace .base .iter() .enumerate() - .for_each(|(i, col)| assert_eq!(col.len(), h, "column {}, table {}", i, t.name())); + .for_each(|(i, col)| assert_eq!(col.len(), h, "column {}, table {}", i, table.name())); trace.height = TableHeight(h); - + let padding_len = trace.height.padding_len(); + let padding_row_f = table.padding_row_f(); trace.base.par_iter_mut().enumerate().for_each(|(i, col)| { - col.extend(repeat_n(t.padding_row_f()[i], trace.height.padding_len())); + col.extend(repeat_n(padding_row_f[i], padding_len)); }); + let padding_row_ef = table.padding_row_ef(); trace.ext.par_iter_mut().enumerate().for_each(|(i, col)| { - col.extend(repeat_n(t.padding_row_ef()[i], trace.height.padding_len())); + col.extend(repeat_n(padding_row_ef[i], padding_len)); }); } diff --git a/crates/lean_vm/Cargo.toml b/crates/lean_vm/Cargo.toml index 5a2effef..f91dc646 100644 --- a/crates/lean_vm/Cargo.toml +++ b/crates/lean_vm/Cargo.toml @@ -27,4 +27,5 @@ lookup.workspace = true thiserror.workspace = true derive_more.workspace = true multilinear-toolkit.workspace = true -num_enum.workspace = true \ No newline at end of file +num_enum.workspace = true +itertools.workspace = true \ No newline at end of file diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index f2589b8d..822119d7 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -1,7 +1,9 @@ +use std::collections::BTreeMap; + use crate::core::F; use crate::diagnostics::profiler::MemoryProfile; use crate::execution::Memory; -use crate::{N_TABLES, TableTrace}; +use crate::{N_TABLES, Table, TableTrace}; use thiserror::Error; #[derive(Debug)] @@ -11,7 +13,7 @@ pub struct ExecutionResult { pub memory: Memory, pub pcs: Vec, pub fps: Vec, - pub traces: [TableTrace; N_TABLES], + pub traces: BTreeMap, pub summary: String, pub memory_profile: Option, } diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index f9bed034..65f4eeeb 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -13,8 +13,7 @@ use crate::{ TableTrace, }; use multilinear_toolkit::prelude::*; -use std::array; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, VecDeque}; use utils::{poseidon16_permute, poseidon24_permute, pretty_integer}; use xmss::{Poseidon16History, Poseidon24History}; @@ -58,6 +57,7 @@ pub fn execute_bytecode( no_vec_runtime_memory: usize, // size of the "non-vectorized" runtime memory profiling: bool, (poseidons_16_precomputed, poseidons_24_precomputed): (&Poseidon16History, &Poseidon24History), + merkle_path_hints: VecDeque>, ) -> ExecutionResult { let mut std_out = String::new(); let mut instruction_history = ExecutionHistory::new(); @@ -69,6 +69,7 @@ pub fn execute_bytecode( no_vec_runtime_memory, profiling, (poseidons_16_precomputed, poseidons_24_precomputed), + merkle_path_hints, ) .unwrap_or_else(|(last_pc, err)| { let lines_history = &instruction_history.lines; @@ -147,6 +148,7 @@ fn execute_bytecode_helper( no_vec_runtime_memory: usize, profiling: bool, (poseidons_16_precomputed, poseidons_24_precomputed): (&Poseidon16History, &Poseidon24History), + mut merkle_path_hints: VecDeque>, ) -> Result { // set public memory let mut memory = Memory::new(build_public_memory(public_input)); @@ -187,8 +189,7 @@ fn execute_bytecode_helper( let mut n_poseidon16_precomputed_used = 0; let mut n_poseidon24_precomputed_used = 0; - // Events collected only in final execution - let mut traces: [TableTrace; N_TABLES] = array::from_fn(|i| TableTrace::new(&ALL_TABLES[i])); + let mut traces = BTreeMap::from_iter((0..N_TABLES).map(|i| (ALL_TABLES[i], TableTrace::new(&ALL_TABLES[i])))); let mut add_counts = 0; let mut mul_counts = 0; @@ -212,6 +213,7 @@ fn execute_bytecode_helper( for hint in bytecode.hints.get(&pc).unwrap_or(&vec![]) { let mut hint_ctx = HintExecutionContext { memory: &mut memory, + private_input_start: public_memory_size, fp, ap: &mut ap, ap_vec: &mut ap_vec, @@ -242,6 +244,7 @@ fn execute_bytecode_helper( jump_counts: &mut jump_counts, poseidon16_precomputed: poseidons_16_precomputed, poseidon24_precomputed: poseidons_24_precomputed, + merkle_path_hints: &mut merkle_path_hints, n_poseidon16_precomputed_used: &mut n_poseidon16_precomputed_used, n_poseidon24_precomputed_used: &mut n_poseidon24_precomputed_used, }; @@ -331,14 +334,13 @@ fn execute_bytecode_helper( summary.push('\n'); - if traces[Table::poseidon16().index()].base[0].len() + traces[Table::poseidon24().index()].base[0].len() > 0 { + if traces[&Table::poseidon16_core()].base[0].len() + traces[&Table::poseidon24_core()].base[0].len() > 0 { summary.push_str(&format!( "Poseidon2_16 calls: {}, Poseidon2_24 calls: {}, (1 poseidon per {} instructions)\n", - pretty_integer(traces[Table::poseidon16().index()].base[0].len()), - pretty_integer(traces[Table::poseidon24().index()].base[0].len()), + pretty_integer(traces[&Table::poseidon16_core()].base[0].len()), + pretty_integer(traces[&Table::poseidon24_core()].base[0].len()), cpu_cycles - / (traces[Table::poseidon16().index()].base[0].len() - + traces[Table::poseidon24().index()].base[0].len()) + / (traces[&Table::poseidon16_core()].base[0].len() + traces[&Table::poseidon24_core()].base[0].len()) )); } // if !dot_products.is_empty() { diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 09a5d034..cd8d1fe2 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -43,16 +43,11 @@ pub enum Hint { /// and ai < 4, b < 2^7 - 1 /// The decomposition is unique, and always exists (except for x = -1) DecomposeCustom { - /// Memory offset for results: m[fp + res_offset..fp + res_offset + 13 * len(to_decompose)] - res_offset: usize, + decomposed: MemOrConstant, + remaining: MemOrConstant, /// Values to decompose into custom representation to_decompose: Vec, }, - /// Provide a counter value - CounterHint { - /// Memory offset where counter result will be stored: m[fp + res_offset] - res_offset: usize, - }, /// Print debug information during execution Print { /// Source code location information @@ -60,21 +55,30 @@ pub enum Hint { /// Values to print content: Vec, }, + PrivateInputStart { + res_offset: usize, + }, /// Report source code location for debugging LocationReport { /// Source code location location: SourceLineNumber, }, /// Jump destination label (for debugging purposes) - Label { label: Label }, + Label { + label: Label, + }, /// Stack frame size (for memory profiling) - StackFrame { label: Label, size: usize }, + StackFrame { + label: Label, + size: usize, + }, } /// Execution state for hint processing #[derive(Debug)] pub struct HintExecutionContext<'a> { pub memory: &'a mut Memory, + pub private_input_start: usize, // normal pointer pub fp: usize, pub ap: &'a mut usize, pub ap_vec: &'a mut usize, @@ -162,25 +166,23 @@ impl Hint { } } Self::DecomposeCustom { - res_offset, + decomposed, + remaining, to_decompose, } => { - let mut memory_index = ctx.fp + *res_offset; + let mut memory_index_decomposed = decomposed.read_value(ctx.memory, ctx.fp)?.to_usize(); + let mut memory_index_remaining = remaining.read_value(ctx.memory, ctx.fp)?.to_usize(); for value_source in to_decompose { let value = value_source.read_value(ctx.memory, ctx.fp)?.to_usize(); for i in 0..12 { let value = F::from_usize((value >> (2 * i)) & 0b11); - ctx.memory.set(memory_index, value)?; - memory_index += 1; + ctx.memory.set(memory_index_decomposed, value)?; + memory_index_decomposed += 1; } - ctx.memory.set(memory_index, F::from_usize(value >> 24))?; - memory_index += 1; + ctx.memory.set(memory_index_remaining, F::from_usize(value >> 24))?; + memory_index_remaining += 1; } } - Self::CounterHint { res_offset } => { - ctx.memory.set(ctx.fp + *res_offset, F::from_usize(*ctx.counter_hint))?; - *ctx.counter_hint += 1; - } Self::Inverse { arg, res_offset } => { let value = arg.read_value(ctx.memory, ctx.fp)?; let result = value.try_inverse().unwrap_or(F::ZERO); @@ -223,6 +225,10 @@ impl Hint { .push(*ctx.cpu_cycles_before_new_line); *ctx.cpu_cycles_before_new_line = 0; } + Self::PrivateInputStart { res_offset } => { + ctx.memory + .set(ctx.fp + *res_offset, F::from_usize(ctx.private_input_start))?; + } Self::Label { .. } => {} Self::StackFrame { label, size } => { if ctx.profiling { @@ -257,6 +263,9 @@ impl Display for Hint { write!(f, "m[fp + {offset}] = request_memory({size})") } } + Self::PrivateInputStart { res_offset } => { + write!(f, "m[fp + {res_offset}] = private_input_start()") + } Self::DecomposeBits { res_offset, to_decompose, @@ -271,10 +280,11 @@ impl Display for Hint { write!(f, ")") } Self::DecomposeCustom { - res_offset, + decomposed, + remaining, to_decompose, } => { - write!(f, "m[fp + {res_offset}] = decompose_custom(")?; + write!(f, "decompose_custom(m[fp + {decomposed}], m[fp + {remaining}], ")?; for (i, v) in to_decompose.iter().enumerate() { if i > 0 { write!(f, ", ")?; @@ -283,9 +293,6 @@ impl Display for Hint { } write!(f, ")") } - Self::CounterHint { res_offset } => { - write!(f, "m[fp + {res_offset}] = counter_hint()") - } Self::Print { line_info, content } => { write!(f, "print(")?; for (i, v) in content.iter().enumerate() { diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index 2e2cd141..87629925 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -6,8 +6,9 @@ use crate::core::{F, Label}; use crate::diagnostics::RunnerError; use crate::execution::Memory; use crate::tables::TableT; -use crate::{N_TABLES, Table, TableTrace}; +use crate::{Table, TableTrace}; use multilinear_toolkit::prelude::*; +use std::collections::{BTreeMap, VecDeque}; use std::fmt::{Display, Formatter}; use utils::ToUsize; @@ -64,13 +65,14 @@ pub struct InstructionContext<'a> { pub fp: &'a mut usize, pub pc: &'a mut usize, pub pcs: &'a Vec, - pub traces: &'a mut [TableTrace; N_TABLES], + pub traces: &'a mut BTreeMap, pub add_counts: &'a mut usize, pub mul_counts: &'a mut usize, pub deref_counts: &'a mut usize, pub jump_counts: &'a mut usize, pub poseidon16_precomputed: &'a [([F; 16], [F; 16])], pub poseidon24_precomputed: &'a [([F; 24], [F; 8])], + pub merkle_path_hints: &'a mut VecDeque>, pub n_poseidon16_precomputed_used: &'a mut usize, pub n_poseidon24_precomputed_used: &'a mut usize, } diff --git a/crates/lean_vm/src/tables/dot_product/air.rs b/crates/lean_vm/src/tables/dot_product/air.rs index a515845d..04b354bf 100644 --- a/crates/lean_vm/src/tables/dot_product/air.rs +++ b/crates/lean_vm/src/tables/dot_product/air.rs @@ -21,16 +21,16 @@ use p3_air::{Air, AirBuilder}; */ // F columns -pub(super) const DOT_PRODUCT_AIR_COL_FLAG: usize = 0; -pub(super) const DOT_PRODUCT_AIR_COL_LEN: usize = 1; -pub(super) const DOT_PRODUCT_AIR_COL_INDEX_A: usize = 2; -pub(super) const DOT_PRODUCT_AIR_COL_INDEX_B: usize = 3; -pub(super) const DOT_PRODUCT_AIR_COL_INDEX_RES: usize = 4; +pub(super) const COL_FLAG: usize = 0; +pub(super) const COL_LEN: usize = 1; +pub(super) const COL_INDEX_A: usize = 2; +pub(super) const COL_INDEX_B: usize = 3; +pub(super) const COL_INDEX_RES: usize = 4; // EF columns -pub(super) const DOT_PRODUCT_AIR_COL_VALUE_B: usize = 0; -pub(super) const DOT_PRODUCT_AIR_COL_VALUE_RES: usize = 1; -pub(super) const DOT_PRODUCT_AIR_COL_COMPUTATION: usize = 2; +pub(super) const COL_VALUE_B: usize = 0; +pub(super) const COL_VALUE_RES: usize = 1; +pub(super) const COL_COMPUTATION: usize = 2; pub(super) const fn dot_product_air_col_value_a(be: bool) -> usize { if be { 5 } else { 3 } @@ -60,15 +60,10 @@ impl Air for DotProductPrecompile { 8 } fn down_column_indexes_f(&self) -> Vec { - vec![ - DOT_PRODUCT_AIR_COL_FLAG, - DOT_PRODUCT_AIR_COL_LEN, - DOT_PRODUCT_AIR_COL_INDEX_A, - DOT_PRODUCT_AIR_COL_INDEX_B, - ] + vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B] } fn down_column_indexes_ef(&self) -> Vec { - vec![DOT_PRODUCT_AIR_COL_COMPUTATION] + vec![COL_COMPUTATION] } #[inline] @@ -78,20 +73,20 @@ impl Air for DotProductPrecompile { let down_f = builder.down_f(); let down_ef = builder.down_ef(); - let flag = up_f[DOT_PRODUCT_AIR_COL_FLAG].clone(); - let len = up_f[DOT_PRODUCT_AIR_COL_LEN].clone(); - let index_a = up_f[DOT_PRODUCT_AIR_COL_INDEX_A].clone(); - let index_b = up_f[DOT_PRODUCT_AIR_COL_INDEX_B].clone(); - let index_res = up_f[DOT_PRODUCT_AIR_COL_INDEX_RES].clone(); + let flag = up_f[COL_FLAG].clone(); + let len = up_f[COL_LEN].clone(); + let index_a = up_f[COL_INDEX_A].clone(); + let index_b = up_f[COL_INDEX_B].clone(); + let index_res = up_f[COL_INDEX_RES].clone(); let value_a = if BE { AB::EF::from(up_f[dot_product_air_col_value_a(BE)].clone()) // TODO embdding overhead } else { up_ef[dot_product_air_col_value_a(BE)].clone() }; - let value_b = up_ef[DOT_PRODUCT_AIR_COL_VALUE_B].clone(); - let res = up_ef[DOT_PRODUCT_AIR_COL_VALUE_RES].clone(); - let computation = up_ef[DOT_PRODUCT_AIR_COL_COMPUTATION].clone(); + let value_b = up_ef[COL_VALUE_B].clone(); + let res = up_ef[COL_VALUE_RES].clone(); + let computation = up_ef[COL_COMPUTATION].clone(); let flag_down = down_f[0].clone(); let len_down = down_f[1].clone(); @@ -104,10 +99,7 @@ impl Air for DotProductPrecompile { extra_data, AB::F::from_usize(self.identifier().index()), flag.clone(), - index_a.clone(), - index_b.clone(), - index_res.clone(), - len.clone(), + &[index_a.clone(), index_b.clone(), index_res.clone(), len.clone()], )); builder.assert_bool(flag.clone()); diff --git a/crates/lean_vm/src/tables/dot_product/exec.rs b/crates/lean_vm/src/tables/dot_product/exec.rs index e774b523..4b25d40a 100644 --- a/crates/lean_vm/src/tables/dot_product/exec.rs +++ b/crates/lean_vm/src/tables/dot_product/exec.rs @@ -24,7 +24,7 @@ pub(super) fn exec_dot_product_be( { { - let computation = &mut trace.ext[DOT_PRODUCT_AIR_COL_COMPUTATION]; + let computation = &mut trace.ext[COL_COMPUTATION]; computation.extend(EF::zero_vec(size)); let new_size = computation.len(); computation[new_size - 1] = slice_1[size - 1] * slice_0[size - 1]; @@ -34,16 +34,15 @@ pub(super) fn exec_dot_product_be( } } - trace.base[DOT_PRODUCT_AIR_COL_FLAG].push(F::ONE); - trace.base[DOT_PRODUCT_AIR_COL_FLAG].extend(F::zero_vec(size - 1)); - trace.base[DOT_PRODUCT_AIR_COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_B] - .extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); + trace.base[COL_FLAG].push(F::ONE); + trace.base[COL_FLAG].extend(F::zero_vec(size - 1)); + trace.base[COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); + trace.base[COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); + trace.base[COL_INDEX_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); + trace.base[COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); trace.base[dot_product_air_col_value_a(true)].extend(slice_0); - trace.ext[DOT_PRODUCT_AIR_COL_VALUE_B].extend(slice_1); - trace.ext[DOT_PRODUCT_AIR_COL_VALUE_RES].extend(vec![dot_product_result; size]); + trace.ext[COL_VALUE_B].extend(slice_1); + trace.ext[COL_VALUE_RES].extend(vec![dot_product_result; size]); } Ok(()) @@ -76,7 +75,7 @@ pub(super) fn exec_dot_product_ee( { { - let computation = &mut trace.ext[DOT_PRODUCT_AIR_COL_COMPUTATION]; + let computation = &mut trace.ext[COL_COMPUTATION]; computation.extend(EF::zero_vec(size)); let new_size = computation.len(); computation[new_size - 1] = slice_1[size - 1] * slice_0[size - 1]; @@ -86,17 +85,15 @@ pub(super) fn exec_dot_product_ee( } } - trace.base[DOT_PRODUCT_AIR_COL_FLAG].push(F::ONE); - trace.base[DOT_PRODUCT_AIR_COL_FLAG].extend(F::zero_vec(size - 1)); - trace.base[DOT_PRODUCT_AIR_COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_A] - .extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i * DIMENSION))); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_B] - .extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); - trace.base[DOT_PRODUCT_AIR_COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); + trace.base[COL_FLAG].push(F::ONE); + trace.base[COL_FLAG].extend(F::zero_vec(size - 1)); + trace.base[COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); + trace.base[COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i * DIMENSION))); + trace.base[COL_INDEX_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); + trace.base[COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); trace.ext[dot_product_air_col_value_a(false)].extend(slice_0); - trace.ext[DOT_PRODUCT_AIR_COL_VALUE_B].extend(slice_1); - trace.ext[DOT_PRODUCT_AIR_COL_VALUE_RES].extend(vec![dot_product_result; size]); + trace.ext[COL_VALUE_B].extend(slice_1); + trace.ext[COL_VALUE_RES].extend(vec![dot_product_result; size]); } Ok(()) diff --git a/crates/lean_vm/src/tables/dot_product/mod.rs b/crates/lean_vm/src/tables/dot_product/mod.rs index 3954abca..982e9aa0 100644 --- a/crates/lean_vm/src/tables/dot_product/mod.rs +++ b/crates/lean_vm/src/tables/dot_product/mod.rs @@ -26,27 +26,17 @@ impl TableT for DotProductPrecompile { } fn commited_columns_f(&self) -> Vec { - let mut res = vec![ - DOT_PRODUCT_AIR_COL_FLAG, - DOT_PRODUCT_AIR_COL_LEN, - DOT_PRODUCT_AIR_COL_INDEX_A, - DOT_PRODUCT_AIR_COL_INDEX_B, - DOT_PRODUCT_AIR_COL_INDEX_RES, - ]; - if BE { - res.push(dot_product_air_col_value_a(BE)); - } - res + vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES] } fn commited_columns_ef(&self) -> Vec { - vec![DOT_PRODUCT_AIR_COL_COMPUTATION] + vec![COL_COMPUTATION] } fn normal_lookups_f(&self) -> Vec { if BE { vec![LookupIntoMemory { - index: DOT_PRODUCT_AIR_COL_INDEX_A, + index: COL_INDEX_A, values: dot_product_air_col_value_a(BE), }] } else { @@ -57,19 +47,19 @@ impl TableT for DotProductPrecompile { fn normal_lookups_ef(&self) -> Vec { let mut res = vec![ ExtensionFieldLookupIntoMemory { - index: DOT_PRODUCT_AIR_COL_INDEX_B, - values: DOT_PRODUCT_AIR_COL_VALUE_B, + index: COL_INDEX_B, + values: COL_VALUE_B, }, ExtensionFieldLookupIntoMemory { - index: DOT_PRODUCT_AIR_COL_INDEX_RES, - values: DOT_PRODUCT_AIR_COL_VALUE_RES, + index: COL_INDEX_RES, + values: COL_VALUE_RES, }, ]; if !BE { res.insert( 0, ExtensionFieldLookupIntoMemory { - index: DOT_PRODUCT_AIR_COL_INDEX_A, + index: COL_INDEX_A, values: dot_product_air_col_value_a(BE), }, ); @@ -85,13 +75,8 @@ impl TableT for DotProductPrecompile { vec![Bus { table: BusTable::Constant(self.identifier()), direction: BusDirection::Pull, - selector: DOT_PRODUCT_AIR_COL_FLAG, - data: vec![ - DOT_PRODUCT_AIR_COL_INDEX_A, - DOT_PRODUCT_AIR_COL_INDEX_B, - DOT_PRODUCT_AIR_COL_INDEX_RES, - DOT_PRODUCT_AIR_COL_LEN, - ], + selector: BusSelector::Column(COL_FLAG), + data: vec![COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES, COL_LEN], }] } @@ -119,7 +104,7 @@ impl TableT for DotProductPrecompile { aux: usize, ctx: &mut InstructionContext<'_>, ) -> Result<(), RunnerError> { - let trace = &mut ctx.traces[self.identifier().index()]; + let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); if BE { exec_dot_product_be(arg_a, arg_b, arg_c, aux, ctx.memory, trace) } else { diff --git a/crates/lean_vm/src/tables/eq_poly_base_ext/air.rs b/crates/lean_vm/src/tables/eq_poly_base_ext/air.rs new file mode 100644 index 00000000..b41760f9 --- /dev/null +++ b/crates/lean_vm/src/tables/eq_poly_base_ext/air.rs @@ -0,0 +1,95 @@ +use crate::{ + DIMENSION, EF, ExtraDataForBuses, TableT, eval_virtual_bus_column, + tables::eq_poly_base_ext::EqPolyBaseExtPrecompile, +}; +use multilinear_toolkit::prelude::*; +use p3_air::{Air, AirBuilder}; + +// F columns +pub(super) const COL_FLAG: usize = 0; +pub(super) const COL_LEN: usize = 1; +pub(super) const COL_INDEX_A: usize = 2; +pub(super) const COL_INDEX_B: usize = 3; +pub(super) const COL_INDEX_RES: usize = 4; +pub(super) const COL_VALUE_A: usize = 5; + +// EF columns +pub(super) const COL_VALUE_B: usize = 0; +pub(super) const COL_VALUE_RES: usize = 1; +pub(super) const COL_COMPUTATION: usize = 2; + +pub(super) const N_COLS_F: usize = 6; +pub(super) const N_COLS_EF: usize = 3; + +impl Air for EqPolyBaseExtPrecompile { + type ExtraData = ExtraDataForBuses; + + fn n_columns_f_air(&self) -> usize { + N_COLS_F + } + fn n_columns_ef_air(&self) -> usize { + N_COLS_EF + } + fn degree(&self) -> usize { + 4 + } + fn n_constraints(&self) -> usize { + 8 + } + fn down_column_indexes_f(&self) -> Vec { + vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B] + } + fn down_column_indexes_ef(&self) -> Vec { + vec![COL_COMPUTATION] + } + + #[inline] + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up_f = builder.up_f(); + let up_ef = builder.up_ef(); + let down_f = builder.down_f(); + let down_ef = builder.down_ef(); + + let flag = up_f[COL_FLAG].clone(); + let len = up_f[COL_LEN].clone(); + let index_a = up_f[COL_INDEX_A].clone(); + let index_b = up_f[COL_INDEX_B].clone(); + let index_res = up_f[COL_INDEX_RES].clone(); + let value_a = up_f[COL_VALUE_A].clone(); + + let value_b = up_ef[COL_VALUE_B].clone(); + let res = up_ef[COL_VALUE_RES].clone(); + let computation = up_ef[COL_COMPUTATION].clone(); + + let flag_down = down_f[0].clone(); + let len_down = down_f[1].clone(); + let index_a_down = down_f[2].clone(); + let index_b_down = down_f[3].clone(); + + let computation_down = down_ef[0].clone(); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.identifier().index()), + flag.clone(), + &[index_a.clone(), index_b.clone(), index_res.clone(), len.clone()], + )); + + builder.assert_bool(flag.clone()); + + let product_up = + value_b.clone() * value_a.clone() + (AB::EF::ONE - value_b.clone()) * (AB::F::ONE - value_a.clone()); + let not_flag_down = AB::F::ONE - flag_down.clone(); + builder.assert_eq_ef( + computation.clone(), + product_up.clone() * (computation_down * not_flag_down.clone() + flag_down.clone()), + ); + builder.assert_zero(not_flag_down.clone() * (len.clone() - (len_down + AB::F::ONE))); + builder.assert_zero(flag_down * (len - AB::F::ONE)); + let index_a_increment = AB::F::ONE; + builder.assert_zero(not_flag_down.clone() * (index_a - (index_a_down - index_a_increment))); + builder.assert_zero(not_flag_down * (index_b - (index_b_down - AB::F::from_usize(DIMENSION)))); + + builder.assert_zero_ef((computation - res) * flag); + } +} diff --git a/crates/lean_vm/src/tables/eq_poly_base_ext/exec.rs b/crates/lean_vm/src/tables/eq_poly_base_ext/exec.rs new file mode 100644 index 00000000..898588e8 --- /dev/null +++ b/crates/lean_vm/src/tables/eq_poly_base_ext/exec.rs @@ -0,0 +1,46 @@ +use crate::EF; +use crate::F; +use crate::Memory; +use crate::RunnerError; +use crate::TableTrace; +use crate::tables::eq_poly_base_ext::*; +use utils::ToUsize; + +pub(super) fn exec_eq_poly_base_ext( + ptr_arg_0: F, + ptr_arg_1: F, + ptr_res: F, + size: usize, + memory: &mut Memory, + trace: &mut TableTrace, +) -> Result<(), RunnerError> { + assert!(size > 0); + + let slice_0 = memory.slice(ptr_arg_0.to_usize(), size)?; + let slice_1 = memory.get_continuous_slice_of_ef_elements(ptr_arg_1.to_usize(), size)?; + + let computation = &mut trace.ext[COL_COMPUTATION]; + computation.extend(EF::zero_vec(size)); + let new_size = computation.len(); + computation[new_size - 1] = + slice_1[size - 1] * slice_0[size - 1] + (EF::ONE - slice_1[size - 1]) * (F::ONE - slice_0[size - 1]); + for i in 0..size - 1 { + computation[new_size - 2 - i] = computation[new_size - 1 - i] + * (slice_1[size - 2 - i] * slice_0[size - 2 - i] + + (EF::ONE - slice_1[size - 2 - i]) * (F::ONE - slice_0[size - 2 - i])); + } + let final_result = computation[new_size - size]; + memory.set_ef_element(ptr_res.to_usize(), final_result)?; + + trace.base[COL_FLAG].push(F::ONE); + trace.base[COL_FLAG].extend(F::zero_vec(size - 1)); + trace.base[COL_LEN].extend(((1..=size).rev()).map(F::from_usize)); + trace.base[COL_INDEX_A].extend((0..size).map(|i| F::from_usize(ptr_arg_0.to_usize() + i))); + trace.base[COL_INDEX_B].extend((0..size).map(|i| F::from_usize(ptr_arg_1.to_usize() + i * DIMENSION))); + trace.base[COL_INDEX_RES].extend(vec![F::from_usize(ptr_res.to_usize()); size]); + trace.base[COL_VALUE_A].extend(slice_0); + trace.ext[COL_VALUE_B].extend(slice_1); + trace.ext[COL_VALUE_RES].extend(vec![final_result; size]); + + Ok(()) +} diff --git a/crates/lean_vm/src/tables/eq_poly_base_ext/mod.rs b/crates/lean_vm/src/tables/eq_poly_base_ext/mod.rs new file mode 100644 index 00000000..d58ac8f9 --- /dev/null +++ b/crates/lean_vm/src/tables/eq_poly_base_ext/mod.rs @@ -0,0 +1,93 @@ +use crate::{InstructionContext, tables::eq_poly_base_ext::exec::exec_eq_poly_base_ext, *}; +use multilinear_toolkit::prelude::*; + +mod air; +use air::*; +mod exec; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct EqPolyBaseExtPrecompile; + +impl TableT for EqPolyBaseExtPrecompile { + fn name(&self) -> &'static str { + "eq_poly_base_ext" + } + + fn identifier(&self) -> Table { + Table::eq_poly_base_ext() + } + + fn commited_columns_f(&self) -> Vec { + vec![COL_FLAG, COL_LEN, COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES] + } + + fn commited_columns_ef(&self) -> Vec { + vec![COL_COMPUTATION] + } + + fn normal_lookups_f(&self) -> Vec { + vec![LookupIntoMemory { + index: COL_INDEX_A, + values: COL_VALUE_A, + }] + } + + fn normal_lookups_ef(&self) -> Vec { + vec![ + ExtensionFieldLookupIntoMemory { + index: COL_INDEX_B, + values: COL_VALUE_B, + }, + ExtensionFieldLookupIntoMemory { + index: COL_INDEX_RES, + values: COL_VALUE_RES, + }, + ] + } + + fn vector_lookups(&self) -> Vec { + vec![] + } + + fn buses(&self) -> Vec { + vec![Bus { + table: BusTable::Constant(self.identifier()), + direction: BusDirection::Pull, + selector: BusSelector::Column(COL_FLAG), + data: vec![COL_INDEX_A, COL_INDEX_B, COL_INDEX_RES, COL_LEN], + }] + } + + fn padding_row_f(&self) -> Vec { + [vec![ + F::ONE, // StartFlag + F::ONE, // Len + F::ZERO, // Index A + F::ZERO, // Index B + F::from_usize(ONE_VEC_PTR * VECTOR_LEN), // Index Res + F::ZERO, // Value A + ]] + .concat() + } + + fn padding_row_ef(&self) -> Vec { + vec![ + EF::ZERO, // Value B + EF::ONE, // Value Res + EF::ONE, // Computation + ] + } + + #[inline(always)] + fn execute( + &self, + arg_a: F, + arg_b: F, + arg_c: F, + aux: usize, + ctx: &mut InstructionContext<'_>, + ) -> Result<(), RunnerError> { + let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); + exec_eq_poly_base_ext(arg_a, arg_b, arg_c, aux, ctx.memory, trace) + } +} diff --git a/crates/lean_vm/src/tables/execution/air.rs b/crates/lean_vm/src/tables/execution/air.rs index aca81b7a..a4136907 100644 --- a/crates/lean_vm/src/tables/execution/air.rs +++ b/crates/lean_vm/src/tables/execution/air.rs @@ -118,10 +118,7 @@ impl Air for ExecutionTable { extra_data, precompile_index.clone(), is_precompile.clone(), - nu_a.clone(), - nu_b.clone(), - nu_c.clone(), - aux.clone(), + &[nu_a.clone(), nu_b.clone(), nu_c.clone(), aux.clone()], )); builder.assert_zero(flag_a_minus_one * (addr_a.clone() - fp_plus_operand_a)); diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index 643dcc7f..a6b95703 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -63,7 +63,7 @@ impl TableT for ExecutionTable { vec![Bus { table: BusTable::Variable(COL_INDEX_PRECOMPILE_INDEX), direction: BusDirection::Push, - selector: COL_INDEX_IS_PRECOMPILE, + selector: BusSelector::Column(COL_INDEX_IS_PRECOMPILE), data: vec![ COL_INDEX_EXEC_NU_A, COL_INDEX_EXEC_NU_B, diff --git a/crates/lean_vm/src/tables/merkle/mod.rs b/crates/lean_vm/src/tables/merkle/mod.rs new file mode 100644 index 00000000..03ba9657 --- /dev/null +++ b/crates/lean_vm/src/tables/merkle/mod.rs @@ -0,0 +1,333 @@ +use std::array; + +use crate::*; +use multilinear_toolkit::prelude::*; +use p3_air::Air; +use utils::{ToUsize, get_poseidon_16_of_zero, poseidon16_permute, to_big_endian_in_field}; + +// Does not support height = 1 (minimum height is 2) + +// "committed" columns +const COL_FLAG: ColIndex = 0; +const COL_INDEX_LEAF: ColIndex = 1; // vectorized pointer +const COL_LEAF_POSITION: ColIndex = 2; // (between 0 and 2^height - 1) +const COL_INDEX_ROOT: ColIndex = 3; // vectorized pointer +const COL_HEIGHT: ColIndex = 4; // merkle tree height + +const COL_ZERO: ColIndex = 5; // always equal to 0, TODO remove this +const COL_ONE: ColIndex = 6; // always equal to 1, TODO remove this +const COL_IS_LEFT: ColIndex = 7; // boolean, whether the current node is a left child +const COL_LOOKUP_MEM_INDEX: ColIndex = 8; // = COL_INDEX_LEAF if flag = 1, otherwise = COL_INDEX_ROOT + +const INITIAL_COLS_DATA_LEFT: ColIndex = 9; +const INITIAL_COLS_DATA_RIGHT: ColIndex = INITIAL_COLS_DATA_LEFT + VECTOR_LEN; +const INITIAL_COLS_DATA_RES: ColIndex = INITIAL_COLS_DATA_RIGHT + VECTOR_LEN; + +// "virtual" columns (vectorized lookups into memory) +const COL_LOOKUP_MEM_VALUES: ColIndex = INITIAL_COLS_DATA_RES + VECTOR_LEN; + +const TOTAL_N_COLS: usize = COL_LOOKUP_MEM_VALUES + VECTOR_LEN; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MerklePrecompile; + +impl TableT for MerklePrecompile { + fn name(&self) -> &'static str { + "merkle_verify" + } + + fn identifier(&self) -> Table { + Table::merkle() + } + + fn commited_columns_f(&self) -> Vec { + (0..COL_LOOKUP_MEM_VALUES).collect() + } + + fn commited_columns_ef(&self) -> Vec { + vec![] + } + + fn normal_lookups_f(&self) -> Vec { + vec![] + } + + fn normal_lookups_ef(&self) -> Vec { + vec![] + } + + fn vector_lookups(&self) -> Vec { + vec![VectorLookupIntoMemory { + index: COL_LOOKUP_MEM_INDEX, + values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES + i), + }] + } + + fn buses(&self) -> Vec { + vec![ + Bus { + table: BusTable::Constant(self.identifier()), + direction: BusDirection::Pull, + selector: BusSelector::Column(COL_FLAG), + data: vec![COL_INDEX_LEAF, COL_LEAF_POSITION, COL_INDEX_ROOT, COL_HEIGHT], + }, + Bus { + table: BusTable::Constant(Table::poseidon16_core()), + direction: BusDirection::Push, + selector: BusSelector::ConstantOne, + data: [ + vec![COL_ONE], // Compression + (INITIAL_COLS_DATA_LEFT..INITIAL_COLS_DATA_LEFT + 8).collect::>(), + (INITIAL_COLS_DATA_RIGHT..INITIAL_COLS_DATA_RIGHT + 8).collect::>(), + (INITIAL_COLS_DATA_RES..INITIAL_COLS_DATA_RES + 8).collect::>(), + vec![COL_ZERO; VECTOR_LEN], // Padding + ] + .concat(), + }, + ] + } + + fn padding_row_f(&self) -> Vec { + let default_root = get_poseidon_16_of_zero()[..VECTOR_LEN].to_vec(); + [ + vec![ + F::ONE, // flag + F::ZERO, // index_leaf + F::ZERO, // leaf_position + F::from_usize(POSEIDON_16_NULL_HASH_PTR), // index_root + F::ONE, + F::ZERO, // col_zero + F::ONE, // col_one + F::ZERO, // is_left + F::from_usize(ZERO_VEC_PTR), // lookup_mem_index + ], + vec![F::ZERO; VECTOR_LEN], // data_left + vec![F::ZERO; VECTOR_LEN], // data_right + default_root.clone(), // data_res + vec![F::ZERO; VECTOR_LEN], // lookup_mem_values + ] + .concat() + } + + fn padding_row_ef(&self) -> Vec { + vec![] + } + + #[inline(always)] + fn execute( + &self, + index_leaf: F, + leaf_position: F, + index_root: F, + height: usize, + ctx: &mut InstructionContext<'_>, + ) -> Result<(), RunnerError> { + assert!(height >= 2); + + let leaf_position = leaf_position.to_usize(); + assert!(height > 0); + assert!(leaf_position < (1 << height)); + + let auth_path = ctx.merkle_path_hints.pop_front().unwrap(); + assert_eq!(auth_path.len(), height); + let mut leaf_position_bools = to_big_endian_in_field::(!leaf_position, height); + leaf_position_bools.reverse(); // little-endian + + let leaf = ctx.memory.get_vector(index_leaf.to_usize())?; + + { + let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; + trace[COL_FLAG].extend([vec![F::ONE], vec![F::ZERO; height - 1]].concat()); + trace[COL_INDEX_LEAF].extend(vec![index_leaf; height]); + trace[COL_LEAF_POSITION].extend((0..height).map(|d| F::from_usize(leaf_position >> d))); + trace[COL_INDEX_ROOT].extend(vec![index_root; height]); + trace[COL_HEIGHT].extend((1..=height).rev().map(F::from_usize)); + trace[COL_ZERO].extend(vec![F::ZERO; height]); + trace[COL_ONE].extend(vec![F::ONE; height]); + trace[COL_IS_LEFT].extend(leaf_position_bools); + trace[COL_LOOKUP_MEM_INDEX].extend([vec![index_leaf], vec![index_root; height - 1]].concat()); + } + + let mut current_hash = leaf; + for (d, neightbour) in auth_path.iter().enumerate() { + let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; + + let is_left = (leaf_position >> d) & 1 == 0; + + // TODO precompute (in parallel + SIMD) poseidons + + let (data_left, data_right) = if is_left { + (current_hash, *neightbour) + } else { + (*neightbour, current_hash) + }; + for i in 0..VECTOR_LEN { + trace[INITIAL_COLS_DATA_LEFT + i].push(data_left[i]); + trace[INITIAL_COLS_DATA_RIGHT + i].push(data_right[i]); + } + + let mut input = [F::ZERO; VECTOR_LEN * 2]; + input[..VECTOR_LEN].copy_from_slice(&data_left); + input[VECTOR_LEN..].copy_from_slice(&data_right); + + let output = match ctx.poseidon16_precomputed.get(*ctx.n_poseidon16_precomputed_used) { + Some(precomputed) if precomputed.0 == input => { + *ctx.n_poseidon16_precomputed_used += 1; + precomputed.1 + } + _ => poseidon16_permute(input), + }; + + current_hash = output[..VECTOR_LEN].try_into().unwrap(); + for i in 0..VECTOR_LEN { + trace[INITIAL_COLS_DATA_RES + i].push(current_hash[i]); + } + + add_poseidon_16_core_row(ctx.traces, 1, input, current_hash, [F::ZERO; VECTOR_LEN], true); + } + let root = current_hash; + ctx.memory.set_vector(index_root.to_usize(), root)?; + + let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; + for i in 0..VECTOR_LEN { + trace[COL_LOOKUP_MEM_VALUES + i].extend([vec![leaf[i]], vec![root[i]; height - 1]].concat()); + } + + Ok(()) + } +} + +impl Air for MerklePrecompile { + type ExtraData = ExtraDataForBuses; + fn n_columns_f_air(&self) -> usize { + TOTAL_N_COLS + } + fn n_columns_ef_air(&self) -> usize { + 0 + } + fn degree(&self) -> usize { + 3 + } + fn down_column_indexes_f(&self) -> Vec { + (0..TOTAL_N_COLS - 2 * VECTOR_LEN).collect() + } + fn down_column_indexes_ef(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + 12 + 5 * VECTOR_LEN + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up = builder.up_f(); + let flag = up[COL_FLAG].clone(); + let index_leaf = up[COL_INDEX_LEAF].clone(); + let leaf_position = up[COL_LEAF_POSITION].clone(); + let index_root = up[COL_INDEX_ROOT].clone(); + let height = up[COL_HEIGHT].clone(); + let col_zero = up[COL_ZERO].clone(); + let col_one = up[COL_ONE].clone(); + let is_left = up[COL_IS_LEFT].clone(); + let lookup_index = up[COL_LOOKUP_MEM_INDEX].clone(); + let data_left: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_LEFT + i].clone()); + let data_right: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RIGHT + i].clone()); + let data_res: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RES + i].clone()); + let lookup_values: [_; VECTOR_LEN] = array::from_fn(|i| up[COL_LOOKUP_MEM_VALUES + i].clone()); + + let down = builder.down_f(); + let flag_down = down[0].clone(); + let index_leaf_down = down[1].clone(); + let leaf_position_down = down[2].clone(); + let index_root_down = down[3].clone(); + let height_down = down[4].clone(); + let _col_zero_down = down[5].clone(); + let _col_one_down = down[6].clone(); + let is_left_down = down[7].clone(); + let _lookup_index_down = down[8].clone(); + let data_left_down: [_; VECTOR_LEN] = array::from_fn(|i| down[9 + i].clone()); + let data_right_down: [_; VECTOR_LEN] = array::from_fn(|i| down[9 + VECTOR_LEN + i].clone()); + + let mut core_bus_data = [AB::F::ZERO; 1 + 2 * 16]; + core_bus_data[0] = col_one.clone(); // Compression + core_bus_data[1..9].clone_from_slice(&data_left); + core_bus_data[9..17].clone_from_slice(&data_right); + core_bus_data[17..25].clone_from_slice(&data_res); + core_bus_data[25..].clone_from_slice(&vec![col_zero.clone(); VECTOR_LEN]); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.identifier().index()), + flag.clone(), + &[ + index_leaf.clone(), + leaf_position.clone(), + index_root.clone(), + height.clone(), + ], + )); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(Table::poseidon16_core().index()), + AB::F::ONE, + &core_bus_data, + )); + + // TODO double check constraints + + builder.assert_eq(col_one.clone(), AB::F::ONE); + builder.assert_eq(col_zero.clone(), AB::F::ZERO); + + builder.assert_bool(flag.clone()); + builder.assert_bool(is_left.clone()); + + let not_flag = AB::F::ONE - flag.clone(); + let not_flag_down = AB::F::ONE - flag_down.clone(); + let is_right = AB::F::ONE - is_left.clone(); + let is_right_down = AB::F::ONE - is_left_down.clone(); + + builder.assert_eq( + lookup_index.clone(), + flag.clone() * index_leaf.clone() + not_flag.clone() * index_root.clone(), + ); + + // Parameters should not change as long as the flag has not been switched back to 1: + builder.assert_zero(not_flag_down.clone() * (index_leaf_down.clone() - index_leaf.clone())); + builder.assert_zero(not_flag_down.clone() * (index_root_down.clone() - index_root.clone())); + + // decrease height by 1 each step + builder.assert_zero(not_flag_down.clone() * (height_down.clone() + AB::F::ONE - height.clone())); + + builder.assert_zero( + not_flag_down.clone() + * ((leaf_position_down.clone() * AB::F::TWO + is_right.clone()) - leaf_position.clone()), + ); + + // start (bottom of the tree) + let starts_and_is_left = flag.clone() * is_left.clone(); + for i in 0..VECTOR_LEN { + builder.assert_zero(starts_and_is_left.clone() * (data_left[i].clone() - lookup_values[i].clone())); + } + let starts_and_is_right = flag.clone() * is_right.clone(); + for i in 0..VECTOR_LEN { + builder.assert_zero(starts_and_is_right.clone() * (data_right[i].clone() - lookup_values[i].clone())); + } + + // transition (interior nodes) + let transition_left = not_flag_down.clone() * is_left_down.clone(); + for i in 0..VECTOR_LEN { + builder.assert_zero(transition_left.clone() * (data_left_down[i].clone() - data_res[i].clone())); + } + let transition_right = not_flag_down.clone() * is_right_down.clone(); + for i in 0..VECTOR_LEN { + builder.assert_zero(transition_right.clone() * (data_right_down[i].clone() - data_res[i].clone())); + } + + // end (top of the tree) + builder.assert_zero(flag_down.clone() * (height.clone() - AB::F::ONE)); // at last step, height should be 1 + builder.assert_zero(flag_down.clone() * leaf_position.clone() * (AB::F::ONE - leaf_position.clone())); // at last step, leaf position should be boolean + for i in 0..VECTOR_LEN { + builder + .assert_zero(not_flag.clone() * flag_down.clone() * (data_res[i].clone() - lookup_values[i].clone())); + } + } +} diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index a31d27c9..4f08ebde 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -16,5 +16,14 @@ pub use table_trait::*; mod execution; pub use execution::*; +mod merkle; +pub use merkle::*; + +mod slice_hash; +pub use slice_hash::*; + +mod eq_poly_base_ext; +pub use eq_poly_base_ext::*; + mod utils; pub(crate) use utils::*; diff --git a/crates/lean_vm/src/tables/poseidon_16/core.rs b/crates/lean_vm/src/tables/poseidon_16/core.rs new file mode 100644 index 00000000..15619c33 --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_16/core.rs @@ -0,0 +1,151 @@ +use std::collections::BTreeMap; + +use crate::*; +use multilinear_toolkit::prelude::*; +use p3_air::Air; +use utils::get_poseidon_16_of_zero; + +const POSEIDON_16_CORE_COL_FLAG: ColIndex = 0; +pub const POSEIDON_16_CORE_COL_COMPRESSION: ColIndex = 1; +pub const POSEIDON_16_CORE_COL_INPUT_START: ColIndex = 2; +// virtual via GKR +pub const POSEIDON_16_CORE_COL_OUTPUT_START: ColIndex = POSEIDON_16_CORE_COL_INPUT_START + 16; +// intermediate columns ("commited cubes") are not handled here + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Poseidon16CorePrecompile; + +impl TableT for Poseidon16CorePrecompile { + fn name(&self) -> &'static str { + "poseidon16_core" + } + + fn identifier(&self) -> Table { + Table::poseidon16_core() + } + + fn n_columns_f_total(&self) -> usize { + 2 + 16 * 2 + } + + fn commited_columns_f(&self) -> Vec { + [ + vec![POSEIDON_16_CORE_COL_FLAG, POSEIDON_16_CORE_COL_COMPRESSION], + (POSEIDON_16_CORE_COL_INPUT_START..POSEIDON_16_CORE_COL_INPUT_START + 16).collect::>(), + ] + .concat() + // (committed cubes are handled elsewhere) + } + + fn commited_columns_ef(&self) -> Vec { + vec![] + } + + fn normal_lookups_f(&self) -> Vec { + vec![] + } + + fn normal_lookups_ef(&self) -> Vec { + vec![] + } + + fn vector_lookups(&self) -> Vec { + vec![] + } + + fn buses(&self) -> Vec { + vec![Bus { + table: BusTable::Constant(self.identifier()), + direction: BusDirection::Pull, + selector: BusSelector::Column(POSEIDON_16_CORE_COL_FLAG), + data: [ + vec![POSEIDON_16_CORE_COL_COMPRESSION], + (POSEIDON_16_CORE_COL_INPUT_START..POSEIDON_16_CORE_COL_INPUT_START + 16).collect::>(), + (POSEIDON_16_CORE_COL_OUTPUT_START..POSEIDON_16_CORE_COL_OUTPUT_START + 16).collect::>(), + ] + .concat(), + }] + } + + fn padding_row_f(&self) -> Vec { + let mut poseidon_of_zero = *get_poseidon_16_of_zero(); + if POSEIDON_16_DEFAULT_COMPRESSION { + poseidon_of_zero[8..].fill(F::ZERO); + } + [ + vec![F::ZERO, F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION)], + vec![F::ZERO; 16], + poseidon_of_zero.to_vec(), + ] + .concat() + } + + fn padding_row_ef(&self) -> Vec { + vec![] + } + + #[inline(always)] + fn execute(&self, _: F, _: F, _: F, _: usize, _: &mut InstructionContext<'_>) -> Result<(), RunnerError> { + unreachable!() + } +} + +impl Air for Poseidon16CorePrecompile { + type ExtraData = ExtraDataForBuses; + fn n_columns_f_air(&self) -> usize { + 2 + 16 * 2 + } + fn n_columns_ef_air(&self) -> usize { + 0 + } + fn degree(&self) -> usize { + 1 + } + fn down_column_indexes_f(&self) -> Vec { + vec![] + } + fn down_column_indexes_ef(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + 1 + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up = builder.up_f(); + let flag = up[POSEIDON_16_CORE_COL_FLAG].clone(); + let mut data = [AB::F::ZERO; 1 + 2 * 16]; + data[0] = up[POSEIDON_16_CORE_COL_COMPRESSION].clone(); + data[1..17].clone_from_slice(&up[POSEIDON_16_CORE_COL_INPUT_START..][..16]); + data[17..33].clone_from_slice(&up[POSEIDON_16_CORE_COL_OUTPUT_START..][..16]); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.identifier().index()), + flag.clone(), + &data, + )); + } +} + +pub fn add_poseidon_16_core_row( + traces: &mut BTreeMap, + multiplicity: usize, + input: [F; 16], + res_a: [F; 8], + res_b: [F; 8], + is_compression: bool, +) { + let trace = traces.get_mut(&Table::poseidon16_core()).unwrap(); + + trace.base[POSEIDON_16_CORE_COL_FLAG].push(F::from_usize(multiplicity)); + trace.base[POSEIDON_16_CORE_COL_COMPRESSION].push(F::from_bool(is_compression)); + for (i, value) in input.iter().enumerate() { + trace.base[POSEIDON_16_CORE_COL_INPUT_START + i].push(*value); + } + for (i, value) in res_a.iter().enumerate() { + trace.base[POSEIDON_16_CORE_COL_OUTPUT_START + i].push(*value); + } + for (i, value) in res_b.iter().enumerate() { + trace.base[POSEIDON_16_CORE_COL_OUTPUT_START + 8 + i].push(*value); + } +} diff --git a/crates/lean_vm/src/tables/poseidon_16/from_memory.rs b/crates/lean_vm/src/tables/poseidon_16/from_memory.rs new file mode 100644 index 00000000..445cf04f --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_16/from_memory.rs @@ -0,0 +1,246 @@ +use std::array; + +use crate::*; +use multilinear_toolkit::prelude::*; +use p3_air::Air; +use utils::{ToUsize, get_poseidon_16_of_zero, poseidon16_permute}; + +const POSEIDON_16_MEM_COL_FLAG: ColIndex = 0; +const POSEIDON_16_MEM_COL_INDEX_RES: ColIndex = 1; +const POSEIDON_16_MEM_COL_INDEX_RES_BIS: ColIndex = 2; // = if compressed { 0 } else { POSEIDON_16_COL_INDEX_RES + 1 } +const POSEIDON_16_MEM_COL_COMPRESSION: ColIndex = 3; +const POSEIDON_16_MEM_COL_INDEX_A: ColIndex = 4; +const POSEIDON_16_MEM_COL_INDEX_B: ColIndex = 5; +const POSEIDON_16_MEM_COL_INPUT_START: ColIndex = 6; +const POSEIDON_16_MEM_COL_OUTPUT_START: ColIndex = POSEIDON_16_MEM_COL_INPUT_START + 16; +// intermediate columns ("commited cubes") are not handled here + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Poseidon16MemPrecompile; + +impl TableT for Poseidon16MemPrecompile { + fn name(&self) -> &'static str { + "poseidon16" + } + + fn identifier(&self) -> Table { + Table::poseidon16_mem() + } + + fn n_columns_f_total(&self) -> usize { + 6 + 16 * 2 + } + + fn commited_columns_f(&self) -> Vec { + vec![ + POSEIDON_16_MEM_COL_FLAG, + POSEIDON_16_MEM_COL_INDEX_RES, + POSEIDON_16_MEM_COL_INDEX_RES_BIS, + POSEIDON_16_MEM_COL_COMPRESSION, + POSEIDON_16_MEM_COL_INDEX_A, + POSEIDON_16_MEM_COL_INDEX_B, + ] // (committed cubes are handled elsewhere) + } + + fn commited_columns_ef(&self) -> Vec { + vec![] + } + + fn normal_lookups_f(&self) -> Vec { + vec![] + } + + fn normal_lookups_ef(&self) -> Vec { + vec![] + } + + fn vector_lookups(&self) -> Vec { + vec![ + VectorLookupIntoMemory { + index: POSEIDON_16_MEM_COL_INDEX_A, + values: array::from_fn(|i| POSEIDON_16_MEM_COL_INPUT_START + i), + }, + VectorLookupIntoMemory { + index: POSEIDON_16_MEM_COL_INDEX_B, + values: array::from_fn(|i| POSEIDON_16_MEM_COL_INPUT_START + VECTOR_LEN + i), + }, + VectorLookupIntoMemory { + index: POSEIDON_16_MEM_COL_INDEX_RES, + values: array::from_fn(|i| POSEIDON_16_MEM_COL_OUTPUT_START + i), + }, + VectorLookupIntoMemory { + index: POSEIDON_16_MEM_COL_INDEX_RES_BIS, + values: array::from_fn(|i| POSEIDON_16_MEM_COL_OUTPUT_START + VECTOR_LEN + i), + }, + ] + } + + fn buses(&self) -> Vec { + vec![ + Bus { + table: BusTable::Constant(self.identifier()), + direction: BusDirection::Pull, + selector: BusSelector::Column(POSEIDON_16_MEM_COL_FLAG), + data: vec![ + POSEIDON_16_MEM_COL_INDEX_A, + POSEIDON_16_MEM_COL_INDEX_B, + POSEIDON_16_MEM_COL_INDEX_RES, + POSEIDON_16_MEM_COL_COMPRESSION, + ], + }, + Bus { + table: BusTable::Constant(Table::poseidon16_core()), + direction: BusDirection::Push, + selector: BusSelector::ConstantOne, + data: [ + vec![POSEIDON_16_MEM_COL_COMPRESSION], + (POSEIDON_16_MEM_COL_INPUT_START..POSEIDON_16_MEM_COL_INPUT_START + 16).collect::>(), + (POSEIDON_16_MEM_COL_OUTPUT_START..POSEIDON_16_MEM_COL_OUTPUT_START + 16) + .collect::>(), + ] + .concat(), + }, + ] + } + + fn padding_row_f(&self) -> Vec { + let mut poseidon_of_zero = *get_poseidon_16_of_zero(); + if POSEIDON_16_DEFAULT_COMPRESSION { + poseidon_of_zero[8..].fill(F::ZERO); + } + [ + vec![ + F::ZERO, + F::from_usize(POSEIDON_16_NULL_HASH_PTR), + F::from_usize(if POSEIDON_16_DEFAULT_COMPRESSION { + ZERO_VEC_PTR + } else { + 1 + POSEIDON_16_NULL_HASH_PTR + }), + F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION), + F::from_usize(ZERO_VEC_PTR), + F::from_usize(ZERO_VEC_PTR), + ], + vec![F::ZERO; 16], + poseidon_of_zero.to_vec(), + ] + .concat() + } + + fn padding_row_ef(&self) -> Vec { + vec![] + } + + #[inline(always)] + fn execute( + &self, + arg_a: F, + arg_b: F, + index_res_a: F, + is_compression: usize, + ctx: &mut InstructionContext<'_>, + ) -> Result<(), RunnerError> { + assert!(is_compression == 0 || is_compression == 1); + let is_compression = is_compression == 1; + let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); + + let arg0 = ctx.memory.get_vector(arg_a.to_usize())?; + let arg1 = ctx.memory.get_vector(arg_b.to_usize())?; + + let mut input = [F::ZERO; VECTOR_LEN * 2]; + input[..VECTOR_LEN].copy_from_slice(&arg0); + input[VECTOR_LEN..].copy_from_slice(&arg1); + + let output = match ctx.poseidon16_precomputed.get(*ctx.n_poseidon16_precomputed_used) { + Some(precomputed) if precomputed.0 == input => { + *ctx.n_poseidon16_precomputed_used += 1; + precomputed.1 + } + _ => poseidon16_permute(input), + }; + + let res_a: [F; VECTOR_LEN] = output[..VECTOR_LEN].try_into().unwrap(); + let (index_res_b, res_b): (F, [F; VECTOR_LEN]) = if is_compression { + (F::from_usize(ZERO_VEC_PTR), [F::ZERO; VECTOR_LEN]) + } else { + (index_res_a + F::ONE, output[VECTOR_LEN..].try_into().unwrap()) + }; + + ctx.memory.set_vector(index_res_a.to_usize(), res_a)?; + ctx.memory.set_vector(index_res_b.to_usize(), res_b)?; + + trace.base[POSEIDON_16_MEM_COL_FLAG].push(F::ONE); + trace.base[POSEIDON_16_MEM_COL_INDEX_A].push(arg_a); + trace.base[POSEIDON_16_MEM_COL_INDEX_B].push(arg_b); + trace.base[POSEIDON_16_MEM_COL_INDEX_RES].push(index_res_a); + trace.base[POSEIDON_16_MEM_COL_INDEX_RES_BIS].push(index_res_b); + trace.base[POSEIDON_16_MEM_COL_COMPRESSION].push(F::from_bool(is_compression)); + for (i, value) in input.iter().enumerate() { + trace.base[POSEIDON_16_MEM_COL_INPUT_START + i].push(*value); + } + for (i, value) in res_a.iter().enumerate() { + trace.base[POSEIDON_16_MEM_COL_OUTPUT_START + i].push(*value); + } + for (i, value) in res_b.iter().enumerate() { + trace.base[POSEIDON_16_MEM_COL_OUTPUT_START + 8 + i].push(*value); + } + + add_poseidon_16_core_row(ctx.traces, 1, input, res_a, res_b, is_compression); + + Ok(()) + } +} + +impl Air for Poseidon16MemPrecompile { + type ExtraData = ExtraDataForBuses; + fn n_columns_f_air(&self) -> usize { + 6 + 16 * 2 + } + fn n_columns_ef_air(&self) -> usize { + 0 + } + fn degree(&self) -> usize { + 2 + } + fn down_column_indexes_f(&self) -> Vec { + vec![] + } + fn down_column_indexes_ef(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + 5 + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up = builder.up_f(); + let flag = up[POSEIDON_16_MEM_COL_FLAG].clone(); + let index_res = up[POSEIDON_16_MEM_COL_INDEX_RES].clone(); + let index_res_bis = up[POSEIDON_16_MEM_COL_INDEX_RES_BIS].clone(); + let compression = up[POSEIDON_16_MEM_COL_COMPRESSION].clone(); + let index_a = up[POSEIDON_16_MEM_COL_INDEX_A].clone(); + let index_b = up[POSEIDON_16_MEM_COL_INDEX_B].clone(); + + let mut core_bus_data = [AB::F::ZERO; 1 + 2 * 16]; + core_bus_data[0] = up[POSEIDON_16_MEM_COL_COMPRESSION].clone(); + core_bus_data[1..17].clone_from_slice(&up[POSEIDON_16_MEM_COL_INPUT_START..][..16]); + core_bus_data[17..33].clone_from_slice(&up[POSEIDON_16_MEM_COL_OUTPUT_START..][..16]); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.identifier().index()), + flag.clone(), + &[index_a.clone(), index_b.clone(), index_res.clone(), compression.clone()], + )); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(Table::poseidon16_core().index()), + AB::F::ONE, + &core_bus_data, + )); + + builder.assert_bool(flag.clone()); + builder.assert_bool(compression.clone()); + builder.assert_eq(index_res_bis, (index_res + AB::F::ONE) * (AB::F::ONE - compression)); + } +} diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index ee9a36e5..b35d4130 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -1,222 +1,7 @@ -use std::array; +mod core; +pub use core::*; -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use utils::{ToUsize, get_poseidon_16_of_zero, poseidon16_permute}; +mod from_memory; +pub use from_memory::*; -const POSEIDON_16_DEFAULT_COMPRESSION: bool = true; - -const POSEIDON_16_COL_FLAG: ColIndex = 0; -const POSEIDON_16_COL_INDEX_RES: ColIndex = 1; -const POSEIDON_16_COL_INDEX_RES_BIS: ColIndex = 2; // = if compressed { 0 } else { POSEIDON_16_COL_INDEX_RES + 1 } -pub const POSEIDON_16_COL_COMPRESSION: ColIndex = 3; -const POSEIDON_16_COL_INDEX_A: ColIndex = 4; -const POSEIDON_16_COL_INDEX_B: ColIndex = 5; -pub const POSEIDON_16_COL_INDEX_INPUT_START: ColIndex = 6; -const POSEIDON_16_COL_INDEX_OUTPUT_START: ColIndex = POSEIDON_16_COL_INDEX_INPUT_START + 16; -// intermediate columns ("commited cubes") are not handled here - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Poseidon16Precompile; - -impl TableT for Poseidon16Precompile { - fn name(&self) -> &'static str { - "poseidon16" - } - - fn identifier(&self) -> Table { - Table::poseidon16() - } - - fn n_columns_f_total(&self) -> usize { - 6 + 16 * 2 - } - - fn commited_columns_f(&self) -> Vec { - vec![ - POSEIDON_16_COL_FLAG, - POSEIDON_16_COL_INDEX_RES, - POSEIDON_16_COL_INDEX_RES_BIS, - POSEIDON_16_COL_COMPRESSION, - POSEIDON_16_COL_INDEX_A, - POSEIDON_16_COL_INDEX_B, - ] // (committed cubes are handled elsewhere) - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![ - VectorLookupIntoMemory { - index: POSEIDON_16_COL_INDEX_A, - values: array::from_fn(|i| POSEIDON_16_COL_INDEX_INPUT_START + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_16_COL_INDEX_B, - values: array::from_fn(|i| POSEIDON_16_COL_INDEX_INPUT_START + VECTOR_LEN + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_16_COL_INDEX_RES, - values: array::from_fn(|i| POSEIDON_16_COL_INDEX_OUTPUT_START + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_16_COL_INDEX_RES_BIS, - values: array::from_fn(|i| POSEIDON_16_COL_INDEX_OUTPUT_START + VECTOR_LEN + i), - }, - ] - } - - fn buses(&self) -> Vec { - vec![Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: POSEIDON_16_COL_FLAG, - data: vec![ - POSEIDON_16_COL_INDEX_A, - POSEIDON_16_COL_INDEX_B, - POSEIDON_16_COL_INDEX_RES, - POSEIDON_16_COL_COMPRESSION, - ], - }] - } - - fn padding_row_f(&self) -> Vec { - let mut poseidon_of_zero = *get_poseidon_16_of_zero(); - if POSEIDON_16_DEFAULT_COMPRESSION { - poseidon_of_zero[8..].fill(F::ZERO); - } - [ - vec![ - F::ZERO, - F::from_usize(POSEIDON_16_NULL_HASH_PTR), - F::from_usize(if POSEIDON_16_DEFAULT_COMPRESSION { - ZERO_VEC_PTR - } else { - 1 + POSEIDON_16_NULL_HASH_PTR - }), - F::from_bool(POSEIDON_16_DEFAULT_COMPRESSION), - F::from_usize(ZERO_VEC_PTR), - F::from_usize(ZERO_VEC_PTR), - ], - vec![F::ZERO; 16], - poseidon_of_zero.to_vec(), - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute( - &self, - arg_a: F, - arg_b: F, - index_res_a: F, - is_compression: usize, - ctx: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { - assert!(is_compression == 0 || is_compression == 1); - let is_compression = is_compression == 1; - let trace = &mut ctx.traces[self.identifier().index()]; - - let arg0 = ctx.memory.get_vector(arg_a.to_usize())?; - let arg1 = ctx.memory.get_vector(arg_b.to_usize())?; - - let mut input = [F::ZERO; VECTOR_LEN * 2]; - input[..VECTOR_LEN].copy_from_slice(&arg0); - input[VECTOR_LEN..].copy_from_slice(&arg1); - - let output = match ctx.poseidon16_precomputed.get(*ctx.n_poseidon16_precomputed_used) { - Some(precomputed) if precomputed.0 == input => { - *ctx.n_poseidon16_precomputed_used += 1; - precomputed.1 - } - _ => poseidon16_permute(input), - }; - - let res_a: [F; VECTOR_LEN] = output[..VECTOR_LEN].try_into().unwrap(); - let (index_res_b, res_b): (F, [F; VECTOR_LEN]) = if is_compression { - (F::from_usize(ZERO_VEC_PTR), [F::ZERO; VECTOR_LEN]) - } else { - (index_res_a + F::ONE, output[VECTOR_LEN..].try_into().unwrap()) - }; - - ctx.memory.set_vector(index_res_a.to_usize(), res_a)?; - ctx.memory.set_vector(index_res_b.to_usize(), res_b)?; - - trace.base[POSEIDON_16_COL_FLAG].push(F::ONE); - trace.base[POSEIDON_16_COL_INDEX_A].push(arg_a); - trace.base[POSEIDON_16_COL_INDEX_B].push(arg_b); - trace.base[POSEIDON_16_COL_INDEX_RES].push(index_res_a); - trace.base[POSEIDON_16_COL_INDEX_RES_BIS].push(index_res_b); - trace.base[POSEIDON_16_COL_COMPRESSION].push(F::from_bool(is_compression)); - for (i, value) in input.iter().enumerate() { - trace.base[POSEIDON_16_COL_INDEX_INPUT_START + i].push(*value); - } - for (i, value) in res_a.iter().enumerate() { - trace.base[POSEIDON_16_COL_INDEX_OUTPUT_START + i].push(*value); - } - for (i, value) in res_b.iter().enumerate() { - trace.base[POSEIDON_16_COL_INDEX_OUTPUT_START + 8 + i].push(*value); - } - Ok(()) - } -} - -impl Air for Poseidon16Precompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - 6 - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 2 - } - fn down_column_indexes_f(&self) -> Vec { - vec![] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 4 - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[POSEIDON_16_COL_FLAG].clone(); - let index_res = up[POSEIDON_16_COL_INDEX_RES].clone(); - let index_res_bis = up[POSEIDON_16_COL_INDEX_RES_BIS].clone(); - let compression = up[POSEIDON_16_COL_COMPRESSION].clone(); - let index_a = up[POSEIDON_16_COL_INDEX_A].clone(); - let index_b = up[POSEIDON_16_COL_INDEX_B].clone(); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - index_a.clone(), - index_b.clone(), - index_res.clone(), - compression.clone(), - )); - - builder.assert_bool(flag.clone()); - builder.assert_bool(compression.clone()); - builder.assert_eq(index_res_bis, (index_res + AB::F::ONE) * (AB::F::ONE - compression)); - } -} +pub const POSEIDON_16_DEFAULT_COMPRESSION: bool = true; diff --git a/crates/lean_vm/src/tables/poseidon_24/core.rs b/crates/lean_vm/src/tables/poseidon_24/core.rs new file mode 100644 index 00000000..8263bc8b --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_24/core.rs @@ -0,0 +1,138 @@ +use std::collections::BTreeMap; + +use crate::*; +use multilinear_toolkit::prelude::*; +use p3_air::Air; +use utils::get_poseidon_24_of_zero; + +const POSEIDON_24_CORE_COL_FLAG: ColIndex = 0; +pub const POSEIDON_24_CORE_COL_INPUT_START: ColIndex = 1; +// virtual via GKR +pub const POSEIDON_24_CORE_COL_OUTPUT_START: ColIndex = POSEIDON_24_CORE_COL_INPUT_START + 24; +// intermediate columns ("commited cubes") are not handled here + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Poseidon24CorePrecompile; + +impl TableT for Poseidon24CorePrecompile { + fn name(&self) -> &'static str { + "poseidon24_core" + } + + fn identifier(&self) -> Table { + Table::poseidon24_core() + } + + fn n_columns_f_total(&self) -> usize { + 1 + 24 + 8 + } + + fn commited_columns_f(&self) -> Vec { + [ + vec![POSEIDON_24_CORE_COL_FLAG], + (POSEIDON_24_CORE_COL_INPUT_START..POSEIDON_24_CORE_COL_INPUT_START + 24).collect::>(), + ] + .concat() + // (committed cubes are handled elsewhere) + } + + fn commited_columns_ef(&self) -> Vec { + vec![] + } + + fn normal_lookups_f(&self) -> Vec { + vec![] + } + + fn normal_lookups_ef(&self) -> Vec { + vec![] + } + + fn vector_lookups(&self) -> Vec { + vec![] + } + + fn buses(&self) -> Vec { + vec![Bus { + table: BusTable::Constant(self.identifier()), + direction: BusDirection::Pull, + selector: BusSelector::Column(POSEIDON_24_CORE_COL_FLAG), + data: [ + (POSEIDON_24_CORE_COL_INPUT_START..POSEIDON_24_CORE_COL_INPUT_START + 24).collect::>(), + (POSEIDON_24_CORE_COL_OUTPUT_START..POSEIDON_24_CORE_COL_OUTPUT_START + 8).collect::>(), + ] + .concat(), + }] + } + + fn padding_row_f(&self) -> Vec { + [ + vec![F::ZERO], + vec![F::ZERO; 24], + get_poseidon_24_of_zero()[16..].to_vec(), + ] + .concat() + } + + fn padding_row_ef(&self) -> Vec { + vec![] + } + + #[inline(always)] + fn execute(&self, _: F, _: F, _: F, _: usize, _: &mut InstructionContext<'_>) -> Result<(), RunnerError> { + unreachable!() + } +} + +impl Air for Poseidon24CorePrecompile { + type ExtraData = ExtraDataForBuses; + fn n_columns_f_air(&self) -> usize { + 1 + 24 + 8 + } + fn n_columns_ef_air(&self) -> usize { + 0 + } + fn degree(&self) -> usize { + 2 + } + fn down_column_indexes_f(&self) -> Vec { + vec![] + } + fn down_column_indexes_ef(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + 1 + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up = builder.up_f(); + let flag = up[POSEIDON_24_CORE_COL_FLAG].clone(); + let mut data = [AB::F::ZERO; 24 + 8]; + data[0..24].clone_from_slice(&up[POSEIDON_24_CORE_COL_INPUT_START..][..24]); + data[24..32].clone_from_slice(&up[POSEIDON_24_CORE_COL_OUTPUT_START..][..8]); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.identifier().index()), + flag.clone(), + &data, + )); + } +} + +pub fn add_poseidon_24_core_row( + traces: &mut BTreeMap, + multiplicity: usize, + input: [F; 24], + res: [F; 8], +) { + let trace = traces.get_mut(&Table::poseidon24_core()).unwrap(); + + trace.base[POSEIDON_24_CORE_COL_FLAG].push(F::from_usize(multiplicity)); + for (i, value) in input.iter().enumerate() { + trace.base[POSEIDON_24_CORE_COL_INPUT_START + i].push(*value); + } + for (i, value) in res.iter().enumerate() { + trace.base[POSEIDON_24_CORE_COL_OUTPUT_START + i].push(*value); + } +} diff --git a/crates/lean_vm/src/tables/poseidon_24/from_memory.rs b/crates/lean_vm/src/tables/poseidon_24/from_memory.rs new file mode 100644 index 00000000..d8ff548a --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_24/from_memory.rs @@ -0,0 +1,220 @@ +use crate::*; +use multilinear_toolkit::prelude::*; +use p3_air::Air; +use std::array; +use utils::{ToUsize, get_poseidon_24_of_zero, poseidon24_permute}; + +const POSEIDON_24_MEM_COL_FLAG: ColIndex = 0; +const POSEIDON_24_MEM_COL_INDEX_A: ColIndex = 1; +const POSEIDON_24_MEM_COL_INDEX_A_BIS: ColIndex = 2; +const POSEIDON_24_MEM_COL_INDEX_B: ColIndex = 3; +const POSEIDON_24_MEM_COL_INDEX_RES: ColIndex = 4; +const POSEIDON_24_MEM_COL_INPUT_START: ColIndex = 5; +const POSEIDON_24_MEM_COL_OUTPUT_START: ColIndex = POSEIDON_24_MEM_COL_INPUT_START + 24; +// intermediate columns ("commited cubes") are not handled here + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Poseidon24MemPrecompile; + +impl TableT for Poseidon24MemPrecompile { + fn name(&self) -> &'static str { + "poseidon24" + } + + fn identifier(&self) -> Table { + Table::poseidon24_mem() + } + + fn n_columns_f_total(&self) -> usize { + 5 + 24 + 8 + } + + fn commited_columns_f(&self) -> Vec { + vec![ + POSEIDON_24_MEM_COL_FLAG, + POSEIDON_24_MEM_COL_INDEX_A, + POSEIDON_24_MEM_COL_INDEX_A_BIS, + POSEIDON_24_MEM_COL_INDEX_B, + POSEIDON_24_MEM_COL_INDEX_RES, + ] // indexes only here (committed cubes are handled elsewhere) + } + + fn commited_columns_ef(&self) -> Vec { + vec![] + } + + fn normal_lookups_f(&self) -> Vec { + vec![] + } + + fn normal_lookups_ef(&self) -> Vec { + vec![] + } + + fn vector_lookups(&self) -> Vec { + vec![ + VectorLookupIntoMemory { + index: POSEIDON_24_MEM_COL_INDEX_A, + values: array::from_fn(|i| POSEIDON_24_MEM_COL_INPUT_START + i), + }, + VectorLookupIntoMemory { + index: POSEIDON_24_MEM_COL_INDEX_A_BIS, + values: array::from_fn(|i| POSEIDON_24_MEM_COL_INPUT_START + VECTOR_LEN + i), + }, + VectorLookupIntoMemory { + index: POSEIDON_24_MEM_COL_INDEX_B, + values: array::from_fn(|i| POSEIDON_24_MEM_COL_INPUT_START + 2 * VECTOR_LEN + i), + }, + VectorLookupIntoMemory { + index: POSEIDON_24_MEM_COL_INDEX_RES, + values: array::from_fn(|i| POSEIDON_24_MEM_COL_OUTPUT_START + i), + }, + ] + } + + fn buses(&self) -> Vec { + vec![ + Bus { + table: BusTable::Constant(self.identifier()), + direction: BusDirection::Pull, + selector: BusSelector::Column(POSEIDON_24_MEM_COL_FLAG), + data: vec![ + POSEIDON_24_MEM_COL_INDEX_A, + POSEIDON_24_MEM_COL_INDEX_B, + POSEIDON_24_MEM_COL_INDEX_RES, + ], + }, + Bus { + table: BusTable::Constant(Table::poseidon24_core()), + direction: BusDirection::Push, + selector: BusSelector::ConstantOne, + data: [ + (POSEIDON_24_MEM_COL_INPUT_START..POSEIDON_24_MEM_COL_INPUT_START + 24).collect::>(), + (POSEIDON_24_MEM_COL_OUTPUT_START..POSEIDON_24_MEM_COL_OUTPUT_START + 8).collect::>(), + ] + .concat(), + }, + ] + } + + fn padding_row_f(&self) -> Vec { + [ + vec![ + F::ZERO, + F::from_usize(ZERO_VEC_PTR), + F::from_usize(ZERO_VEC_PTR + 1), + F::from_usize(ZERO_VEC_PTR), + F::from_usize(POSEIDON_24_NULL_HASH_PTR), + ], + vec![F::ZERO; 24], + get_poseidon_24_of_zero()[16..].to_vec(), + ] + .concat() + } + + fn padding_row_ef(&self) -> Vec { + vec![] + } + + #[inline(always)] + fn execute( + &self, + arg_a: F, + arg_b: F, + res: F, + aux: usize, + ctx: &mut InstructionContext<'_>, + ) -> Result<(), RunnerError> { + assert_eq!(aux, 0); // no aux for poseidon24 + let trace = ctx.traces.get_mut(&self.identifier()).unwrap(); + + let arg0 = ctx.memory.get_vector(arg_a.to_usize())?; + let arg1 = ctx.memory.get_vector(1 + arg_a.to_usize())?; + let arg2 = ctx.memory.get_vector(arg_b.to_usize())?; + + let mut input = [F::ZERO; VECTOR_LEN * 3]; + input[..VECTOR_LEN].copy_from_slice(&arg0); + input[VECTOR_LEN..2 * VECTOR_LEN].copy_from_slice(&arg1); + input[2 * VECTOR_LEN..].copy_from_slice(&arg2); + + let output = match ctx.poseidon24_precomputed.get(*ctx.n_poseidon24_precomputed_used) { + Some(precomputed) if precomputed.0 == input => { + *ctx.n_poseidon24_precomputed_used += 1; + precomputed.1 + } + _ => { + let output = poseidon24_permute(input); + output[2 * VECTOR_LEN..].try_into().unwrap() + } + }; + + ctx.memory.set_vector(res.to_usize(), output)?; + + trace.base[POSEIDON_24_MEM_COL_FLAG].push(F::ONE); + trace.base[POSEIDON_24_MEM_COL_INDEX_A].push(arg_a); + trace.base[POSEIDON_24_MEM_COL_INDEX_A_BIS].push(arg_a + F::ONE); + trace.base[POSEIDON_24_MEM_COL_INDEX_B].push(arg_b); + trace.base[POSEIDON_24_MEM_COL_INDEX_RES].push(res); + for (i, value) in input.iter().enumerate() { + trace.base[POSEIDON_24_MEM_COL_INPUT_START + i].push(*value); + } + for (i, value) in output.iter().enumerate() { + trace.base[POSEIDON_24_MEM_COL_OUTPUT_START + i].push(*value); + } + + add_poseidon_24_core_row(ctx.traces, 1, input, output); + + Ok(()) + } +} + +impl Air for Poseidon24MemPrecompile { + type ExtraData = ExtraDataForBuses; + fn n_columns_f_air(&self) -> usize { + 5 + 24 + 8 + } + fn n_columns_ef_air(&self) -> usize { + 0 + } + fn degree(&self) -> usize { + 2 + } + fn down_column_indexes_f(&self) -> Vec { + vec![] + } + fn down_column_indexes_ef(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + 4 + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up = builder.up_f(); + let flag = up[POSEIDON_24_MEM_COL_FLAG].clone(); + let index_res = up[POSEIDON_24_MEM_COL_INDEX_RES].clone(); + let index_input_a = up[POSEIDON_24_MEM_COL_INDEX_A].clone(); + let index_input_a_bis = up[POSEIDON_24_MEM_COL_INDEX_A_BIS].clone(); + let index_b = up[POSEIDON_24_MEM_COL_INDEX_B].clone(); + + let mut core_bus_data = [AB::F::ZERO; 24 + 8]; + core_bus_data[0..24].clone_from_slice(&up[POSEIDON_24_MEM_COL_INPUT_START..][..24]); + core_bus_data[24..32].clone_from_slice(&up[POSEIDON_24_MEM_COL_OUTPUT_START..][..8]); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.identifier().index()), + flag.clone(), + &[index_input_a.clone(), index_b, index_res, AB::F::ZERO], + )); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(Table::poseidon24_core().index()), + AB::F::ONE, + &core_bus_data, + )); + + builder.assert_bool(flag); + builder.assert_eq(index_input_a_bis, index_input_a + AB::F::ONE); + } +} diff --git a/crates/lean_vm/src/tables/poseidon_24/mod.rs b/crates/lean_vm/src/tables/poseidon_24/mod.rs index 3573d6d7..96e662e4 100644 --- a/crates/lean_vm/src/tables/poseidon_24/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_24/mod.rs @@ -1,197 +1,5 @@ -use crate::*; -use multilinear_toolkit::prelude::*; -use p3_air::Air; -use std::array; -use utils::{ToUsize, get_poseidon_24_of_zero, poseidon24_permute}; +mod core; +pub use core::*; -const POSEIDON_24_COL_FLAG: ColIndex = 0; -const POSEIDON_24_COL_INDEX_A: ColIndex = 1; -const POSEIDON_24_COL_INDEX_A_BIS: ColIndex = 2; -const POSEIDON_24_COL_INDEX_B: ColIndex = 3; -const POSEIDON_24_COL_INDEX_RES: ColIndex = 4; -pub const POSEIDON_24_COL_INDEX_INPUT_START: ColIndex = 5; -const POSEIDON_24_COL_INDEX_OUTPUT_START: ColIndex = POSEIDON_24_COL_INDEX_INPUT_START + 24; -// intermediate columns ("commited cubes") are not handled here - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Poseidon24Precompile; - -impl TableT for Poseidon24Precompile { - fn name(&self) -> &'static str { - "poseidon24" - } - - fn identifier(&self) -> Table { - Table::poseidon24() - } - - fn n_columns_f_total(&self) -> usize { - 5 + 24 + 8 - } - - fn commited_columns_f(&self) -> Vec { - vec![ - POSEIDON_24_COL_FLAG, - POSEIDON_24_COL_INDEX_A, - POSEIDON_24_COL_INDEX_A_BIS, - POSEIDON_24_COL_INDEX_B, - POSEIDON_24_COL_INDEX_RES, - ] // indexes only here (committed cubes are handled elsewhere) - } - - fn commited_columns_ef(&self) -> Vec { - vec![] - } - - fn normal_lookups_f(&self) -> Vec { - vec![] - } - - fn normal_lookups_ef(&self) -> Vec { - vec![] - } - - fn vector_lookups(&self) -> Vec { - vec![ - VectorLookupIntoMemory { - index: POSEIDON_24_COL_INDEX_A, - values: array::from_fn(|i| POSEIDON_24_COL_INDEX_INPUT_START + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_24_COL_INDEX_A_BIS, - values: array::from_fn(|i| POSEIDON_24_COL_INDEX_INPUT_START + VECTOR_LEN + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_24_COL_INDEX_B, - values: array::from_fn(|i| POSEIDON_24_COL_INDEX_INPUT_START + 2 * VECTOR_LEN + i), - }, - VectorLookupIntoMemory { - index: POSEIDON_24_COL_INDEX_RES, - values: array::from_fn(|i| POSEIDON_24_COL_INDEX_OUTPUT_START + i), - }, - ] - } - - fn buses(&self) -> Vec { - vec![Bus { - table: BusTable::Constant(self.identifier()), - direction: BusDirection::Pull, - selector: POSEIDON_24_COL_FLAG, - data: vec![ - POSEIDON_24_COL_INDEX_A, - POSEIDON_24_COL_INDEX_B, - POSEIDON_24_COL_INDEX_RES, - ], - }] - } - - fn padding_row_f(&self) -> Vec { - [ - vec![ - F::ZERO, - F::from_usize(ZERO_VEC_PTR), - F::from_usize(ZERO_VEC_PTR + 1), - F::from_usize(ZERO_VEC_PTR), - F::from_usize(POSEIDON_24_NULL_HASH_PTR), - ], - vec![F::ZERO; 24], - get_poseidon_24_of_zero()[16..].to_vec(), - ] - .concat() - } - - fn padding_row_ef(&self) -> Vec { - vec![] - } - - #[inline(always)] - fn execute( - &self, - arg_a: F, - arg_b: F, - res: F, - aux: usize, - ctx: &mut InstructionContext<'_>, - ) -> Result<(), RunnerError> { - assert_eq!(aux, 0); // no aux for poseidon24 - let trace = &mut ctx.traces[self.identifier().index()]; - - let arg0 = ctx.memory.get_vector(arg_a.to_usize())?; - let arg1 = ctx.memory.get_vector(1 + arg_a.to_usize())?; - let arg2 = ctx.memory.get_vector(arg_b.to_usize())?; - - let mut input = [F::ZERO; VECTOR_LEN * 3]; - input[..VECTOR_LEN].copy_from_slice(&arg0); - input[VECTOR_LEN..2 * VECTOR_LEN].copy_from_slice(&arg1); - input[2 * VECTOR_LEN..].copy_from_slice(&arg2); - - let output = match ctx.poseidon24_precomputed.get(*ctx.n_poseidon24_precomputed_used) { - Some(precomputed) if precomputed.0 == input => { - *ctx.n_poseidon24_precomputed_used += 1; - precomputed.1 - } - _ => { - let output = poseidon24_permute(input); - output[2 * VECTOR_LEN..].try_into().unwrap() - } - }; - - ctx.memory.set_vector(res.to_usize(), output)?; - - trace.base[POSEIDON_24_COL_FLAG].push(F::ONE); - trace.base[POSEIDON_24_COL_INDEX_A].push(arg_a); - trace.base[POSEIDON_24_COL_INDEX_A_BIS].push(arg_a + F::ONE); - trace.base[POSEIDON_24_COL_INDEX_B].push(arg_b); - trace.base[POSEIDON_24_COL_INDEX_RES].push(res); - for (i, value) in input.iter().enumerate() { - trace.base[POSEIDON_24_COL_INDEX_INPUT_START + i].push(*value); - } - for (i, value) in output.iter().enumerate() { - trace.base[POSEIDON_24_COL_INDEX_OUTPUT_START + i].push(*value); - } - - Ok(()) - } -} - -impl Air for Poseidon24Precompile { - type ExtraData = ExtraDataForBuses; - fn n_columns_f_air(&self) -> usize { - 5 - } - fn n_columns_ef_air(&self) -> usize { - 0 - } - fn degree(&self) -> usize { - 2 - } - fn down_column_indexes_f(&self) -> Vec { - vec![] - } - fn down_column_indexes_ef(&self) -> Vec { - vec![] - } - fn n_constraints(&self) -> usize { - 3 - } - fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { - let up = builder.up_f(); - let flag = up[POSEIDON_24_COL_FLAG].clone(); - let index_res = up[POSEIDON_24_COL_INDEX_RES].clone(); - let index_input_a = up[POSEIDON_24_COL_INDEX_A].clone(); - let index_input_a_bis = up[POSEIDON_24_COL_INDEX_A_BIS].clone(); - let index_b = up[POSEIDON_24_COL_INDEX_B].clone(); - - builder.eval_virtual_column(eval_virtual_bus_column::( - extra_data, - AB::F::from_usize(self.identifier().index()), - flag.clone(), - index_input_a.clone(), - index_b, - index_res, - AB::F::ZERO, - )); - builder.assert_bool(flag); - builder.assert_eq(index_input_a_bis, index_input_a + AB::F::ONE); - } -} +mod from_memory; +pub use from_memory::*; diff --git a/crates/lean_vm/src/tables/slice_hash/mod.rs b/crates/lean_vm/src/tables/slice_hash/mod.rs new file mode 100644 index 00000000..16156c68 --- /dev/null +++ b/crates/lean_vm/src/tables/slice_hash/mod.rs @@ -0,0 +1,292 @@ +use std::array; + +use crate::*; +use multilinear_toolkit::prelude::*; +use p3_air::Air; +use utils::{ToUsize, get_poseidon_24_of_zero, poseidon24_permute}; + +// Does not support len = 1 (minimum len is 2) + +// "committed" columns +const COL_FLAG: ColIndex = 0; +const COL_INDEX_SEED: ColIndex = 1; // vectorized pointer +const COL_INDEX_START: ColIndex = 2; // vectorized pointer +const COL_INDEX_START_BIS: ColIndex = 3; // = COL_INDEX_START + 1 +const COL_INDEX_RES: ColIndex = 4; // vectorized pointer +const COL_LEN: ColIndex = 5; + +const COL_LOOKUP_MEM_INDEX_SEED_OR_RES: ColIndex = 6; // = COL_INDEX_START if flag = 1, otherwise = COL_INDEX_RES +const INITIAL_COLS_DATA_RIGHT: ColIndex = 7; +const INITIAL_COLS_DATA_RES: ColIndex = INITIAL_COLS_DATA_RIGHT + VECTOR_LEN; + +// "virtual" columns (vectorized lookups into memory) +const COL_LOOKUP_MEM_VALUES_SEED_OR_RES: ColIndex = INITIAL_COLS_DATA_RES + VECTOR_LEN; // 8 columns +const COL_LOOKUP_MEM_VALUES_LEFT: ColIndex = COL_LOOKUP_MEM_VALUES_SEED_OR_RES + VECTOR_LEN; // 16 columns + +const TOTAL_N_COLS: usize = COL_LOOKUP_MEM_VALUES_LEFT + 2 * VECTOR_LEN; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct SliceHashPrecompile; + +impl TableT for SliceHashPrecompile { + fn name(&self) -> &'static str { + "slice_hash" + } + + fn identifier(&self) -> Table { + Table::slice_hash() + } + + fn commited_columns_f(&self) -> Vec { + (0..COL_LOOKUP_MEM_VALUES_SEED_OR_RES).collect() + } + + fn commited_columns_ef(&self) -> Vec { + vec![] + } + + fn normal_lookups_f(&self) -> Vec { + vec![] + } + + fn normal_lookups_ef(&self) -> Vec { + vec![] + } + + fn vector_lookups(&self) -> Vec { + vec![ + VectorLookupIntoMemory { + index: COL_LOOKUP_MEM_INDEX_SEED_OR_RES, + values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES_SEED_OR_RES + i), + }, + VectorLookupIntoMemory { + index: COL_INDEX_START, + values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES_LEFT + i), + }, + VectorLookupIntoMemory { + index: COL_INDEX_START_BIS, + values: array::from_fn(|i| COL_LOOKUP_MEM_VALUES_LEFT + VECTOR_LEN + i), + }, + ] + } + + fn buses(&self) -> Vec { + vec![ + Bus { + table: BusTable::Constant(self.identifier()), + direction: BusDirection::Pull, + selector: BusSelector::Column(COL_FLAG), + data: vec![COL_INDEX_SEED, COL_INDEX_START, COL_INDEX_RES, COL_LEN], + }, + Bus { + table: BusTable::Constant(Table::poseidon24_core()), + direction: BusDirection::Push, + selector: BusSelector::ConstantOne, + data: [ + (COL_LOOKUP_MEM_VALUES_LEFT..COL_LOOKUP_MEM_VALUES_LEFT + 16).collect::>(), + (INITIAL_COLS_DATA_RIGHT..INITIAL_COLS_DATA_RIGHT + 8).collect::>(), + (INITIAL_COLS_DATA_RES..INITIAL_COLS_DATA_RES + 8).collect::>(), + ] + .concat(), + }, + ] + } + + fn padding_row_f(&self) -> Vec { + let default_hash = get_poseidon_24_of_zero()[2 * VECTOR_LEN..].to_vec(); + [ + vec![ + F::ONE, // flag + F::from_usize(ZERO_VEC_PTR), // index seed + F::from_usize(ZERO_VEC_PTR), // index_start + F::from_usize(ZERO_VEC_PTR + 1), // index_start_bis + F::from_usize(ZERO_VEC_PTR), // index_res + F::ONE, // len + F::from_usize(ZERO_VEC_PTR), // COL_LOOKUP_MEM_INDEX_SEED_OR_RES + ], + vec![F::ZERO; VECTOR_LEN], // INITIAL_COLS_DATA_RIGHT + default_hash, // INITIAL_COLS_DATA_RES + vec![F::ZERO; VECTOR_LEN], // COL_LOOKUP_MEM_VALUES_SEED_OR_RES + vec![F::ZERO; VECTOR_LEN * 2], // COL_LOOKUP_MEM_VALUES_LEFT + ] + .concat() + } + + fn padding_row_ef(&self) -> Vec { + vec![] + } + + #[inline(always)] + fn execute( + &self, + index_seed: F, + index_start: F, + index_res: F, + len: usize, + ctx: &mut InstructionContext<'_>, + ) -> Result<(), RunnerError> { + assert!(len >= 2); + + let seed = ctx.memory.get_vector(index_seed.to_usize())?; + let mut cap = seed; + for i in 0..len { + let index = index_start.to_usize() + i * 2; + + let mut input = [F::ZERO; VECTOR_LEN * 3]; + input[..VECTOR_LEN].copy_from_slice(&ctx.memory.get_vector(index)?); + input[VECTOR_LEN..VECTOR_LEN * 2].copy_from_slice(&ctx.memory.get_vector(index + 1)?); + input[VECTOR_LEN * 2..].copy_from_slice(&cap); + // let output: [F; VECTOR_LEN] = poseidon24_permute(input)[VECTOR_LEN * 2..].try_into().unwrap(); + + let output = match ctx.poseidon24_precomputed.get(*ctx.n_poseidon24_precomputed_used) { + Some(precomputed) if precomputed.0 == input => { + *ctx.n_poseidon24_precomputed_used += 1; + precomputed.1 + } + _ => poseidon24_permute(input)[VECTOR_LEN * 2..].try_into().unwrap(), + }; + + let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; + + for j in 0..VECTOR_LEN * 2 { + trace[COL_LOOKUP_MEM_VALUES_LEFT + j].push(input[j]); + } + for j in 0..VECTOR_LEN { + trace[INITIAL_COLS_DATA_RIGHT + j].push(cap[j]); + } + for j in 0..VECTOR_LEN { + trace[INITIAL_COLS_DATA_RES + j].push(output[j]); + } + + add_poseidon_24_core_row(ctx.traces, 1, input, output); + + cap = output; + } + let trace = &mut ctx.traces.get_mut(&self.identifier()).unwrap().base; + + let final_res = cap; + ctx.memory.set_vector(index_res.to_usize(), final_res)?; + + trace[COL_FLAG].extend([vec![F::ONE], vec![F::ZERO; len - 1]].concat()); + trace[COL_INDEX_SEED].extend(vec![index_seed; len]); + trace[COL_INDEX_START].extend((0..len).map(|i| index_start + F::from_usize(i * 2))); + trace[COL_INDEX_START_BIS].extend((0..len).map(|i| index_start + F::from_usize(i * 2 + 1))); + trace[COL_INDEX_RES].extend(vec![index_res; len]); + trace[COL_LEN].extend((1..=len).rev().map(F::from_usize)); + trace[COL_LOOKUP_MEM_INDEX_SEED_OR_RES].extend([vec![index_seed], vec![index_res; len - 1]].concat()); + for i in 0..VECTOR_LEN { + trace[COL_LOOKUP_MEM_VALUES_SEED_OR_RES + i].extend([vec![seed[i]], vec![final_res[i]; len - 1]].concat()); + } + + Ok(()) + } +} + +impl Air for SliceHashPrecompile { + type ExtraData = ExtraDataForBuses; + fn n_columns_f_air(&self) -> usize { + TOTAL_N_COLS + } + fn n_columns_ef_air(&self) -> usize { + 0 + } + fn degree(&self) -> usize { + 3 + } + fn down_column_indexes_f(&self) -> Vec { + (0..INITIAL_COLS_DATA_RES).collect() + } + fn down_column_indexes_ef(&self) -> Vec { + vec![] + } + fn n_constraints(&self) -> usize { + 8 + 5 * VECTOR_LEN + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let up = builder.up_f(); + let flag = up[COL_FLAG].clone(); + let index_seed = up[COL_INDEX_SEED].clone(); + let index_start = up[COL_INDEX_START].clone(); + let index_start_bis = up[COL_INDEX_START_BIS].clone(); + let index_res = up[COL_INDEX_RES].clone(); + let len = up[COL_LEN].clone(); + let lookup_index_seed_or_res = up[COL_LOOKUP_MEM_INDEX_SEED_OR_RES].clone(); + let data_right: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RIGHT + i].clone()); + let data_res: [_; VECTOR_LEN] = array::from_fn(|i| up[INITIAL_COLS_DATA_RES + i].clone()); + let data_seed_or_res_lookup_values: [_; VECTOR_LEN] = + array::from_fn(|i| up[COL_LOOKUP_MEM_VALUES_SEED_OR_RES + i].clone()); + + let down = builder.down_f(); + let flag_down = down[0].clone(); + let index_seed_down = down[1].clone(); + let index_start_down = down[2].clone(); + let _index_start_bis_down = down[3].clone(); + let index_res_down = down[4].clone(); + let len_down = down[5].clone(); + let _lookup_index_seed_or_res_down = down[6].clone(); + let data_right_down: [_; VECTOR_LEN] = array::from_fn(|i| down[7 + i].clone()); + + let mut core_bus_data = [AB::F::ZERO; 24 + 8]; + core_bus_data[0..16].clone_from_slice(&up[COL_LOOKUP_MEM_VALUES_LEFT..][..16]); + core_bus_data[16..24].clone_from_slice(&up[INITIAL_COLS_DATA_RIGHT..][..8]); + core_bus_data[24..32].clone_from_slice(&up[INITIAL_COLS_DATA_RES..][..8]); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(self.identifier().index()), + flag.clone(), + &[index_seed.clone(), index_start.clone(), index_res.clone(), len.clone()], + )); + + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + AB::F::from_usize(Table::poseidon24_core().index()), + AB::F::ONE, + &core_bus_data, + )); + + // TODO double check constraints + + builder.assert_bool(flag.clone()); + + let not_flag = AB::F::ONE - flag.clone(); + let not_flag_down = AB::F::ONE - flag_down.clone(); + + builder.assert_eq( + lookup_index_seed_or_res.clone(), + flag.clone() * index_seed.clone() + not_flag.clone() * index_res.clone(), + ); + + // index_start_bis = index_start + 1 + builder.assert_eq(index_start_bis.clone(), index_start.clone() + AB::F::ONE); + + // Parameters should not change as long as the flag has not been switched back to 1: + builder.assert_zero(not_flag_down.clone() * (index_seed_down.clone() - index_seed.clone())); + builder.assert_zero(not_flag_down.clone() * (index_res_down.clone() - index_res.clone())); + + builder.assert_zero(not_flag_down.clone() * (index_start_down.clone() - (index_start.clone() + AB::F::TWO))); + + // decrease len by 1 each step + builder.assert_zero(not_flag_down.clone() * (len_down.clone() + AB::F::ONE - len.clone())); + + // start: ingest the seed + for i in 0..VECTOR_LEN { + builder.assert_zero(flag.clone() * (data_right[i].clone() - data_seed_or_res_lookup_values[i].clone())); + } + + // transition + for i in 0..VECTOR_LEN { + builder.assert_zero(not_flag_down.clone() * (data_res[i].clone() - data_right_down[i].clone())); + } + + // end + builder.assert_zero(flag_down.clone() * (len.clone() - AB::F::ONE)); // at last step, len should be 1 + for i in 0..VECTOR_LEN { + builder.assert_zero( + not_flag.clone() + * flag_down.clone() + * (data_res[i].clone() - data_seed_or_res_lookup_values[i].clone()), + ); + } + } +} diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 99252a03..4de9a519 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -3,13 +3,18 @@ use p3_air::Air; use crate::*; -pub const N_TABLES: usize = 5; +pub const N_TABLES: usize = 10; pub const ALL_TABLES: [Table; N_TABLES] = [ Table::execution(), Table::dot_product_be(), Table::dot_product_ee(), - Table::poseidon16(), - Table::poseidon24(), + Table::poseidon16_core(), + Table::poseidon16_mem(), + Table::poseidon24_core(), + Table::poseidon24_mem(), + Table::merkle(), + Table::slice_hash(), + Table::eq_poly_base_ext(), ]; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -18,8 +23,13 @@ pub enum Table { Execution(ExecutionTable), DotProductBE(DotProductPrecompile), DotProductEE(DotProductPrecompile), - Poseidon16(Poseidon16Precompile), - Poseidon24(Poseidon24Precompile), + Poseidon16Core(Poseidon16CorePrecompile), + Poseidon16Mem(Poseidon16MemPrecompile), + Poseidon24Core(Poseidon24CorePrecompile), + Poseidon24Mem(Poseidon24MemPrecompile), + Merkle(MerklePrecompile), + SliceHash(SliceHashPrecompile), + EqPolyBaseExt(EqPolyBaseExtPrecompile), } #[macro_export] @@ -29,9 +39,14 @@ macro_rules! delegate_to_inner { match $self { Self::DotProductBE(p) => p.$method($($($arg),*)?), Self::DotProductEE(p) => p.$method($($($arg),*)?), - Self::Poseidon16(p) => p.$method($($($arg),*)?), - Self::Poseidon24(p) => p.$method($($($arg),*)?), + Self::Poseidon16Core(p) => p.$method($($($arg),*)?), + Self::Poseidon16Mem(p) => p.$method($($($arg),*)?), + Self::Poseidon24Core(p) => p.$method($($($arg),*)?), + Self::Poseidon24Mem(p) => p.$method($($($arg),*)?), Self::Execution(p) => p.$method($($($arg),*)?), + Self::Merkle(p) => p.$method($($($arg),*)?), + Self::SliceHash(p) => p.$method($($($arg),*)?), + Self::EqPolyBaseExt(p) => p.$method($($($arg),*)?), } }; // New pattern for applying a macro to the inner value @@ -39,9 +54,14 @@ macro_rules! delegate_to_inner { match $self { Table::DotProductBE(p) => $macro_name!(p), Table::DotProductEE(p) => $macro_name!(p), - Table::Poseidon16(p) => $macro_name!(p), - Table::Poseidon24(p) => $macro_name!(p), + Table::Poseidon16Core(p) => $macro_name!(p), + Table::Poseidon16Mem(p) => $macro_name!(p), + Table::Poseidon24Core(p) => $macro_name!(p), + Table::Poseidon24Mem(p) => $macro_name!(p), Table::Execution(p) => $macro_name!(p), + Table::Merkle(p) => $macro_name!(p), + Table::SliceHash(p) => $macro_name!(p), + Table::EqPolyBaseExt(p) => $macro_name!(p), } }; } @@ -56,11 +76,26 @@ impl Table { pub const fn dot_product_ee() -> Self { Self::DotProductEE(DotProductPrecompile::) } - pub const fn poseidon16() -> Self { - Self::Poseidon16(Poseidon16Precompile) + pub const fn poseidon16_core() -> Self { + Self::Poseidon16Core(Poseidon16CorePrecompile) } - pub const fn poseidon24() -> Self { - Self::Poseidon24(Poseidon24Precompile) + pub const fn poseidon16_mem() -> Self { + Self::Poseidon16Mem(Poseidon16MemPrecompile) + } + pub const fn poseidon24_core() -> Self { + Self::Poseidon24Core(Poseidon24CorePrecompile) + } + pub const fn poseidon24_mem() -> Self { + Self::Poseidon24Mem(Poseidon24MemPrecompile) + } + pub const fn merkle() -> Self { + Self::Merkle(MerklePrecompile) + } + pub const fn slice_hash() -> Self { + Self::SliceHash(SliceHashPrecompile) + } + pub const fn eq_poly_base_ext() -> Self { + Self::EqPolyBaseExt(EqPolyBaseExtPrecompile) } pub fn embed(&self) -> PF { PF::from_usize(self.index()) @@ -68,6 +103,12 @@ impl Table { pub const fn index(&self) -> usize { unsafe { *(self as *const Self as *const usize) } } + pub fn is_poseidon(&self) -> bool { + matches!( + self, + Table::Poseidon16Core(_) | Table::Poseidon16Mem(_) | Table::Poseidon24Core(_) | Table::Poseidon24Mem(_) + ) + } } impl TableT for Table { @@ -144,6 +185,14 @@ impl Air for Table { } } +pub fn max_bus_width() -> usize { + 1 + ALL_TABLES + .iter() + .map(|table| table.buses().iter().map(|bus| bus.data.len()).max().unwrap()) + .max() + .unwrap() +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index 6a5ea5dd..f9068e49 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -1,7 +1,7 @@ use crate::{EF, F, InstructionContext, RunnerError, Table, VECTOR_LEN}; use multilinear_toolkit::prelude::*; use p3_air::Air; -use std::{any::TypeId, array, mem::transmute_copy}; +use std::{any::TypeId, array, mem::transmute}; use utils::ToUsize; use sub_protocols::{ @@ -34,7 +34,7 @@ pub struct VectorLookupIntoMemory { pub values: [ColIndex; 8], } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BusDirection { Pull, Push, @@ -49,15 +49,21 @@ impl BusDirection { } } +#[derive(Debug)] +pub enum BusSelector { + Column(ColIndex), + ConstantOne, +} + #[derive(Debug)] pub struct Bus { pub direction: BusDirection, pub table: BusTable, - pub selector: ColIndex, + pub selector: BusSelector, pub data: Vec, // For now, we only supports F (base field) columns as bus data } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum BusTable { Constant(Table), Variable(ColIndex), @@ -103,10 +109,12 @@ impl TableTrace { } #[derive(Debug)] -pub struct ExtraDataForBuses { +pub struct ExtraDataForBuses>> { // GKR quotient challenges - pub fingerprint_challenge_powers: [EF; 5], + pub fingerprint_challenge_powers: Vec, + pub fingerprint_challenge_powers_packed: Vec>, pub bus_beta: EF, + pub bus_beta_packed: EFPacking, pub alpha_powers: Vec, } @@ -123,15 +131,20 @@ impl AlphaPowers for ExtraDataForBuses { } impl>> ExtraDataForBuses { - pub fn transmute_bus_data(&self) -> ([NewEF; 5], NewEF) { + pub fn transmute_bus_data(&self) -> (&Vec, &NewEF) { if TypeId::of::() == TypeId::of::() { - unsafe { transmute_copy::<_, _>(&(self.fingerprint_challenge_powers, self.bus_beta)) } + unsafe { + transmute::<(&Vec, &EF), (&Vec, &NewEF)>(( + &self.fingerprint_challenge_powers, + &self.bus_beta, + )) + } } else { assert_eq!(TypeId::of::(), TypeId::of::>()); unsafe { - transmute_copy::<_, _>(&( - self.fingerprint_challenge_powers.map(|c| EFPacking::::from(c)), - EFPacking::::from(self.bus_beta), + transmute::<(&Vec>, &EFPacking), (&Vec, &NewEF)>(( + &self.fingerprint_challenge_powers_packed, + &self.bus_beta_packed, )) } } @@ -426,7 +439,10 @@ impl Bus { fingerprint_challenge: EF, ) -> EF { let padding_row_f = table.padding_row_f(); - let default_selector = padding_row_f[self.selector]; + let default_selector = match &self.selector { + BusSelector::ConstantOne => F::ONE, + BusSelector::Column(col) => padding_row_f[*col], + }; let default_table = match &self.table { BusTable::Constant(t) => F::from_usize(t.index()), BusTable::Variable(col) => padding_row_f[*col], diff --git a/crates/lean_vm/src/tables/utils.rs b/crates/lean_vm/src/tables/utils.rs index 30a5f06b..222decf7 100644 --- a/crates/lean_vm/src/tables/utils.rs +++ b/crates/lean_vm/src/tables/utils.rs @@ -7,17 +7,17 @@ pub(crate) fn eval_virtual_bus_column> extra_data: &ExtraDataForBuses, precompile_index: AB::F, flag: AB::F, - arg_a: AB::F, - arg_b: AB::F, - arg_c: AB::F, - aux: AB::F, + data: &[AB::F], ) -> AB::EF { let (fingerprint_challenge_powers, bus_beta) = extra_data.transmute_bus_data::(); - let data = fingerprint_challenge_powers[1].clone() * arg_a - + fingerprint_challenge_powers[2].clone() * arg_b - + fingerprint_challenge_powers[3].clone() * arg_c - + fingerprint_challenge_powers[4].clone() * aux; - - (data + precompile_index) * bus_beta + flag + assert!(data.len() < fingerprint_challenge_powers.len()); + (fingerprint_challenge_powers[1..] + .iter() + .zip(data) + .map(|(c, d)| c.clone() * d.clone()) + .sum::() + + precompile_index) + * bus_beta.clone() + + flag } diff --git a/crates/lean_vm/tests/test_lean_vm.rs b/crates/lean_vm/tests/test_lean_vm.rs deleted file mode 100644 index 5b123787..00000000 --- a/crates/lean_vm/tests/test_lean_vm.rs +++ /dev/null @@ -1,205 +0,0 @@ -use lean_vm::*; -use multilinear_toolkit::prelude::*; -use p3_util::log2_ceil_usize; -use std::collections::BTreeMap; - -// Pointers for precompile inputs allocated in public memory. -const POSEIDON16_ARG_A_PTR: usize = 6; -const POSEIDON16_ARG_B_PTR: usize = 7; -const POSEIDON24_ARG_A_PTR: usize = 11; // uses ptr and ptr + 1 -const POSEIDON24_ARG_B_PTR: usize = 13; -const DOT_ARG0_PTR: usize = 180; // normal pointer, len 2 -const DOT_ARG1_PTR: usize = 200; // normal pointer, len 2 -const MLE_COEFF_PTR: usize = 32; // interpreted with shift << n_vars -const MLE_POINT_PTR: usize = 15; // interpreted with shift << log_point_size - -// Offsets used in hints for storing result pointers at fp + offset. -const POSEIDON16_RES_OFFSET: usize = 0; -const POSEIDON24_RES_OFFSET: usize = 1; -const DOT_RES_OFFSET: usize = 2; -const MLE_RES_OFFSET: usize = 3; - -const DOT_PRODUCT_LEN: usize = 2; -const MLE_N_VARS: usize = 1; - -// Ensure public input covers the highest index used (dot product arg1 slice). -const MAX_MEMORY_INDEX: usize = DOT_ARG1_PTR + DOT_PRODUCT_LEN * DIMENSION - 1; -const PUBLIC_INPUT_LEN: usize = MAX_MEMORY_INDEX - NONRESERVED_PROGRAM_INPUT_START + 1; - -const POSEIDON16_ARG_A_VALUES: [u64; VECTOR_LEN] = [1, 2, 3, 4, 5, 6, 7, 8]; -const POSEIDON16_ARG_B_VALUES: [u64; VECTOR_LEN] = [101, 102, 103, 104, 105, 106, 107, 108]; -const POSEIDON24_ARG_A_VALUES: [[u64; VECTOR_LEN]; 2] = [ - [201, 202, 203, 204, 205, 206, 207, 208], - [211, 212, 213, 214, 215, 216, 217, 218], -]; -const POSEIDON24_ARG_B_VALUES: [u64; VECTOR_LEN] = [221, 222, 223, 224, 225, 226, 227, 228]; -const DOT_ARG0_VALUES: [[u64; DIMENSION]; DOT_PRODUCT_LEN] = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]; -const DOT_ARG1_VALUES: [[u64; DIMENSION]; DOT_PRODUCT_LEN] = [[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]; -const MLE_COEFF_VALUES: [u64; 1 << MLE_N_VARS] = [7, 9]; -const MLE_POINT_VALUES: [u64; DIMENSION] = [21, 22, 23, 24, 25]; - -fn f(value: u64) -> F { - F::from_isize(value as isize) -} - -fn set_public_input_cell(public_input: &mut [F], memory_index: usize, value: F) { - assert!(memory_index >= NONRESERVED_PROGRAM_INPUT_START); - let idx = memory_index - NONRESERVED_PROGRAM_INPUT_START; - assert!(idx < public_input.len()); - public_input[idx] = value; -} - -fn set_vector(public_input: &mut [F], ptr: usize, values: &[u64]) { - assert_eq!(values.len(), VECTOR_LEN); - for (i, &value) in values.iter().enumerate() { - set_public_input_cell(public_input, ptr * VECTOR_LEN + i, f(value)); - } -} - -fn set_multivector(public_input: &mut [F], ptr: usize, chunks: &[&[u64]]) { - for (chunk_index, chunk) in chunks.iter().enumerate() { - assert_eq!(chunk.len(), VECTOR_LEN); - for (i, &value) in chunk.iter().enumerate() { - set_public_input_cell(public_input, (ptr + chunk_index) * VECTOR_LEN + i, f(value)); - } - } -} - -fn set_ef_slice(public_input: &mut [F], ptr: usize, elements: &[[u64; DIMENSION]]) { - for (i, coeffs) in elements.iter().enumerate() { - for (j, &value) in coeffs.iter().enumerate() { - set_public_input_cell(public_input, ptr + i * DIMENSION + j, f(value)); - } - } -} - -fn set_base_slice(public_input: &mut [F], start_index: usize, values: &[u64]) { - for (i, &value) in values.iter().enumerate() { - set_public_input_cell(public_input, start_index + i, f(value)); - } -} - -fn build_test_case() -> (Bytecode, Vec) { - let mut public_input = vec![F::ZERO; PUBLIC_INPUT_LEN]; - - set_vector(&mut public_input, POSEIDON16_ARG_A_PTR, &POSEIDON16_ARG_A_VALUES); - set_vector(&mut public_input, POSEIDON16_ARG_B_PTR, &POSEIDON16_ARG_B_VALUES); - - let poseidon24_chunks = [&POSEIDON24_ARG_A_VALUES[0][..], &POSEIDON24_ARG_A_VALUES[1][..]]; - set_multivector(&mut public_input, POSEIDON24_ARG_A_PTR, &poseidon24_chunks); - set_vector(&mut public_input, POSEIDON24_ARG_B_PTR, &POSEIDON24_ARG_B_VALUES); - - set_ef_slice(&mut public_input, DOT_ARG0_PTR, &DOT_ARG0_VALUES); - set_ef_slice(&mut public_input, DOT_ARG1_PTR, &DOT_ARG1_VALUES); - - let coeff_base = MLE_COEFF_PTR << MLE_N_VARS; - set_base_slice(&mut public_input, coeff_base, &MLE_COEFF_VALUES); - - let log_point_size = log2_ceil_usize(MLE_N_VARS * DIMENSION); - let point_base = MLE_POINT_PTR << log_point_size; - set_base_slice(&mut public_input, point_base, &MLE_POINT_VALUES); - - let mut hints = BTreeMap::new(); - hints.insert( - 0, - vec![Hint::RequestMemory { - function_name: Label::function("main"), - offset: POSEIDON16_RES_OFFSET, - size: MemOrConstant::Constant(f(2)), - vectorized: true, - vectorized_len: LOG_VECTOR_LEN + 1, - }], - ); - hints.insert( - 1, - vec![Hint::RequestMemory { - function_name: Label::function("main"), - offset: POSEIDON24_RES_OFFSET, - size: MemOrConstant::Constant(f(1)), - vectorized: true, - vectorized_len: LOG_VECTOR_LEN, - }], - ); - hints.insert( - 2, - vec![Hint::RequestMemory { - function_name: Label::function("main"), - offset: DOT_RES_OFFSET, - size: MemOrConstant::Constant(f(1)), - vectorized: false, - vectorized_len: 0, - }], - ); - hints.insert( - 3, - vec![Hint::RequestMemory { - function_name: Label::function("main"), - offset: MLE_RES_OFFSET, - size: MemOrConstant::Constant(f(1)), - vectorized: true, - vectorized_len: LOG_VECTOR_LEN, - }], - ); - - let instructions = vec![ - Instruction::Precompile { - table: Table::poseidon16(), - arg_a: MemOrConstant::Constant(f(POSEIDON16_ARG_A_PTR as u64)), - arg_b: MemOrConstant::Constant(f(POSEIDON16_ARG_B_PTR as u64)), - arg_c: MemOrFp::MemoryAfterFp { - offset: POSEIDON16_RES_OFFSET, - }, - aux: 0, // compression = false - }, - Instruction::Precompile { - table: Table::poseidon24(), - arg_a: MemOrConstant::Constant(f(POSEIDON24_ARG_A_PTR as u64)), - arg_b: MemOrConstant::Constant(f(POSEIDON24_ARG_B_PTR as u64)), - arg_c: MemOrFp::MemoryAfterFp { - offset: POSEIDON24_RES_OFFSET, - }, - aux: 0, // unused - }, - Instruction::Precompile { - table: Table::dot_product_ee(), - arg_a: MemOrConstant::Constant(f(DOT_ARG0_PTR as u64)), - arg_b: MemOrConstant::Constant(f(DOT_ARG1_PTR as u64)), - arg_c: MemOrFp::MemoryAfterFp { offset: DOT_RES_OFFSET }, - aux: DOT_PRODUCT_LEN, - }, - ]; - - let bytecode = Bytecode { - instructions, - hints, - starting_frame_memory: 512, - program: Default::default(), - function_locations: Default::default(), - }; - - (bytecode, public_input) -} - -fn run_program() -> (Bytecode, ExecutionResult) { - let (bytecode, public_input) = build_test_case(); - let result = execute_bytecode(&bytecode, (&public_input, &[]), 1 << 20, false, (&vec![], &vec![])); - println!("{}", result.summary); - (bytecode, result) -} -#[test] -fn test_memory_operations() { - let mut memory = Memory::empty(); - assert!(memory.set(0, F::from_usize(42)).is_ok()); - assert_eq!(memory.get(0).unwrap(), F::from_usize(42)); -} - -#[test] -fn test_operation_compute() { - use crate::Operation; - - let add = Operation::Add; - let mul = Operation::Mul; - - assert_eq!(add.compute(F::from_usize(2), F::from_usize(3)), F::from_usize(5)); - assert_eq!(mul.compute(F::from_usize(2), F::from_usize(3)), F::from_usize(6)); -} diff --git a/crates/lookup/src/logup_star.rs b/crates/lookup/src/logup_star.rs index 9b22e17c..64bdcf0d 100644 --- a/crates/lookup/src/logup_star.rs +++ b/crates/lookup/src/logup_star.rs @@ -277,12 +277,13 @@ mod tests { ); println!("Proving logup_star took {} ms", time.elapsed().as_millis()); - let mut verifier_state = build_verifier_state(&prover_state); + let last_prover_state = prover_state.challenger().state(); + let mut verifier_state = build_verifier_state(prover_state); let verifier_statements = verify_logup_star(&mut verifier_state, log_table_len, log_indexes_len, &[claim], EF::ONE).unwrap(); assert_eq!(&verifier_statements, &prover_statements); - assert_eq!(prover_state.challenger().state(), verifier_state.challenger().state()); + assert_eq!(last_prover_state, verifier_state.challenger().state()); assert_eq!( indexes.evaluate(&verifier_statements.on_indexes.point), diff --git a/crates/lookup/src/quotient_gkr.rs b/crates/lookup/src/quotient_gkr.rs index e72a4929..724fa8e8 100644 --- a/crates/lookup/src/quotient_gkr.rs +++ b/crates/lookup/src/quotient_gkr.rs @@ -220,7 +220,7 @@ fn sum_quotients_helper( ) -> Vec> { assert_eq!(numerators_and_denominators.len(), n_groups); let n = numerators_and_denominators[0].len(); - assert!(n.is_power_of_two() && n >= 2); + assert!(n.is_power_of_two() && n >= 2, "n = {}", n); let mut new_numerators = Vec::new(); let mut new_denominators = Vec::new(); let (prev_numerators, prev_denominators) = numerators_and_denominators.split_at(n_groups / 2); @@ -320,7 +320,7 @@ mod tests { ); println!("Proving time: {:?}", time.elapsed()); - let mut verifier_state = build_verifier_state(&prover_state); + let mut verifier_state = build_verifier_state(prover_state); let verifier_statements = verify_gkr_quotient::(&mut verifier_state, log_n).unwrap(); assert_eq!(&verifier_statements, &prover_statements); diff --git a/crates/poseidon_circuit/src/tests.rs b/crates/poseidon_circuit/src/tests.rs index 1d4a03d4..f55e35bc 100644 --- a/crates/poseidon_circuit/src/tests.rs +++ b/crates/poseidon_circuit/src/tests.rs @@ -25,19 +25,22 @@ const COMPRESSION_OUTPUT_WIDTH: usize = 8; #[test] fn test_poseidon_benchmark() { - run_poseidon_benchmark::<16, 0, 3>(12, false); - run_poseidon_benchmark::<16, 0, 3>(12, true); - run_poseidon_benchmark::<16, 16, 3>(12, false); - run_poseidon_benchmark::<16, 16, 3>(12, true); + run_poseidon_benchmark::<16, 0, 3>(12, false, false); + run_poseidon_benchmark::<16, 0, 3>(12, true, false); + run_poseidon_benchmark::<16, 16, 3>(12, false, false); + run_poseidon_benchmark::<16, 16, 3>(12, true, false); } pub fn run_poseidon_benchmark( log_n_poseidons: usize, compress: bool, + tracing: bool, ) where KoalaBearInternalLayerParameters: InternalLayerBaseParameters, { - init_tracing(); + if tracing { + init_tracing(); + } precompute_dft_twiddles::(1 << 24); let whir_config_builder = WhirConfigBuilder { @@ -167,7 +170,7 @@ pub fn run_poseidon_benchmark 1 { } fn eq_mle_extension_base_dynamic(a, b, n) -> 1 { - if n == N_VARS - FOLDING_FACTOR_0 { - res = eq_mle_extension_base_const(a, b, N_VARS - FOLDING_FACTOR_0); - return res; - } - if n == N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1 { - res = eq_mle_extension_base_const(a, b, N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1); - return res; - } - if n == N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1 - FOLDING_FACTOR_2 { - res = eq_mle_extension_base_const(a, b, N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1 - FOLDING_FACTOR_2); - return res; - } - TODO_eq_mle_extension_base_dynamicc = n; - print(TODO_eq_mle_extension_base_dynamicc); - panic(); -} - -fn eq_mle_extension_base_const(a, b, const n) -> 1 { - // a: base - // b: extension - - buff = malloc(n*DIM); - - for i in 0..n unroll { - ai = a[i]; - bi = b + i*DIM; - buffi = buff + i*DIM; - ai_double = ai * 2; - buffi[0] = 1 + ai_double * bi[0] - ai - bi[0]; - for j in 1..DIM unroll { - buffi[j] = ai_double * bi[j] - bi[j]; - } - } - - prods = malloc(n*DIM); - assert_eq_extension(buff, prods); - for i in 0..n - 1 unroll { - mul_extension(prods + i*DIM, buff + (i + 1)*DIM, prods + (i + 1)*DIM); + res = malloc(DIM); + match n { + 0 => { } // unreachable + 1 => { eq_poly_base_ext(a, b, res, 1); } + 2 => { eq_poly_base_ext(a, b, res, 2); } + 3 => { eq_poly_base_ext(a, b, res, 3); } + 4 => { eq_poly_base_ext(a, b, res, 4); } + 5 => { eq_poly_base_ext(a, b, res, 5); } + 6 => { eq_poly_base_ext(a, b, res, 6); } + 7 => { eq_poly_base_ext(a, b, res, 7); } + 8 => { eq_poly_base_ext(a, b, res, 8); } + 9 => { eq_poly_base_ext(a, b, res, 9); } + 10 => { eq_poly_base_ext(a, b, res, 10); } + 11 => { eq_poly_base_ext(a, b, res, 11); } + 12 => { eq_poly_base_ext(a, b, res, 12); } + 13 => { eq_poly_base_ext(a, b, res, 13); } + 14 => { eq_poly_base_ext(a, b, res, 14); } + 15 => { eq_poly_base_ext(a, b, res, 15); } + 16 => { eq_poly_base_ext(a, b, res, 16); } + 17 => { eq_poly_base_ext(a, b, res, 17); } + 18 => { eq_poly_base_ext(a, b, res, 18); } + 19 => { eq_poly_base_ext(a, b, res, 19); } + 20 => { eq_poly_base_ext(a, b, res, 20); } + 21 => { eq_poly_base_ext(a, b, res, 21); } + 22 => { eq_poly_base_ext(a, b, res, 22); } } - return prods + (n - 1) * DIM; + return res; } fn expand_from_univariate_dynamic(alpha, n) -> 1 { - if n == N_VARS - FOLDING_FACTOR_0 { - res = expand_from_univariate_const(alpha, N_VARS - FOLDING_FACTOR_0); - return res; - } - if n == N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1 { - res = expand_from_univariate_const(alpha, N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1); - return res; - } - if n == N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1 - FOLDING_FACTOR_2 { - res = expand_from_univariate_const(alpha, N_VARS - FOLDING_FACTOR_0 - FOLDING_FACTOR_1 - FOLDING_FACTOR_2); - return res; + match n { + 0 => { } // unreachable + 1 => { res = expand_from_univariate_const(alpha, 1); } + 2 => { res = expand_from_univariate_const(alpha, 2); } + 3 => { res = expand_from_univariate_const(alpha, 3); } + 4 => { res = expand_from_univariate_const(alpha, 4); } + 5 => { res = expand_from_univariate_const(alpha, 5); } + 6 => { res = expand_from_univariate_const(alpha, 6); } + 7 => { res = expand_from_univariate_const(alpha, 7); } + 8 => { res = expand_from_univariate_const(alpha, 8); } + 9 => { res = expand_from_univariate_const(alpha, 9); } + 10 => { res = expand_from_univariate_const(alpha, 10); } + 11 => { res = expand_from_univariate_const(alpha, 11); } + 12 => { res = expand_from_univariate_const(alpha, 12); } + 13 => { res = expand_from_univariate_const(alpha, 13); } + 14 => { res = expand_from_univariate_const(alpha, 14); } + 15 => { res = expand_from_univariate_const(alpha, 15); } + 16 => { res = expand_from_univariate_const(alpha, 16); } + 17 => { res = expand_from_univariate_const(alpha, 17); } + 18 => { res = expand_from_univariate_const(alpha, 18); } + 19 => { res = expand_from_univariate_const(alpha, 19); } + 20 => { res = expand_from_univariate_const(alpha, 20); } + 21 => { res = expand_from_univariate_const(alpha, 21); } + 22 => { res = expand_from_univariate_const(alpha, 22); } } - TODO_expand_from_univariate_dynamic = n; - print(TODO_expand_from_univariate_dynamic); - panic(); + return res; } fn expand_from_univariate_const(alpha, const n) -> 1 { @@ -285,8 +293,10 @@ fn sumcheck(fs_state, n_steps, claimed_sum) -> 3 { fn sample_stir_indexes_and_fold(fs_state, num_queries, merkle_leaves_in_basefield, folding_factor, two_pow_folding_factor, domain_size, prev_root, folding_randomness, grinding_bits) -> 3 { + folded_domain_size = domain_size - folding_factor; + fs_state_8 = fs_grinding(fs_state, grinding_bits); - fs_state_9, stir_challenges_indexes = sample_bits(fs_state_8, num_queries); + fs_state_9, stir_challenges_indexes = sample_bits_dynamic(fs_state_8, num_queries, folded_domain_size); answers = malloc(num_queries); // a vector of vectorized pointers, each pointing to `two_pow_folding_factor` field elements (base if first rounds, extension otherwise) fs_states_b = malloc(num_queries + 1); @@ -307,47 +317,10 @@ fn sample_stir_indexes_and_fold(fs_state, num_queries, merkle_leaves_in_basefiel fs_state_10 = fs_states_b[num_queries]; leaf_hashes = malloc(num_queries); // a vector of vectorized pointers, each pointing to 1 chunk of 8 field elements - for i in 0..num_queries { - answer = answers[i]; - internal_states = malloc(1 + (n_chuncks_per_answer / 2)); // "/ 2" because with poseidon24 we hash 2 chuncks of 8 field elements at each permutation - internal_states[0] = pointer_to_zero_vector; // initial state - for j in 0..n_chuncks_per_answer / 2 { - h24 = malloc_vec(1); - poseidon24(answer + (2*j), internal_states[j], h24); - internal_states[j + 1] = h24; - } - leaf_hashes[i] = internal_states[n_chuncks_per_answer / 2]; - } - - folded_domain_size = domain_size - folding_factor; - - fs_states_c = malloc(num_queries + 1); - fs_states_c[0] = fs_state_10; + batch_hash_slice_dynamic(num_queries, answers, leaf_hashes, n_chuncks_per_answer); - for i in 0..num_queries { - fs_state_11, merkle_path = fs_hint(fs_states_c[i], folded_domain_size); - fs_states_c[i + 1] = fs_state_11; - - stir_index_bits = stir_challenges_indexes[i]; // a pointer to 31 bits - - states = malloc(1 + folded_domain_size); - states[0] = leaf_hashes[i]; - for j in 0..folded_domain_size { - if stir_index_bits[j] == 1 { - left = merkle_path + j; - right = states[j]; - } else { - left = states[j]; - right = merkle_path + j; - } - state_j_plus_1 = malloc_vec(1); - poseidon16(left, right, state_j_plus_1, COMPRESSION); - states[j + 1] = state_j_plus_1; - } - assert_eq_vec(states[folded_domain_size], prev_root); - } - - fs_state_11 = fs_states_c[num_queries]; + // Merkle verification + merkle_verif_batch_dynamic(num_queries, leaf_hashes, stir_challenges_indexes + num_queries, prev_root, folded_domain_size); folds = malloc(num_queries * DIM); @@ -359,18 +332,80 @@ fn sample_stir_indexes_and_fold(fs_state, num_queries, merkle_leaves_in_basefiel } } else { for i in 0..num_queries { - dot_product_dynamic(answers[i] * 8, poly_eq, folds + i*DIM, two_pow_folding_factor); + dot_product_ee_dynamic(answers[i] * 8, poly_eq, folds + i*DIM, two_pow_folding_factor); } } circle_values = malloc(num_queries); // ROOT^each_stir_index for i in 0..num_queries { stir_index_bits = stir_challenges_indexes[i]; - circle_value = unit_root_pow(folded_domain_size, stir_index_bits); + circle_value = unit_root_pow_dynamic(folded_domain_size, stir_index_bits); circle_values[i] = circle_value; } - return fs_state_11, circle_values, folds; + return fs_state_10, circle_values, folds; +} + +fn batch_hash_slice_dynamic(num_queries, all_data_to_hash, all_resulting_hashes, len) { + if len == DIM * 2 { + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM); + return; + } + if len == 16 { + batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, 8); + return; + } + TODO_batch_hash_slice_dynamic = len; + print(77777123); + print(len); + panic(); +} + +fn batch_hash_slice_const(num_queries, all_data_to_hash, all_resulting_hashes, const half_len) { + for i in 0..num_queries { + data = all_data_to_hash[i]; + res = malloc_vec(1); + slice_hash(pointer_to_zero_vector, data, res, half_len); + all_resulting_hashes[i] = res; + } + return; +} + +fn merkle_verif_batch_dynamic(n_paths, leaves_digests, leave_positions, root, height) { + if height == MERKLE_HEIGHT_0 { + merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_0); + return; + } + if height == MERKLE_HEIGHT_1 { + merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_1); + return; + } + if height == MERKLE_HEIGHT_2 { + merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_2); + return; + } + if height == MERKLE_HEIGHT_3 { + merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_3); + return; + } + + print(12345555); + print(height); + panic(); +} + +fn merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, const height) { + // n_paths: F + // leaves_digests: pointer to a slice of n_paths vectorized pointers, each pointing to 1 chunk of 8 field elements + // leave_positions: pointer to a slice of n_paths field elements (each < 2^height) + // root: vectorized pointer to 1 chunk of 8 field elements + // height: F + + for i in 0..n_paths { + merkle_verify(leaves_digests[i], leave_positions[i], root, height); + } + + return; } @@ -387,7 +422,7 @@ fn whir_round(fs_state, prev_root, folding_factor, two_pow_folding_factor, merkl combination_randomness_powers = powers(combination_randomness_gen, num_queries + 1); // "+ 1" because of one OOD sample claimed_sum_supplement_side = malloc(5); - dot_product_dynamic(folds, combination_randomness_powers + DIM, claimed_sum_supplement_side, num_queries); + dot_product_ee_dynamic(folds, combination_randomness_powers + DIM, claimed_sum_supplement_side, num_queries); claimed_sum_supplement = add_extension_ret(claimed_sum_supplement_side, ood_eval); new_claimed_sum_b = add_extension_ret(claimed_sum_supplement, new_claimed_sum_a); @@ -420,35 +455,34 @@ fn powers(alpha, n) -> 1 { return res; } -fn unit_root_pow(domain_size, index_bits) -> 1 { +fn unit_root_pow_dynamic(domain_size, index_bits) -> 1 { // index_bits is a pointer to domain_size bits - if domain_size == 19 { - res = unit_root_pow_const(19, index_bits); - return res; - } - if domain_size == 18 { - res = unit_root_pow_const(18, index_bits); - return res; - } - if domain_size == 17 { - res = unit_root_pow_const(17, index_bits); - return res; - } - if domain_size == 16 { - res = unit_root_pow_const(16, index_bits); - return res; + match domain_size { + 0 => { } // unreachable + 1 => { res = unit_root_pow_const(1, index_bits); } + 2 => { res = unit_root_pow_const(2, index_bits); } + 3 => { res = unit_root_pow_const(3, index_bits); } + 4 => { res = unit_root_pow_const(4, index_bits); } + 5 => { res = unit_root_pow_const(5, index_bits); } + 6 => { res = unit_root_pow_const(6, index_bits); } + 7 => { res = unit_root_pow_const(7, index_bits); } + 8 => { res = unit_root_pow_const(8, index_bits); } + 9 => { res = unit_root_pow_const(9, index_bits); } + 10 => { res = unit_root_pow_const(10, index_bits); } + 11 => { res = unit_root_pow_const(11, index_bits); } + 12 => { res = unit_root_pow_const(12, index_bits); } + 13 => { res = unit_root_pow_const(13, index_bits); } + 14 => { res = unit_root_pow_const(14, index_bits); } + 15 => { res = unit_root_pow_const(15, index_bits); } + 16 => { res = unit_root_pow_const(16, index_bits); } + 17 => { res = unit_root_pow_const(17, index_bits); } + 18 => { res = unit_root_pow_const(18, index_bits); } + 19 => { res = unit_root_pow_const(19, index_bits); } + 20 => { res = unit_root_pow_const(20, index_bits); } + 21 => { res = unit_root_pow_const(21, index_bits); } + 22 => { res = unit_root_pow_const(22, index_bits); } } - if domain_size == 15 { - res = unit_root_pow_const(15, index_bits); - return res; - } - if domain_size == 20 { - res = unit_root_pow_const(20, index_bits); - return res; - } - UNIMPLEMENTED = 0; - print(UNIMPLEMENTED, domain_size); - panic(); + return res; } fn unit_root_pow_const(const domain_size, index_bits) -> 1 { @@ -460,7 +494,7 @@ fn unit_root_pow_const(const domain_size, index_bits) -> 1 { return prods[domain_size - 1]; } -fn dot_product_dynamic(a, b, res, n) { +fn dot_product_ee_dynamic(a, b, res, n) { if n == 16 { dot_product_ee(a, b, res, 16); return; @@ -498,34 +532,11 @@ fn dot_product_dynamic(a, b, res, n) { return; } - TODO_dot_product_dynamic = 0; - print(TODO_dot_product_dynamic, n); + TODO_dot_product_ee_dynamic = 0; + print(TODO_dot_product_ee_dynamic, n); panic(); } -fn dot_product_base_extension(a, b, res, const n) { - // a is a normal pointer to n F elements - // b is a normal pointer to n continous EF elements - // res is a normal pointer to 1 EF element - - prods = malloc(n * DIM); - for i in 0..n unroll { - for j in 0..DIM unroll { - prods[i * DIM + j] = a[i] * b[i * DIM + j]; - } - } - my_buff = malloc(n * DIM); - for i in 0..DIM unroll { - my_buff[n * i] = prods[i]; - for j in 0..n - 1 unroll { - my_buff[(n * i) + j + 1] = my_buff[(n * i) + j] + prods[i + ((j + 1) * DIM)]; - } - res[i] = my_buff[(n * i) + n - 1]; - } - - return; -} - fn poly_eq_extension(point, n, two_pow_n) -> 1 { // Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] @@ -549,31 +560,118 @@ fn poly_eq_extension(point, n, two_pow_n) -> 1 { return res; } -fn poly_eq_base(point, n, two_pow_n) -> 1 { +fn poly_eq_base(point, n) -> 1 { + match n { + 0 => { } // unreachable + 1 => { res = poly_eq_base_1(point); } + 2 => { res = poly_eq_base_2(point); } + 3 => { res = poly_eq_base_3(point); } + 4 => { res = poly_eq_base_4(point); } + 5 => { res = poly_eq_base_5(point); } + 6 => { res = poly_eq_base_6(point); } + 7 => { res = poly_eq_base_7(point); } + } + return res; +} + +fn poly_eq_base_7(point) -> 1 { + // n = 7 // return a (normal) pointer to 2^n base field elements, corresponding to the "equality polynomial" at point // Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy] - if n == 0 { - // base case - res = malloc(1); - res[0] = 1; - return res; + res = malloc(128); + + inner_res = poly_eq_base_6(point + 1); + + for i in 0..64 unroll { + res[64 + i] = inner_res[i] * point[0]; + res[i] = inner_res[i] - res[64 + i]; } + + return res; +} + +fn poly_eq_base_6(point) -> 1 { + // n = 6 + res = malloc(64); + + inner_res = poly_eq_base_5(point + 1); + + for i in 0..32 unroll { + res[32 + i] = inner_res[i] * point[0]; + res[i] = inner_res[i] - res[32 + i]; + } + + return res; +} - res = malloc(two_pow_n); +fn poly_eq_base_5(point) -> 1 { + // n = 5 + res = malloc(32); + + inner_res = poly_eq_base_4(point + 1); + + for i in 0..16 unroll { + res[16 + i] = inner_res[i] * point[0]; + res[i] = inner_res[i] - res[16 + i]; + } + + return res; +} - inner_res = poly_eq_base(point + 1, n - 1, two_pow_n / 2); +fn poly_eq_base_4(point) -> 1 { + // n = 4 + res = malloc(16); + + inner_res = poly_eq_base_3(point + 1); + + for i in 0..8 unroll { + res[8 + i] = inner_res[i] * point[0]; + res[i] = inner_res[i] - res[8 + i]; + } + + return res; +} - two_pow_n_minus_1 = two_pow_n / 2; +fn poly_eq_base_3(point) -> 1 { + // n = 3 + res = malloc(8); + + inner_res = poly_eq_base_2(point + 1); + + for i in 0..4 unroll { + res[4 + i] = inner_res[i] * point[0]; + res[i] = inner_res[i] - res[4 + i]; + } + + return res; +} - for i in 0..two_pow_n_minus_1 { - res[two_pow_n_minus_1 + i] = inner_res[i] * point[0]; - res[i] = inner_res[i] - res[two_pow_n_minus_1 + i]; +fn poly_eq_base_2(point) -> 1 { + // n = 2 + res = malloc(4); + + inner_res = poly_eq_base_1(point + 1); + + for i in 0..2 unroll { + res[2 + i] = inner_res[i] * point[0]; + res[i] = inner_res[i] - res[2 + i]; } return res; } +fn poly_eq_base_1(point) -> 1 { + // n = 1 + // Base case: eq(x) = [1 - x, x] + res = malloc(2); + + res[1] = point[0]; + res[0] = 1 - res[1]; + + return res; +} + fn pow(a, b) -> 1 { if b == 0 { @@ -584,21 +682,46 @@ fn pow(a, b) -> 1 { } } -fn sample_bits(fs_state, n) -> 2 { - // return the updated fs_state, and a pointer to n pointers, each pointing to 31 (boolean) field elements - samples = malloc(n); - new_fs_state = fs_sample_helper(fs_state, n, samples); - sampled_bits = malloc(n); - for i in 0..n { - bits = checked_decompose_bits(samples[i]); +fn sample_bits_dynamic(fs_state, n_samples, K) -> 2 { + if n_samples == NUM_QUERIES_0 { + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_0, K); + return new_fs_state, sampled_bits; + } + if n_samples == NUM_QUERIES_1 { + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_1, K); + return new_fs_state, sampled_bits; + } + if n_samples == NUM_QUERIES_2 { + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_2, K); + return new_fs_state, sampled_bits; + } + if n_samples == NUM_QUERIES_3 { + new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_3, K); + return new_fs_state, sampled_bits; + } + print(n_samples); + print(999333); + panic(); +} + +fn sample_bits_const(fs_state, const n_samples, K) -> 2 { + // return the updated fs_state, and a pointer to n pointers, each pointing to 31 (boolean) field elements, + // ... followed by the n corresponding sampled field elements (where we only look at the first K bits) + samples = malloc(n_samples); + new_fs_state = fs_sample_helper(fs_state, n_samples, samples); + sampled_bits = malloc(n_samples * 2); + for i in 0..n_samples unroll { + bits, partial_sum = checked_decompose_bits(samples[i], K); sampled_bits[i] = bits; + sampled_bits[n_samples + i] = partial_sum; } return new_fs_state, sampled_bits; } -fn checked_decompose_bits(a) -> 1 { - // return a pointer to bits of a +fn checked_decompose_bits(a, k) -> 2 { + // return a pointer to the 31 bits of a + // .. and the partial value, reading the first K bits bits = decompose_bits(a); // hint for i in 0..F_BITS unroll { @@ -610,7 +733,8 @@ fn checked_decompose_bits(a) -> 1 { sums[i] = sums[i - 1] + bits[i] * 2**i; } assert a == sums[F_BITS - 1]; - return bits; + partial_sum = sums[k - 1]; + return bits, partial_sum; } fn degree_two_polynomial_sum_at_0_and_1(coeffs) -> 1 { @@ -660,7 +784,7 @@ fn fs_new(transcript) -> 1 { return fs_state; } - fn fs_grinding(fs_state, bits) -> 1 { +fn fs_grinding(fs_state, bits) -> 1 { // WARNING: should not be called 2 times in a row without duplexing in between if bits == 0 { @@ -687,7 +811,7 @@ fn fs_new(transcript) -> 1 { l_updated_ptr = l_r_updated* 8; sampled = l_updated_ptr[7]; - sampled_bits = checked_decompose_bits(sampled); + sampled_bits, _ = checked_decompose_bits(sampled, 1); // 1 is useless here, could be anything for i in 0..bits { assert sampled_bits[i] == 0; } @@ -695,6 +819,7 @@ fn fs_new(transcript) -> 1 { } fn less_than_8(a) inline -> 1 { + // TODO range check if a * (a - 1) * (a - 2) * (a - 3) * (a - 4) * (a - 5) * (a - 6) * (a - 7) == 0 { return 1; // a < 8 } else { @@ -767,11 +892,29 @@ fn fs_hint(fs_state, n) -> 2 { } fn fs_receive_ef(fs_state, n) -> 2 { + match n { + 0 => { } // unreachable + 1 => { final_fs_state, res = fs_receive_ef_const(fs_state, 1); } + 2 => { final_fs_state, res = fs_receive_ef_const(fs_state, 2); } + 3 => { final_fs_state, res = fs_receive_ef_const(fs_state, 3); } + 4 => { final_fs_state, res = fs_receive_ef_const(fs_state, 4); } + 5 => { final_fs_state, res = fs_receive_ef_const(fs_state, 5); } + 6 => { final_fs_state, res = fs_receive_ef_const(fs_state, 6); } + 7 => { final_fs_state, res = fs_receive_ef_const(fs_state, 7); } + 8 => { final_fs_state, res = fs_receive_ef_const(fs_state, 8); } + 9 => { final_fs_state, res = fs_receive_ef_const(fs_state, 9); } + 10 => { final_fs_state, res = fs_receive_ef_const(fs_state, 10); } + } + return final_fs_state, res; +} + + +fn fs_receive_ef_const(fs_state, const n) -> 2 { // return the updated fs_state, and a (normal) pointer to n consecutive EF elements final_fs_state = fs_observe(fs_state, n); res = malloc(n * DIM); // TODO optimize with dot_product - for i in 0..n { // TODO unroll in most cases + for i in 0..n unroll { ptr = (fs_state[0] + i) * 8; for j in 0..DIM unroll { res[i * DIM + j] = ptr[j]; diff --git a/crates/rec_aggregation/src/recursion.rs b/crates/rec_aggregation/src/recursion.rs index 47032ca5..b3473710 100644 --- a/crates/rec_aggregation/src/recursion.rs +++ b/crates/rec_aggregation/src/recursion.rs @@ -1,3 +1,4 @@ +use std::collections::VecDeque; use std::path::Path; use std::time::Instant; @@ -10,14 +11,13 @@ use multilinear_toolkit::prelude::*; use rand::Rng; use rand::SeedableRng; use rand::rngs::StdRng; -use utils::{ - build_prover_state, build_verifier_state, padd_with_zero_to_next_multiple_of, padd_with_zero_to_next_power_of_two, -}; +use utils::build_challenger; +use utils::{build_prover_state, padd_with_zero_to_next_multiple_of}; use whir_p3::{FoldingFactor, SecurityAssumption, WhirConfig, WhirConfigBuilder, precompute_dft_twiddles}; const NUM_VARIABLES: usize = 25; -pub fn run_whir_recursion_benchmark() { +pub fn run_whir_recursion_benchmark(tracing: bool, n_recursions: usize) { let src_file = Path::new(env!("CARGO_MANIFEST_DIR")).join("recursion_program.lean_lang"); let mut program_str = std::fs::read_to_string(src_file).unwrap(); let recursion_config_builder = WhirConfigBuilder { @@ -30,6 +30,8 @@ pub fn run_whir_recursion_benchmark() { rs_domain_initial_reduction_factor: 3, }; + program_str = program_str.replace("N_RECURSIONS_PLACEHOLDER", &n_recursions.to_string()); + let mut recursion_config = WhirConfig::::new(recursion_config_builder.clone(), NUM_VARIABLES); // TODO remove overriding this @@ -55,6 +57,11 @@ pub fn run_whir_recursion_benchmark() { .replace( &format!("GRINDING_BITS_{}_PLACEHOLDER", recursion_config.n_rounds()), &recursion_config.final_pow_bits.to_string(), + ) + .replace("N_VARS_PLACEHOLDER", &NUM_VARIABLES.to_string()) + .replace( + "LOG_INV_RATE_PLACEHOLDER", + &recursion_config_builder.starting_log_inv_rate.to_string(), ); assert_eq!(recursion_config.n_rounds(), 3); // this is hardcoded in the program above for round in 0..=recursion_config.n_rounds() { @@ -82,10 +89,19 @@ pub fn run_whir_recursion_benchmark() { precompute_dft_twiddles::(1 << 24); let witness = recursion_config.commit(&mut prover_state, &polynomial); + recursion_config.prove(&mut prover_state, statement.clone(), witness, &polynomial.by_ref()); + let whir_proof = prover_state.into_proof(); + + { + let mut verifier_state = VerifierState::new(whir_proof.clone(), build_challenger()); + let parsed_commitment = recursion_config.parse_commitment::(&mut verifier_state).unwrap(); + recursion_config + .verify(&mut verifier_state, &parsed_commitment, statement) + .unwrap(); + } - let mut public_input = prover_state.proof_data().to_vec(); - let commitment_size = public_input.len(); - assert_eq!(commitment_size, 16); + let commitment_size = 16; + let mut public_input = whir_proof.proof_data[..commitment_size].to_vec(); public_input.extend(padd_with_zero_to_next_multiple_of( &point .iter() @@ -93,92 +109,69 @@ pub fn run_whir_recursion_benchmark() { .collect::>(), VECTOR_LEN, )); - public_input.extend(padd_with_zero_to_next_power_of_two( + public_input.extend(padd_with_zero_to_next_multiple_of( >::as_basis_coefficients_slice(&eval), + VECTOR_LEN, )); - recursion_config.prove(&mut prover_state, statement.clone(), witness, &polynomial.by_ref()); - - let first_folding_factor = recursion_config_builder.folding_factor.at_round(0); - - // to align the first merkle leaves (in base field) (required to appropriately call the precompile multilinear_eval) - let mut proof_data_padding = (1 << first_folding_factor) - - ((NONRESERVED_PROGRAM_INPUT_START - + public_input.len() - + { - // sumcheck polys - first_folding_factor * 3 * VECTOR_LEN - } - + { - // merkle root - VECTOR_LEN - } - + { - // grinding witness - VECTOR_LEN - } - + { - // ood answer - VECTOR_LEN - }) - % (1 << first_folding_factor)); - assert_eq!(proof_data_padding % 8, 0); - proof_data_padding /= 8; + public_input.extend(whir_proof.proof_data[commitment_size..].to_vec()); - program_str = program_str - .replace( - "PADDING_FOR_INITIAL_MERKLE_LEAVES_PLACEHOLDER", - &proof_data_padding.to_string(), - ) - .replace("N_VARS_PLACEHOLDER", &NUM_VARIABLES.to_string()) - .replace( - "LOG_INV_RATE_PLACEHOLDER", - &recursion_config_builder.starting_log_inv_rate.to_string(), - ); - - public_input.extend(F::zero_vec(proof_data_padding * 8)); + assert!(public_input.len().is_multiple_of(VECTOR_LEN)); + program_str = program_str.replace( + "WHIR_PROOF_SIZE_PLACEHOLDER", + &(public_input.len() / VECTOR_LEN).to_string(), + ); - public_input.extend(prover_state.proof_data()[commitment_size..].to_vec()); + public_input = std::iter::repeat_n(public_input, n_recursions).flatten().collect(); - { - let mut verifier_state = build_verifier_state(&prover_state); - let parsed_commitment = recursion_config.parse_commitment::(&mut verifier_state).unwrap(); - recursion_config - .verify(&mut verifier_state, &parsed_commitment, statement) - .unwrap(); + if tracing { + utils::init_tracing(); } - utils::init_tracing(); let bytecode = compile_program(program_str); + let mut merkle_path_hints = VecDeque::new(); + for _ in 0..n_recursions { + merkle_path_hints.extend(whir_proof.merkle_hints.clone()); + } + // in practice we will precompute all the possible values // (depending on the number of recursions + the number of xmss signatures) // (or even better: find a linear relation) - let no_vec_runtime_memory = - execute_bytecode(&bytecode, (&public_input, &[]), 1 << 20, false, (&vec![], &vec![])).no_vec_runtime_memory; + let no_vec_runtime_memory = execute_bytecode( + &bytecode, + (&public_input, &[]), + 1 << 20, + false, + (&vec![], &vec![]), // TODO + merkle_path_hints.clone(), + ) + .no_vec_runtime_memory; let time = Instant::now(); - let (proof_data, summary) = prove_execution( + let (proof, summary) = prove_execution( &bytecode, (&public_input, &[]), whir_config_builder(), no_vec_runtime_memory, false, (&vec![], &vec![]), // TODO precompute poseidons + merkle_path_hints, ); + let proof_size = proof.proof_size; let proving_time = time.elapsed(); - verify_execution(&bytecode, &public_input, proof_data.clone(), whir_config_builder()).unwrap(); + verify_execution(&bytecode, &public_input, proof, whir_config_builder()).unwrap(); println!("{summary}"); println!( - "WHIR recursion, proving time: {} ms, proof size: {} KiB (not optimized)", - proving_time.as_millis(), - proof_data.len() * F::bits() / (8 * 1024) + "Proving time: {} ms / WHIR recursion, proof size: {} KiB (not optimized)", + proving_time.as_millis() / n_recursions as u128, + proof_size * F::bits() / (8 * 1024) ); } #[test] fn test_whir_recursion() { - run_whir_recursion_benchmark(); + run_whir_recursion_benchmark(false, 1); } diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 9cf7c520..086d8758 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -4,18 +4,17 @@ use lean_prover::{prove_execution::prove_execution, verify_execution::verify_exe use lean_vm::*; use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; +use std::collections::VecDeque; use std::path::Path; use std::sync::OnceLock; use std::time::Instant; use tracing::{info_span, instrument}; use whir_p3::precompute_dft_twiddles; use xmss::{ - Poseidon16History, Poseidon24History, V, XMSS_MAX_LOG_LIFETIME, XmssPublicKey, XmssSignature, - xmss_generate_phony_signatures, xmss_verify_with_poseidon_trace, + Poseidon16History, Poseidon24History, V, XMSS_MAX_LOG_LIFETIME, XMSS_MIN_LOG_LIFETIME, XmssPublicKey, + XmssSignature, xmss_generate_phony_signatures, xmss_verify_with_poseidon_trace, }; -const XMSS_SIG_SIZE_VEC_PADDED: usize = (V + 1 + XMSS_MAX_LOG_LIFETIME) + XMSS_MAX_LOG_LIFETIME.div_ceil(8); - static XMSS_AGGREGATION_PROGRAM: OnceLock = OnceLock::new(); fn get_xmss_aggregation_program() -> &'static XmssAggregationProgram { @@ -26,36 +25,41 @@ pub fn xmss_setup_aggregation_program() { let _ = get_xmss_aggregation_program(); } -fn build_public_input(xmss_pub_keys: &[XmssPublicKey], message_hash: [F; 8]) -> Vec { +// vectorized +fn xmss_sig_size_in_memory() -> usize { + 1 + V +} + +fn build_public_input(xmss_pub_keys: &[XmssPublicKey], message_hash: [F; 8], slot: u64) -> Vec { let mut public_input = message_hash.to_vec(); public_input.extend(xmss_pub_keys.iter().flat_map(|pk| pk.merkle_root)); public_input.extend(xmss_pub_keys.iter().map(|pk| F::from_usize(pk.log_lifetime))); - public_input.extend(F::zero_vec( - xmss_pub_keys.len().next_multiple_of(8) - xmss_pub_keys.len(), - )); + public_input.extend( + xmss_pub_keys + .iter() + .map(|pk| F::from_u64(slot.checked_sub(pk.first_slot).unwrap())), // index in merkle tree + ); let min_public_input_size = (1 << LOG_SMALLEST_DECOMPOSITION_CHUNK) - NONRESERVED_PROGRAM_INPUT_START; public_input.extend(F::zero_vec(min_public_input_size.saturating_sub(public_input.len()))); - let private_input_start = - F::from_usize((public_input.len() + 8 + NONRESERVED_PROGRAM_INPUT_START).next_power_of_two()); public_input.splice( 0..0, [ vec![ - private_input_start, F::from_usize(xmss_pub_keys.len()), - F::from_usize(XMSS_SIG_SIZE_VEC_PADDED), + F::from_usize(xmss_sig_size_in_memory()), ], - vec![F::ZERO; 5], + vec![F::ZERO; 6], ] .concat(), ); public_input } -fn build_private_input(all_signatures: &[XmssSignature], xmss_pub_keys: &[XmssPublicKey]) -> Vec { +fn build_private_input(all_signatures: &[XmssSignature]) -> (Vec, VecDeque>) { let mut private_input = vec![]; - for (signature, pubkey) in all_signatures.iter().zip(xmss_pub_keys) { + let mut merkle_path_hints = VecDeque::>::new(); + for signature in all_signatures { let initial_private_input_len = private_input.len(); private_input.extend(signature.wots_signature.randomness.to_vec()); private_input.extend( @@ -65,19 +69,13 @@ fn build_private_input(all_signatures: &[XmssSignature], xmss_pub_keys: &[XmssPu .iter() .flat_map(|digest| digest.to_vec()), ); - private_input.extend(signature.merkle_proof.iter().copied().flatten()); - let wots_index = signature.slot.checked_sub(pubkey.first_slot).unwrap(); - private_input.extend((0..pubkey.log_lifetime).map(|i| { - if (wots_index >> i).is_multiple_of(2) { - F::ONE - } else { - F::ZERO - } - })); + let sig_size = private_input.len() - initial_private_input_len; - private_input.extend(F::zero_vec(XMSS_SIG_SIZE_VEC_PADDED * VECTOR_LEN - sig_size)); + private_input.extend(F::zero_vec(xmss_sig_size_in_memory() * VECTOR_LEN - sig_size)); + + merkle_path_hints.push_back(signature.merkle_proof.clone()); } - private_input + (private_input, merkle_path_hints) } #[derive(Debug, Clone)] @@ -91,7 +89,7 @@ impl XmssAggregationProgram { pub fn compute_non_vec_memory(&self, log_lifetimes: &[usize]) -> usize { log_lifetimes .iter() - .map(|&ll| self.no_vec_mem_per_log_lifetime[ll - 1]) + .map(|&ll| self.no_vec_mem_per_log_lifetime[ll - XMSS_MIN_LOG_LIFETIME]) .sum::() + self.default_no_vec_mem } @@ -104,7 +102,7 @@ fn compile_xmss_aggregation_program() -> XmssAggregationProgram { let bytecode = compile_program(program_str); let default_no_vec_mem = exec_phony_xmss(&bytecode, &[]).no_vec_runtime_memory; let mut no_vec_mem_per_log_lifetime = vec![]; - for log_lifetime in 1..=XMSS_MAX_LOG_LIFETIME { + for log_lifetime in XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME { let no_vec_mem = exec_phony_xmss(&bytecode, &[log_lifetime]).no_vec_runtime_memory; no_vec_mem_per_log_lifetime.push(no_vec_mem.checked_sub(default_no_vec_mem).unwrap()); } @@ -120,7 +118,7 @@ fn compile_xmss_aggregation_program() -> XmssAggregationProgram { for _ in 0..n_sanity_checks { let n_sigs = rng.random_range(1..=25); let log_lifetimes = (0..n_sigs) - .map(|_| rng.random_range(1..=XMSS_MAX_LOG_LIFETIME)) + .map(|_| rng.random_range(XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME)) .collect::>(); let result = exec_phony_xmss(&res.bytecode, &log_lifetimes); assert_eq!( @@ -136,32 +134,36 @@ fn compile_xmss_aggregation_program() -> XmssAggregationProgram { fn exec_phony_xmss(bytecode: &Bytecode, log_lifetimes: &[usize]) -> ExecutionResult { let mut rng = StdRng::seed_from_u64(0); let message_hash: [F; 8] = rng.random(); - let slot = 1 << 33; + let slot = 1111; let (xmss_pub_keys, all_signatures) = xmss_generate_phony_signatures(log_lifetimes, message_hash, slot); - let public_input = build_public_input(&xmss_pub_keys, message_hash); - let private_input = build_private_input(&all_signatures, &xmss_pub_keys); + let public_input = build_public_input(&xmss_pub_keys, message_hash, slot); + let (private_input, merkle_path_hints) = build_private_input(&all_signatures); execute_bytecode( bytecode, (&public_input, &private_input), 1 << 21, false, (&vec![], &vec![]), + merkle_path_hints, ) } -pub fn run_xmss_benchmark(log_lifetimes: &[usize]) { - utils::init_tracing(); +pub fn run_xmss_benchmark(log_lifetimes: &[usize], tracing: bool) { + if tracing { + utils::init_tracing(); + } xmss_setup_aggregation_program(); precompute_dft_twiddles::(1 << 24); let mut rng = StdRng::seed_from_u64(0); let message_hash: [F; 8] = rng.random(); - let slot = 1 << 33; + let slot = 1111; + let (xmss_pub_keys, all_signatures) = xmss_generate_phony_signatures(log_lifetimes, message_hash, slot); let time = Instant::now(); let (proof_data, n_field_elements_in_proof, summary) = - xmss_aggregate_signatures_helper(&xmss_pub_keys, &all_signatures, message_hash).unwrap(); + xmss_aggregate_signatures_helper(&xmss_pub_keys, &all_signatures, message_hash, slot).unwrap(); let proving_time = time.elapsed(); xmss_verify_aggregated_signatures(&xmss_pub_keys, message_hash, &proof_data, slot).unwrap(); @@ -187,14 +189,14 @@ pub fn xmss_aggregate_signatures( message_hash: [F; 8], slot: u64, ) -> Result, XmssAggregateError> { - let _ = slot; // TODO - Ok(xmss_aggregate_signatures_helper(xmss_pub_keys, all_signatures, message_hash)?.0) + Ok(xmss_aggregate_signatures_helper(xmss_pub_keys, all_signatures, message_hash, slot)?.0) } fn xmss_aggregate_signatures_helper( xmss_pub_keys: &[XmssPublicKey], all_signatures: &[XmssSignature], message_hash: [F; 8], + slot: u64, ) -> Result<(Vec, usize, String), XmssAggregateError> { if xmss_pub_keys.len() != all_signatures.len() { return Err(XmssAggregateError::WrongSignatureCount); @@ -206,28 +208,22 @@ fn xmss_aggregate_signatures_helper( precompute_poseidons(xmss_pub_keys, all_signatures, &message_hash) .ok_or(XmssAggregateError::InvalidSigature)?; - let public_input = build_public_input(xmss_pub_keys, message_hash); - let private_input = build_private_input(all_signatures, xmss_pub_keys); + let public_input = build_public_input(xmss_pub_keys, message_hash, slot); + let (private_input, merkle_path_hints) = build_private_input(all_signatures); - let (proof_field_elements, summary) = prove_execution( + let (proof, summary) = prove_execution( &program.bytecode, (&public_input, &private_input), whir_config_builder(), program.compute_non_vec_memory(&xmss_pub_keys.iter().map(|pk| pk.log_lifetime).collect::>()), false, (&poseidons_16_precomputed, &poseidons_24_precomputed), + merkle_path_hints, ); - let proof_bytes = info_span!("Proof serialization").in_scope(|| { - let mut buff = unsafe { uninitialized_vec(proof_field_elements.len() * 4) }; - buff.par_chunks_exact_mut(4).enumerate().for_each(|(i, chunk)| { - let fe = proof_field_elements[i]; - chunk.copy_from_slice(&fe.as_canonical_u32().to_be_bytes()); - }); - buff - }); + let proof_bytes = info_span!("Proof serialization").in_scope(|| bincode::serialize(&proof).unwrap()); - Ok((proof_bytes, proof_field_elements.len(), summary)) + Ok((proof_bytes, proof.proof_size, summary)) } pub fn xmss_verify_aggregated_signatures( @@ -239,25 +235,13 @@ pub fn xmss_verify_aggregated_signatures( let _ = slot; // TODO let program = get_xmss_aggregation_program(); - let proof_field_elements = info_span!("Proof deserialization").in_scope(|| { - proof_bytes - .par_chunks_exact(4) - .map(|chunk| { - let mut arr = [0u8; 4]; - arr.copy_from_slice(chunk); - F::from_u32(u32::from_be_bytes(arr)) - }) - .collect::>() - }); + let proof = info_span!("Proof deserialization") + .in_scope(|| bincode::deserialize(proof_bytes)) + .map_err(|_| ProofError::InvalidProof)?; - let public_input = build_public_input(xmss_pub_keys, message_hash); + let public_input = build_public_input(xmss_pub_keys, message_hash, slot); - verify_execution( - &program.bytecode, - &public_input, - proof_field_elements, - whir_config_builder(), - ) + verify_execution(&program.bytecode, &public_input, proof, whir_config_builder()) } #[instrument(skip_all)] @@ -290,7 +274,7 @@ fn test_xmss_aggregate() { let n_xmss = 10; let mut rng = StdRng::seed_from_u64(0); let log_lifetimes = (0..n_xmss) - .map(|_| rng.random_range(1..=XMSS_MAX_LOG_LIFETIME)) + .map(|_| rng.random_range(XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME)) .collect::>(); - run_xmss_benchmark(&log_lifetimes); + run_xmss_benchmark(&log_lifetimes, false); } diff --git a/crates/rec_aggregation/xmss_aggregate.lean_lang b/crates/rec_aggregation/xmss_aggregate.lean_lang index ab8c8841..70becb9d 100644 --- a/crates/rec_aggregation/xmss_aggregate.lean_lang +++ b/crates/rec_aggregation/xmss_aggregate.lean_lang @@ -1,6 +1,3 @@ -// Public input: message_hash | all_public_keys | bitield -// Private input: signatures = (randomness | chain_tips | merkle_path) - const COMPRESSION = 1; const PERMUTATION = 0; @@ -8,40 +5,50 @@ const V = 66; const W = 4; const TARGET_SUM = 118; +const V_HALF = V / 2; // V should be even + fn main() { public_input_start_ = public_input_start; - private_input_start = public_input_start_[0]; - n_signatures = public_input_start_[1]; - sig_size = public_input_start_[2]; // vectorized + n_signatures = public_input_start_[0]; + sig_size = public_input_start_[1]; // vectorized message_hash = public_input_start / 8 + 1; all_public_keys = message_hash + 1; all_log_lifetimes = (all_public_keys + n_signatures) * 8; - signatures_start = private_input_start / 8; + all_merkle_indexes = all_log_lifetimes + n_signatures; + + signatures_start_no_vec = private_input_start(); + signatures_start = signatures_start_no_vec / 8; for i in 0..n_signatures { xmss_public_key = all_public_keys + i; signature = signatures_start + i * sig_size; log_lifetime = all_log_lifetimes[i]; - xmss_public_key_recovered = xmss_recover_pub_key(message_hash, signature, log_lifetime); + merkle_index = all_merkle_indexes[i]; + xmss_public_key_recovered = xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index); assert_eq_vec(xmss_public_key, xmss_public_key_recovered); } return; } -fn xmss_recover_pub_key(message_hash, signature, log_lifetime) -> 1 { +fn xmss_recover_pub_key(message_hash, signature, log_lifetime, merkle_index) -> 1 { // message_hash: vectorized pointers (of length 1) - // signature: vectorized pointer = randomness | chain_tips | merkle_neighbours | merkle_are_left + // signature: vectorized pointer = randomness | chain_tips // return a vectorized pointer (of length 1), the hashed xmss public key randomness = signature; // vectorized chain_tips = signature + 1; // vectorized - merkle_neighbours = chain_tips + V; // vectorized - merkle_are_left = (merkle_neighbours + log_lifetime) * 8; // non-vectorized // 1) We encode message_hash + randomness into the d-th layer of the hypercube compressed = malloc_vec(1); poseidon16(message_hash, randomness, compressed, COMPRESSION); compressed_ptr = compressed * 8; - decomposed = decompose_custom(compressed_ptr[0], compressed_ptr[1], compressed_ptr[2], compressed_ptr[3], compressed_ptr[4], compressed_ptr[5]); + compressed_vals = malloc(6); + dot_product_ee(compressed_ptr, pointer_to_one_vector * 8, compressed_vals, 1); + compressed_vals[5] = compressed_ptr[5]; + + encoding = malloc(12 * 6); + remaining = malloc(6); + + decompose_custom(encoding, remaining, compressed_vals[0], compressed_vals[1], compressed_vals[2], compressed_vals[3], compressed_vals[4], compressed_vals[5]); // check that the decomposition is correct for i in 0..6 unroll { @@ -49,14 +56,14 @@ fn xmss_recover_pub_key(message_hash, signature, log_lifetime) -> 1 { // TODO Implem range check (https://github.com/leanEthereum/leanMultisig/issues/52) // For now we use dummy instructions to replicate exactly the cost - // assert decomposed[i * 13 + j] < 4; + // assert encoding[i * 12 + j] < 4; dummy_0 = 88888888; assert dummy_0 == 88888888; assert dummy_0 == 88888888; assert dummy_0 == 88888888; } - // assert decomposed[i * 13 + 12] < 2^7 - 1; + // assert remaining[i] < 2^7 - 1; dummy_1 = 88888888; dummy_2 = 88888888; dummy_3 = 88888888; @@ -64,21 +71,14 @@ fn xmss_recover_pub_key(message_hash, signature, log_lifetime) -> 1 { assert dummy_2 == 88888888; assert dummy_3 == 88888888; - partial_sums = malloc(12); - partial_sums[0] = decomposed[i * 13]; - for j in 1..12 unroll { - partial_sums[j] = partial_sums[j - 1] + (decomposed[i * 13 + j]) * 4**j; + partial_sums = malloc(13); + partial_sums[0] = remaining[i] * 2**24; + for j in 1..13 unroll { + partial_sums[j] = partial_sums[j - 1] + encoding[i * 12 + (j-1)] * 4**(j-1); } - assert partial_sums[11] + (decomposed[i * 13 + 12]) * 4**12 == compressed_ptr[i]; + assert partial_sums[12] == compressed_vals[i]; } - encoding = malloc(12 * 6); - for i in 0..6 unroll { - for j in 0..12 unroll { - encoding[i * 12 + j] = decomposed[i * 13 + j]; - } - } - // we need to check the target sum sums = malloc(V); sums[0] = encoding[0]; @@ -126,70 +126,48 @@ fn xmss_recover_pub_key(message_hash, signature, log_lifetime) -> 1 { } } - public_key_hashed = malloc_vec(V / 2); - poseidon24(public_key, pointer_to_zero_vector, public_key_hashed); - - for i in 1..V / 2 unroll { - poseidon24(public_key + (2*i), public_key_hashed + (i - 1), public_key_hashed + i); - } - - wots_pubkey_hashed = public_key_hashed + (V / 2 - 1); + wots_pubkey_hashed = malloc_vec(1); + slice_hash(pointer_to_zero_vector, public_key, wots_pubkey_hashed, V_HALF); - // TODO unroll + merkle_root = malloc_vec(1); match log_lifetime { - 0 => { merkle_hash = verify_merkle_path(0, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 1 => { merkle_hash = verify_merkle_path(1, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 2 => { merkle_hash = verify_merkle_path(2, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 3 => { merkle_hash = verify_merkle_path(3, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 4 => { merkle_hash = verify_merkle_path(4, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 5 => { merkle_hash = verify_merkle_path(5, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 6 => { merkle_hash = verify_merkle_path(6, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 7 => { merkle_hash = verify_merkle_path(7, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 8 => { merkle_hash = verify_merkle_path(8, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 9 => { merkle_hash = verify_merkle_path(9, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 10 => { merkle_hash = verify_merkle_path(10, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 11 => { merkle_hash = verify_merkle_path(11, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 12 => { merkle_hash = verify_merkle_path(12, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 13 => { merkle_hash = verify_merkle_path(13, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 14 => { merkle_hash = verify_merkle_path(14, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 15 => { merkle_hash = verify_merkle_path(15, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 16 => { merkle_hash = verify_merkle_path(16, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 17 => { merkle_hash = verify_merkle_path(17, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 18 => { merkle_hash = verify_merkle_path(18, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 19 => { merkle_hash = verify_merkle_path(19, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 20 => { merkle_hash = verify_merkle_path(20, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 21 => { merkle_hash = verify_merkle_path(21, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 22 => { merkle_hash = verify_merkle_path(22, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 23 => { merkle_hash = verify_merkle_path(23, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 24 => { merkle_hash = verify_merkle_path(24, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 25 => { merkle_hash = verify_merkle_path(25, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 26 => { merkle_hash = verify_merkle_path(26, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 27 => { merkle_hash = verify_merkle_path(27, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 28 => { merkle_hash = verify_merkle_path(28, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 29 => { merkle_hash = verify_merkle_path(29, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 30 => { merkle_hash = verify_merkle_path(30, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 31 => { merkle_hash = verify_merkle_path(31, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - 32 => { merkle_hash = verify_merkle_path(32, merkle_are_left, wots_pubkey_hashed, merkle_neighbours); } - } - - return merkle_hash; -} + 0 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 0); } + 1 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 1); } + 2 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 2); } + 3 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 3); } + 4 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 4); } + 5 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 5); } + 6 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 6); } + 7 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 7); } + 8 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 8); } + 9 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 9); } + 10 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 10); } + 11 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 11); } + 12 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 12); } + 13 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 13); } + 14 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 14); } + 15 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 15); } + 16 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 16); } + 17 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 17); } + 18 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 18); } + 19 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 19); } + 20 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 20); } + 21 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 21); } + 22 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 22); } + 23 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 23); } + 24 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 24); } + 25 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 25); } + 26 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 26); } + 27 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 27); } + 28 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 28); } + 29 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 29); } + 30 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 30); } + 31 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 31); } + 32 => { merkle_verify(wots_pubkey_hashed, merkle_index, merkle_root, 32); } + } + -fn verify_merkle_path(const height, merkle_are_left, leaf_hash, merkle_neighbours) -> 1 { - merkle_hashes = malloc_vec(height); - if merkle_are_left[0] == 1 { - poseidon16(leaf_hash, merkle_neighbours, merkle_hashes, COMPRESSION); - } else { - poseidon16(merkle_neighbours, leaf_hash, merkle_hashes, COMPRESSION); - } - for h in 1..height unroll { - if merkle_are_left[h] == 1 { - poseidon16(merkle_hashes + (h-1), merkle_neighbours + h, merkle_hashes + h, COMPRESSION); - } else { - poseidon16(merkle_neighbours + h, merkle_hashes + (h-1), merkle_hashes + h, COMPRESSION); - } - } - return merkle_hashes + (height - 1); + return merkle_root; } fn assert_eq_vec(x, y) inline { diff --git a/crates/sub_protocols/src/commit_extension_from_base.rs b/crates/sub_protocols/src/commit_extension_from_base.rs index 66aa6c40..cc677fc8 100644 --- a/crates/sub_protocols/src/commit_extension_from_base.rs +++ b/crates/sub_protocols/src/commit_extension_from_base.rs @@ -66,7 +66,7 @@ impl ExtensionCommitmentFromBaseVerifier { let mut statements_remaning_to_verify = Vec::new(); for (chunk, claim_value) in sub_evals.chunks_exact(EF::DIMENSION).zip(&claim.values) { - if dot_product_with_base(&sub_evals) != *claim_value { + if dot_product_with_base(chunk) != *claim_value { return Err(ProofError::InvalidProof); } statements_remaning_to_verify.extend( diff --git a/crates/sub_protocols/src/packed_pcs.rs b/crates/sub_protocols/src/packed_pcs.rs index 85ee8626..1362e6f3 100644 --- a/crates/sub_protocols/src/packed_pcs.rs +++ b/crates/sub_protocols/src/packed_pcs.rs @@ -670,7 +670,7 @@ mod tests { &witness.packed_polynomial.by_ref(), ); - let mut verifier_state = build_verifier_state(&prover_state); + let mut verifier_state = build_verifier_state(prover_state); let parsed_commitment = packed_pcs_parse_commitment( &whir_config_builder, diff --git a/crates/sub_protocols/tests/test_generic_packed_lookup.rs b/crates/sub_protocols/tests/test_generic_packed_lookup.rs index 2cc2e07e..2c4e166f 100644 --- a/crates/sub_protocols/tests/test_generic_packed_lookup.rs +++ b/crates/sub_protocols/tests/test_generic_packed_lookup.rs @@ -73,7 +73,7 @@ fn test_generic_packed_lookup() { let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); - let mut verifier_state = build_verifier_state(&prover_state); + let mut verifier_state = build_verifier_state(prover_state); let packed_lookup_verifier = GenericPackedLookupVerifier::step_1( &mut verifier_state, diff --git a/crates/sub_protocols/tests/test_normal_packed_lookup.rs b/crates/sub_protocols/tests/test_normal_packed_lookup.rs index 89002f1f..458509a1 100644 --- a/crates/sub_protocols/tests/test_normal_packed_lookup.rs +++ b/crates/sub_protocols/tests/test_normal_packed_lookup.rs @@ -106,7 +106,7 @@ fn test_normal_packed_lookup() { let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); - let mut verifier_state = build_verifier_state(&prover_state); + let mut verifier_state = build_verifier_state(prover_state); let packed_lookup_verifier = NormalPackedLookupVerifier::step_1( &mut verifier_state, diff --git a/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs b/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs index cb7b08bd..8721962e 100644 --- a/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs +++ b/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs @@ -79,7 +79,7 @@ fn test_vectorized_packed_lookup() { let remaining_claims_to_prove = packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); - let mut verifier_state = build_verifier_state(&prover_state); + let mut verifier_state = build_verifier_state(prover_state); let packed_lookup_verifier = VectorizedPackedLookupVerifier::<_, VECTOR_LEN>::step_1( &mut verifier_state, diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index 9254bb74..480b7d4c 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -67,10 +67,6 @@ macro_rules! assert_eq_many { }; } -pub fn powers_const(base: F) -> [F; N] { - base.powers().collect_n(N).try_into().unwrap() -} - #[instrument(skip_all)] pub fn transpose(matrix: &[F], width: usize, column_extra_capacity: usize) -> Vec> { assert!((matrix.len().is_multiple_of(width))); diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index 6e0ca371..29030d98 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -19,13 +19,9 @@ pub fn build_prover_state>(padding: bool) -> Prove } pub fn build_verifier_state>( - prover_state: &ProverState, + prover_state: ProverState, ) -> VerifierState { - VerifierState::new( - prover_state.proof_data().to_vec(), - build_challenger(), - prover_state.has_padding(), - ) + VerifierState::new(prover_state.into_proof(), build_challenger()) } pub trait ToUsize { diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 9708b37a..3c2e8ee2 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -18,6 +18,7 @@ pub const W: usize = 4; pub const CHAIN_LENGTH: usize = 1 << W; pub const D: usize = 80; pub const TARGET_SUM: usize = V * (W - 1) - D; +pub const XMSS_MIN_LOG_LIFETIME: usize = 2; pub const XMSS_MAX_LOG_LIFETIME: usize = 30; pub type Poseidon16History = Vec<([F; 16], [F; 16])>; diff --git a/crates/xmss/src/phony_xmss.rs b/crates/xmss/src/phony_xmss.rs index 923f3df2..e2f1dacc 100644 --- a/crates/xmss/src/phony_xmss.rs +++ b/crates/xmss/src/phony_xmss.rs @@ -66,6 +66,8 @@ pub fn xmss_generate_phony_signatures( .par_iter() .enumerate() .map(|(i, &log_lifetime)| { + assert!(log_lifetime >= XMSS_MIN_LOG_LIFETIME); + assert!(log_lifetime <= XMSS_MAX_LOG_LIFETIME); let mut rng = StdRng::seed_from_u64(i as u64); let first_slot = slot - rng.random_range(0..(1 << log_lifetime).min(slot)); let xmss_secret_key = PhonyXmssSecretKey::random(&mut rng, first_slot, log_lifetime, slot); diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs index 2f46de15..927513a0 100644 --- a/crates/xmss/src/xmss.rs +++ b/crates/xmss/src/xmss.rs @@ -48,7 +48,7 @@ pub fn xmss_key_gen( if first_slot >= (1 << XMSS_MAX_LOG_LIFETIME) { return Err(XmssKeyGenError::FirstSlotTooLarge); } - if log_lifetime == 0 { + if log_lifetime < XMSS_MIN_LOG_LIFETIME { return Err(XmssKeyGenError::LogLifetimeTooSmall); } if log_lifetime > XMSS_MAX_LOG_LIFETIME { diff --git a/docs/benchmark_graphs/graphs/recursive_whir_opening.svg b/docs/benchmark_graphs/graphs/recursive_whir_opening.svg index 46021aea..b6f1deba 100644 --- a/docs/benchmark_graphs/graphs/recursive_whir_opening.svg +++ b/docs/benchmark_graphs/graphs/recursive_whir_opening.svg @@ -1,12 +1,12 @@ - + - 2025-11-15T22:38:00.887118 + 2025-12-04T21:10:16.841222 image/svg+xml @@ -21,18 +21,18 @@ - - @@ -40,23 +40,23 @@ z - +" clip-path="url(#p9e6f5d0f00)" style="fill: none; stroke: #b0b0b0; stroke-opacity: 0.3; stroke-width: 0.8; stroke-linecap: square"/> - - + - + - + - + - + - + - + - + @@ -266,18 +266,18 @@ L 144.865637 34.3575 - + - + - + - + - + - + - + - + - + @@ -382,18 +382,18 @@ L 277.365343 34.3575 - + - + - + - + - + - + - + - + - + @@ -481,18 +481,18 @@ L 409.865049 34.3575 - + - + - + @@ -503,18 +503,18 @@ L 454.031618 34.3575 - + - + - + @@ -525,18 +525,18 @@ L 498.198186 34.3575 - + - + - + @@ -547,18 +547,18 @@ L 542.364755 34.3575 - + - + - + @@ -569,18 +569,18 @@ L 586.531324 34.3575 - + - + - + - + - + - + @@ -638,27 +638,115 @@ L 674.864461 34.3575 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - - + - + - + - - + + - + - + - + - + @@ -697,24 +785,24 @@ L 700.1025 142.22227 - - + + - + - - + - + - + @@ -722,19 +810,19 @@ L -2 0 - - + + - + - + - + - + @@ -742,19 +830,19 @@ L 700.1025 292.131693 - - + + - + - + - + - + @@ -762,19 +850,19 @@ L 700.1025 254.365067 - - + + - + - + - + - + @@ -782,19 +870,19 @@ L 700.1025 227.569219 - - + + - + - + - + - + @@ -802,19 +890,19 @@ L 700.1025 206.784744 - - + + - + - + - + - + @@ -822,19 +910,19 @@ L 700.1025 189.802592 - - + + - + - + - + - + @@ -842,19 +930,19 @@ L 700.1025 175.444387 - - + + - + - + - + - + @@ -862,19 +950,19 @@ L 700.1025 163.006744 - - + + - + - + - + - + @@ -882,19 +970,19 @@ L 700.1025 152.035966 - - + + - + - + - + - + @@ -902,28 +990,28 @@ L 700.1025 77.659796 - - + + - + - + - + - + - + - + - + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - + + - - - - - + + + + + - - + + + + + + + + + + + - - - - + - + - - - + - + - + - + - - + - + - + - + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + + diff --git a/docs/benchmark_graphs/graphs/xmss_aggregated.svg b/docs/benchmark_graphs/graphs/xmss_aggregated.svg index c7580dd0..39c78623 100644 --- a/docs/benchmark_graphs/graphs/xmss_aggregated.svg +++ b/docs/benchmark_graphs/graphs/xmss_aggregated.svg @@ -1,12 +1,12 @@ - + - 2025-11-15T22:38:01.014383 + 2025-12-04T21:10:16.906816 image/svg+xml @@ -21,18 +21,18 @@ - - @@ -40,23 +40,23 @@ z - +" clip-path="url(#p74f34e2e3d)" style="fill: none; stroke: #b0b0b0; stroke-opacity: 0.3; stroke-width: 0.8; stroke-linecap: square"/> - - + - + - + - + - + - + - + - + @@ -266,18 +266,18 @@ L 134.882157 34.3575 - + - + - + - + - + - + - + - + - + @@ -382,18 +382,18 @@ L 270.095392 34.3575 - + - + - + - + - + - + - + - + - + @@ -481,18 +481,18 @@ L 405.308627 34.3575 - + - + - + @@ -503,18 +503,18 @@ L 450.379706 34.3575 - + - + - + @@ -525,18 +525,18 @@ L 495.450784 34.3575 - + - + - + @@ -547,18 +547,18 @@ L 540.521863 34.3575 - + - + - + @@ -569,18 +569,18 @@ L 585.592941 34.3575 - + - + - + - + - + - + @@ -638,45 +638,133 @@ L 675.735098 34.3575 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - - + - + - + - - + + - + - + - + - + @@ -684,19 +772,19 @@ L 701.49 310.646072 - - + + - + - + - + - + @@ -704,19 +792,19 @@ L 701.49 249.248612 - - + + - + - + - + - + @@ -724,19 +812,19 @@ L 701.49 187.851151 - - + + - + - + - + - + @@ -744,19 +832,19 @@ L 701.49 126.453691 - - + + - + - + - + - + @@ -765,26 +853,26 @@ L 701.49 65.05623 - - + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - + + - - - - - - + + + + + + - - + + + + + + + + + + + - - - - + - + - - - + - + - + - + - - + - + - + - + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + + diff --git a/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg b/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg index fb5888b6..8407b476 100644 --- a/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg +++ b/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg @@ -1,12 +1,12 @@ - + - 2025-11-15T22:38:01.125000 + 2025-12-04T21:10:16.970754 image/svg+xml @@ -21,18 +21,18 @@ - - @@ -40,23 +40,23 @@ z - +" clip-path="url(#pab2b291bd7)" style="fill: none; stroke: #b0b0b0; stroke-opacity: 0.3; stroke-width: 0.8; stroke-linecap: square"/> - - + - + - + - + - + - + - + - + @@ -266,18 +266,18 @@ L 131.567377 34.3575 - + - + - + - + - + - + - + - + - + @@ -382,18 +382,18 @@ L 267.540319 34.3575 - + - + - + - + - + - + - + - + - + @@ -481,18 +481,18 @@ L 403.51326 34.3575 - + - + - + @@ -503,18 +503,18 @@ L 448.837574 34.3575 - + - + - + @@ -525,18 +525,18 @@ L 494.161887 34.3575 - + - + - + @@ -547,18 +547,18 @@ L 539.486201 34.3575 - + - + - + @@ -569,18 +569,18 @@ L 584.810515 34.3575 - + - + - + - + - + - + @@ -638,27 +638,115 @@ L 675.459142 34.3575 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - - + - + - + - - + + - + - + - + - + @@ -695,19 +783,19 @@ L 701.35875 325.426942 - - + + - + - + - + - + @@ -715,19 +803,19 @@ L 701.35875 278.810352 - - + + - + - + - + - + @@ -735,19 +823,19 @@ L 701.35875 232.193762 - - + + - + - + - + - + @@ -756,19 +844,19 @@ L 701.35875 185.577171 - - + + - + - + - + - + @@ -777,19 +865,19 @@ L 701.35875 138.960581 - - + + - + - + - + - + @@ -798,19 +886,19 @@ L 701.35875 92.34399 - - + + - + - + - + - + @@ -819,27 +907,27 @@ L 701.35875 45.7274 - - + + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - + + - - - - - - - + + + + + + + + + + + + + + + + - - + + - - - - + - + - - - + - + - + - + - - + - + - + - + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + + diff --git a/docs/benchmark_graphs/main.py b/docs/benchmark_graphs/main.py index 9cc81efe..e13aa390 100644 --- a/docs/benchmark_graphs/main.py +++ b/docs/benchmark_graphs/main.py @@ -5,7 +5,7 @@ # uv run python main.py -N_DAYS_SHOWN = 100 +N_DAYS_SHOWN = 130 plt.rcParams.update({ 'font.size': 12, # Base font size @@ -17,48 +17,44 @@ }) -def create_duration_graph(data, target=None, target_label=None, title="", y_legend="", file="", label1="Series 1", label2=None, log_scale=False): - dates = [] - values1 = [] - values2 = [] +def create_duration_graph(data, target=None, target_label=None, title="", y_legend="", file="", labels=None, log_scale=False): + if labels is None: + labels = ["Series 1"] + + # Number of curves based on tuple length + num_curves = len(data[0]) - 1 if data else 1 # -1 for the date - # Check if data contains triplets or pairs - has_second_curve = len(data[0]) == 3 if data else False + dates = [] + values = [[] for _ in range(num_curves)] for item in data: - if has_second_curve: - day, perf1, perf2 = item - dates.append(datetime.strptime(day, '%Y-%m-%d')) - values1.append(perf1) - values2.append(perf2) - else: - day, perf1 = item - dates.append(datetime.strptime(day, '%Y-%m-%d')) - values1.append(perf1) - - color = '#2E86AB' - color2 = '#A23B72' # Different color for second curve + dates.append(datetime.strptime(item[0], '%Y-%m-%d')) + for i in range(num_curves): + values[i].append(item[i + 1]) + + colors = ['#2E86AB', '#A23B72', '#28A745', '#FF6F00', '#6A1B9A'] + markers = ['o', 's', '^', 'D', 'v'] _, ax = plt.subplots(figsize=(10, 6)) - # Filter out None values for first curve - dates1_filtered = [d for d, v in zip(dates, values1) if v is not None] - values1_filtered = [v for v in values1 if v is not None] + all_values = [] - ax.plot(dates1_filtered, values1_filtered, marker='o', linewidth=2, - markersize=7, color=color, label=label1) + for i in range(num_curves): + if i >= len(labels): + break # No label provided for this curve - # Plot second curve if it exists - if has_second_curve and label2 is not None: - # Filter out None values for second curve - dates2_filtered = [d for d, v in zip(dates, values2) if v is not None] - values2_filtered = [v for v in values2 if v is not None] + # Filter out None values + dates_filtered = [d for d, v in zip(dates, values[i]) if v is not None] + values_filtered = [v for v in values[i] if v is not None] - ax.plot(dates2_filtered, values2_filtered, marker='s', linewidth=2, - markersize=7, color=color2, label=label2) - all_values = values1_filtered + values2_filtered - else: - all_values = values1_filtered + if values_filtered: # Only plot if there's data + ax.plot(dates_filtered, values_filtered, + marker=markers[i % len(markers)], + linewidth=2, + markersize=7, + color=colors[i % len(colors)], + label=labels[i]) + all_values.extend(values_filtered) min_date = min(dates) max_date = max(dates) @@ -73,32 +69,27 @@ def create_duration_graph(data, target=None, target_label=None, title="", y_lege plt.setp(ax.xaxis.get_majorticklabels(), rotation=50, ha='right') if target is not None and target_label is not None: - ax.axhline(y=target, color=color, linestyle='--', + ax.axhline(y=target, color='#555555', linestyle='--', linewidth=2, label=target_label) ax.set_ylabel(y_legend, fontsize=12) ax.set_title(title, fontsize=16, pad=15) - ax.grid(True, alpha=0.3, which='both') # Grid for both major and minor ticks + ax.grid(True, alpha=0.3, which='both') ax.legend() - # Set log scale if requested if log_scale: ax.set_yscale('log') - # Set locators for major and minor ticks ax.yaxis.set_major_locator(LogLocator(base=10.0, numticks=15)) ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs=range(1, 10), numticks=100)) - # Format labels ax.yaxis.set_major_formatter(ScalarFormatter()) ax.yaxis.set_minor_formatter(ScalarFormatter()) ax.yaxis.get_major_formatter().set_scientific(False) ax.yaxis.get_minor_formatter().set_scientific(False) - # Show minor tick labels ax.tick_params(axis='y', which='minor', labelsize=10) else: - # Adjust y-limit to accommodate both curves (only for linear scale) if all_values: max_value = max(all_values) if target is not None: @@ -113,7 +104,7 @@ def create_duration_graph(data, target=None, target_label=None, title="", y_lege create_duration_graph( data=[ - ('2025-08-27', 85000, None), + ('2025-08-27', 85000, None,), ('2025-08-30', 95000, None), ('2025-09-09', 108000, None), ('2025-09-14', 108000, None), @@ -130,90 +121,89 @@ def create_duration_graph(data, target=None, target_label=None, title="", y_lege title="Raw Poseidon2", y_legend="Poseidons proven / s", file="raw_poseidons", - label1="i9-12900H", - label2="mac m4 max", + labels=["i9-12900H", "mac m4 max"], log_scale=False ) create_duration_graph( data=[ - ('2025-08-27', 2.7, None), - ('2025-09-07', 1.4, None), - ('2025-09-09', 1.32, None), - ('2025-09-10', 0.970, None), - ('2025-09-14', 0.825, None), - ('2025-09-28', 0.725, None), - ('2025-10-01', 0.685, None), - ('2025-10-03', 0.647, None), - ('2025-10-12', 0.569, None), - ('2025-10-13', 0.521, None), - ('2025-10-18', 0.411, 0.320), - ('2025-10-27', 0.425, 0.330), - ('2025-11-15', 0.417, 0.330), + ('2025-08-27', 2.7, None, None), + ('2025-09-07', 1.4, None, None), + ('2025-09-09', 1.32, None, None), + ('2025-09-10', 0.970, None, None), + ('2025-09-14', 0.825, None, None), + ('2025-09-28', 0.725, None, None), + ('2025-10-01', 0.685, None, None), + ('2025-10-03', 0.647, None, None), + ('2025-10-12', 0.569, None, None), + ('2025-10-13', 0.521, None, None), + ('2025-10-18', 0.411, 0.320, None), + ('2025-10-27', 0.425, 0.330, None), + ('2025-11-15', 0.417, 0.330, None), + ('2025-12-04', None, 0.097, 0.130), ], target=0.1, target_label="Target (0.1 s)", title="Recursive WHIR opening (log scale)", y_legend="Proving time (s)", file="recursive_whir_opening", - label1="i9-12900H", - label2="mac m4 max", + labels=["i9-12900H", "mac m4 max", "mac m4 max | lean-vm-simple"], log_scale=True ) create_duration_graph( data=[ - ('2025-08-27', 35, None), - ('2025-09-02', 37, None), - ('2025-09-03', 53, None), - ('2025-09-09', 62, None), - ('2025-09-10', 76, None), - ('2025-09-14', 107, None), - ('2025-09-28', 137, None), - ('2025-10-01', 172, None), - ('2025-10-03', 177, None), - ('2025-10-07', 193, None), - ('2025-10-12', 214, None), - ('2025-10-13', 234, None), - ('2025-10-18', 255, 465), - ('2025-10-27', 314, 555), - ('2025-11-02', 350, 660), - ('2025-11-15', 380, 720), + ('2025-08-27', 35, None, None), + ('2025-09-02', 37, None, None), + ('2025-09-03', 53, None, None), + ('2025-09-09', 62, None, None), + ('2025-09-10', 76, None, None), + ('2025-09-14', 107, None, None), + ('2025-09-28', 137, None, None), + ('2025-10-01', 172, None, None), + ('2025-10-03', 177, None, None), + ('2025-10-07', 193, None, None), + ('2025-10-12', 214, None, None), + ('2025-10-13', 234, None, None), + ('2025-10-18', 255, 465, None), + ('2025-10-27', 314, 555, None), + ('2025-11-02', 350, 660, None), + ('2025-11-15', 380, 720, None), + ('2025-12-04', None, 940, 755), ], target=1000, target_label="Target (1000 XMSS/s)", title="number of XMSS aggregated / s", y_legend="", file="xmss_aggregated", - label1="i9-12900H", - label2="mac m4 max" + labels=["i9-12900H", "mac m4 max", "mac m4 max | lean-vm-simple"] ) create_duration_graph( data=[ - ('2025-08-27', 14.2 / 0.92, None), - ('2025-09-02', 13.5 / 0.82, None), - ('2025-09-03', 9.4 / 0.82, None), - ('2025-09-09', 8.02 / 0.72, None), - ('2025-09-10', 6.53 / 0.72, None), - ('2025-09-14', 4.65 / 0.72, None), - ('2025-09-28', 3.63 / 0.63, None), - ('2025-10-01', 2.9 / 0.42, None), - ('2025-10-03', 2.81 / 0.42, None), - ('2025-10-07', 2.59 / 0.42, None), - ('2025-10-12', 2.33 / 0.40, None), - ('2025-10-13', 2.13 / 0.38, None), - ('2025-10-18', 1.96 / 0.37, 1.07 / 0.12), - ('2025-10-27', (610_000 / 157) / 314, (1_250_000 / 157) / 555), - ('2025-10-29', (650_000 / 157) / 314, (1_300_000 / 157) / 555), - ('2025-11-02', (650_000 / 157) / 350, (1_300_000 / 157) / 660), - ('2025-11-15', (650_000 / 157) / 380, (1_300_000 / 157) / 720), + ('2025-08-27', 14.2 / 0.92, None, None), + ('2025-09-02', 13.5 / 0.82, None, None), + ('2025-09-03', 9.4 / 0.82, None, None), + ('2025-09-09', 8.02 / 0.72, None, None), + ('2025-09-10', 6.53 / 0.72, None, None), + ('2025-09-14', 4.65 / 0.72, None, None), + ('2025-09-28', 3.63 / 0.63, None, None), + ('2025-10-01', 2.9 / 0.42, None, None), + ('2025-10-03', 2.81 / 0.42, None, None), + ('2025-10-07', 2.59 / 0.42, None, None), + ('2025-10-12', 2.33 / 0.40, None, None), + ('2025-10-13', 2.13 / 0.38, None, None), + ('2025-10-18', 1.96 / 0.37, 1.07 / 0.12, None), + ('2025-10-27', (610_000 / 157) / 314, (1_250_000 / 157) / 555, None), + ('2025-10-29', (650_000 / 157) / 314, (1_300_000 / 157) / 555, None), + ('2025-11-02', (650_000 / 157) / 350, (1_300_000 / 157) / 660, None), + ('2025-11-15', (650_000 / 157) / 380, (1_300_000 / 157) / 720, None), + ('2025-12-04', None, (1_300_000 / 157) / 940, (1_300_000 / 157) / 755), ], target=2, target_label="Target (2x)", title="XMSS aggregated: zkVM overhead vs raw Poseidons", y_legend="", file="xmss_aggregated_overhead", - label1="i9-12900H", - label2="mac m4 max" + labels=["i9-12900H", "mac m4 max", "mac m4 max | lean-vm-simple"] ) \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 71fad9af..3dce06aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub use multilinear_toolkit::prelude::{ pub use rec_aggregation::xmss_aggregate::{xmss_aggregate_signatures, xmss_verify_aggregated_signatures}; pub use xmss::{ XMSS_MAX_LOG_LIFETIME, + XMSS_MIN_LOG_LIFETIME, XmssPublicKey, XmssSecretKey, xmss_generate_phony_signatures, // useful for tests @@ -59,9 +60,9 @@ mod tests { // (Actually, no need to call it if `xmss_aggregation_setup_prover` was already called) xmss_aggregation_setup_verifier(); - let log_lifetimes = (10..=XMSS_MAX_LOG_LIFETIME).collect::>(); + let log_lifetimes = (XMSS_MIN_LOG_LIFETIME..=XMSS_MAX_LOG_LIFETIME).collect::>(); let message_hash: [F; 8] = std::array::from_fn(|i| F::from_usize(i * 7)); - let slot = 1 << 33; + let slot = 77777; let (xmss_pub_keys, all_signatures) = xmss_generate_phony_signatures(&log_lifetimes, message_hash, slot); diff --git a/src/main.rs b/src/main.rs index 6d284f5e..7aed3d56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,13 +9,22 @@ enum Cli { Xmss { #[arg(long)] n_signatures: usize, + #[arg(long, help = "Enable tracing")] + tracing: bool, }, #[command(about = "Run 1 WHIR recursive proof")] - Recursion, + Recursion { + #[arg(long, help = "Enable tracing")] + tracing: bool, + #[arg(long, default_value_t = 1, help = "Number of recursions")] + count: usize, + }, #[command(about = "Prove validity of Poseidon2 permutations over 16 field elements")] Poseidon { #[arg(long, help = "log2(number of Poseidons)")] log_n_perms: usize, + #[arg(long, help = "Enable tracing")] + tracing: bool, }, } @@ -23,15 +32,18 @@ fn main() { let cli = Cli::parse(); match cli { - Cli::Xmss { n_signatures } => { + Cli::Xmss { n_signatures, tracing } => { let log_lifetimes = (0..n_signatures).map(|_| XMSS_MAX_LOG_LIFETIME).collect::>(); - run_xmss_benchmark(&log_lifetimes); + run_xmss_benchmark(&log_lifetimes, tracing); } - Cli::Recursion => { - run_whir_recursion_benchmark(); + Cli::Recursion { tracing, count } => { + run_whir_recursion_benchmark(tracing, count); } - Cli::Poseidon { log_n_perms: log_count } => { - run_poseidon_benchmark::<16, 16, 3>(log_count, false); + Cli::Poseidon { + log_n_perms: log_count, + tracing, + } => { + run_poseidon_benchmark::<16, 16, 3>(log_count, false, tracing); } } }