Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/backends/plonky2/circuits/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ impl CustomPredicateEntryTarget {
conjunction: predicate.conjunction,
statements,
args_len: predicate.args_len,
wildcard_names: predicate.wildcard_names.clone(),
};
self.predicate.set_targets(pw, params, &predicate)?;
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/backends/plonky2/mainpod/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl fmt::Display for Statement {
if i != 0 {
write!(f, " ")?;
}
write!(f, "{}", arg)?;
arg.fmt(f)?;
}
}
Ok(())
Expand Down
17 changes: 4 additions & 13 deletions src/examples/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn eth_friend_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredica
],
)?;

println!("a.0. eth_friend = {}", builder.predicates.last().unwrap());
println!("a.0. {}", builder.predicates.last().unwrap());
Ok(builder.finish())
}

Expand Down Expand Up @@ -75,10 +75,7 @@ pub fn eth_dos_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredicateB
.arg(literal(0)),
],
)?;
println!(
"b.0. eth_dos_distance_base = {}",
builder.predicates.last().unwrap()
);
println!("b.0. {}", builder.predicates.last().unwrap());

let eth_dos_distance = Predicate::BatchSelf(2);

Expand Down Expand Up @@ -115,10 +112,7 @@ pub fn eth_dos_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredicateB
],
)?;

println!(
"b.1. eth_dos_distance_ind = {}",
builder.predicates.last().unwrap()
);
println!("b.1. {}", builder.predicates.last().unwrap());

let _eth_dos_distance = builder.predicate_or(
"eth_dos_distance",
Expand All @@ -136,10 +130,7 @@ pub fn eth_dos_batch(params: &Params, mock: bool) -> Result<Arc<CustomPredicateB
],
)?;

println!(
"b.2. eth_dos_distance = {}",
builder.predicates.last().unwrap()
);
println!("b.2. {}", builder.predicates.last().unwrap());

Ok(builder.finish())
}
4 changes: 4 additions & 0 deletions src/frontend/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ impl CustomPredicateBatchBuilder {
conjunction,
statements,
args.len(),
args.iter()
.chain(priv_args.iter())
.map(|s| s.to_string())
.collect(),
)?;
self.predicates.push(custom_predicate);
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
Expand Down
20 changes: 19 additions & 1 deletion src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ mod tests {
StatementTmplArg::Literal(value.into())
}

fn names(names: &[&str]) -> Vec<String> {
names.iter().map(|s| s.to_string()).collect()
}

#[test]
fn test_e2e_simple_predicate() -> Result<(), LangError> {
let input = r#"
Expand Down Expand Up @@ -86,6 +90,7 @@ mod tests {
"is_equal".to_string(),
expected_statements,
2, // args_len (PodA, PodB)
names(&["PodA", "PodB"]),
)?;
let expected_batch =
CustomPredicateBatch::new(&params, "PodlogBatch".to_string(), vec![expected_predicate]);
Expand Down Expand Up @@ -180,6 +185,7 @@ mod tests {
"uses_private".to_string(),
expected_statements,
1, // args_len (A)
names(&["A", "Temp"]),
)?;
let expected_batch =
CustomPredicateBatch::new(&params, "PodlogBatch".to_string(), vec![expected_predicate]);
Expand Down Expand Up @@ -226,6 +232,7 @@ mod tests {
"my_pred".to_string(),
expected_pred_statements,
2, // args_len (X, Y)
names(&["X", "Y"]),
)?;
let expected_batch =
CustomPredicateBatch::new(&params, "PodlogBatch".to_string(), vec![expected_predicate]);
Expand Down Expand Up @@ -516,7 +523,7 @@ mod tests {
eth_friend(?intermed_key, ?dst_key)
)

eth_dos_distance(src_key, dst_key, distance_key, private: intermed_key, shorter_distance_key) = OR(
eth_dos_distance(src_key, dst_key, distance_key) = OR(
eth_dos_distance_base(?src_key, ?dst_key, ?distance_key)
eth_dos_distance_ind(?src_key, ?dst_key, ?distance_key)
)
Expand Down Expand Up @@ -566,6 +573,7 @@ mod tests {
true, // AND
expected_friend_stmts,
2, // public_args_len: src_key, dst_key
names(&["src_key", "dst_key", "attestation_pod"]),
)?;

// eth_dos_distance_base (Index 1)
Expand All @@ -588,6 +596,7 @@ mod tests {
true, // AND
expected_base_stmts,
3, // public_args_len
names(&["src_key", "dst_key", "distance_key"]),
)?;

// eth_dos_distance_ind (Index 2)
Expand Down Expand Up @@ -630,6 +639,14 @@ mod tests {
true, // AND
expected_ind_stmts,
3, // public_args_len
names(&[
"src_key",
"dst_key",
"distance_key",
"one_key",
"shorter_distance_key",
"intermed_key",
]),
)?;

// eth_dos_distance (Index 3)
Expand Down Expand Up @@ -659,6 +676,7 @@ mod tests {
false, // OR
expected_dist_stmts,
3, // public_args_len
names(&["src_key", "dst_key", "distance_key"]),
)?;

let expected_batch = CustomPredicateBatch::new(
Expand Down
87 changes: 64 additions & 23 deletions src/middleware/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ impl Wildcard {

impl fmt::Display for Wildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "*{}[{}]", self.index, self.name)
if f.alternate() {
write!(f, "?{}:{}", self.index, self.name)
} else {
write!(f, "?{}", self.name)
}
}
}

Expand All @@ -43,8 +47,8 @@ pub enum KeyOrWildcard {
impl fmt::Display for KeyOrWildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Key(k) => write!(f, "{}", k),
Self::Wildcard(wc) => write!(f, "{}", wc),
Self::Key(k) => k.fmt(f),
Self::Wildcard(wc) => wc.fmt(f),
}
}
}
Expand Down Expand Up @@ -75,7 +79,7 @@ impl fmt::Display for SelfOrWildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::SELF => write!(f, "SELF"),
Self::Wildcard(wc) => write!(f, "{}", wc),
Self::Wildcard(wc) => wc.fmt(f),
}
}
}
Expand Down Expand Up @@ -166,9 +170,14 @@ impl fmt::Display for StatementTmplArg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::Literal(v) => write!(f, "{}", v),
Self::AnchoredKey(pod_id, key) => write!(f, "({}, {})", pod_id, key),
Self::WildcardLiteral(v) => write!(f, "{}", v),
Self::Literal(v) => v.fmt(f),
Self::AnchoredKey(pod_id, key) => {
pod_id.fmt(f)?;
write!(f, "[")?;
key.fmt(f)?;
write!(f, "]")
}
Self::WildcardLiteral(v) => v.fmt(f),
}
}
}
Expand All @@ -191,12 +200,13 @@ impl StatementTmpl {

impl fmt::Display for StatementTmpl {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}(", self.pred)?;
self.pred.fmt(f)?;
write!(f, "(")?;
for (i, arg) in self.args.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
arg.fmt(f)?;
}
writeln!(f)
}
Expand Down Expand Up @@ -240,6 +250,9 @@ pub struct CustomPredicate {
pub(crate) conjunction: bool,
pub(crate) statements: Vec<StatementTmpl>,
pub(crate) args_len: usize,
/// Names of the wildcards, the first `args_len` entries correspond to the `args_len` arguments
/// of the custom predicate.
pub(crate) wildcard_names: Vec<String>,
// TODO: Add private args length?
// TODO: Add args type information?
}
Expand All @@ -254,30 +267,34 @@ impl CustomPredicate {
args: vec![],
}],
args_len: 0,
wildcard_names: vec![],
}
}
pub fn and(
params: &Params,
name: String,
statements: Vec<StatementTmpl>,
args_len: usize,
wildcard_names: Vec<String>,
) -> Result<Self> {
Self::new(params, name, true, statements, args_len)
Self::new(params, name, true, statements, args_len, wildcard_names)
}
pub fn or(
params: &Params,
name: String,
statements: Vec<StatementTmpl>,
args_len: usize,
wildcard_names: Vec<String>,
) -> Result<Self> {
Self::new(params, name, false, statements, args_len)
Self::new(params, name, false, statements, args_len, wildcard_names)
}
pub fn new(
params: &Params,
name: String,
conjunction: bool,
statements: Vec<StatementTmpl>,
args_len: usize,
wildcard_names: Vec<String>,
) -> Result<Self> {
if statements.len() > params.max_custom_predicate_arity {
return Err(Error::max_length(
Expand All @@ -293,12 +310,20 @@ impl CustomPredicate {
params.max_statement_args,
));
}
if wildcard_names.len() > params.max_custom_predicate_wildcards {
return Err(Error::max_length(
"custom_predicate_wildcards.len".to_string(),
wildcard_names.len(),
params.max_custom_predicate_wildcards,
));
}

Ok(Self {
name,
conjunction,
statements,
args_len,
wildcard_names,
})
}
pub fn pad_statement_tmpl(&self) -> StatementTmpl {
Expand Down Expand Up @@ -346,25 +371,33 @@ impl ToFields for CustomPredicate {

impl fmt::Display for CustomPredicate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?;
write!(f, "{}(", self.name)?;
for (i, name) in self.wildcard_names.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
if i == self.args_len {
write!(f, "private: ")?;
}
if f.alternate() {
write!(f, "{}:", i)?;
}
write!(f, "{}", name)?;
}
writeln!(f, ") = {}(", if self.conjunction { "AND" } else { "OR" })?;
for st in &self.statements {
write!(f, " {}(", st.pred)?;
write!(f, " ")?;
st.pred.fmt(f)?;
write!(f, "(")?;
for (i, arg) in st.args.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
writeln!(f, "),")?;
}
write!(f, ">(")?;
for i in 0..self.args_len {
if i != 0 {
write!(f, ", ")?;
arg.fmt(f)?;
}
write!(f, "*{}", i)?;
writeln!(f, ")")?;
}
writeln!(f, ")")?;
write!(f, ")")?;
Ok(())
}
}
Expand Down Expand Up @@ -467,6 +500,9 @@ mod tests {
index: i,
}
}
fn names(names: &[&str]) -> Vec<String> {
names.iter().map(|s| s.to_string()).collect()
}

type STA = StatementTmplArg;
type KOW = KeyOrWildcard;
Expand Down Expand Up @@ -507,6 +543,7 @@ mod tests {
),
],
2,
names(&["1", "2", "3", "4", "5"]),
)?],
);

Expand Down Expand Up @@ -570,6 +607,7 @@ mod tests {
),
],
4,
names(&["1", "2", "3", "4"]),
)?;

let eth_friend_batch =
Expand All @@ -596,6 +634,7 @@ mod tests {
),
],
6,
names(&["0", "1", "2", "3", "4", "5"]),
)?;

// 1
Expand Down Expand Up @@ -640,6 +679,7 @@ mod tests {
),
],
6,
names(&["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"]),
)?;

// 2
Expand Down Expand Up @@ -671,6 +711,7 @@ mod tests {
),
],
6,
names(&["0", "1", "2", "3", "4", "5"]),
)?;

let eth_dos_distance_batch = CustomPredicateBatch::new(
Expand Down
Loading
Loading