diff --git a/Cargo.lock b/Cargo.lock index f734ff0d..769fc395 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,9 +17,7 @@ version = "0.1.0" dependencies = [ "multilinear-toolkit", "p3-air", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-uni-stark", "p3-util", "rand", @@ -68,22 +66,22 @@ dependencies = [ [[package]] name = "anstyle-query" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] name = "anstyle-wincon" -version = "3.0.10" +version = "3.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -95,7 +93,7 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "backend" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#43c8955c2ca07c1c64962f190aab38c28e53adbb" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git?branch=modular-precompiles#8bf0cfebcddd4292a0be65d795c43da680d70336" dependencies = [ "fiat-shamir", "itertools", @@ -179,12 +177,11 @@ dependencies = [ [[package]] name = "constraints-folder" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#43c8955c2ca07c1c64962f190aab38c28e53adbb" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git?branch=modular-precompiles#8bf0cfebcddd4292a0be65d795c43da680d70336" dependencies = [ "fiat-shamir", "p3-air", "p3-field", - "p3-matrix", ] [[package]] @@ -232,9 +229,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", "typenum", @@ -281,7 +278,7 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "fiat-shamir" version = "0.1.0" -source = "git+https://github.com/leanEthereum/fiat-shamir.git#211b12c35c9742c3d2ec0477381954208f97986c" +source = "git+https://github.com/leanEthereum/fiat-shamir.git?branch=modular-precompiles#922ea9dc0c91456efc7fb08d1caa3c3362447d8e" dependencies = [ "p3-challenger", "p3-field", @@ -290,9 +287,9 @@ dependencies = [ [[package]] name = "generic-array" -version = "0.14.9" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bb6743198531e02858aeaea5398fcc883e71851fcbcb5a2f773e2fb6cb1edf2" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", @@ -343,15 +340,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -[[package]] -name = "lean-multisig" -version = "0.1.0" -dependencies = [ - "clap", - "poseidon_circuit", - "rec_aggregation", -] - [[package]] name = "lean_compiler" version = "0.1.0" @@ -362,18 +350,14 @@ dependencies = [ "multilinear-toolkit", "p3-air", "p3-challenger", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-poseidon2", - "p3-poseidon2-air", "p3-symmetric", "p3-util", - "packed_pcs", "pest", "pest_derive", "rand", - "rayon", + "sub_protocols", "tracing", "utils", "whir-p3", @@ -391,19 +375,15 @@ dependencies = [ "multilinear-toolkit", "p3-air", "p3-challenger", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-poseidon2", - "p3-poseidon2-air", "p3-symmetric", "p3-util", - "packed_pcs", "pest", "pest_derive", "poseidon_circuit", "rand", - "rayon", + "sub_protocols", "tracing", "utils", "vm_air", @@ -423,18 +403,14 @@ dependencies = [ "multilinear-toolkit", "p3-air", "p3-challenger", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-poseidon2", - "p3-poseidon2-air", "p3-symmetric", "p3-util", - "packed_pcs", "pest", "pest_derive", "rand", - "rayon", + "sub_protocols", "thiserror", "tracing", "utils", @@ -460,11 +436,9 @@ version = "0.1.0" dependencies = [ "multilinear-toolkit", "p3-challenger", - "p3-field", "p3-koala-bear", "p3-util", "rand", - "rayon", "tracing", "utils", "whir-p3", @@ -485,10 +459,19 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "modular-precompiles" +version = "0.1.0" +dependencies = [ + "clap", + "poseidon_circuit", + "rec_aggregation", +] + [[package]] name = "multilinear-toolkit" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#43c8955c2ca07c1c64962f190aab38c28e53adbb" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git?branch=modular-precompiles#8bf0cfebcddd4292a0be65d795c43da680d70336" dependencies = [ "backend", "constraints-folder", @@ -551,7 +534,7 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "p3-air" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "p3-field", "p3-matrix", @@ -560,7 +543,7 @@ dependencies = [ [[package]] name = "p3-baby-bear" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "p3-field", "p3-mds", @@ -573,7 +556,7 @@ dependencies = [ [[package]] name = "p3-challenger" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "p3-field", "p3-maybe-rayon", @@ -585,7 +568,7 @@ dependencies = [ [[package]] name = "p3-commit" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "p3-challenger", @@ -599,7 +582,7 @@ dependencies = [ [[package]] name = "p3-dft" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "p3-field", @@ -612,7 +595,7 @@ dependencies = [ [[package]] name = "p3-field" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "num-bigint", @@ -627,7 +610,7 @@ dependencies = [ [[package]] name = "p3-interpolation" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "p3-field", "p3-matrix", @@ -638,7 +621,7 @@ dependencies = [ [[package]] name = "p3-koala-bear" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "num-bigint", @@ -654,7 +637,7 @@ dependencies = [ [[package]] name = "p3-matrix" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "p3-field", @@ -669,7 +652,7 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "rayon", ] @@ -677,7 +660,7 @@ dependencies = [ [[package]] name = "p3-mds" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "p3-dft", "p3-field", @@ -689,7 +672,7 @@ dependencies = [ [[package]] name = "p3-merkle-tree" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "p3-commit", @@ -706,7 +689,7 @@ dependencies = [ [[package]] name = "p3-monty-31" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "num-bigint", @@ -728,7 +711,7 @@ dependencies = [ [[package]] name = "p3-poseidon2" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "p3-field", "p3-mds", @@ -737,24 +720,10 @@ dependencies = [ "rand", ] -[[package]] -name = "p3-poseidon2-air" -version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" -dependencies = [ - "p3-air", - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-poseidon2", - "rand", - "tracing", -] - [[package]] name = "p3-symmetric" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "p3-field", @@ -764,7 +733,7 @@ dependencies = [ [[package]] name = "p3-uni-stark" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "itertools", "p3-air", @@ -782,27 +751,12 @@ dependencies = [ [[package]] name = "p3-util" version = "0.3.0" -source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=lean-multisig#59a799ce00907553aeef3d7cf38f616023e9ad5f" +source = "git+https://github.com/TomWambsgans/Plonky3.git?branch=modular-precompiles#5475ea9b0b6274517a3299c730e2a363f56116ea" dependencies = [ "rayon", "serde", ] -[[package]] -name = "packed_pcs" -version = "0.1.0" -dependencies = [ - "multilinear-toolkit", - "p3-field", - "p3-koala-bear", - "p3-util", - "rand", - "rayon", - "tracing", - "utils", - "whir-p3", -] - [[package]] name = "paste" version = "1.0.15" @@ -863,12 +817,11 @@ name = "poseidon_circuit" version = "0.1.0" dependencies = [ "multilinear-toolkit", - "p3-field", "p3-koala-bear", "p3-monty-31", "p3-poseidon2", - "packed_pcs", "rand", + "sub_protocols", "tracing", "utils", "whir-p3", @@ -894,9 +847,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.41" +version = "1.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" dependencies = [ "proc-macro2", ] @@ -968,17 +921,13 @@ dependencies = [ "multilinear-toolkit", "p3-air", "p3-challenger", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-poseidon2", - "p3-poseidon2-air", "p3-symmetric", "p3-util", - "packed_pcs", "rand", - "rayon", "serde_json", + "sub_protocols", "tracing", "utils", "whir-p3", @@ -1089,26 +1038,40 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "sub_protocols" +version = "0.1.0" +dependencies = [ + "derive_more", + "lookup", + "multilinear-toolkit", + "p3-koala-bear", + "p3-util", + "rand", + "tracing", + "utils", + "whir-p3", +] + [[package]] name = "sumcheck" version = "0.3.0" -source = "git+https://github.com/leanEthereum/multilinear-toolkit.git#43c8955c2ca07c1c64962f190aab38c28e53adbb" +source = "git+https://github.com/leanEthereum/multilinear-toolkit.git?branch=modular-precompiles#8bf0cfebcddd4292a0be65d795c43da680d70336" dependencies = [ "backend", "constraints-folder", "fiat-shamir", "p3-air", "p3-field", - "p3-matrix", "p3-util", "rayon", ] [[package]] name = "syn" -version = "2.0.108" +version = "2.0.110" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da58917d35242480a05c2897064da0a80589a2a0476c9a3f2fdc83b53502e917" +checksum = "a99801b5bd34ede4cf3fc688c5919368fea4e4814a4664359503e6015b280aea" dependencies = [ "proc-macro2", "quote", @@ -1271,15 +1234,11 @@ dependencies = [ "multilinear-toolkit", "p3-air", "p3-challenger", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-poseidon2", - "p3-poseidon2-air", "p3-symmetric", "p3-util", "rand", - "rayon", "tracing", "tracing-forest", "tracing-subscriber", @@ -1308,22 +1267,18 @@ dependencies = [ "multilinear-toolkit", "p3-air", "p3-challenger", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-poseidon2", - "p3-poseidon2-air", "p3-symmetric", + "p3-uni-stark", "p3-util", - "packed_pcs", "pest", "pest_derive", "rand", - "rayon", + "sub_protocols", "tracing", "utils", "whir-p3", - "witness_generation", "xmss", ] @@ -1339,7 +1294,7 @@ dependencies = [ [[package]] name = "whir-p3" version = "0.1.0" -source = "git+https://github.com/TomWambsgans/whir-p3?branch=lean-multisig#57c8d12015260a6aca706a021364225e5ebdd636" +source = "git+https://github.com/TomWambsgans/whir-p3?branch=modular-precompiles#9e4698dfb4b0c273dba8dc067ad2cd8f5730a592" dependencies = [ "itertools", "multilinear-toolkit", @@ -1397,16 +1352,7 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets 0.53.5", + "windows-targets", ] [[package]] @@ -1424,31 +1370,14 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.6", - "windows_aarch64_msvc 0.52.6", - "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", - "windows_i686_msvc 0.52.6", - "windows_x86_64_gnu 0.52.6", - "windows_x86_64_gnullvm 0.52.6", - "windows_x86_64_msvc 0.52.6", -] - -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm 0.53.1", - "windows_aarch64_msvc 0.53.1", - "windows_i686_gnu 0.53.1", - "windows_i686_gnullvm 0.53.1", - "windows_i686_msvc 0.53.1", - "windows_x86_64_gnu 0.53.1", - "windows_x86_64_gnullvm 0.53.1", - "windows_x86_64_msvc 0.53.1", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] [[package]] @@ -1457,96 +1386,48 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - [[package]] name = "wit-bindgen" version = "0.46.0" @@ -1565,22 +1446,19 @@ dependencies = [ "multilinear-toolkit", "p3-air", "p3-challenger", - "p3-field", "p3-koala-bear", - "p3-matrix", "p3-monty-31", "p3-poseidon2", - "p3-poseidon2-air", "p3-symmetric", "p3-util", - "packed_pcs", "pest", "pest_derive", "poseidon_circuit", "rand", - "rayon", + "sub_protocols", "tracing", "utils", + "vm_air", "whir-p3", "xmss", ] @@ -1589,7 +1467,7 @@ dependencies = [ name = "xmss" version = "0.1.0" dependencies = [ - "p3-field", + "multilinear-toolkit", "p3-koala-bear", "p3-util", "rand", diff --git a/Cargo.toml b/Cargo.toml index eae85faa..38a659d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "lean-multisig" +name = "modular-precompiles" version.workspace = true edition.workspace = true @@ -8,7 +8,11 @@ version = "0.1.0" edition = "2024" [workspace] -members = ["crates/*", "crates/lean_prover/vm_air", "crates/lean_prover/witness_generation"] +members = [ + "crates/*", + "crates/lean_prover/vm_air", + "crates/lean_prover/witness_generation", +] [workspace.lints] rust.missing_debug_implementations = "warn" @@ -44,7 +48,7 @@ air = { path = "crates/air" } utils = { path = "crates/utils" } lean_vm = { path = "crates/lean_vm" } xmss = { path = "crates/xmss" } -packed_pcs = { path = "crates/packed_pcs" } +sub_protocols = { path = "crates/sub_protocols" } lookup = { path = "crates/lookup" } lean_compiler = { path = "crates/lean_compiler" } lean_prover = { path = "crates/lean_prover" } @@ -58,32 +62,28 @@ thiserror = "2.0" clap = { version = "4.3.10", features = ["derive"] } rand = "0.9.2" sha3 = "0.10.8" -rayon = "1.5.1" derive_more = { version = "2.0.1", features = ["full"] } pest = "2.7" pest_derive = "2.7" +itertools = "0.10.5" colored = "3.0.0" tracing = "0.1.26" serde_json = "*" tracing-subscriber = { version = "0.3.19", features = ["std", "env-filter"] } tracing-forest = { version = "0.2.0", features = ["ansi", "smallvec"] } -p3-koala-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-baby-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-field = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-poseidon2 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-matrix = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-blake3 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-symmetric = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-air = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-uni-stark = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-poseidon2-air = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-goldilocks = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-challenger = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-util = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } -p3-monty-31 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "lean-multisig" } +p3-koala-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-baby-bear = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-poseidon2 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-symmetric = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-air = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-uni-stark = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-goldilocks = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-challenger = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-util = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } +p3-monty-31 = { git = "https://github.com/TomWambsgans/Plonky3.git", branch = "modular-precompiles" } -whir-p3 = { git = "https://github.com/TomWambsgans/whir-p3", branch = "lean-multisig" } -multilinear-toolkit = { git = "https://github.com/leanEthereum/multilinear-toolkit.git" } +whir-p3 = { git = "https://github.com/TomWambsgans/whir-p3", branch = "modular-precompiles" } +multilinear-toolkit = { git = "https://github.com/leanEthereum/multilinear-toolkit.git", branch = "modular-precompiles" } [dependencies] clap.workspace = true @@ -94,11 +94,12 @@ poseidon_circuit.workspace = true # p3-koala-bear = { path = "../zk/Plonky3/koala-bear" } # p3-field = { path = "../zk/Plonky3/field" } # p3-poseidon2 = { path = "../zk/Plonky3/poseidon2" } -# p3-matrix = { path = "../zk/Plonky3/matrix" } # p3-symmetric = { path = "../zk/Plonky3/symmetric" } # p3-air = { path = "../zk/Plonky3/air" } # p3-uni-stark = { path = "../zk/Plonky3/uni-stark" } -# p3-poseidon2-air = { path = "../zk/Plonky3/poseidon2-air" } +# p3-merkle-tree = { path = "../zk/Plonky3/merkle-tree" } +# p3-commit = { path = "../zk/Plonky3/commit" } +# p3-matrix = { path = "../zk/Plonky3/matrix" } # p3-dft = { path = "../zk/Plonky3/dft" } # p3-challenger = { path = "../zk/Plonky3/challenger" } # p3-monty-31 = { path = "../zk/Plonky3/monty-31" } diff --git a/crates/air/Cargo.toml b/crates/air/Cargo.toml index 04d287b5..28e36568 100644 --- a/crates/air/Cargo.toml +++ b/crates/air/Cargo.toml @@ -7,16 +7,13 @@ edition.workspace = true workspace = true [dependencies] -p3-field.workspace = true tracing.workspace = true utils.workspace = true p3-air.workspace = true p3-uni-stark.workspace = true -p3-matrix.workspace = true p3-util.workspace = true multilinear-toolkit.workspace = true [dev-dependencies] p3-koala-bear.workspace = true -p3-matrix.workspace = true rand.workspace = true \ No newline at end of file diff --git a/crates/air/src/lib.rs b/crates/air/src/lib.rs index fb02e7cd..da983579 100644 --- a/crates/air/src/lib.rs +++ b/crates/air/src/lib.rs @@ -3,15 +3,18 @@ use ::utils::ConstraintChecker; use multilinear_toolkit::prelude::*; use p3_air::Air; -use p3_field::ExtensionField; use p3_uni_stark::SymbolicAirBuilder; mod prove; -pub mod table; +mod table; mod uni_skip_utils; mod utils; mod verify; +pub use prove::*; +pub use table::*; +pub use verify::*; + #[cfg(test)] pub mod tests; diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index db1f18c3..f3400ade 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -1,19 +1,12 @@ use std::any::TypeId; use multilinear_toolkit::prelude::*; -use p3_air::BaseAir; -use p3_field::{ExtensionField, Field, cyclic_subgroup_known_order}; use p3_util::{log2_ceil_usize, log2_strict_usize}; use tracing::{info_span, instrument}; -use utils::{ - FSProver, add_multilinears_inplace, fold_multilinear_chunks, multilinears_linear_combination, -}; +use utils::{FSProver, fold_multilinear_chunks, multilinears_linear_combination}; use crate::MyAir; -use crate::{ - uni_skip_utils::{matrix_down_folded, matrix_up_folded}, - utils::{column_down, column_up, columns_up_and_down}, -}; +use crate::{uni_skip_utils::matrix_next_mle_folded, utils::column_shifted}; use super::table::AirTable; @@ -24,16 +17,17 @@ cf https://eprint.iacr.org/2023/552.pdf and https://solvable.group/posts/super-a */ #[instrument(name = "prove air", skip_all)] -fn prove_air< - 'a, +pub fn prove_air< WF: ExtensionField>, // witness field EF: ExtensionField> + ExtensionField, A: MyAir + 'static, >( prover_state: &mut FSProver>, - univariate_skips: usize, table: &AirTable, - witness: &[&'a [WF]], + univariate_skips: usize, + witness: &[&[WF]], + last_row_shifted: &[WF], + virtual_column_statement: Option>, // point should be randomness generated after committing to the columns ) -> (MultilinearPoint, Vec) { let n_rows = witness[0].len(); assert!(witness.iter().all(|col| col.len() == n_rows)); @@ -43,143 +37,127 @@ fn prove_air< "TODO handle the case UNIVARIATE_SKIPS >= log_length" ); - let structured_air = >>::structured(&table.air); - let constraints_batching_scalar = prover_state.sample(); - let constraints_batching_scalars = - cyclic_subgroup_known_order(constraints_batching_scalar, table.n_constraints) - .collect::>(); + let constraints_batching_scalars = constraints_batching_scalar + .powers() + .take(table.n_constraints + virtual_column_statement.is_some() as usize) + .collect(); let n_sc_rounds = log_n_rows + 1 - univariate_skips; - let zerocheck_challenges = prover_state.sample_vec(n_sc_rounds); - - let columns_for_zero_check: MleGroup<'_, EF> = if TypeId::of::() == TypeId::of::>() { - let columns = unsafe { std::mem::transmute::<&[&[WF]], &[&[PF]]>(witness) }; - if structured_air { - MleGroupOwned::Base(columns_up_and_down(columns)).into() - } else { - MleGroupRef::Base(columns.to_vec()).into() - } + let zerocheck_challenges = virtual_column_statement + .as_ref() + .map(|st| st.point.0.clone()) + .unwrap_or_else(|| prover_state.sample_vec(n_sc_rounds)); + assert_eq!(zerocheck_challenges.len(), n_sc_rounds); + + let shifted_rows = table + .columns_with_shift() + .par_iter() + .zip_eq(last_row_shifted) + .map(|(&col_index, &final_value)| column_shifted(witness[col_index], final_value)) + .collect::>(); + + let mut columns_up_down = witness.to_vec(); // orginal columns, followed by shifted ones + columns_up_down.extend(shifted_rows.iter().map(Vec::as_slice)); + let columns_up_down_group: MleGroupRef<'_, EF> = if TypeId::of::() == TypeId::of::>() + { + let columns = + unsafe { std::mem::transmute::, Vec<&[PF]>>(columns_up_down.clone()) }; + MleGroupRef::<'_, EF>::Base(columns) } else { assert!(TypeId::of::() == TypeId::of::()); - let columns = unsafe { std::mem::transmute::<&[&'a [WF]], &[&'a [EF]]>(witness) }; - if structured_air { - MleGroupOwned::Extension(columns_up_and_down(columns)).into() - } else { - MleGroupRef::Extension(columns.to_vec()).into() - } + let columns = + unsafe { std::mem::transmute::, Vec<&[EF]>>(columns_up_down.clone()) }; + MleGroupRef::<'_, EF>::Extension(columns) }; - let columns_for_zero_check_packed = columns_for_zero_check.by_ref().pack(); + let columns_up_down_packed = columns_up_down_group.pack(); let (outer_sumcheck_challenge, inner_sums, _) = info_span!("zerocheck").in_scope(|| { sumcheck_prove( univariate_skips, - columns_for_zero_check_packed, + columns_up_down_packed, &table.air, &constraints_batching_scalars, Some((zerocheck_challenges, None)), - true, + virtual_column_statement.is_none(), prover_state, - EF::ZERO, + virtual_column_statement + .as_ref() + .map(|st| st.value) + .unwrap_or_else(|| EF::ZERO), None, ) }); prover_state.add_extension_scalars(&inner_sums); - if structured_air { - open_structured_columns( - prover_state, - univariate_skips, - witness, - &outer_sumcheck_challenge, - ) - } else { - unreachable!() - } -} - -impl>, A: MyAir + 'static> AirTable { - #[instrument(name = "air: prove in base", skip_all)] - pub fn prove_base( - &self, - prover_state: &mut FSProver>, - univariate_skips: usize, - witness: &[&[PF]], - ) -> (MultilinearPoint, Vec) { - prove_air::, EF, A>(prover_state, univariate_skips, self, witness) - } - - #[instrument(name = "air: prove in extension", skip_all)] - pub fn prove_extension( - &self, - prover_state: &mut FSProver>, - univariate_skips: usize, - witness: &[&[EF]], - ) -> (MultilinearPoint, Vec) { - prove_air::(prover_state, univariate_skips, self, witness) - } + open_columns( + prover_state, + univariate_skips, + &table.columns_with_shift(), + witness, + &outer_sumcheck_challenge, + ) } #[instrument(skip_all)] -fn open_structured_columns> + ExtensionField, IF: Field>( +fn open_columns> + ExtensionField, IF: Field>( prover_state: &mut FSProver>, univariate_skips: usize, - witness: &[&[IF]], + columns_with_shift: &[usize], + columns: &[&[IF]], outer_sumcheck_challenge: &[EF], ) -> (MultilinearPoint, Vec) { - let n_columns = witness.len(); - let n_rows = witness[0].len(); - let log_n_rows = log2_strict_usize(n_rows); - let batching_scalars = prover_state.sample_vec(log2_ceil_usize(n_columns)); - let alpha = prover_state.sample(); + let n_up_down_columns = columns.len() + columns_with_shift.len(); + let batching_scalars = prover_state.sample_vec(log2_ceil_usize(n_up_down_columns)); - let poly_eq_batching_scalars = eval_eq(&batching_scalars); + let eval_eq_batching_scalars = eval_eq(&batching_scalars)[..n_up_down_columns].to_vec(); - let batched_column = - multilinears_linear_combination(witness, &poly_eq_batching_scalars[..n_columns]); - - let batched_column_mixed = info_span!("mixing up / down").in_scope(|| { - let mut batched_column_mixed = column_down(&batched_column); - add_multilinears_inplace( - &mut batched_column_mixed, - &scale_poly(&column_up(&batched_column), alpha), - ); - batched_column_mixed - }); + let batched_column_up = + multilinears_linear_combination(columns, &eval_eq_batching_scalars[..columns.len()]); + let batched_column_down = multilinears_linear_combination( + &columns_with_shift + .iter() + .map(|&i| columns[i]) + .collect::>(), + &eval_eq_batching_scalars[columns.len()..], + ); - // TODO do not recompute this (we can deduce it from already computed values) let sub_evals = info_span!("fold_multilinear_chunks").in_scope(|| { - fold_multilinear_chunks( - &batched_column_mixed, - &MultilinearPoint( - outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec(), - ), - ) + let sub_evals_up = fold_multilinear_chunks( + &batched_column_up, + &MultilinearPoint(outer_sumcheck_challenge[1..].to_vec()), + ); + let sub_evals_down = fold_multilinear_chunks( + &column_shifted(&batched_column_down, EF::ZERO), + &MultilinearPoint(outer_sumcheck_challenge[1..].to_vec()), + ); + sub_evals_up + .iter() + .zip(sub_evals_down.iter()) + .map(|(&up, &down)| up + down) + .collect::>() }); prover_state.add_extension_scalars(&sub_evals); let epsilons = prover_state.sample_vec(univariate_skips); - let point = [ - epsilons, - outer_sumcheck_challenge[1..log_n_rows - univariate_skips + 1].to_vec(), - ] - .concat(); + let inner_sum = sub_evals.evaluate(&MultilinearPoint(epsilons.clone())); - // TODO do not recompute this (we can deduce it from already computed values) - let inner_sum = info_span!("mixed column eval") - .in_scope(|| batched_column_mixed.evaluate(&MultilinearPoint(point.clone()))); + let point = [epsilons, outer_sumcheck_challenge[1..].to_vec()].concat(); - let mut mat_up = matrix_up_folded(&point, alpha); - matrix_down_folded(&point, &mut mat_up); + // TODO opti in case of flat AIR (no need of `matrix_next_mle_folded`) + let matrix_up = eval_eq(&point); + let matrix_down = matrix_next_mle_folded(&point); let inner_mle = info_span!("packing").in_scope(|| { MleGroupOwned::ExtensionPacked(vec![ - pack_extension(&mat_up), - pack_extension(&batched_column), + pack_extension(&matrix_up), + pack_extension(&batched_column_up), + pack_extension(&matrix_down), + pack_extension(&batched_column_down), ]) }); @@ -187,7 +165,7 @@ fn open_structured_columns> + ExtensionField, IF: sumcheck_prove::( 1, inner_mle, - &ProductComputation, + &MySumcheck, &[], None, false, @@ -198,8 +176,8 @@ fn open_structured_columns> + ExtensionField, IF: }); let evaluations_remaining_to_prove = info_span!("final evals").in_scope(|| { - witness - .iter() + columns + .par_iter() .map(|col| col.evaluate(&inner_challenges)) .collect::>() }); @@ -207,3 +185,33 @@ fn open_structured_columns> + ExtensionField, IF: (inner_challenges, evaluations_remaining_to_prove) } + +struct MySumcheck; + +impl>, EF: ExtensionField> SumcheckComputation + for MySumcheck +{ + fn eval(&self, point: &[IF], _: &[EF]) -> EF { + if TypeId::of::() == TypeId::of::() { + let point = unsafe { std::mem::transmute::<&[IF], &[EF]>(point) }; + point[0] * point[1] + point[2] * point[3] + } else { + unreachable!() + } + } + fn degree(&self) -> usize { + 2 + } +} + +impl>> SumcheckComputationPacked for MySumcheck { + fn eval_packed_base(&self, _: &[PFPacking], _: &[EF]) -> EFPacking { + unreachable!() + } + fn eval_packed_extension(&self, point: &[EFPacking], _: &[EF]) -> EFPacking { + point[0] * point[1] + point[2] * point[3] + } + fn degree(&self) -> usize { + 2 + } +} diff --git a/crates/air/src/table.rs b/crates/air/src/table.rs index ed706929..91396439 100644 --- a/crates/air/src/table.rs +++ b/crates/air/src/table.rs @@ -1,10 +1,8 @@ use std::{any::TypeId, marker::PhantomData, mem::transmute}; use p3_air::BaseAir; -use p3_field::{ExtensionField, Field}; use multilinear_toolkit::prelude::*; -use p3_matrix::dense::RowMajorMatrixView; use p3_uni_stark::get_symbolic_constraints; use tracing::instrument; use utils::ConstraintChecker; @@ -21,7 +19,7 @@ pub struct AirTable { impl>, A: MyAir> AirTable { pub fn new(air: A) -> Self { - let symbolic_constraints = get_symbolic_constraints(&air, 0, 0); + let symbolic_constraints = get_symbolic_constraints(&air); let n_constraints = symbolic_constraints.len(); let constraint_degree = Iterator::max(symbolic_constraints.iter().map(|c| c.degree_multiple())).unwrap(); @@ -37,10 +35,15 @@ impl>, A: MyAir> AirTable { >>::width(&self.air) } + pub fn columns_with_shift(&self) -> Vec { + >>::columns_with_shift(&self.air) + } + #[instrument(name = "Check trace validity", skip_all)] pub fn check_trace_validity>>( &self, witness: &[&[IF]], + last_row: &[IF], ) -> Result<(), String> where EF: ExtensionField, @@ -50,7 +53,7 @@ impl>, A: MyAir> AirTable { if witness.len() != self.n_columns() { return Err("Invalid number of columns".to_string()); } - let handle_errors = |row: usize, constraint_checker: &mut ConstraintChecker<'_, IF, EF>| { + let handle_errors = |row: usize, constraint_checker: &ConstraintChecker<'_, IF, EF>| { if !constraint_checker.errors.is_empty() { return Err(format!( "Trace is not valid at row {}: contraints not respected: {}", @@ -65,70 +68,59 @@ impl>, A: MyAir> AirTable { } Ok(()) }; - if >>::structured(&self.air) { - for row in 0..n_rows - 1 { - let up = (0..self.n_columns()) - .map(|j| witness[j][row]) - .collect::>(); - let down = (0..self.n_columns()) - .map(|j| witness[j][row + 1]) - .collect::>(); - let up_and_down = [up, down].concat(); - let mut constraints_checker = ConstraintChecker:: { - main: RowMajorMatrixView::new(&up_and_down, self.n_columns()), - constraint_index: 0, - errors: Vec::new(), - field: PhantomData, - }; - if TypeId::of::() == TypeId::of::() { - unsafe { - self.air.eval(transmute::< - &mut ConstraintChecker<'_, IF, EF>, - &mut ConstraintChecker<'_, EF, EF>, - >(&mut constraints_checker)); - } - } else { - assert_eq!(TypeId::of::(), TypeId::of::>()); - unsafe { - self.air.eval(transmute::< - &mut ConstraintChecker<'_, IF, EF>, - &mut ConstraintChecker<'_, PF, EF>, - >(&mut constraints_checker)); - } - } - handle_errors(row, &mut constraints_checker)?; + for row in 0..n_rows - 1 { + let up = (0..self.n_columns()) + .map(|j| witness[j][row]) + .collect::>(); + let down = self + .columns_with_shift() + .iter() + .map(|j| witness[*j][row + 1]) + .collect::>(); + let up_and_down = [up, down].concat(); + let constraints_checker = self.eval_transition::(&up_and_down); + handle_errors(row, &constraints_checker)?; + } + // last transition: + let up = (0..self.n_columns()) + .map(|j| witness[j][n_rows - 1]) + .collect::>(); + assert_eq!(last_row.len(), self.columns_with_shift().len()); + let up_and_down = [up, last_row.to_vec()].concat(); + let constraints_checker = self.eval_transition::(&up_and_down); + handle_errors(n_rows - 1, &constraints_checker)?; + Ok(()) + } + + fn eval_transition<'a, IF: ExtensionField>>( + &self, + up_and_down: &'a [IF], + ) -> ConstraintChecker<'a, IF, EF> + where + EF: ExtensionField, + { + let mut constraints_checker = ConstraintChecker:: { + main: up_and_down, + constraint_index: 0, + errors: Vec::new(), + field: PhantomData, + }; + if TypeId::of::() == TypeId::of::() { + unsafe { + self.air.eval(transmute::< + &mut ConstraintChecker<'_, IF, EF>, + &mut ConstraintChecker<'_, EF, EF>, + >(&mut constraints_checker)); } } else { - #[allow(clippy::needless_range_loop)] - for row in 0..n_rows { - let up = (0..self.n_columns()) - .map(|j| witness[j][row]) - .collect::>(); - let mut constraints_checker = ConstraintChecker { - main: RowMajorMatrixView::new(&up, self.n_columns()), - constraint_index: 0, - errors: Vec::new(), - field: PhantomData, - }; - if TypeId::of::() == TypeId::of::() { - unsafe { - self.air.eval(transmute::< - &mut ConstraintChecker<'_, IF, EF>, - &mut ConstraintChecker<'_, EF, EF>, - >(&mut constraints_checker)); - } - } else { - assert_eq!(TypeId::of::(), TypeId::of::>()); - unsafe { - self.air.eval(transmute::< - &mut ConstraintChecker<'_, IF, EF>, - &mut ConstraintChecker<'_, PF, EF>, - >(&mut constraints_checker)); - } - } - handle_errors(row, &mut constraints_checker)?; + assert_eq!(TypeId::of::(), TypeId::of::>()); + unsafe { + self.air.eval(transmute::< + &mut ConstraintChecker<'_, IF, EF>, + &mut ConstraintChecker<'_, PF, EF>, + >(&mut constraints_checker)); } } - Ok(()) + constraints_checker } } diff --git a/crates/air/src/tests.rs b/crates/air/src/tests.rs index 0d931270..93b136f4 100644 --- a/crates/air/src/tests.rs +++ b/crates/air/src/tests.rs @@ -1,66 +1,129 @@ -use std::borrow::Borrow; +use std::{ + any::TypeId, + mem::{transmute, transmute_copy}, +}; use multilinear_toolkit::prelude::*; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::PrimeCharacteristicRing; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; -use p3_matrix::Matrix; +use p3_uni_stark::SymbolicExpression; use rand::{Rng, SeedableRng, rngs::StdRng}; use utils::{build_prover_state, build_verifier_state}; -use crate::table::AirTable; +use crate::{prove::prove_air, table::AirTable, verify::verify_air}; const UNIVARIATE_SKIPS: usize = 3; +const N_COLS_WITHOUT_SHIFT: usize = 2; + type F = KoalaBear; type EF = QuinticExtensionFieldKB; -struct ExampleStructuredAir; +struct ExampleStructuredAir< + const N_COLUMNS: usize, + const N_PREPROCESSED_COLUMNS: usize, + const VIRTUAL_COLUMN: bool, +>; -impl BaseAir - for ExampleStructuredAir +impl + BaseAir for ExampleStructuredAir { fn width(&self) -> usize { N_COLUMNS } - fn structured(&self) -> bool { - true - } fn degree(&self) -> usize { - N_PREPROCESSED_COLUMNS + N_PREPROCESSED_COLUMNS - N_COLS_WITHOUT_SHIFT + } + fn columns_with_shift(&self) -> Vec { + [ + (0..N_PREPROCESSED_COLUMNS - N_COLS_WITHOUT_SHIFT).collect::>(), + (N_PREPROCESSED_COLUMNS..N_COLUMNS).collect::>(), + ] + .concat() } } -impl Air - for ExampleStructuredAir +impl< + AB: AirBuilder, + const N_COLUMNS: usize, + const N_PREPROCESSED_COLUMNS: usize, + const VIRTUAL_COLUMN: bool, +> Air for ExampleStructuredAir +where + AB::Var: 'static, + AB::Expr: 'static, + AB::FinalOutput: 'static, { #[inline] fn eval(&self, builder: &mut AB) { let main = builder.main(); - let up = main.row_slice(0).expect("The matrix is empty?"); - let up: &[AB::Var] = (*up).borrow(); - assert_eq!(up.len(), N_COLUMNS); - let down = main.row_slice(1).expect("The matrix is empty?"); - let down: &[AB::Var] = (*down).borrow(); - assert_eq!(down.len(), N_COLUMNS); + let up = main[..N_COLUMNS].to_vec(); + let down = main[N_COLUMNS..].to_vec(); + assert_eq!(down.len(), N_COLUMNS - N_COLS_WITHOUT_SHIFT); + + if VIRTUAL_COLUMN { + // virtual column = col_0 * col_1 + col_2 + builder.add_custom(>::eval_custom( + self, + &[ + up[0].clone().into(), + up[1].clone().into(), + up[2].clone().into(), + ], + )); + } for j in N_PREPROCESSED_COLUMNS..N_COLUMNS { builder.assert_eq( - down[j].clone(), + down[j - N_COLS_WITHOUT_SHIFT].clone(), up[j].clone() + AB::F::from_usize(j) - + (0..N_PREPROCESSED_COLUMNS) + + (0..N_PREPROCESSED_COLUMNS - N_COLS_WITHOUT_SHIFT) .map(|k| AB::Expr::from(down[k].clone())) .product::(), ); } } + + fn eval_custom(&self, inputs: &[::Expr]) -> ::FinalOutput { + assert_eq!(inputs.len(), 3); + let type_id_final_output = TypeId::of::<::FinalOutput>(); + let type_id_expr = TypeId::of::<::Expr>(); + let type_id_f = TypeId::of::(); + let type_id_ef = TypeId::of::(); + let type_id_f_packing = TypeId::of::>(); + let type_id_ef_packing = TypeId::of::>(); + + if type_id_expr == type_id_f && type_id_final_output == type_id_ef { + let inputs = unsafe { transmute::<&[::Expr], &[F]>(inputs) }; + let res = EF::from(inputs[0] * inputs[1] + inputs[2]); + unsafe { transmute_copy::::FinalOutput>(&res) } + } else if type_id_expr == type_id_ef && type_id_final_output == type_id_ef { + let inputs = unsafe { transmute::<&[::Expr], &[EF]>(inputs) }; + let res = inputs[0] * inputs[1] + inputs[2]; + unsafe { transmute_copy::::FinalOutput>(&res) } + } else if type_id_expr == type_id_ef_packing { + assert_eq!(type_id_final_output, type_id_ef_packing); + let inputs = + unsafe { transmute::<&[::Expr], &[EFPacking]>(inputs) }; + let res = inputs[0] * inputs[1] + inputs[2]; + unsafe { transmute_copy::, ::FinalOutput>(&res) } + } else if type_id_expr == type_id_f_packing { + assert_eq!(type_id_final_output, type_id_ef_packing); + let inputs = + unsafe { transmute::<&[::Expr], &[PFPacking]>(inputs) }; + let res = EFPacking::::from(inputs[0] * inputs[1] + inputs[2]); + unsafe { transmute_copy::, ::FinalOutput>(&res) } + } else { + assert_eq!(type_id_expr, TypeId::of::>()); + unsafe { transmute_copy(&SymbolicExpression::::default()) } + } + } } fn generate_structured_trace( - log_length: usize, + n_rows: usize, ) -> Vec> { - let n_rows = 1 << log_length; let mut trace = vec![]; let mut rng = StdRng::seed_from_u64(0); for _ in 0..N_PREPROCESSED_COLUMNS { @@ -77,7 +140,7 @@ fn generate_structured_trace(), ); @@ -88,23 +151,97 @@ fn generate_structured_trace(); + test_air_helper::(); +} + +fn test_air_helper() { const N_COLUMNS: usize = 17; - const N_PREPROCESSED_COLUMNS: usize = 3; + const N_PREPROCESSED_COLUMNS: usize = 5; + const _: () = assert!(N_PREPROCESSED_COLUMNS > N_COLS_WITHOUT_SHIFT); let log_n_rows = 12; + let n_rows = 1 << log_n_rows; let mut prover_state = build_prover_state::(); - let columns = generate_structured_trace::(log_n_rows); - let columns_ref = columns.iter().map(|col| col.as_slice()).collect::>(); + let columns_plus_one = + generate_structured_trace::(n_rows + 1); + let columns_ref = columns_plus_one + .iter() + .map(|col| &col[..n_rows]) + .collect::>(); + let mut last_row = columns_plus_one + .iter() + .map(|col| col[n_rows]) + .collect::>(); + last_row.drain(N_PREPROCESSED_COLUMNS - N_COLS_WITHOUT_SHIFT..N_PREPROCESSED_COLUMNS); + let last_row_ef = last_row.iter().map(|&v| EF::from(v)).collect::>(); + + let virtual_column_statement_prover = if VIRTUAL_COLUMN { + let virtual_column = columns_ref[0] + .iter() + .zip(columns_ref[1].iter()) + .zip(columns_ref[2].iter()) + .map(|((&a, &b), &c)| a * b + c) + .collect::>(); + let virtual_column_evaluation_point = + MultilinearPoint(prover_state.sample_vec(log_n_rows + 1 - UNIVARIATE_SKIPS)); + let selectors = univariate_selectors(UNIVARIATE_SKIPS); + let virtual_column_value = evaluate_univariate_multilinear::<_, _, _, true>( + &virtual_column, + &virtual_column_evaluation_point, + &selectors, + None, + ); + prover_state.add_extension_scalar(virtual_column_value); + + Some(Evaluation::new( + virtual_column_evaluation_point.0.clone(), + virtual_column_value, + )) + } else { + None + }; + + let table = AirTable::::new(ExampleStructuredAir::< + N_COLUMNS, + N_PREPROCESSED_COLUMNS, + VIRTUAL_COLUMN, + > {}); - let table = AirTable::::new(ExampleStructuredAir::); - table.check_trace_validity(&columns_ref).unwrap(); - let (point_prover, evaluations_remaining_to_prove) = - table.prove_base(&mut prover_state, UNIVARIATE_SKIPS, &columns_ref); + table.check_trace_validity(&columns_ref, &last_row).unwrap(); + + let (point_prover, evaluations_remaining_to_prove) = prove_air( + &mut prover_state, + &table, + UNIVARIATE_SKIPS, + &columns_ref, + &last_row, + virtual_column_statement_prover, + ); let mut verifier_state = build_verifier_state(&prover_state); - let (point_verifier, evaluations_remaining_to_verify) = table - .verify(&mut verifier_state, UNIVARIATE_SKIPS, log_n_rows) - .unwrap(); + + let virtual_column_statement_verifier = if VIRTUAL_COLUMN { + let virtual_column_evaluation_point = + MultilinearPoint(verifier_state.sample_vec(log_n_rows + 1 - UNIVARIATE_SKIPS)); + let virtual_column_value = verifier_state.next_extension_scalar().unwrap(); + Some(Evaluation::new( + virtual_column_evaluation_point.0.clone(), + virtual_column_value, + )) + } else { + None + }; + + let (point_verifier, evaluations_remaining_to_verify) = verify_air( + &mut verifier_state, + &table, + UNIVARIATE_SKIPS, + log_n_rows, + &last_row_ef, + virtual_column_statement_verifier, + ) + .unwrap(); assert_eq!(point_prover, point_verifier); assert_eq!( &evaluations_remaining_to_prove, @@ -112,7 +249,7 @@ fn test_structured_air() { ); for i in 0..N_COLUMNS { assert_eq!( - columns[i].evaluate(&point_prover), + columns_ref[i].evaluate(&point_prover), evaluations_remaining_to_verify[i] ); } diff --git a/crates/air/src/uni_skip_utils.rs b/crates/air/src/uni_skip_utils.rs index eae55259..aa7958b8 100644 --- a/crates/air/src/uni_skip_utils.rs +++ b/crates/air/src/uni_skip_utils.rs @@ -2,18 +2,9 @@ use multilinear_toolkit::prelude::*; use tracing::instrument; #[instrument(skip_all)] -pub fn matrix_up_folded>>(outer_challenges: &[F], alpha: F) -> Vec { - let n = outer_challenges.len(); - let mut folded = eval_eq_scaled(outer_challenges, alpha); - let outer_challenges_prod: F = outer_challenges.iter().copied().product(); - folded[(1 << n) - 1] -= outer_challenges_prod * alpha; - folded[(1 << n) - 2] += outer_challenges_prod * alpha; - folded -} - -#[instrument(skip_all)] -pub fn matrix_down_folded>>(outer_challenges: &[F], dest: &mut [F]) { +pub fn matrix_next_mle_folded>>(outer_challenges: &[F]) -> Vec { let n = outer_challenges.len(); + let mut res = F::zero_vec(1 << n); for k in 0..n { let outer_challenges_prod = (F::ONE - outer_challenges[n - k - 1]) * outer_challenges[n - k..].iter().copied().product::(); @@ -21,9 +12,36 @@ pub fn matrix_down_folded>>(outer_challenges: &[F], dest for (mut i, v) in eq_mle.iter_mut().enumerate() { i <<= k + 1; i += 1 << k; - dest[i] += *v; + res[i] += *v; + } + } + res +} + +#[cfg(test)] +mod tests { + use utils::to_big_endian_in_field; + + use crate::utils::next_mle; + + use super::*; + type F = p3_koala_bear::KoalaBear; + + #[test] + fn test_matrix_down_folded() { + let n_vars = 5; + for x in 0..1 << n_vars { + let x_bools = to_big_endian_in_field::(x, n_vars); + let matrix = matrix_next_mle_folded(&x_bools); + for y in 0..1 << n_vars { + let y_bools = to_big_endian_in_field::(y, n_vars); + let expected = F::from_bool(x + 1 == y); + assert_eq!( + matrix.evaluate(&MultilinearPoint(y_bools.clone())), + expected + ); + assert_eq!(next_mle(&[x_bools.clone(), y_bools].concat()), expected); + } } } - // bottom left corner: - dest[(1 << n) - 1] += outer_challenges.iter().copied().product::(); } diff --git a/crates/air/src/utils.rs b/crates/air/src/utils.rs index 5ff34789..c34093ea 100644 --- a/crates/air/src/utils.rs +++ b/crates/air/src/utils.rs @@ -1,66 +1,18 @@ use multilinear_toolkit::prelude::*; -use p3_field::Field; -use tracing::instrument; - -pub(crate) fn matrix_up_lde(point: &[F]) -> F { - /* - Matrix UP: - - (1 0 0 0 ... 0 0 0) - (0 1 0 0 ... 0 0 0) - (0 0 1 0 ... 0 0 0) - (0 0 0 1 ... 0 0 0) - ... ... ... - (0 0 0 0 ... 1 0 0) - (0 0 0 0 ... 0 1 0) - (0 0 0 0 ... 0 1 0) - - Square matrix of size self.n_columns x sef.n_columns - As a multilinear polynomial in 2 * log_length variables: - - self.n_columns first variables -> encoding the row index - - self.n_columns last variables -> encoding the column index - */ - - assert_eq!(point.len() % 2, 0); - let n = point.len() / 2; - let (s1, s2) = point.split_at(n); - MultilinearPoint(s1.to_vec()).eq_poly_outside(&MultilinearPoint(s2.to_vec())) - + point[..point.len() - 1].iter().copied().product::() - * (F::ONE - point[point.len() - 1] * F::TWO) -} - -pub(crate) fn matrix_down_lde(point: &[F]) -> F { - /* - Matrix DOWN: - - (0 1 0 0 ... 0 0 0) - (0 0 1 0 ... 0 0 0) - (0 0 0 1 ... 0 0 0) - (0 0 0 0 ... 0 0 0) - (0 0 0 0 ... 0 0 0) - ... ... ... - (0 0 0 0 ... 0 1 0) - (0 0 0 0 ... 0 0 1) - (0 0 0 0 ... 0 0 1) - - Square matrix of size self.n_columns x sef.n_columns - As a multilinear polynomial in 2 * log_length variables: - - self.n_columns first variables -> encoding the row index - - self.n_columns last variables -> encoding the column index - - TODO OPTIMIZATIOn: - the lde currently is in log(table_length)^2, but it could be log(table_length) using a recursive construction - (However it is not representable as a polynomial in this case, but as a fraction instead) - - */ - next_mle(point) + point.iter().copied().product::() - - // bottom right corner -} /// Returns a multilinear polynomial in 2n variables that evaluates to 1 /// if and only if the second n-bit vector is equal to the first vector plus one (viewed as big-endian integers). /// +/// (0 1 0 0 ... 0 0 0) +/// (0 0 1 0 ... 0 0 0) +/// (0 0 0 1 ... 0 0 0) +/// (0 0 0 0 ... 0 0 0) +/// (0 0 0 0 ... 0 0 0) +/// ... ... ... +/// (0 0 0 0 ... 0 1 0) +/// (0 0 0 0 ... 0 0 1) +/// (0 0 0 0 ... 0 0 0) +/// /// # Arguments /// - `point`: A slice of 2n field elements representing two n-bit vectors concatenated. /// The first n elements are `x` (original vector), the last n are `y` (candidate successor). @@ -82,7 +34,7 @@ pub(crate) fn matrix_down_lde(point: &[F]) -> F { /// /// # Returns /// Field element: 1 if y = x + 1, 0 otherwise. -fn next_mle(point: &[F]) -> F { +pub(crate) fn next_mle(point: &[F]) -> F { // Check that the point length is even: we split into x and y of equal length. assert_eq!( point.len() % 2, @@ -130,29 +82,9 @@ fn next_mle(point: &[F]) -> F { .sum() } -#[instrument(skip_all, fields(len = columns.len(), col_len = columns[0].len()))] -pub(crate) fn columns_up_and_down(columns: &[&[F]]) -> Vec> { - (0..columns.len() * 2) - .into_par_iter() - .map(|i| { - if i < columns.len() { - column_up(columns[i]) - } else { - column_down(columns[i - columns.len()]) - } - }) - .collect() -} - -pub(crate) fn column_up(column: &[F]) -> Vec { - let mut up = parallel_clone_vec(column); - up[column.len() - 1] = up[column.len() - 2]; - up -} - -pub(crate) fn column_down(column: &[F]) -> Vec { +pub(crate) fn column_shifted(column: &[F], final_value: F) -> Vec { let mut down = unsafe { uninitialized_vec(column.len()) }; parallel_clone(&column[1..], &mut down[..column.len() - 1]); - down[column.len() - 1] = down[column.len() - 2]; + down[column.len() - 1] = final_value; down } diff --git a/crates/air/src/verify.rs b/crates/air/src/verify.rs index 5b2b2ef0..f0ec9fb0 100644 --- a/crates/air/src/verify.rs +++ b/crates/air/src/verify.rs @@ -1,25 +1,27 @@ use multilinear_toolkit::prelude::*; use p3_air::BaseAir; -use p3_field::{ExtensionField, cyclic_subgroup_known_order, dot_product}; use p3_util::log2_ceil_usize; -use crate::{ - MyAir, - utils::{matrix_down_lde, matrix_up_lde}, -}; +use crate::{MyAir, utils::next_mle}; use super::table::AirTable; -fn verify_air>, A: MyAir>( +pub fn verify_air>, A: MyAir>( verifier_state: &mut FSVerifier>, table: &AirTable, univariate_skips: usize, log_n_rows: usize, + last_row: &[EF], + virtual_column_statement: Option>, // point should be randomness generated after committing to the columns ) -> Result<(MultilinearPoint, Vec), ProofError> { let constraints_batching_scalar = verifier_state.sample(); - let n_zerocheck_challenges = log_n_rows + 1 - univariate_skips; - let global_zerocheck_challenges = verifier_state.sample_vec(n_zerocheck_challenges); + let n_sc_rounds = log_n_rows + 1 - univariate_skips; + let zerocheck_challenges = virtual_column_statement + .as_ref() + .map(|st| st.point.0.clone()) + .unwrap_or_else(|| verifier_state.sample_vec(n_sc_rounds)); + assert_eq!(zerocheck_challenges.len(), n_sc_rounds); let (sc_sum, outer_statement) = sumcheck_verify_with_univariate_skip::( verifier_state, @@ -27,7 +29,12 @@ fn verify_air>, A: MyAir>( log_n_rows, univariate_skips, )?; - if sc_sum != EF::ZERO { + if sc_sum + != virtual_column_statement + .as_ref() + .map(|st| st.value) + .unwrap_or(EF::ZERO) + { return Err(ProofError::InvalidProof); } @@ -36,23 +43,18 @@ fn verify_air>, A: MyAir>( .map(|s| s.evaluate(outer_statement.point[0])) .collect::>(); - let inner_sums = verifier_state.next_extension_scalars_vec( - if >>::structured(&table.air) { - 2 * table.n_columns() - } else { - table.n_columns() - }, - )?; + let inner_sums = verifier_state + .next_extension_scalars_vec(table.n_columns() + table.columns_with_shift().len())?; - let constraint_evals = SumcheckComputation::eval( - &table.air, - &inner_sums, - &cyclic_subgroup_known_order(constraints_batching_scalar, table.n_constraints) - .collect::>(), - ); + let constraints_batching_scalars = constraints_batching_scalar + .powers() + .take(table.n_constraints + virtual_column_statement.is_some() as usize) + .collect(); + let constraint_evals = + SumcheckComputation::eval(&table.air, &inner_sums, &constraints_batching_scalars); if eq_poly_with_skip( - &global_zerocheck_challenges, + &zerocheck_challenges, &outer_statement.point, univariate_skips, ) * constraint_evals @@ -60,107 +62,49 @@ fn verify_air>, A: MyAir>( { return Err(ProofError::InvalidProof); } - let structured_air = >>::structured(&table.air); - - if structured_air { - verify_structured_columns( - verifier_state, - table.n_columns(), - univariate_skips, - &inner_sums, - &Evaluation::new( - outer_statement.point[1..log_n_rows - univariate_skips + 1].to_vec(), - outer_statement.value, - ), - &outer_selector_evals, - log_n_rows, - ) - } else { - verify_unstructured_columns( - verifier_state, - univariate_skips, - inner_sums, - &outer_statement.point, - &outer_selector_evals, - log_n_rows, - ) - } -} - -impl>, A: MyAir> AirTable { - pub fn verify( - &self, - verifier_state: &mut FSVerifier>, - univariate_skips: usize, - log_n_rows: usize, - ) -> Result<(MultilinearPoint, Vec), ProofError> { - verify_air::(verifier_state, self, univariate_skips, log_n_rows) - } -} - -fn verify_unstructured_columns>>( - verifier_state: &mut FSVerifier>, - univariate_skips: usize, - inner_sums: Vec, - outer_sumcheck_point: &MultilinearPoint, - outer_selector_evals: &[EF], - log_n_rows: usize, -) -> Result<(MultilinearPoint, Vec), ProofError> { - let n_columns = inner_sums.len(); - let columns_batching_scalars = verifier_state.sample_vec(log2_ceil_usize(n_columns)); - - let sub_evals = verifier_state.next_extension_scalars_vec(1 << univariate_skips)?; - - if dot_product::( - sub_evals.iter().copied(), - outer_selector_evals.iter().copied(), - ) != dot_product::( - inner_sums.iter().copied(), - eval_eq(&columns_batching_scalars).iter().copied(), - ) { - return Err(ProofError::InvalidProof); - } - - let epsilons = MultilinearPoint(verifier_state.sample_vec(univariate_skips)); - let common_point = MultilinearPoint( - [ - epsilons.0.clone(), - outer_sumcheck_point[1..log_n_rows - univariate_skips + 1].to_vec(), - ] - .concat(), - ); - - let evaluations_remaining_to_verify = verifier_state.next_extension_scalars_vec(n_columns)?; - - if sub_evals.evaluate(&epsilons) - != dot_product( - eval_eq(&columns_batching_scalars).into_iter(), - evaluations_remaining_to_verify.iter().copied(), - ) - { - return Err(ProofError::InvalidProof); - } - Ok((common_point, evaluations_remaining_to_verify)) + open_columns( + verifier_state, + table.n_columns(), + univariate_skips, + &table.columns_with_shift(), + inner_sums, + &Evaluation::new(outer_statement.point[1..].to_vec(), outer_statement.value), + &outer_selector_evals, + log_n_rows, + last_row, + ) } #[allow(clippy::too_many_arguments)] // TODO -fn verify_structured_columns>>( +fn open_columns>>( verifier_state: &mut FSVerifier>, n_columns: usize, univariate_skips: usize, - all_inner_sums: &[EF], + columns_with_shift: &[usize], + mut evals_up_and_down: Vec, outer_sumcheck_challenge: &Evaluation, outer_selector_evals: &[EF], log_n_rows: usize, + last_row: &[EF], ) -> Result<(MultilinearPoint, Vec), ProofError> { - let columns_batching_scalars = verifier_state.sample_vec(log2_ceil_usize(n_columns)); - let alpha = verifier_state.sample(); + assert_eq!(n_columns + last_row.len(), evals_up_and_down.len()); + let last_row_selector = outer_selector_evals[(1 << univariate_skips) - 1] + * outer_sumcheck_challenge + .point + .iter() + .copied() + .product::(); + for (&last_row_value, down_col_eval) in last_row.iter().zip(&mut evals_up_and_down[n_columns..]) + { + *down_col_eval -= last_row_value * last_row_selector; + } - let poly_eq_batching_scalars = eval_eq(&columns_batching_scalars); + let batching_scalars = verifier_state.sample_vec(log2_ceil_usize(n_columns + last_row.len())); - let all_witness_up = &all_inner_sums[..n_columns]; - let all_witness_down = &all_inner_sums[n_columns..]; + let eval_eq_batching_scalars = eval_eq(&batching_scalars); + let batching_scalars_up = &eval_eq_batching_scalars[..n_columns]; + let batching_scalars_down = &eval_eq_batching_scalars[n_columns..]; let sub_evals = verifier_state.next_extension_scalars_vec(1 << univariate_skips)?; @@ -168,14 +112,9 @@ fn verify_structured_columns>>( sub_evals.iter().copied(), outer_selector_evals.iter().copied(), ) != dot_product::( - all_witness_up.iter().copied(), - poly_eq_batching_scalars.iter().copied(), - ) * alpha - + dot_product::( - all_witness_down.iter().copied(), - poly_eq_batching_scalars.iter().copied(), - ) - { + evals_up_and_down.iter().copied(), + eval_eq_batching_scalars.iter().copied(), + ) { return Err(ProofError::InvalidProof); } @@ -187,24 +126,31 @@ fn verify_structured_columns>>( return Err(ProofError::InvalidProof); } - let matrix_lde_point = [ - epsilons.0, - outer_sumcheck_challenge.point.to_vec(), - inner_sumcheck_stement.point.0.clone(), - ] - .concat(); - let up = matrix_up_lde(&matrix_lde_point); - let down = matrix_down_lde(&matrix_lde_point); - - let final_value = inner_sumcheck_stement.value / (up * alpha + down); + let matrix_up_sc_eval = + MultilinearPoint([epsilons.0.clone(), outer_sumcheck_challenge.point.0.clone()].concat()) + .eq_poly_outside(&inner_sumcheck_stement.point); + let matrix_down_sc_eval = next_mle( + &[ + epsilons.0, + outer_sumcheck_challenge.point.to_vec(), + inner_sumcheck_stement.point.0.clone(), + ] + .concat(), + ); let evaluations_remaining_to_verify = verifier_state.next_extension_scalars_vec(n_columns)?; - if final_value - != dot_product( - eval_eq(&columns_batching_scalars).into_iter(), - evaluations_remaining_to_verify.iter().copied(), - ) + let batched_col_up_sc_eval = dot_product::( + batching_scalars_up.iter().copied(), + evaluations_remaining_to_verify.iter().copied(), + ); + let batched_col_down_sc_eval = (0..columns_with_shift.len()) + .map(|i| evaluations_remaining_to_verify[columns_with_shift[i]] * batching_scalars_down[i]) + .sum::(); + + if inner_sumcheck_stement.value + != matrix_up_sc_eval * batched_col_up_sc_eval + + matrix_down_sc_eval * batched_col_down_sc_eval { return Err(ProofError::InvalidProof); } diff --git a/crates/lean_compiler/Cargo.toml b/crates/lean_compiler/Cargo.toml index edc4be6f..b394494d 100644 --- a/crates/lean_compiler/Cargo.toml +++ b/crates/lean_compiler/Cargo.toml @@ -10,22 +10,18 @@ workspace = true pest.workspace = true pest_derive.workspace = true utils.workspace = true -p3-field.workspace = true xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true p3-challenger.workspace = true p3-air.workspace = true -p3-matrix.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true -rayon.workspace = true tracing.workspace = true air.workspace = true -packed_pcs.workspace = true -p3-poseidon2-air.workspace = true +sub_protocols.workspace = true lookup.workspace = true lean_vm.workspace = true multilinear-toolkit.workspace = true diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index 1b26315f..b172b538 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -1,6 +1,6 @@ use crate::{F, a_simplify_lang::*, ir::*, lang::*, precompiles::*}; use lean_vm::*; -use p3_field::Field; +use multilinear_toolkit::prelude::*; use std::{ borrow::Borrow, collections::{BTreeMap, BTreeSet}, @@ -532,7 +532,7 @@ fn compile_lines( SimpleLine::FunctionRet { return_data } => { if compiler.func_name == "main" { - // pC -> ending_pc, fp -> 0 + // pc -> ending_pc, fp -> 0 let zero_value_offset = IntermediateValue::MemoryAfterFp { offset: compiler.stack_size.into(), }; diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 350b9db7..1dcc450c 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -1,6 +1,6 @@ use crate::{F, NONRESERVED_PROGRAM_INPUT_START, ZERO_VEC_PTR, ir::*, lang::*}; use lean_vm::*; -use p3_field::{PrimeCharacteristicRing, PrimeField32}; +use multilinear_toolkit::prelude::*; use std::collections::BTreeMap; use utils::ToUsize; @@ -317,7 +317,7 @@ fn compile_block( res, size, } => { - low_level_bytecode.push(Instruction::DotProductExtensionExtension { + low_level_bytecode.push(Instruction::DotProduct { arg0: arg0.try_into_mem_or_constant(compiler).unwrap(), arg1: arg1.try_into_mem_or_constant(compiler).unwrap(), res: res.try_into_mem_or_fp(compiler).unwrap(), diff --git a/crates/lean_compiler/src/ir/operation.rs b/crates/lean_compiler/src/ir/operation.rs index d829e5e0..da8660b6 100644 --- a/crates/lean_compiler/src/ir/operation.rs +++ b/crates/lean_compiler/src/ir/operation.rs @@ -1,7 +1,6 @@ use crate::F; use lean_vm::Operation; -use p3_field::PrimeCharacteristicRing; -use p3_field::PrimeField64; +use multilinear_toolkit::prelude::*; use std::fmt::{Display, Formatter}; use utils::ToUsize; diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 3514c509..5690b8db 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -1,5 +1,5 @@ use lean_vm::*; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use p3_util::log2_ceil_usize; use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; diff --git a/crates/lean_compiler/src/parser/parsers/literal.rs b/crates/lean_compiler/src/parser/parsers/literal.rs index e3a1d218..aac154af 100644 --- a/crates/lean_compiler/src/parser/parsers/literal.rs +++ b/crates/lean_compiler/src/parser/parsers/literal.rs @@ -8,7 +8,7 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use utils::ToUsize; /// Parser for constant declarations. diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index 95ddc5e6..d8189e2a 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -10,22 +10,18 @@ workspace = true pest.workspace = true pest_derive.workspace = true utils.workspace = true -p3-field.workspace = true xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true p3-challenger.workspace = true p3-air.workspace = true -p3-matrix.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true -rayon.workspace = true tracing.workspace = true air.workspace = true -packed_pcs.workspace = true -p3-poseidon2-air.workspace = true +sub_protocols.workspace = true lookup.workspace = true lean_vm.workspace = true lean_compiler.workspace = true diff --git a/crates/lean_prover/src/common.rs b/crates/lean_prover/src/common.rs index e10eeeb4..92b268d5 100644 --- a/crates/lean_prover/src/common.rs +++ b/crates/lean_prover/src/common.rs @@ -1,10 +1,9 @@ use multilinear_toolkit::prelude::*; -use p3_field::{Algebra, BasedVectorSpace}; -use p3_field::{ExtensionField, PrimeCharacteristicRing}; use p3_koala_bear::{KOALABEAR_RC16_INTERNAL, KOALABEAR_RC24_INTERNAL}; use p3_util::log2_ceil_usize; -use packed_pcs::ColDims; -use poseidon_circuit::{PoseidonGKRLayers, default_cube_layers}; +use poseidon_circuit::{GKRPoseidonResult, PoseidonGKRLayers, default_cube_layers}; +use sub_protocols::{ColDims, committed_dims_extension_from_base}; +use vm_air::*; use crate::*; use lean_vm::*; @@ -61,7 +60,39 @@ pub fn get_base_dims( ColDims::padded(n_rows_table_dot_products, F::ZERO), // dot product: index b ColDims::padded(n_rows_table_dot_products, F::ZERO), // dot product: index res ], - vec![ColDims::padded(n_rows_table_dot_products, F::ZERO); DIMENSION], // dot product: computation + committed_dims_extension_from_base(n_rows_table_dot_products, EF::ZERO), // dot product: computation + ] + .concat() +} + +pub fn normal_lookup_into_memory_initial_statements( + exec_air_point: &MultilinearPoint, + exec_evals: &[EF], + dot_product_air_point: &MultilinearPoint, + dot_product_evals: &[EF], +) -> Vec>> { + [ + [ + COL_INDEX_MEM_VALUE_A, + COL_INDEX_MEM_VALUE_B, + COL_INDEX_MEM_VALUE_C, + ] + .into_iter() + .map(|index| vec![Evaluation::new(exec_air_point.clone(), exec_evals[index])]) + .collect::>(), + [ + DOT_PRODUCT_AIR_COL_VALUE_A, + DOT_PRODUCT_AIR_COL_VALUE_B, + DOT_PRODUCT_AIR_COL_VALUE_RES, + ] + .into_iter() + .map(|index| { + vec![Evaluation::new( + dot_product_air_point.clone(), + dot_product_evals[index], + )] + }) + .collect::>(), ] .concat() } @@ -147,312 +178,64 @@ pub fn add_memory_statements_for_dot_product_precompile( Ok(()) } -pub struct PrecompileFootprint { - pub global_challenge: EF, - pub fingerprint_challenge_powers: [EF; 5], -} - -const PRECOMP_INDEX_OPERAND_A: usize = 0; -const PRECOMP_INDEX_OPERAND_B: usize = 1; -const PRECOMP_INDEX_FLAG_A: usize = 2; -const PRECOMP_INDEX_FLAG_B: usize = 3; -const PRECOMP_INDEX_FLAG_C: usize = 4; -const PRECOMP_INDEX_AUX: usize = 5; -const PRECOMP_INDEX_POSEIDON_16: usize = 6; -const PRECOMP_INDEX_POSEIDON_24: usize = 7; -const PRECOMP_INDEX_DOT_PRODUCT: usize = 8; -const PRECOMP_INDEX_MULTILINEAR_EVAL: usize = 9; -const PRECOMP_INDEX_MEM_VALUE_A: usize = 10; -const PRECOMP_INDEX_MEM_VALUE_B: usize = 11; -const PRECOMP_INDEX_MEM_VALUE_C: usize = 12; -const PRECOMP_INDEX_FP: usize = 13; - -pub fn reorder_full_trace_for_precomp_foot_print(full_trace: Vec) -> Vec { - assert_eq!(full_trace.len(), N_TOTAL_COLUMNS); - vec![ - full_trace[COL_INDEX_OPERAND_A], - full_trace[COL_INDEX_OPERAND_B], - full_trace[COL_INDEX_FLAG_A], - full_trace[COL_INDEX_FLAG_B], - full_trace[COL_INDEX_FLAG_C], - full_trace[COL_INDEX_AUX], - full_trace[COL_INDEX_POSEIDON_16], - full_trace[COL_INDEX_POSEIDON_24], - full_trace[COL_INDEX_DOT_PRODUCT], - full_trace[COL_INDEX_MULTILINEAR_EVAL], - full_trace[COL_INDEX_MEM_VALUE_A], - full_trace[COL_INDEX_MEM_VALUE_B], - full_trace[COL_INDEX_MEM_VALUE_C], - full_trace[COL_INDEX_FP], +pub fn default_poseidon_indexes() -> Vec { + [ + vec![ + ZERO_VEC_PTR, + ZERO_VEC_PTR, + POSEIDON_16_NULL_HASH_PTR, + if POSEIDON_16_DEFAULT_COMPRESSION { + ZERO_VEC_PTR + } else { + POSEIDON_16_NULL_HASH_PTR + 1 + }, + ], + vec![ + ZERO_VEC_PTR, + ZERO_VEC_PTR, + ZERO_VEC_PTR, + POSEIDON_24_NULL_HASH_PTR, + ], ] + .concat() } -impl PrecompileFootprint { - fn air_eval< - PointF: PrimeCharacteristicRing + Copy, - ResultF: Algebra + Algebra + Copy, - >( - &self, - point: &[PointF], - mul_point_f_and_ef: impl Fn(PointF, EF) -> ResultF, - ) -> ResultF { - let nu_a = (ResultF::ONE - point[PRECOMP_INDEX_FLAG_A]) * point[PRECOMP_INDEX_MEM_VALUE_A] - + point[PRECOMP_INDEX_FLAG_A] * point[PRECOMP_INDEX_OPERAND_A]; - let nu_b = (ResultF::ONE - point[PRECOMP_INDEX_FLAG_B]) * point[PRECOMP_INDEX_MEM_VALUE_B] - + point[PRECOMP_INDEX_FLAG_B] * point[PRECOMP_INDEX_OPERAND_B]; - let nu_c = (ResultF::ONE - point[PRECOMP_INDEX_FLAG_C]) * point[PRECOMP_INDEX_MEM_VALUE_C] - + point[PRECOMP_INDEX_FLAG_C] * point[PRECOMP_INDEX_FP]; - - (nu_a * self.fingerprint_challenge_powers[1] - + nu_b * self.fingerprint_challenge_powers[2] - + nu_c * self.fingerprint_challenge_powers[3] - + mul_point_f_and_ef( - point[PRECOMP_INDEX_AUX], - self.fingerprint_challenge_powers[4], - ) - + PointF::from_usize(TABLE_INDEX_POSEIDONS_16)) - * point[PRECOMP_INDEX_POSEIDON_16] - + (nu_a * self.fingerprint_challenge_powers[1] - + nu_b * self.fingerprint_challenge_powers[2] - + nu_c * self.fingerprint_challenge_powers[3] - + PointF::from_usize(TABLE_INDEX_POSEIDONS_24)) - * point[PRECOMP_INDEX_POSEIDON_24] - + (nu_a * self.fingerprint_challenge_powers[1] - + nu_b * self.fingerprint_challenge_powers[2] - + nu_c * self.fingerprint_challenge_powers[3] - + mul_point_f_and_ef( - point[PRECOMP_INDEX_AUX], - self.fingerprint_challenge_powers[4], - ) - + PointF::from_usize(TABLE_INDEX_DOT_PRODUCTS)) - * point[PRECOMP_INDEX_DOT_PRODUCT] - + (nu_a * self.fingerprint_challenge_powers[1] - + nu_b * self.fingerprint_challenge_powers[2] - + nu_c * self.fingerprint_challenge_powers[3] - + mul_point_f_and_ef( - point[PRECOMP_INDEX_AUX], - self.fingerprint_challenge_powers[4], - ) - + PointF::from_usize(TABLE_INDEX_MULTILINEAR_EVAL)) - * point[PRECOMP_INDEX_MULTILINEAR_EVAL] - + self.global_challenge - } -} - -impl> SumcheckComputation for PrecompileFootprint -where - EF: ExtensionField, -{ - fn degree(&self) -> usize { - 3 - } - fn eval(&self, point: &[N], _: &[EF]) -> EF { - self.air_eval(point, |p, c| c * p) - } -} - -impl SumcheckComputationPacked for PrecompileFootprint { - fn degree(&self) -> usize { - 3 - } - - fn eval_packed_extension(&self, point: &[EFPacking], _: &[EF]) -> EFPacking { - self.air_eval(point, |p, c| p * c) - } - - fn eval_packed_base(&self, point: &[PFPacking], _: &[EF]) -> EFPacking { - self.air_eval(point, |p, c| EFPacking::::from(p) * c) - } -} - -pub struct DotProductFootprint { - pub global_challenge: EF, - pub fingerprint_challenge_powers: [EF; 5], -} - -impl DotProductFootprint { - fn air_eval< - PointF: PrimeCharacteristicRing + Copy, - ResultF: Algebra + Algebra + Copy, - >( - &self, - point: &[PointF], - mul_point_f_and_ef: impl Fn(PointF, EF) -> ResultF, - ) -> ResultF { - ResultF::from_usize(TABLE_INDEX_DOT_PRODUCTS) - + (mul_point_f_and_ef(point[2], self.fingerprint_challenge_powers[1]) - + mul_point_f_and_ef(point[3], self.fingerprint_challenge_powers[2]) - + mul_point_f_and_ef(point[4], self.fingerprint_challenge_powers[3]) - + mul_point_f_and_ef(point[1], self.fingerprint_challenge_powers[4])) - * point[0] - + self.global_challenge - } -} - -impl>> SumcheckComputation for DotProductFootprint -where - EF: ExtensionField, -{ - fn degree(&self) -> usize { - 2 - } - - fn eval(&self, point: &[N], _: &[EF]) -> EF { - self.air_eval(point, |p, c| c * p) - } -} - -impl SumcheckComputationPacked for DotProductFootprint { - fn degree(&self) -> usize { - 2 - } - - fn eval_packed_extension(&self, point: &[EFPacking], _: &[EF]) -> EFPacking { - self.air_eval(point, |p, c| p * c) - } - fn eval_packed_base(&self, point: &[PFPacking], _: &[EF]) -> EFPacking { - self.air_eval(point, |p, c| EFPacking::::from(p) * c) - } -} - -pub fn get_poseidon_lookup_statements( - (log_n_p16, log_n_p24): (usize, usize), - (p16_input_point, p16_input_evals): &(MultilinearPoint, Vec), - (p16_output_point, p16_output_evals): &(MultilinearPoint, Vec), - (p24_input_point, p24_input_evals): &(MultilinearPoint, Vec), - (p24_output_point, p24_output_evals): &(MultilinearPoint, Vec), - memory_folding_challenges: &MultilinearPoint, -) -> Vec> { - let p16_folded_eval_addr_a = (&p16_input_evals[..8]).evaluate(memory_folding_challenges); - let p16_folded_eval_addr_b = (&p16_input_evals[8..16]).evaluate(memory_folding_challenges); - let p16_folded_eval_addr_res_a = (&p16_output_evals[..8]).evaluate(memory_folding_challenges); - let p16_folded_eval_addr_res_b = (&p16_output_evals[8..16]).evaluate(memory_folding_challenges); - - let p24_folded_eval_addr_a = (&p24_input_evals[..8]).evaluate(memory_folding_challenges); - let p24_folded_eval_addr_b = (&p24_input_evals[8..16]).evaluate(memory_folding_challenges); - let p24_folded_eval_addr_c = (&p24_input_evals[16..24]).evaluate(memory_folding_challenges); - let p24_folded_eval_addr_res = (&p24_output_evals[16..24]).evaluate(memory_folding_challenges); - - let padding_p16 = EF::zero_vec(log_n_p16.max(log_n_p24) - log_n_p16); - let padding_p24 = EF::zero_vec(log_n_p16.max(log_n_p24) - log_n_p24); - +pub fn poseidon_lookup_statements( + p16_gkr: &GKRPoseidonResult, + p24_gkr: &GKRPoseidonResult, +) -> Vec>> { vec![ - Evaluation::new( - [ - vec![EF::ZERO; 3], - padding_p16.clone(), - p16_input_point.0.clone(), - ] - .concat(), - p16_folded_eval_addr_a, - ), - Evaluation::new( - [ - vec![EF::ZERO, EF::ZERO, EF::ONE], - padding_p16.clone(), - p16_input_point.0.clone(), - ] - .concat(), - p16_folded_eval_addr_b, - ), - Evaluation::new( - [ - vec![EF::ZERO, EF::ONE, EF::ZERO], - padding_p16.clone(), - p16_output_point.0.clone(), - ] - .concat(), - p16_folded_eval_addr_res_a, - ), - Evaluation::new( - [ - vec![EF::ZERO, EF::ONE, EF::ONE], - padding_p16.clone(), - p16_output_point.0.clone(), - ] - .concat(), - p16_folded_eval_addr_res_b, - ), - Evaluation::new( - [ - vec![EF::ONE, EF::ZERO, EF::ZERO], - padding_p24.clone(), - p24_input_point.0.clone(), - ] - .concat(), - p24_folded_eval_addr_a, - ), - Evaluation::new( - [ - vec![EF::ONE, EF::ZERO, EF::ONE], - padding_p24.clone(), - p24_input_point.0.clone(), - ] - .concat(), - p24_folded_eval_addr_b, - ), - Evaluation::new( - [ - vec![EF::ONE, EF::ONE, EF::ZERO], - padding_p24.clone(), - p24_input_point.0.clone(), - ] - .concat(), - p24_folded_eval_addr_c, - ), - Evaluation::new( - [ - vec![EF::ONE, EF::ONE, EF::ONE], - padding_p24.clone(), - p24_output_point.0.clone(), - ] - .concat(), - p24_folded_eval_addr_res, - ), + vec![MultiEvaluation::new( + p16_gkr.input_statements.point.clone(), + p16_gkr.input_statements.values[..VECTOR_LEN].to_vec(), + )], + vec![MultiEvaluation::new( + p16_gkr.input_statements.point.clone(), + p16_gkr.input_statements.values[VECTOR_LEN..].to_vec(), + )], + vec![MultiEvaluation::new( + p16_gkr.output_statements.point.clone(), + p16_gkr.output_statements.values[..VECTOR_LEN].to_vec(), + )], + vec![MultiEvaluation::new( + p16_gkr.output_statements.point.clone(), + p16_gkr.output_statements.values[VECTOR_LEN..].to_vec(), + )], + vec![MultiEvaluation::new( + p24_gkr.input_statements.point.clone(), + p24_gkr.input_statements.values[..VECTOR_LEN].to_vec(), + )], + vec![MultiEvaluation::new( + p24_gkr.input_statements.point.clone(), + p24_gkr.input_statements.values[VECTOR_LEN..VECTOR_LEN * 2].to_vec(), + )], + vec![MultiEvaluation::new( + p24_gkr.input_statements.point.clone(), + p24_gkr.input_statements.values[VECTOR_LEN * 2..].to_vec(), + )], + vec![MultiEvaluation::new( + p24_gkr.output_statements.point.clone(), + p24_gkr.output_statements.values[VECTOR_LEN * 2..].to_vec(), + )], ] } - -pub fn poseidon_lookup_correcting_factors( - log_n_p16: usize, - log_n_p24: usize, - index_lookup_point: &MultilinearPoint, -) -> (EF, EF) { - let correcting_factor = index_lookup_point[3..3 + log_n_p16.abs_diff(log_n_p24)] - .iter() - .map(|&x| EF::ONE - x) - .product::(); - if log_n_p16 > log_n_p24 { - (EF::ONE, correcting_factor) - } else { - (correcting_factor, EF::ONE) - } -} - -pub fn add_poseidon_lookup_statements_on_indexes( - log_n_p16: usize, - log_n_p24: usize, - index_lookup_point: &MultilinearPoint, - inner_values: &[EF], - p16_index_statements: [&mut Vec>; 3], // input_a, input_b, res_a - p24_index_statements: [&mut Vec>; 3], // input_a, input_b, res -) { - assert_eq!(inner_values.len(), 6); - let mut idx_point_right_p16 = MultilinearPoint(index_lookup_point[3..].to_vec()); - let mut idx_point_right_p24 = - MultilinearPoint(index_lookup_point[3 + log_n_p16.abs_diff(log_n_p24)..].to_vec()); - if log_n_p16 < log_n_p24 { - std::mem::swap(&mut idx_point_right_p16, &mut idx_point_right_p24); - } - for (i, stmt) in p16_index_statements.into_iter().enumerate() { - stmt.push(Evaluation::new( - idx_point_right_p16.clone(), - inner_values[i], - )); - } - for (i, stmt) in p24_index_statements.into_iter().enumerate() { - stmt.push(Evaluation::new( - idx_point_right_p24.clone(), - inner_values[i + 3], - )); - } -} diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index cbc14b9d..f54c8b88 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -12,8 +12,12 @@ pub mod prove_execution; pub mod verify_execution; const UNIVARIATE_SKIPS: usize = 3; +const TWO_POW_UNIVARIATE_SKIPS: usize = 1 << UNIVARIATE_SKIPS; const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 8; // TODO optimize +const DOT_PRODUCT_UNIVARIATE_SKIPS: usize = 1; +const TWO_POW_DOT_PRODUCT_UNIVARIATE_SKIPS: usize = 1 << DOT_PRODUCT_UNIVARIATE_SKIPS; + pub fn whir_config_builder() -> WhirConfigBuilder { WhirConfigBuilder { folding_factor: FoldingFactor::new(7, 4), @@ -25,8 +29,3 @@ pub fn whir_config_builder() -> WhirConfigBuilder { starting_log_inv_rate: 1, } } - -const TABLE_INDEX_POSEIDONS_16: usize = 1; // should be != 0 -const TABLE_INDEX_POSEIDONS_24: usize = 2; -const TABLE_INDEX_DOT_PRODUCTS: usize = 3; -const TABLE_INDEX_MULTILINEAR_EVAL: usize = 4; diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 24c92a66..3b1b7acf 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -1,19 +1,17 @@ +use std::array; + use crate::common::*; use crate::*; -use ::air::table::AirTable; +use ::air::AirTable; +use air::prove_air; use lean_vm::*; use lookup::prove_gkr_product; use lookup::{compute_pushforward, prove_logup_star}; use multilinear_toolkit::prelude::*; -use p3_field::ExtensionField; -use p3_field::Field; -use p3_field::PrimeCharacteristicRing; use p3_util::{log2_ceil_usize, log2_strict_usize}; -use packed_pcs::*; use poseidon_circuit::{PoseidonGKRLayers, prove_poseidon_gkr}; +use sub_protocols::*; use tracing::info_span; -use utils::ToUsize; -use utils::dot_product_with_base; use utils::field_slice_as_base; use utils::{build_prover_state, padd_with_zero_to_next_power_of_two}; use vm_air::*; @@ -76,18 +74,6 @@ pub fn prove_execution( precompute_dft_twiddles::(1 << 24); - let mut exec_columns = full_trace[..N_INSTRUCTION_COLUMNS_IN_AIR] - .iter() - .map(Vec::as_slice) - .collect::>(); - exec_columns.extend( - full_trace[N_INSTRUCTION_COLUMNS..] - .iter() - .map(Vec::as_slice) - .collect::>(), - ); - let exec_table = AirTable::::new(VMAir); - let _validity_proof_span = info_span!("Validity proof generation").entered(); let p16_gkr_layers = PoseidonGKRLayers::<16, N_COMMITED_CUBES_P16>::build(Some(VECTOR_LEN)); @@ -97,19 +83,23 @@ pub fn prove_execution( generate_poseidon_witness_helper(&p16_gkr_layers, &poseidons_16, Some(n_compressions_16)); let p24_witness = generate_poseidon_witness_helper(&p24_gkr_layers, &poseidons_24, None); - let dot_product_table = AirTable::::new(DotProductAir); - let (dot_product_columns, dot_product_padding_len) = build_dot_product_columns(&dot_products, 1 << LOG_MIN_DOT_PRODUCT_ROWS); + let dot_product_col_index_a = + field_slice_as_base(&dot_product_columns[DOT_PRODUCT_AIR_COL_INDEX_A]).unwrap(); + let dot_product_col_index_b = + field_slice_as_base(&dot_product_columns[DOT_PRODUCT_AIR_COL_INDEX_B]).unwrap(); + let dot_product_col_index_res = + field_slice_as_base(&dot_product_columns[DOT_PRODUCT_AIR_COL_INDEX_RES]).unwrap(); let dot_product_flags: Vec> = field_slice_as_base(&dot_product_columns[DOT_PRODUCT_AIR_COL_START_FLAG]).unwrap(); let dot_product_lengths: Vec> = field_slice_as_base(&dot_product_columns[DOT_PRODUCT_AIR_COL_LEN]).unwrap(); let dot_product_computations: &[EF] = &dot_product_columns[DOT_PRODUCT_AIR_COL_COMPUTATION]; - let dot_product_computations_base = - transpose_slice_to_basis_coefficients::(dot_product_computations); + let dot_product_computation_ext_to_base_helper = + ExtensionCommitmentFromBaseProver::before_commitment(dot_product_computations); let n_rows_table_dot_products = dot_product_flags.len() - dot_product_padding_len; let log_n_rows_dot_product_table = log2_strict_usize(dot_product_flags.len()); @@ -153,8 +143,18 @@ pub fn prove_execution( ) .unwrap(); } - let p16_indexes = all_poseidon_16_indexes(&poseidons_16); - let p24_indexes = all_poseidon_24_indexes(&poseidons_24); + let [ + p16_indexes_input_a, + p16_indexes_input_b, + p16_indexes_output, + p16_indexes_output_shifted, // = if compressed { 0 } else { p16_indexes_output + 1 } + ] = all_poseidon_16_indexes(&poseidons_16); + let [ + p24_indexes_input_a, + p24_indexes_input_a_shifted, // = p24_indexes_input_a + 1 + p24_indexes_input_b, + p24_indexes_output, + ] = all_poseidon_24_indexes(&poseidons_24); let base_dims = get_base_dims( n_cycles, @@ -166,10 +166,6 @@ pub fn prove_execution( (&p16_gkr_layers, &p24_gkr_layers), ); - let dot_product_col_index_a = field_slice_as_base(&dot_product_columns[2]).unwrap(); - let dot_product_col_index_b = field_slice_as_base(&dot_product_columns[3]).unwrap(); - let dot_product_col_index_res = field_slice_as_base(&dot_product_columns[4]).unwrap(); - let base_pols = [ vec![ memory.as_slice(), @@ -179,8 +175,16 @@ pub fn prove_execution( full_trace[COL_INDEX_MEM_ADDRESS_B].as_slice(), full_trace[COL_INDEX_MEM_ADDRESS_C].as_slice(), ], - p16_indexes.iter().map(Vec::as_slice).collect::>(), - p24_indexes.iter().map(Vec::as_slice).collect::>(), + vec![ + &p16_indexes_input_a, + &p16_indexes_input_b, + &p16_indexes_output, + ], + vec![ + &p24_indexes_input_a, + &p24_indexes_input_b, + &p24_indexes_output, + ], p16_witness .committed_cubes .iter() @@ -198,7 +202,8 @@ pub fn prove_execution( dot_product_col_index_b.as_slice(), dot_product_col_index_res.as_slice(), ], - dot_product_computations_base + dot_product_computation_ext_to_base_helper + .sub_columns_to_commit .iter() .map(Vec::as_slice) .collect(), @@ -258,7 +263,10 @@ pub fn prove_execution( } let (grand_product_exec_res, grand_product_exec_statement) = - prove_gkr_product(&mut prover_state, &exec_column_for_grand_product); + prove_gkr_product::<_, TWO_POW_UNIVARIATE_SKIPS>( + &mut prover_state, + &exec_column_for_grand_product, + ); let p16_column_for_grand_product = poseidons_16 .par_iter() @@ -273,7 +281,7 @@ pub fn prove_execution( .collect::>(); let (grand_product_p16_res, grand_product_p16_statement) = - prove_gkr_product(&mut prover_state, &p16_column_for_grand_product); + prove_gkr_product::<_, 2>(&mut prover_state, &p16_column_for_grand_product); let p24_column_for_grand_product = poseidons_24 .par_iter() @@ -288,7 +296,7 @@ pub fn prove_execution( .collect::>(); let (grand_product_p24_res, grand_product_p24_statement) = - prove_gkr_product(&mut prover_state, &p24_column_for_grand_product); + prove_gkr_product::<_, 2>(&mut prover_state, &p24_column_for_grand_product); let dot_product_column_for_grand_product = (0..1 << log_n_rows_dot_product_table) .into_par_iter() @@ -324,7 +332,10 @@ pub fn prove_execution( .product::(); let (grand_product_dot_product_res, grand_product_dot_product_statement) = - prove_gkr_product(&mut prover_state, &dot_product_column_for_grand_product); + prove_gkr_product::<_, TWO_POW_DOT_PRODUCT_UNIVARIATE_SKIPS>( + &mut prover_state, + &dot_product_column_for_grand_product, + ); let corrected_prod_exec = grand_product_exec_res / grand_product_challenge_global.exp_u64( @@ -380,11 +391,11 @@ pub fn prove_execution( ); let p16_grand_product_evals_on_indexes_a = - p16_indexes[0].evaluate(&grand_product_p16_statement.point); + p16_indexes_input_a.evaluate(&grand_product_p16_statement.point); let p16_grand_product_evals_on_indexes_b = - p16_indexes[1].evaluate(&grand_product_p16_statement.point); + p16_indexes_input_b.evaluate(&grand_product_p16_statement.point); let p16_grand_product_evals_on_indexes_res = - p16_indexes[2].evaluate(&grand_product_p16_statement.point); + p16_indexes_output.evaluate(&grand_product_p16_statement.point); prover_state.add_extension_scalars(&[ p16_grand_product_evals_on_indexes_a, @@ -406,11 +417,11 @@ pub fn prove_execution( )]; let p24_grand_product_evals_on_indexes_a = - p24_indexes[0].evaluate(&grand_product_p24_statement.point); + p24_indexes_input_a.evaluate(&grand_product_p24_statement.point); let p24_grand_product_evals_on_indexes_b = - p24_indexes[1].evaluate(&grand_product_p24_statement.point); + p24_indexes_input_b.evaluate(&grand_product_p24_statement.point); let p24_grand_product_evals_on_indexes_res = - p24_indexes[2].evaluate(&grand_product_p24_statement.point); + p24_indexes_output.evaluate(&grand_product_p24_statement.point); prover_state.add_extension_scalars(&[ p24_grand_product_evals_on_indexes_a, p24_grand_product_evals_on_indexes_b, @@ -430,128 +441,39 @@ pub fn prove_execution( p24_grand_product_evals_on_indexes_res, )]; - let dot_product_footprint_computation = DotProductFootprint { + let exec_table = AirTable::::new(VMAir { global_challenge: grand_product_challenge_global, fingerprint_challenge_powers: powers_const(fingerprint_challenge), - }; - - let ( - grand_product_dot_product_sumcheck_point, - grand_product_dot_product_sumcheck_inner_evals, - _, - ) = info_span!("Grand product sumcheck for Dot Product").in_scope(|| { - sumcheck_prove( - 1, - MleGroupRef::Extension( - dot_product_columns[..5] - .iter() - .map(|c| c.as_slice()) - .collect::>(), - ), // we do not use packing here because it's slower in practice (this sumcheck is small) - &dot_product_footprint_computation, - &[], - Some((grand_product_dot_product_statement.point.0.clone(), None)), - false, + }); + let (exec_air_point, exec_evals_to_prove) = info_span!("Execution AIR proof").in_scope(|| { + prove_air( &mut prover_state, - grand_product_dot_product_statement.value, - None, + &exec_table, + UNIVARIATE_SKIPS, + &full_trace.iter().map(Vec::as_slice).collect::>(), + &execution_air_padding_row(bytecode.ending_pc), + Some(grand_product_exec_statement), ) }); - assert_eq!(grand_product_dot_product_sumcheck_inner_evals.len(), 5); - prover_state.add_extension_scalars(&grand_product_dot_product_sumcheck_inner_evals); - - let grand_product_dot_product_flag_statement = Evaluation::new( - grand_product_dot_product_sumcheck_point.clone(), - grand_product_dot_product_sumcheck_inner_evals[0], - ); - let grand_product_dot_product_len_statement = Evaluation::new( - grand_product_dot_product_sumcheck_point.clone(), - grand_product_dot_product_sumcheck_inner_evals[1], - ); - - let grand_product_dot_product_table_indexes_statement_index_a = Evaluation::new( - grand_product_dot_product_sumcheck_point.clone(), - grand_product_dot_product_sumcheck_inner_evals[2], - ); - let grand_product_dot_product_table_indexes_statement_index_b = Evaluation::new( - grand_product_dot_product_sumcheck_point.clone(), - grand_product_dot_product_sumcheck_inner_evals[3], - ); - let grand_product_dot_product_table_indexes_statement_index_res = Evaluation::new( - grand_product_dot_product_sumcheck_point.clone(), - grand_product_dot_product_sumcheck_inner_evals[4], - ); - - let precompile_foot_print_computation = PrecompileFootprint { + let dot_product_table = AirTable::::new(DotProductAir { global_challenge: grand_product_challenge_global, fingerprint_challenge_powers: powers_const(fingerprint_challenge), - }; - - let (grand_product_exec_sumcheck_point, mut grand_product_exec_sumcheck_inner_evals, _) = - info_span!("Grand product sumcheck for Execution").in_scope(|| { - sumcheck_prove( - 1, // TODO univariate skip - MleGroupRef::Base( - reorder_full_trace_for_precomp_foot_print( - full_trace.iter().collect::>(), - ) - .iter() - .map(|c| c.as_slice()) - .collect::>(), - ) - .pack(), - &precompile_foot_print_computation, - &[], - Some((grand_product_exec_statement.point.0.clone(), None)), - false, - &mut prover_state, - grand_product_exec_statement.value, - None, - ) - }); - - // TODO compute eq polynomial 1 time and then inner product with each column - for col in [ - COL_INDEX_OPERAND_C, - COL_INDEX_ADD, - COL_INDEX_MUL, - COL_INDEX_DEREF, - COL_INDEX_JUMP, - COL_INDEX_PC, - COL_INDEX_MEM_ADDRESS_A, - COL_INDEX_MEM_ADDRESS_B, - COL_INDEX_MEM_ADDRESS_C, - ] { - grand_product_exec_sumcheck_inner_evals.insert( - col, - full_trace[col].evaluate(&grand_product_exec_sumcheck_point), - ); - } - assert_eq!( - N_TOTAL_COLUMNS, - grand_product_exec_sumcheck_inner_evals.len() - ); - prover_state.add_extension_scalars(&grand_product_exec_sumcheck_inner_evals); - - let grand_product_exec_evals_on_each_column = - &grand_product_exec_sumcheck_inner_evals[..N_INSTRUCTION_COLUMNS]; - - let grand_product_fp_statement = Evaluation::new( - grand_product_exec_sumcheck_point.clone(), - grand_product_exec_sumcheck_inner_evals[COL_INDEX_FP], - ); - - let (exec_air_point, exec_evals_to_prove) = info_span!("Execution AIR proof") - .in_scope(|| exec_table.prove_base(&mut prover_state, UNIVARIATE_SKIPS, &exec_columns)); - + }); let dot_product_columns_ref = dot_product_columns .iter() .map(Vec::as_slice) .collect::>(); let (dot_product_air_point, dot_product_evals_to_prove) = info_span!("DotProduct AIR proof") .in_scope(|| { - dot_product_table.prove_extension(&mut prover_state, 1, &dot_product_columns_ref) + prove_air( + &mut prover_state, + &dot_product_table, + DOT_PRODUCT_UNIVARIATE_SKIPS, + &dot_product_columns_ref, + &dot_product_air_padding_row(), + Some(grand_product_dot_product_statement), + ) }); let random_point_p16 = MultilinearPoint(prover_state.sample_vec(log_n_p16)); @@ -562,17 +484,6 @@ pub fn prove_execution( UNIVARIATE_SKIPS, &p16_gkr_layers, ); - let p16_cubes_statements = p16_gkr - .cubes_statements - .1 - .iter() - .map(|&e| { - vec![Evaluation { - point: p16_gkr.cubes_statements.0.clone(), - value: e, - }] - }) - .collect::>(); let random_point_p24 = MultilinearPoint(prover_state.sample_vec(log_n_p24)); let p24_gkr = prove_poseidon_gkr( @@ -582,56 +493,28 @@ pub fn prove_execution( UNIVARIATE_SKIPS, &p24_gkr_layers, ); - let p24_cubes_statements = p24_gkr - .cubes_statements - .1 - .iter() - .map(|&e| { - vec![Evaluation { - point: p24_gkr.cubes_statements.0.clone(), - value: e, - }] - }) - .collect::>(); - - // Poseidons 16/24 memory addresses lookup - let poseidon_logup_star_alpha = prover_state.sample(); - let memory_folding_challenges = MultilinearPoint(prover_state.sample_vec(LOG_VECTOR_LEN)); - - let poseidon_lookup_statements = get_poseidon_lookup_statements( - (log_n_p16, log_n_p24), - &p16_gkr.input_statements, - &(random_point_p16.clone(), p16_gkr.output_values), - &p24_gkr.input_statements, - &(random_point_p24.clone(), p24_gkr.output_values), - &memory_folding_challenges, - ); - let all_poseidon_indexes = full_poseidon_indexes_poly(&poseidons_16, &poseidons_24); - - let poseidon_folded_memory = fold_multilinear_chunks(&memory, &memory_folding_challenges); - - let mut poseidon_poly_eq_point = EF::zero_vec(all_poseidon_indexes.len()); - for (i, statement) in poseidon_lookup_statements.iter().enumerate() { - compute_sparse_eval_eq::( - &statement.point, - &mut poseidon_poly_eq_point, - poseidon_logup_star_alpha.exp_u64(i as u64), - ); - } - - let poseidon_pushforward = compute_pushforward( - &all_poseidon_indexes, - poseidon_folded_memory.len(), - &poseidon_poly_eq_point, - ); + let poseidon_value_columns = vec![ + array::from_fn(|i| FPacking::::unpack_slice(&p16_witness.input_layer[i])), + array::from_fn(|i| FPacking::::unpack_slice(&p16_witness.input_layer[i + VECTOR_LEN])), + array::from_fn(|i| { + FPacking::::unpack_slice(&p16_witness.compression.as_ref().unwrap().2[i]) + }), + array::from_fn(|i| { + FPacking::::unpack_slice( + &p16_witness.compression.as_ref().unwrap().2[i + VECTOR_LEN], + ) + }), + array::from_fn(|i| FPacking::::unpack_slice(&p24_witness.input_layer[i])), + array::from_fn(|i| FPacking::::unpack_slice(&p24_witness.input_layer[i + VECTOR_LEN])), + array::from_fn(|i| { + FPacking::::unpack_slice(&p24_witness.input_layer[i + VECTOR_LEN * 2]) + }), + array::from_fn(|i| { + FPacking::::unpack_slice(&p24_witness.output_layer[i + VECTOR_LEN * 2]) + }), + ]; - let non_used_precompiles_evals = full_trace - [N_INSTRUCTION_COLUMNS_IN_AIR..N_INSTRUCTION_COLUMNS] - .iter() - .map(|col| col.evaluate(&exec_air_point)) - .collect::>(); - prover_state.add_extension_scalars(&non_used_precompiles_evals); let bytecode_compression_challenges = MultilinearPoint(prover_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); @@ -639,228 +522,80 @@ pub fn prove_execution( let bytecode_lookup_claim_1 = Evaluation::new( exec_air_point.clone(), - padd_with_zero_to_next_power_of_two( - &[ - (0..N_INSTRUCTION_COLUMNS_IN_AIR) - .map(|i| exec_evals_to_prove[i]) - .collect::>(), - non_used_precompiles_evals, - ] - .concat(), - ) - .evaluate(&bytecode_compression_challenges), - ); - let bytecode_lookup_point_2 = grand_product_exec_sumcheck_point.clone(); - let bytecode_lookup_claim_2 = Evaluation::new( - bytecode_lookup_point_2.clone(), - padd_with_zero_to_next_power_of_two(grand_product_exec_evals_on_each_column) + padd_with_zero_to_next_power_of_two(&exec_evals_to_prove[..N_INSTRUCTION_COLUMNS]) .evaluate(&bytecode_compression_challenges), ); - let alpha_bytecode_lookup = prover_state.sample(); - - let mut bytecode_poly_eq_point = eval_eq(&exec_air_point); - compute_eval_eq::, EF, true>( - &bytecode_lookup_point_2, - &mut bytecode_poly_eq_point, - alpha_bytecode_lookup, - ); + let bytecode_poly_eq_point = eval_eq(&exec_air_point); let bytecode_pushforward = compute_pushforward( &full_trace[COL_INDEX_PC], folded_bytecode.len(), &bytecode_poly_eq_point, ); - let dot_product_table_length = dot_product_columns[0].len(); - assert!(dot_product_table_length.is_power_of_two()); - let mut dot_product_indexes_spread = vec![F::zero_vec(dot_product_table_length * 4); DIMENSION]; - for i in 0..dot_product_table_length { - let index_a: F = dot_product_columns[DOT_PRODUCT_AIR_COL_INDEX_A][i] - .as_base() - .unwrap(); - let index_b: F = dot_product_columns[DOT_PRODUCT_AIR_COL_INDEX_B][i] - .as_base() - .unwrap(); - let index_res: F = dot_product_columns[DOT_PRODUCT_AIR_COL_INDEX_RES][i] - .as_base() - .unwrap(); - for (j, column) in dot_product_indexes_spread.iter_mut().enumerate() { - column[i] = index_a + F::from_usize(j); - column[i + dot_product_table_length] = index_b + F::from_usize(j); - column[i + 2 * dot_product_table_length] = index_res + F::from_usize(j); - } - } - let dot_product_values_spread = dot_product_indexes_spread - .iter() - .map(|slice| { - slice - .par_iter() - .map(|i| memory[i.to_usize()]) - .collect::>() - }) - .collect::>(); - - let dot_product_values_mixing_challenges = MultilinearPoint(prover_state.sample_vec(2)); - let dot_product_values_mixed = [ - dot_product_evals_to_prove[DOT_PRODUCT_AIR_COL_VALUE_A], - dot_product_evals_to_prove[DOT_PRODUCT_AIR_COL_VALUE_B], - dot_product_evals_to_prove[DOT_PRODUCT_AIR_COL_RES], - EF::ZERO, - ] - .evaluate(&dot_product_values_mixing_challenges); - - let dot_product_evals_spread = dot_product_values_spread - .iter() - .map(|slice| { - slice.evaluate(&MultilinearPoint( - [ - dot_product_values_mixing_challenges.0.clone(), - dot_product_air_point.0.clone(), - ] - .concat(), - )) - }) - .collect::>(); - assert_eq!( - dot_product_with_base(&dot_product_evals_spread), - dot_product_values_mixed - ); - prover_state.add_extension_scalars(&dot_product_evals_spread); - - let dot_product_values_batching_scalars = MultilinearPoint(prover_state.sample_vec(3)); - - let dot_product_values_batched_point = MultilinearPoint( - [ - dot_product_values_batching_scalars.0.clone(), - dot_product_values_mixing_challenges.0.clone(), - dot_product_air_point.0.clone(), - ] - .concat(), - ); - let dot_product_values_batched_eval = - padd_with_zero_to_next_power_of_two(&dot_product_evals_spread) - .evaluate(&dot_product_values_batching_scalars); - - let concatenated_dot_product_values_spread = - padd_with_zero_to_next_power_of_two(&dot_product_values_spread.concat()); - - let padded_dot_product_indexes_spread = - padd_with_zero_to_next_power_of_two(&dot_product_indexes_spread.concat()); - - assert!( - padded_dot_product_indexes_spread.len() <= 1 << log_n_cycles, - "Currently the number of dot products must be < num_cycles / 32 (TODO relax this)" - ); - - let unused_1 = evaluate_as_larger_multilinear_pol( - &concatenated_dot_product_values_spread, - &grand_product_exec_sumcheck_point, - ); - prover_state.add_extension_scalar(unused_1); - - let grand_product_mem_values_mixing_challenges = MultilinearPoint(prover_state.sample_vec(2)); - let base_memory_lookup_statement_1 = Evaluation::new( - [ - grand_product_mem_values_mixing_challenges.0.clone(), - grand_product_exec_sumcheck_point.0, - ] - .concat(), - [ - grand_product_exec_sumcheck_inner_evals[COL_INDEX_MEM_VALUE_A], - grand_product_exec_sumcheck_inner_evals[COL_INDEX_MEM_VALUE_B], - grand_product_exec_sumcheck_inner_evals[COL_INDEX_MEM_VALUE_C], - unused_1, - ] - .evaluate(&grand_product_mem_values_mixing_challenges), - ); - - let unused_2 = evaluate_as_larger_multilinear_pol( - &concatenated_dot_product_values_spread, - &exec_air_point, - ); - prover_state.add_extension_scalar(unused_2); - let exec_air_mem_values_mixing_challenges = MultilinearPoint(prover_state.sample_vec(2)); - let base_memory_lookup_statement_2 = Evaluation::new( - [ - exec_air_mem_values_mixing_challenges.0.clone(), - exec_air_point.0.clone(), - ] - .concat(), - [ - exec_evals_to_prove[COL_INDEX_MEM_VALUE_A.index_in_air()], - exec_evals_to_prove[COL_INDEX_MEM_VALUE_B.index_in_air()], - exec_evals_to_prove[COL_INDEX_MEM_VALUE_C.index_in_air()], - unused_2, - ] - .evaluate(&exec_air_mem_values_mixing_challenges), - ); - - let unused_3a = evaluate_as_smaller_multilinear_pol( - &full_trace[COL_INDEX_MEM_VALUE_A], - &dot_product_values_batched_point, - ); - let unused_3b = evaluate_as_smaller_multilinear_pol( - &full_trace[COL_INDEX_MEM_VALUE_B], - &dot_product_values_batched_point, - ); - let unused_3c = evaluate_as_smaller_multilinear_pol( - &full_trace[COL_INDEX_MEM_VALUE_C], - &dot_product_values_batched_point, - ); - prover_state.add_extension_scalars(&[unused_3a, unused_3b, unused_3c]); - - let dot_product_air_mem_values_mixing_challenges = MultilinearPoint(prover_state.sample_vec(2)); - let base_memory_lookup_statement_3 = Evaluation::new( + let normal_lookup_into_memory = NormalPackedLookupProver::step_1( + &mut prover_state, + &memory, + vec![ + &full_trace[COL_INDEX_MEM_ADDRESS_A], + &full_trace[COL_INDEX_MEM_ADDRESS_B], + &full_trace[COL_INDEX_MEM_ADDRESS_C], + &dot_product_col_index_a, + &dot_product_col_index_b, + &dot_product_col_index_res, + ], [ - dot_product_air_mem_values_mixing_challenges.0.clone(), - EF::zero_vec(log_n_cycles - dot_product_values_batched_point.len()), - dot_product_values_batched_point.0.clone(), + vec![n_cycles; 3], + vec![n_rows_table_dot_products.max(1 << LOG_MIN_DOT_PRODUCT_ROWS); 3], ] .concat(), - [ - unused_3a, - unused_3b, - unused_3c, - dot_product_values_batched_eval, - ] - .evaluate(&dot_product_air_mem_values_mixing_challenges), + [vec![0; 3], vec![0; 3]].concat(), + vec![ + &full_trace[COL_INDEX_MEM_VALUE_A], + &full_trace[COL_INDEX_MEM_VALUE_B], + &full_trace[COL_INDEX_MEM_VALUE_C], + ], + vec![ + &dot_product_columns[DOT_PRODUCT_AIR_COL_VALUE_A], + &dot_product_columns[DOT_PRODUCT_AIR_COL_VALUE_B], + &dot_product_columns[DOT_PRODUCT_AIR_COL_VALUE_RES], + ], + normal_lookup_into_memory_initial_statements( + &exec_air_point, + &exec_evals_to_prove, + &dot_product_air_point, + &dot_product_evals_to_prove, + ), + LOG_SMALLEST_DECOMPOSITION_CHUNK, ); - // Main memory lookup - let base_memory_indexes = [ - full_trace[COL_INDEX_MEM_ADDRESS_A].clone(), - full_trace[COL_INDEX_MEM_ADDRESS_B].clone(), - full_trace[COL_INDEX_MEM_ADDRESS_C].clone(), + let vectorized_lookup_into_memory = VectorizedPackedLookupProver::<_, VECTOR_LEN>::step_1( + &mut prover_state, + &memory, + vec![ + &p16_indexes_input_a, + &p16_indexes_input_b, + &p16_indexes_output, + &p16_indexes_output_shifted, + &p24_indexes_input_a, + &p24_indexes_input_a_shifted, + &p24_indexes_input_b, + &p24_indexes_output, + ], [ - padded_dot_product_indexes_spread.clone(), - F::zero_vec((1 << log_n_cycles) - padded_dot_product_indexes_spread.len()), + vec![n_poseidons_16.max(1 << LOG_MIN_POSEIDONS_16); 4], + vec![n_poseidons_24.max(1 << LOG_MIN_POSEIDONS_24); 4], ] .concat(), - ] - .concat(); - - let memory_poly_eq_point_alpha = prover_state.sample(); - - let mut base_memory_poly_eq_point = eval_eq(&base_memory_lookup_statement_1.point); - compute_eval_eq::, EF, true>( - &base_memory_lookup_statement_2.point, - &mut base_memory_poly_eq_point, - memory_poly_eq_point_alpha, - ); - compute_eval_eq::, EF, true>( - &base_memory_lookup_statement_3.point, - &mut base_memory_poly_eq_point, - memory_poly_eq_point_alpha.square(), - ); - let base_memory_pushforward = compute_pushforward( - &base_memory_indexes, - memory.len(), - &base_memory_poly_eq_point, + default_poseidon_indexes(), + poseidon_value_columns, + poseidon_lookup_statements(&p16_gkr, &p24_gkr), + LOG_SMALLEST_DECOMPOSITION_CHUNK, ); // 2nd Commitment let extension_pols = vec![ - base_memory_pushforward.as_slice(), - poseidon_pushforward.as_slice(), + normal_lookup_into_memory.pushforward_to_commit(), + vectorized_lookup_into_memory.pushforward_to_commit(), bytecode_pushforward.as_slice(), ]; @@ -882,211 +617,93 @@ pub fn prove_execution( LOG_SMALLEST_DECOMPOSITION_CHUNK, ); - let base_memory_logup_star_statements = prove_logup_star( - &mut prover_state, - &MleRef::Base(&memory), - &base_memory_indexes, - base_memory_lookup_statement_1.value - + memory_poly_eq_point_alpha * base_memory_lookup_statement_2.value - + memory_poly_eq_point_alpha.square() * base_memory_lookup_statement_3.value, - &base_memory_poly_eq_point, - &base_memory_pushforward, - Some(non_zero_memory_size), - ); - let poseidon_logup_star_statements = prove_logup_star( - &mut prover_state, - &MleRef::Extension(&poseidon_folded_memory), - &all_poseidon_indexes, - poseidon_lookup_statements - .iter() - .enumerate() - .map(|(i, s)| s.value * poseidon_logup_star_alpha.exp_u64(i as u64)) - .sum(), - &poseidon_poly_eq_point, - &poseidon_pushforward, - Some(non_zero_memory_size.div_ceil(VECTOR_LEN)), - ); + let normal_lookup_into_memory_statements = + normal_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); + + let vectorized_lookup_statements = + vectorized_lookup_into_memory.step_2(&mut prover_state, non_zero_memory_size); let bytecode_logup_star_statements = prove_logup_star( &mut prover_state, &MleRef::Extension(&folded_bytecode), &full_trace[COL_INDEX_PC], - bytecode_lookup_claim_1.value + alpha_bytecode_lookup * bytecode_lookup_claim_2.value, + bytecode_lookup_claim_1.value, &bytecode_poly_eq_point, &bytecode_pushforward, Some(bytecode.instructions.len()), ); - let poseidon_lookup_memory_point = MultilinearPoint( - [ - poseidon_logup_star_statements.on_table.point.0.clone(), - memory_folding_challenges.0, - ] - .concat(), - ); - - memory_statements.push(base_memory_logup_star_statements.on_table.clone()); - memory_statements.push(Evaluation::new( - poseidon_lookup_memory_point.clone(), - poseidon_logup_star_statements.on_table.value, - )); + memory_statements.push(normal_lookup_into_memory_statements.on_table.clone()); + memory_statements.push(vectorized_lookup_statements.on_table.clone()); { // index opening for poseidon lookup - - let (correcting_factor_p16, correcting_factor_p24) = poseidon_lookup_correcting_factors( - log_n_p16, - log_n_p24, - &poseidon_logup_star_statements.on_indexes.point, - ); - let poseidon_index_evals = fold_multilinear_chunks( - &all_poseidon_indexes, - &MultilinearPoint(poseidon_logup_star_statements.on_indexes.point[3..].to_vec()), + p16_indexes_a_statements.extend(vectorized_lookup_statements.on_indexes[0].clone()); + p16_indexes_b_statements.extend(vectorized_lookup_statements.on_indexes[1].clone()); + p16_indexes_res_statements.extend(vectorized_lookup_statements.on_indexes[2].clone()); + // vectorized_lookup_statements.on_indexes[3] is proven via sumcheck below + p24_indexes_a_statements.extend(vectorized_lookup_statements.on_indexes[4].clone()); + p24_indexes_a_statements.extend( + vectorized_lookup_statements.on_indexes[5] + .iter() + .map(|eval| Evaluation::new(eval.point.clone(), eval.value - EF::ONE)), ); + p24_indexes_b_statements.extend(vectorized_lookup_statements.on_indexes[6].clone()); + p24_indexes_res_statements.extend(vectorized_lookup_statements.on_indexes[7].clone()); - let inner_values = [ - poseidon_index_evals[0] / correcting_factor_p16, - poseidon_index_evals[1] / correcting_factor_p16, - poseidon_index_evals[2] / correcting_factor_p16, - // skip 3 (16_output_b, proved via sumcheck) - poseidon_index_evals[4] / correcting_factor_p24, - // skip 5 (24_input_b) - poseidon_index_evals[6] / correcting_factor_p24, - poseidon_index_evals[7] / correcting_factor_p24, - ]; - prover_state.add_extension_scalars(&inner_values); - - let p16_value_index_res_b = poseidon_index_evals[3] / correcting_factor_p16; // prove this value via sumcheck: index_res_b = (index_res_a + 1) * (1 - compression) - let p16_one_minus_compression = p16_witness + let p16_one_minus_compression = &p16_witness .compression .as_ref() .unwrap() .1 .par_iter() - .map(|c| FPacking::::ONE - *c) - .collect::>(); - let p16_index_res_a_plus_one = FPacking::::pack_slice(&p16_indexes[2]) - .par_iter() - .map(|c| *c + F::ONE) + .map(|c| EFPacking::::ONE - *c) // TODO embedding overhead .collect::>(); - - // TODO there is a big inneficiency in impl SumcheckComputationPacked for ProductComputation + let p16_index_res_a_plus_one = pack_extension( + &p16_indexes_output + .par_iter() + .map(|c| EF::ONE + *c) // TODO embedding overhead + .collect::>(), + ); + let alpha = prover_state.sample(); + let mut poly_eq = EFPacking::::zero_vec(1 << (log_n_p16 - packing_log_width::())); + let mut sum = EF::ZERO; + for (statement, alpha_power) in vectorized_lookup_statements.on_indexes[3] + .iter() + .zip(alpha.powers()) + { + sum += statement.value * alpha_power; + compute_sparse_eval_eq_packed(&statement.point, &mut poly_eq, alpha_power); + } + // TODO there is a lot of embedding overhead in this sumcheck let (sc_point, sc_values, _) = sumcheck_prove( - 1, // TODO univariate skip - MleGroupRef::BasePacked(vec![&p16_one_minus_compression, &p16_index_res_a_plus_one]), - &ProductComputation, + 1, + MleGroupRef::ExtensionPacked(vec![ + &poly_eq, + &p16_one_minus_compression, + &p16_index_res_a_plus_one, + ]), + &CubeComputation, &[], - Some(( - poseidon_logup_star_statements.on_indexes.point[3..].to_vec(), - None, - )), + None, false, &mut prover_state, - p16_value_index_res_b, + sum, None, ); - prover_state.add_extension_scalar(sc_values[1]); - p16_indexes_res_statements.push(Evaluation::new(sc_point, sc_values[1] - EF::ONE)); - - add_poseidon_lookup_statements_on_indexes( - log_n_p16, - log_n_p24, - &poseidon_logup_star_statements.on_indexes.point, - &inner_values, - [ - &mut p16_indexes_a_statements, - &mut p16_indexes_b_statements, - &mut p16_indexes_res_statements, - ], - [ - &mut p24_indexes_a_statements, - &mut p24_indexes_b_statements, - &mut p24_indexes_res_statements, - ], - ); + prover_state.add_extension_scalar(sc_values[2]); + p16_indexes_res_statements.push(Evaluation::new(sc_point, sc_values[2] - EF::ONE)); } let (initial_pc_statement, final_pc_statement) = initial_and_final_pc_conditions(bytecode, log_n_cycles); - let dot_product_computation_column_evals = dot_product_computations_base - .par_iter() - .map(|slice| slice.evaluate(&dot_product_air_point)) - .collect::>(); - - prover_state.add_extension_scalars(&dot_product_computation_column_evals); - let dot_product_computation_column_statements = (0..DIMENSION) - .map(|i| { - vec![Evaluation::new( - dot_product_air_point.clone(), - dot_product_computation_column_evals[i], - )] - }) - .collect::>(); - - let mem_lookup_eval_indexes_partial_point = - MultilinearPoint(base_memory_logup_star_statements.on_indexes.point[2..].to_vec()); - let mem_lookup_eval_indexes_a = - full_trace[COL_INDEX_MEM_ADDRESS_A].evaluate(&mem_lookup_eval_indexes_partial_point); // validity is proven via PCS - let mem_lookup_eval_indexes_b = - full_trace[COL_INDEX_MEM_ADDRESS_B].evaluate(&mem_lookup_eval_indexes_partial_point); // validity is proven via PCS - let mem_lookup_eval_indexes_c = - full_trace[COL_INDEX_MEM_ADDRESS_C].evaluate(&mem_lookup_eval_indexes_partial_point); // validity is proven via PCS - assert_eq!(mem_lookup_eval_indexes_partial_point.len(), log_n_cycles); - assert_eq!( - log2_strict_usize(padded_dot_product_indexes_spread.len()), - log_n_rows_dot_product_table + 5 - ); - let index_diff = log_n_cycles - log2_strict_usize(padded_dot_product_indexes_spread.len()); - let mem_lookup_eval_spread_indexes_dot_product = padded_dot_product_indexes_spread.evaluate( - &MultilinearPoint(mem_lookup_eval_indexes_partial_point[index_diff..].to_vec()), - ); - - prover_state.add_extension_scalars(&[ - mem_lookup_eval_indexes_a, - mem_lookup_eval_indexes_b, - mem_lookup_eval_indexes_c, - mem_lookup_eval_spread_indexes_dot_product, - ]); - - let dot_product_logup_star_indexes_inner_point = - MultilinearPoint(mem_lookup_eval_indexes_partial_point.0[5 + index_diff..].to_vec()); - let dot_product_logup_star_indexes_inner_value_a = dot_product_columns - [DOT_PRODUCT_AIR_COL_INDEX_A] - .evaluate(&dot_product_logup_star_indexes_inner_point); - let dot_product_logup_star_indexes_inner_value_b = dot_product_columns - [DOT_PRODUCT_AIR_COL_INDEX_B] - .evaluate(&dot_product_logup_star_indexes_inner_point); - let dot_product_logup_star_indexes_inner_value_res = dot_product_columns - [DOT_PRODUCT_AIR_COL_INDEX_RES] - .evaluate(&dot_product_logup_star_indexes_inner_point); - - prover_state.add_extension_scalars(&[ - dot_product_logup_star_indexes_inner_value_a, - dot_product_logup_star_indexes_inner_value_b, - dot_product_logup_star_indexes_inner_value_res, - ]); - - let dot_product_logup_star_indexes_statement_a = Evaluation::new( - dot_product_logup_star_indexes_inner_point.clone(), - dot_product_logup_star_indexes_inner_value_a, - ); - let dot_product_logup_star_indexes_statement_b = Evaluation::new( - dot_product_logup_star_indexes_inner_point.clone(), - dot_product_logup_star_indexes_inner_value_b, - ); - let dot_product_logup_star_indexes_statement_res = Evaluation::new( - dot_product_logup_star_indexes_inner_point.clone(), - dot_product_logup_star_indexes_inner_value_res, - ); + let dot_product_computation_column_statements = dot_product_computation_ext_to_base_helper + .after_commitment(&mut prover_state, &dot_product_air_point); - let exec_air_statement = |col_index: usize| { - Evaluation::new( - exec_air_point.clone(), - exec_evals_to_prove[col_index.index_in_air()], - ) - }; + let exec_air_statement = + |col_index: usize| Evaluation::new(exec_air_point.clone(), exec_evals_to_prove[col_index]); let dot_product_air_statement = |col_index: usize| { Evaluation::new( dot_product_air_point.clone(), @@ -1104,28 +721,22 @@ pub fn prove_execution( initial_pc_statement, final_pc_statement, ], // pc - vec![exec_air_statement(COL_INDEX_FP), grand_product_fp_statement], // fp - vec![ - exec_air_statement(COL_INDEX_MEM_ADDRESS_A), - Evaluation::new( - mem_lookup_eval_indexes_partial_point.clone(), - mem_lookup_eval_indexes_a, - ), - ], // exec memory address A - vec![ - exec_air_statement(COL_INDEX_MEM_ADDRESS_B), - Evaluation::new( - mem_lookup_eval_indexes_partial_point.clone(), - mem_lookup_eval_indexes_b, - ), - ], // exec memory address B - vec![ - exec_air_statement(COL_INDEX_MEM_ADDRESS_C), - Evaluation::new( - mem_lookup_eval_indexes_partial_point, - mem_lookup_eval_indexes_c, - ), - ], // exec memory address C + vec![exec_air_statement(COL_INDEX_FP)], // fp + [ + vec![exec_air_statement(COL_INDEX_MEM_ADDRESS_A)], + normal_lookup_into_memory_statements.on_indexes[0].clone(), + ] + .concat(), // exec memory address A + [ + vec![exec_air_statement(COL_INDEX_MEM_ADDRESS_B)], + normal_lookup_into_memory_statements.on_indexes[1].clone(), + ] + .concat(), // exec memory address B + [ + vec![exec_air_statement(COL_INDEX_MEM_ADDRESS_C)], + normal_lookup_into_memory_statements.on_indexes[2].clone(), + ] + .concat(), // exec memory address C p16_indexes_a_statements, p16_indexes_b_statements, p16_indexes_res_statements, @@ -1133,32 +744,26 @@ pub fn prove_execution( p24_indexes_b_statements, p24_indexes_res_statements, ], - p16_cubes_statements, - p24_cubes_statements, + encapsulate_vec(p16_gkr.cubes_statements.split()), + encapsulate_vec(p24_gkr.cubes_statements.split()), vec![ - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_START_FLAG), - grand_product_dot_product_flag_statement, - ], - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_LEN), - grand_product_dot_product_len_statement, - ], - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_A), - dot_product_logup_star_indexes_statement_a, - grand_product_dot_product_table_indexes_statement_index_a, - ], - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_B), - dot_product_logup_star_indexes_statement_b, - grand_product_dot_product_table_indexes_statement_index_b, - ], - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_RES), - dot_product_logup_star_indexes_statement_res, - grand_product_dot_product_table_indexes_statement_index_res, - ], + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_START_FLAG)], + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_LEN)], + [ + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_A)], + normal_lookup_into_memory_statements.on_indexes[3].clone(), + ] + .concat(), + [ + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_B)], + normal_lookup_into_memory_statements.on_indexes[4].clone(), + ] + .concat(), + [ + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_RES)], + normal_lookup_into_memory_statements.on_indexes[4].clone(), + ] + .concat(), ], dot_product_computation_column_statements, ] @@ -1178,8 +783,8 @@ pub fn prove_execution( &extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK, &[ - base_memory_logup_star_statements.on_pushforward, - poseidon_logup_star_statements.on_pushforward, + normal_lookup_into_memory_statements.on_pushforward, + vectorized_lookup_statements.on_pushforward, bytecode_logup_star_statements.on_pushforward, ], &mut prover_state, diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 99eba7cf..eece77b6 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -1,18 +1,16 @@ use crate::common::*; use crate::*; -use ::air::table::AirTable; +use ::air::AirTable; +use air::verify_air; use lean_vm::*; use lookup::verify_gkr_product; use lookup::verify_logup_star; use multilinear_toolkit::prelude::*; -use p3_field::PrimeCharacteristicRing; -use p3_field::dot_product; use p3_util::{log2_ceil_usize, log2_strict_usize}; -use packed_pcs::*; use poseidon_circuit::PoseidonGKRLayers; use poseidon_circuit::verify_poseidon_gkr; +use sub_protocols::*; use utils::ToUsize; -use utils::dot_product_with_base; use utils::{build_challenger, padd_with_zero_to_next_power_of_two}; use vm_air::*; use whir_p3::WhirConfig; @@ -27,12 +25,9 @@ pub fn verify_execution( ) -> Result<(), ProofError> { let mut verifier_state = VerifierState::new(proof_data, build_challenger()); - let exec_table = AirTable::::new(VMAir); let p16_gkr_layers = PoseidonGKRLayers::<16, N_COMMITED_CUBES_P16>::build(Some(VECTOR_LEN)); let p24_gkr_layers = PoseidonGKRLayers::<24, N_COMMITED_CUBES_P24>::build(None); - let dot_product_table = AirTable::::new(DotProductAir); - let [ n_cycles, n_poseidons_16, @@ -121,13 +116,16 @@ pub fn verify_execution( let grand_product_challenge_global = verifier_state.sample(); let fingerprint_challenge = verifier_state.sample(); let (grand_product_exec_res, grand_product_exec_statement) = - verify_gkr_product(&mut verifier_state, log_n_cycles)?; + verify_gkr_product::<_, TWO_POW_UNIVARIATE_SKIPS>(&mut verifier_state, log_n_cycles)?; let (grand_product_p16_res, grand_product_p16_statement) = - verify_gkr_product(&mut verifier_state, log_n_p16)?; + verify_gkr_product::<_, 2>(&mut verifier_state, log_n_p16)?; let (grand_product_p24_res, grand_product_p24_statement) = - verify_gkr_product(&mut verifier_state, log_n_p24)?; + verify_gkr_product::<_, 2>(&mut verifier_state, log_n_p24)?; let (grand_product_dot_product_res, grand_product_dot_product_statement) = - verify_gkr_product(&mut verifier_state, table_dot_products_log_n_rows)?; + verify_gkr_product::<_, TWO_POW_DOT_PRODUCT_UNIVARIATE_SKIPS>( + &mut verifier_state, + table_dot_products_log_n_rows, + )?; let vm_multilinear_eval_grand_product_res = vm_multilinear_evals .iter() .map(|vm_multilinear_eval| { @@ -266,96 +264,34 @@ pub fn verify_execution( p24_grand_product_evals_on_indexes_res, )]; - // Grand product statements - let (grand_product_final_dot_product_eval, grand_product_dot_product_sumcheck_claim) = - sumcheck_verify(&mut verifier_state, table_dot_products_log_n_rows, 3)?; - if grand_product_final_dot_product_eval != grand_product_dot_product_statement.value { - return Err(ProofError::InvalidProof); - } - let grand_product_dot_product_sumcheck_inner_evals = - verifier_state.next_extension_scalars_vec(5)?; - - if grand_product_dot_product_sumcheck_claim.value - != grand_product_dot_product_sumcheck_claim - .point - .eq_poly_outside(&grand_product_dot_product_statement.point) - * { - DotProductFootprint { - global_challenge: grand_product_challenge_global, - fingerprint_challenge_powers: powers_const(fingerprint_challenge), - } - .eval(&grand_product_dot_product_sumcheck_inner_evals, &[]) - } - { - return Err(ProofError::InvalidProof); - } - - let grand_product_dot_product_flag_statement = Evaluation::new( - grand_product_dot_product_sumcheck_claim.point.clone(), - grand_product_dot_product_sumcheck_inner_evals[0], - ); - let grand_product_dot_product_len_statement = Evaluation::new( - grand_product_dot_product_sumcheck_claim.point.clone(), - grand_product_dot_product_sumcheck_inner_evals[1], - ); - let grand_product_dot_product_table_indexes_statement_index_a = Evaluation::new( - grand_product_dot_product_sumcheck_claim.point.clone(), - grand_product_dot_product_sumcheck_inner_evals[2], - ); - let grand_product_dot_product_table_indexes_statement_index_b = Evaluation::new( - grand_product_dot_product_sumcheck_claim.point.clone(), - grand_product_dot_product_sumcheck_inner_evals[3], - ); - let grand_product_dot_product_table_indexes_statement_index_res = Evaluation::new( - grand_product_dot_product_sumcheck_claim.point.clone(), - grand_product_dot_product_sumcheck_inner_evals[4], - ); - - let (grand_product_final_exec_eval, grand_product_exec_sumcheck_claim) = - sumcheck_verify(&mut verifier_state, log_n_cycles, 4)?; - if grand_product_final_exec_eval != grand_product_exec_statement.value { - return Err(ProofError::InvalidProof); - } - - let grand_product_exec_sumcheck_inner_evals = - verifier_state.next_extension_scalars_vec(N_TOTAL_COLUMNS)?; // TODO some of the values are unused - - let grand_product_exec_evals_on_each_column = - &grand_product_exec_sumcheck_inner_evals[..N_INSTRUCTION_COLUMNS]; - - if grand_product_exec_sumcheck_claim.value - != grand_product_exec_sumcheck_claim - .point - .eq_poly_outside(&grand_product_exec_statement.point) - * { - PrecompileFootprint { - global_challenge: grand_product_challenge_global, - fingerprint_challenge_powers: powers_const(fingerprint_challenge), - } - .eval( - &reorder_full_trace_for_precomp_foot_print( - grand_product_exec_sumcheck_inner_evals.clone(), - ), - &[], - ) - } - { - return Err(ProofError::InvalidProof); - } - - let grand_product_fp_statement = Evaluation::new( - grand_product_exec_sumcheck_claim.point.clone(), - grand_product_exec_sumcheck_inner_evals[COL_INDEX_FP], - ); - - let (exec_air_point, exec_evals_to_verify) = - exec_table.verify(&mut verifier_state, UNIVARIATE_SKIPS, log_n_cycles)?; + let exec_table = AirTable::::new(VMAir { + global_challenge: grand_product_challenge_global, + fingerprint_challenge_powers: powers_const(fingerprint_challenge), + }); + let (exec_air_point, exec_evals_to_verify) = verify_air( + &mut verifier_state, + &exec_table, + UNIVARIATE_SKIPS, + log_n_cycles, + &execution_air_padding_row(bytecode.ending_pc), + Some(grand_product_exec_statement), + )?; - let (dot_product_air_point, dot_product_evals_to_verify) = - dot_product_table.verify(&mut verifier_state, 1, table_dot_products_log_n_rows)?; + let dot_product_table = AirTable::::new(DotProductAir { + global_challenge: grand_product_challenge_global, + fingerprint_challenge_powers: powers_const(fingerprint_challenge), + }); + let (dot_product_air_point, dot_product_evals_to_verify) = verify_air( + &mut verifier_state, + &dot_product_table, + DOT_PRODUCT_UNIVARIATE_SKIPS, + table_dot_products_log_n_rows, + &dot_product_air_padding_row(), + Some(grand_product_dot_product_statement), + )?; let random_point_p16 = MultilinearPoint(verifier_state.sample_vec(log_n_p16)); - let gkr_16 = verify_poseidon_gkr( + let p16_gkr = verify_poseidon_gkr( &mut verifier_state, log_n_p16, &random_point_p16, @@ -363,20 +299,9 @@ pub fn verify_execution( UNIVARIATE_SKIPS, Some(n_compressions_16), ); - let p16_cubes_statements = gkr_16 - .cubes_statements - .1 - .iter() - .map(|&e| { - vec![Evaluation { - point: gkr_16.cubes_statements.0.clone(), - value: e, - }] - }) - .collect::>(); let random_point_p24 = MultilinearPoint(verifier_state.sample_vec(log_n_p24)); - let gkr_24 = verify_poseidon_gkr( + let p24_gkr = verify_poseidon_gkr( &mut verifier_state, log_n_p24, &random_point_p24, @@ -384,138 +309,55 @@ pub fn verify_execution( UNIVARIATE_SKIPS, None, ); - let p24_cubes_statements = gkr_24 - .cubes_statements - .1 - .iter() - .map(|&e| { - vec![Evaluation { - point: gkr_24.cubes_statements.0.clone(), - value: e, - }] - }) - .collect::>(); - - let poseidon_logup_star_alpha = verifier_state.sample(); - let memory_folding_challenges = MultilinearPoint(verifier_state.sample_vec(LOG_VECTOR_LEN)); - let non_used_precompiles_evals = verifier_state - .next_extension_scalars_vec(N_INSTRUCTION_COLUMNS - N_INSTRUCTION_COLUMNS_IN_AIR)?; let bytecode_compression_challenges = MultilinearPoint(verifier_state.sample_vec(log2_ceil_usize(N_INSTRUCTION_COLUMNS))); let bytecode_lookup_claim_1 = Evaluation::new( exec_air_point.clone(), - padd_with_zero_to_next_power_of_two( - &[ - (0..N_INSTRUCTION_COLUMNS_IN_AIR) - .map(|i| exec_evals_to_verify[i]) - .collect::>(), - non_used_precompiles_evals, - ] - .concat(), - ) - .evaluate(&bytecode_compression_challenges), - ); - - let bytecode_lookup_claim_2 = Evaluation::new( - grand_product_exec_sumcheck_claim.point.clone(), - padd_with_zero_to_next_power_of_two(grand_product_exec_evals_on_each_column) + padd_with_zero_to_next_power_of_two(&exec_evals_to_verify[..N_INSTRUCTION_COLUMNS]) .evaluate(&bytecode_compression_challenges), ); - let alpha_bytecode_lookup = verifier_state.sample(); - - let dot_product_values_mixing_challenges = MultilinearPoint(verifier_state.sample_vec(2)); - - let dot_product_evals_spread = verifier_state.next_extension_scalars_vec(DIMENSION)?; - - let dot_product_values_mixed = [ - dot_product_evals_to_verify[DOT_PRODUCT_AIR_COL_VALUE_A], - dot_product_evals_to_verify[DOT_PRODUCT_AIR_COL_VALUE_B], - dot_product_evals_to_verify[DOT_PRODUCT_AIR_COL_RES], - EF::ZERO, - ] - .evaluate(&dot_product_values_mixing_challenges); - - if dot_product_with_base(&dot_product_evals_spread) != dot_product_values_mixed { - return Err(ProofError::InvalidProof); - } - let dot_product_values_batching_scalars = MultilinearPoint(verifier_state.sample_vec(3)); - let dot_product_values_batched_point = MultilinearPoint( - [ - dot_product_values_batching_scalars.0.clone(), - dot_product_values_mixing_challenges.0.clone(), - dot_product_air_point.0.clone(), - ] - .concat(), - ); - let dot_product_values_batched_eval = - padd_with_zero_to_next_power_of_two(&dot_product_evals_spread) - .evaluate(&dot_product_values_batching_scalars); - - let unused_1 = verifier_state.next_extension_scalar()?; - let grand_product_mem_values_mixing_challenges = MultilinearPoint(verifier_state.sample_vec(2)); - let base_memory_lookup_statement_1 = Evaluation::new( - [ - grand_product_mem_values_mixing_challenges.0.clone(), - grand_product_exec_sumcheck_claim.point.0, - ] - .concat(), - [ - grand_product_exec_sumcheck_inner_evals[COL_INDEX_MEM_VALUE_A], - grand_product_exec_sumcheck_inner_evals[COL_INDEX_MEM_VALUE_B], - grand_product_exec_sumcheck_inner_evals[COL_INDEX_MEM_VALUE_C], - unused_1, - ] - .evaluate(&grand_product_mem_values_mixing_challenges), - ); - let unused_2 = verifier_state.next_extension_scalar()?; - let exec_air_mem_values_mixing_challenges = MultilinearPoint(verifier_state.sample_vec(2)); - let base_memory_lookup_statement_2 = Evaluation::new( + let normal_lookup_into_memory = NormalPackedLookupVerifier::step_1( + &mut verifier_state, + 3, [ - exec_air_mem_values_mixing_challenges.0.clone(), - exec_air_point.0.clone(), + vec![n_cycles; 3], + vec![n_rows_table_dot_products.max(1 << LOG_MIN_DOT_PRODUCT_ROWS); 3], ] .concat(), - [ - exec_evals_to_verify[COL_INDEX_MEM_VALUE_A.index_in_air()], - exec_evals_to_verify[COL_INDEX_MEM_VALUE_B.index_in_air()], - exec_evals_to_verify[COL_INDEX_MEM_VALUE_C.index_in_air()], - unused_2, - ] - .evaluate(&exec_air_mem_values_mixing_challenges), - ); - - let [unused_3a, unused_3b, unused_3c] = verifier_state.next_extension_scalars_const()?; + [vec![0; 3], vec![0; 3]].concat(), + normal_lookup_into_memory_initial_statements( + &exec_air_point, + &exec_evals_to_verify, + &dot_product_air_point, + &dot_product_evals_to_verify, + ), + LOG_SMALLEST_DECOMPOSITION_CHUNK, + &public_memory, // we need to pass the first few values of memory, public memory is enough + )?; - let dot_product_air_mem_values_mixing_challenges = - MultilinearPoint(verifier_state.sample_vec(2)); - let base_memory_lookup_statement_3 = Evaluation::new( + let vectorized_lookup_into_memory = VectorizedPackedLookupVerifier::<_, VECTOR_LEN>::step_1( + &mut verifier_state, [ - dot_product_air_mem_values_mixing_challenges.0.clone(), - EF::zero_vec(log_n_cycles - dot_product_values_batched_point.len()), - dot_product_values_batched_point.0.clone(), + vec![n_poseidons_16.max(1 << LOG_MIN_POSEIDONS_16); 4], + vec![n_poseidons_24.max(1 << LOG_MIN_POSEIDONS_24); 4], ] .concat(), - [ - unused_3a, - unused_3b, - unused_3c, - dot_product_values_batched_eval, - ] - .evaluate(&dot_product_air_mem_values_mixing_challenges), - ); - - let memory_poly_eq_point_alpha = verifier_state.sample(); + default_poseidon_indexes(), + poseidon_lookup_statements(&p16_gkr, &p24_gkr), + LOG_SMALLEST_DECOMPOSITION_CHUNK, + &public_memory, // we need to pass the first few values of memory, public memory is enough + )?; let extension_dims = vec![ - ColDims::padded(public_memory.len() + private_memory_len, EF::ZERO), // memory + ColDims::padded(public_memory.len() + private_memory_len, EF::ZERO), // memory pushwordard ColDims::padded( (public_memory.len() + private_memory_len).div_ceil(VECTOR_LEN), EF::ZERO, - ), // memory (folded) - ColDims::padded(bytecode.instructions.len(), EF::ZERO), + ), // memory (folded) pushwordard + ColDims::padded(bytecode.instructions.len(), EF::ZERO), // bytecode pushforward ]; let parsed_commitment_extension = packed_pcs_parse_commitment( @@ -530,44 +372,18 @@ pub fn verify_execution( ) .unwrap(); - let base_memory_logup_star_statements = verify_logup_star( - &mut verifier_state, - log_memory, - log_n_cycles + 2, - &[ - base_memory_lookup_statement_1, - base_memory_lookup_statement_2, - base_memory_lookup_statement_3, - ], - memory_poly_eq_point_alpha, - ) - .unwrap(); - - let poseidon_lookup_statements = get_poseidon_lookup_statements( - (log_n_p16, log_n_p24), - &gkr_16.input_statements, - &(random_point_p16.clone(), gkr_16.output_values), - &gkr_24.input_statements, - &(random_point_p24.clone(), gkr_24.output_values), - &memory_folding_challenges, - ); + let normal_lookup_statements = + normal_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; - let poseidon_lookup_log_length = 3 + log_n_p16.max(log_n_p24); - let poseidon_logup_star_statements = verify_logup_star( - &mut verifier_state, - log_memory - 3, // "-3" because it's folded memory - poseidon_lookup_log_length, - &poseidon_lookup_statements, - poseidon_logup_star_alpha, - ) - .unwrap(); + let vectorized_lookup_statements = + vectorized_lookup_into_memory.step_2(&mut verifier_state, log_memory)?; let bytecode_logup_star_statements = verify_logup_star( &mut verifier_state, log2_ceil_usize(bytecode.instructions.len()), log_n_cycles, - &[bytecode_lookup_claim_1, bytecode_lookup_claim_2], - alpha_bytecode_lookup, + &[bytecode_lookup_claim_1], + EF::ONE, ) .unwrap(); let folded_bytecode = fold_bytecode(bytecode, &bytecode_compression_challenges); @@ -577,30 +393,27 @@ pub fn verify_execution( return Err(ProofError::InvalidProof); } - let poseidon_lookup_memory_point = MultilinearPoint( - [ - poseidon_logup_star_statements.on_table.point.0.clone(), - memory_folding_challenges.0, - ] - .concat(), - ); - - memory_statements.push(base_memory_logup_star_statements.on_table.clone()); - memory_statements.push(Evaluation::new( - poseidon_lookup_memory_point.clone(), - poseidon_logup_star_statements.on_table.value, - )); + memory_statements.push(normal_lookup_statements.on_table.clone()); + memory_statements.push(vectorized_lookup_statements.on_table.clone()); { // index opening for poseidon lookup - let (correcting_factor_p16, correcting_factor_p24) = poseidon_lookup_correcting_factors( - log_n_p16, - log_n_p24, - &poseidon_logup_star_statements.on_indexes.point, + // index opening for poseidon lookup + p16_indexes_a_statements.extend(vectorized_lookup_statements.on_indexes[0].clone()); + p16_indexes_b_statements.extend(vectorized_lookup_statements.on_indexes[1].clone()); + p16_indexes_res_statements.extend(vectorized_lookup_statements.on_indexes[2].clone()); + // vectorized_lookup_statements.on_indexes[3] is proven via sumcheck below + p24_indexes_a_statements.extend(vectorized_lookup_statements.on_indexes[4].clone()); + p24_indexes_a_statements.extend( + vectorized_lookup_statements.on_indexes[5] + .iter() + .map(|eval| Evaluation::new(eval.point.clone(), eval.value - EF::ONE)), ); + p24_indexes_b_statements.extend(vectorized_lookup_statements.on_indexes[6].clone()); + p24_indexes_res_statements.extend(vectorized_lookup_statements.on_indexes[7].clone()); - let mut inner_values = verifier_state.next_extension_scalars_vec(6)?; + let alpha = verifier_state.sample(); let (p16_value_index_res_b, sc_eval) = sumcheck_verify_with_univariate_skip( &mut verifier_state, @@ -608,6 +421,18 @@ pub fn verify_execution( log_n_p16, 1, // TODO univariate skip )?; + let mut eq_poly_eval = EF::ZERO; + let mut p16_value_index_res_b_expected = EF::ZERO; + for (statement, alpha_power) in vectorized_lookup_statements.on_indexes[3] + .iter() + .zip(alpha.powers()) + { + p16_value_index_res_b_expected += statement.value * alpha_power; + eq_poly_eval += alpha_power * statement.point.eq_poly_outside(&sc_eval.point); + } + if p16_value_index_res_b_expected != p16_value_index_res_b { + return Err(ProofError::InvalidProof); + } let sc_res_index_value = verifier_state.next_extension_scalar()?; p16_indexes_res_statements.push(Evaluation::new( sc_eval.point.clone(), @@ -617,158 +442,27 @@ pub fn verify_execution( if sc_res_index_value * (EF::ONE - mle_of_zeros_then_ones((1 << log_n_p16) - n_compressions_16, &sc_eval.point)) - * sc_eval.point.eq_poly_outside(&MultilinearPoint( - poseidon_logup_star_statements.on_indexes.point[3..].to_vec(), - )) + * eq_poly_eval != sc_eval.value { return Err(ProofError::InvalidProof); } - - add_poseidon_lookup_statements_on_indexes( - log_n_p16, - log_n_p24, - &poseidon_logup_star_statements.on_indexes.point, - &inner_values, - [ - &mut p16_indexes_a_statements, - &mut p16_indexes_b_statements, - &mut p16_indexes_res_statements, - ], - [ - &mut p24_indexes_a_statements, - &mut p24_indexes_b_statements, - &mut p24_indexes_res_statements, - ], - ); - - inner_values.insert(3, p16_value_index_res_b); - inner_values.insert(5, inner_values[4] + EF::ONE); - - for v in &mut inner_values[..4] { - *v *= correcting_factor_p16; - } - for v in &mut inner_values[4..] { - *v *= correcting_factor_p24; - } - - if inner_values.evaluate(&MultilinearPoint( - poseidon_logup_star_statements.on_indexes.point[..3].to_vec(), - )) != poseidon_logup_star_statements.on_indexes.value - { - return Err(ProofError::InvalidProof); - } } let (initial_pc_statement, final_pc_statement) = initial_and_final_pc_conditions(bytecode, log_n_cycles); - let dot_product_computation_column_evals = - verifier_state.next_extension_scalars_const::()?; - if dot_product_with_base(&dot_product_computation_column_evals) - != dot_product_evals_to_verify[DOT_PRODUCT_AIR_COL_COMPUTATION] - { - return Err(ProofError::InvalidProof); - } - let dot_product_computation_column_statements = (0..DIMENSION) - .map(|i| { - vec![Evaluation::new( + let dot_product_computation_column_statements = + ExtensionCommitmentFromBaseVerifier::after_commitment( + &mut verifier_state, + &Evaluation::new( dot_product_air_point.clone(), - dot_product_computation_column_evals[i], - )] - }) - .collect::>(); - - let mem_lookup_eval_indexes_partial_point = - MultilinearPoint(base_memory_logup_star_statements.on_indexes.point[2..].to_vec()); - let [ - mem_lookup_eval_indexes_a, - mem_lookup_eval_indexes_b, - mem_lookup_eval_indexes_c, - mem_lookup_eval_spread_indexes_dot_product, - ] = verifier_state.next_extension_scalars_const()?; - - let index_diff = log_n_cycles - (table_dot_products_log_n_rows + 5); - - assert_eq!( - [ - mem_lookup_eval_indexes_a, - mem_lookup_eval_indexes_b, - mem_lookup_eval_indexes_c, - mem_lookup_eval_spread_indexes_dot_product - * mem_lookup_eval_indexes_partial_point[..index_diff] - .iter() - .map(|x| EF::ONE - *x) - .product::(), - ] - .evaluate(&MultilinearPoint( - base_memory_logup_star_statements.on_indexes.point[..2].to_vec(), - )), - base_memory_logup_star_statements.on_indexes.value - ); - - let dot_product_logup_star_indexes_inner_point = - MultilinearPoint(mem_lookup_eval_indexes_partial_point.0[5 + index_diff..].to_vec()); - - let [ - dot_product_logup_star_indexes_inner_value_a, - dot_product_logup_star_indexes_inner_value_b, - dot_product_logup_star_indexes_inner_value_res, - ] = verifier_state.next_extension_scalars_const()?; - - let dot_product_logup_star_indexes_statement_a = Evaluation::new( - dot_product_logup_star_indexes_inner_point.clone(), - dot_product_logup_star_indexes_inner_value_a, - ); - let dot_product_logup_star_indexes_statement_b = Evaluation::new( - dot_product_logup_star_indexes_inner_point.clone(), - dot_product_logup_star_indexes_inner_value_b, - ); - let dot_product_logup_star_indexes_statement_res = Evaluation::new( - dot_product_logup_star_indexes_inner_point.clone(), - dot_product_logup_star_indexes_inner_value_res, - ); - - { - let dot_product_logup_star_indexes_inner_value: EF = dot_product( - eval_eq(&mem_lookup_eval_indexes_partial_point.0[3 + index_diff..5 + index_diff]) - .into_iter(), - [ - dot_product_logup_star_indexes_inner_value_a, - dot_product_logup_star_indexes_inner_value_b, - dot_product_logup_star_indexes_inner_value_res, - EF::ZERO, - ] - .into_iter(), - ); - - let mut dot_product_indexes_inner_evals_incr = vec![EF::ZERO; 8]; - for (i, value) in dot_product_indexes_inner_evals_incr - .iter_mut() - .enumerate() - .take(DIMENSION) - { - *value = dot_product_logup_star_indexes_inner_value - + EF::from_usize(i) - * [F::ONE, F::ONE, F::ONE, F::ZERO].evaluate(&MultilinearPoint( - mem_lookup_eval_indexes_partial_point.0[3 + index_diff..5 + index_diff] - .to_vec(), - )); - } - if dot_product_indexes_inner_evals_incr.evaluate(&MultilinearPoint( - mem_lookup_eval_indexes_partial_point.0[index_diff..3 + index_diff].to_vec(), - )) != mem_lookup_eval_spread_indexes_dot_product - { - return Err(ProofError::InvalidProof); - } - } + dot_product_evals_to_verify[DOT_PRODUCT_AIR_COL_COMPUTATION], + ), + )?; - let exec_air_statement = |col_index: usize| { - Evaluation::new( - exec_air_point.clone(), - exec_evals_to_verify[col_index.index_in_air()], - ) - }; + let exec_air_statement = + |col_index: usize| Evaluation::new(exec_air_point.clone(), exec_evals_to_verify[col_index]); let dot_product_air_statement = |col_index: usize| { Evaluation::new( dot_product_air_point.clone(), @@ -788,28 +482,22 @@ pub fn verify_execution( initial_pc_statement, final_pc_statement, ], // pc - vec![exec_air_statement(COL_INDEX_FP), grand_product_fp_statement], // fp - vec![ - exec_air_statement(COL_INDEX_MEM_ADDRESS_A), - Evaluation::new( - mem_lookup_eval_indexes_partial_point.clone(), - mem_lookup_eval_indexes_a, - ), - ], // exec memory address A - vec![ - exec_air_statement(COL_INDEX_MEM_ADDRESS_B), - Evaluation::new( - mem_lookup_eval_indexes_partial_point.clone(), - mem_lookup_eval_indexes_b, - ), - ], // exec memory address B - vec![ - exec_air_statement(COL_INDEX_MEM_ADDRESS_C), - Evaluation::new( - mem_lookup_eval_indexes_partial_point, - mem_lookup_eval_indexes_c, - ), - ], // exec memory address C + vec![exec_air_statement(COL_INDEX_FP)], // fp + [ + vec![exec_air_statement(COL_INDEX_MEM_ADDRESS_A)], + normal_lookup_statements.on_indexes[0].clone(), + ] + .concat(), // exec memory address A + [ + vec![exec_air_statement(COL_INDEX_MEM_ADDRESS_B)], + normal_lookup_statements.on_indexes[1].clone(), + ] + .concat(), // exec memory address B + [ + vec![exec_air_statement(COL_INDEX_MEM_ADDRESS_C)], + normal_lookup_statements.on_indexes[2].clone(), + ] + .concat(), // exec memory address C p16_indexes_a_statements, p16_indexes_b_statements, p16_indexes_res_statements, @@ -817,32 +505,26 @@ pub fn verify_execution( p24_indexes_b_statements, p24_indexes_res_statements, ], - p16_cubes_statements, - p24_cubes_statements, + encapsulate_vec(p16_gkr.cubes_statements.split()), + encapsulate_vec(p24_gkr.cubes_statements.split()), vec![ - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_START_FLAG), - grand_product_dot_product_flag_statement, - ], // dot product: (start) flag - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_LEN), - grand_product_dot_product_len_statement, - ], // dot product: length - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_A), - dot_product_logup_star_indexes_statement_a, - grand_product_dot_product_table_indexes_statement_index_a, - ], // dot product: index a - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_B), - dot_product_logup_star_indexes_statement_b, - grand_product_dot_product_table_indexes_statement_index_b, - ], // dot product: index b - vec![ - dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_RES), - dot_product_logup_star_indexes_statement_res, - grand_product_dot_product_table_indexes_statement_index_res, - ], // dot product: index res + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_START_FLAG)], // dot product: (start) flag + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_LEN)], // dot product: length + [ + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_A)], + normal_lookup_statements.on_indexes[3].clone(), + ] + .concat(), + [ + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_B)], + normal_lookup_statements.on_indexes[4].clone(), + ] + .concat(), + [ + vec![dot_product_air_statement(DOT_PRODUCT_AIR_COL_INDEX_RES)], + normal_lookup_statements.on_indexes[4].clone(), + ] + .concat(), ], dot_product_computation_column_statements, ] @@ -855,8 +537,8 @@ pub fn verify_execution( &extension_dims, LOG_SMALLEST_DECOMPOSITION_CHUNK, &[ - base_memory_logup_star_statements.on_pushforward, - poseidon_logup_star_statements.on_pushforward, + normal_lookup_statements.on_pushforward, + vectorized_lookup_statements.on_pushforward, bytecode_logup_star_statements.on_pushforward, ], &mut verifier_state, diff --git a/crates/lean_prover/tests/hash_chain.rs b/crates/lean_prover/tests/hash_chain.rs index d4a9f874..05ff54ae 100644 --- a/crates/lean_prover/tests/hash_chain.rs +++ b/crates/lean_prover/tests/hash_chain.rs @@ -1,11 +1,10 @@ -use std::time::Instant; - use lean_compiler::*; use lean_prover::{ prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder, }; use lean_vm::{F, execute_bytecode}; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; +use std::time::Instant; use xmss::iterate_hash; #[test] diff --git a/crates/lean_prover/tests/test_zkvm.rs b/crates/lean_prover/tests/test_zkvm.rs index 48360691..9733e38a 100644 --- a/crates/lean_prover/tests/test_zkvm.rs +++ b/crates/lean_prover/tests/test_zkvm.rs @@ -3,7 +3,7 @@ use lean_prover::{ prove_execution::prove_execution, verify_execution::verify_execution, whir_config_builder, }; use lean_vm::*; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; #[test] fn test_zk_vm_all_precompiles() { diff --git a/crates/lean_prover/vm_air/Cargo.toml b/crates/lean_prover/vm_air/Cargo.toml index 93066cd6..baa78472 100644 --- a/crates/lean_prover/vm_air/Cargo.toml +++ b/crates/lean_prover/vm_air/Cargo.toml @@ -10,24 +10,20 @@ workspace = true pest.workspace = true pest_derive.workspace = true utils.workspace = true -p3-field.workspace = true xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true p3-challenger.workspace = true p3-air.workspace = true -p3-matrix.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true -rayon.workspace = true tracing.workspace = true air.workspace = true -packed_pcs.workspace = true -p3-poseidon2-air.workspace = true +p3-uni-stark.workspace = true +sub_protocols.workspace = true lookup.workspace = true lean_vm.workspace = true lean_compiler.workspace = true -witness_generation.workspace = true multilinear-toolkit.workspace = true diff --git a/crates/lean_prover/vm_air/src/dot_product_air.rs b/crates/lean_prover/vm_air/src/dot_product_air.rs index e3a1dc93..e2ede9e2 100644 --- a/crates/lean_prover/vm_air/src/dot_product_air.rs +++ b/crates/lean_prover/vm_air/src/dot_product_air.rs @@ -1,9 +1,11 @@ -use std::borrow::Borrow; - -use lean_vm::{DIMENSION, EF, WitnessDotProduct}; +use lean_vm::{DIMENSION, EF, TABLE_INDEX_DOT_PRODUCTS}; +use multilinear_toolkit::prelude::*; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::PrimeCharacteristicRing; -use p3_matrix::Matrix; +use p3_uni_stark::SymbolicExpression; +use std::{ + any::TypeId, + mem::{transmute, transmute_copy}, +}; /* (DIMENSION = 5) @@ -28,61 +30,75 @@ pub const DOT_PRODUCT_AIR_COL_INDEX_B: usize = 3; pub const DOT_PRODUCT_AIR_COL_INDEX_RES: usize = 4; pub const DOT_PRODUCT_AIR_COL_VALUE_A: usize = 5; pub const DOT_PRODUCT_AIR_COL_VALUE_B: usize = 6; -pub const DOT_PRODUCT_AIR_COL_RES: usize = 7; +pub const DOT_PRODUCT_AIR_COL_VALUE_RES: usize = 7; pub const DOT_PRODUCT_AIR_COL_COMPUTATION: usize = 8; pub const DOT_PRODUCT_AIR_N_COLUMNS: usize = 9; #[derive(Debug)] -pub struct DotProductAir; +pub struct DotProductAir { + pub global_challenge: EF, + pub fingerprint_challenge_powers: [EF; 5], +} -impl BaseAir for DotProductAir { +impl BaseAir for DotProductAir { fn width(&self) -> usize { DOT_PRODUCT_AIR_N_COLUMNS } - fn structured(&self) -> bool { - true - } fn degree(&self) -> usize { 3 } + fn columns_with_shift(&self) -> Vec { + vec![ + DOT_PRODUCT_AIR_COL_START_FLAG, + DOT_PRODUCT_AIR_COL_LEN, + DOT_PRODUCT_AIR_COL_INDEX_A, + DOT_PRODUCT_AIR_COL_INDEX_B, + DOT_PRODUCT_AIR_COL_COMPUTATION, + ] + } } -impl Air for DotProductAir { +impl>> Air for DotProductAir +where + AB::Var: 'static, + AB::Expr: 'static, + AB::FinalOutput: 'static, +{ #[inline] fn eval(&self, builder: &mut AB) { let main = builder.main(); - let up = main.row_slice(0).unwrap(); - let up: &[AB::Var] = (*up).borrow(); - assert_eq!(up.len(), DOT_PRODUCT_AIR_N_COLUMNS); - let down = main.row_slice(1).unwrap(); - let down: &[AB::Var] = (*down).borrow(); - assert_eq!(down.len(), DOT_PRODUCT_AIR_N_COLUMNS); - - let [ - start_flag_up, - len_up, - index_a_up, - index_b_up, - _index_res_up, - value_a_up, - value_b_up, - res_up, - computation_up, - ] = up.to_vec().try_into().ok().unwrap(); - let [ - start_flag_down, - len_down, - index_a_down, - index_b_down, - _index_res_down, - _value_a_down, - _value_b_down, - _res_down, - computation_down, - ] = down.to_vec().try_into().ok().unwrap(); - - // TODO we could some some of the following computation in the base field + let up = &main[..DOT_PRODUCT_AIR_N_COLUMNS]; + let down = &main[DOT_PRODUCT_AIR_N_COLUMNS..]; + + let start_flag_up = up[DOT_PRODUCT_AIR_COL_START_FLAG].clone(); + let len_up = up[DOT_PRODUCT_AIR_COL_LEN].clone(); + let index_a_up = up[DOT_PRODUCT_AIR_COL_INDEX_A].clone(); + let index_b_up = up[DOT_PRODUCT_AIR_COL_INDEX_B].clone(); + let index_res_up = up[DOT_PRODUCT_AIR_COL_INDEX_RES].clone(); + let value_a_up = up[DOT_PRODUCT_AIR_COL_VALUE_A].clone(); + let value_b_up = up[DOT_PRODUCT_AIR_COL_VALUE_B].clone(); + let res_up = up[DOT_PRODUCT_AIR_COL_VALUE_RES].clone(); + let computation_up = up[DOT_PRODUCT_AIR_COL_COMPUTATION].clone(); + + let start_flag_down = down[0].clone(); + let len_down = down[1].clone(); + let index_a_down = down[2].clone(); + let index_b_down = down[3].clone(); + let computation_down = down[4].clone(); + + // TODO we could do most of the following computation in the base field + + builder.add_custom( as Air>::eval_custom( + self, + &[ + start_flag_up.clone().into(), + len_up.clone().into(), + index_a_up.clone().into(), + index_b_up.clone().into(), + index_res_up.clone().into(), + ], + )); builder.assert_bool(start_flag_down.clone()); @@ -104,87 +120,65 @@ impl Air for DotProductAir { builder.assert_zero(start_flag_up * (computation_up - res_up)); } -} -pub fn build_dot_product_columns( - witness: &[WitnessDotProduct], - min_n_rows: usize, -) -> (Vec>, usize) { - let ( - mut flag, - mut len, - mut index_a, - mut index_b, - mut index_res, - mut value_a, - mut value_b, - mut res, - mut computation, - ) = ( - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - ); - for dot_product in witness { - assert!(dot_product.len > 0); - - // computation - { - computation.extend(EF::zero_vec(dot_product.len)); - let new_size = computation.len(); - computation[new_size - 1] = - dot_product.slice_0[dot_product.len - 1] * dot_product.slice_1[dot_product.len - 1]; - for i in 0..dot_product.len - 1 { - computation[new_size - 2 - i] = computation[new_size - 1 - i] - + dot_product.slice_0[dot_product.len - 2 - i] - * dot_product.slice_1[dot_product.len - 2 - i]; - } + fn eval_custom(&self, inputs: &[::Expr]) -> ::FinalOutput { + let type_id_final_output = TypeId::of::<::FinalOutput>(); + let type_id_expr = TypeId::of::<::Expr>(); + // let type_id_f = TypeId::of::>(); + let type_id_ef = TypeId::of::(); + let type_id_f_packing = TypeId::of::>(); + let type_id_ef_packing = TypeId::of::>(); + + if type_id_expr == type_id_ef { + assert_eq!(type_id_final_output, type_id_ef); + let inputs = unsafe { transmute::<&[::Expr], &[EF]>(inputs) }; + let res = self.gkr_virtual_column_eval(inputs, |p, c| c * p); + unsafe { transmute_copy::::FinalOutput>(&res) } + } else if type_id_expr == type_id_ef_packing { + assert_eq!(type_id_final_output, type_id_ef_packing); + let inputs = + unsafe { transmute::<&[::Expr], &[EFPacking]>(inputs) }; + let res = self.gkr_virtual_column_eval(inputs, |p, c| p * c); + unsafe { transmute_copy::, ::FinalOutput>(&res) } + } else if type_id_expr == type_id_f_packing { + assert_eq!(type_id_final_output, type_id_ef_packing); + let inputs = + unsafe { transmute::<&[::Expr], &[PFPacking]>(inputs) }; + let res = self.gkr_virtual_column_eval(inputs, |p, c| EFPacking::::from(p) * c); + unsafe { transmute_copy::, ::FinalOutput>(&res) } + } else { + assert_eq!(type_id_expr, TypeId::of::>>()); + unsafe { transmute_copy(&SymbolicExpression::>::default()) } } + } +} - flag.push(EF::ONE); - flag.extend(EF::zero_vec(dot_product.len - 1)); - len.extend(((1..=dot_product.len).rev()).map(EF::from_usize)); - index_a.extend( - (0..dot_product.len).map(|i| EF::from_usize(dot_product.addr_0 + i * DIMENSION)), - ); - index_b.extend( - (0..dot_product.len).map(|i| EF::from_usize(dot_product.addr_1 + i * DIMENSION)), - ); - index_res.extend(vec![EF::from_usize(dot_product.addr_res); dot_product.len]); - value_a.extend(dot_product.slice_0.clone()); - value_b.extend(dot_product.slice_1.clone()); - res.extend(vec![dot_product.res; dot_product.len]); +impl DotProductAir { + fn gkr_virtual_column_eval< + PointF: PrimeCharacteristicRing + Copy, + ResultF: Algebra + Algebra + Copy, + >( + &self, + point: &[PointF], + mul_point_f_and_ef: impl Fn(PointF, EF) -> ResultF, + ) -> ResultF { + ResultF::from_usize(TABLE_INDEX_DOT_PRODUCTS) + + (mul_point_f_and_ef(point[2], self.fingerprint_challenge_powers[1]) + + mul_point_f_and_ef(point[3], self.fingerprint_challenge_powers[2]) + + mul_point_f_and_ef(point[4], self.fingerprint_challenge_powers[3]) + + mul_point_f_and_ef(point[1], self.fingerprint_challenge_powers[4])) + * point[0] + + self.global_challenge } +} - let padding_len = flag.len().next_power_of_two().max(min_n_rows) - flag.len(); - flag.extend(vec![EF::ONE; padding_len]); - len.extend(vec![EF::ONE; padding_len]); - index_a.extend(EF::zero_vec(padding_len)); - index_b.extend(EF::zero_vec(padding_len)); - index_res.extend(EF::zero_vec(padding_len)); - value_a.extend(EF::zero_vec(padding_len)); - value_b.extend(EF::zero_vec(padding_len)); - res.extend(EF::zero_vec(padding_len)); - computation.extend(EF::zero_vec(padding_len)); - - ( - vec![ - flag, - len, - index_a, - index_b, - index_res, - value_a, - value_b, - res, - computation, - ], - padding_len, - ) +pub fn dot_product_air_padding_row() -> Vec { + // only the shifted columns + vec![ + EF::ONE, // StartFlag + EF::ONE, // Len + EF::ZERO, // IndexA + EF::ZERO, // IndexB + EF::ZERO, // Computation + ] } diff --git a/crates/lean_prover/vm_air/src/execution_air.rs b/crates/lean_prover/vm_air/src/execution_air.rs index 11a9221a..0e8f9ee0 100644 --- a/crates/lean_prover/vm_air/src/execution_air.rs +++ b/crates/lean_prover/vm_air/src/execution_air.rs @@ -1,35 +1,76 @@ -use std::borrow::Borrow; +use std::{ + any::TypeId, + mem::{transmute, transmute_copy}, +}; +use multilinear_toolkit::prelude::*; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::PrimeCharacteristicRing; -use p3_matrix::Matrix; -use witness_generation::*; +use p3_uni_stark::SymbolicExpression; + +pub const N_INSTRUCTION_COLUMNS: usize = 13; +pub const N_COMMITTED_EXEC_COLUMNS: usize = 5; +pub const N_MEMORY_VALUE_COLUMNS: usize = 3; // virtual (lookup into memory, with logup*) +pub const N_EXEC_AIR_COLUMNS: usize = + N_INSTRUCTION_COLUMNS + N_COMMITTED_EXEC_COLUMNS + N_MEMORY_VALUE_COLUMNS; + +// Instruction columns +pub const COL_INDEX_OPERAND_A: usize = 0; +pub const COL_INDEX_OPERAND_B: usize = 1; +pub const COL_INDEX_OPERAND_C: usize = 2; +pub const COL_INDEX_FLAG_A: usize = 3; +pub const COL_INDEX_FLAG_B: usize = 4; +pub const COL_INDEX_FLAG_C: usize = 5; +pub const COL_INDEX_ADD: usize = 6; +pub const COL_INDEX_MUL: usize = 7; +pub const COL_INDEX_DEREF: usize = 8; +pub const COL_INDEX_JUMP: usize = 9; +pub const COL_INDEX_AUX: usize = 10; +pub const COL_INDEX_IS_PRECOMPILE: usize = 11; +pub const COL_INDEX_PRECOMPILE_INDEX: usize = 12; + +// Execution columns +pub const COL_INDEX_MEM_VALUE_A: usize = 13; // virtual with logup* +pub const COL_INDEX_MEM_VALUE_B: usize = 14; // virtual with logup* +pub const COL_INDEX_MEM_VALUE_C: usize = 15; // virtual with logup* +pub const COL_INDEX_PC: usize = 16; +pub const COL_INDEX_FP: usize = 17; +pub const COL_INDEX_MEM_ADDRESS_A: usize = 18; +pub const COL_INDEX_MEM_ADDRESS_B: usize = 19; +pub const COL_INDEX_MEM_ADDRESS_C: usize = 20; #[derive(Debug)] -pub struct VMAir; +pub struct VMAir { + // GKR grand product challenges + pub global_challenge: EF, + pub fingerprint_challenge_powers: [EF; 5], +} -impl BaseAir for VMAir { +impl BaseAir for VMAir { fn width(&self) -> usize { N_EXEC_AIR_COLUMNS } - fn structured(&self) -> bool { - true - } fn degree(&self) -> usize { 5 } + fn columns_with_shift(&self) -> Vec { + vec![COL_INDEX_PC, COL_INDEX_FP] + } } -impl Air for VMAir { +impl>> Air for VMAir +where + AB::Var: 'static, + AB::Expr: 'static, + AB::FinalOutput: 'static, +{ #[inline] fn eval(&self, builder: &mut AB) { let main = builder.main(); - let up = main.row_slice(0).unwrap(); - let up: &[AB::Var] = (*up).borrow(); - assert_eq!(up.len(), N_EXEC_AIR_COLUMNS); - let down = main.row_slice(1).unwrap(); - let down: &[AB::Var] = (*down).borrow(); - assert_eq!(down.len(), N_EXEC_AIR_COLUMNS); + let up = &main[..N_EXEC_AIR_COLUMNS]; + let down = &main[N_EXEC_AIR_COLUMNS..]; + + let next_pc = down[0].clone(); + let next_fp = down[1].clone(); let (operand_a, operand_b, operand_c) = ( up[COL_INDEX_OPERAND_A].clone(), @@ -46,58 +87,136 @@ impl Air for VMAir { let deref = up[COL_INDEX_DEREF].clone(); let jump = up[COL_INDEX_JUMP].clone(); let aux = up[COL_INDEX_AUX].clone(); + let is_precompile = up[COL_INDEX_IS_PRECOMPILE].clone(); + let precompile_index = up[COL_INDEX_PRECOMPILE_INDEX].clone(); let (value_a, value_b, value_c) = ( - up[COL_INDEX_MEM_VALUE_A.index_in_air()].clone(), - up[COL_INDEX_MEM_VALUE_B.index_in_air()].clone(), - up[COL_INDEX_MEM_VALUE_C.index_in_air()].clone(), - ); - let (pc, next_pc) = ( - up[COL_INDEX_PC.index_in_air()].clone(), - down[COL_INDEX_PC.index_in_air()].clone(), - ); - let (fp, next_fp) = ( - up[COL_INDEX_FP.index_in_air()].clone(), - down[COL_INDEX_FP.index_in_air()].clone(), + up[COL_INDEX_MEM_VALUE_A].clone(), + up[COL_INDEX_MEM_VALUE_B].clone(), + up[COL_INDEX_MEM_VALUE_C].clone(), ); + let pc = up[COL_INDEX_PC].clone(); + let fp = up[COL_INDEX_FP].clone(); let (addr_a, addr_b, addr_c) = ( - up[COL_INDEX_MEM_ADDRESS_A.index_in_air()].clone(), - up[COL_INDEX_MEM_ADDRESS_B.index_in_air()].clone(), - up[COL_INDEX_MEM_ADDRESS_C.index_in_air()].clone(), + up[COL_INDEX_MEM_ADDRESS_A].clone(), + up[COL_INDEX_MEM_ADDRESS_B].clone(), + up[COL_INDEX_MEM_ADDRESS_C].clone(), ); let flag_a_minus_one = flag_a.clone() - AB::F::ONE; let flag_b_minus_one = flag_b.clone() - AB::F::ONE; let flag_c_minus_one = flag_c.clone() - AB::F::ONE; - let nu_a = flag_a.clone() * operand_a.clone() + value_a.clone() * -flag_a_minus_one.clone(); - let nu_b = flag_b.clone() * operand_b.clone() + value_b * -flag_b_minus_one.clone(); - let nu_c = flag_c.clone() * fp.clone() + value_c.clone() * -flag_c_minus_one.clone(); + let nu_a = flag_a * operand_a.clone() + value_a.clone() * -flag_a_minus_one.clone(); + let nu_b = flag_b * operand_b.clone() + value_b.clone() * -flag_b_minus_one.clone(); + let nu_c = flag_c * fp.clone() + value_c.clone() * -flag_c_minus_one.clone(); - let fp_plus_operand_a = fp.clone() + operand_a; - let fp_plus_operand_b = fp.clone() + operand_b; + let fp_plus_operand_a = fp.clone() + operand_a.clone(); + let fp_plus_operand_b = fp.clone() + operand_b.clone(); let fp_plus_operand_c = fp.clone() + operand_c.clone(); - let pc_plus_one = pc.clone() + AB::F::ONE; + let pc_plus_one = pc + AB::F::ONE; let nu_a_minus_one = nu_a.clone() - AB::F::ONE; - builder.assert_zero(flag_a_minus_one * (addr_a - fp_plus_operand_a)); - builder.assert_zero(flag_b_minus_one * (addr_b - fp_plus_operand_b)); + builder.add_custom( as Air>::eval_custom( + self, + &[ + nu_a.clone(), + nu_b.clone(), + nu_c.clone(), + aux.clone().into(), + is_precompile.into(), + precompile_index.into(), + ], + )); + + builder.assert_zero(flag_a_minus_one * (addr_a.clone() - fp_plus_operand_a)); + builder.assert_zero(flag_b_minus_one * (addr_b.clone() - fp_plus_operand_b)); builder.assert_zero(flag_c_minus_one * (addr_c.clone() - fp_plus_operand_c)); builder.assert_zero(add * (nu_b.clone() - (nu_a.clone() + nu_c.clone()))); builder.assert_zero(mul * (nu_b.clone() - nu_a.clone() * nu_c.clone())); - builder.assert_zero(deref.clone() * (addr_c - (value_a + operand_c))); + builder + .assert_zero(deref.clone() * (addr_c.clone() - (value_a.clone() + operand_c.clone()))); builder.assert_zero(deref.clone() * aux.clone() * (value_c.clone() - nu_b.clone())); - builder.assert_zero(deref * (aux - AB::F::ONE) * (value_c - fp.clone())); + builder.assert_zero( + deref.clone() * (aux.clone() - AB::F::ONE) * (value_c.clone() - fp.clone()), + ); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_pc.clone() - pc_plus_one.clone())); builder.assert_zero((jump.clone() - AB::F::ONE) * (next_fp.clone() - fp.clone())); builder.assert_zero(jump.clone() * nu_a.clone() * nu_a_minus_one.clone()); - builder.assert_zero(jump.clone() * nu_a.clone() * (next_pc.clone() - nu_b)); - builder.assert_zero(jump.clone() * nu_a.clone() * (next_fp.clone() - nu_c)); - builder.assert_zero(jump.clone() * nu_a_minus_one.clone() * (next_pc - pc_plus_one)); - builder.assert_zero(jump * nu_a_minus_one * (next_fp - fp)); + builder.assert_zero(jump.clone() * nu_a.clone() * (next_pc.clone() - nu_b.clone())); + builder.assert_zero(jump.clone() * nu_a.clone() * (next_fp.clone() - nu_c.clone())); + builder.assert_zero( + jump.clone() * nu_a_minus_one.clone() * (next_pc.clone() - pc_plus_one.clone()), + ); + builder.assert_zero(jump.clone() * nu_a_minus_one.clone() * (next_fp.clone() - fp.clone())); + } + + fn eval_custom(&self, inputs: &[::Expr]) -> ::FinalOutput { + let type_id_final_output = TypeId::of::<::FinalOutput>(); + let type_id_expr = TypeId::of::<::Expr>(); + // let type_id_f = TypeId::of::>(); + let type_id_ef = TypeId::of::(); + let type_id_f_packing = TypeId::of::>(); + let type_id_ef_packing = TypeId::of::>(); + + if type_id_expr == type_id_ef { + assert_eq!(type_id_final_output, type_id_ef); + let inputs = unsafe { transmute::<&[::Expr], &[EF]>(inputs) }; + let res = self.gkr_virtual_column_eval(inputs, |p, c| c * p); + unsafe { transmute_copy::::FinalOutput>(&res) } + } else if type_id_expr == type_id_ef_packing { + assert_eq!(type_id_final_output, type_id_ef_packing); + let inputs = + unsafe { transmute::<&[::Expr], &[EFPacking]>(inputs) }; + let res = self.gkr_virtual_column_eval(inputs, |p, c| p * c); + unsafe { transmute_copy::, ::FinalOutput>(&res) } + } else if type_id_expr == type_id_f_packing { + assert_eq!(type_id_final_output, type_id_ef_packing); + let inputs = + unsafe { transmute::<&[::Expr], &[PFPacking]>(inputs) }; + let res = self.gkr_virtual_column_eval(inputs, |p, c| EFPacking::::from(p) * c); + unsafe { transmute_copy::, ::FinalOutput>(&res) } + } else { + assert_eq!(type_id_expr, TypeId::of::>>()); + unsafe { transmute_copy(&SymbolicExpression::>::default()) } + } } } + +impl VMAir { + fn gkr_virtual_column_eval< + PointF: PrimeCharacteristicRing + Copy, + ResultF: Algebra + Algebra + Copy, + >( + &self, + point: &[PointF], + mul_point_f_and_ef: impl Fn(PointF, EF) -> ResultF, + ) -> ResultF { + let nu_a = point[0]; + let nu_b = point[1]; + let nu_c = point[2]; + let aux = point[3]; + let is_precompile = point[4]; + let precompile_index = point[5]; + + let nu_a_mul_challenge_1 = mul_point_f_and_ef(nu_a, self.fingerprint_challenge_powers[1]); + let nu_b_mul_challenge_2 = mul_point_f_and_ef(nu_b, self.fingerprint_challenge_powers[2]); + let nu_c_mul_challenge_3 = mul_point_f_and_ef(nu_c, self.fingerprint_challenge_powers[3]); + + let nu_sums = nu_a_mul_challenge_1 + nu_b_mul_challenge_2 + nu_c_mul_challenge_3; + let aux_mul_challenge_4 = mul_point_f_and_ef(aux, self.fingerprint_challenge_powers[4]); + (nu_sums + aux_mul_challenge_4 + precompile_index) * is_precompile + self.global_challenge + } +} + +pub fn execution_air_padding_row(ending_pc: usize) -> Vec { + // only the shifted columns + vec![ + F::from_usize(ending_pc), // PC + F::ZERO, // FP + ] +} diff --git a/crates/lean_prover/witness_generation/Cargo.toml b/crates/lean_prover/witness_generation/Cargo.toml index b57d4dca..9f3f5431 100644 --- a/crates/lean_prover/witness_generation/Cargo.toml +++ b/crates/lean_prover/witness_generation/Cargo.toml @@ -10,22 +10,18 @@ workspace = true pest.workspace = true pest_derive.workspace = true utils.workspace = true -p3-field.workspace = true xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true p3-challenger.workspace = true p3-air.workspace = true -p3-matrix.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true -rayon.workspace = true tracing.workspace = true air.workspace = true -packed_pcs.workspace = true -p3-poseidon2-air.workspace = true +sub_protocols.workspace = true lookup.workspace = true lean_vm.workspace = true lean_compiler.workspace = true @@ -33,3 +29,4 @@ derive_more.workspace = true multilinear-toolkit.workspace = true poseidon_circuit.workspace = true p3-monty-31.workspace = true +vm_air.workspace = true diff --git a/crates/lean_prover/witness_generation/src/dot_product.rs b/crates/lean_prover/witness_generation/src/dot_product.rs new file mode 100644 index 00000000..38859ac3 --- /dev/null +++ b/crates/lean_prover/witness_generation/src/dot_product.rs @@ -0,0 +1,85 @@ +use lean_vm::{DIMENSION, EF, WitnessDotProduct}; +use multilinear_toolkit::prelude::*; + +pub fn build_dot_product_columns( + witness: &[WitnessDotProduct], + min_n_rows: usize, +) -> (Vec>, usize) { + let ( + mut flag, + mut len, + mut index_a, + mut index_b, + mut index_res, + mut value_a, + mut value_b, + mut res, + mut computation, + ) = ( + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + ); + for dot_product in witness { + assert!(dot_product.len > 0); + + // computation + { + computation.extend(EF::zero_vec(dot_product.len)); + let new_size = computation.len(); + computation[new_size - 1] = + dot_product.slice_0[dot_product.len - 1] * dot_product.slice_1[dot_product.len - 1]; + for i in 0..dot_product.len - 1 { + computation[new_size - 2 - i] = computation[new_size - 1 - i] + + dot_product.slice_0[dot_product.len - 2 - i] + * dot_product.slice_1[dot_product.len - 2 - i]; + } + } + + flag.push(EF::ONE); + flag.extend(EF::zero_vec(dot_product.len - 1)); + len.extend(((1..=dot_product.len).rev()).map(EF::from_usize)); + index_a.extend( + (0..dot_product.len).map(|i| EF::from_usize(dot_product.addr_0 + i * DIMENSION)), + ); + index_b.extend( + (0..dot_product.len).map(|i| EF::from_usize(dot_product.addr_1 + i * DIMENSION)), + ); + index_res.extend(vec![EF::from_usize(dot_product.addr_res); dot_product.len]); + value_a.extend(dot_product.slice_0.clone()); + value_b.extend(dot_product.slice_1.clone()); + res.extend(vec![dot_product.res; dot_product.len]); + } + + let padding_len = flag.len().next_power_of_two().max(min_n_rows) - flag.len(); + flag.extend(vec![EF::ONE; padding_len]); + len.extend(vec![EF::ONE; padding_len]); + index_a.extend(EF::zero_vec(padding_len)); + index_b.extend(EF::zero_vec(padding_len)); + index_res.extend(EF::zero_vec(padding_len)); + value_a.extend(EF::zero_vec(padding_len)); + value_b.extend(EF::zero_vec(padding_len)); + res.extend(EF::zero_vec(padding_len)); + computation.extend(EF::zero_vec(padding_len)); + + ( + vec![ + flag, + len, + index_a, + index_b, + index_res, + value_a, + value_b, + res, + computation, + ], + padding_len, + ) +} diff --git a/crates/lean_prover/witness_generation/src/execution_trace.rs b/crates/lean_prover/witness_generation/src/execution_trace.rs index 84773333..f130bf55 100644 --- a/crates/lean_prover/witness_generation/src/execution_trace.rs +++ b/crates/lean_prover/witness_generation/src/execution_trace.rs @@ -1,20 +1,14 @@ -use std::array; - use crate::instruction_encoder::field_representation; -use crate::{ - COL_INDEX_FP, COL_INDEX_MEM_ADDRESS_A, COL_INDEX_MEM_ADDRESS_B, COL_INDEX_MEM_ADDRESS_C, - COL_INDEX_MEM_VALUE_A, COL_INDEX_MEM_VALUE_B, COL_INDEX_MEM_VALUE_C, COL_INDEX_PC, - LOG_MIN_DOT_PRODUCT_ROWS, LOG_MIN_POSEIDONS_16, LOG_MIN_POSEIDONS_24, N_TOTAL_COLUMNS, -}; +use crate::*; use lean_vm::*; -use p3_field::Field; -use p3_field::PrimeCharacteristicRing; -use rayon::prelude::*; +use multilinear_toolkit::prelude::*; +use std::array; use utils::{ToUsize, transposed_par_iter_mut}; +use vm_air::*; #[derive(Debug)] pub struct ExecutionTrace { - pub full_trace: [Vec; N_TOTAL_COLUMNS], + pub full_trace: [Vec; N_EXEC_AIR_COLUMNS], pub n_cycles: usize, // before padding with the repeated final instruction pub n_poseidons_16: usize, pub n_poseidons_24: usize, @@ -47,7 +41,7 @@ pub fn get_execution_trace( let n_cycles = execution_result.pcs.len(); let memory = &execution_result.memory; - let mut trace: [Vec; N_TOTAL_COLUMNS] = + let mut trace: [Vec; N_EXEC_AIR_COLUMNS] = array::from_fn(|_| F::zero_vec(n_cycles.next_power_of_two())); transposed_par_iter_mut(&mut trace) diff --git a/crates/lean_prover/witness_generation/src/instruction_encoder.rs b/crates/lean_prover/witness_generation/src/instruction_encoder.rs index 3c855d8d..dca85d9c 100644 --- a/crates/lean_prover/witness_generation/src/instruction_encoder.rs +++ b/crates/lean_prover/witness_generation/src/instruction_encoder.rs @@ -1,7 +1,6 @@ use lean_vm::*; -use p3_field::PrimeCharacteristicRing; - -use crate::*; +use multilinear_toolkit::prelude::*; +use vm_air::*; pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { let mut fields = [F::ZERO; N_INSTRUCTION_COLUMNS]; @@ -69,25 +68,28 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { res, is_compression, } => { - fields[COL_INDEX_POSEIDON_16] = F::ONE; + fields[COL_INDEX_IS_PRECOMPILE] = F::ONE; + fields[COL_INDEX_PRECOMPILE_INDEX] = F::from_usize(TABLE_INDEX_POSEIDONS_16); set_nu_a(&mut fields, arg_a); set_nu_b(&mut fields, arg_b); set_nu_c(&mut fields, res); fields[COL_INDEX_AUX] = F::from_bool(*is_compression); // AUX = "is_compression" } Instruction::Poseidon2_24 { arg_a, arg_b, res } => { - fields[COL_INDEX_POSEIDON_24] = F::ONE; + fields[COL_INDEX_IS_PRECOMPILE] = F::ONE; + fields[COL_INDEX_PRECOMPILE_INDEX] = F::from_usize(TABLE_INDEX_POSEIDONS_24); set_nu_a(&mut fields, arg_a); set_nu_b(&mut fields, arg_b); set_nu_c(&mut fields, res); } - Instruction::DotProductExtensionExtension { + Instruction::DotProduct { arg0, arg1, res, size, } => { - fields[COL_INDEX_DOT_PRODUCT] = F::ONE; + fields[COL_INDEX_IS_PRECOMPILE] = F::ONE; + fields[COL_INDEX_PRECOMPILE_INDEX] = F::from_usize(TABLE_INDEX_DOT_PRODUCTS); set_nu_a(&mut fields, arg0); set_nu_b(&mut fields, arg1); set_nu_c(&mut fields, res); @@ -99,7 +101,8 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { res, n_vars, } => { - fields[COL_INDEX_MULTILINEAR_EVAL] = F::ONE; + fields[COL_INDEX_IS_PRECOMPILE] = F::ONE; + fields[COL_INDEX_PRECOMPILE_INDEX] = F::from_usize(TABLE_INDEX_MULTILINEAR_EVAL); set_nu_a(&mut fields, coeffs); set_nu_b(&mut fields, point); set_nu_c(&mut fields, res); diff --git a/crates/lean_prover/witness_generation/src/lib.rs b/crates/lean_prover/witness_generation/src/lib.rs index e792ecb3..1550d233 100644 --- a/crates/lean_prover/witness_generation/src/lib.rs +++ b/crates/lean_prover/witness_generation/src/lib.rs @@ -1,7 +1,5 @@ #![cfg_attr(not(test), allow(unused_crate_dependencies))] -use lean_compiler::PRECOMPILES; - mod execution_trace; mod instruction_encoder; @@ -11,56 +9,8 @@ pub use instruction_encoder::*; mod poseidon_tables; pub use poseidon_tables::*; -pub const N_INSTRUCTION_COLUMNS: usize = 15; -pub const N_COMMITTED_EXEC_COLUMNS: usize = 5; -pub const N_MEMORY_VALUE_COLUMNS: usize = 3; // virtual (lookup into memory, with logup*) -pub const N_EXEC_COLUMNS: usize = N_COMMITTED_EXEC_COLUMNS + N_MEMORY_VALUE_COLUMNS; -pub const N_INSTRUCTION_COLUMNS_IN_AIR: usize = N_INSTRUCTION_COLUMNS - PRECOMPILES.len(); -pub const N_EXEC_AIR_COLUMNS: usize = N_INSTRUCTION_COLUMNS_IN_AIR + N_EXEC_COLUMNS; -pub const N_TOTAL_COLUMNS: usize = N_INSTRUCTION_COLUMNS + N_EXEC_COLUMNS; - -// Instruction columns -pub const COL_INDEX_OPERAND_A: usize = 0; -pub const COL_INDEX_OPERAND_B: usize = 1; -pub const COL_INDEX_OPERAND_C: usize = 2; -pub const COL_INDEX_FLAG_A: usize = 3; -pub const COL_INDEX_FLAG_B: usize = 4; -pub const COL_INDEX_FLAG_C: usize = 5; -pub const COL_INDEX_ADD: usize = 6; -pub const COL_INDEX_MUL: usize = 7; -pub const COL_INDEX_DEREF: usize = 8; -pub const COL_INDEX_JUMP: usize = 9; -pub const COL_INDEX_AUX: usize = 10; -pub const COL_INDEX_POSEIDON_16: usize = 11; -pub const COL_INDEX_POSEIDON_24: usize = 12; -pub const COL_INDEX_DOT_PRODUCT: usize = 13; -pub const COL_INDEX_MULTILINEAR_EVAL: usize = 14; - -// Execution columns -pub const COL_INDEX_MEM_VALUE_A: usize = 15; // virtual with logup* -pub const COL_INDEX_MEM_VALUE_B: usize = 16; // virtual with logup* -pub const COL_INDEX_MEM_VALUE_C: usize = 17; // virtual with logup* -pub const COL_INDEX_PC: usize = 18; -pub const COL_INDEX_FP: usize = 19; -pub const COL_INDEX_MEM_ADDRESS_A: usize = 20; -pub const COL_INDEX_MEM_ADDRESS_B: usize = 21; -pub const COL_INDEX_MEM_ADDRESS_C: usize = 22; - -pub trait InAirColumnIndex { - fn index_in_air(self) -> usize; -} - -impl InAirColumnIndex for usize { - fn index_in_air(self) -> usize { - if self < N_INSTRUCTION_COLUMNS_IN_AIR { - self - } else { - assert!(self >= N_INSTRUCTION_COLUMNS); - assert!(self < N_INSTRUCTION_COLUMNS + N_EXEC_COLUMNS); - self - PRECOMPILES.len() - } - } -} +mod dot_product; +pub use dot_product::*; // Zero padding will be added to each at least, if this minimum is not reached // (ensuring AIR / GKR work fine, with SIMD, without too much edge cases) diff --git a/crates/lean_prover/witness_generation/src/poseidon_tables.rs b/crates/lean_prover/witness_generation/src/poseidon_tables.rs index 9223b4a1..c5d2c02a 100644 --- a/crates/lean_prover/witness_generation/src/poseidon_tables.rs +++ b/crates/lean_prover/witness_generation/src/poseidon_tables.rs @@ -2,87 +2,36 @@ use std::array; use lean_vm::{F, PoseidonWitnessTrait, WitnessPoseidon16, WitnessPoseidon24}; use multilinear_toolkit::prelude::*; -use p3_field::PrimeCharacteristicRing; use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters}; use p3_monty_31::InternalLayerBaseParameters; use poseidon_circuit::{PoseidonGKRLayers, PoseidonWitness, generate_poseidon_witness}; use tracing::instrument; -use utils::{padd_with_zero_to_next_power_of_two, transposed_par_iter_mut}; +use utils::transposed_par_iter_mut; -pub fn all_poseidon_16_indexes(poseidons_16: &[WitnessPoseidon16]) -> [Vec; 3] { - [ - poseidons_16 - .par_iter() - .map(|p| F::from_usize(p.addr_input_a)) - .collect::>(), - poseidons_16 - .par_iter() - .map(|p| F::from_usize(p.addr_input_b)) - .collect::>(), - poseidons_16 - .par_iter() - .map(|p| F::from_usize(p.addr_output)) - .collect::>(), - ] -} - -pub fn all_poseidon_24_indexes(poseidons_24: &[WitnessPoseidon24]) -> [Vec; 3] { - [ - padd_with_zero_to_next_power_of_two( - &poseidons_24 - .iter() - .map(|p| F::from_usize(p.addr_input_a)) - .collect::>(), - ), - padd_with_zero_to_next_power_of_two( - &poseidons_24 - .iter() - .map(|p| F::from_usize(p.addr_input_b)) - .collect::>(), - ), - padd_with_zero_to_next_power_of_two( - &poseidons_24 - .iter() - .map(|p| F::from_usize(p.addr_output)) - .collect::>(), - ), - ] +pub fn all_poseidon_16_indexes(poseidons_16: &[WitnessPoseidon16]) -> [Vec; 4] { + assert!(poseidons_16.len().is_power_of_two()); + #[rustfmt::skip] + let res = [ + poseidons_16.par_iter().map(|p| F::from_usize(p.addr_input_a)).collect::>(), + poseidons_16.par_iter().map(|p| F::from_usize(p.addr_input_b)).collect::>(), + poseidons_16.par_iter().map(|p| F::from_usize(p.addr_output)).collect::>(), + poseidons_16.par_iter().map(|p| { + F::from_usize((1 - p.is_compression as usize) * (p.addr_output + 1)) + }).collect::>(), + ]; + res } -pub fn full_poseidon_indexes_poly( - poseidons_16: &[WitnessPoseidon16], - poseidons_24: &[WitnessPoseidon24], -) -> Vec { - let max_n_poseidons = poseidons_16 - .len() - .max(poseidons_24.len()) - .next_power_of_two(); - let mut all_poseidon_indexes = F::zero_vec(8 * max_n_poseidons); +pub fn all_poseidon_24_indexes(poseidons_24: &[WitnessPoseidon24]) -> [Vec; 4] { + assert!(poseidons_24.len().is_power_of_two()); #[rustfmt::skip] - let chunks = [ - poseidons_16.par_iter().map(|p| p.addr_input_a).collect::>(), - poseidons_16.par_iter().map(|p| p.addr_input_b).collect::>(), - poseidons_16.par_iter().map(|p| p.addr_output).collect::>(), - poseidons_16.par_iter().map(|p| { - if p.is_compression { 0 } else { p.addr_output + 1 } - }) - .collect::>(), - poseidons_24.par_iter().map(|p| p.addr_input_a).collect::>(), - poseidons_24.par_iter().map(|p| p.addr_input_a + 1).collect::>(), - poseidons_24.par_iter().map(|p| p.addr_input_b).collect::>(), - poseidons_24.par_iter().map(|p| p.addr_output).collect::>() - ]; - - for (chunk_idx, addrs) in chunks.into_iter().enumerate() { - all_poseidon_indexes[chunk_idx * max_n_poseidons..] - .par_iter_mut() - .zip(addrs) - .for_each(|(slot, addr)| { - *slot = F::from_usize(addr); - }); - } - - all_poseidon_indexes + let res = [ + poseidons_24.par_iter().map(|p| F::from_usize(p.addr_input_a)).collect::>(), + poseidons_24.par_iter().map(|p| F::from_usize(p.addr_input_a + 1)).collect::>(), + poseidons_24.par_iter().map(|p| F::from_usize(p.addr_input_b)).collect::>(), + poseidons_24.par_iter().map(|p| F::from_usize(p.addr_output)).collect::>() + ]; + res } #[instrument(skip_all)] diff --git a/crates/lean_vm/Cargo.toml b/crates/lean_vm/Cargo.toml index fd2577dd..a1939641 100644 --- a/crates/lean_vm/Cargo.toml +++ b/crates/lean_vm/Cargo.toml @@ -11,22 +11,18 @@ colored.workspace = true pest.workspace = true pest_derive.workspace = true utils.workspace = true -p3-field.workspace = true xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true p3-challenger.workspace = true p3-air.workspace = true -p3-matrix.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true -rayon.workspace = true tracing.workspace = true air.workspace = true -packed_pcs.workspace = true -p3-poseidon2-air.workspace = true +sub_protocols.workspace = true lookup.workspace = true thiserror.workspace = true derive_more.workspace = true diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 7d6ca413..81e2fa38 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -38,3 +38,9 @@ pub const POSEIDON_24_NULL_HASH_PTR: usize = 5; /// Normal pointer to start of program input pub const NONRESERVED_PROGRAM_INPUT_START: usize = 6 * 8; + +/// Precompiles Indexes +pub const TABLE_INDEX_POSEIDONS_16: usize = 1; // should be != 0 +pub const TABLE_INDEX_POSEIDONS_24: usize = 2; +pub const TABLE_INDEX_DOT_PRODUCTS: usize = 3; +pub const TABLE_INDEX_MULTILINEAR_EVAL: usize = 4; diff --git a/crates/lean_vm/src/execution/memory.rs b/crates/lean_vm/src/execution/memory.rs index f575e836..d577d7b3 100644 --- a/crates/lean_vm/src/execution/memory.rs +++ b/crates/lean_vm/src/execution/memory.rs @@ -1,9 +1,7 @@ //! Memory management for the VM - use crate::core::{DIMENSION, EF, F, MAX_RUNNER_MEMORY_SIZE, VECTOR_LEN}; use crate::diagnostics::RunnerError; -use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; -use rayon::prelude::*; +use multilinear_toolkit::prelude::*; pub const MIN_LOG_MEMORY_SIZE: usize = 16; pub const MAX_LOG_MEMORY_SIZE: usize = 29; diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index d19064f7..41898992 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -12,7 +12,7 @@ use crate::witness::{ WitnessDotProduct, WitnessMultilinearEval, WitnessPoseidon16, WitnessPoseidon24, }; use crate::{CodeAddress, HintExecutionContext, SourceLineNumber}; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use std::collections::{BTreeMap, BTreeSet}; use utils::{poseidon16_permute, poseidon24_permute, pretty_integer}; use xmss::{Poseidon16History, Poseidon24History}; diff --git a/crates/lean_vm/src/execution/tests.rs b/crates/lean_vm/src/execution/tests.rs index 15513e7b..428299d7 100644 --- a/crates/lean_vm/src/execution/tests.rs +++ b/crates/lean_vm/src/execution/tests.rs @@ -1,6 +1,5 @@ use crate::*; -use p3_field::BasedVectorSpace; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; #[test] fn test_basic_memory_operations() { diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index ddadc5d0..68cccfc0 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -2,7 +2,7 @@ use crate::core::{F, LOG_VECTOR_LEN, Label, SourceLineNumber, VECTOR_LEN}; use crate::diagnostics::{MemoryObject, MemoryObjectType, MemoryProfile, RunnerError}; use crate::execution::{ExecutionHistory, Memory}; use crate::isa::operands::MemOrConstant; -use p3_field::{Field, PrimeCharacteristicRing}; +use multilinear_toolkit::prelude::*; use std::fmt::{Display, Formatter}; use utils::{ToUsize, pretty_integer}; diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index 7810ff2d..548a17df 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -11,7 +11,6 @@ use crate::witness::{ WitnessPoseidon24, }; use multilinear_toolkit::prelude::*; -use p3_field::{BasedVectorSpace, PrimeCharacteristicRing, dot_product}; use p3_util::log2_ceil_usize; use std::fmt::{Display, Formatter}; use utils::{ToUsize, poseidon16_permute, poseidon24_permute}; @@ -76,7 +75,7 @@ pub enum Instruction { }, /// Dot product computation between extension field element vectors - DotProductExtensionExtension { + DotProduct { /// First vector pointer (normal pointer, size `size`) arg0: MemOrConstant, /// Second vector pointer (normal pointer, size `size`) @@ -311,7 +310,7 @@ impl Instruction { *ctx.pc += 1; Ok(()) } - Self::DotProductExtensionExtension { + Self::DotProduct { arg0, arg1, res, @@ -427,7 +426,7 @@ impl Display for Instruction { } => { write!(f, "{res} = m[m[fp + {shift_0}] + {shift_1}]") } - Self::DotProductExtensionExtension { + Self::DotProduct { arg0, arg1, res, diff --git a/crates/lean_vm/src/isa/operands/mem_or_constant.rs b/crates/lean_vm/src/isa/operands/mem_or_constant.rs index bb318889..8bc3b29f 100644 --- a/crates/lean_vm/src/isa/operands/mem_or_constant.rs +++ b/crates/lean_vm/src/isa/operands/mem_or_constant.rs @@ -1,7 +1,7 @@ use crate::core::F; use crate::diagnostics::RunnerError; use crate::execution::Memory; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use std::fmt::{Display, Formatter}; /// Represents a value that can be either a constant or memory location diff --git a/crates/lean_vm/src/isa/operands/mem_or_fp.rs b/crates/lean_vm/src/isa/operands/mem_or_fp.rs index 70bbbb8f..013f4194 100644 --- a/crates/lean_vm/src/isa/operands/mem_or_fp.rs +++ b/crates/lean_vm/src/isa/operands/mem_or_fp.rs @@ -3,7 +3,7 @@ use crate::core::F; use crate::diagnostics::RunnerError; use crate::execution::Memory; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use std::fmt::{Display, Formatter}; /// Represents a value from memory or the frame pointer itself diff --git a/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs b/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs index bd6d714e..0cb63aae 100644 --- a/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs +++ b/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs @@ -1,7 +1,7 @@ use crate::core::F; use crate::diagnostics::RunnerError; use crate::execution::Memory; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use std::fmt::{Display, Formatter}; /// Memory, frame pointer, or constant operand diff --git a/crates/lean_vm/src/isa/operation.rs b/crates/lean_vm/src/isa/operation.rs index 74e89b7c..e5b880bb 100644 --- a/crates/lean_vm/src/isa/operation.rs +++ b/crates/lean_vm/src/isa/operation.rs @@ -1,7 +1,7 @@ //! VM operation definitions use crate::core::F; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use std::fmt::{Display, Formatter}; /// Basic arithmetic operations supported by the VM diff --git a/crates/lean_vm/src/lib.rs b/crates/lean_vm/src/lib.rs index cd4c54af..969a3945 100644 --- a/crates/lean_vm/src/lib.rs +++ b/crates/lean_vm/src/lib.rs @@ -1,10 +1,10 @@ //! Lean VM - A minimal virtual machine implementation -pub mod core; -pub mod diagnostics; -pub mod execution; -pub mod isa; -pub mod witness; +mod core; +mod diagnostics; +mod execution; +mod isa; +mod witness; pub use core::*; pub use diagnostics::*; diff --git a/crates/lean_vm/src/witness/dot_product.rs b/crates/lean_vm/src/witness/dot_product.rs index 486269a5..0c456274 100644 --- a/crates/lean_vm/src/witness/dot_product.rs +++ b/crates/lean_vm/src/witness/dot_product.rs @@ -1,7 +1,7 @@ //! Dot product witness for arithmetic operations between extension field elements use crate::core::{EF, F}; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; /// Witness data for dot the product precompile #[derive(Debug, Clone)] diff --git a/crates/lean_vm/src/witness/multilinear_eval.rs b/crates/lean_vm/src/witness/multilinear_eval.rs index 6f312a74..f93396aa 100644 --- a/crates/lean_vm/src/witness/multilinear_eval.rs +++ b/crates/lean_vm/src/witness/multilinear_eval.rs @@ -1,7 +1,7 @@ //! Multilinear polynomial evaluation witness use crate::core::{EF, F}; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; #[derive(Debug, Clone)] pub struct RowMultilinearEval { diff --git a/crates/lean_vm/src/witness/poseidon16.rs b/crates/lean_vm/src/witness/poseidon16.rs index aca584eb..f1fd9046 100644 --- a/crates/lean_vm/src/witness/poseidon16.rs +++ b/crates/lean_vm/src/witness/poseidon16.rs @@ -4,7 +4,7 @@ use crate::{ PoseidonWitnessTrait, core::{F, POSEIDON_16_NULL_HASH_PTR, ZERO_VEC_PTR}, }; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; pub const POSEIDON_16_DEFAULT_COMPRESSION: bool = true; diff --git a/crates/lean_vm/src/witness/poseidon24.rs b/crates/lean_vm/src/witness/poseidon24.rs index d87e86be..e26dd93d 100644 --- a/crates/lean_vm/src/witness/poseidon24.rs +++ b/crates/lean_vm/src/witness/poseidon24.rs @@ -4,7 +4,7 @@ use crate::{ PoseidonWitnessTrait, core::{F, POSEIDON_24_NULL_HASH_PTR, ZERO_VEC_PTR}, }; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; /// Witness data for Poseidon2 over 24 field elements #[derive(Debug, Clone)] diff --git a/crates/lean_vm/tests/test_lean_vm.rs b/crates/lean_vm/tests/test_lean_vm.rs index 3b6841cd..d36baac1 100644 --- a/crates/lean_vm/tests/test_lean_vm.rs +++ b/crates/lean_vm/tests/test_lean_vm.rs @@ -1,7 +1,6 @@ use lean_vm::error::ExecutionResult; use lean_vm::*; -use p3_field::BasedVectorSpace; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use p3_util::log2_ceil_usize; use std::collections::BTreeMap; use utils::ToUsize; @@ -176,7 +175,7 @@ fn build_test_case() -> (Bytecode, Vec) { offset: POSEIDON24_RES_OFFSET, }, }, - Instruction::DotProductExtensionExtension { + Instruction::DotProduct { arg0: MemOrConstant::Constant(f(DOT_ARG0_PTR as u64)), arg1: MemOrConstant::Constant(f(DOT_ARG1_PTR as u64)), res: MemOrFp::MemoryAfterFp { @@ -332,7 +331,7 @@ fn test_memory_operations() { #[test] fn test_operation_compute() { - use crate::isa::Operation; + use crate::Operation; let add = Operation::Add; let mul = Operation::Mul; diff --git a/crates/lookup/Cargo.toml b/crates/lookup/Cargo.toml index e8e323ba..12354a7f 100644 --- a/crates/lookup/Cargo.toml +++ b/crates/lookup/Cargo.toml @@ -8,10 +8,8 @@ workspace = true [dependencies] utils.workspace = true -p3-field.workspace = true p3-koala-bear.workspace = true rand.workspace = true -rayon.workspace = true whir-p3.workspace = true p3-challenger.workspace = true tracing.workspace = true diff --git a/crates/lookup/src/logup_star.rs b/crates/lookup/src/logup_star.rs index e282430e..77a99ea4 100644 --- a/crates/lookup/src/logup_star.rs +++ b/crates/lookup/src/logup_star.rs @@ -6,10 +6,8 @@ https://eprint.iacr.org/2025/946.pdf */ use multilinear_toolkit::prelude::*; -use p3_field::{ExtensionField, PrimeField64}; use utils::ToUsize; -use p3_field::PrimeCharacteristicRing; use tracing::{info_span, instrument}; use utils::{FSProver, FSVerifier}; @@ -233,7 +231,6 @@ mod tests { use std::time::Instant; use super::*; - use p3_field::PrimeCharacteristicRing; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; use rand::{Rng, SeedableRng, rngs::StdRng}; use utils::{build_challenger, init_tracing}; diff --git a/crates/lookup/src/product_gkr.rs b/crates/lookup/src/product_gkr.rs index 89036c96..9d131e9c 100644 --- a/crates/lookup/src/product_gkr.rs +++ b/crates/lookup/src/product_gkr.rs @@ -8,11 +8,7 @@ with custom GKR */ use multilinear_toolkit::prelude::*; -use p3_field::PrimeCharacteristicRing; -use p3_field::{ExtensionField, PrimeField64}; use tracing::instrument; -use utils::left_ref; -use utils::right_ref; use utils::{FSProver, FSVerifier}; use crate::MIN_VARS_FOR_PACKING; @@ -27,25 +23,17 @@ A': [a0*a4, a1*a5, a2*a6, a3*a7] */ #[instrument(skip_all)] -pub fn prove_gkr_product( +pub fn prove_gkr_product( prover_state: &mut FSProver>, final_layer: &[EF], ) -> (EF, Evaluation) where EF: ExtensionField>, - PF: PrimeField64, { - assert!(log2_strict_usize(final_layer.len()) >= 1); - if final_layer.len() == 2 { - prover_state.add_extension_scalars(final_layer); - let product = final_layer[0] * final_layer[1]; - let point = MultilinearPoint(vec![prover_state.sample()]); - let claim = Evaluation { - point: point.clone(), - value: final_layer.evaluate(&point), - }; - return (product, claim); - } + assert!( + log2_strict_usize(final_layer.len()) > log2_strict_usize(N_GROUPS), + "TODO small case" + ); let final_layer: Mle<'_, EF> = if final_layer.len() >= 1 << MIN_VARS_FOR_PACKING { // TODO packing beforehand @@ -55,14 +43,16 @@ where }; let mut layers = vec![final_layer]; + layers.push(product_n_by_n::<_, N_GROUPS>(&layers.last().unwrap().by_ref()).into()); loop { if layers.last().unwrap().n_vars() == 1 { break; } - layers.push(product_2_by_2(&layers.last().unwrap().by_ref()).into()); + layers.push(product_n_by_n::<_, 2>(&layers.last().unwrap().by_ref()).into()); } - let last_layer = match layers.last().unwrap().by_ref() { + let last_layer = layers.pop().unwrap(); + let last_layer = match last_layer.by_ref() { MleRef::Extension(slice) => slice, _ => unreachable!(), }; @@ -76,69 +66,80 @@ where value: last_layer.evaluate(&point), }; - for layer in layers.iter().rev().skip(1) { - claim = match layer.by_ref() { - MleRef::Extension(slice) => prove_gkr_product_step(prover_state, slice, &claim), - MleRef::ExtensionPacked(slice) => { - prove_gkr_product_step_packed(prover_state, slice, &claim) - } - _ => unreachable!(), - } + for layer in layers[1..].iter().rev() { + claim = prove_gkr_product_step::<_, 2>(prover_state, &layer.by_ref(), &claim); } + claim = prove_gkr_product_step::<_, N_GROUPS>(prover_state, &layers[0].by_ref(), &claim); (product, claim) } -fn prove_gkr_product_step( +fn prove_gkr_product_step>, const N_GROUPS: usize>( + prover_state: &mut FSProver>, + up_layer: &MleRef<'_, EF>, + claim: &Evaluation, +) -> Evaluation { + match up_layer { + MleRef::Extension(slice) => { + prove_gkr_product_step_unpacked::<_, N_GROUPS>(prover_state, slice, claim) + } + MleRef::ExtensionPacked(slice) => { + prove_gkr_product_step_packed::<_, N_GROUPS>(prover_state, slice, claim) + } + _ => unreachable!(), + } +} + +fn prove_gkr_product_step_unpacked>, const N_GROUPS: usize>( prover_state: &mut FSProver>, up_layer: &[EF], claim: &Evaluation, -) -> Evaluation -where - EF: ExtensionField>, - PF: PrimeField64, -{ - assert_eq!(up_layer.len().ilog2() as usize - 1, claim.point.0.len()); - prove_gkr_product_step_core( +) -> Evaluation { + assert_eq!(up_layer.len(), N_GROUPS << claim.point.0.len()); + prove_gkr_product_step_core::<_, N_GROUPS>( prover_state, - MleGroupRef::Extension(vec![left_ref(up_layer), right_ref(up_layer)]), + MleGroupRef::Extension(split_at_many( + up_layer, + &(1..N_GROUPS) + .map(|i| i * up_layer.len() / N_GROUPS) + .collect::>(), + )), claim, ) } -fn prove_gkr_product_step_packed( +fn prove_gkr_product_step_packed>, const N_GROUPS: usize>( prover_state: &mut FSProver>, up_layer_packed: &[EFPacking], claim: &Evaluation, -) -> Evaluation -where - EF: ExtensionField>, - PF: PrimeField64, -{ +) -> Evaluation { assert_eq!( up_layer_packed.len() * packing_width::(), - 2 << claim.point.0.len() + N_GROUPS << claim.point.0.len() ); - prove_gkr_product_step_core( + prove_gkr_product_step_core::<_, N_GROUPS>( prover_state, - MleGroupRef::ExtensionPacked(vec![left_ref(up_layer_packed), right_ref(up_layer_packed)]), + MleGroupRef::ExtensionPacked(split_at_many( + up_layer_packed, + &(1..N_GROUPS) + .map(|i| i * up_layer_packed.len() / N_GROUPS) + .collect::>(), + )), claim, ) } -fn prove_gkr_product_step_core( +fn prove_gkr_product_step_core>, const N_GROUPS: usize>( prover_state: &mut FSProver>, up_layer: MleGroupRef<'_, EF>, claim: &Evaluation, -) -> Evaluation -where - EF: ExtensionField>, - PF: PrimeField64, -{ +) -> Evaluation { + let _: () = assert!(N_GROUPS.is_power_of_two()); + assert_eq!(up_layer.n_columns(), N_GROUPS); let (sc_point, inner_evals, _) = sumcheck_prove::( 1, up_layer, - &ProductComputation, + &MultiProductComputation:: {}, &[], Some((claim.point.0.clone(), None)), false, @@ -149,95 +150,104 @@ where prover_state.add_extension_scalars(&inner_evals); + let selectors = univariate_selectors::>(log2_strict_usize(N_GROUPS)); let mixing_challenge = prover_state.sample(); let mut next_point = sc_point; next_point.0.insert(0, mixing_challenge); - let next_claim = - inner_evals[0] * (EF::ONE - mixing_challenge) + inner_evals[1] * mixing_challenge; + let next_claim: EF = inner_evals + .iter() + .enumerate() + .map(|(i, &v)| v * selectors[i].evaluate(mixing_challenge)) + .sum(); Evaluation::new(next_point, next_claim) } -pub fn verify_gkr_product( +pub fn verify_gkr_product>, const N_GROUPS: usize>( verifier_state: &mut FSVerifier>, n_vars: usize, -) -> Result<(EF, Evaluation), ProofError> -where - EF: ExtensionField>, - PF: PrimeField64, -{ +) -> Result<(EF, Evaluation), ProofError> { + assert!(n_vars > log2_strict_usize(N_GROUPS), "TODO small case"); let [a, b] = verifier_state.next_extension_scalars_const()?; - if b == EF::ZERO { - return Err(ProofError::InvalidProof); - } let product = a * b; let point = MultilinearPoint(vec![verifier_state.sample()]); let value = [a, b].evaluate(&point); let mut claim = Evaluation { point, value }; - for i in 1..n_vars { - claim = verify_gkr_product_step(verifier_state, i, &claim)?; + for i in 1..n_vars - log2_strict_usize(N_GROUPS) { + claim = verify_gkr_product_step::<_, 2>(verifier_state, i, &claim)?; } + claim = verify_gkr_product_step::<_, N_GROUPS>( + verifier_state, + n_vars - log2_strict_usize(N_GROUPS), + &claim, + )?; Ok((product, claim)) } -fn verify_gkr_product_step( +fn verify_gkr_product_step>, const N_GROUPS: usize>( verifier_state: &mut FSVerifier>, current_layer_log_len: usize, claim: &Evaluation, -) -> Result, ProofError> -where - EF: ExtensionField>, - PF: PrimeField64, -{ - let (sc_eval, postponed) = sumcheck_verify(verifier_state, current_layer_log_len, 3) - .map_err(|_| ProofError::InvalidProof)?; +) -> Result, ProofError> { + let (sc_eval, postponed) = + sumcheck_verify(verifier_state, current_layer_log_len, 1 + N_GROUPS)?; if sc_eval != claim.value { return Err(ProofError::InvalidProof); } - let [eval_left, eval_right] = verifier_state.next_extension_scalars_const()?; + let inner_evals = verifier_state.next_extension_scalars_const::()?; - let postponed_target = claim.point.eq_poly_outside(&postponed.point) * eval_left * eval_right; + let postponed_target = + claim.point.eq_poly_outside(&postponed.point) * inner_evals.iter().copied().product::(); if postponed_target != postponed.value { return Err(ProofError::InvalidProof); } + let selectors = univariate_selectors::>(log2_strict_usize(N_GROUPS)); let mixing_challenge = verifier_state.sample(); let mut next_point = postponed.point; next_point.0.insert(0, mixing_challenge); - let next_claim = eval_left * (EF::ONE - mixing_challenge) + eval_right * mixing_challenge; + let next_claim: EF = inner_evals + .iter() + .enumerate() + .map(|(i, &v)| v * selectors[i].evaluate(mixing_challenge)) + .sum(); Ok(Evaluation::new(next_point, next_claim)) } -fn product_2_by_2>>(layer: &MleRef<'_, EF>) -> MleOwned { +fn product_n_by_n>, const N: usize>( + layer: &MleRef<'_, EF>, +) -> MleOwned { match layer { - MleRef::Extension(slice) => MleOwned::Extension(product_2_by_2_helper(slice)), + MleRef::Extension(slice) => MleOwned::Extension(product_n_by_n_helper::<_, N>(slice)), MleRef::ExtensionPacked(slice) => { if slice.len() >= 1 << MIN_VARS_FOR_PACKING { - MleOwned::ExtensionPacked(product_2_by_2_helper(slice)) + MleOwned::ExtensionPacked(product_n_by_n_helper::<_, N>(slice)) } else { - MleOwned::Extension(product_2_by_2_helper(&unpack_extension(slice))) + MleOwned::Extension(product_n_by_n_helper::<_, N>(&unpack_extension(slice))) } } _ => unreachable!(), } } -fn product_2_by_2_helper( +fn product_n_by_n_helper( layer: &[EF], ) -> Vec { - let n = layer.len(); - (0..n / 2) + assert!(layer.len().is_multiple_of(N)); + let size = layer.len(); + let size_div_n = size / N; + (0..size / N) .into_par_iter() - .map(|i| layer[i] * layer[n / 2 + i]) + .map(|i| (0..N).map(|j| layer[i + j * size_div_n]).product()) .collect() } @@ -254,12 +264,12 @@ mod tests { #[test] fn test_gkr_product() { - for log_n in 1..10 { - test_gkr_product_helper(log_n); - } + test_gkr_product_helper::<8>(20); + test_gkr_product_helper::<4>(7); + test_gkr_product_helper::<4>(3); } - fn test_gkr_product_helper(log_n: usize) { + fn test_gkr_product_helper(log_n: usize) { let n = 1 << log_n; let mut rng = StdRng::seed_from_u64(0); @@ -271,15 +281,25 @@ mod tests { let time = Instant::now(); - let (product_prover, claim_prover) = prove_gkr_product(&mut prover_state, &layer); + let (product_prover, claim_prover) = + prove_gkr_product::<_, N_GROUPS>(&mut prover_state, &layer); println!("GKR product took {:?}", time.elapsed()); let mut verifier_state = build_verifier_state(&prover_state); let (product_verifier, claim_verifier) = - verify_gkr_product::(&mut verifier_state, log_n).unwrap(); + verify_gkr_product::<_, N_GROUPS>(&mut verifier_state, log_n).unwrap(); assert_eq!(&claim_prover, &claim_verifier); - assert_eq!(layer.evaluate(&claim_verifier.point), claim_verifier.value); + let selectors = univariate_selectors::>(log2_strict_usize(N_GROUPS)); + assert_eq!( + evaluate_univariate_multilinear::<_, _, _, true>( + &layer, + &claim_verifier.point, + &selectors, + None + ), + claim_verifier.value + ); assert_eq_many!(product_verifier, product_prover, real_product); } } diff --git a/crates/lookup/src/quotient_gkr.rs b/crates/lookup/src/quotient_gkr.rs index 4e8918d5..862c8645 100644 --- a/crates/lookup/src/quotient_gkr.rs +++ b/crates/lookup/src/quotient_gkr.rs @@ -7,9 +7,6 @@ with custom GKR */ use multilinear_toolkit::prelude::*; -use p3_field::PackedFieldExtension; -use p3_field::PrimeCharacteristicRing; -use p3_field::{ExtensionField, PrimeField64, dot_product}; use tracing::instrument; use utils::{FSProver, FSVerifier}; diff --git a/crates/poseidon_circuit/Cargo.toml b/crates/poseidon_circuit/Cargo.toml index a995a9d9..74625f71 100644 --- a/crates/poseidon_circuit/Cargo.toml +++ b/crates/poseidon_circuit/Cargo.toml @@ -7,7 +7,6 @@ edition.workspace = true workspace = true [dependencies] -p3-field.workspace = true tracing.workspace = true utils.workspace = true # p3-util.workspace = true @@ -17,5 +16,5 @@ p3-poseidon2.workspace = true p3-monty-31.workspace = true rand.workspace = true whir-p3.workspace = true -packed_pcs.workspace = true +sub_protocols.workspace = true diff --git a/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs b/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs index f806eed9..6d6403bc 100644 --- a/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs +++ b/crates/poseidon_circuit/src/gkr_layers/batch_partial_rounds.rs @@ -1,7 +1,6 @@ use std::array; use multilinear_toolkit::prelude::*; -use p3_field::ExtensionField; use p3_koala_bear::{ GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters, }; diff --git a/crates/poseidon_circuit/src/gkr_layers/compression.rs b/crates/poseidon_circuit/src/gkr_layers/compression.rs index f556bae3..0e8d1bfd 100644 --- a/crates/poseidon_circuit/src/gkr_layers/compression.rs +++ b/crates/poseidon_circuit/src/gkr_layers/compression.rs @@ -1,5 +1,4 @@ use multilinear_toolkit::prelude::*; -use p3_field::ExtensionField; use crate::{EF, F}; diff --git a/crates/poseidon_circuit/src/gkr_layers/full_round.rs b/crates/poseidon_circuit/src/gkr_layers/full_round.rs index e75e8a2f..334adae3 100644 --- a/crates/poseidon_circuit/src/gkr_layers/full_round.rs +++ b/crates/poseidon_circuit/src/gkr_layers/full_round.rs @@ -1,5 +1,4 @@ use multilinear_toolkit::prelude::*; -use p3_field::ExtensionField; use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters}; use p3_monty_31::InternalLayerBaseParameters; diff --git a/crates/poseidon_circuit/src/gkr_layers/partial_round.rs b/crates/poseidon_circuit/src/gkr_layers/partial_round.rs index b5b2ca4c..8f6b51f8 100644 --- a/crates/poseidon_circuit/src/gkr_layers/partial_round.rs +++ b/crates/poseidon_circuit/src/gkr_layers/partial_round.rs @@ -1,5 +1,4 @@ use multilinear_toolkit::prelude::*; -use p3_field::ExtensionField; use crate::{EF, F}; diff --git a/crates/poseidon_circuit/src/lib.rs b/crates/poseidon_circuit/src/lib.rs index 1a864f97..089bfb4a 100644 --- a/crates/poseidon_circuit/src/lib.rs +++ b/crates/poseidon_circuit/src/lib.rs @@ -1,6 +1,6 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] -use multilinear_toolkit::prelude::MultilinearPoint; +use multilinear_toolkit::prelude::MultiEvaluation; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; mod prove; @@ -23,9 +23,10 @@ pub use gkr_layers::*; pub(crate) type F = KoalaBear; pub(crate) type EF = QuinticExtensionFieldKB; +/// remain to be proven #[derive(Debug, Clone)] pub struct GKRPoseidonResult { - pub output_values: Vec, // of length width - pub input_statements: (MultilinearPoint, Vec), // of length width, remain to be proven - pub cubes_statements: (MultilinearPoint, Vec), // of length n_committed_cubes, remain to be proven + pub output_statements: MultiEvaluation, // of length width + pub input_statements: MultiEvaluation, // of length width + pub cubes_statements: MultiEvaluation, // of length n_committed_cubes } diff --git a/crates/poseidon_circuit/src/prove.rs b/crates/poseidon_circuit/src/prove.rs index 1836b283..ccff4dcd 100644 --- a/crates/poseidon_circuit/src/prove.rs +++ b/crates/poseidon_circuit/src/prove.rs @@ -12,13 +12,14 @@ use tracing::{info_span, instrument}; pub fn prove_poseidon_gkr( prover_state: &mut FSProver>, witness: &PoseidonWitness, WIDTH, N_COMMITED_CUBES>, - mut point: Vec, + output_point: Vec, univariate_skips: usize, layers: &PoseidonGKRLayers, ) -> GKRPoseidonResult where KoalaBearInternalLayerParameters: InternalLayerBaseParameters, { + let mut point = output_point.clone(); let selectors = univariate_selectors::(univariate_skips); let (inv_mds_matrix, inv_light_matrix) = build_poseidon_inv_matrices::(); @@ -182,8 +183,9 @@ where ) }; + let output_statements = MultiEvaluation::new(output_point, output_claims); GKRPoseidonResult { - output_values: output_claims, + output_statements, input_statements, cubes_statements, } @@ -306,7 +308,7 @@ fn inner_evals_on_commited_columns( point: &[EF], univariate_skips: usize, columns: &[Vec>], -) -> (MultilinearPoint, Vec) { +) -> MultiEvaluation { let eq_mle = eval_eq_packed(&point[1..]); let inner_evals = columns .par_iter() @@ -333,5 +335,5 @@ fn inner_evals_on_commited_columns( values_to_prove .push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); } - (point_to_prove, values_to_prove) + MultiEvaluation::new(point_to_prove, values_to_prove) } diff --git a/crates/poseidon_circuit/src/tests.rs b/crates/poseidon_circuit/src/tests.rs index ee053ac6..47609e2e 100644 --- a/crates/poseidon_circuit/src/tests.rs +++ b/crates/poseidon_circuit/src/tests.rs @@ -3,12 +3,12 @@ use p3_koala_bear::{ KoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters, QuinticExtensionFieldKB, }; use p3_monty_31::InternalLayerBaseParameters; -use packed_pcs::{ +use rand::{Rng, SeedableRng, rngs::StdRng}; +use std::{array, time::Instant}; +use sub_protocols::{ ColDims, packed_pcs_commit, packed_pcs_global_statements_for_prover, packed_pcs_global_statements_for_verifier, packed_pcs_parse_commitment, }; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use std::{array, time::Instant}; use utils::{ build_prover_state, build_verifier_state, init_tracing, poseidon16_permute_mut, poseidon24_permute_mut, transposed_par_iter_mut, @@ -93,8 +93,7 @@ pub fn run_poseidon_benchmark< proof_size_gkr, output_layer, prover_duration, - output_values_prover, - claim_point, + output_statements_prover, ) = { // ---------------------------------------------------- PROVER ---------------------------------------------------- @@ -140,7 +139,7 @@ pub fn run_poseidon_benchmark< let claim_point = prover_state.sample_vec(log_n_poseidons); let GKRPoseidonResult { - output_values, + output_statements, input_statements, cubes_statements, } = prove_poseidon_gkr( @@ -150,13 +149,14 @@ pub fn run_poseidon_benchmark< UNIVARIATE_SKIPS, &layers, ); + assert_eq!(&output_statements.point.0, &claim_point); // PCS opening let mut pcs_statements = vec![]; - for (point_to_prove, evals_to_prove) in [input_statements, cubes_statements] { - for v in evals_to_prove { + for meval in [input_statements, cubes_statements] { + for v in meval.values { pcs_statements.push(vec![Evaluation { - point: point_to_prove.clone(), + point: meval.point.clone(), value: v, }]); } @@ -190,14 +190,13 @@ pub fn run_poseidon_benchmark< true => witness.compression.unwrap().2, }, prover_duration, - output_values, - claim_point, + output_statements, ) }; let verifier_time = Instant::now(); - let output_values_verifier = { + let output_statements_verifier = { // ---------------------------------------------------- VERIFIER ---------------------------------------------------- let parsed_pcs_commitment = packed_pcs_parse_commitment( @@ -211,7 +210,7 @@ pub fn run_poseidon_benchmark< let output_claim_point = verifier_state.sample_vec(log_n_poseidons); let GKRPoseidonResult { - output_values, + output_statements, input_statements, cubes_statements, } = verify_poseidon_gkr( @@ -222,13 +221,14 @@ pub fn run_poseidon_benchmark< UNIVARIATE_SKIPS, if compress { Some(n_compressions) } else { None }, ); + assert_eq!(&output_statements.point.0, &output_claim_point); // PCS verification let mut pcs_statements = vec![]; - for (point_to_verif, evals_to_verif) in [input_statements, cubes_statements] { - for v in evals_to_verif { + for meval in [input_statements, cubes_statements] { + for v in meval.values { pcs_statements.push(vec![Evaluation { - point: point_to_verif.clone(), + point: meval.point.clone(), value: v, }]); } @@ -250,7 +250,7 @@ pub fn run_poseidon_benchmark< global_statements, ) .unwrap(); - output_values + output_statements }; let verifier_duration = verifier_time.elapsed(); @@ -304,13 +304,13 @@ pub fn run_poseidon_benchmark< assert_eq!(PFPacking::::unpack_slice(layer), data_to_hash[i]); }); } - assert_eq!(output_values_verifier, output_values_prover); + assert_eq!(&output_statements_prover, &output_statements_verifier); assert_eq!( - output_values_verifier.as_slice(), + &output_statements_verifier.values, &output_layer .iter() .map(|layer| PFPacking::::unpack_slice(layer) - .evaluate(&MultilinearPoint(claim_point.clone()))) + .evaluate(&output_statements_verifier.point)) .collect::>() ); diff --git a/crates/poseidon_circuit/src/verify.rs b/crates/poseidon_circuit/src/verify.rs index 0c98808b..6b8521b1 100644 --- a/crates/poseidon_circuit/src/verify.rs +++ b/crates/poseidon_circuit/src/verify.rs @@ -172,8 +172,10 @@ where ) }; + let output_statements = + MultiEvaluation::new(MultilinearPoint(output_claim_point.to_vec()), output_claims); GKRPoseidonResult { - output_values: output_claims, + output_statements, input_statements, cubes_statements, } @@ -221,7 +223,7 @@ fn verify_inner_evals_on_commited_columns( point: &[EF], claimed_evals: &[EF], selectors: &[DensePolynomial], -) -> (MultilinearPoint, Vec) { +) -> MultiEvaluation { let univariate_skips = log2_strict_usize(selectors.len()); let inner_evals_inputs = verifier_state .next_extension_scalars_vec(claimed_evals.len() << univariate_skips) @@ -246,5 +248,5 @@ fn verify_inner_evals_on_commited_columns( values_to_verif .push(col_inner_evals.evaluate(&MultilinearPoint(pcs_batching_scalars_inputs.clone()))); } - (point_to_verif, values_to_verif) + MultiEvaluation::new(point_to_verif, values_to_verif) } diff --git a/crates/rec_aggregation/Cargo.toml b/crates/rec_aggregation/Cargo.toml index de4c1a4d..6c3a2e91 100644 --- a/crates/rec_aggregation/Cargo.toml +++ b/crates/rec_aggregation/Cargo.toml @@ -8,22 +8,18 @@ workspace = true [dependencies] utils.workspace = true -p3-field.workspace = true xmss.workspace = true rand.workspace = true p3-poseidon2.workspace = true p3-koala-bear.workspace = true p3-challenger.workspace = true p3-air.workspace = true -p3-matrix.workspace = true p3-symmetric.workspace = true p3-util.workspace = true whir-p3.workspace = true -rayon.workspace = true tracing.workspace = true air.workspace = true -packed_pcs.workspace = true -p3-poseidon2-air.workspace = true +sub_protocols.workspace = true lookup.workspace = true lean_vm.workspace = true serde_json.workspace = true diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 833bc0e2..9d0971a0 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -1,13 +1,10 @@ -use std::time::Instant; - use lean_compiler::*; use lean_prover::whir_config_builder; use lean_prover::{prove_execution::prove_execution, verify_execution::verify_execution}; use lean_vm::*; -use p3_field::Field; -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use rand::{Rng, SeedableRng, rngs::StdRng}; -use rayon::prelude::*; +use std::time::Instant; use tracing::instrument; use xmss::{ PhonyXmssSecretKey, Poseidon16History, Poseidon24History, V, XmssPublicKey, XmssSignature, diff --git a/crates/packed_pcs/Cargo.toml b/crates/sub_protocols/Cargo.toml similarity index 80% rename from crates/packed_pcs/Cargo.toml rename to crates/sub_protocols/Cargo.toml index a4593b10..2f51c4ce 100644 --- a/crates/packed_pcs/Cargo.toml +++ b/crates/sub_protocols/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "packed_pcs" +name = "sub_protocols" version.workspace = true edition.workspace = true @@ -7,14 +7,15 @@ edition.workspace = true workspace = true [dependencies] -p3-field.workspace = true tracing.workspace = true utils.workspace = true whir-p3.workspace = true -rayon.workspace = true +derive_more.workspace = true p3-util.workspace = true multilinear-toolkit.workspace = true +lookup.workspace = true [dev-dependencies] p3-koala-bear.workspace = true rand.workspace = true + diff --git a/crates/sub_protocols/src/commit_extension_from_base.rs b/crates/sub_protocols/src/commit_extension_from_base.rs new file mode 100644 index 00000000..2f83bd1b --- /dev/null +++ b/crates/sub_protocols/src/commit_extension_from_base.rs @@ -0,0 +1,73 @@ +use crate::ColDims; +use multilinear_toolkit::prelude::*; +use utils::dot_product_with_base; +use utils::transpose_slice_to_basis_coefficients; + +/// Commit extension field columns with a PCS allowing to commit in the base field + +#[derive(Debug)] +pub struct ExtensionCommitmentFromBaseProver>> { + pub sub_columns_to_commit: Vec>>, +} + +pub fn committed_dims_extension_from_base>>( + non_zero_height: usize, + default_value: EF, +) -> Vec>> { + EF::as_basis_coefficients_slice(&default_value) + .iter() + .map(|&d| ColDims::padded(non_zero_height, d)) + .collect() +} + +impl>> ExtensionCommitmentFromBaseProver { + pub fn before_commitment(extension_column: &[EF]) -> Self { + let sub_columns_to_commit = + transpose_slice_to_basis_coefficients::, EF>(extension_column); + Self { + sub_columns_to_commit, + } + } + + pub fn after_commitment( + &self, + prover_state: &mut FSProver>, + evaluation_point: &MultilinearPoint, + ) -> Vec>> { + let sub_evals = self + .sub_columns_to_commit + .par_iter() + .map(|slice| slice.evaluate(evaluation_point)) + .collect::>(); + + prover_state.add_extension_scalars(&sub_evals); + + sub_evals + .iter() + .map(|&sub_value| vec![Evaluation::new(evaluation_point.clone(), sub_value)]) + .collect::>() + } +} + +#[derive(Debug)] +pub struct ExtensionCommitmentFromBaseVerifier {} + +impl ExtensionCommitmentFromBaseVerifier { + pub fn after_commitment>>( + verifier_state: &mut FSVerifier>, + claim: &Evaluation, + ) -> ProofResult>>> { + let sub_evals = verifier_state.next_extension_scalars_vec(EF::DIMENSION)?; + + let statements_remaning_to_verify = sub_evals + .iter() + .map(|&sub_value| vec![Evaluation::new(claim.point.clone(), sub_value)]) + .collect::>(); + + if dot_product_with_base(&sub_evals) != claim.value { + return Err(ProofError::InvalidProof); + } + + Ok(statements_remaning_to_verify) + } +} diff --git a/crates/sub_protocols/src/generic_packed_lookup.rs b/crates/sub_protocols/src/generic_packed_lookup.rs new file mode 100644 index 00000000..3b2325ef --- /dev/null +++ b/crates/sub_protocols/src/generic_packed_lookup.rs @@ -0,0 +1,375 @@ +use lookup::compute_pushforward; +use lookup::prove_logup_star; +use lookup::verify_logup_star; +use multilinear_toolkit::prelude::*; +use std::any::TypeId; +use utils::VecOrSlice; +use utils::{FSProver, assert_eq_many}; + +use crate::{ColDims, MultilinearChunks, packed_pcs_global_statements_for_prover}; + +#[derive(Debug)] +pub struct GenericPackedLookupProver<'a, TF: Field, EF: ExtensionField + ExtensionField>> +{ + // inputs + pub(crate) table: VecOrSlice<'a, TF>, + pub(crate) index_columns: Vec<&'a [PF]>, + + // outputs + pub(crate) n_cols_per_group: Vec, + pub(crate) chunks: MultilinearChunks, + pub(crate) packed_lookup_indexes: Vec>, + pub(crate) poly_eq_point: Vec, + pub(crate) pushforward: Vec, // to be committed + pub(crate) batched_value: EF, +} + +#[derive(Debug, PartialEq)] +pub struct PackedLookupStatements { + pub on_table: Evaluation, + pub on_pushforward: Vec>, + pub on_indexes: Vec>>, // contain sparse points (TODO take advantage of it) +} + +impl<'a, TF: Field, EF: ExtensionField + ExtensionField>> + GenericPackedLookupProver<'a, TF, EF> +where + PF: PrimeField64, +{ + pub fn pushforward_to_commit(&self) -> &[EF] { + &self.pushforward + } + + // before committing to the pushforward + #[allow(clippy::too_many_arguments)] + pub fn step_1( + prover_state: &mut FSProver>, + table: VecOrSlice<'a, TF>, // table[0] is assumed to be zero + index_columns: Vec<&'a [PF]>, + heights: Vec, + default_indexes: Vec, + value_columns: Vec>>, // value_columns[i][j] = (index_columns[i] + j)*table (using the notation of https://eprint.iacr.org/2025/946) + statements: Vec>>, + log_smallest_decomposition_chunk: usize, + ) -> Self { + let table_ref = table.as_slice(); + assert!(table_ref[0].is_zero()); + assert!(table_ref.len().is_power_of_two()); + assert_eq_many!( + index_columns.len(), + heights.len(), + default_indexes.len(), + value_columns.len(), + statements.len() + ); + value_columns + .iter() + .zip(&statements) + .for_each(|(cols, evals)| { + assert_eq!(cols.len(), evals[0].num_values()); + }); + let n_groups = value_columns.len(); + let n_cols_per_group = value_columns + .iter() + .map(|cols| cols.len()) + .collect::>(); + + let flatened_value_columns = value_columns + .iter() + .flat_map(|cols| cols.iter().map(|col| col.as_slice())) + .collect::>(); + + let mut all_dims = vec![]; + for (i, (default_index, height)) in default_indexes.iter().zip(heights.iter()).enumerate() { + for col_index in 0..n_cols_per_group[i] { + all_dims.push(ColDims::padded( + *height, + table_ref[col_index + default_index], + )); + } + } + + let (_packed_lookup_values, chunks) = crate::compute_multilinear_chunks_and_apply( + &flatened_value_columns, + &all_dims, + log_smallest_decomposition_chunk, + ); + + let packed_statements = packed_pcs_global_statements_for_prover( + &flatened_value_columns, + &all_dims, + log_smallest_decomposition_chunk, + &expand_multi_evals(&statements), + prover_state, + ); + + let mut missing_shifted_index_cols = vec![vec![]; n_groups]; + for (i, index_col) in index_columns.iter().enumerate() { + for j in 1..n_cols_per_group[i] { + let shifted_col = index_col + .par_iter() + .map(|&x| x + PF::::from_usize(j)) + .collect::>>(); + missing_shifted_index_cols[i].push(shifted_col); + } + } + let mut all_index_cols_ref = vec![]; + for (i, index_col) in index_columns.iter().enumerate() { + all_index_cols_ref.push(*index_col); + for shifted_col in &missing_shifted_index_cols[i] { + all_index_cols_ref.push(shifted_col.as_slice()); + } + } + + let packed_lookup_indexes = chunks.apply(&all_index_cols_ref); + + let batching_scalar = prover_state.sample(); + + let mut poly_eq_point = EF::zero_vec(1 << chunks.packed_n_vars); + for (alpha_power, statement) in batching_scalar.powers().zip(&packed_statements) { + compute_sparse_eval_eq(&statement.point, &mut poly_eq_point, alpha_power); + } + let pushforward = + compute_pushforward(&packed_lookup_indexes, table_ref.len(), &poly_eq_point); + + let batched_value: EF = batching_scalar + .powers() + .zip(&packed_statements) + .map(|(alpha_power, statement)| alpha_power * statement.value) + .sum(); + + Self { + table, + index_columns, + n_cols_per_group, + batched_value, + packed_lookup_indexes, + poly_eq_point, + pushforward, + chunks, + } + } + + // after committing to the pushforward + pub fn step_2( + &self, + prover_state: &mut FSProver>, + non_zero_memory_size: usize, + ) -> PackedLookupStatements { + let table = if TypeId::of::() == TypeId::of::>() { + MleRef::Base(unsafe { std::mem::transmute::<&[TF], &[PF]>(self.table.as_slice()) }) + } else if TypeId::of::() == TypeId::of::() { + MleRef::Extension(unsafe { std::mem::transmute::<&[TF], &[EF]>(self.table.as_slice()) }) + } else { + panic!(); + }; + let logup_star_statements = prove_logup_star( + prover_state, + &table, + &self.packed_lookup_indexes, + self.batched_value, + &self.poly_eq_point, + &self.pushforward, + Some(non_zero_memory_size), + ); + + let mut value_on_packed_indexes = EF::ZERO; + let mut offset = 0; + let mut index_statements_to_prove = vec![]; + for (i, n_cols) in self.n_cols_per_group.iter().enumerate() { + let my_chunks = &self.chunks[offset..offset + n_cols]; + offset += n_cols; + + assert!(my_chunks.iter().all(|col_chunks| { + col_chunks.iter().zip(my_chunks[0].iter()).all(|(c1, c2)| { + c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars + }) + })); + let mut inner_statements = vec![]; + let mut inner_evals = vec![]; + for chunk in &my_chunks[0] { + let sparse_point = MultilinearPoint( + [ + chunk.bits_offset_in_original(), + logup_star_statements.on_indexes.point + [self.chunks.packed_n_vars - chunk.n_vars..] + .to_vec(), + ] + .concat(), + ); + let eval = self.index_columns[i].evaluate_sparse(&sparse_point); + inner_evals.push(eval); + inner_statements.push(Evaluation::new(sparse_point, eval)); + } + prover_state.add_extension_scalars(&inner_evals); + index_statements_to_prove.push(inner_statements); + + for (col_index, chunks_for_col) in my_chunks.iter().enumerate() { + for (&inner_eval, chunk) in inner_evals.iter().zip(chunks_for_col) { + let missing_vars = self.chunks.packed_n_vars - chunk.n_vars; + value_on_packed_indexes += (inner_eval + PF::::from_usize(col_index)) + * MultilinearPoint( + logup_star_statements.on_indexes.point[..missing_vars].to_vec(), + ) + .eq_poly_outside(&MultilinearPoint( + chunk.bits_offset_in_packed(self.chunks.packed_n_vars), + )); + } + } + } + // sanity check + assert_eq!( + value_on_packed_indexes, + logup_star_statements.on_indexes.value + ); + + PackedLookupStatements { + on_table: logup_star_statements.on_table, + on_pushforward: logup_star_statements.on_pushforward, + on_indexes: index_statements_to_prove, + } + } +} + +#[derive(Debug)] +pub struct GenericPackedLookupVerifier>> { + n_cols_per_group: Vec, + chunks: MultilinearChunks, + batching_scalar: EF, + packed_statements: Vec>, +} + +impl>> GenericPackedLookupVerifier +where + PF: PrimeField64, +{ + // before receiving the commitment to the pushforward + pub fn step_1>>( + verifier_state: &mut FSVerifier>, + heights: Vec, + default_indexes: Vec, + statements: Vec>>, + log_smallest_decomposition_chunk: usize, + table_initial_values: &[TF], + ) -> ProofResult + where + EF: ExtensionField, + { + let n_cols_per_group = statements + .iter() + .map(|evals| evals[0].num_values()) + .collect::>(); + let mut all_dims = vec![]; + for (i, (default_index, height)) in default_indexes.iter().zip(heights.iter()).enumerate() { + for col_index in 0..n_cols_per_group[i] { + all_dims.push(ColDims::padded( + *height, + table_initial_values[col_index + default_index], + )); + } + } + + let packed_statements = crate::packed_pcs_global_statements_for_verifier( + &all_dims, + log_smallest_decomposition_chunk, + &expand_multi_evals(&statements), + verifier_state, + &Default::default(), + )?; + let chunks = MultilinearChunks::compute(&all_dims, log_smallest_decomposition_chunk); + + let batching_scalar = verifier_state.sample(); + + Ok(Self { + n_cols_per_group, + chunks, + batching_scalar, + packed_statements, + }) + } + + // after receiving the commitment to the pushforward + pub fn step_2( + &self, + verifier_state: &mut FSVerifier>, + log_memory_size: usize, + ) -> ProofResult> { + let logup_star_statements = verify_logup_star( + verifier_state, + log_memory_size, + self.chunks.packed_n_vars, + &self.packed_statements, + self.batching_scalar, + ) + .unwrap(); + + let mut value_on_packed_indexes = EF::ZERO; + let mut offset = 0; + let mut index_statements_to_verify = vec![]; + for n_cols in &self.n_cols_per_group { + let my_chunks = &self.chunks[offset..offset + n_cols]; + offset += n_cols; + + // sanity check + assert!(my_chunks.iter().all(|col_chunks| { + col_chunks.iter().zip(my_chunks[0].iter()).all(|(c1, c2)| { + c1.offset_in_original == c2.offset_in_original && c1.n_vars == c2.n_vars + }) + })); + let mut inner_statements = vec![]; + let inner_evals = verifier_state.next_extension_scalars_vec(my_chunks[0].len())?; + for (chunk, &eval) in my_chunks[0].iter().zip(&inner_evals) { + let sparse_point = MultilinearPoint( + [ + chunk.bits_offset_in_original(), + logup_star_statements.on_indexes.point + [self.chunks.packed_n_vars - chunk.n_vars..] + .to_vec(), + ] + .concat(), + ); + inner_statements.push(Evaluation::new(sparse_point, eval)); + } + index_statements_to_verify.push(inner_statements); + + for (col_index, chunks_for_col) in my_chunks.iter().enumerate() { + for (&inner_eval, chunk) in inner_evals.iter().zip(chunks_for_col) { + let missing_vars = self.chunks.packed_n_vars - chunk.n_vars; + value_on_packed_indexes += (inner_eval + PF::::from_usize(col_index)) + * MultilinearPoint( + logup_star_statements.on_indexes.point[..missing_vars].to_vec(), + ) + .eq_poly_outside(&MultilinearPoint( + chunk.bits_offset_in_packed(self.chunks.packed_n_vars), + )); + } + } + } + if value_on_packed_indexes != logup_star_statements.on_indexes.value { + return Err(ProofError::InvalidProof); + } + + Ok(PackedLookupStatements { + on_table: logup_star_statements.on_table, + on_pushforward: logup_star_statements.on_pushforward, + on_indexes: index_statements_to_verify, + }) + } +} + +fn expand_multi_evals( + statements: &[Vec>], +) -> Vec>> { + statements + .iter() + .flat_map(|multi_evals| { + let mut evals = vec![vec![]; multi_evals[0].num_values()]; + for meval in multi_evals { + for (i, &v) in meval.values.iter().enumerate() { + evals[i].push(Evaluation::new(meval.point.clone(), v)); + } + } + evals + }) + .collect::>() +} diff --git a/crates/sub_protocols/src/lib.rs b/crates/sub_protocols/src/lib.rs new file mode 100644 index 00000000..809f23af --- /dev/null +++ b/crates/sub_protocols/src/lib.rs @@ -0,0 +1,14 @@ +mod generic_packed_lookup; +pub use generic_packed_lookup::*; + +mod packed_pcs; +pub use packed_pcs::*; + +mod commit_extension_from_base; +pub use commit_extension_from_base::*; + +mod normal_packed_lookup; +pub use normal_packed_lookup::*; + +mod vectorized_packed_lookup; +pub use vectorized_packed_lookup::*; diff --git a/crates/sub_protocols/src/normal_packed_lookup.rs b/crates/sub_protocols/src/normal_packed_lookup.rs new file mode 100644 index 00000000..12a986c1 --- /dev/null +++ b/crates/sub_protocols/src/normal_packed_lookup.rs @@ -0,0 +1,194 @@ +use multilinear_toolkit::prelude::*; +use utils::FSProver; +use utils::VecOrSlice; +use utils::dot_product_with_base; +use utils::transpose_slice_to_basis_coefficients; + +use crate::GenericPackedLookupProver; +use crate::GenericPackedLookupVerifier; +use crate::PackedLookupStatements; + +#[derive(Debug)] +pub struct NormalPackedLookupProver<'a, EF: ExtensionField>> { + generic: GenericPackedLookupProver<'a, PF, EF>, +} + +impl<'a, EF: ExtensionField>> NormalPackedLookupProver<'a, EF> +where + PF: PrimeField64, +{ + pub fn pushforward_to_commit(&self) -> &[EF] { + self.generic.pushforward_to_commit() + } + + // before committing to the pushforward + #[allow(clippy::too_many_arguments)] + pub fn step_1( + prover_state: &mut FSProver>, + table: &'a [PF], // table[0] is assumed to be zero + index_columns: Vec<&'a [PF]>, + heights: Vec, + default_indexes: Vec, + value_columns_base: Vec<&'a [PF]>, + value_columns_extension: Vec<&'a [EF]>, + statements: Vec>>, + log_smallest_decomposition_chunk: usize, + ) -> Self { + assert_eq!( + index_columns.len(), + value_columns_base.len() + value_columns_extension.len() + ); + let n_base_cols = value_columns_base.len(); + // let n_extension_cols = value_columns_extension.len(); + + let mut all_value_columns = vec![]; + for col_base in value_columns_base { + all_value_columns.push(vec![VecOrSlice::Slice(col_base)]); + } + for col_extension in &value_columns_extension { + all_value_columns.push( + transpose_slice_to_basis_coefficients::, EF>(col_extension) + .into_iter() + .map(VecOrSlice::Vec) + .collect(), + ); + } + + let mut multi_eval_statements = vec![]; + for eval_group in &statements[..n_base_cols] { + multi_eval_statements.push( + eval_group + .iter() + .map(|e| MultiEvaluation::new(e.point.clone(), vec![e.value])) + .collect(), + ); + } + + for (eval_group, extension_column_split) in statements[n_base_cols..] + .iter() + .zip(&all_value_columns[n_base_cols..]) + { + let mut multi_evals = vec![]; + for eval in eval_group { + let sub_evals = extension_column_split + .par_iter() + .map(|slice| slice.as_slice().evaluate(&eval.point)) + .collect::>(); + // sanity check: + assert_eq!(dot_product_with_base(&sub_evals), eval.value); + + prover_state.add_extension_scalars(&sub_evals); + multi_evals.push(MultiEvaluation::new(eval.point.clone(), sub_evals)); + } + + multi_eval_statements.push(multi_evals); + } + + let generic = GenericPackedLookupProver::step_1( + prover_state, + VecOrSlice::Slice(table), + index_columns, + heights, + default_indexes, + all_value_columns, + multi_eval_statements, + log_smallest_decomposition_chunk, + ); + + Self { generic } + } + + // after committing to the pushforward + pub fn step_2( + &self, + prover_state: &mut FSProver>, + non_zero_memory_size: usize, + ) -> PackedLookupStatements { + self.generic.step_2(prover_state, non_zero_memory_size) + } +} + +#[derive(Debug)] +pub struct NormalPackedLookupVerifier>> { + generic: GenericPackedLookupVerifier, +} + +impl>> NormalPackedLookupVerifier +where + PF: PrimeField64, +{ + // before receiving the commitment to the pushforward + pub fn step_1>>( + verifier_state: &mut FSVerifier>, + n_base_cols: usize, + heights: Vec, + default_indexes: Vec, + statements: Vec>>, + log_smallest_decomposition_chunk: usize, + table_initial_values: &[TF], + ) -> ProofResult + where + EF: ExtensionField, + { + let mut multi_eval_statements = vec![]; + for eval_group in &statements[..n_base_cols] { + multi_eval_statements.push( + eval_group + .iter() + .map(|e| MultiEvaluation::new(e.point.clone(), vec![e.value])) + .collect(), + ); + } + + for eval_group in &statements[n_base_cols..] { + let mut multi_evals = vec![]; + for eval in eval_group { + let sub_evals = verifier_state + .next_extension_scalars_vec(>>::DIMENSION)?; + if dot_product_with_base(&sub_evals) != eval.value { + return Err(ProofError::InvalidProof); + } + multi_evals.push(MultiEvaluation::new(eval.point.clone(), sub_evals)); + } + + multi_eval_statements.push(multi_evals); + } + + let generic = GenericPackedLookupVerifier::step_1( + verifier_state, + heights, + default_indexes, + multi_eval_statements, + log_smallest_decomposition_chunk, + table_initial_values, + )?; + + Ok(Self { generic }) + } + + // after receiving the commitment to the pushforward + pub fn step_2( + &self, + verifier_state: &mut FSVerifier>, + log_memory_size: usize, + ) -> ProofResult> { + self.generic.step_2(verifier_state, log_memory_size) + } +} + +fn expand_multi_evals( + statements: &[Vec>], +) -> Vec>> { + statements + .iter() + .flat_map(|multi_evals| { + let mut evals = vec![vec![]; multi_evals[0].num_values()]; + for meval in multi_evals { + for (i, &v) in meval.values.iter().enumerate() { + evals[i].push(Evaluation::new(meval.point.clone(), v)); + } + } + evals + }) + .collect::>() +} diff --git a/crates/packed_pcs/src/lib.rs b/crates/sub_protocols/src/packed_pcs.rs similarity index 80% rename from crates/packed_pcs/src/lib.rs rename to crates/sub_protocols/src/packed_pcs.rs index 398ece97..f6968e8c 100644 --- a/crates/packed_pcs/src/lib.rs +++ b/crates/sub_protocols/src/packed_pcs.rs @@ -1,7 +1,6 @@ use std::{any::TypeId, cmp::Reverse, collections::BTreeMap}; use multilinear_toolkit::prelude::*; -use p3_field::{ExtensionField, Field, TwoAdicField}; use p3_util::{log2_ceil_usize, log2_strict_usize}; use tracing::instrument; use utils::{ @@ -11,40 +10,34 @@ use utils::{ use whir_p3::*; #[derive(Debug, Clone)] -struct Chunk { - original_poly_index: usize, - original_n_vars: usize, - n_vars: usize, - offset_in_original: usize, - public_data: bool, - offset_in_packed: Option, +pub struct Chunk { + pub original_poly_index: usize, + pub original_n_vars: usize, + pub n_vars: usize, + pub offset_in_original: usize, + pub public_data: bool, + pub offset_in_packed: Option, } impl Chunk { - fn bits_offset_in_original(&self) -> Vec { + pub fn bits_offset_in_original(&self) -> Vec { to_big_endian_in_field( self.offset_in_original >> self.n_vars, self.original_n_vars - self.n_vars, ) } -} - -impl Chunk { + pub fn bits_offset_in_packed(&self, packed_n_vars: usize) -> Vec { + to_big_endian_in_field( + self.offset_in_packed.unwrap() >> self.n_vars, + packed_n_vars - self.n_vars, + ) + } fn global_point_for_statement( &self, point: &[F], packed_n_vars: usize, ) -> MultilinearPoint { - MultilinearPoint( - [ - to_big_endian_in_field( - self.offset_in_packed.unwrap() >> self.n_vars, - packed_n_vars - self.n_vars, - ), - point.to_vec(), - ] - .concat(), - ) + MultilinearPoint([self.bits_offset_in_packed(packed_n_vars), point.to_vec()].concat()) } } @@ -140,44 +133,11 @@ fn split_in_chunks( } } -fn compute_chunks( - dims: &[ColDims], - log_smallest_decomposition_chunk: usize, -) -> (BTreeMap>, usize) { - let mut all_chunks = Vec::new(); - for (i, dim) in dims.iter().enumerate() { - all_chunks.extend(split_in_chunks(i, dim, log_smallest_decomposition_chunk)); - } - all_chunks.sort_by_key(|c| (Reverse(c.public_data), Reverse(c.n_vars))); - - let mut offset_in_packed = 0; - let mut chunks_decomposition: BTreeMap<_, Vec<_>> = BTreeMap::new(); - for chunk in &mut all_chunks { - if !chunk.public_data { - chunk.offset_in_packed = Some(offset_in_packed); - offset_in_packed += 1 << chunk.n_vars; - } - chunks_decomposition - .entry(chunk.original_poly_index) - .or_default() - .push(chunk.clone()); - } - let packed_n_vars = log2_ceil_usize( - all_chunks - .iter() - .filter(|c| !c.public_data) - .map(|c| 1 << c.n_vars) - .sum::(), - ); - (chunks_decomposition, packed_n_vars) -} - pub fn num_packed_vars_for_dims( dims: &[ColDims], log_smallest_decomposition_chunk: usize, ) -> usize { - let (_, packed_n_vars) = compute_chunks::(dims, log_smallest_decomposition_chunk); - packed_n_vars + MultilinearChunks::compute(dims, log_smallest_decomposition_chunk).packed_n_vars } #[derive(Debug)] @@ -186,18 +146,86 @@ pub struct MultiCommitmentWitness>> { pub packed_polynomial: MleOwned, } +#[derive(Debug, derive_more::Deref)] +pub struct MultilinearChunks { + #[deref] + pub chunks_decomposition: Vec>, + pub packed_n_vars: usize, +} + +impl MultilinearChunks { + pub fn compute(dims: &[ColDims], log_smallest_decomposition_chunk: usize) -> Self { + let mut all_chunks = Vec::new(); + for (i, dim) in dims.iter().enumerate() { + all_chunks.extend(split_in_chunks(i, dim, log_smallest_decomposition_chunk)); + } + all_chunks.sort_by_key(|c| (Reverse(c.public_data), Reverse(c.n_vars))); + + let mut offset_in_packed = 0; + let mut chunks_decomposition: BTreeMap<_, Vec<_>> = BTreeMap::new(); + for chunk in &mut all_chunks { + if !chunk.public_data { + chunk.offset_in_packed = Some(offset_in_packed); + offset_in_packed += 1 << chunk.n_vars; + } + chunks_decomposition + .entry(chunk.original_poly_index) + .or_default() + .push(chunk.clone()); + } + let packed_n_vars = log2_ceil_usize( + all_chunks + .iter() + .filter(|c| !c.public_data) + .map(|c| 1 << c.n_vars) + .sum::(), + ); + let chunks_decomposition = chunks_decomposition.values().cloned().collect::>(); + assert_eq!(chunks_decomposition.len(), dims.len()); + Self { + chunks_decomposition, + packed_n_vars, + } + } + + pub fn apply(&self, polynomials: &[&[F]]) -> Vec + where + F: Field, + { + let packed_polynomial = F::zero_vec(1 << self.packed_n_vars); // TODO avoid this huge cloning of all witness data + self.iter() + .flatten() + .collect::>() + .par_iter() + .filter(|chunk| !chunk.public_data) + .for_each(|chunk| { + let start = chunk.offset_in_packed.unwrap(); + let end = start + (1 << chunk.n_vars); + let original_poly = &polynomials[chunk.original_poly_index]; + unsafe { + let slice = std::slice::from_raw_parts_mut( + (packed_polynomial.as_ptr() as *mut F).add(start), + end - start, + ); + slice.copy_from_slice( + &original_poly[chunk.offset_in_original + ..chunk.offset_in_original + (1 << chunk.n_vars)], + ); + } + }); + + packed_polynomial + } +} + #[instrument(skip_all)] -pub fn packed_pcs_commit( - whir_config_builder: &WhirConfigBuilder, +pub fn compute_multilinear_chunks_and_apply( polynomials: &[&[F]], dims: &[ColDims], - prover_state: &mut FSProver>, log_smallest_decomposition_chunk: usize, -) -> MultiCommitmentWitness +) -> (Vec, MultilinearChunks) where - F: Field + TwoAdicField + ExtensionField>, - PF: TwoAdicField, - EF: ExtensionField + TwoAdicField + ExtensionField>, + F: Field, { assert_eq!(polynomials.len(), dims.len()); for (i, (poly, dim)) in polynomials.iter().zip(dims.iter()).enumerate() { @@ -210,14 +238,13 @@ where dim.n_vars ); } - let (chunks_decomposition, packed_n_vars) = - compute_chunks::(dims, log_smallest_decomposition_chunk); + let chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); { // logging let total_commited_data: usize = dims.iter().map(|d| d.committed_size).sum(); - let packed_commited_data: usize = chunks_decomposition - .values() + let packed_commited_data: usize = chunks + .iter() .flatten() .filter(|c| !c.public_data) .map(|c| 1 << c.n_vars) @@ -228,32 +255,34 @@ where total_commited_data, (total_commited_data as f64).log2(), (packed_commited_data as f64).log2(), - packed_n_vars + chunks.packed_n_vars ); } - let packed_polynomial = F::zero_vec(1 << packed_n_vars); // TODO avoid this huge cloning of all witness data - chunks_decomposition - .values() - .flatten() - .collect::>() - .par_iter() - .filter(|chunk| !chunk.public_data) - .for_each(|chunk| { - let start = chunk.offset_in_packed.unwrap(); - let end = start + (1 << chunk.n_vars); - let original_poly = &polynomials[chunk.original_poly_index]; - unsafe { - let slice = std::slice::from_raw_parts_mut( - (packed_polynomial.as_ptr() as *mut F).add(start), - end - start, - ); - slice.copy_from_slice( - &original_poly - [chunk.offset_in_original..chunk.offset_in_original + (1 << chunk.n_vars)], - ); - } - }); + let packed_polynomial = chunks.apply(polynomials); + + (packed_polynomial, chunks) +} + +#[instrument(skip_all)] +pub fn packed_pcs_commit( + whir_config_builder: &WhirConfigBuilder, + polynomials: &[&[F]], + dims: &[ColDims], + prover_state: &mut FSProver>, + log_smallest_decomposition_chunk: usize, +) -> MultiCommitmentWitness +where + F: Field + TwoAdicField + ExtensionField>, + PF: TwoAdicField, + EF: ExtensionField + TwoAdicField + ExtensionField>, +{ + let (packed_polynomial, _chunks_decomposition) = compute_multilinear_chunks_and_apply::( + polynomials, + dims, + log_smallest_decomposition_chunk, + ); + let packed_n_vars = log2_strict_usize(packed_polynomial.len()); let mle = if TypeId::of::() == TypeId::of::>() { MleOwned::Base(unsafe { std::mem::transmute::, Vec>>(packed_polynomial) }) @@ -291,8 +320,7 @@ pub fn packed_pcs_global_statements_for_prover< // - cache the "eq" poly, and then use dot product // - current packing is not optimal in the end: can lead to [16][4][2][2] (instead of [16][8]) - let (chunks_decomposition, packed_vars) = - compute_chunks::(dims, log_smallest_decomposition_chunk); + let all_chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); let statements_flattened = statements_per_polynomial .iter() @@ -310,9 +338,7 @@ pub fn packed_pcs_global_statements_for_prover< let dim = &dims[*poly_index]; let pol = polynomials[*poly_index]; - let chunks = chunks_decomposition - .get(poly_index) - .expect("missing chunk definition for polynomial"); + let chunks = &all_chunks[*poly_index]; assert!(!chunks.is_empty()); let mut sub_packed_statements = Vec::new(); let mut evals_to_send = Vec::new(); @@ -323,10 +349,16 @@ pub fn packed_pcs_global_statements_for_prover< statement.point.0.len(), "poly: {poly_index}" ); - assert!(chunks[0].offset_in_packed.unwrap() % (1 << chunks[0].n_vars) == 0); + assert!( + chunks[0] + .offset_in_packed + .unwrap() + .is_multiple_of(1 << chunks[0].n_vars) + ); sub_packed_statements.push(Evaluation::new( - chunks[0].global_point_for_statement(&statement.point, packed_vars), + chunks[0] + .global_point_for_statement(&statement.point, all_chunks.packed_n_vars), statement.value, )); } else { @@ -372,10 +404,13 @@ pub fn packed_pcs_global_statements_for_prover< MultilinearPoint(statement.point.0[missing_vars..].to_vec()); let sub_value = (&pol[chunk.offset_in_original ..chunk.offset_in_original + (1 << chunk.n_vars)]) - .evaluate(&sub_point); + .evaluate_sparse(&sub_point); // `evaluate_sparse` because sometime (typically due to packed lookup protocol, the original statement is already sparse) ( Some(Evaluation::new( - chunk.global_point_for_statement(&sub_point, packed_vars), + chunk.global_point_for_statement( + &sub_point, + all_chunks.packed_n_vars, + ), sub_value, )), sub_value, @@ -417,8 +452,8 @@ pub fn packed_pcs_global_statements_for_prover< let initial_sub_point = MultilinearPoint(statement.point.0[initial_missing_vars..].to_vec()); - let initial_packed_point = - chunks[0].global_point_for_statement(&initial_sub_point, packed_vars); + let initial_packed_point = chunks[0] + .global_point_for_statement(&initial_sub_point, all_chunks.packed_n_vars); sub_packed_statements .insert(0, Evaluation::new(initial_packed_point, initial_sub_value)); evals_to_send.insert(0, initial_sub_value); @@ -448,8 +483,9 @@ pub fn packed_pcs_parse_commitment< where PF: TwoAdicField, { - let (_, packed_n_vars) = compute_chunks::(dims, log_smallest_decomposition_chunk); - WhirConfig::new(whir_config_builder.clone(), packed_n_vars).parse_commitment(verifier_state) + let all_chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); + WhirConfig::new(whir_config_builder.clone(), all_chunks.packed_n_vars) + .parse_commitment(verifier_state) } pub fn packed_pcs_global_statements_for_verifier< @@ -463,23 +499,26 @@ pub fn packed_pcs_global_statements_for_verifier< public_data: &BTreeMap>, // poly_index -> public data slice (power of 2) ) -> Result>, ProofError> { assert_eq!(dims.len(), statements_per_polynomial.len()); - let (chunks_decomposition, packed_n_vars) = - compute_chunks::(dims, log_smallest_decomposition_chunk); + let all_chunks = MultilinearChunks::compute(dims, log_smallest_decomposition_chunk); let mut packed_statements = Vec::new(); for (poly_index, statements) in statements_per_polynomial.iter().enumerate() { let dim = &dims[poly_index]; let has_public_data = dim.log_public_data_size.is_some(); - let chunks = chunks_decomposition - .get(&poly_index) - .expect("missing chunk definition for polynomial"); + let chunks = &all_chunks[poly_index]; assert!(!chunks.is_empty()); for statement in statements { if chunks.len() == 1 { assert!(!chunks[0].public_data, "TODO"); assert_eq!(chunks[0].n_vars, statement.point.0.len()); - assert!(chunks[0].offset_in_packed.unwrap() % (1 << chunks[0].n_vars) == 0); + assert!( + chunks[0] + .offset_in_packed + .unwrap() + .is_multiple_of(1 << chunks[0].n_vars) + ); packed_statements.push(Evaluation::new( - chunks[0].global_point_for_statement(&statement.point, packed_n_vars), + chunks[0] + .global_point_for_statement(&statement.point, all_chunks.packed_n_vars), statement.value, )); } else { @@ -517,7 +556,7 @@ pub fn packed_pcs_global_statements_for_verifier< let sub_point = MultilinearPoint(statement.point.0[missing_vars..].to_vec()); packed_statements.push(Evaluation::new( - chunk.global_point_for_statement(&sub_point, packed_n_vars), + chunk.global_point_for_statement(&sub_point, all_chunks.packed_n_vars), sub_value, )); } @@ -564,7 +603,6 @@ fn compute_multilinear_value_from_chunks>( #[cfg(test)] mod tests { - use p3_field::PrimeCharacteristicRing; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; use p3_util::log2_strict_usize; use rand::{Rng, SeedableRng, rngs::StdRng}; diff --git a/crates/sub_protocols/src/vectorized_packed_lookup.rs b/crates/sub_protocols/src/vectorized_packed_lookup.rs new file mode 100644 index 00000000..bfc43a1e --- /dev/null +++ b/crates/sub_protocols/src/vectorized_packed_lookup.rs @@ -0,0 +1,175 @@ +use multilinear_toolkit::prelude::*; +use utils::FSProver; +use utils::VecOrSlice; +use utils::fold_multilinear_chunks; + +use crate::GenericPackedLookupProver; +use crate::GenericPackedLookupVerifier; +use crate::PackedLookupStatements; + +#[derive(Debug)] +pub struct VectorizedPackedLookupProver<'a, EF: ExtensionField>, const VECTOR_LEN: usize> { + generic: GenericPackedLookupProver<'a, EF, EF>, + folding_scalars: MultilinearPoint, +} + +impl<'a, EF: ExtensionField>, const VECTOR_LEN: usize> + VectorizedPackedLookupProver<'a, EF, VECTOR_LEN> +where + PF: PrimeField64, +{ + pub fn pushforward_to_commit(&self) -> &[EF] { + self.generic.pushforward_to_commit() + } + + // before committing to the pushforward + #[allow(clippy::too_many_arguments)] + pub fn step_1( + prover_state: &mut FSProver>, + table: &'a [PF], // table[0] is assumed to be zero + index_columns: Vec<&'a [PF]>, + heights: Vec, + default_indexes: Vec, + value_columns: Vec<[&'a [PF]; VECTOR_LEN]>, + statements: Vec>>, + log_smallest_decomposition_chunk: usize, + ) -> Self { + let folding_scalars = + MultilinearPoint(prover_state.sample_vec(log2_strict_usize(VECTOR_LEN))); + let folded_table = fold_multilinear_chunks(table, &folding_scalars); + + let folding_poly_eq = eval_eq(&folding_scalars); + let folded_value_columns = value_columns + .par_iter() + .map(|cols| { + let n = cols[0].len(); + assert!(cols.iter().all(|c| c.len() == n)); + assert!(n.is_power_of_two()); + vec![VecOrSlice::Vec( + (0..n) + .into_par_iter() + .map(|i| { + folding_poly_eq + .iter() + .enumerate() + .map(|(j, &coeff)| coeff * cols[j][i]) + .sum::() + }) + .collect::>(), + )] + }) + .collect::>(); + + let generic = GenericPackedLookupProver::<'_, EF, EF>::step_1( + prover_state, + VecOrSlice::Vec(folded_table), + index_columns, + heights, + default_indexes, + folded_value_columns, + get_folded_statements(statements, &folding_scalars), + log_smallest_decomposition_chunk, + ); + + Self { + generic, + folding_scalars, + } + } + + // after committing to the pushforward + pub fn step_2( + &self, + prover_state: &mut FSProver>, + non_zero_memory_size: usize, + ) -> PackedLookupStatements { + let mut statements = self + .generic + .step_2(prover_state, non_zero_memory_size.div_ceil(VECTOR_LEN)); + statements + .on_table + .point + .extend(self.folding_scalars.0.clone()); + statements + } +} + +#[derive(Debug)] +pub struct VectorizedPackedLookupVerifier>, const VECTOR_LEN: usize> { + generic: GenericPackedLookupVerifier, + folding_scalars: MultilinearPoint, +} + +impl>, const VECTOR_LEN: usize> + VectorizedPackedLookupVerifier +where + PF: PrimeField64, +{ + // before receiving the commitment to the pushforward + pub fn step_1( + verifier_state: &mut FSVerifier>, + heights: Vec, + default_indexes: Vec, + statements: Vec>>, + log_smallest_decomposition_chunk: usize, + table_initial_values: &[PF], + ) -> ProofResult { + let folding_scalars = + MultilinearPoint(verifier_state.sample_vec(log2_strict_usize(VECTOR_LEN))); + let folded_table_initial_values = fold_multilinear_chunks( + &table_initial_values[..(table_initial_values.len() / VECTOR_LEN) * VECTOR_LEN], + &folding_scalars, + ); + + let generic = GenericPackedLookupVerifier::step_1::( + verifier_state, + heights, + default_indexes, + get_folded_statements(statements, &folding_scalars), + log_smallest_decomposition_chunk, + &folded_table_initial_values, + )?; + + Ok(Self { + generic, + folding_scalars, + }) + } + + // after receiving the commitment to the pushforward + pub fn step_2( + &self, + verifier_state: &mut FSVerifier>, + log_memory_size: usize, + ) -> ProofResult> { + let mut statements = self.generic.step_2( + verifier_state, + log_memory_size - log2_strict_usize(VECTOR_LEN), + )?; + statements + .on_table + .point + .extend(self.folding_scalars.0.clone()); + Ok(statements) + } +} + +fn get_folded_statements( + statements: Vec>>, + folding_scalars: &MultilinearPoint, +) -> Vec>> { + statements + .iter() + .map(|sub_statements| { + sub_statements + .iter() + .map(|meval| { + MultiEvaluation::new( + meval.point.clone(), + vec![meval.values.evaluate(folding_scalars)], + ) + }) + .collect::>() + }) + .collect::>() +} diff --git a/crates/sub_protocols/tests/test_generic_packed_lookup.rs b/crates/sub_protocols/tests/test_generic_packed_lookup.rs new file mode 100644 index 00000000..a0726538 --- /dev/null +++ b/crates/sub_protocols/tests/test_generic_packed_lookup.rs @@ -0,0 +1,126 @@ +use multilinear_toolkit::prelude::*; +use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; +use p3_util::log2_ceil_usize; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use sub_protocols::{GenericPackedLookupProver, GenericPackedLookupVerifier}; +use utils::{ToUsize, VecOrSlice, assert_eq_many, build_prover_state, build_verifier_state}; + +type F = KoalaBear; +type EF = QuinticExtensionFieldKB; +const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 5; + +#[test] +fn test_generic_packed_lookup() { + let non_zero_memory_size: usize = 37412; + let lookups_height_and_cols: Vec<(usize, usize)> = + vec![(4587, 1), (1234, 3), (9411, 1), (7890, 2)]; + let default_indexes = vec![7, 11, 0, 2]; + let n_statements = [1, 5, 2, 1]; + assert_eq_many!( + lookups_height_and_cols.len(), + default_indexes.len(), + n_statements.len() + ); + + let mut rng = StdRng::seed_from_u64(0); + let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); + for mem in memory.iter_mut().take(non_zero_memory_size).skip(1) { + *mem = rng.random(); + } + + let mut all_indexe_columns = vec![]; + let mut all_value_columns = vec![]; + let mut all_statements = vec![]; + for (i, (n_lines, n_cols)) in lookups_height_and_cols.iter().enumerate() { + let mut indexes = vec![F::from_usize(default_indexes[i]); n_lines.next_power_of_two()]; + for idx in indexes.iter_mut().take(*n_lines) { + *idx = F::from_usize(rng.random_range(0..non_zero_memory_size)); + } + all_indexe_columns.push(indexes); + let indexes = all_indexe_columns.last().unwrap(); + + let mut columns = vec![]; + for col_index in 0..*n_cols { + let mut col = F::zero_vec(n_lines.next_power_of_two()); + for i in 0..n_lines.next_power_of_two() { + col[i] = memory[indexes[i].to_usize() + col_index]; + } + columns.push(col); + } + let mut statements = vec![]; + for _ in 0..n_statements[i] { + let point = MultilinearPoint::::random(&mut rng, log2_ceil_usize(*n_lines)); + let values = columns + .iter() + .map(|col| col.evaluate(&point)) + .collect::>(); + statements.push(MultiEvaluation::new(point, values)); + } + all_statements.push(statements); + all_value_columns.push(columns); + } + + let mut prover_state = build_prover_state(); + + let packed_lookup_prover = GenericPackedLookupProver::step_1( + &mut prover_state, + VecOrSlice::Slice(&memory), + all_indexe_columns.iter().map(Vec::as_slice).collect(), + lookups_height_and_cols.iter().map(|(h, _)| *h).collect(), + default_indexes.clone(), + all_value_columns + .iter() + .map(|cols| cols.iter().map(|s| VecOrSlice::Slice(s)).collect()) + .collect(), + all_statements.clone(), + LOG_SMALLEST_DECOMPOSITION_CHUNK, + ); + + // phony commitment to pushforward + prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); + + let remaining_claims_to_prove = + packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); + + let mut verifier_state = build_verifier_state(&prover_state); + + let packed_lookup_verifier = GenericPackedLookupVerifier::step_1( + &mut verifier_state, + lookups_height_and_cols.iter().map(|(h, _)| *h).collect(), + default_indexes, + all_statements, + LOG_SMALLEST_DECOMPOSITION_CHUNK, + &memory[..100], + ) + .unwrap(); + + // receive commitment to pushforward + let pushforward = verifier_state + .receive_hint_extension_scalars(non_zero_memory_size.next_power_of_two()) + .unwrap(); + + let remaining_claims_to_verify = packed_lookup_verifier + .step_2(&mut verifier_state, log2_ceil_usize(non_zero_memory_size)) + .unwrap(); + + assert_eq!(&remaining_claims_to_prove, &remaining_claims_to_verify); + + assert_eq!( + memory.evaluate(&remaining_claims_to_verify.on_table.point), + remaining_claims_to_verify.on_table.value + ); + for pusforward_statement in &remaining_claims_to_verify.on_pushforward { + assert_eq!( + pushforward.evaluate(&pusforward_statement.point), + pusforward_statement.value + ); + } + for (index_col, index_statements) in all_indexe_columns + .iter() + .zip(remaining_claims_to_verify.on_indexes.iter()) + { + for statement in index_statements { + assert_eq!(index_col.evaluate(&statement.point), statement.value); + } + } +} diff --git a/crates/sub_protocols/tests/test_normal_packed_lookup.rs b/crates/sub_protocols/tests/test_normal_packed_lookup.rs new file mode 100644 index 00000000..67fe952d --- /dev/null +++ b/crates/sub_protocols/tests/test_normal_packed_lookup.rs @@ -0,0 +1,158 @@ +use multilinear_toolkit::prelude::*; +use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; +use p3_util::log2_ceil_usize; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use sub_protocols::{NormalPackedLookupProver, NormalPackedLookupVerifier}; +use utils::{ToUsize, assert_eq_many, build_prover_state, build_verifier_state}; + +type F = KoalaBear; +type EF = QuinticExtensionFieldKB; +const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 5; + +#[test] +fn test_normal_packed_lookup() { + let non_zero_memory_size: usize = 37412; + let base_cols_heights: Vec = vec![785, 1022, 4751]; + let extension_cols_heights: Vec = vec![2088, 110]; + let default_indexes = vec![7, 11, 0, 2, 3]; + let n_statements = vec![1, 5, 2, 1, 8]; + assert_eq_many!( + base_cols_heights.len() + extension_cols_heights.len(), + default_indexes.len(), + n_statements.len() + ); + + let mut rng = StdRng::seed_from_u64(0); + let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); + for mem in memory.iter_mut().take(non_zero_memory_size).skip(1) { + *mem = rng.random(); + } + + let mut all_indexe_columns = vec![]; + for (i, height) in base_cols_heights.iter().enumerate() { + let mut indexes = vec![F::from_usize(default_indexes[i]); height.next_power_of_two()]; + for idx in indexes.iter_mut().take(*height) { + *idx = F::from_usize(rng.random_range(0..non_zero_memory_size)); + } + all_indexe_columns.push(indexes); + } + for (i, height) in extension_cols_heights.iter().enumerate() { + let mut indexes = vec![ + F::from_usize(default_indexes[i + base_cols_heights.len()]); + height.next_power_of_two() + ]; + for idx in indexes.iter_mut().take(*height) { + *idx = F::from_usize(rng.random_range( + 0..non_zero_memory_size - >>::DIMENSION, + )); + } + all_indexe_columns.push(indexes); + } + + let mut base_value_columns = vec![]; + for base_col in &all_indexe_columns[..base_cols_heights.len()] { + let mut values = vec![]; + for index in base_col { + values.push(memory[index.to_usize()]); + } + base_value_columns.push(values); + } + let mut extension_value_columns = vec![]; + for ext_col in &all_indexe_columns[base_cols_heights.len()..] { + let mut values = vec![]; + for index in ext_col { + values.push(QuinticExtensionFieldKB::from_basis_coefficients_fn(|i| { + memory[index.to_usize() + i] + })); + } + extension_value_columns.push(values); + } + + let mut all_statements = vec![]; + for (value_col_base, n_statements) in base_value_columns.iter().zip(&n_statements) { + let mut statements = vec![]; + for _ in 0..*n_statements { + let point = + MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_base.len())); + let value = value_col_base.evaluate(&point); + statements.push(Evaluation::new(point, value)); + } + all_statements.push(statements); + } + for (value_col_ext, n_statements) in extension_value_columns + .iter() + .zip(&n_statements[base_cols_heights.len()..]) + { + let mut statements = vec![]; + for _ in 0..*n_statements { + let point = + MultilinearPoint::::random(&mut rng, log2_strict_usize(value_col_ext.len())); + let value = value_col_ext.evaluate(&point); + statements.push(Evaluation::new(point, value)); + } + all_statements.push(statements); + } + + let mut prover_state = build_prover_state(); + + let packed_lookup_prover = NormalPackedLookupProver::step_1( + &mut prover_state, + &memory, + all_indexe_columns.iter().map(Vec::as_slice).collect(), + [base_cols_heights.clone(), extension_cols_heights.clone()].concat(), + default_indexes.clone(), + base_value_columns.iter().map(Vec::as_slice).collect(), + extension_value_columns.iter().map(Vec::as_slice).collect(), + all_statements.clone(), + LOG_SMALLEST_DECOMPOSITION_CHUNK, + ); + + // phony commitment to pushforward + prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); + + let remaining_claims_to_prove = + packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); + + let mut verifier_state = build_verifier_state(&prover_state); + + let packed_lookup_verifier = NormalPackedLookupVerifier::step_1( + &mut verifier_state, + base_cols_heights.len(), + [base_cols_heights, extension_cols_heights].concat(), + default_indexes, + all_statements, + LOG_SMALLEST_DECOMPOSITION_CHUNK, + &memory[..100], + ) + .unwrap(); + + // receive commitment to pushforward + let pushforward = verifier_state + .receive_hint_extension_scalars(non_zero_memory_size.next_power_of_two()) + .unwrap(); + + let remaining_claims_to_verify = packed_lookup_verifier + .step_2(&mut verifier_state, log2_ceil_usize(non_zero_memory_size)) + .unwrap(); + + assert_eq!(&remaining_claims_to_prove, &remaining_claims_to_verify); + + assert_eq!( + memory.evaluate(&remaining_claims_to_verify.on_table.point), + remaining_claims_to_verify.on_table.value + ); + for pusforward_statement in &remaining_claims_to_verify.on_pushforward { + assert_eq!( + pushforward.evaluate(&pusforward_statement.point), + pusforward_statement.value + ); + } + for (index_col, index_statements) in all_indexe_columns + .iter() + .zip(remaining_claims_to_verify.on_indexes.iter()) + { + for statement in index_statements { + assert_eq!(index_col.evaluate(&statement.point), statement.value); + } + } +} diff --git a/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs b/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs new file mode 100644 index 00000000..9ac6a648 --- /dev/null +++ b/crates/sub_protocols/tests/test_vectorized_packed_lookup.rs @@ -0,0 +1,136 @@ +use std::array; + +use multilinear_toolkit::prelude::*; +use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; +use p3_util::log2_ceil_usize; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use sub_protocols::{VectorizedPackedLookupProver, VectorizedPackedLookupVerifier}; +use utils::{ToUsize, assert_eq_many, build_prover_state, build_verifier_state}; + +type F = KoalaBear; +type EF = QuinticExtensionFieldKB; +const LOG_SMALLEST_DECOMPOSITION_CHUNK: usize = 5; + +const VECTOR_LEN: usize = 8; + +#[test] +fn test_vectorized_packed_lookup() { + let non_zero_memory_size: usize = 37412; + let cols_heights: Vec = vec![785, 1022, 4751]; + let default_indexes = vec![7, 11, 0]; + let n_statements = vec![1, 5, 2]; + assert_eq_many!( + cols_heights.len(), + default_indexes.len(), + n_statements.len() + ); + + let mut rng = StdRng::seed_from_u64(0); + let mut memory = F::zero_vec(non_zero_memory_size.next_power_of_two()); + for mem in memory + .iter_mut() + .take(non_zero_memory_size) + .skip(VECTOR_LEN) + { + *mem = rng.random(); + } + + let mut all_indexe_columns = vec![]; + for (i, height) in cols_heights.iter().enumerate() { + let mut indexes = vec![F::from_usize(default_indexes[i]); height.next_power_of_two()]; + for idx in indexes.iter_mut().take(*height) { + *idx = F::from_usize(rng.random_range(0..non_zero_memory_size / VECTOR_LEN)); + } + all_indexe_columns.push(indexes); + } + + let mut all_value_columns = vec![]; + for index_col in &all_indexe_columns { + let mut values: [Vec; VECTOR_LEN] = Default::default(); + for index in index_col { + for i in 0..VECTOR_LEN { + values[i].push(memory[index.to_usize() * VECTOR_LEN + i]); + } + } + all_value_columns.push(values); + } + + let mut all_statements = vec![]; + for (value_cols, n_statements) in all_value_columns.iter().zip(&n_statements) { + let mut statements = vec![]; + for _ in 0..*n_statements { + let point = + MultilinearPoint::::random(&mut rng, log2_strict_usize(value_cols[0].len())); + let values = value_cols + .iter() + .map(|col| col.evaluate(&point)) + .collect::>(); + statements.push(MultiEvaluation::new(point, values)); + } + all_statements.push(statements); + } + + let mut prover_state = build_prover_state(); + + let packed_lookup_prover = VectorizedPackedLookupProver::step_1( + &mut prover_state, + &memory, + all_indexe_columns.iter().map(Vec::as_slice).collect(), + cols_heights.clone(), + default_indexes.clone(), + all_value_columns + .iter() + .map(|v| array::from_fn::<_, VECTOR_LEN, _>(|i| v[i].as_slice())) + .collect(), + all_statements.clone(), + LOG_SMALLEST_DECOMPOSITION_CHUNK, + ); + + // phony commitment to pushforward + prover_state.hint_extension_scalars(packed_lookup_prover.pushforward_to_commit()); + + let remaining_claims_to_prove = + packed_lookup_prover.step_2(&mut prover_state, non_zero_memory_size); + + let mut verifier_state = build_verifier_state(&prover_state); + + let packed_lookup_verifier = VectorizedPackedLookupVerifier::<_, VECTOR_LEN>::step_1( + &mut verifier_state, + cols_heights, + default_indexes, + all_statements, + LOG_SMALLEST_DECOMPOSITION_CHUNK, + &memory[..100], + ) + .unwrap(); + + // receive commitment to pushforward + let pushforward = verifier_state + .receive_hint_extension_scalars(non_zero_memory_size.next_power_of_two() / VECTOR_LEN) + .unwrap(); + + let remaining_claims_to_verify = packed_lookup_verifier + .step_2(&mut verifier_state, log2_ceil_usize(non_zero_memory_size)) + .unwrap(); + + assert_eq!(&remaining_claims_to_prove, &remaining_claims_to_verify); + + assert_eq!( + memory.evaluate(&remaining_claims_to_verify.on_table.point), + remaining_claims_to_verify.on_table.value + ); + for pusforward_statement in &remaining_claims_to_verify.on_pushforward { + assert_eq!( + pushforward.evaluate(&pusforward_statement.point), + pusforward_statement.value + ); + } + for (index_col, index_statements) in all_indexe_columns + .iter() + .zip(remaining_claims_to_verify.on_indexes.iter()) + { + for statement in index_statements { + assert_eq!(index_col.evaluate(&statement.point), statement.value); + } + } +} diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index a5eacd82..fa686f7a 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -7,16 +7,12 @@ edition.workspace = true workspace = true [dependencies] -p3-field.workspace = true -rayon.workspace = true p3-air.workspace = true -p3-matrix.workspace = true p3-challenger.workspace = true p3-koala-bear.workspace = true tracing-forest.workspace = true p3-symmetric.workspace = true p3-poseidon2.workspace = true -p3-poseidon2-air.workspace = true tracing-subscriber.workspace = true tracing.workspace = true p3-util.workspace = true diff --git a/crates/utils/src/constraints_checker.rs b/crates/utils/src/constraints_checker.rs index 1266998e..8d7ad5b3 100644 --- a/crates/utils/src/constraints_checker.rs +++ b/crates/utils/src/constraints_checker.rs @@ -1,8 +1,6 @@ use std::marker::PhantomData; use p3_air::AirBuilder; -use p3_field::ExtensionField; -use p3_matrix::dense::RowMajorMatrixView; use multilinear_toolkit::prelude::*; @@ -12,7 +10,7 @@ Debug purpose #[derive(Debug)] pub struct ConstraintChecker<'a, IF, EF> { - pub main: RowMajorMatrixView<'a, IF>, + pub main: &'a [IF], pub constraint_index: usize, pub errors: Vec, pub field: PhantomData, @@ -24,28 +22,13 @@ impl<'a, EF: ExtensionField> + ExtensionField, IF: ExtensionField; type Expr = IF; type Var = IF; - type M = RowMajorMatrixView<'a, IF>; + type FinalOutput = EF; #[inline] - fn main(&self) -> Self::M { + fn main(&self) -> &[IF] { self.main } - #[inline] - fn is_first_row(&self) -> Self::Expr { - unreachable!() - } - - #[inline] - fn is_last_row(&self) -> Self::Expr { - unreachable!() - } - - #[inline] - fn is_transition_window(&self, _: usize) -> Self::Expr { - unreachable!() - } - #[inline] fn assert_zero>(&mut self, x: I) { let x: IF = x.into(); @@ -55,8 +38,7 @@ impl<'a, EF: ExtensionField> + ExtensionField, IF: ExtensionField>(&mut self, _: [I; N]) { - unreachable!() + fn add_custom(&mut self, value: Self::FinalOutput) { + let _ = value; } } diff --git a/crates/utils/src/debug.rs b/crates/utils/src/debug.rs new file mode 100644 index 00000000..559ae38e --- /dev/null +++ b/crates/utils/src/debug.rs @@ -0,0 +1,9 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::Hash; +use std::hash::Hasher; + +pub fn debug_hash(value: A) -> u64 { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + hasher.finish() +} diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index 367f2c5d..4960b565 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -20,3 +20,6 @@ pub use constraints_checker::*; mod poseidon2; pub use poseidon2::*; + +mod debug; +pub use debug::*; diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index 02d8d257..ec443243 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -1,6 +1,3 @@ -use p3_field::{BasedVectorSpace, ExtensionField, Field, dot_product}; -use rayon::prelude::*; - use multilinear_toolkit::prelude::*; use tracing::instrument; @@ -14,18 +11,6 @@ pub fn transmute_slice(slice: &[Before]) -> &[After] { unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const After, new_len) } } -pub fn left_ref(slice: &[A]) -> &[A] { - assert!(slice.len().is_multiple_of(2)); - let mid = slice.len() / 2; - &slice[..mid] -} - -pub fn right_ref(slice: &[A]) -> &[A] { - assert!(slice.len().is_multiple_of(2)); - let mid = slice.len() / 2; - &slice[mid..] -} - pub fn from_end(slice: &[A], n: usize) -> &[A] { assert!(n <= slice.len()); &slice[slice.len() - n..] @@ -152,3 +137,22 @@ pub fn transposed_par_iter_mut( .into_par_iter() .map(move |i| unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].0.add(i)) }) } + +#[derive(Debug)] +pub enum VecOrSlice<'a, T> { + Vec(Vec), + Slice(&'a [T]), +} + +impl<'a, T> VecOrSlice<'a, T> { + pub fn as_slice(&self) -> &[T] { + match self { + VecOrSlice::Vec(v) => v.as_slice(), + VecOrSlice::Slice(s) => s, + } + } +} + +pub fn encapsulate_vec(v: Vec) -> Vec> { + v.into_iter().map(|x| vec![x]).collect() +} diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index 9dbf5038..c9c8ca53 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -1,7 +1,6 @@ use std::borrow::Borrow; use crate::from_end; -use p3_field::{ExtensionField, Field, dot_product}; use p3_util::log2_strict_usize; use multilinear_toolkit::prelude::*; @@ -71,14 +70,6 @@ pub fn multilinear_eval_constants_at_right(limit: usize, point: &[F]) // dst // } -pub fn add_multilinears_inplace(dst: &mut [F], src: &[F]) { - assert_eq!(dst.len(), src.len()); - - dst.par_iter_mut() - .zip(src.par_iter()) - .for_each(|(a, b)| *a += *b); -} - pub fn padd_with_zero_to_next_power_of_two(pol: &[F]) -> Vec { let next_power_of_two = pol.len().next_power_of_two(); let mut padded = pol.to_vec(); @@ -130,7 +121,6 @@ pub fn fold_multilinear_chunks>( #[cfg(test)] mod tests { - use p3_field::PrimeCharacteristicRing; use p3_koala_bear::{KoalaBear, QuinticExtensionFieldKB}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; diff --git a/crates/utils/src/poseidon2.rs b/crates/utils/src/poseidon2.rs index 4214c8e8..2944d937 100644 --- a/crates/utils/src/poseidon2.rs +++ b/crates/utils/src/poseidon2.rs @@ -9,8 +9,6 @@ use p3_koala_bear::KOALABEAR_RC24_INTERNAL; use p3_koala_bear::KoalaBear; use p3_koala_bear::Poseidon2KoalaBear; use p3_poseidon2::ExternalLayerConstants; -use p3_poseidon2_air::p16::RoundConstants16; -use p3_poseidon2_air::p24::RoundConstants24; use p3_symmetric::Permutation; pub type Poseidon16 = Poseidon2KoalaBear<16>; @@ -24,23 +22,16 @@ pub const QUARTER_FULL_ROUNDS_24: usize = 2; pub const HALF_FULL_ROUNDS_24: usize = 4; pub const PARTIAL_ROUNDS_24: usize = 23; -pub type MyRoundConstants16 = RoundConstants16; -pub type MyRoundConstants24 = RoundConstants24; - static POSEIDON16_INSTANCE: OnceLock = OnceLock::new(); #[inline(always)] pub(crate) fn get_poseidon16() -> &'static Poseidon16 { POSEIDON16_INSTANCE.get_or_init(|| { - let round_constants = build_poseidon16_constants(); let external_constants = ExternalLayerConstants::new( - round_constants.beginning_full_round_constants.to_vec(), - round_constants.ending_full_round_constants.to_vec(), + KOALABEAR_RC16_EXTERNAL_INITIAL.to_vec(), + KOALABEAR_RC16_EXTERNAL_FINAL.to_vec(), ); - Poseidon16::new( - external_constants, - round_constants.partial_round_constants.to_vec(), - ) + Poseidon16::new(external_constants, KOALABEAR_RC16_INTERNAL.to_vec()) }) } @@ -69,30 +60,10 @@ static POSEIDON24_INSTANCE: OnceLock = OnceLock::new(); #[inline(always)] pub(crate) fn get_poseidon24() -> &'static Poseidon24 { POSEIDON24_INSTANCE.get_or_init(|| { - let round_constants = build_poseidon24_constants(); let external_constants = ExternalLayerConstants::new( - round_constants.beginning_full_round_constants.to_vec(), - round_constants.ending_full_round_constants.to_vec(), + KOALABEAR_RC24_EXTERNAL_INITIAL.to_vec(), + KOALABEAR_RC24_EXTERNAL_FINAL.to_vec(), ); - Poseidon24::new( - external_constants, - round_constants.partial_round_constants.to_vec(), - ) + Poseidon24::new(external_constants, KOALABEAR_RC24_INTERNAL.to_vec()) }) } - -pub fn build_poseidon16_constants() -> MyRoundConstants16 { - RoundConstants16 { - beginning_full_round_constants: KOALABEAR_RC16_EXTERNAL_INITIAL, - partial_round_constants: KOALABEAR_RC16_INTERNAL, - ending_full_round_constants: KOALABEAR_RC16_EXTERNAL_FINAL, - } -} - -pub fn build_poseidon24_constants() -> MyRoundConstants24 { - RoundConstants24 { - beginning_full_round_constants: KOALABEAR_RC24_EXTERNAL_INITIAL, - partial_round_constants: KOALABEAR_RC24_INTERNAL, - ending_full_round_constants: KOALABEAR_RC24_EXTERNAL_FINAL, - } -} diff --git a/crates/utils/src/wrappers.rs b/crates/utils/src/wrappers.rs index d9adc483..154e2c69 100644 --- a/crates/utils/src/wrappers.rs +++ b/crates/utils/src/wrappers.rs @@ -1,7 +1,5 @@ use multilinear_toolkit::prelude::*; use p3_challenger::DuplexChallenger; -use p3_field::ExtensionField; -use p3_field::PrimeField64; use p3_koala_bear::KoalaBear; use crate::Poseidon16; diff --git a/crates/xmss/Cargo.toml b/crates/xmss/Cargo.toml index d8406099..29fec5eb 100644 --- a/crates/xmss/Cargo.toml +++ b/crates/xmss/Cargo.toml @@ -8,7 +8,7 @@ workspace = true [dependencies] p3-koala-bear.workspace = true -p3-field.workspace = true rand.workspace = true utils.workspace = true p3-util.workspace = true +multilinear-toolkit.workspace = true diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs index 43003363..2b5ad6bf 100644 --- a/crates/xmss/src/wots.rs +++ b/crates/xmss/src/wots.rs @@ -1,4 +1,4 @@ -use p3_field::PrimeCharacteristicRing; +use multilinear_toolkit::prelude::*; use p3_util::log2_strict_usize; use rand::{Rng, RngCore}; use utils::{ToUsize, to_little_endian_bits}; diff --git a/docs/benchmark_graphs/graphs/recursive_whir_opening.svg b/docs/benchmark_graphs/graphs/recursive_whir_opening.svg index 17ad03cc..46021aea 100644 --- a/docs/benchmark_graphs/graphs/recursive_whir_opening.svg +++ b/docs/benchmark_graphs/graphs/recursive_whir_opening.svg @@ -1,12 +1,12 @@ - + - 2025-10-27T17:01:34.332924 + 2025-11-15T22:38:00.887118 image/svg+xml @@ -21,8 +21,8 @@ - - - +" clip-path="url(#pf5ac6a0352)" style="fill: none; stroke: #b0b0b0; stroke-opacity: 0.3; stroke-width: 0.8; stroke-linecap: square"/> - - + - + - + - + - + - + - + - + @@ -266,18 +266,18 @@ L 181.671111 34.3575 - + - + - + - + - + - + - + - + - + @@ -382,18 +382,18 @@ L 369.379028 34.3575 - + - + - + - + - + - + - + - + - + @@ -481,18 +481,18 @@ L 557.086944 34.3575 - + - + - + @@ -503,18 +503,18 @@ L 619.65625 34.3575 - + - + - + @@ -523,27 +523,142 @@ L 682.225556 34.3575 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - - + - + - + - - + + - + - + - + - + @@ -582,24 +697,24 @@ L 700.1025 142.244292 - - + + - + - - + - + - + @@ -607,19 +722,19 @@ L -2 0 - - + + - + - + - + - + @@ -627,19 +742,19 @@ L 700.1025 292.184322 - - + + - + - + - + - + @@ -647,19 +762,19 @@ L 700.1025 254.409985 - - + + - + - + - + - + @@ -667,46 +782,19 @@ L 700.1025 227.608666 - - + + - + - + - + - - - - + @@ -714,19 +802,19 @@ z - - + + - + - + - + - + @@ -734,19 +822,19 @@ L 700.1025 189.834329 - - + + - + - + - + - + @@ -754,19 +842,19 @@ L 700.1025 175.473193 - - + + - + - + - + - + @@ -774,19 +862,19 @@ L 700.1025 163.03301 - - + + - + - + - + - + @@ -794,19 +882,19 @@ L 700.1025 152.059992 - - + + - + - + - + - + @@ -814,28 +902,28 @@ L 700.1025 77.668637 - - + + - + - + - + - + - + - + - + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - + + - - - - + + + + - - + + - - - @@ -1210,7 +1302,7 @@ L 700.1025 372.112477 L 700.1025 34.3575 " style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/> - + @@ -1440,16 +1532,16 @@ Q 571.44375 96.79875 573.84375 96.79875 z " style="fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter"/> - + - + - + @@ -1472,16 +1564,16 @@ z - + - + - + @@ -1513,13 +1605,13 @@ z - + - + @@ -1555,8 +1647,8 @@ z - - + + diff --git a/docs/benchmark_graphs/graphs/xmss_aggregated.svg b/docs/benchmark_graphs/graphs/xmss_aggregated.svg index 9f4b11b2..c7580dd0 100644 --- a/docs/benchmark_graphs/graphs/xmss_aggregated.svg +++ b/docs/benchmark_graphs/graphs/xmss_aggregated.svg @@ -1,12 +1,12 @@ - + - 2025-11-02T01:33:52.153182 + 2025-11-15T22:38:01.014383 image/svg+xml @@ -21,8 +21,8 @@ - - - +" clip-path="url(#pca8fa1ea92)" style="fill: none; stroke: #b0b0b0; stroke-opacity: 0.3; stroke-width: 0.8; stroke-linecap: square"/> - - + - + - + - + - + - + - + - + @@ -266,18 +266,18 @@ L 172.441389 34.3575 - + - + - + - + - + - + - + - + - + @@ -382,18 +382,18 @@ L 363.993472 34.3575 - + - + - + - + - + - + - + - + - + @@ -481,18 +481,18 @@ L 555.545556 34.3575 - + - + - + @@ -503,18 +503,18 @@ L 619.39625 34.3575 - + - + - + @@ -523,45 +523,160 @@ L 683.246944 34.3575 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - - + - + - + - - + + - + - + - + - + @@ -569,19 +684,19 @@ L 701.49 310.702481 - - + + - + - + - + - + @@ -589,19 +704,19 @@ L 701.49 249.292485 - - + + - + - + - + - + @@ -609,19 +724,19 @@ L 701.49 187.882489 - - + + - + - + - + - + @@ -629,19 +744,19 @@ L 701.49 126.472494 - - + + - + - + - + - + @@ -650,25 +765,26 @@ L 701.49 65.062498 - - + + - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - + + - - - - - + + + + + - - + + - - - @@ -742,7 +861,7 @@ L 701.49 372.112477 L 701.49 34.3575 " style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/> - + @@ -1169,30 +1288,30 @@ z - - - + - + - + - + - - + - + - + - + - - + - + - + - + + diff --git a/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg b/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg index 367a9dc8..fb5888b6 100644 --- a/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg +++ b/docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg @@ -1,12 +1,12 @@ - + - 2025-11-02T01:33:52.249035 + 2025-11-15T22:38:01.125000 image/svg+xml @@ -21,8 +21,8 @@ - - - +" clip-path="url(#p52f0bb76c4)" style="fill: none; stroke: #b0b0b0; stroke-opacity: 0.3; stroke-width: 0.8; stroke-linecap: square"/> - - + - + - + - + - + - + - + - + @@ -266,18 +266,18 @@ L 169.337639 34.3575 - + - + - + - + - + - + - + - + - + @@ -382,18 +382,18 @@ L 361.965972 34.3575 - + - + - + - + - + - + - + - + - + @@ -481,18 +481,18 @@ L 554.594306 34.3575 - + - + - + @@ -503,18 +503,18 @@ L 618.80375 34.3575 - + - + - + @@ -523,56 +523,64 @@ L 683.013194 34.3575 - - - + - + - - - - + - - - - - - - - + + + + + + + - + - + - + - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -607,19 +695,19 @@ z - - + + - + - + - + - + @@ -627,19 +715,19 @@ L 701.35875 278.860261 - - + + - + - + - + - + @@ -647,19 +735,19 @@ L 701.35875 232.234153 - - + + - + - + - + - + @@ -668,19 +756,19 @@ L 701.35875 185.608045 - - + + - + - + - + - + @@ -689,19 +777,19 @@ L 701.35875 138.981937 - - + + - + - + - + - + @@ -710,19 +798,19 @@ L 701.35875 92.355829 - - + + - + - + - + - + @@ -731,26 +819,27 @@ L 701.35875 45.729721 - - + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - + + - - - - - - + + + + + + - - + + - - - @@ -827,7 +919,7 @@ L 701.35875 372.112477 L 701.35875 34.3575 " style="fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square"/> - + @@ -1302,30 +1394,30 @@ z - - - + - + - + - + - - + - + - + - + - - + - + - + - + + diff --git a/docs/benchmark_graphs/main.py b/docs/benchmark_graphs/main.py index 085d25ee..9cc81efe 100644 --- a/docs/benchmark_graphs/main.py +++ b/docs/benchmark_graphs/main.py @@ -5,7 +5,7 @@ # uv run python main.py -N_DAYS_SHOWN = 70 +N_DAYS_SHOWN = 100 plt.rcParams.update({ 'font.size': 12, # Base font size @@ -149,6 +149,7 @@ def create_duration_graph(data, target=None, target_label=None, title="", y_lege ('2025-10-13', 0.521, None), ('2025-10-18', 0.411, 0.320), ('2025-10-27', 0.425, 0.330), + ('2025-11-15', 0.417, 0.330), ], target=0.1, target_label="Target (0.1 s)", @@ -177,6 +178,7 @@ def create_duration_graph(data, target=None, target_label=None, title="", y_lege ('2025-10-18', 255, 465), ('2025-10-27', 314, 555), ('2025-11-02', 350, 660), + ('2025-11-15', 380, 720), ], target=1000, target_label="Target (1000 XMSS/s)", @@ -205,6 +207,7 @@ def create_duration_graph(data, target=None, target_label=None, title="", y_lege ('2025-10-27', (610_000 / 157) / 314, (1_250_000 / 157) / 555), ('2025-10-29', (650_000 / 157) / 314, (1_300_000 / 157) / 555), ('2025-11-02', (650_000 / 157) / 350, (1_300_000 / 157) / 660), + ('2025-11-15', (650_000 / 157) / 380, (1_300_000 / 157) / 720), ], target=2, target_label="Target (2x)",