Skip to content

Commit 30f26a9

Browse files
authored
chore(backend): implement some circuit op logic (#165)
* Initial circuit op work * Fix copy op * Add more ops * Fixes * Code review
1 parent 3b2860b commit 30f26a9

File tree

2 files changed

+465
-78
lines changed

2 files changed

+465
-78
lines changed

src/backends/plonky2/circuits/common.rs

Lines changed: 231 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
//! Common functionality to build Pod circuits with plonky2
22
3+
use crate::backends::plonky2::basetypes::D;
34
use crate::backends::plonky2::mock::mainpod::Statement;
45
use crate::backends::plonky2::mock::mainpod::{Operation, OperationArg};
5-
use crate::middleware::{Params, StatementArg, ToFields, Value, F, HASH_SIZE, VALUE_SIZE};
6+
use crate::middleware::{
7+
NativeOperation, NativePredicate, Params, Predicate, StatementArg, ToFields, Value, F,
8+
HASH_SIZE, VALUE_SIZE,
9+
};
610
use crate::middleware::{OPERATION_ARG_F_LEN, STATEMENT_ARG_F_LEN};
711
use anyhow::Result;
812
use plonky2::field::extension::Extendable;
@@ -11,26 +15,67 @@ use plonky2::hash::hash_types::RichField;
1115
use plonky2::iop::target::{BoolTarget, Target};
1216
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
1317
use plonky2::plonk::circuit_builder::CircuitBuilder;
14-
use std::iter;
18+
use std::{array, iter};
19+
20+
pub const CODE_SIZE: usize = HASH_SIZE + 2;
1521

1622
#[derive(Copy, Clone)]
1723
pub struct ValueTarget {
1824
pub elements: [Target; VALUE_SIZE],
1925
}
2026

27+
impl ValueTarget {
28+
pub fn zero(builder: &mut CircuitBuilder<F, D>) -> Self {
29+
Self {
30+
elements: [builder.zero(); VALUE_SIZE],
31+
}
32+
}
33+
34+
pub fn one(builder: &mut CircuitBuilder<F, D>) -> Self {
35+
Self {
36+
elements: array::from_fn(|i| {
37+
if i == 0 {
38+
builder.one()
39+
} else {
40+
builder.zero()
41+
}
42+
}),
43+
}
44+
}
45+
46+
pub fn from_slice(xs: &[Target]) -> Self {
47+
assert_eq!(xs.len(), VALUE_SIZE);
48+
Self {
49+
elements: array::from_fn(|i| xs[i]),
50+
}
51+
}
52+
}
53+
2154
#[derive(Clone)]
2255
pub struct StatementTarget {
2356
pub predicate: [Target; Params::predicate_size()],
2457
pub args: Vec<[Target; STATEMENT_ARG_F_LEN]>,
2558
}
2659

2760
impl StatementTarget {
28-
pub fn to_flattened(&self) -> Vec<Target> {
29-
self.predicate
30-
.iter()
31-
.chain(self.args.iter().flatten())
32-
.cloned()
33-
.collect()
61+
pub fn new_native(
62+
builder: &mut CircuitBuilder<F, D>,
63+
params: &Params,
64+
predicate: NativePredicate,
65+
args: &[[Target; STATEMENT_ARG_F_LEN]],
66+
) -> Self {
67+
let predicate_vec = builder.constants(&Predicate::Native(predicate).to_fields(params));
68+
Self {
69+
predicate: array::from_fn(|i| predicate_vec[i]),
70+
args: args
71+
.iter()
72+
.map(|arg| *arg)
73+
.chain(
74+
iter::repeat([builder.zero(); STATEMENT_ARG_F_LEN])
75+
.take(params.max_statement_args - args.len()),
76+
)
77+
.collect(),
78+
}
3479
}
3580

3681
pub fn set_targets(
@@ -51,6 +96,16 @@ impl StatementTarget {
5196
}
5297
Ok(())
5398
}
99+
100+
pub fn has_native_type(
101+
&self,
102+
builder: &mut CircuitBuilder<F, D>,
103+
params: &Params,
104+
t: NativePredicate,
105+
) -> BoolTarget {
106+
let st_code = builder.constants(&Predicate::Native(t).to_fields(params));
107+
builder.is_equal_slice(&self.predicate, &st_code)
108+
}
54109
}
55110

56111
// TODO: Implement Operation::to_field to determine the size of each element
@@ -79,6 +134,49 @@ impl OperationTarget {
79134
}
80135
Ok(())
81136
}
137+
138+
pub fn has_native_type(
139+
&self,
140+
builder: &mut CircuitBuilder<F, D>,
141+
t: NativeOperation,
142+
) -> BoolTarget {
143+
let one = builder.one();
144+
let op_is_native = builder.is_equal(self.op_type[0], one);
145+
let op_code = builder.constant(F::from_canonical_u64(t as u64));
146+
let op_code_matches = builder.is_equal(self.op_type[1], op_code);
147+
builder.and(op_is_native, op_code_matches)
148+
}
149+
}
150+
151+
/// Trait for target structs that may be converted to and from vectors
152+
/// of targets.
153+
pub trait Flattenable {
154+
fn flatten(&self) -> Vec<Target>;
155+
fn from_flattened(vs: &[Target]) -> Self;
156+
}
157+
158+
impl Flattenable for StatementTarget {
159+
fn flatten(&self) -> Vec<Target> {
160+
self.predicate
161+
.iter()
162+
.chain(self.args.iter().flatten())
163+
.cloned()
164+
.collect()
165+
}
166+
167+
fn from_flattened(v: &[Target]) -> Self {
168+
let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN;
169+
assert_eq!(
170+
v.len(),
171+
Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN
172+
);
173+
let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]);
174+
let args = (0..num_args)
175+
.map(|i| array::from_fn(|j| v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j]))
176+
.collect();
177+
178+
Self { predicate, args }
179+
}
82180
}
83181

84182
pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
@@ -91,11 +189,27 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
91189
fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget;
92190
fn constant_value(&mut self, v: Value) -> ValueTarget;
93191
fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget;
192+
193+
// Convenience methods for checking values.
194+
/// Checks whether `xs` is right-padded with 0s so as to represent a `Value`.
195+
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget;
196+
/// Checks whether `x < y` if `b` is true. This involves checking
197+
/// that `x` and `y` each consist of two `u32` limbs.
198+
fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget);
199+
200+
// Convenience methods for accessing and connecting elements of
201+
// (vectors of) flattenables.
202+
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T;
203+
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T;
204+
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T);
205+
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget;
206+
207+
// Convenience methods for Boolean into-iters.
208+
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
209+
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget;
94210
}
95211

96-
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
97-
for CircuitBuilder<F, D>
98-
{
212+
impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> {
99213
fn connect_slice(&mut self, xs: &[Target], ys: &[Target]) {
100214
assert_eq!(xs.len(), ys.len());
101215
for (x, y) in xs.iter().zip(ys.iter()) {
@@ -157,4 +271,110 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderPod<F, D>
157271
self.and(ok, is_eq)
158272
})
159273
}
274+
275+
fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget {
276+
let zeros = iter::repeat(self.zero())
277+
.take(STATEMENT_ARG_F_LEN - VALUE_SIZE)
278+
.collect::<Vec<_>>();
279+
self.is_equal_slice(&xs[VALUE_SIZE..], &zeros)
280+
}
281+
282+
fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) {
283+
const NUM_BITS: usize = 32;
284+
285+
// Lt assertion with 32-bit range check.
286+
let assert_limb_lt = |builder: &mut Self, x, y| {
287+
// Check that targets fit within `NUM_BITS` bits.
288+
builder.range_check(x, NUM_BITS);
289+
builder.range_check(y, NUM_BITS);
290+
// Check that `y-1-x` fits within `NUM_BITS` bits.
291+
let one = builder.one();
292+
let y_minus_one = builder.sub(y, one);
293+
let expr = builder.sub(y_minus_one, x);
294+
builder.range_check(expr, NUM_BITS);
295+
};
296+
297+
// If b is false, replace `x` and `y` with dummy values.
298+
let zero = ValueTarget::zero(self);
299+
let one = ValueTarget::one(self);
300+
let x = self.select_value(b, x, zero);
301+
let y = self.select_value(b, y, one);
302+
303+
// `x` and `y` should only have two limbs each.
304+
x.elements
305+
.into_iter()
306+
.skip(2)
307+
.for_each(|l| self.assert_zero(l));
308+
y.elements
309+
.into_iter()
310+
.skip(2)
311+
.for_each(|l| self.assert_zero(l));
312+
313+
let big_limbs_eq = self.is_equal(x.elements[1], y.elements[1]);
314+
let lhs = self.select(big_limbs_eq, x.elements[0], x.elements[1]);
315+
let rhs = self.select(big_limbs_eq, y.elements[0], y.elements[1]);
316+
assert_limb_lt(self, lhs, rhs);
317+
}
318+
319+
fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T {
320+
// TODO: Revisit this when we need more than 64 statements.
321+
let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| {
322+
assert!(v.len() <= 64);
323+
builder.random_access(i, v.to_vec())
324+
};
325+
let matrix_row_ref = |builder: &mut CircuitBuilder<F, D>, m: &[Vec<Target>], i| {
326+
let num_rows = m.len();
327+
let num_columns = m
328+
.get(0)
329+
.map(|row| {
330+
let row_len = row.len();
331+
assert!(m.iter().all(|row| row.len() == row_len));
332+
row_len
333+
})
334+
.unwrap_or(0);
335+
(0..num_columns)
336+
.map(|j| {
337+
vector_ref(
338+
builder,
339+
&(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(),
340+
i,
341+
)
342+
})
343+
.collect::<Vec<_>>()
344+
};
345+
346+
let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>();
347+
T::from_flattened(&matrix_row_ref(self, &flattened_ts, i))
348+
}
349+
350+
fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T {
351+
let flattened_x = x.flatten();
352+
let flattened_y = y.flatten();
353+
354+
T::from_flattened(
355+
&iter::zip(flattened_x, flattened_y)
356+
.map(|(x, y)| self.select(b, x, y))
357+
.collect::<Vec<_>>(),
358+
)
359+
}
360+
361+
fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) {
362+
self.connect_slice(&xs.flatten(), &ys.flatten())
363+
}
364+
365+
fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget {
366+
self.is_equal_slice(&xs.flatten(), &ys.flatten())
367+
}
368+
369+
fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
370+
xs.into_iter()
371+
.reduce(|a, b| self.and(a, b))
372+
.unwrap_or(self._true())
373+
}
374+
375+
fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget {
376+
xs.into_iter()
377+
.reduce(|a, b| self.or(a, b))
378+
.unwrap_or(self._false())
379+
}
160380
}

0 commit comments

Comments
 (0)