Skip to content

Commit d335b37

Browse files
committed
Code review
1 parent 591202c commit d335b37

File tree

2 files changed

+117
-83
lines changed

2 files changed

+117
-83
lines changed

src/backends/plonky2/circuits/common.rs

Lines changed: 92 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,13 @@ pub struct ValueTarget {
2525
}
2626

2727
impl ValueTarget {
28-
pub fn zero<F: RichField + Extendable<D>, const D: usize>(
29-
builder: &mut CircuitBuilder<F, D>,
30-
) -> Self {
28+
pub fn zero(builder: &mut CircuitBuilder<F, D>) -> Self {
3129
Self {
3230
elements: [builder.zero(); VALUE_SIZE],
3331
}
3432
}
3533

36-
pub fn one<F: RichField + Extendable<D>, const D: usize>(
37-
builder: &mut CircuitBuilder<F, D>,
38-
) -> Self {
34+
pub fn one(builder: &mut CircuitBuilder<F, D>) -> Self {
3935
Self {
4036
elements: array::from_fn(|i| {
4137
if i == 0 {
@@ -81,27 +77,6 @@ impl StatementTarget {
8177
.collect(),
8278
}
8379
}
84-
pub fn to_flattened(&self) -> Vec<Target> {
85-
self.predicate
86-
.iter()
87-
.chain(self.args.iter().flatten())
88-
.cloned()
89-
.collect()
90-
}
91-
92-
pub fn from_flattened(v: Vec<Target>) -> Self {
93-
let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN;
94-
assert_eq!(
95-
v.len(),
96-
Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN
97-
);
98-
let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]);
99-
let args = (0..num_args)
100-
.map(|i| array::from_fn(|j| v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j]))
101-
.collect();
102-
103-
Self { predicate, args }
104-
}
10580

10681
pub fn set_targets(
10782
&self,
@@ -165,8 +140,42 @@ impl OperationTarget {
165140
builder: &mut CircuitBuilder<F, D>,
166141
t: NativeOperation,
167142
) -> BoolTarget {
143+
let one = builder.one();
144+
let op_is_native = builder.is_equal(self.op_type[0], one);
168145
let op_code = builder.constant(F::from_canonical_u64(t as u64));
169-
builder.is_equal(self.op_type[1], op_code)
146+
let op_code_matches = builder.is_equal(self.op_type[1], op_code);
147+
builder.and(op_is_native, op_code_matches)
148+
}
149+
}
150+
151+
/// Trait for target structs that may be converted to and from vectors
152+
/// of targets.
153+
pub trait Flattenable {
154+
fn flatten(&self) -> Vec<Target>;
155+
fn from_flattened(vs: &[Target]) -> Self;
156+
}
157+
158+
impl Flattenable for StatementTarget {
159+
fn flatten(&self) -> Vec<Target> {
160+
self.predicate
161+
.iter()
162+
.chain(self.args.iter().flatten())
163+
.cloned()
164+
.collect()
165+
}
166+
167+
fn from_flattened(v: &[Target]) -> Self {
168+
let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN;
169+
assert_eq!(
170+
v.len(),
171+
Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN
172+
);
173+
let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]);
174+
let args = (0..num_args)
175+
.map(|i| array::from_fn(|j| v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j]))
176+
.collect();
177+
178+
Self { predicate, args }
170179
}
171180
}
172181

@@ -183,23 +192,24 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
183192

184193
// Convenience methods for checking values.
185194
/// Checks whether `xs` is right-padded with 0s so as to represent a `Value`.
186-
fn is_value(&mut self, xs: &[Target]) -> BoolTarget;
195+
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget;
187196
/// Checks whether `x < y` if `b` is true. This involves checking
188197
/// that `x` and `y` each consist of two `u32` limbs.
189198
fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
190199

191-
// Convenience methods for randomly accessing vector elements and rows of matrices.
192-
fn vector_ref(&mut self, v: &[Target], i: Target) -> Target;
193-
fn matrix_row_ref(&mut self, m: &[Vec<Target>], i: Target) -> Vec<Target>;
200+
// Convenience methods for accessing and connecting elements of
201+
// (vectors of) flattenables.
202+
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T;
203+
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T;
204+
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
205+
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
194206

195207
// Convenience methods for Boolean into-iters.
196208
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
197209
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
198210
}
199211

200-
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
201-
for CircuitBuilder<F, D>
202-
{
212+
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
203213
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]) {
204214
assert_eq!(xs.len(), ys.len());
205215
for (x, y) in xs.iter().zip(ys.iter()) {
@@ -262,7 +272,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
262272
})
263273
}
264274

265-
fn is_value(&mut self, xs: &[Target]) -> BoolTarget {
275+
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget {
266276
let zeros = iter::repeat(self.zero())
267277
.take(STATEMENT_ARG_F_LEN - VALUE_SIZE)
268278
.collect::<Vec<_>>();
@@ -306,24 +316,54 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
306316
assert_limb_lt(self, lhs, rhs);
307317
}
308318

309-
// TODO: Revisit this when we need more than 64 statements.
310-
fn vector_ref(&mut self, v: &[Target], i: Target) -> Target {
311-
self.random_access(i, v.to_vec())
319+
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T {
320+
// TODO: Revisit this when we need more than 64 statements.
321+
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
322+
assert!(v.len() <= 64);
323+
builder.random_access(i, v.to_vec())
324+
};
325+
let matrix_row_ref = |builder: &mut CircuitBuilder<F, D>, m: &[Vec<Target>], i| {
326+
let num_rows = m.len();
327+
let num_columns = m
328+
.get(0)
329+
.map(|row| {
330+
let row_len = row.len();
331+
assert!(m.iter().all(|row| row.len() == row_len));
332+
row_len
333+
})
334+
.unwrap_or(0);
335+
(0..num_columns)
336+
.map(|j| {
337+
vector_ref(
338+
builder,
339+
&(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(),
340+
i,
341+
)
342+
})
343+
.collect::<Vec<_>>()
344+
};
345+
346+
let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>();
347+
T::from_flattened(&matrix_row_ref(self, &flattened_ts, i))
312348
}
313349

314-
fn matrix_row_ref(&mut self, m: &[Vec<Target>], i: Target) -> Vec<Target> {
315-
let num_rows = m.len();
316-
let num_columns = m
317-
.get(0)
318-
.map(|row| {
319-
let row_len = row.len();
320-
assert!(m.iter().all(|row| row.len() == row_len));
321-
row_len
322-
})
323-
.unwrap_or(0);
324-
(0..num_columns)
325-
.map(|j| self.vector_ref(&(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(), i))
326-
.collect()
350+
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T {
351+
let flattened_x = x.flatten();
352+
let flattened_y = y.flatten();
353+
354+
T::from_flattened(
355+
&iter::zip(flattened_x, flattened_y)
356+
.map(|(x, y)| self.select(b, x, y))
357+
.collect::<Vec<_>>(),
358+
)
359+
}
360+
361+
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) {
362+
self.connect_slice(&xs.flatten(), &ys.flatten())
363+
}
364+
365+
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget {
366+
self.is_equal_slice(&xs.flatten(), &ys.flatten())
327367
}
328368

329369
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {

src/backends/plonky2/circuits/mainpod.rs

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ use crate::middleware::{
2828
StatementArg, ToFields, KEY_TYPE, SELF,
2929
};
3030

31+
use super::common::Flattenable;
32+
3133
//
3234
// SignedPod verification
3335
//
@@ -147,17 +149,7 @@ impl OperationVerifyGate {
147149
op.args
148150
.iter()
149151
.flatten()
150-
.map(|&i| {
151-
StatementTarget::from_flattened(
152-
builder.matrix_row_ref(
153-
&prev_statements
154-
.iter()
155-
.map(|st_targ| st_targ.to_flattened())
156-
.collect::<Vec<_>>(),
157-
i,
158-
),
159-
)
160-
})
152+
.map(|&i| builder.vec_ref(prev_statements, i))
161153
.collect::<Vec<_>>()
162154
};
163155

@@ -180,8 +172,8 @@ impl OperationVerifyGate {
180172
} else {
181173
vec![
182174
self.eval_copy(builder, st, op, &resolved_op_args)?,
183-
self.eval_eq(builder, st, op, &resolved_op_args),
184-
self.eval_lt(builder, st, op, &resolved_op_args),
175+
self.eval_eq_from_entries(builder, st, op, &resolved_op_args),
176+
self.eval_lt_from_entries(builder, st, op, &resolved_op_args),
185177
]
186178
},
187179
]
@@ -194,7 +186,7 @@ impl OperationVerifyGate {
194186
Ok(OperationVerifyTarget {})
195187
}
196188

197-
fn eval_eq(
189+
fn eval_eq_from_entries(
198190
&self,
199191
builder: &mut CircuitBuilder<F, D>,
200192
st: &StatementTarget,
@@ -215,7 +207,10 @@ impl OperationVerifyGate {
215207
// `STATEMENT_ARG_F_LEN - VALUE_SIZE` slots of each being 0.
216208
let arg1_value = resolved_op_args[0].args[1];
217209
let arg2_value = resolved_op_args[1].args[1];
218-
let op_arg_range_checks = [builder.is_value(&arg1_value), builder.is_value(&arg2_value)];
210+
let op_arg_range_checks = [
211+
builder.statement_arg_is_value(&arg1_value),
212+
builder.statement_arg_is_value(&arg2_value),
213+
];
219214
let op_arg_range_ok = builder.all(op_arg_range_checks);
220215
let op_args_eq =
221216
builder.is_equal_slice(&arg1_value[..VALUE_SIZE], &arg2_value[..VALUE_SIZE]);
@@ -228,7 +223,7 @@ impl OperationVerifyGate {
228223
NativePredicate::Equal,
229224
&[arg1_key, arg2_key],
230225
);
231-
let st_ok = builder.is_equal_slice(&st.to_flattened(), &expected_statement.to_flattened());
226+
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
232227

233228
builder.all([
234229
op_code_ok,
@@ -239,7 +234,7 @@ impl OperationVerifyGate {
239234
])
240235
}
241236

242-
fn eval_lt(
237+
fn eval_lt_from_entries(
243238
&self,
244239
builder: &mut CircuitBuilder<F, D>,
245240
st: &StatementTarget,
@@ -263,7 +258,7 @@ impl OperationVerifyGate {
263258
let arg2_value = resolved_op_args[1].args[1];
264259
let op_arg_range_checks = [&arg1_value, &arg2_value]
265260
.into_iter()
266-
.map(|x| builder.is_value(x))
261+
.map(|x| builder.statement_arg_is_value(x))
267262
.collect::<Vec<_>>();
268263
let op_arg_range_ok = builder.all(op_arg_range_checks);
269264
builder.assert_less_if(
@@ -280,7 +275,7 @@ impl OperationVerifyGate {
280275
NativePredicate::Lt,
281276
&[arg1_key, arg2_key],
282277
);
283-
let st_ok = builder.is_equal_slice(&st.to_flattened(), &expected_statement.to_flattened());
278+
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
284279

285280
builder.all([op_code_ok, op_arg_types_ok, op_arg_range_ok, st_ok])
286281
}
@@ -293,9 +288,9 @@ impl OperationVerifyGate {
293288
) -> BoolTarget {
294289
let op_code_ok = op.has_native_type(builder, NativeOperation::None);
295290

296-
let expected_statement_flattened =
297-
builder.constants(&Statement::None.to_fields(&self.params));
298-
let st_ok = builder.is_equal_slice(&st.to_flattened(), &expected_statement_flattened);
291+
let expected_statement =
292+
StatementTarget::new_native(builder, &self.params, NativePredicate::None, &[]);
293+
let st_ok = builder.is_equal_flattenable(st, &expected_statement);
299294

300295
builder.all([op_code_ok, st_ok])
301296
}
@@ -342,8 +337,8 @@ impl OperationVerifyGate {
342337
) -> Result<BoolTarget> {
343338
let op_code_ok = op.has_native_type(builder, NativeOperation::CopyStatement);
344339

345-
let expected_statement_flattened = &resolved_op_args[0].to_flattened();
346-
let st_ok = builder.is_equal_slice(&st.to_flattened(), expected_statement_flattened);
340+
let expected_statement = &resolved_op_args[0];
341+
let st_ok = builder.is_equal_flattenable(st, expected_statement);
347342

348343
Ok(builder.all([op_code_ok, st_ok]))
349344
}
@@ -414,14 +409,13 @@ impl MainPodVerifyGate {
414409
// TODO: Store this hash in a global static with lazy init so that we don't have to
415410
// compute it every time.
416411
let key_type = hash_str(KEY_TYPE);
417-
let expected_type_statement_flattened = builder.constants(
418-
&Statement::ValueOf(AnchoredKey(SELF, key_type), Value::from(PodType::MockMain))
419-
.to_fields(params),
420-
);
421-
builder.connect_slice(
422-
&type_statement.to_flattened(),
423-
&expected_type_statement_flattened,
412+
let expected_type_statement = StatementTarget::from_flattened(
413+
&builder.constants(
414+
&Statement::ValueOf(AnchoredKey(SELF, key_type), Value::from(PodType::MockMain))
415+
.to_fields(params),
416+
),
424417
);
418+
builder.connect_flattenable(type_statement, &expected_type_statement);
425419

426420
// 5. Verify input statements
427421
let mut op_verifications = Vec::new();

0 commit comments

Comments
 (0)