|
| 1 | +//! Common functionality to build Pod circuits with plonky2 |
| 2 | +
|
| 3 | +use crate::backends::plonky2::basetypes::D; |
| 4 | +use crate::backends::plonky2::mock::mainpod::Statement; |
| 5 | +use crate::backends::plonky2::mock::mainpod::{Operation, OperationArg}; |
| 6 | +use crate::middleware::{ |
| 7 | + NativeOperation, NativePredicate, Params, Predicate, StatementArg, ToFields, Value, F, |
| 8 | + HASH_SIZE, VALUE_SIZE, |
| 9 | +}; |
| 10 | +use crate::middleware::{OPERATION_ARG_F_LEN, STATEMENT_ARG_F_LEN}; |
| 11 | +use anyhow::Result; |
| 12 | +use plonky2::field::extension::Extendable; |
| 13 | +use plonky2::field::types::{Field, PrimeField64}; |
| 14 | +use plonky2::hash::hash_types::RichField; |
| 15 | +use plonky2::iop::target::{BoolTarget, Target}; |
| 16 | +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; |
| 17 | +use plonky2::plonk::circuit_builder::CircuitBuilder; |
| 18 | +use std::{array, iter}; |
| 19 | + |
| 20 | +pub const CODE_SIZE: usize = HASH_SIZE + 2; |
| 21 | + |
| 22 | +#[derive(Copy, Clone)] |
| 23 | +pub struct ValueTarget { |
| 24 | + pub elements: [Target; VALUE_SIZE], |
| 25 | +} |
| 26 | + |
| 27 | +impl ValueTarget { |
| 28 | + pub fn zero(builder: &mut CircuitBuilder<F, D>) -> Self { |
| 29 | + Self { |
| 30 | + elements: [builder.zero(); VALUE_SIZE], |
| 31 | + } |
| 32 | + } |
| 33 | + |
| 34 | + pub fn one(builder: &mut CircuitBuilder<F, D>) -> Self { |
| 35 | + Self { |
| 36 | + elements: array::from_fn(|i| { |
| 37 | + if i == 0 { |
| 38 | + builder.one() |
| 39 | + } else { |
| 40 | + builder.zero() |
| 41 | + } |
| 42 | + }), |
| 43 | + } |
| 44 | + } |
| 45 | + |
| 46 | + pub fn from_slice(xs: &[Target]) -> Self { |
| 47 | + assert_eq!(xs.len(), VALUE_SIZE); |
| 48 | + Self { |
| 49 | + elements: array::from_fn(|i| xs[i]), |
| 50 | + } |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +#[derive(Clone)] |
| 55 | +pub struct StatementTarget { |
| 56 | + pub predicate: [Target; Params::predicate_size()], |
| 57 | + pub args: Vec<[Target; STATEMENT_ARG_F_LEN]>, |
| 58 | +} |
| 59 | + |
| 60 | +impl StatementTarget { |
| 61 | + pub fn new_native( |
| 62 | + builder: &mut CircuitBuilder<F, D>, |
| 63 | + params: &Params, |
| 64 | + predicate: NativePredicate, |
| 65 | + args: &[[Target; STATEMENT_ARG_F_LEN]], |
| 66 | + ) -> Self { |
| 67 | + let predicate_vec = builder.constants(&Predicate::Native(predicate).to_fields(params)); |
| 68 | + Self { |
| 69 | + predicate: array::from_fn(|i| predicate_vec[i]), |
| 70 | + args: args |
| 71 | + .iter() |
| 72 | + .map(|arg| *arg) |
| 73 | + .chain( |
| 74 | + iter::repeat([builder.zero(); STATEMENT_ARG_F_LEN]) |
| 75 | + .take(params.max_statement_args - args.len()), |
| 76 | + ) |
| 77 | + .collect(), |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + pub fn set_targets( |
| 82 | + &self, |
| 83 | + pw: &mut PartialWitness<F>, |
| 84 | + params: &Params, |
| 85 | + st: &Statement, |
| 86 | + ) -> Result<()> { |
| 87 | + pw.set_target_arr(&self.predicate, &st.predicate().to_fields(params))?; |
| 88 | + for (i, arg) in st |
| 89 | + .args() |
| 90 | + .iter() |
| 91 | + .chain(iter::repeat(&StatementArg::None)) |
| 92 | + .take(params.max_statement_args) |
| 93 | + .enumerate() |
| 94 | + { |
| 95 | + pw.set_target_arr(&self.args[i], &arg.to_fields(params))?; |
| 96 | + } |
| 97 | + Ok(()) |
| 98 | + } |
| 99 | + |
| 100 | + pub fn has_native_type( |
| 101 | + &self, |
| 102 | + builder: &mut CircuitBuilder<F, D>, |
| 103 | + params: &Params, |
| 104 | + t: NativePredicate, |
| 105 | + ) -> BoolTarget { |
| 106 | + let st_code = builder.constants(&Predicate::Native(t).to_fields(params)); |
| 107 | + builder.is_equal_slice(&self.predicate, &st_code) |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +// TODO: Implement Operation::to_field to determine the size of each element |
| 112 | +#[derive(Clone)] |
| 113 | +pub struct OperationTarget { |
| 114 | + pub op_type: [Target; Params::operation_type_size()], |
| 115 | + pub args: Vec<[Target; OPERATION_ARG_F_LEN]>, |
| 116 | +} |
| 117 | + |
| 118 | +impl OperationTarget { |
| 119 | + pub fn set_targets( |
| 120 | + &self, |
| 121 | + pw: &mut PartialWitness<F>, |
| 122 | + params: &Params, |
| 123 | + op: &Operation, |
| 124 | + ) -> Result<()> { |
| 125 | + pw.set_target_arr(&self.op_type, &op.op_type().to_fields(params))?; |
| 126 | + for (i, arg) in op |
| 127 | + .args() |
| 128 | + .iter() |
| 129 | + .chain(iter::repeat(&OperationArg::None)) |
| 130 | + .take(params.max_operation_args) |
| 131 | + .enumerate() |
| 132 | + { |
| 133 | + pw.set_target_arr(&self.args[i], &arg.to_fields(params))?; |
| 134 | + } |
| 135 | + Ok(()) |
| 136 | + } |
| 137 | + |
| 138 | + pub fn has_native_type( |
| 139 | + &self, |
| 140 | + builder: &mut CircuitBuilder<F, D>, |
| 141 | + t: NativeOperation, |
| 142 | + ) -> BoolTarget { |
| 143 | + let one = builder.one(); |
| 144 | + let op_is_native = builder.is_equal(self.op_type[0], one); |
| 145 | + let op_code = builder.constant(F::from_canonical_u64(t as u64)); |
| 146 | + let op_code_matches = builder.is_equal(self.op_type[1], op_code); |
| 147 | + builder.and(op_is_native, op_code_matches) |
| 148 | + } |
| 149 | +} |
| 150 | + |
| 151 | +/// Trait for target structs that may be converted to and from vectors |
| 152 | +/// of targets. |
| 153 | +pub trait Flattenable { |
| 154 | + fn flatten(&self) -> Vec<Target>; |
| 155 | + fn from_flattened(vs: &[Target]) -> Self; |
| 156 | +} |
| 157 | + |
| 158 | +impl Flattenable for StatementTarget { |
| 159 | + fn flatten(&self) -> Vec<Target> { |
| 160 | + self.predicate |
| 161 | + .iter() |
| 162 | + .chain(self.args.iter().flatten()) |
| 163 | + .cloned() |
| 164 | + .collect() |
| 165 | + } |
| 166 | + |
| 167 | + fn from_flattened(v: &[Target]) -> Self { |
| 168 | + let num_args = (v.len() - Params::predicate_size()) / STATEMENT_ARG_F_LEN; |
| 169 | + assert_eq!( |
| 170 | + v.len(), |
| 171 | + Params::predicate_size() + num_args * STATEMENT_ARG_F_LEN |
| 172 | + ); |
| 173 | + let predicate: [Target; Params::predicate_size()] = array::from_fn(|i| v[i]); |
| 174 | + let args = (0..num_args) |
| 175 | + .map(|i| array::from_fn(|j| v[Params::predicate_size() + i * STATEMENT_ARG_F_LEN + j])) |
| 176 | + .collect(); |
| 177 | + |
| 178 | + Self { predicate, args } |
| 179 | + } |
| 180 | +} |
| 181 | + |
| 182 | +pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> { |
| 183 | + fn connect_values(&mut self, x: ValueTarget, y: ValueTarget); |
| 184 | + fn connect_slice(&mut self, xs: &[Target], ys: &[Target]); |
| 185 | + fn add_virtual_value(&mut self) -> ValueTarget; |
| 186 | + fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget; |
| 187 | + fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget; |
| 188 | + fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget; |
| 189 | + fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget; |
| 190 | + fn constant_value(&mut self, v: Value) -> ValueTarget; |
| 191 | + fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget; |
| 192 | + |
| 193 | + // Convenience methods for checking values. |
| 194 | + /// Checks whether `xs` is right-padded with 0s so as to represent a `Value`. |
| 195 | + fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget; |
| 196 | + /// Checks whether `x < y` if `b` is true. This involves checking |
| 197 | + /// that `x` and `y` each consist of two `u32` limbs. |
| 198 | + fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget); |
| 199 | + |
| 200 | + // Convenience methods for accessing and connecting elements of |
| 201 | + // (vectors of) flattenables. |
| 202 | + fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T; |
| 203 | + fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T; |
| 204 | + fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T); |
| 205 | + fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget; |
| 206 | + |
| 207 | + // Convenience methods for Boolean into-iters. |
| 208 | + fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget; |
| 209 | + fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget; |
| 210 | +} |
| 211 | + |
| 212 | +impl CircuitBuilderPod<F, D> for CircuitBuilder<F, D> { |
| 213 | + fn connect_slice(&mut self, xs: &[Target], ys: &[Target]) { |
| 214 | + assert_eq!(xs.len(), ys.len()); |
| 215 | + for (x, y) in xs.iter().zip(ys.iter()) { |
| 216 | + self.connect(*x, *y); |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + fn connect_values(&mut self, x: ValueTarget, y: ValueTarget) { |
| 221 | + self.connect_slice(&x.elements, &y.elements); |
| 222 | + } |
| 223 | + |
| 224 | + fn add_virtual_value(&mut self) -> ValueTarget { |
| 225 | + ValueTarget { |
| 226 | + elements: self.add_virtual_target_arr(), |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + fn add_virtual_statement(&mut self, params: &Params) -> StatementTarget { |
| 231 | + StatementTarget { |
| 232 | + predicate: self.add_virtual_target_arr(), |
| 233 | + args: (0..params.max_statement_args) |
| 234 | + .map(|_| self.add_virtual_target_arr()) |
| 235 | + .collect(), |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + fn add_virtual_operation(&mut self, params: &Params) -> OperationTarget { |
| 240 | + OperationTarget { |
| 241 | + op_type: self.add_virtual_target_arr(), |
| 242 | + args: (0..params.max_operation_args) |
| 243 | + .map(|_| self.add_virtual_target_arr()) |
| 244 | + .collect(), |
| 245 | + } |
| 246 | + } |
| 247 | + |
| 248 | + fn select_value(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) -> ValueTarget { |
| 249 | + ValueTarget { |
| 250 | + elements: std::array::from_fn(|i| self.select(b, x.elements[i], y.elements[i])), |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + fn select_bool(&mut self, b: BoolTarget, x: BoolTarget, y: BoolTarget) -> BoolTarget { |
| 255 | + BoolTarget::new_unsafe(self.select(b, x.target, y.target)) |
| 256 | + } |
| 257 | + |
| 258 | + fn constant_value(&mut self, v: Value) -> ValueTarget { |
| 259 | + ValueTarget { |
| 260 | + elements: std::array::from_fn(|i| { |
| 261 | + self.constant(F::from_noncanonical_u64(v.0[i].to_noncanonical_u64())) |
| 262 | + }), |
| 263 | + } |
| 264 | + } |
| 265 | + |
| 266 | + fn is_equal_slice(&mut self, xs: &[Target], ys: &[Target]) -> BoolTarget { |
| 267 | + assert_eq!(xs.len(), ys.len()); |
| 268 | + let init = self._true(); |
| 269 | + xs.iter().zip(ys.iter()).fold(init, |ok, (x, y)| { |
| 270 | + let is_eq = self.is_equal(*x, *y); |
| 271 | + self.and(ok, is_eq) |
| 272 | + }) |
| 273 | + } |
| 274 | + |
| 275 | + fn statement_arg_is_value(&mut self, xs: &[Target]) -> BoolTarget { |
| 276 | + let zeros = iter::repeat(self.zero()) |
| 277 | + .take(STATEMENT_ARG_F_LEN - VALUE_SIZE) |
| 278 | + .collect::<Vec<_>>(); |
| 279 | + self.is_equal_slice(&xs[VALUE_SIZE..], &zeros) |
| 280 | + } |
| 281 | + |
| 282 | + fn assert_less_if(&mut self, b: BoolTarget, x: ValueTarget, y: ValueTarget) { |
| 283 | + const NUM_BITS: usize = 32; |
| 284 | + |
| 285 | + // Lt assertion with 32-bit range check. |
| 286 | + let assert_limb_lt = |builder: &mut Self, x, y| { |
| 287 | + // Check that targets fit within `NUM_BITS` bits. |
| 288 | + builder.range_check(x, NUM_BITS); |
| 289 | + builder.range_check(y, NUM_BITS); |
| 290 | + // Check that `y-1-x` fits within `NUM_BITS` bits. |
| 291 | + let one = builder.one(); |
| 292 | + let y_minus_one = builder.sub(y, one); |
| 293 | + let expr = builder.sub(y_minus_one, x); |
| 294 | + builder.range_check(expr, NUM_BITS); |
| 295 | + }; |
| 296 | + |
| 297 | + // If b is false, replace `x` and `y` with dummy values. |
| 298 | + let zero = ValueTarget::zero(self); |
| 299 | + let one = ValueTarget::one(self); |
| 300 | + let x = self.select_value(b, x, zero); |
| 301 | + let y = self.select_value(b, y, one); |
| 302 | + |
| 303 | + // `x` and `y` should only have two limbs each. |
| 304 | + x.elements |
| 305 | + .into_iter() |
| 306 | + .skip(2) |
| 307 | + .for_each(|l| self.assert_zero(l)); |
| 308 | + y.elements |
| 309 | + .into_iter() |
| 310 | + .skip(2) |
| 311 | + .for_each(|l| self.assert_zero(l)); |
| 312 | + |
| 313 | + let big_limbs_eq = self.is_equal(x.elements[1], y.elements[1]); |
| 314 | + let lhs = self.select(big_limbs_eq, x.elements[0], x.elements[1]); |
| 315 | + let rhs = self.select(big_limbs_eq, y.elements[0], y.elements[1]); |
| 316 | + assert_limb_lt(self, lhs, rhs); |
| 317 | + } |
| 318 | + |
| 319 | + fn vec_ref<T: Flattenable>(&mut self, ts: &[T], i: Target) -> T { |
| 320 | + // TODO: Revisit this when we need more than 64 statements. |
| 321 | + let vector_ref = |builder: &mut CircuitBuilder<F, D>, v: &[Target], i| { |
| 322 | + assert!(v.len() <= 64); |
| 323 | + builder.random_access(i, v.to_vec()) |
| 324 | + }; |
| 325 | + let matrix_row_ref = |builder: &mut CircuitBuilder<F, D>, m: &[Vec<Target>], i| { |
| 326 | + let num_rows = m.len(); |
| 327 | + let num_columns = m |
| 328 | + .get(0) |
| 329 | + .map(|row| { |
| 330 | + let row_len = row.len(); |
| 331 | + assert!(m.iter().all(|row| row.len() == row_len)); |
| 332 | + row_len |
| 333 | + }) |
| 334 | + .unwrap_or(0); |
| 335 | + (0..num_columns) |
| 336 | + .map(|j| { |
| 337 | + vector_ref( |
| 338 | + builder, |
| 339 | + &(0..num_rows).map(|i| m[i][j]).collect::<Vec<_>>(), |
| 340 | + i, |
| 341 | + ) |
| 342 | + }) |
| 343 | + .collect::<Vec<_>>() |
| 344 | + }; |
| 345 | + |
| 346 | + let flattened_ts = ts.iter().map(|t| t.flatten()).collect::<Vec<_>>(); |
| 347 | + T::from_flattened(&matrix_row_ref(self, &flattened_ts, i)) |
| 348 | + } |
| 349 | + |
| 350 | + fn select_flattenable<T: Flattenable>(&mut self, b: BoolTarget, x: &T, y: &T) -> T { |
| 351 | + let flattened_x = x.flatten(); |
| 352 | + let flattened_y = y.flatten(); |
| 353 | + |
| 354 | + T::from_flattened( |
| 355 | + &iter::zip(flattened_x, flattened_y) |
| 356 | + .map(|(x, y)| self.select(b, x, y)) |
| 357 | + .collect::<Vec<_>>(), |
| 358 | + ) |
| 359 | + } |
| 360 | + |
| 361 | + fn connect_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) { |
| 362 | + self.connect_slice(&xs.flatten(), &ys.flatten()) |
| 363 | + } |
| 364 | + |
| 365 | + fn is_equal_flattenable<T: Flattenable>(&mut self, xs: &T, ys: &T) -> BoolTarget { |
| 366 | + self.is_equal_slice(&xs.flatten(), &ys.flatten()) |
| 367 | + } |
| 368 | + |
| 369 | + fn all(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget { |
| 370 | + xs.into_iter() |
| 371 | + .reduce(|a, b| self.and(a, b)) |
| 372 | + .unwrap_or(self._true()) |
| 373 | + } |
| 374 | + |
| 375 | + fn any(&mut self, xs: impl IntoIterator<Item = BoolTarget>) -> BoolTarget { |
| 376 | + xs.into_iter() |
| 377 | + .reduce(|a, b| self.or(a, b)) |
| 378 | + .unwrap_or(self._false()) |
| 379 | + } |
| 380 | +} |
0 commit comments