Skip to content

Commit 9c9a2c4

Browse files
authored
Feat/fst order pred part1 & part2 (#454)
Implement support for first order predicates in the backend. Now a statement template can have a predicate hash or a wildcard. ## predicate <-> predicate hash constraints To build the custom predicate table we need to calculate the custom predicate batch id, which uses the serialization of the statement templates before normalization. This serialization uses the predicate hash when the template uses a predicate (instead of a wildcard). Then in normalization we recalculate the predicate hash if it was a Batch Self. This means that the relation between hash and predicate must be checked before and after normalization when the template is not using a wildcard. How this is achieved: - Before normalization: the constructor of StatementTmplTarget forces that if we keep a predicate, it's hash must be equal to the pred_hash when the template has a predicate (and not a wildcard) - After normalization: the predicate hash is calculated in the normalization and replaced in the case of the template using a predicate and it being a BatchSelf. If it was a predicate but not batch self, the old value was used which was constrained via the constructor. See `CircuitBuilder::add_virtual_statement_tmpl` and `normalize_st_tmpl_circuit` ## Wildcard predicate resolution It is done via `make_predicate_from_template_circuit` and is fairly simple as it's contains similar logic to `make_statement_arg_from_template_circuit` but simpler.
1 parent 1724e7b commit 9c9a2c4

File tree

11 files changed

+569
-240
lines changed

11 files changed

+569
-240
lines changed

src/backends/plonky2/circuits/common.rs

Lines changed: 173 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ use crate::{
3232
},
3333
middleware::{
3434
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, NativeOperation,
35-
NativePredicate, OperationType, Params, Predicate, PredicatePrefix, RawValue, StatementArg,
36-
StatementTmpl, StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F,
37-
HASH_SIZE, STATEMENT_ARG_F_LEN, VALUE_SIZE,
35+
NativePredicate, OperationType, Params, Predicate, PredicateOrWildcard,
36+
PredicateOrWildcardPrefix, PredicatePrefix, RawValue, StatementArg, StatementTmpl,
37+
StatementTmplArg, StatementTmplArgPrefix, ToFields, Value, EMPTY_VALUE, F, HASH_SIZE,
38+
STATEMENT_ARG_F_LEN, VALUE_SIZE,
3839
},
3940
};
4041

@@ -46,6 +47,22 @@ pub struct ValueTarget {
4647
pub elements: [Target; VALUE_SIZE],
4748
}
4849

50+
impl From<ValueTarget> for HashOutTarget {
51+
fn from(v: ValueTarget) -> HashOutTarget {
52+
HashOutTarget {
53+
elements: v.elements,
54+
}
55+
}
56+
}
57+
58+
impl From<HashOutTarget> for ValueTarget {
59+
fn from(h: HashOutTarget) -> ValueTarget {
60+
ValueTarget {
61+
elements: h.elements,
62+
}
63+
}
64+
}
65+
4966
impl ValueTarget {
5067
pub fn zero(builder: &mut CircuitBuilder) -> Self {
5168
Self {
@@ -524,18 +541,112 @@ impl StatementTmplArgTarget {
524541
}
525542
}
526543

544+
#[derive(Clone, Serialize, Deserialize)]
545+
pub struct PredicateHashOrWildcardTarget {
546+
/// layout: `prefix | [data]`, where data is predicate_hash or wildcard_index
547+
pub elements: [Target; Params::pred_hash_or_wc_size()],
548+
}
549+
550+
impl PredicateHashOrWildcardTarget {
551+
pub fn new(prefix: Target, data: ValueTarget) -> Self {
552+
let v = data.elements;
553+
Self {
554+
elements: [prefix, v[0], v[1], v[2], v[3]],
555+
}
556+
}
557+
pub fn new_pred_hash(builder: &mut CircuitBuilder, pred_hash: HashOutTarget) -> Self {
558+
Self::new(
559+
builder.constant(F::from(PredicateOrWildcardPrefix::Predicate)),
560+
ValueTarget::from(pred_hash),
561+
)
562+
}
563+
pub fn is_pred(&self, builder: &mut CircuitBuilder) -> BoolTarget {
564+
let prefix_pred = builder.constant(F::from(PredicateOrWildcardPrefix::Predicate));
565+
builder.is_equal(self.elements[0], prefix_pred)
566+
}
567+
pub fn data(&self) -> ValueTarget {
568+
ValueTarget {
569+
elements: self.elements[1..].try_into().expect("4 elements"),
570+
}
571+
}
572+
pub fn pred_hash(&self) -> HashOutTarget {
573+
HashOutTarget::from(self.data())
574+
}
575+
pub fn wc_index(&self) -> Target {
576+
self.elements[1]
577+
}
578+
pub fn set_targets_raw(
579+
&self,
580+
pw: &mut PartialWitness<F>,
581+
prefix: PredicateOrWildcardPrefix,
582+
data: RawValue,
583+
) -> Result<()> {
584+
pw.set_target(self.elements[0], F::from(prefix))?;
585+
pw.set_target_arr(&self.elements[1..], &data.0)?;
586+
Ok(())
587+
}
588+
pub fn set_targets(
589+
&self,
590+
pw: &mut PartialWitness<F>,
591+
params: &Params,
592+
pred: &PredicateOrWildcard,
593+
) -> Result<()> {
594+
match pred {
595+
PredicateOrWildcard::Predicate(pred) => {
596+
self.set_targets_raw(
597+
pw,
598+
PredicateOrWildcardPrefix::Predicate,
599+
RawValue::from(pred.hash(params)),
600+
)?;
601+
}
602+
PredicateOrWildcard::Wildcard(wc) => {
603+
self.set_targets_raw(
604+
pw,
605+
PredicateOrWildcardPrefix::Wildcard,
606+
RawValue([F::from_canonical_usize(wc.index), F::ZERO, F::ZERO, F::ZERO]),
607+
)?;
608+
}
609+
}
610+
Ok(())
611+
}
612+
}
613+
614+
impl Flattenable for PredicateHashOrWildcardTarget {
615+
fn flatten(&self) -> Vec<Target> {
616+
self.elements.to_vec()
617+
}
618+
fn from_flattened(_params: &Params, vs: &[Target]) -> Self {
619+
Self {
620+
elements: vs.try_into().expect("5 elements"),
621+
}
622+
}
623+
fn size(_params: &Params) -> usize {
624+
Params::pred_hash_or_wc_size()
625+
}
626+
}
627+
527628
#[derive(Clone, Serialize, Deserialize)]
528629
pub struct StatementTmplTarget {
630+
/// The preimage of the predicate_hash. This predicate is needed only to build the custom
631+
/// predicate table because it needs to normalize statement templates with predicates that
632+
/// refer to self into content-addressed predicates (using the batch id and index). The
633+
/// predicate type is inspected to do this normalization. After the table is built we only use
634+
/// the predicate hash for equality checks.
529635
pred: Option<PredicateTarget>,
530-
pred_hash: HashOutTarget,
636+
/// This is constrained to be `hash(pred)` through the type constructor when we have `pred`
637+
/// and the template uses a predicate and not a wildcard.
638+
pred_hash_or_wc: PredicateHashOrWildcardTarget,
531639
pub args: Vec<StatementTmplArgTarget>,
532640
}
533641

534642
impl StatementTmplTarget {
535-
pub fn new(pred_hash: HashOutTarget, args: Vec<StatementTmplArgTarget>) -> Self {
643+
pub fn new(
644+
pred_hash_or_wc: PredicateHashOrWildcardTarget,
645+
args: Vec<StatementTmplArgTarget>,
646+
) -> Self {
536647
Self {
537648
pred: None,
538-
pred_hash,
649+
pred_hash_or_wc,
539650
args,
540651
}
541652
}
@@ -546,9 +657,22 @@ impl StatementTmplTarget {
546657
st_tmpl: &StatementTmpl,
547658
) -> Result<()> {
548659
if let Some(pred) = &self.pred {
549-
pred.set_targets(pw, params, &st_tmpl.pred)?;
660+
match &st_tmpl.pred_or_wc {
661+
PredicateOrWildcard::Predicate(p) => {
662+
// We store a predicate (not a wildcard) and we have it available. In this
663+
// case the hash will be calculated by constraints later on and we should not
664+
// rely on the original data.
665+
pred.set_targets(pw, params, p)?
666+
}
667+
PredicateOrWildcard::Wildcard(_wc) => {
668+
// Fill in with a recognizable constant for better debugging; this value is
669+
// not supposed to be used.
670+
pw.set_target_arr(&pred.elements, &[F(0xdead); Params::predicate_size()])?
671+
}
672+
}
550673
}
551-
pw.set_hash_target(self.pred_hash, HashOut::from(st_tmpl.pred.hash(params)))?;
674+
self.pred_hash_or_wc
675+
.set_targets(pw, params, &st_tmpl.pred_or_wc)?;
552676
let arg_pad = StatementTmplArg::None;
553677
for (i, arg) in st_tmpl
554678
.args
@@ -564,8 +688,8 @@ impl StatementTmplTarget {
564688
pub fn pred(&self) -> Option<&PredicateTarget> {
565689
self.pred.as_ref()
566690
}
567-
pub fn pred_hash(&self) -> &HashOutTarget {
568-
&self.pred_hash
691+
pub fn pred_hash_or_wc(&self) -> &PredicateHashOrWildcardTarget {
692+
&self.pred_hash_or_wc
569693
}
570694
}
571695

@@ -603,6 +727,8 @@ impl CustomPredicateTarget {
603727
}
604728
}
605729

730+
/// This type is used to build the custom predicate table, which exposes the custom predicates with
731+
/// normalized statement templates indexed by batch_id and custom_predicate_index.
606732
#[derive(Clone, Serialize, Deserialize)]
607733
pub struct CustomPredicateBatchTarget {
608734
pub predicates: Vec<CustomPredicateTarget>,
@@ -660,15 +786,17 @@ impl CustomPredicateEntryTarget {
660786
.clone()
661787
.into_iter()
662788
.map(|st_tmpl| {
663-
let pred = match st_tmpl.pred {
664-
Predicate::BatchSelf(i) => Predicate::Custom(CustomPredicateRef {
665-
batch: batch.clone(),
666-
index: i,
667-
}),
668-
p => p,
789+
let pred_or_wc = match st_tmpl.pred_or_wc {
790+
PredicateOrWildcard::Predicate(Predicate::BatchSelf(i)) => {
791+
PredicateOrWildcard::Predicate(Predicate::Custom(CustomPredicateRef {
792+
batch: batch.clone(),
793+
index: i,
794+
}))
795+
}
796+
x => x.clone(),
669797
};
670798
StatementTmpl {
671-
pred,
799+
pred_or_wc,
672800
args: st_tmpl.args,
673801
}
674802
})
@@ -724,15 +852,15 @@ pub struct CustomPredicateVerifyEntryTarget {
724852
}
725853

726854
impl CustomPredicateVerifyEntryTarget {
727-
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder, with_pred: bool) -> Self {
855+
pub fn new_virtual(params: &Params, builder: &mut CircuitBuilder) -> Self {
728856
let custom_predicate_table_len =
729857
params.max_custom_predicate_batches * params.max_custom_batch_size;
730858
CustomPredicateVerifyEntryTarget {
731859
custom_predicate_table_index: IndexTarget::new_virtual(
732860
custom_predicate_table_len,
733861
builder,
734862
),
735-
custom_predicate: builder.add_virtual_custom_predicate_entry(params, with_pred),
863+
custom_predicate: builder.add_virtual_custom_predicate_entry(params),
736864
args: (0..params.max_custom_predicate_wildcards)
737865
.map(|_| builder.add_virtual_value())
738866
.collect(),
@@ -1062,7 +1190,7 @@ impl Flattenable for CustomPredicateTarget {
10621190

10631191
impl Flattenable for StatementTmplTarget {
10641192
fn flatten(&self) -> Vec<Target> {
1065-
self.pred_hash
1193+
self.pred_hash_or_wc
10661194
.flatten()
10671195
.into_iter()
10681196
.chain(self.args.iter().flat_map(|sta| sta.flatten()))
@@ -1071,24 +1199,27 @@ impl Flattenable for StatementTmplTarget {
10711199

10721200
fn from_flattened(params: &Params, v: &[Target]) -> Self {
10731201
assert_eq!(v.len(), Self::size(params));
1074-
let pred_hash_end = HASH_SIZE;
1075-
let pred_hash = HashOutTarget::from_flattened(params, &v[..pred_hash_end]);
1202+
let pred_hash_or_wc_end = Params::pred_hash_or_wc_size();
1203+
let pred_hash_or_wc =
1204+
PredicateHashOrWildcardTarget::from_flattened(params, &v[..pred_hash_or_wc_end]);
10761205
let sta_size = Params::statement_tmpl_arg_size();
10771206
let args = (0..params.max_statement_args)
10781207
.map(|i| {
1079-
let sta_v = &v[pred_hash_end + sta_size * i..pred_hash_end + sta_size * (i + 1)];
1208+
let sta_v = &v
1209+
[pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * (i + 1)];
10801210
StatementTmplArgTarget::from_flattened(params, sta_v)
10811211
})
10821212
.collect();
10831213
Self {
10841214
pred: None,
1085-
pred_hash,
1215+
pred_hash_or_wc,
10861216
args,
10871217
}
10881218
}
10891219

10901220
fn size(params: &Params) -> usize {
1091-
HASH_SIZE + params.max_statement_args * StatementTmplArgTarget::size(params)
1221+
Params::pred_hash_or_wc_size()
1222+
+ params.max_statement_args * StatementTmplArgTarget::size(params)
10921223
}
10931224
}
10941225

@@ -1168,11 +1299,8 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
11681299
params: &Params,
11691300
with_pred: bool,
11701301
) -> CustomPredicateBatchTarget;
1171-
fn add_virtual_custom_predicate_entry(
1172-
&mut self,
1173-
params: &Params,
1174-
with_pred: bool,
1175-
) -> CustomPredicateEntryTarget;
1302+
fn add_virtual_custom_predicate_entry(&mut self, params: &Params)
1303+
-> CustomPredicateEntryTarget;
11761304
fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget;
11771305
fn select_statement_arg(
11781306
&mut self,
@@ -1320,24 +1448,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
13201448
}
13211449
}
13221450

1323-
/// If `with_pred = true` a predicate is included and its hash constrained.
1451+
/// If `with_pred = true` a predicate is included.
13241452
/// If `with_pred = false` only the predicate hash is included.
1453+
/// The pred_hash is constrained to be hash(pred) conditionally on the template using a
1454+
/// predicate and not a wildcard.
13251455
fn add_virtual_statement_tmpl(
13261456
&mut self,
13271457
params: &Params,
13281458
with_pred: bool,
13291459
) -> StatementTmplTarget {
1330-
let (pred, pred_hash) = if with_pred {
1460+
let pred_hash_or_wc =
1461+
PredicateHashOrWildcardTarget::new(self.add_virtual_target(), self.add_virtual_value());
1462+
let pred = if with_pred {
13311463
let pred = self.add_virtual_predicate();
13321464
let pred_hash = pred.hash(self);
1333-
(Some(pred), pred_hash)
1465+
let is_pred = pred_hash_or_wc.is_pred(self);
1466+
let data = pred_hash_or_wc.data();
1467+
for i in 0..VALUE_SIZE {
1468+
self.conditional_assert_eq(is_pred.target, data.elements[i], pred_hash.elements[i]);
1469+
}
1470+
Some(pred)
13341471
} else {
1335-
let pred_hash = self.add_virtual_hash();
1336-
(None, pred_hash)
1472+
None
13371473
};
13381474
StatementTmplTarget {
13391475
pred,
1340-
pred_hash,
1476+
pred_hash_or_wc,
13411477
args: (0..params.max_statement_args)
13421478
.map(|_| self.add_virtual_statement_tmpl_arg())
13431479
.collect(),
@@ -1377,12 +1513,11 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
13771513
fn add_virtual_custom_predicate_entry(
13781514
&mut self,
13791515
params: &Params,
1380-
with_pred: bool,
13811516
) -> CustomPredicateEntryTarget {
13821517
CustomPredicateEntryTarget {
13831518
id: self.add_virtual_hash(),
13841519
index: self.add_virtual_target(),
1385-
predicate: self.add_virtual_custom_predicate(params, with_pred),
1520+
predicate: self.add_virtual_custom_predicate(params, false),
13861521
}
13871522
}
13881523

0 commit comments

Comments
 (0)