diff --git a/src/backends/plonky2/primitives/ec/curve.rs b/src/backends/plonky2/primitives/ec/curve.rs index 89f9b091..38bc5211 100644 --- a/src/backends/plonky2/primitives/ec/curve.rs +++ b/src/backends/plonky2/primitives/ec/curve.rs @@ -325,9 +325,19 @@ type FieldTarget = OEFTarget<5, QuinticExtension>; pub struct PointTarget { pub x: FieldTarget, pub u: FieldTarget, + pub(super) checked_on_curve: bool, + pub(super) checked_in_subgroup: bool, } impl PointTarget { + pub fn new_unsafe(x: FieldTarget, u: FieldTarget) -> Self { + Self { + x, + u, + checked_on_curve: false, + checked_in_subgroup: false, + } + } pub fn to_value(&self, builder: &mut CircuitBuilder) -> ValueTarget { let hash = builder.hash_n_to_hash_no_pad::( self.x @@ -383,8 +393,12 @@ where ) -> plonky2::util::serialization::IoResult<()> { dst.write_target_array(&self.orig.x.components)?; dst.write_target_array(&self.orig.u.components)?; + dst.write_bool(self.orig.checked_on_curve)?; + dst.write_bool(self.orig.checked_in_subgroup)?; dst.write_target_array(&self.sqrt.x.components)?; - dst.write_target_array(&self.sqrt.u.components) + dst.write_target_array(&self.sqrt.u.components)?; + dst.write_bool(self.sqrt.checked_on_curve)?; + dst.write_bool(self.sqrt.checked_in_subgroup) } fn deserialize( @@ -397,16 +411,21 @@ where let orig = PointTarget { x: FieldTarget::new(src.read_target_array()?), u: FieldTarget::new(src.read_target_array()?), + checked_on_curve: src.read_bool()?, + checked_in_subgroup: src.read_bool()?, }; let sqrt = PointTarget { x: FieldTarget::new(src.read_target_array()?), u: FieldTarget::new(src.read_target_array()?), + checked_on_curve: src.read_bool()?, + checked_in_subgroup: src.read_bool()?, }; Ok(Self { orig, sqrt }) } } pub trait CircuitBuilderElliptic { + fn add_virtual_point_target_unsafe(&mut self) -> PointTarget; fn add_virtual_point_target(&mut self) -> PointTarget; fn identity_point(&mut self) -> PointTarget; fn constant_point(&mut self, p: Point) -> PointTarget; @@ -430,17 +449,17 @@ pub trait CircuitBuilderElliptic { /// Check that two points are equal. This assumes that the points are /// already known to be in the subgroup. fn connect_point(&mut self, p1: &PointTarget, p2: &PointTarget); - fn check_point_on_curve(&mut self, p: &PointTarget); - fn check_point_in_subgroup(&mut self, p: &PointTarget); + fn check_point_on_curve(&mut self, p: &mut PointTarget); + fn check_point_in_subgroup(&mut self, p: &mut PointTarget); } impl CircuitBuilderElliptic for CircuitBuilder { + fn add_virtual_point_target_unsafe(&mut self) -> PointTarget { + PointTarget::new_unsafe(self.add_virtual_nnf_target(), self.add_virtual_nnf_target()) + } fn add_virtual_point_target(&mut self) -> PointTarget { - let p = PointTarget { - x: self.add_virtual_nnf_target(), - u: self.add_virtual_nnf_target(), - }; - self.check_point_in_subgroup(&p); + let mut p = self.add_virtual_point_target_unsafe(); + self.check_point_in_subgroup(&mut p); p } @@ -449,14 +468,18 @@ impl CircuitBuilderElliptic for CircuitBuilder { } fn constant_point(&mut self, p: Point) -> PointTarget { - assert!(p.is_in_subgroup()); - PointTarget { - x: self.nnf_constant(&p.x), - u: self.nnf_constant(&p.u), - } + assert!(p.is_in_subgroup(), "Given point should be in EC subgroup."); + let mut p_target = + PointTarget::new_unsafe(self.nnf_constant(&p.x), self.nnf_constant(&p.u)); + self.check_point_in_subgroup(&mut p_target); + p_target } fn add_point(&mut self, p1: &PointTarget, p2: &PointTarget) -> PointTarget { + assert!( + p1.checked_on_curve && p2.checked_on_curve, + "EC addition formula requires that both points lie on the curve." + ); let mut inputs = Vec::with_capacity(20); inputs.extend_from_slice(&p1.x.components); inputs.extend_from_slice(&p1.u.components); @@ -474,37 +497,13 @@ impl CircuitBuilderElliptic for CircuitBuilder { let t = self.nnf_add_scalar_times_generator_power(b1, 1, &t); let xq = self.nnf_div(&x, &z); let uq = self.nnf_div(&u, &t); - PointTarget { x: xq, u: uq } - /* - let t1 = self.nnf_mul(&p1.x, &p2.x); - let t3 = self.nnf_mul(&p1.u, &p2.u); - let t5 = self.nnf_add(&p1.x, &p2.x); - let t6 = self.nnf_add(&p1.u, &p2.u); - let b1 = self.constant(GoldilocksField::from_canonical_u32(Point::B1_U32)); - let t7 = self.nnf_add_scalar_times_generator_power(b1, 1, &t1); - let t9_1 = self.nnf_mul_generator(&t5); - let t9_2 = self.nnf_mul_scalar(b1, &t9_1); - let t9_3 = self.nnf_add(&t9_2, &t7); - let t9_4 = self.nnf_add(&t9_3, &t9_3); - let t9 = self.nnf_mul(&t3, &t9_4); - let one = self.one(); - let t10_1 = self.nnf_add(&t3, &t3); - let t10_2 = self.nnf_add_scalar_times_generator_power(one, 0, &t10_1); - let t10_3 = self.nnf_add(&t5, &t7); - let t10 = self.nnf_mul(&t10_2, &t10_3); - let x_1 = self.nnf_sub(&t10, &t7); - let x_2 = self.nnf_mul_generator(&x_1); - let x = self.nnf_mul_scalar(b1, &x_2); - let z = self.nnf_sub(&t7, &t9); - let neg_one = self.neg_one(); - let u_1 = self.nnf_mul_scalar(neg_one, &t1); - let u_2 = self.nnf_add_scalar_times_generator_power(b1, 1, &u_1); - let u = self.nnf_mul(&t6, &u_2); - let t = self.nnf_add(&t7, &t9); - let xq = self.nnf_div(&x, &z); - let uq = self.nnf_div(&u, &t); - PointTarget { x: xq, u: uq } - */ + // If p1 and p2 lie in the subgroup, then so does p1 + p2. + PointTarget { + x: xq, + u: uq, + checked_on_curve: true, + checked_in_subgroup: p1.checked_in_subgroup && p2.checked_in_subgroup, + } } fn double_point(&mut self, p: &PointTarget) -> PointTarget { @@ -566,10 +565,16 @@ impl CircuitBuilderElliptic for CircuitBuilder { PointTarget { x: self.nnf_if(b, &p_true.x, &p_false.x), u: self.nnf_if(b, &p_true.u, &p_false.u), + checked_on_curve: p_true.checked_on_curve && p_false.checked_on_curve, + checked_in_subgroup: p_true.checked_in_subgroup && p_false.checked_in_subgroup, } } fn connect_point(&mut self, p1: &PointTarget, p2: &PointTarget) { + assert!( + p1.checked_in_subgroup && p2.checked_in_subgroup, + "Connected points must lie in the EC subgroup." + ); // The elements of the subgroup have distinct u-coordinates. So it // is not necessary to connect the x-coordinates. // Explanation: If a point has u-coordinate lambda: @@ -582,7 +587,7 @@ impl CircuitBuilderElliptic for CircuitBuilder { self.nnf_connect(&p1.u, &p2.u); } - fn check_point_on_curve(&mut self, p: &PointTarget) { + fn check_point_on_curve(&mut self, p: &mut PointTarget) { let t1 = self.nnf_mul(&p.u, &p.u); let two = self.two(); let t2 = self.nnf_add_scalar_times_generator_power(two, 0, &p.x); @@ -591,16 +596,14 @@ impl CircuitBuilderElliptic for CircuitBuilder { let t4 = self.nnf_add_scalar_times_generator_power(b1, 1, &t3); let t5 = self.nnf_mul(&t1, &t4); self.nnf_connect(&p.x, &t5); + p.checked_on_curve = true; } - fn check_point_in_subgroup(&mut self, p: &PointTarget) { + fn check_point_in_subgroup(&mut self, p: &mut PointTarget) { // In order to be in the subgroup, the point needs to be a multiple // of two. - let sqrt = PointTarget { - x: self.add_virtual_nnf_target(), - u: self.add_virtual_nnf_target(), - }; - self.check_point_on_curve(&sqrt); + let mut sqrt = self.add_virtual_point_target_unsafe(); + self.check_point_on_curve(&mut sqrt); let doubled = self.double_point(&sqrt); // connect_point assumes that the point is already known to be in the // subgroup, so connect the coordinates instead @@ -610,6 +613,8 @@ impl CircuitBuilderElliptic for CircuitBuilder { orig: p.clone(), sqrt, }); + p.checked_on_curve = true; + p.checked_in_subgroup = true; } }