Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/backends/mock_main/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ impl TryFrom<Statement> for middleware::Statement {

impl From<middleware::Statement> for Statement {
fn from(s: middleware::Statement) -> Self {
Statement(s.code(), s.args().into_iter().map(|arg| arg).collect())
match s.code() {
middleware::Predicate::Native(c) => {
Statement(c, s.args().into_iter().map(|arg| arg).collect())
}
// TODO: Custom statements
_ => todo!(),
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/frontend/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &HashOrWildcardStr) ->
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::PodType;
use crate::middleware::{CustomPredicateRef, PodType};

#[test]
fn test_custom_pred() {
Expand Down Expand Up @@ -204,7 +204,7 @@ mod tests {
println!("a.0. eth_friend = {}", builder.predicates.last().unwrap());
let eth_friend = builder.finish();
// This batch only has 1 predicate, so we pick it already for convenience
let eth_friend = Predicate::Custom(eth_friend, 0);
let eth_friend = Predicate::Custom(CustomPredicateRef(eth_friend, 0));

// next chunk builds:
// eth_dos_distance_base(src_or, src_key, dst_or, dst_key, distance_or, distance_key) = and<
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use anyhow::Result;
use itertools::Itertools;
use std::collections::HashMap;
use std::convert::From;
use std::fmt;
use std::{fmt, hash as h};

use crate::middleware::{
self,
Expand All @@ -22,15 +22,15 @@ pub use operation::*;
pub use statement::*;

/// This type is just for presentation purposes.
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)]
#[derive(Clone, Debug, Default, h::Hash, PartialEq, Eq)]
pub enum PodClass {
#[default]
Signed,
Main,
}

// An Origin, which represents a reference to an ancestor POD.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)]
#[derive(Clone, Debug, PartialEq, Eq, h::Hash, Default)]
pub struct Origin(pub PodClass, pub PodId);

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -166,7 +166,7 @@ impl SignedPod {
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, h::Hash)]
pub struct AnchoredKey(pub Origin, pub String);

impl From<AnchoredKey> for middleware::AnchoredKey {
Expand Down
89 changes: 78 additions & 11 deletions src/middleware/custom.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
use std::fmt;
use std::sync::Arc;
use std::{fmt, hash as h, iter::zip};

use super::{hash_str, Hash, NativePredicate, ToFields, Value, F};
use anyhow::{anyhow, Result};

use super::{
hash_str, AnchoredKey, Hash, NativePredicate, PodId, Statement, StatementArg, ToFields, Value,
F,
};

// BEGIN Custom 1b

#[derive(Debug)]
#[derive(Clone, Debug, PartialEq, Eq, h::Hash)]
pub enum HashOrWildcard {
Hash(Hash),
Wildcard(usize),
}

impl HashOrWildcard {
/// Matches a hash or wildcard against a value, returning a pair
/// representing a wildcard binding (if any) or an error if no
/// match is possible.
pub fn match_against(&self, v: &Value) -> Result<Option<(usize, Value)>> {
match self {
HashOrWildcard::Hash(h) if &Value::from(h.clone()) == v => Ok(None),
HashOrWildcard::Wildcard(i) => Ok(Some((*i, v.clone()))),
_ => Err(anyhow!("Failed to match {} against {}.", self, v)),
}
}
}

impl fmt::Display for HashOrWildcard {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expand All @@ -20,13 +38,32 @@ impl fmt::Display for HashOrWildcard {
}
}

#[derive(Debug)]
#[derive(Clone, Debug, PartialEq, Eq, h::Hash)]
pub enum StatementTmplArg {
None,
Literal(Value),
Key(HashOrWildcard, HashOrWildcard),
}

impl StatementTmplArg {
/// Matches a statement template argument against a statement
/// argument, returning a wildcard correspondence in the case of
/// one or more wildcard matches, nothing in the case of a
/// literal/hash match, and an error otherwise.
pub fn match_against(&self, s_arg: &StatementArg) -> Result<Vec<(usize, Value)>> {
match (self, s_arg) {
(Self::None, StatementArg::None) => Ok(vec![]),
(Self::Literal(v), StatementArg::Literal(w)) if v == w => Ok(vec![]),
(Self::Key(tmpl_o, tmpl_k), StatementArg::Key(AnchoredKey(PodId(o), k))) => {
let o_corr = tmpl_o.match_against(&o.clone().into())?;
let k_corr = tmpl_k.match_against(&k.clone().into())?;
Ok([o_corr, k_corr].into_iter().flat_map(|x| x).collect())
}
_ => Err(anyhow!("Failed to match {} against {}.", self, s_arg)),
}
}
}

impl fmt::Display for StatementTmplArg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expand All @@ -50,10 +87,37 @@ impl fmt::Display for StatementTmplArg {
// END

/// Statement Template for a Custom Predicate
#[derive(Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct StatementTmpl(pub Predicate, pub Vec<StatementTmplArg>);

#[derive(Debug)]
impl StatementTmpl {
pub fn pred(&self) -> &Predicate {
&self.0
}
pub fn args(&self) -> &[StatementTmplArg] {
&self.1
}
/// Matches a statement template against a statement, returning
/// the variable bindings as an association list. Returns an error
/// if there is type or argument mismatch.
pub fn match_against(&self, s: &Statement) -> Result<Vec<(usize, Value)>> {
type P = Predicate;
if matches!(self, Self(P::BatchSelf(_), _)) {
Err(anyhow!(
"Cannot check self-referencing statement templates."
))
} else if self.pred() != &s.code() {
Err(anyhow!("Type mismatch between {:?} and {}.", self, s))
} else {
zip(self.args(), s.args())
.map(|(t_arg, s_arg)| t_arg.match_against(&s_arg))
.collect::<Result<Vec<_>>>()
.map(|v| v.concat())
}
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CustomPredicate {
/// true for "and", false for "or"
pub conjunction: bool,
Expand Down Expand Up @@ -96,7 +160,7 @@ impl fmt::Display for CustomPredicate {
}
}

#[derive(Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CustomPredicateBatch {
pub name: String,
pub predicates: Vec<CustomPredicate>,
Expand All @@ -109,11 +173,14 @@ impl CustomPredicateBatch {
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CustomPredicateRef(pub Arc<CustomPredicateBatch>, pub usize);

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Predicate {
Native(NativePredicate),
BatchSelf(usize),
Custom(Arc<CustomPredicateBatch>, usize),
Custom(CustomPredicateRef),
}

impl From<NativePredicate> for Predicate {
Expand All @@ -127,7 +194,7 @@ impl ToFields for Predicate {
match self {
Self::Native(p) => p.to_fields(),
Self::BatchSelf(i) => Value::from(i as i64).to_fields(),
Self::Custom(_pb, _i) => todo!(), // TODO
Self::Custom(_) => todo!(), // TODO
}
}
}
Expand All @@ -137,7 +204,7 @@ impl fmt::Display for Predicate {
match self {
Self::Native(p) => write!(f, "{:?}", p),
Self::BatchSelf(i) => write!(f, "self.{}", i),
Self::Custom(pb, i) => write!(f, "{}.{}", pb.name, i),
Self::Custom(CustomPredicateRef(pb, i)) => write!(f, "{}.{}", pb.name, i),
}
}
}
9 changes: 8 additions & 1 deletion src/middleware/operation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::{anyhow, Result};

use super::Statement;
use super::{CustomPredicateRef, Statement};
use crate::middleware::{AnchoredKey, SELF};

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -42,6 +42,7 @@ pub enum Operation {
SumOf(Statement, Statement, Statement),
ProductOf(Statement, Statement, Statement),
MaxOf(Statement, Statement, Statement),
Custom(CustomPredicateRef, Vec<Statement>),
}

impl Operation {
Expand All @@ -64,6 +65,7 @@ impl Operation {
Self::SumOf(_, _, _) => SumOf,
Self::ProductOf(_, _, _) => ProductOf,
Self::MaxOf(_, _, _) => MaxOf,
Self::Custom(_, _) => todo!(),
}
}

Expand All @@ -85,6 +87,7 @@ impl Operation {
Self::SumOf(s1, s2, s3) => vec![s1, s2, s3],
Self::ProductOf(s1, s2, s3) => vec![s1, s2, s3],
Self::MaxOf(s1, s2, s3) => vec![s1, s2, s3],
Self::Custom(_, args) => args,
}
}
/// Forms operation from op-code and arguments.
Expand Down Expand Up @@ -171,6 +174,10 @@ impl Operation {
let v3: i64 = v3.clone().try_into()?;
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
}
(
Self::Custom(CustomPredicateRef(cpb, i), _args),
Custom(CustomPredicateRef(s_cpb, s_i), _s_args),
) if cpb == s_cpb && i == s_i => todo!(),
_ => Err(anyhow!(
"Invalid deduction: {:?} ⇏ {:#}",
self,
Expand Down
37 changes: 20 additions & 17 deletions src/middleware/statement.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use anyhow::{anyhow, Result};
use plonky2::field::types::Field;
use std::fmt;
use std::{collections::HashMap, fmt};
use strum_macros::FromRepr;

use super::{AnchoredKey, ToFields, Value, F};
use super::{AnchoredKey, CustomPredicateRef, Hash, Predicate, ToFields, Value, F};

pub const KEY_SIGNER: &str = "_signer";
pub const KEY_TYPE: &str = "_type";
pub const STATEMENT_ARG_F_LEN: usize = 8;

#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq, Hash)]
pub enum NativePredicate {
None = 0,
ValueOf = 1,
Expand All @@ -30,9 +30,8 @@ impl ToFields for NativePredicate {
}
}

// TODO: Incorporate custom statements into this enum.
/// Type encapsulating statements with their associated arguments.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Statement {
None,
ValueOf(AnchoredKey, Value),
Expand All @@ -45,25 +44,28 @@ pub enum Statement {
SumOf(AnchoredKey, AnchoredKey, AnchoredKey),
ProductOf(AnchoredKey, AnchoredKey, AnchoredKey),
MaxOf(AnchoredKey, AnchoredKey, AnchoredKey),
Custom(CustomPredicateRef, Vec<AnchoredKey>),
}

impl Statement {
pub fn is_none(&self) -> bool {
self == &Self::None
}
pub fn code(&self) -> NativePredicate {
pub fn code(&self) -> Predicate {
use Predicate::*;
match self {
Self::None => NativePredicate::None,
Self::ValueOf(_, _) => NativePredicate::ValueOf,
Self::Equal(_, _) => NativePredicate::Equal,
Self::NotEqual(_, _) => NativePredicate::NotEqual,
Self::Gt(_, _) => NativePredicate::Gt,
Self::Lt(_, _) => NativePredicate::Lt,
Self::Contains(_, _) => NativePredicate::Contains,
Self::NotContains(_, _) => NativePredicate::NotContains,
Self::SumOf(_, _, _) => NativePredicate::SumOf,
Self::ProductOf(_, _, _) => NativePredicate::ProductOf,
Self::MaxOf(_, _, _) => NativePredicate::MaxOf,
Self::None => Native(NativePredicate::None),
Self::ValueOf(_, _) => Native(NativePredicate::ValueOf),
Self::Equal(_, _) => Native(NativePredicate::Equal),
Self::NotEqual(_, _) => Native(NativePredicate::NotEqual),
Self::Gt(_, _) => Native(NativePredicate::Gt),
Self::Lt(_, _) => Native(NativePredicate::Lt),
Self::Contains(_, _) => Native(NativePredicate::Contains),
Self::NotContains(_, _) => Native(NativePredicate::NotContains),
Self::SumOf(_, _, _) => Native(NativePredicate::SumOf),
Self::ProductOf(_, _, _) => Native(NativePredicate::ProductOf),
Self::MaxOf(_, _, _) => Native(NativePredicate::MaxOf),
Self::Custom(cpr, _) => Custom(cpr.clone()),
}
}
pub fn args(&self) -> Vec<StatementArg> {
Expand All @@ -80,6 +82,7 @@ impl Statement {
Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(|h| Key(h))),
}
}
}
Expand Down
Loading