Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPT review #9

Open
wants to merge 21 commits into
base: mpt-refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion zkevm-circuits/src/circuit_tools/constraint_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl<F: Field> ConstraintBuilder<F> {
meta: &mut ConstraintSystem<F>,
lookup_names: &[S],
) {
for lookup_name in lookup_names.iter() {
for lookup_name in lookup_names.iter() { // "keccek", "fixed", "s_parent", ...
let lookups = self
.lookups
.iter()
Expand All @@ -186,7 +186,9 @@ impl<F: Field> ConstraintBuilder<F> {
.collect::<Vec<_>>();
for lookup in lookups.iter() {
meta.lookup_any(lookup.description, |_meta| {
// 拿对应的表
let table = self.get_lookup_table_values(lookup_name);
// 拿要查的值
let mut values: Vec<_> = lookup
.values
.iter()
Expand Down
140 changes: 110 additions & 30 deletions zkevm-circuits/src/mpt_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub struct MPTConfig<F> {
pub(crate) managed_columns: Vec<Column<Advice>>,
pub(crate) memory: Memory<F>,
keccak_table: KeccakTable,
fixed_table: [Column<Fixed>; 3],
fixed_table: [Column<Fixed>; 5],
state_machine: StateMachineConfig<F>,
pub(crate) is_start: Column<Advice>,
pub(crate) is_branch: Column<Advice>,
Expand All @@ -127,6 +127,10 @@ pub enum FixedTableTag {
RangeKeyLen256,
/// For checking there are 0s after the RLP stream ends
RangeKeyLen16,
/// For checking RLP
RLP,
/// For distinguishing odd key part in extension
ExtOddKey,
}

impl_expr!(FixedTableTag);
Expand Down Expand Up @@ -164,7 +168,7 @@ impl<F: Field> MPTConfig<F> {

let main = MainCols::new(meta);

let fixed_table: [Column<Fixed>; 3] = (0..3)
let fixed_table: [Column<Fixed>; 5] = (0..5)
.map(|_| meta.fixed_column())
.collect::<Vec<_>>()
.try_into()
Expand Down Expand Up @@ -280,7 +284,7 @@ impl<F: Field> MPTConfig<F> {
cb.base
.generate_lookups(meta, &vec!["fixed".to_string(), "keccak".to_string()]);
} else if disable_lookups == 4 {
cb.base.generate_lookups(meta, &vec!["keccak".to_string()]);
cb.base.generate_lookups(meta, &vec!["fixed".to_string()]);
}

println!("num lookups: {}", meta.lookups().len());
Expand Down Expand Up @@ -870,6 +874,73 @@ impl<F: Field> MPTConfig<F> {
offset += 1;
}

// Rlp prefixes table [rlp_tag, byte, is_string, is_short, is_verylong]
for ind in 0..=127 {
Comment on lines +957 to +958
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you do this a bit more like in the witness generation with a match statement for example. Will be quite a bit shorter, less duplicated code (like setting the tag to the same value for each case) and that also uses the RLP constants like RLP_SHORT etc... so fewer magic values.

// short string
assignf!(region, (self.fixed_table[0], offset) => FixedTableTag::RLP.scalar())?;
assignf!(region, (self.fixed_table[1], offset) => ind.scalar())?;
assignf!(region, (self.fixed_table[2], offset) => true.scalar())?;
assignf!(region, (self.fixed_table[3], offset) => true.scalar())?;
assignf!(region, (self.fixed_table[4], offset) => false.scalar())?;
offset += 1;
}
for ind in 128..=183 {
// long string
assignf!(region, (self.fixed_table[0], offset) => FixedTableTag::RLP.scalar())?;
assignf!(region, (self.fixed_table[1], offset) => ind.scalar())?;
assignf!(region, (self.fixed_table[2], offset) => true.scalar())?;
assignf!(region, (self.fixed_table[3], offset) => false.scalar())?;
assignf!(region, (self.fixed_table[4], offset) => false.scalar())?;
offset += 1;
}
for ind in 184..=191 {
// very long string
assignf!(region, (self.fixed_table[0], offset) => FixedTableTag::RLP.scalar())?;
assignf!(region, (self.fixed_table[1], offset) => ind.scalar())?;
assignf!(region, (self.fixed_table[2], offset) => true.scalar())?;
assignf!(region, (self.fixed_table[3], offset) => false.scalar())?;
assignf!(region, (self.fixed_table[4], offset) => true.scalar())?;
offset += 1;
}
for ind in 192..=247 {
// short list
assignf!(region, (self.fixed_table[0], offset) => FixedTableTag::RLP.scalar())?;
assignf!(region, (self.fixed_table[1], offset) => ind.scalar())?;
assignf!(region, (self.fixed_table[2], offset) => false.scalar())?;
assignf!(region, (self.fixed_table[3], offset) => true.scalar())?;
assignf!(region, (self.fixed_table[4], offset) => false.scalar())?;
offset += 1;
}
// 248
// long list
assignf!(region, (self.fixed_table[0], offset) => FixedTableTag::RLP.scalar())?;
assignf!(region, (self.fixed_table[1], offset) => 248i32.scalar())?;
assignf!(region, (self.fixed_table[2], offset) => false.scalar())?;
assignf!(region, (self.fixed_table[3], offset) => false.scalar())?;
assignf!(region, (self.fixed_table[4], offset) => false.scalar())?;
offset += 1;
// 249
// very long list
assignf!(region, (self.fixed_table[0], offset) => FixedTableTag::RLP.scalar())?;
assignf!(region, (self.fixed_table[1], offset) => 249i32.scalar())?;
assignf!(region, (self.fixed_table[2], offset) => false.scalar())?;
assignf!(region, (self.fixed_table[3], offset) => false.scalar())?;
assignf!(region, (self.fixed_table[4], offset) => true.scalar())?;
offset += 1;

// KEY_PREFIX_ODD: u8 = 0b0001_0000;
// first byte of ext key >> 4 == 1
for ind in 0..256 {
assignf!(region, (self.fixed_table[0], offset) => FixedTableTag::ExtOddKey.scalar())?;
assignf!(region, (self.fixed_table[1], offset) => ind.scalar())?;
if 16 <= ind && ind < 32 {
assignf!(region, (self.fixed_table[2], offset) => true.scalar())?;
} else {
assignf!(region, (self.fixed_table[2], offset) => false.scalar())?;
}
offset += 1;
}

Ok(())
},
)
Expand Down Expand Up @@ -934,11 +1005,14 @@ mod tests {
dev::MockProver,
halo2curves::{bn256::Fr, FieldExt},
};

use std::fs;
use std::{fs, env::VarError};

#[test]
fn test_mpt() {
let only_run = var("ONLY_RUN")
.and_then(|idx| idx.parse::<usize>().map_err(|e|VarError::NotPresent)
).ok();
println!("{:?}", only_run);
// for debugging:
let path = "src/mpt_circuit/tests";
// let path = "tests";
Expand All @@ -954,31 +1028,37 @@ mod tests {
})
.enumerate()
.for_each(|(idx, f)| {
let path = f.path();
let mut parts = path.to_str().unwrap().split('-');
parts.next();
let file = std::fs::File::open(path.clone());
let reader = std::io::BufReader::new(file.unwrap());
let w: Vec<Vec<u8>> = serde_json::from_reader(reader).unwrap();

let count = w.iter().filter(|r| r[r.len() - 1] != 5).count() * 2;
let randomness: Fr = 123456789.scalar();
let instance: Vec<Vec<Fr>> = (1..HASH_WIDTH + 1)
.map(|exp| vec![randomness.pow(&[exp as u64, 0, 0, 0]); count])
.collect();

let circuit = MPTCircuit::<Fr> {
witness: w.clone(),
randomness,
};

println!("{} {:?}", idx, path);
// let prover = MockProver::run(9, &circuit, vec![pub_root]).unwrap();
let num_rows = w.len() * 2;
let prover = MockProver::run(14 /* 9 */, &circuit, instance).unwrap();
assert_eq!(prover.verify_at_rows(0..num_rows, 0..num_rows,), Ok(()));
//assert_eq!(prover.verify_par(), Ok(()));
//prover.assert_satisfied();
let mut run = true;
if let Some(i) = only_run {
if idx != i {run = false;}
}
if run {
let path = f.path();
let mut parts = path.to_str().unwrap().split('-');
parts.next();
let file = std::fs::File::open(path.clone());
let reader = std::io::BufReader::new(file.unwrap());
let w: Vec<Vec<u8>> = serde_json::from_reader(reader).unwrap();

let count = w.iter().filter(|r| r[r.len() - 1] != 5).count() * 2;
let randomness: Fr = 123456789.scalar();
let instance: Vec<Vec<Fr>> = (1..HASH_WIDTH + 1)
.map(|exp| vec![randomness.pow(&[exp as u64, 0, 0, 0]); count])
.collect();

let circuit = MPTCircuit::<Fr> {
witness: w.clone(),
randomness,
};

println!("{} {:?}", idx, path);
// let prover = MockProver::run(9, &circuit, vec![pub_root]).unwrap();
let num_rows = w.len() * 2;
let prover = MockProver::run(14 /* 9 */, &circuit, instance).unwrap();
assert_eq!(prover.verify_at_rows(0..num_rows, 0..num_rows,), Ok(()));
//assert_eq!(prover.verify_par(), Ok(()));
//prover.assert_satisfied();
}
});
}
}
9 changes: 8 additions & 1 deletion zkevm-circuits/src/mpt_circuit/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,15 @@ impl<F: Field> ExtensionGadget<F> {

config.rlp_key = ListKeyGadget::construct(&mut cb.base, &key_bytes[0]);
// TODO(Brecht): add lookup constraint
// let is_key_part_odd = key_bytes[0][rlp_key.key_value.num_rlp_bytes()] >> 4 == 1;
config.is_key_part_odd = cb.base.query_cell();

let odd_flag = matchx! {
config.rlp_key.key_value.is_short() => key_bytes[0][0].expr(),
config.rlp_key.key_value.is_long() => key_bytes[0][1].expr(),
config.rlp_key.key_value.is_very_long() => key_bytes[0][2].expr(),
};
require!((FixedTableTag::ExtOddKey.expr(), odd_flag, config.is_key_part_odd.expr()) => @"fixed");

// We need to check that the nibbles we stored in s are between 0 and 15.
cb.set_range_s(FixedTableTag::RangeKeyLen16.expr());

Expand Down
63 changes: 50 additions & 13 deletions zkevm-circuits/src/mpt_circuit/rlp_gadgets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ use crate::{
cell_manager::Cell,
constraint_builder::{ConstraintBuilder, RLCable, RLCableValue},
},
matchr, matchw,
mpt_circuit::param::{RLP_LIST_LONG, RLP_LIST_SHORT, RLP_SHORT},
matchr, matchw,
mpt_circuit::{
FixedTableTag,
param::{RLP_LIST_LONG, RLP_LIST_SHORT, RLP_SHORT}
},
util::Expr,
};
use eth_types::Field;
use gadgets::util::{not, Scalar};
use halo2_proofs::{
circuit::Region,
circuit::{Region, self},
plonk::{Error, Expression},
};

Expand All @@ -37,12 +40,26 @@ pub(crate) struct RLPListWitness {

impl<F: Field> RLPListGadget<F> {
pub(crate) fn construct(cb: &mut ConstraintBuilder<F>, bytes: &[Expression<F>]) -> Self {
// TODO(Brecht): add lookup/LtGadget
let is_short = cb.query_cell();
let is_long = cb.query_cell();
let is_very_long = cb.query_cell();
let is_string = cb.query_cell();

circuit!([meta, cb], {
require!(vec![
FixedTableTag::RLP.expr(),
bytes[0].clone(),
is_string.expr(),
is_short.expr(),
is_very_long.expr()] => @"fixed"
Comment on lines +52 to +54
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_long seems to be missing? I think that's actually smart to only do 2 out of 3 in the lookup table because it saves a cell and you can deduce one of the three values e.g. is_long is 1 - is_short - is_very_long, but I don't see that being done so not sure what's up? :)

);
});

RLPListGadget {
is_short: cb.query_cell(),
is_long: cb.query_cell(),
is_very_long: cb.query_cell(),
is_string: cb.query_cell(),
is_short,
is_long,
is_very_long,
is_string,
bytes: bytes.to_vec(),
}
}
Expand All @@ -61,8 +78,11 @@ impl<F: Field> RLPListGadget<F> {
let mut is_very_long = false;
let mut is_string = false;
match bytes[0] {
// 192 - 247
RLP_LIST_SHORT..=RLP_LIST_LONG => is_short = true,
// 248
RLP_LIST_LONG_1 => is_long = true,
// 249
RLP_LIST_LONG_2 => is_very_long = true,
_ => is_string = true,
}
Expand Down Expand Up @@ -261,12 +281,26 @@ pub(crate) struct RLPValueWitness {

impl<F: Field> RLPValueGadget<F> {
pub(crate) fn construct(cb: &mut ConstraintBuilder<F>, bytes: &[Expression<F>]) -> Self {
// TODO(Brecht): add lookup
let is_short = cb.query_cell();
let is_long = cb.query_cell();
let is_very_long = cb.query_cell();
let is_list = cb.query_cell();

circuit!([meta, cb], {
require!(vec![
FixedTableTag::RLP.expr(),
bytes[0].clone(),
not::expr(is_list.expr()),
is_short.expr(),
is_very_long.expr()] => @"fixed"
);
});

RLPValueGadget {
is_short: cb.query_cell(),
is_long: cb.query_cell(),
is_very_long: cb.query_cell(),
is_list: cb.query_cell(),
is_short,
is_long,
is_very_long,
is_list,
bytes: bytes.to_vec(),
}
}
Expand All @@ -286,8 +320,11 @@ impl<F: Field> RLPValueGadget<F> {
let mut is_very_long = false;
let mut is_list = false;
match bytes[0] {
// 0 - 127
0..=RLP_SHORT_INCLUSIVE => is_short = true,
// 128 - 183
RLP_SHORT..=RLP_LONG => is_long = true,
// 189 - 191
RLP_LONG_EXCLUSIVE..=RLP_VALUE_MAX => is_very_long = true,
_ => is_list = true,
}
Expand Down