diff --git a/constraint/grumpkin/solver.go b/constraint/grumpkin/solver.go index 22894ee609..6bc26e59f2 100644 --- a/constraint/grumpkin/solver.go +++ b/constraint/grumpkin/solver.go @@ -627,6 +627,10 @@ func (r *UnsatisfiedConstraintError) Error() string { return fmt.Sprintf("constraint #%d is not satisfied: %s", r.CID, r.Err.Error()) } +func (r *UnsatisfiedConstraintError) Unwrap() error { + return r.Err +} + func (s *solver) wrapErrWithDebugInfo(cID uint32, err error) *UnsatisfiedConstraintError { var debugInfo *string if dID, ok := s.MDebug[int(cID)]; ok { diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv index 69861b4628..01eb313403 100644 --- a/internal/stats/latest_stats.csv +++ b/internal/stats/latest_stats.csv @@ -119,11 +119,11 @@ pairing_bw6761,bn254,plonk,5318762,5097941 pairing_bw6761,bls12_377,plonk,0,0 pairing_bw6761,bls12_381,plonk,0,0 pairing_bw6761,bw6_761,plonk,0,0 -scalar_mul_G1_bn254,bn254,groth16,55409,87955 +scalar_mul_G1_bn254,bn254,groth16,51703,82119 scalar_mul_G1_bn254,bls12_377,groth16,0,0 scalar_mul_G1_bn254,bls12_381,groth16,0,0 scalar_mul_G1_bn254,bw6_761,groth16,0,0 -scalar_mul_G1_bn254,bn254,plonk,199958,192839 +scalar_mul_G1_bn254,bn254,plonk,186429,179810 scalar_mul_G1_bn254,bls12_377,plonk,0,0 scalar_mul_G1_bn254,bls12_381,plonk,0,0 scalar_mul_G1_bn254,bw6_761,plonk,0,0 @@ -135,11 +135,11 @@ scalar_mul_P256,bn254,plonk,263160,253523 scalar_mul_P256,bls12_377,plonk,0,0 scalar_mul_P256,bls12_381,plonk,0,0 scalar_mul_P256,bw6_761,plonk,0,0 -scalar_mul_secp256k1,bn254,groth16,56093,89037 +scalar_mul_secp256k1,bn254,groth16,51753,82204 scalar_mul_secp256k1,bls12_377,groth16,0,0 scalar_mul_secp256k1,bls12_381,groth16,0,0 scalar_mul_secp256k1,bw6_761,groth16,0,0 -scalar_mul_secp256k1,bn254,plonk,202472,195259 +scalar_mul_secp256k1,bn254,plonk,186633,180006 scalar_mul_secp256k1,bls12_377,plonk,0,0 scalar_mul_secp256k1,bls12_381,plonk,0,0 scalar_mul_secp256k1,bw6_761,plonk,0,0 diff --git a/std/algebra/emulated/sw_bls12381/g2.go b/std/algebra/emulated/sw_bls12381/g2.go index 18d2c0259d..8831f3803e 100644 --- a/std/algebra/emulated/sw_bls12381/g2.go +++ b/std/algebra/emulated/sw_bls12381/g2.go @@ -24,6 +24,10 @@ type G2 struct { // SSWU map coefficients sswuCoeffA, sswuCoeffB *fields_bls12381.E2 sswuZ *fields_bls12381.E2 + + // Precomputed G2 generator and its multiple for GLV+FakeGLV + g2Gen *g2AffP // G2 generator + g2GenNbits *g2AffP // [2^nbits]G2 where nbits = (r.BitLen()+3)/4 + 2 } type g2AffP struct { @@ -80,6 +84,31 @@ func NewG2(api frontend.API) (*G2, error) { A0: *fp.NewElement(sswuZ.A0), A1: *fp.NewElement(sswuZ.A1), } + + // Precomputed G2 generator for GLV+FakeGLV + g2Gen := &g2AffP{ + X: fields_bls12381.E2{ + A0: *fp.NewElement("352701069587466618187139116011060144890029952792775240219908644239793785735715026873347600343865175952761926303160"), + A1: *fp.NewElement("3059144344244213709971259814753781636986470325476647558659373206291635324768958432433509563104347017837885763365758"), + }, + Y: fields_bls12381.E2{ + A0: *fp.NewElement("1985150602287291935568054521177171638300868978215655730859378665066344726373823718423869104263333984641494340347905"), + A1: *fp.NewElement("927553665492332455747201965776037880757740193453592970025027978793976877002675564980949289727957565575433344219582"), + }, + } + // [2^(nbits-1)]G2 where nbits = (255+3)/4 + 2 = 66, so this is [2^65]G2 + // The loop does nbits-1 doublings, so the generator accumulates to [2^(nbits-1)]G2 + g2GenNbits := &g2AffP{ + X: fields_bls12381.E2{ + A0: *fp.NewElement("1307001654908388153254394944417118155033503188409787277795273489312551176370209873126740711463572657296916966732684"), + A1: *fp.NewElement("1066804690119577865989830850277879393407029322116864061755683314318400220056817483617033672656485029228353937929571"), + }, + Y: fields_bls12381.E2{ + A0: *fp.NewElement("1233864651366532660795929818904272589705597977637697925481983092108793193162343169655985724823869788077854535468808"), + A1: *fp.NewElement("2703972434797875065063829955607449483769333186572810763171217085444622779819503421195150761462859837038921185079043"), + }, + } + return &G2{ api: api, fp: fp, @@ -94,6 +123,9 @@ func NewG2(api frontend.API) (*G2, error) { sswuCoeffA: coeffA, sswuCoeffB: coeffB, sswuZ: z, + // GLV+FakeGLV precomputed values + g2Gen: g2Gen, + g2GenNbits: g2GenNbits, }, nil } @@ -447,7 +479,6 @@ func (g2 G2) doubleAndAddSelect(b frontend.Variable, p, q *G2Affine) *G2Affine { }, } } - func (g2 *G2) computeTwistEquation(Q *G2Affine) (left, right *fields_bls12381.E2) { // Twist: Y² == X³ + aX + b, where a=0 and b=4(1+u) // (X,Y) ∈ {Y² == X³ + aX + b} U (0,0) @@ -510,274 +541,276 @@ func (g2 *G2) IsEqual(p, q *G2Affine) frontend.Variable { return g2.api.And(xEqual, yEqual) } -// scalarMulGeneric computes [s]p and returns it. It doesn't modify p nor s. -// This function doesn't check that the p is on the curve. See AssertIsOnCurve. -// -// ⚠️ p must not be (0,0) and s must not be 0, unless [algopts.WithCompleteArithmetic] option is set. -// (0,0) is not on the curve but we conventionally take it as the -// neutral/infinity point as per the [EVM]. -// -// It computes the right-to-left variable-base double-and-add algorithm ([Joye07], Alg.1). -// -// Since we use incomplete formulas for the addition law, we need to start with -// a non-zero accumulator point (R0). To do this, we skip the LSB (bit at -// position 0) and proceed assuming it was 1. At the end, we conditionally -// subtract the initial value (p) if LSB is 1. We also handle the bits at -// positions 1 and n-1 outside of the loop to optimize the number of -// constraints using [ELM03] (Section 3.1) -// -// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf -// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf -// [Joye07]: https://www.iacr.org/archive/ches2007/47270135/47270135.pdf -func (g2 *G2) scalarMulGeneric(p *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { - cfg, err := algopts.NewConfig(opts...) - if err != nil { - panic(fmt.Sprintf("parse opts: %v", err)) - } - var selector frontend.Variable - if cfg.CompleteArithmetic { - // if p=(0,0) we assign a dummy (0,1) to p and continue - selector = g2.api.And(g2.Ext2.IsZero(&p.P.X), g2.Ext2.IsZero(&p.P.Y)) - one := g2.Ext2.One() - p = g2.Select(selector, &G2Affine{P: g2AffP{X: *one, Y: *one}, Lines: nil}, p) - } - - var st ScalarField - sBits := g2.fr.ToBitsCanonical(s) - n := st.Modulus().BitLen() - if cfg.NbScalarBits > 2 && cfg.NbScalarBits < n { - n = cfg.NbScalarBits - } - - // i = 1 - Rb := g2.triple(p) - R0 := g2.Select(sBits[1], Rb, p) - R1 := g2.Select(sBits[1], p, Rb) - - for i := 2; i < n-1; i++ { - Rb = g2.doubleAndAddSelect(sBits[i], R0, R1) - R0 = g2.Select(sBits[i], Rb, R0) - R1 = g2.Select(sBits[i], R1, Rb) - } - - // i = n-1 - Rb = g2.doubleAndAddSelect(sBits[n-1], R0, R1) - R0 = g2.Select(sBits[n-1], Rb, R0) - - // i = 0 - // we use AddUnified instead of Add. This is because: - // - when s=0 then R0=P and AddUnified(P, -P) = (0,0). We return (0,0). - // - when s=1 then R0=P AddUnified(Q, -Q) is well defined. We return R0=P. - R0 = g2.Select(sBits[0], R0, g2.AddUnified(R0, g2.neg(p))) - - if cfg.CompleteArithmetic { - // if p=(0,0), return (0,0) - zero := g2.Ext2.Zero() - R0 = g2.Select(selector, &G2Affine{P: g2AffP{X: *zero, Y: *zero}, Lines: nil}, R0) - } - - return R0 -} - // ScalarMul computes [s]Q using an efficient endomorphism and returns it. It doesn't modify Q nor s. -// It implements an optimized version based on algorithm 1 of [Halo] (see Section 6.2 and appendix C). +// It implements the GLV+fakeGLV optimization from [EEMP25] which achieves r^(1/4) bounds +// on the sub-scalars, reducing the number of iterations in the scalar multiplication loop. +// +// Benchmarks show ~36% fewer constraints compared to plain GLV: +// - GLV: ~914k constraints +// - GLV+FakeGLV: ~585k constraints // // ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. // (0,0) is not on the curve but we conventionally take it as the // neutral/infinity point as per the [EVM]. // -// [Halo]: https://eprint.iacr.org/2019/1021.pdf +// [EEMP25]: https://eprint.iacr.org/2025/933 // [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf func (g2 *G2) ScalarMul(Q *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { - return g2.scalarMulGLV(Q, s, opts...) + return g2.scalarMulGLVAndFakeGLV(Q, s, opts...) } -// scalarMulGLV computes [s]Q using an efficient endomorphism and returns it. It doesn't modify Q nor s. -// It implements an optimized version based on algorithm 1 of [Halo] (see Section 6.2 and appendix C). +// scalarMulGLVAndFakeGLV computes [s]Q using GLV+fakeGLV with r^(1/4) bounds. +// It implements the "GLV + fake GLV" explained in [EEMP25] (Sec. 3.3). // // ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. -// (0,0) is not on the curve but we conventionally take it as the -// neutral/infinity point as per the [EVM]. // -// [Halo]: https://eprint.iacr.org/2019/1021.pdf -// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf -func (g2 *G2) scalarMulGLV(Q *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { +// [EEMP25]: https://eprint.iacr.org/2025/933 +func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { cfg, err := algopts.NewConfig(opts...) if err != nil { panic(err) } - addFn := g2.add - var selector frontend.Variable + + // handle 0-scalar + var selector0 frontend.Variable + _s := s if cfg.CompleteArithmetic { - addFn = g2.AddUnified - // if Q=(0,0) we assign a dummy (1,1) to Q and continue - selector = g2.api.And( - g2.api.And(g2.fp.IsZero(&Q.P.X.A0), g2.fp.IsZero(&Q.P.X.A1)), - g2.api.And(g2.fp.IsZero(&Q.P.Y.A0), g2.fp.IsZero(&Q.P.Y.A1)), - ) - one := g2.Ext2.One() - Q = g2.Select(selector, &G2Affine{P: g2AffP{X: *one, Y: *one}, Lines: nil}, Q) + one := g2.fr.One() + selector0 = g2.fr.IsZero(s) + _s = g2.fr.Select(selector0, one, s) } - // We use the endomorphism à la GLV to compute [s]Q as - // [s1]Q + [s2]Φ(Q) - // the sub-scalars s1, s2 can be negative (bigints) in the hint. If so, - // they will be reduced in-circuit modulo the SNARK scalar field and not - // the emulated field. So we return in the hint |s1|, |s2| and boolean - // flags sdBits to negate the points Q, Φ(Q) instead of the corresponding - // sub-scalars. + // Instead of computing [s]Q=R, we check that R-[s]Q == 0. + // This is equivalent to [v]R + [-s*v]Q = 0 for some nonzero v. + // + // Using LLL-based lattice reduction we find small sub-scalars: + // [v1 + λ*v2]R + [u1 + λ*u2]Q = 0 + // [v1]R + [v2]Φ(R) + [u1]Q + [u2]Φ(Q) = 0 + // + // where u1, u2, v1, v2 < c*r^{1/4} with c ≈ 1.25 (proven bound from LLL). - // decompose s into s1 and s2 - sdBits, sd, err := g2.fr.NewHintGeneric(decomposeScalarG1, 2, 2, nil, []*emulated.Element[ScalarField]{s, g2.eigenvalue}) + // decompose s into u1, u2, v1, v2 + signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}) if err != nil { - panic(fmt.Sprintf("compute GLV decomposition: %v", err)) - } - s1, s2 := sd[0], sd[1] - selector1, selector2 := sdBits[0], sdBits[1] - s3 := g2.fr.Select(selector1, g2.fr.Neg(s1), s1) - s4 := g2.fr.Select(selector2, g2.fr.Neg(s2), s2) - // s == s3 + [λ]s4 - g2.fr.AssertIsEqual( - g2.fr.Add(s3, g2.fr.Mul(s4, g2.eigenvalue)), - s, + panic(fmt.Sprintf("rationalReconstructExtG2 hint: %v", err)) + } + u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] + isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] + + // Check that: s*(v1 + λ*v2) + u1 + λ*u2 = 0 + var st ScalarField + sv1 := g2.fr.Mul(_s, v1) + sλv2 := g2.fr.Mul(_s, g2.fr.Mul(g2.eigenvalue, v2)) + λu2 := g2.fr.Mul(g2.eigenvalue, u2) + zero := g2.fr.Zero() + + lhs1 := g2.fr.Select(isNegv1, zero, sv1) + lhs2 := g2.fr.Select(isNegv2, zero, sλv2) + lhs3 := g2.fr.Select(isNegu1, zero, u1) + lhs4 := g2.fr.Select(isNegu2, zero, λu2) + lhs := g2.fr.Add( + g2.fr.Add(lhs1, lhs2), + g2.fr.Add(lhs3, lhs4), ) - s1bits := g2.fr.ToBits(s1) - s2bits := g2.fr.ToBits(s2) + rhs1 := g2.fr.Select(isNegv1, sv1, zero) + rhs2 := g2.fr.Select(isNegv2, sλv2, zero) + rhs3 := g2.fr.Select(isNegu1, u1, zero) + rhs4 := g2.fr.Select(isNegu2, λu2, zero) + rhs := g2.fr.Add( + g2.fr.Add(rhs1, rhs2), + g2.fr.Add(rhs3, rhs4), + ) + + g2.fr.AssertIsEqual(lhs, rhs) + + // Ensure the denominator v1 + λ*v2 is non-zero to prevent trivial decomposition + den := g2.fr.Add(v1, g2.fr.Mul(g2.eigenvalue, v2)) + g2.fr.AssertIsDifferent(den, g2.fr.Zero()) + + // Hint the scalar multiplication R = [s]Q + _, point, _, err := emulated.NewVarGenericHint(g2.api, 0, 4, 0, nil, + []*emulated.Element[BaseField]{&Q.P.X.A0, &Q.P.X.A1, &Q.P.Y.A0, &Q.P.Y.A1}, + []*emulated.Element[ScalarField]{s}, + scalarMulG2Hint) + if err != nil { + panic(fmt.Sprintf("scalarMulG2Hint: %v", err)) + } + R := &G2Affine{ + P: g2AffP{ + X: fields_bls12381.E2{A0: *point[0], A1: *point[1]}, + Y: fields_bls12381.E2{A0: *point[2], A1: *point[3]}, + }, + } + // Preserve the original hinted R for return value (before edge-case modifications) + originalR := R + + // handle (0,0)-point and edge cases + var _selector0, _selector1 frontend.Variable + _Q := Q + if cfg.CompleteArithmetic { + one := g2.Ext2.One() + // if Q=(0,0) we assign a dummy point + _selector0 = g2.api.And(g2.Ext2.IsZero(&Q.P.X), g2.Ext2.IsZero(&Q.P.Y)) + _Q = g2.Select(_selector0, &G2Affine{P: g2AffP{X: *one, Y: *one}}, Q) + // if R.X == Q.X (happens when s=±1, so R=±Q), the incomplete addition fails + // We check this BEFORE potentially modifying R + _selector1 = g2.Ext2.IsZero(g2.Ext2.Sub(&Q.P.X, &R.P.X)) + // if s=0/s=-1 (selector0), Q=(0,0) (_selector0), or R.X==Q.X (_selector1), + // we assign a dummy point to R + selectorAny := g2.api.Or(g2.api.Or(selector0, _selector0), _selector1) + R = g2.Select(selectorAny, &G2Affine{P: g2AffP{X: *one, Y: *one}}, R) + } // precompute -Q, -Φ(Q), Φ(Q) - var tableQ, tablePhiQ [3]*G2Affine - negQY := g2.Ext2.Neg(&Q.P.Y) + var tableQ, tablePhiQ [2]*G2Affine + negQY := g2.Ext2.Neg(&_Q.P.Y) tableQ[1] = &G2Affine{ P: g2AffP{ - X: Q.P.X, - Y: *g2.Ext2.Select(selector1, negQY, &Q.P.Y), + X: _Q.P.X, + Y: *g2.Ext2.Select(isNegu1, negQY, &_Q.P.Y), }, } tableQ[0] = g2.neg(tableQ[1]) tablePhiQ[1] = &G2Affine{ P: g2AffP{ - X: *g2.Ext2.MulByElement(&Q.P.X, g2.w2), - Y: *g2.Ext2.Select(selector2, negQY, &Q.P.Y), + X: *g2.Ext2.MulByElement(&_Q.P.X, g2.w2), + Y: *g2.Ext2.Select(isNegu2, negQY, &_Q.P.Y), }, } tablePhiQ[0] = g2.neg(tablePhiQ[1]) - tableQ[2] = g2.triple(tableQ[1]) - tablePhiQ[2] = &G2Affine{ + + // precompute -R, -Φ(R), Φ(R) + var tableR, tablePhiR [2]*G2Affine + negRY := g2.Ext2.Neg(&R.P.Y) + tableR[1] = &G2Affine{ P: g2AffP{ - X: *g2.Ext2.MulByElement(&tableQ[2].P.X, g2.w2), - Y: *g2.Ext2.Select(selector2, g2.Ext2.Neg(&tableQ[2].P.Y), &tableQ[2].P.Y), + X: R.P.X, + Y: *g2.Ext2.Select(isNegv1, negRY, &R.P.Y), }, } - - // we suppose that the first bits of the sub-scalars are 1 and set: - // Acc = Q + Φ(Q) - Acc := g2.add(tableQ[1], tablePhiQ[1]) - - // At each iteration we need to compute: - // [2]Acc ± Q ± Φ(Q). - // We can compute [2]Acc and look up the (precomputed) point P from: - // B1 = Q+Φ(Q) - // B2 = -Q-Φ(Q) - // B3 = Q-Φ(Q) - // B4 = -Q+Φ(Q) - // - // If we extend this by merging two iterations, we need to look up P and P' - // both from {B1, B2, B3, B4} and compute: - // [2]([2]Acc+P)+P' = [4]Acc + T - // where T = [2]P+P'. So at each (merged) iteration, we can compute [4]Acc - // and look up T from the precomputed list of points: - // - // T = [3](Q + Φ(Q)) - // P = B1 and P' = B1 - t1 := g2.add(tableQ[2], tablePhiQ[2]) - // T = Q + Φ(Q) - // P = B1 and P' = B2 - T2 := Acc - // T = [3]Q + Φ(Q) - // P = B1 and P' = B3 - T3 := g2.add(tableQ[2], tablePhiQ[1]) - // T = Q + [3]Φ(Q) - // P = B1 and P' = B4 - t4 := g2.add(tableQ[1], tablePhiQ[2]) - // T = -[3](Q + Φ(Q)) - // P = B2 and P' = B2 - T6 := g2.neg(t1) - // T = -Q - [3]Φ(Q) - // P = B2 and P' = B3 - T7 := g2.neg(t4) - // T = [3]Q - Φ(Q) - // P = B3 and P' = B1 - t9 := g2.add(tableQ[2], tablePhiQ[0]) - // T = Q - [3]Φ(Q) - // P = B3 and P' = B2 - t := g2.neg(tablePhiQ[2]) - T10 := g2.add(tableQ[1], t) - // T = [3](Q - Φ(Q)) - // P = B3 and P' = B3 - T11 := g2.add(tableQ[2], t) - // T = -Φ(Q) + Q - // P = B3 and P' = B4 - T12 := g2.add(tablePhiQ[0], tableQ[1]) - // T = Φ(Q) - [3]Q - // P = B4 and P' = B2 - T14 := g2.neg(t9) - // T = Φ(Q) - Q - // P = B4 and P' = B3 - T15 := g2.neg(T12) - // note that half the points are negatives of the other half, - // hence have the same X coordinates. - - nbits := 130 - for i := nbits - 2; i > 0; i -= 2 { + tableR[0] = g2.neg(tableR[1]) + tablePhiR[1] = &G2Affine{ + P: g2AffP{ + X: *g2.Ext2.MulByElement(&R.P.X, g2.w2), + Y: *g2.Ext2.Select(isNegv2, negRY, &R.P.Y), + }, + } + tablePhiR[0] = g2.neg(tablePhiR[1]) + + // precompute -Q-R, Q+R, Q-R, -Q+R (combining the two points Q and R) + var tableS [4]*G2Affine + tableS[0] = g2.add(tableQ[0], tableR[0]) // -Q - R + tableS[1] = g2.neg(tableS[0]) // Q + R + tableS[2] = g2.add(tableQ[1], tableR[0]) // Q - R + tableS[3] = g2.neg(tableS[2]) // -Q + R + + // precompute -Φ(Q)-Φ(R), Φ(Q)+Φ(R), Φ(Q)-Φ(R), -Φ(Q)+Φ(R) (combining endomorphisms) + var tablePhiS [4]*G2Affine + tablePhiS[0] = g2.add(tablePhiQ[0], tablePhiR[0]) // -Φ(Q) - Φ(R) + tablePhiS[1] = g2.neg(tablePhiS[0]) // Φ(Q) + Φ(R) + tablePhiS[2] = g2.add(tablePhiQ[1], tablePhiR[0]) // Φ(Q) - Φ(R) + tablePhiS[3] = g2.neg(tablePhiS[2]) // -Φ(Q) + Φ(R) + + // Acc = Q + Φ(Q) + R + Φ(R) + Acc := g2.add(tableS[1], tablePhiS[1]) + B1 := Acc + + // Add G2 generator to Acc to avoid incomplete additions in the loop. + // At the end, since [u1]Q + [u2]Φ(Q) + [v1]R + [v2]Φ(R) = 0, + // Acc will equal [2^nbits]G2 (precomputed). + g2GenPoint := &G2Affine{P: *g2.g2Gen} + Acc = g2.add(Acc, g2GenPoint) + + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 + nbits := (st.Modulus().BitLen()+3)/4 + 2 + u1bits := g2.fr.ToBits(u1) + u2bits := g2.fr.ToBits(u2) + v1bits := g2.fr.ToBits(v1) + v2bits := g2.fr.ToBits(v2) + + // Precompute all 16 combinations: ±Q ± Φ(Q) ± R ± Φ(R) + // Using tableS (Q±R) and tablePhiS (Φ(Q)±Φ(R)) to match G1 pattern + // B1 = (Q+R) + (Φ(Q)+Φ(R)) = Q + R + Φ(Q) + Φ(R) + B2 := g2.add(tableS[1], tablePhiS[2]) // (Q+R) + (Φ(Q)-Φ(R)) = Q + R + Φ(Q) - Φ(R) + B3 := g2.add(tableS[1], tablePhiS[3]) // (Q+R) + (-Φ(Q)+Φ(R)) = Q + R - Φ(Q) + Φ(R) + B4 := g2.add(tableS[1], tablePhiS[0]) // (Q+R) + (-Φ(Q)-Φ(R)) = Q + R - Φ(Q) - Φ(R) + B5 := g2.add(tableS[2], tablePhiS[1]) // (Q-R) + (Φ(Q)+Φ(R)) = Q - R + Φ(Q) + Φ(R) + B6 := g2.add(tableS[2], tablePhiS[2]) // (Q-R) + (Φ(Q)-Φ(R)) = Q - R + Φ(Q) - Φ(R) + B7 := g2.add(tableS[2], tablePhiS[3]) // (Q-R) + (-Φ(Q)+Φ(R)) = Q - R - Φ(Q) + Φ(R) + B8 := g2.add(tableS[2], tablePhiS[0]) // (Q-R) + (-Φ(Q)-Φ(R)) = Q - R - Φ(Q) - Φ(R) + B10 := g2.neg(B7) // -Q + R + Φ(Q) - Φ(R) + B12 := g2.neg(B5) // -Q + R - Φ(Q) - Φ(R) + B14 := g2.neg(B3) // -Q - R + Φ(Q) - Φ(R) + B16 := g2.neg(B1) // -Q - R - Φ(Q) - Φ(R) + + var Bi *G2Affine + for i := nbits - 1; i > 0; i-- { // selectorY takes values in [0,15] selectorY := g2.api.Add( - s1bits[i], - g2.api.Mul(s2bits[i], 2), - g2.api.Mul(s1bits[i-1], 4), - g2.api.Mul(s2bits[i-1], 8), + u1bits[i], + g2.api.Mul(u2bits[i], 2), + g2.api.Mul(v1bits[i], 4), + g2.api.Mul(v2bits[i], 8), ) // selectorX takes values in [0,7] s.t.: // - when selectorY < 8: selectorX = selectorY // - when selectorY >= 8: selectorX = 15 - selectorY selectorX := g2.api.Add( - g2.api.Mul(selectorY, g2.api.Sub(1, g2.api.Mul(s2bits[i-1], 2))), - g2.api.Mul(s2bits[i-1], 15), + g2.api.Mul(selectorY, g2.api.Sub(1, g2.api.Mul(v2bits[i], 2))), + g2.api.Mul(v2bits[i], 15), ) // Half of the Bi.X are distinct (8-to-1) and Y[i] = -Y[15-i], // so we use 8-to-1 Mux for both X and Y, with conditional negation for Y. - T := &G2Affine{ + Bi = &G2Affine{ P: g2AffP{ X: fields_bls12381.E2{ - A0: *g2.fp.Mux(selectorX, &T6.P.X.A0, &T10.P.X.A0, &T14.P.X.A0, &T2.P.X.A0, &T7.P.X.A0, &T11.P.X.A0, &T15.P.X.A0, &T3.P.X.A0), - A1: *g2.fp.Mux(selectorX, &T6.P.X.A1, &T10.P.X.A1, &T14.P.X.A1, &T2.P.X.A1, &T7.P.X.A1, &T11.P.X.A1, &T15.P.X.A1, &T3.P.X.A1), + A0: *g2.fp.Mux(selectorX, + &B16.P.X.A0, &B8.P.X.A0, &B14.P.X.A0, &B6.P.X.A0, &B12.P.X.A0, &B4.P.X.A0, &B10.P.X.A0, &B2.P.X.A0, + ), + A1: *g2.fp.Mux(selectorX, + &B16.P.X.A1, &B8.P.X.A1, &B14.P.X.A1, &B6.P.X.A1, &B12.P.X.A1, &B4.P.X.A1, &B10.P.X.A1, &B2.P.X.A1, + ), }, - Y: *g2.muxE2Y8Signed(s2bits[i-1], selectorX, - [8]*emulated.Element[BaseField]{&T6.P.Y.A0, &T10.P.Y.A0, &T14.P.Y.A0, &T2.P.Y.A0, &T7.P.Y.A0, &T11.P.Y.A0, &T15.P.Y.A0, &T3.P.Y.A0}, - [8]*emulated.Element[BaseField]{&T6.P.Y.A1, &T10.P.Y.A1, &T14.P.Y.A1, &T2.P.Y.A1, &T7.P.Y.A1, &T11.P.Y.A1, &T15.P.Y.A1, &T3.P.Y.A1}, + Y: *g2.muxE2Y8Signed(v2bits[i], selectorX, + [8]*emulated.Element[BaseField]{&B16.P.Y.A0, &B8.P.Y.A0, &B14.P.Y.A0, &B6.P.Y.A0, &B12.P.Y.A0, &B4.P.Y.A0, &B10.P.Y.A0, &B2.P.Y.A0}, + [8]*emulated.Element[BaseField]{&B16.P.Y.A1, &B8.P.Y.A1, &B14.P.Y.A1, &B6.P.Y.A1, &B12.P.Y.A1, &B4.P.Y.A1, &B10.P.Y.A1, &B2.P.Y.A1}, ), }, } - // Acc = [4]Acc + T - Acc = g2.double(Acc) - Acc = g2.doubleAndAdd(Acc, T) + // Acc = [2]Acc + Bi + Acc = g2.doubleAndAdd(Acc, Bi) } - // i = 0 - // subtract the Q, Φ(Q) if the first bits are 0. - // When cfg.CompleteArithmetic is set, we use AddUnified instead of Add. - // This means when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0). - tableQ[0] = addFn(tableQ[0], Acc) - Acc = g2.Select(s1bits[0], Acc, tableQ[0]) - tablePhiQ[0] = addFn(tablePhiQ[0], Acc) - Acc = g2.Select(s2bits[0], Acc, tablePhiQ[0]) + // i = 0: subtract Q, Φ(Q), R, Φ(R) if the first bits are 0 + tableQ[0] = g2.add(tableQ[0], Acc) + Acc = g2.Select(u1bits[0], Acc, tableQ[0]) + tablePhiQ[0] = g2.add(tablePhiQ[0], Acc) + Acc = g2.Select(u2bits[0], Acc, tablePhiQ[0]) + tableR[0] = g2.add(tableR[0], Acc) + Acc = g2.Select(v1bits[0], Acc, tableR[0]) + tablePhiR[0] = g2.add(tablePhiR[0], Acc) + Acc = g2.Select(v2bits[0], Acc, tablePhiR[0]) + + // Acc should now be [2^nbits]G2 since [u1]Q + [u2]Φ(Q) + [v1]R + [v2]Φ(R) = 0 + // and we added G2 to the initial accumulator. + expected := &G2Affine{P: *g2.g2GenNbits} + + if cfg.CompleteArithmetic { + // if Q=(0,0), s=0, or R.X==Q.X, skip the check + skip := g2.api.Or(g2.api.Or(selector0, _selector0), _selector1) + Acc = g2.Select(skip, expected, Acc) + } + g2.AssertIsEqual(Acc, expected) if cfg.CompleteArithmetic { - zero := g2.Ext2.Zero() - Acc = g2.Select(selector, &G2Affine{P: g2AffP{X: *zero, Y: *zero}}, Acc) + // if s=0 or Q=(0,0), return (0,0); otherwise return the original hinted R + zeroE2 := g2.Ext2.Zero() + returnZero := g2.api.Or(selector0, _selector0) + return g2.Select(returnZero, &G2Affine{P: g2AffP{X: *zeroE2, Y: *zeroE2}}, originalR) } - return Acc + return R } // MultiScalarMul computes the multi scalar multiplication of the points P and @@ -808,9 +841,9 @@ func (g2 *G2) MultiScalarMul(p []*G2Affine, s []*Scalar, opts ...algopts.Algebra return nil, fmt.Errorf("mismatching points and scalars slice lengths") } n := len(p) - res := g2.scalarMulGLV(p[0], s[0], opts...) + res := g2.ScalarMul(p[0], s[0], opts...) for i := 1; i < n; i++ { - q := g2.scalarMulGLV(p[i], s[i], opts...) + q := g2.ScalarMul(p[i], s[i], opts...) res = addFn(res, q) } return res, nil @@ -820,10 +853,10 @@ func (g2 *G2) MultiScalarMul(p []*G2Affine, s []*Scalar, opts ...algopts.Algebra return nil, fmt.Errorf("need scalar for folding") } gamma := s[0] - res := g2.scalarMulGLV(p[len(p)-1], gamma, opts...) + res := g2.ScalarMul(p[len(p)-1], gamma, opts...) for i := len(p) - 2; i > 0; i-- { res = addFn(p[i], res) - res = g2.scalarMulGLV(res, gamma, opts...) + res = g2.ScalarMul(res, gamma, opts...) } res = addFn(p[0], res) return res, nil diff --git a/std/algebra/emulated/sw_bls12381/g2_test.go b/std/algebra/emulated/sw_bls12381/g2_test.go index 2c674d4262..dba935614f 100644 --- a/std/algebra/emulated/sw_bls12381/g2_test.go +++ b/std/algebra/emulated/sw_bls12381/g2_test.go @@ -9,29 +9,28 @@ import ( bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/algebra/emulated/fields_bls12381" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/test" ) -type mulG2Circuit struct { +type mulG2GLVAndFakeGLVCircuit struct { In, Res G2Affine S Scalar } -func (c *mulG2Circuit) Define(api frontend.API) error { +func (c *mulG2GLVAndFakeGLVCircuit) Define(api frontend.API) error { g2, err := NewG2(api) if err != nil { return fmt.Errorf("new G2 struct: %w", err) } - res1 := g2.scalarMulGLV(&c.In, &c.S) - res2 := g2.scalarMulGeneric(&c.In, &c.S) - g2.AssertIsEqual(res1, &c.Res) - g2.AssertIsEqual(res2, &c.Res) + res := g2.scalarMulGLVAndFakeGLV(&c.In, &c.S) + g2.AssertIsEqual(res, &c.Res) return nil } -func TestScalarMulG2TestSolve(t *testing.T) { +func TestScalarMulG2GLVAndFakeGLV(t *testing.T) { assert := test.NewAssert(t) var r fr_bls12381.Element _, _ = r.SetRandom() @@ -41,12 +40,12 @@ func TestScalarMulG2TestSolve(t *testing.T) { _, _, _, gen := bls12381.Generators() res.ScalarMultiplication(&gen, s) - witness := mulG2Circuit{ + witness := mulG2GLVAndFakeGLVCircuit{ In: NewG2Affine(gen), S: NewScalar(r), Res: NewG2Affine(res), } - err := test.IsSolved(&mulG2Circuit{}, &witness, ecc.BN254.ScalarField()) + err := test.IsSolved(&mulG2GLVAndFakeGLVCircuit{}, &witness, ecc.BN254.ScalarField()) assert.NoError(err) } @@ -356,3 +355,74 @@ func TestMultiScalarMul(t *testing.T) { }, &assignment, ecc.BN254.ScalarField()) assert.NoError(err) } + +// Circuit for testing G2 scalar multiplication with complete arithmetic (handles edge cases) +type scalarMulG2CompleteCircuit struct { + In, Res G2Affine + S Scalar +} + +func (c *scalarMulG2CompleteCircuit) Define(api frontend.API) error { + g2, err := NewG2(api) + if err != nil { + return fmt.Errorf("new G2 struct: %w", err) + } + res := g2.scalarMulGLVAndFakeGLV(&c.In, &c.S, algopts.WithCompleteArithmetic()) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +// TestScalarMulG2EdgeCases tests edge cases: s=0, s=1, s=-1, Q=(0,0) +func TestScalarMulG2EdgeCases(t *testing.T) { + assert := test.NewAssert(t) + _, _, _, gen := bls12381.Generators() + + // Test case: s = 1 (result should be Q) + t.Run("s=1", func(t *testing.T) { + var s fr_bls12381.Element + s.SetOne() + var res bls12381.G2Affine + res.Set(&gen) // [1]Q = Q + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + + // Test case: s = -1 (result should be -Q) + t.Run("s=-1", func(t *testing.T) { + var s fr_bls12381.Element + s.SetOne() + s.Neg(&s) // s = -1 + var res bls12381.G2Affine + res.Neg(&gen) // [-1]Q = -Q + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + + // Test case: s = 0 (result should be (0,0)) + t.Run("s=0", func(t *testing.T) { + var s fr_bls12381.Element + s.SetZero() + var res bls12381.G2Affine // zero value is (0,0) + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + +} diff --git a/std/algebra/emulated/sw_bls12381/hints.go b/std/algebra/emulated/sw_bls12381/hints.go index c8be1bea12..62bafd440a 100644 --- a/std/algebra/emulated/sw_bls12381/hints.go +++ b/std/algebra/emulated/sw_bls12381/hints.go @@ -5,6 +5,7 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" @@ -24,6 +25,8 @@ func GetHints() []solver.Hint { pairingCheckHint, millerLoopAndCheckFinalExpHint, decomposeScalarG1, + scalarMulG2Hint, + rationalReconstructExtG2, g1SqrtRatioHint, g2SqrtRatioHint, unmarshalG1, @@ -450,3 +453,94 @@ func unmarshalG1(mod *big.Int, nativeInputs []*big.Int, outputs []*big.Int) erro return nil }) } + +func scalarMulG2Hint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintContext(field, inputs, outputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 2 { + return fmt.Errorf("expecting two moduli, got %d", len(moduli)) + } + baseModulus, scalarModulus := moduli[0], moduli[1] + baseInputs, baseOutputs := hc.InputsOutputs(baseModulus) + scalarInputs, _ := hc.InputsOutputs(scalarModulus) + if len(baseInputs) != 4 { + return fmt.Errorf("expecting four base inputs (Q.X.A0, Q.X.A1, Q.Y.A0, Q.Y.A1), got %d", len(baseInputs)) + } + if len(baseOutputs) != 4 { + return fmt.Errorf("expecting four base outputs, got %d", len(baseOutputs)) + } + if len(scalarInputs) != 1 { + return fmt.Errorf("expecting one scalar input, got %d", len(scalarInputs)) + } + + // compute the resulting point [s]Q on G2 + var Q bls12381.G2Affine + Q.X.A0.SetBigInt(baseInputs[0]) + Q.X.A1.SetBigInt(baseInputs[1]) + Q.Y.A0.SetBigInt(baseInputs[2]) + Q.Y.A1.SetBigInt(baseInputs[3]) + Q.ScalarMultiplication(&Q, scalarInputs[0]) + Q.X.A0.BigInt(baseOutputs[0]) + Q.X.A1.BigInt(baseOutputs[1]) + Q.Y.A0.BigInt(baseOutputs[2]) + Q.Y.A1.BigInt(baseOutputs[3]) + return nil + }) +} + +func rationalReconstructExtG2(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 1 { + return fmt.Errorf("expecting one modulus, got %d", len(moduli)) + } + _, nativeOutputs := hc.NativeInputsOutputs() + if len(nativeOutputs) != 4 { + return fmt.Errorf("expecting four outputs, got %d", len(nativeOutputs)) + } + emuInputs, emuOutputs := hc.InputsOutputs(moduli[0]) + if len(emuInputs) != 2 { + return fmt.Errorf("expecting two inputs, got %d", len(emuInputs)) + } + if len(emuOutputs) != 4 { + return fmt.Errorf("expecting four outputs, got %d", len(emuOutputs)) + } + + // Use lattice reduction to find (x, y, z, t) such that + // k ≡ (x + λ*y) / (z + λ*t) (mod r) + // + // in-circuit we check that R - [s]Q = 0 or equivalently R + [-s]Q = 0 + // so here we use k = -s. + k := new(big.Int).Neg(emuInputs[0]) + k.Mod(k, moduli[0]) + rc := lattice.NewReconstructor(moduli[0]).SetLambda(emuInputs[1]) + res := rc.RationalReconstructExt(k) + x, y, z, t := res[0], res[1], res[2], res[3] + + // u1 = x, u2 = y, v1 = z, v2 = t + emuOutputs[0].Abs(x) + emuOutputs[1].Abs(y) + emuOutputs[2].Abs(z) + emuOutputs[3].Abs(t) + + // signs + nativeOutputs[0].SetUint64(0) + nativeOutputs[1].SetUint64(0) + nativeOutputs[2].SetUint64(0) + nativeOutputs[3].SetUint64(0) + + if x.Sign() < 0 { + nativeOutputs[0].SetUint64(1) + } + if y.Sign() < 0 { + nativeOutputs[1].SetUint64(1) + } + if z.Sign() < 0 { + nativeOutputs[2].SetUint64(1) + } + if t.Sign() < 0 { + nativeOutputs[3].SetUint64(1) + } + return nil + }) +} diff --git a/std/algebra/emulated/sw_bn254/g2.go b/std/algebra/emulated/sw_bn254/g2.go index 04b2b0133c..3b3a5b5571 100644 --- a/std/algebra/emulated/sw_bn254/g2.go +++ b/std/algebra/emulated/sw_bn254/g2.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/algebra/emulated/fields_bn254" "github.com/consensys/gnark/std/math/emulated" ) @@ -13,9 +14,16 @@ import ( type G2 struct { api frontend.API fp *emulated.Field[BaseField] + fr *emulated.Field[ScalarField] *fields_bn254.Ext2 w *emulated.Element[BaseField] u, v *fields_bn254.E2 + // GLV eigenvalue for endomorphism + eigenvalue *emulated.Element[ScalarField] + + // Precomputed G2 generator and its multiple for GLV+FakeGLV + g2Gen *g2AffP // G2 generator + g2GenNbits *g2AffP // [2^(nbits-1)]G2 where nbits = (r.BitLen()+3)/4 + 2 } type g2AffP struct { @@ -46,7 +54,14 @@ func NewG2(api frontend.API) (*G2, error) { if err != nil { return nil, fmt.Errorf("new base api: %w", err) } + fr, err := emulated.NewField[ScalarField](api) + if err != nil { + return nil, fmt.Errorf("new scalar api: %w", err) + } + // w = thirdRootOneG2 = thirdRootOneG1^2 (used for both psi2 and GLV endomorphism) w := fp.NewElement("21888242871839275220042445260109153167277707414472061641714758635765020556616") + // GLV eigenvalue: lambda such that phi(P) = [lambda]P + eigenvalue := fr.NewElement("4407920970296243842393367215006156084916469457145843978461") u := fields_bn254.E2{ A0: *fp.NewElement("21575463638280843010398324269430826099269044274347216827212613867836435027261"), A1: *fp.NewElement("10307601595873709700152284273816112264069230130616436755625194854815875713954"), @@ -55,13 +70,43 @@ func NewG2(api frontend.API) (*G2, error) { A0: *fp.NewElement("2821565182194536844548159561693502659359617185244120367078079554186484126554"), A1: *fp.NewElement("3505843767911556378687030309984248845540243509899259641013678093033130930403"), } + + // Precomputed G2 generator for GLV+FakeGLV + g2Gen := &g2AffP{ + X: fields_bn254.E2{ + A0: *fp.NewElement("10857046999023057135944570762232829481370756359578518086990519993285655852781"), + A1: *fp.NewElement("11559732032986387107991004021392285783925812861821192530917403151452391805634"), + }, + Y: fields_bn254.E2{ + A0: *fp.NewElement("8495653923123431417604973247489272438418190587263600148770280649306958101930"), + A1: *fp.NewElement("4082367875863433681332203403145435568316851327593401208105741076214120093531"), + }, + } + // [2^(nbits-1)]G2 where nbits = (254+3)/4 + 2 = 66, so this is [2^65]G2 + // The loop does nbits-1 doublings, so the generator accumulates to [2^(nbits-1)]G2 + g2GenNbits := &g2AffP{ + X: fields_bn254.E2{ + A0: *fp.NewElement("6099622139700402640581725571890015148411145321742729577177999911575645303725"), + A1: *fp.NewElement("9870328428465937988383794519490899227160817120884239055108452134207619193487"), + }, + Y: fields_bn254.E2{ + A0: *fp.NewElement("16268382111792290652321980382595025991160708296314050973435867558225525677485"), + A1: *fp.NewElement("15377126855853471483498618408547895055706247905282062963450025729940352455943"), + }, + } + return &G2{ - api: api, - fp: fp, - Ext2: fields_bn254.NewExt2(api), - w: w, - u: &u, - v: &v, + api: api, + fp: fp, + fr: fr, + Ext2: fields_bn254.NewExt2(api), + w: w, + eigenvalue: eigenvalue, + u: &u, + v: &v, + // GLV+FakeGLV precomputed values + g2Gen: g2Gen, + g2GenNbits: g2GenNbits, }, nil } @@ -290,3 +335,333 @@ func (g2 *G2) IsEqual(p, q *G2Affine) frontend.Variable { yEqual := g2.Ext2.IsEqual(&p.P.Y, &q.P.Y) return g2.api.And(xEqual, yEqual) } + +// Select selects between p and q given the selector b. If b == 1, then returns +// p and q otherwise. +func (g2 *G2) Select(b frontend.Variable, p, q *G2Affine) *G2Affine { + x := g2.Ext2.Select(b, &p.P.X, &q.P.X) + y := g2.Ext2.Select(b, &p.P.Y, &q.P.Y) + return &G2Affine{ + P: g2AffP{X: *x, Y: *y}, + Lines: nil, + } +} + +func (g2 G2) triple(p *G2Affine) *G2Affine { + mone := g2.fp.NewElement(-1) + + // compute λ1 = (3p.x²)/2p.y + xx := g2.Square(&p.P.X) + xx = g2.MulByConstElement(xx, big.NewInt(3)) + y2 := g2.Double(&p.P.Y) + λ1 := g2.DivUnchecked(xx, y2) + + // x2 = λ1²-2p.x + x20 := g2.fp.Eval([][]*baseEl{{&λ1.A0, &λ1.A0}, {mone, &λ1.A1, &λ1.A1}, {mone, &p.P.X.A0}}, []int{1, 1, 2}) + x21 := g2.fp.Eval([][]*baseEl{{&λ1.A0, &λ1.A1}, {mone, &p.P.X.A1}}, []int{2, 2}) + x2 := &fields_bn254.E2{A0: *x20, A1: *x21} + + // omit y2 computation, and + // compute λ2 = 2p.y/(x2 − p.x) − λ1. + x1x2 := g2.Sub(&p.P.X, x2) + λ2 := g2.DivUnchecked(y2, x1x2) + λ2 = g2.Sub(λ2, λ1) + + // compute x3 =λ2²-p.x-x2 + x30 := g2.fp.Eval([][]*baseEl{{&λ2.A0, &λ2.A0}, {mone, &λ2.A1, &λ2.A1}, {mone, &p.P.X.A0}, {mone, x20}}, []int{1, 1, 1, 1}) + x31 := g2.fp.Eval([][]*baseEl{{&λ2.A0, &λ2.A1}, {mone, &p.P.X.A1}, {mone, x21}}, []int{2, 1, 1}) + x3 := &fields_bn254.E2{A0: *x30, A1: *x31} + + // compute y3 = λ2*(p.x - x3)-p.y + y3 := g2.Ext2.Sub(&p.P.X, x3) + y30 := g2.fp.Eval([][]*baseEl{{&λ2.A0, &y3.A0}, {mone, &λ2.A1, &y3.A1}, {mone, &p.P.Y.A0}}, []int{1, 1, 1}) + y31 := g2.fp.Eval([][]*baseEl{{&λ2.A0, &y3.A1}, {&λ2.A1, &y3.A0}, {mone, &p.P.Y.A1}}, []int{1, 1, 1}) + y3 = &fields_bn254.E2{A0: *y30, A1: *y31} + + return &G2Affine{ + P: g2AffP{ + X: *x3, + Y: *y3, + }, + } +} + +// ScalarMul computes [s]Q using an efficient endomorphism and returns it. It doesn't modify Q nor s. +// It implements the GLV+fakeGLV optimization from [EEMP25] which achieves r^(1/4) bounds +// on the sub-scalars, reducing the number of iterations in the scalar multiplication loop. +// +// ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// [EEMP25]: https://eprint.iacr.org/2025/933 +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (g2 *G2) ScalarMul(Q *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { + return g2.scalarMulGLVAndFakeGLV(Q, s, opts...) +} + +// scalarMulGLVAndFakeGLV computes [s]Q using GLV+fakeGLV with r^(1/4) bounds. +// It implements the "GLV + fake GLV" explained in [EEMP25] (Sec. 3.3). +// +// ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. +// +// [EEMP25]: https://eprint.iacr.org/2025/933 +func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(err) + } + + // handle 0-scalar + var selector0 frontend.Variable + _s := s + if cfg.CompleteArithmetic { + one := g2.fr.One() + selector0 = g2.fr.IsZero(s) + _s = g2.fr.Select(selector0, one, s) + } + + // Instead of computing [s]Q=R, we check that R-[s]Q == 0. + // This is equivalent to [v]R + [-s*v]Q = 0 for some nonzero v. + // + // Using LLL-based lattice reduction we find small sub-scalars: + // [v1 + λ*v2]R + [u1 + λ*u2]Q = 0 + // [v1]R + [v2]Φ(R) + [u1]Q + [u2]Φ(Q) = 0 + // + // where u1, u2, v1, v2 < c*r^{1/4} with c ≈ 1.25 (proven bound from LLL). + + // decompose s into u1, u2, v1, v2 + signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}) + if err != nil { + panic(fmt.Sprintf("rationalReconstructExtG2 hint: %v", err)) + } + u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] + isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] + + // Check that: s*(v1 + λ*v2) + u1 + λ*u2 = 0 + var st ScalarField + sv1 := g2.fr.Mul(_s, v1) + sλv2 := g2.fr.Mul(_s, g2.fr.Mul(g2.eigenvalue, v2)) + λu2 := g2.fr.Mul(g2.eigenvalue, u2) + zero := g2.fr.Zero() + + lhs1 := g2.fr.Select(isNegv1, zero, sv1) + lhs2 := g2.fr.Select(isNegv2, zero, sλv2) + lhs3 := g2.fr.Select(isNegu1, zero, u1) + lhs4 := g2.fr.Select(isNegu2, zero, λu2) + lhs := g2.fr.Add( + g2.fr.Add(lhs1, lhs2), + g2.fr.Add(lhs3, lhs4), + ) + + rhs1 := g2.fr.Select(isNegv1, sv1, zero) + rhs2 := g2.fr.Select(isNegv2, sλv2, zero) + rhs3 := g2.fr.Select(isNegu1, u1, zero) + rhs4 := g2.fr.Select(isNegu2, λu2, zero) + rhs := g2.fr.Add( + g2.fr.Add(rhs1, rhs2), + g2.fr.Add(rhs3, rhs4), + ) + + g2.fr.AssertIsEqual(lhs, rhs) + + // Ensure the denominator v1 + λ*v2 is non-zero to prevent trivial decomposition + den := g2.fr.Add(v1, g2.fr.Mul(g2.eigenvalue, v2)) + g2.fr.AssertIsDifferent(den, g2.fr.Zero()) + + // Hint the scalar multiplication R = [s]Q + _, point, _, err := emulated.NewVarGenericHint(g2.api, 0, 4, 0, nil, + []*emulated.Element[BaseField]{&Q.P.X.A0, &Q.P.X.A1, &Q.P.Y.A0, &Q.P.Y.A1}, + []*emulated.Element[ScalarField]{s}, + scalarMulG2Hint) + if err != nil { + panic(fmt.Sprintf("scalarMulG2Hint: %v", err)) + } + R := &G2Affine{ + P: g2AffP{ + X: fields_bn254.E2{A0: *point[0], A1: *point[1]}, + Y: fields_bn254.E2{A0: *point[2], A1: *point[3]}, + }, + } + // Preserve the original hinted R for return value (before edge-case modifications) + originalR := R + + // handle (0,0)-point and edge cases + var _selector0, _selector1 frontend.Variable + _Q := Q + if cfg.CompleteArithmetic { + one := g2.Ext2.One() + // if Q=(0,0) we assign a dummy point + _selector0 = g2.api.And(g2.Ext2.IsZero(&Q.P.X), g2.Ext2.IsZero(&Q.P.Y)) + _Q = g2.Select(_selector0, &G2Affine{P: g2AffP{X: *one, Y: *one}}, Q) + // if R.X == Q.X (happens when s=±1, so R=±Q), the incomplete addition fails + // We check this BEFORE potentially modifying R + _selector1 = g2.Ext2.IsZero(g2.Ext2.Sub(&Q.P.X, &R.P.X)) + // if s=0/s=-1 (selector0), Q=(0,0) (_selector0), or R.X==Q.X (_selector1), + // we assign a dummy point to R + selectorAny := g2.api.Or(g2.api.Or(selector0, _selector0), _selector1) + R = g2.Select(selectorAny, &G2Affine{P: g2AffP{X: *one, Y: *one}}, R) + } + + // precompute -Q, -Φ(Q), Φ(Q) + var tableQ, tablePhiQ [2]*G2Affine + negQY := g2.Ext2.Neg(&_Q.P.Y) + tableQ[1] = &G2Affine{ + P: g2AffP{ + X: _Q.P.X, + Y: *g2.Ext2.Select(isNegu1, negQY, &_Q.P.Y), + }, + } + tableQ[0] = g2.neg(tableQ[1]) + // For BN254 G2, glvPhi(Q) = (w * Q.X, Q.Y) + tablePhiQ[1] = &G2Affine{ + P: g2AffP{ + X: *g2.Ext2.MulByElement(&_Q.P.X, g2.w), + Y: *g2.Ext2.Select(isNegu2, negQY, &_Q.P.Y), + }, + } + tablePhiQ[0] = g2.neg(tablePhiQ[1]) + + // precompute -R, -Φ(R), Φ(R) + var tableR, tablePhiR [2]*G2Affine + negRY := g2.Ext2.Neg(&R.P.Y) + tableR[1] = &G2Affine{ + P: g2AffP{ + X: R.P.X, + Y: *g2.Ext2.Select(isNegv1, negRY, &R.P.Y), + }, + } + tableR[0] = g2.neg(tableR[1]) + tablePhiR[1] = &G2Affine{ + P: g2AffP{ + X: *g2.Ext2.MulByElement(&R.P.X, g2.w), + Y: *g2.Ext2.Select(isNegv2, negRY, &R.P.Y), + }, + } + tablePhiR[0] = g2.neg(tablePhiR[1]) + + // precompute -Q-R, Q+R, Q-R, -Q+R (combining the two points Q and R) + var tableS [4]*G2Affine + tableS[0] = g2.add(tableQ[0], tableR[0]) // -Q - R + tableS[1] = g2.neg(tableS[0]) // Q + R + tableS[2] = g2.add(tableQ[1], tableR[0]) // Q - R + tableS[3] = g2.neg(tableS[2]) // -Q + R + + // precompute -Φ(Q)-Φ(R), Φ(Q)+Φ(R), Φ(Q)-Φ(R), -Φ(Q)+Φ(R) (combining endomorphisms) + var tablePhiS [4]*G2Affine + tablePhiS[0] = g2.add(tablePhiQ[0], tablePhiR[0]) // -Φ(Q) - Φ(R) + tablePhiS[1] = g2.neg(tablePhiS[0]) // Φ(Q) + Φ(R) + tablePhiS[2] = g2.add(tablePhiQ[1], tablePhiR[0]) // Φ(Q) - Φ(R) + tablePhiS[3] = g2.neg(tablePhiS[2]) // -Φ(Q) + Φ(R) + + // Acc = Q + Φ(Q) + R + Φ(R) + Acc := g2.add(tableS[1], tablePhiS[1]) + B1 := Acc + + // Add G2 generator to Acc to avoid incomplete additions in the loop. + // At the end, since [u1]Q + [u2]Φ(Q) + [v1]R + [v2]Φ(R) = 0, + // Acc will equal [2^nbits]G2 (precomputed). + g2GenPoint := &G2Affine{P: *g2.g2Gen} + Acc = g2.add(Acc, g2GenPoint) + + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 + nbits := (st.Modulus().BitLen()+3)/4 + 2 + u1bits := g2.fr.ToBits(u1) + u2bits := g2.fr.ToBits(u2) + v1bits := g2.fr.ToBits(v1) + v2bits := g2.fr.ToBits(v2) + + // Precompute all 16 combinations: ±Q ± Φ(Q) ± R ± Φ(R) + // Using tableS (Q±R) and tablePhiS (Φ(Q)±Φ(R)) to match G1 pattern + // B1 = (Q+R) + (Φ(Q)+Φ(R)) = Q + R + Φ(Q) + Φ(R) + B2 := g2.add(tableS[1], tablePhiS[2]) // (Q+R) + (Φ(Q)-Φ(R)) = Q + R + Φ(Q) - Φ(R) + B3 := g2.add(tableS[1], tablePhiS[3]) // (Q+R) + (-Φ(Q)+Φ(R)) = Q + R - Φ(Q) + Φ(R) + B4 := g2.add(tableS[1], tablePhiS[0]) // (Q+R) + (-Φ(Q)-Φ(R)) = Q + R - Φ(Q) - Φ(R) + B5 := g2.add(tableS[2], tablePhiS[1]) // (Q-R) + (Φ(Q)+Φ(R)) = Q - R + Φ(Q) + Φ(R) + B6 := g2.add(tableS[2], tablePhiS[2]) // (Q-R) + (Φ(Q)-Φ(R)) = Q - R + Φ(Q) - Φ(R) + B7 := g2.add(tableS[2], tablePhiS[3]) // (Q-R) + (-Φ(Q)+Φ(R)) = Q - R - Φ(Q) + Φ(R) + B8 := g2.add(tableS[2], tablePhiS[0]) // (Q-R) + (-Φ(Q)-Φ(R)) = Q - R - Φ(Q) - Φ(R) + B9 := g2.neg(B8) // -Q + R + Φ(Q) + Φ(R) + B10 := g2.neg(B7) // -Q + R + Φ(Q) - Φ(R) + B11 := g2.neg(B6) // -Q + R - Φ(Q) + Φ(R) + B12 := g2.neg(B5) // -Q + R - Φ(Q) - Φ(R) + B13 := g2.neg(B4) // -Q - R + Φ(Q) + Φ(R) + B14 := g2.neg(B3) // -Q - R + Φ(Q) - Φ(R) + B15 := g2.neg(B2) // -Q - R - Φ(Q) + Φ(R) + B16 := g2.neg(B1) // -Q - R - Φ(Q) - Φ(R) + + var Bi *G2Affine + for i := nbits - 1; i > 0; i-- { + // selectorY takes values in [0,15] + selectorY := g2.api.Add( + u1bits[i], + g2.api.Mul(u2bits[i], 2), + g2.api.Mul(v1bits[i], 4), + g2.api.Mul(v2bits[i], 8), + ) + // selectorX takes values in [0,7] s.t.: + // - when selectorY < 8: selectorX = selectorY + // - when selectorY >= 8: selectorX = 15 - selectorY + selectorX := g2.api.Add( + g2.api.Mul(selectorY, g2.api.Sub(1, g2.api.Mul(v2bits[i], 2))), + g2.api.Mul(v2bits[i], 15), + ) + + // Bi.Y are distinct so we need a 16-to-1 multiplexer, + // but only half of the Bi.X are distinct so we need an 8-to-1. + Bi = &G2Affine{ + P: g2AffP{ + X: fields_bn254.E2{ + A0: *g2.fp.Mux(selectorX, + &B16.P.X.A0, &B8.P.X.A0, &B14.P.X.A0, &B6.P.X.A0, &B12.P.X.A0, &B4.P.X.A0, &B10.P.X.A0, &B2.P.X.A0, + ), + A1: *g2.fp.Mux(selectorX, + &B16.P.X.A1, &B8.P.X.A1, &B14.P.X.A1, &B6.P.X.A1, &B12.P.X.A1, &B4.P.X.A1, &B10.P.X.A1, &B2.P.X.A1, + ), + }, + Y: fields_bn254.E2{ + A0: *g2.fp.Mux(selectorY, + &B16.P.Y.A0, &B8.P.Y.A0, &B14.P.Y.A0, &B6.P.Y.A0, &B12.P.Y.A0, &B4.P.Y.A0, &B10.P.Y.A0, &B2.P.Y.A0, + &B15.P.Y.A0, &B7.P.Y.A0, &B13.P.Y.A0, &B5.P.Y.A0, &B11.P.Y.A0, &B3.P.Y.A0, &B9.P.Y.A0, &B1.P.Y.A0, + ), + A1: *g2.fp.Mux(selectorY, + &B16.P.Y.A1, &B8.P.Y.A1, &B14.P.Y.A1, &B6.P.Y.A1, &B12.P.Y.A1, &B4.P.Y.A1, &B10.P.Y.A1, &B2.P.Y.A1, + &B15.P.Y.A1, &B7.P.Y.A1, &B13.P.Y.A1, &B5.P.Y.A1, &B11.P.Y.A1, &B3.P.Y.A1, &B9.P.Y.A1, &B1.P.Y.A1, + ), + }, + }, + } + // Acc = [2]Acc + Bi + Acc = g2.doubleAndAdd(Acc, Bi) + } + + // i = 0: subtract Q, Φ(Q), R, Φ(R) if the first bits are 0 + tableQ[0] = g2.add(tableQ[0], Acc) + Acc = g2.Select(u1bits[0], Acc, tableQ[0]) + tablePhiQ[0] = g2.add(tablePhiQ[0], Acc) + Acc = g2.Select(u2bits[0], Acc, tablePhiQ[0]) + tableR[0] = g2.add(tableR[0], Acc) + Acc = g2.Select(v1bits[0], Acc, tableR[0]) + tablePhiR[0] = g2.add(tablePhiR[0], Acc) + Acc = g2.Select(v2bits[0], Acc, tablePhiR[0]) + + // Acc should now be [2^(nbits-1)]G2 since [u1]Q + [u2]Φ(Q) + [v1]R + [v2]Φ(R) = 0 + // and we added G2 to the initial accumulator. + expected := &G2Affine{P: *g2.g2GenNbits} + + if cfg.CompleteArithmetic { + // if Q=(0,0), s=0, or R.X==Q.X, skip the check + skip := g2.api.Or(g2.api.Or(selector0, _selector0), _selector1) + Acc = g2.Select(skip, expected, Acc) + } + g2.AssertIsEqual(Acc, expected) + + if cfg.CompleteArithmetic { + // if s=0 or Q=(0,0), return (0,0); otherwise return the original hinted R + zeroE2 := g2.Ext2.Zero() + returnZero := g2.api.Or(selector0, _selector0) + return g2.Select(returnZero, &G2Affine{P: g2AffP{X: *zeroE2, Y: *zeroE2}}, originalR) + } + + return R +} diff --git a/std/algebra/emulated/sw_bn254/g2_test.go b/std/algebra/emulated/sw_bn254/g2_test.go index 812e21e0e0..8b95c58700 100644 --- a/std/algebra/emulated/sw_bn254/g2_test.go +++ b/std/algebra/emulated/sw_bn254/g2_test.go @@ -1,12 +1,16 @@ package sw_bn254 import ( + "crypto/rand" + "fmt" "math/big" "testing" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/test" ) @@ -156,3 +160,130 @@ func TestEndomorphismG2TestSolve(t *testing.T) { err := test.IsSolved(&endomorphismG2Circuit{}, &witness, ecc.BN254.ScalarField()) assert.NoError(err) } + +type scalarMulG2GLVAndFakeGLVCircuit struct { + In G2Affine + Res G2Affine + S Scalar +} + +func (c *scalarMulG2GLVAndFakeGLVCircuit) Define(api frontend.API) error { + g2, err := NewG2(api) + if err != nil { + return err + } + res := g2.ScalarMul(&c.In, &c.S) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestScalarMulG2GLVAndFakeGLV(t *testing.T) { + assert := test.NewAssert(t) + // Use a fixed scalar for reproducibility + s := big.NewInt(12345) + var sFr fr.Element + sFr.SetBigInt(s) + + _, in1 := randomG1G2Affines() + var res bn254.G2Affine + res.ScalarMultiplication(&in1, s) + + witness := scalarMulG2GLVAndFakeGLVCircuit{ + In: NewG2Affine(in1), + S: NewScalar(sFr), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2GLVAndFakeGLVCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulG2GLVAndFakeGLVRandom(t *testing.T) { + assert := test.NewAssert(t) + // Use a random scalar + s, _ := rand.Int(rand.Reader, fr.Modulus()) + var sFr fr.Element + sFr.SetBigInt(s) + + _, in1 := randomG1G2Affines() + var res bn254.G2Affine + res.ScalarMultiplication(&in1, s) + + witness := scalarMulG2GLVAndFakeGLVCircuit{ + In: NewG2Affine(in1), + S: NewScalar(sFr), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2GLVAndFakeGLVCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +// Circuit for testing G2 scalar multiplication with complete arithmetic (handles edge cases) +type scalarMulG2CompleteCircuit struct { + In, Res G2Affine + S Scalar +} + +func (c *scalarMulG2CompleteCircuit) Define(api frontend.API) error { + g2, err := NewG2(api) + if err != nil { + return fmt.Errorf("new G2 struct: %w", err) + } + res := g2.scalarMulGLVAndFakeGLV(&c.In, &c.S, algopts.WithCompleteArithmetic()) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +// TestScalarMulG2EdgeCases tests edge cases: s=0, s=1, s=-1, Q=(0,0) +func TestScalarMulG2EdgeCases(t *testing.T) { + assert := test.NewAssert(t) + _, _, _, gen := bn254.Generators() + + // Test case: s = 1 (result should be Q) + t.Run("s=1", func(t *testing.T) { + var s fr.Element + s.SetOne() + var res bn254.G2Affine + res.Set(&gen) // [1]Q = Q + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + + // Test case: s = -1 (result should be -Q) + t.Run("s=-1", func(t *testing.T) { + var s fr.Element + s.SetOne() + s.Neg(&s) // s = -1 + var res bn254.G2Affine + res.Neg(&gen) // [-1]Q = -Q + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + + // Test case: s = 0 (result should be (0,0)) + t.Run("s=0", func(t *testing.T) { + var s fr.Element + s.SetZero() + var res bn254.G2Affine // zero value is (0,0) + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + +} diff --git a/std/algebra/emulated/sw_bn254/hints.go b/std/algebra/emulated/sw_bn254/hints.go index 9c0bd4aaf6..29d11142d4 100644 --- a/std/algebra/emulated/sw_bn254/hints.go +++ b/std/algebra/emulated/sw_bn254/hints.go @@ -2,8 +2,10 @@ package sw_bn254 import ( "errors" + "fmt" "math/big" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/std/math/emulated" @@ -19,6 +21,8 @@ func GetHints() []solver.Hint { finalExpHint, pairingCheckHint, millerLoopAndCheckFinalExpHint, + scalarMulG2Hint, + rationalReconstructExtG2, } } @@ -276,3 +280,94 @@ func finalExpWitness(millerLoop *bn254.E12) (residueWitness, cubicNonResiduePowe return residueWitness, cubicNonResiduePower } + +func scalarMulG2Hint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintContext(field, inputs, outputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 2 { + return fmt.Errorf("expecting two moduli, got %d", len(moduli)) + } + baseModulus, scalarModulus := moduli[0], moduli[1] + baseInputs, baseOutputs := hc.InputsOutputs(baseModulus) + scalarInputs, _ := hc.InputsOutputs(scalarModulus) + if len(baseInputs) != 4 { + return fmt.Errorf("expecting four base inputs (Q.X.A0, Q.X.A1, Q.Y.A0, Q.Y.A1), got %d", len(baseInputs)) + } + if len(baseOutputs) != 4 { + return fmt.Errorf("expecting four base outputs, got %d", len(baseOutputs)) + } + if len(scalarInputs) != 1 { + return fmt.Errorf("expecting one scalar input, got %d", len(scalarInputs)) + } + + // compute the resulting point [s]Q on G2 + var Q bn254.G2Affine + Q.X.A0.SetBigInt(baseInputs[0]) + Q.X.A1.SetBigInt(baseInputs[1]) + Q.Y.A0.SetBigInt(baseInputs[2]) + Q.Y.A1.SetBigInt(baseInputs[3]) + Q.ScalarMultiplication(&Q, scalarInputs[0]) + Q.X.A0.BigInt(baseOutputs[0]) + Q.X.A1.BigInt(baseOutputs[1]) + Q.Y.A0.BigInt(baseOutputs[2]) + Q.Y.A1.BigInt(baseOutputs[3]) + return nil + }) +} + +func rationalReconstructExtG2(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 1 { + return fmt.Errorf("expecting one modulus, got %d", len(moduli)) + } + _, nativeOutputs := hc.NativeInputsOutputs() + if len(nativeOutputs) != 4 { + return fmt.Errorf("expecting four outputs, got %d", len(nativeOutputs)) + } + emuInputs, emuOutputs := hc.InputsOutputs(moduli[0]) + if len(emuInputs) != 2 { + return fmt.Errorf("expecting two inputs, got %d", len(emuInputs)) + } + if len(emuOutputs) != 4 { + return fmt.Errorf("expecting four outputs, got %d", len(emuOutputs)) + } + + // Use lattice reduction to find (x, y, z, t) such that + // k ≡ (x + λ*y) / (z + λ*t) (mod r) + // + // in-circuit we check that R - [s]Q = 0 or equivalently R + [-s]Q = 0 + // so here we use k = -s. + k := new(big.Int).Neg(emuInputs[0]) + k.Mod(k, moduli[0]) + rc := lattice.NewReconstructor(moduli[0]).SetLambda(emuInputs[1]) + res := rc.RationalReconstructExt(k) + x, y, z, t := res[0], res[1], res[2], res[3] + + // u1 = x, u2 = y, v1 = z, v2 = t + emuOutputs[0].Abs(x) + emuOutputs[1].Abs(y) + emuOutputs[2].Abs(z) + emuOutputs[3].Abs(t) + + // signs + nativeOutputs[0].SetUint64(0) + nativeOutputs[1].SetUint64(0) + nativeOutputs[2].SetUint64(0) + nativeOutputs[3].SetUint64(0) + + if x.Sign() < 0 { + nativeOutputs[0].SetUint64(1) + } + if y.Sign() < 0 { + nativeOutputs[1].SetUint64(1) + } + if z.Sign() < 0 { + nativeOutputs[2].SetUint64(1) + } + if t.Sign() < 0 { + nativeOutputs[3].SetUint64(1) + } + return nil + }) +} diff --git a/std/algebra/emulated/sw_bw6761/g2.go b/std/algebra/emulated/sw_bw6761/g2.go index a83e0b307c..17e42b0108 100644 --- a/std/algebra/emulated/sw_bw6761/g2.go +++ b/std/algebra/emulated/sw_bw6761/g2.go @@ -6,6 +6,7 @@ import ( bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" ) @@ -59,8 +60,16 @@ func NewG2AffineFixedPlaceholder() G2Affine { } type G2 struct { + api frontend.API curveF *emulated.Field[BaseField] + fr *emulated.Field[ScalarField] w *emulated.Element[BaseField] + // GLV eigenvalue for endomorphism + eigenvalue *emulated.Element[ScalarField] + + // Precomputed G2 generator and its multiple for GLV+FakeGLV + g2Gen *g2AffP // G2 generator + g2GenNbits *g2AffP // [2^(nbits-1)]G2 where nbits = (r.BitLen()+3)/4 + 2 } func NewG2(api frontend.API) (*G2, error) { @@ -68,10 +77,36 @@ func NewG2(api frontend.API) (*G2, error) { if err != nil { return nil, fmt.Errorf("new base api: %w", err) } + fr, err := emulated.NewField[ScalarField](api) + if err != nil { + return nil, fmt.Errorf("new scalar api: %w", err) + } + // w = thirdRootOneG2 = thirdRootOneG1^2 (used for GLV endomorphism) w := ba.NewElement("4922464560225523242118178942575080391082002530232324381063048548642823052024664478336818169867474395270858391911405337707247735739826664939444490469542109391530482826728203582549674992333383150446779312029624171857054392282775648") + // GLV eigenvalue: lambda such that phi(P) = [lambda]P + eigenvalue := fr.NewElement("80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410945") + + // Precomputed G2 generator for GLV+FakeGLV + g2Gen := &g2AffP{ + X: *ba.NewElement("6445332910596979336035888152774071626898886139774101364933948236926875073754470830732273879639675437155036544153105017729592600560631678554299562762294743927912429096636156401171909259073181112518725201388196280039960074422214428"), + Y: *ba.NewElement("562923658089539719386922163444547387757586534741080263946953401595155211934630598999300396317104182598044793758153214972605680357108252243146746187917218885078195819486220416605630144001533548163105316661692978285266378674355041"), + } + // [2^(nbits-1)]G2 where nbits = (377+3)/4 + 2 = 97, so this is [2^96]G2 + // The loop does nbits-1 doublings, so the generator accumulates to [2^(nbits-1)]G2 + g2GenNbits := &g2AffP{ + X: *ba.NewElement("3095984673093732516312387265169694060996602327701627003095800025572039633257324043941471095859774515229409057356532230556857309141882262691503434703676863345821048055421798431014967860961114720963410640620563703233324706890355614"), + Y: *ba.NewElement("6717446314608317454056612988521276523143603352262745009529835803932138303462642316467740443074785130100608444461459148229179290796669940701932233012187852232981798195344309857014515889020782044489099447799956729215609170567055537"), + } + return &G2{ - curveF: ba, - w: w, + api: api, + curveF: ba, + fr: fr, + w: w, + eigenvalue: eigenvalue, + // GLV+FakeGLV precomputed values + g2Gen: g2Gen, + g2GenNbits: g2GenNbits, }, nil } @@ -236,3 +271,283 @@ func (g2 *G2) AssertIsEqual(p, q *G2Affine) { g2.curveF.AssertIsEqual(&p.P.X, &q.P.X) g2.curveF.AssertIsEqual(&p.P.Y, &q.P.Y) } + +// Select selects between p and q given the selector b. If b == 1, then returns +// p and q otherwise. +func (g2 *G2) Select(b frontend.Variable, p, q *G2Affine) *G2Affine { + x := g2.curveF.Select(b, &p.P.X, &q.P.X) + y := g2.curveF.Select(b, &p.P.Y, &q.P.Y) + return &G2Affine{ + P: g2AffP{X: *x, Y: *y}, + Lines: nil, + } +} + +// ScalarMul computes [s]Q using an efficient endomorphism and returns it. It doesn't modify Q nor s. +// It implements the GLV+fakeGLV optimization from [EEMP25] which achieves r^(1/4) bounds +// on the sub-scalars, reducing the number of iterations in the scalar multiplication loop. +// +// ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. +// (0,0) is not on the curve but we conventionally take it as the +// neutral/infinity point as per the [EVM]. +// +// [EEMP25]: https://eprint.iacr.org/2025/933 +// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf +func (g2 *G2) ScalarMul(Q *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { + return g2.scalarMulGLVAndFakeGLV(Q, s, opts...) +} + +// scalarMulGLVAndFakeGLV computes [s]Q using GLV+fakeGLV with r^(1/4) bounds. +// It implements the "GLV + fake GLV" explained in [EEMP25] (Sec. 3.3). +// +// ⚠️ The scalar s must be nonzero and the point Q different from (0,0) unless [algopts.WithCompleteArithmetic] is set. +// +// [EEMP25]: https://eprint.iacr.org/2025/933 +func (g2 *G2) scalarMulGLVAndFakeGLV(Q *G2Affine, s *Scalar, opts ...algopts.AlgebraOption) *G2Affine { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(err) + } + + // handle 0-scalar + var selector0 frontend.Variable + _s := s + if cfg.CompleteArithmetic { + one := g2.fr.One() + selector0 = g2.fr.IsZero(s) + _s = g2.fr.Select(selector0, one, s) + } + + // Instead of computing [s]Q=R, we check that R-[s]Q == 0. + // This is equivalent to [v]R + [-s*v]Q = 0 for some nonzero v. + // + // Using LLL-based lattice reduction we find small sub-scalars: + // [v1 + λ*v2]R + [u1 + λ*u2]Q = 0 + // [v1]R + [v2]Φ(R) + [u1]Q + [u2]Φ(Q) = 0 + // + // where u1, u2, v1, v2 < c*r^{1/4} with c ≈ 1.25 (proven bound from LLL). + + // decompose s into u1, u2, v1, v2 + signs, sd, err := g2.fr.NewHintGeneric(rationalReconstructExtG2, 4, 4, nil, []*emulated.Element[ScalarField]{_s, g2.eigenvalue}) + if err != nil { + panic(fmt.Sprintf("rationalReconstructExtG2 hint: %v", err)) + } + u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] + isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] + + // Check that: s*(v1 + λ*v2) + u1 + λ*u2 = 0 + var st ScalarField + sv1 := g2.fr.Mul(_s, v1) + sλv2 := g2.fr.Mul(_s, g2.fr.Mul(g2.eigenvalue, v2)) + λu2 := g2.fr.Mul(g2.eigenvalue, u2) + zero := g2.fr.Zero() + + lhs1 := g2.fr.Select(isNegv1, zero, sv1) + lhs2 := g2.fr.Select(isNegv2, zero, sλv2) + lhs3 := g2.fr.Select(isNegu1, zero, u1) + lhs4 := g2.fr.Select(isNegu2, zero, λu2) + lhs := g2.fr.Add( + g2.fr.Add(lhs1, lhs2), + g2.fr.Add(lhs3, lhs4), + ) + + rhs1 := g2.fr.Select(isNegv1, sv1, zero) + rhs2 := g2.fr.Select(isNegv2, sλv2, zero) + rhs3 := g2.fr.Select(isNegu1, u1, zero) + rhs4 := g2.fr.Select(isNegu2, λu2, zero) + rhs := g2.fr.Add( + g2.fr.Add(rhs1, rhs2), + g2.fr.Add(rhs3, rhs4), + ) + + g2.fr.AssertIsEqual(lhs, rhs) + + // Ensure the denominator v1 + λ*v2 is non-zero to prevent trivial decomposition + den := g2.fr.Add(v1, g2.fr.Mul(g2.eigenvalue, v2)) + g2.fr.AssertIsDifferent(den, g2.fr.Zero()) + + // Hint the scalar multiplication R = [s]Q + _, point, _, err := emulated.NewVarGenericHint(g2.api, 0, 2, 0, nil, + []*emulated.Element[BaseField]{&Q.P.X, &Q.P.Y}, + []*emulated.Element[ScalarField]{s}, + scalarMulG2Hint) + if err != nil { + panic(fmt.Sprintf("scalarMulG2Hint: %v", err)) + } + R := &G2Affine{ + P: g2AffP{ + X: *point[0], + Y: *point[1], + }, + } + // Preserve the original hinted R for return value (before edge-case modifications) + originalR := R + + // handle (0,0)-point and edge cases + var _selector0, _selector1 frontend.Variable + _Q := Q + if cfg.CompleteArithmetic { + one := g2.curveF.One() + // if Q=(0,0) we assign a dummy point + _selector0 = g2.api.And(g2.curveF.IsZero(&Q.P.X), g2.curveF.IsZero(&Q.P.Y)) + _Q = g2.Select(_selector0, &G2Affine{P: g2AffP{X: *one, Y: *one}}, Q) + // if R.X == Q.X (happens when s=±1, so R=±Q), the incomplete addition fails + // We check this BEFORE potentially modifying R + _selector1 = g2.curveF.IsZero(g2.curveF.Sub(&Q.P.X, &R.P.X)) + // if s=0/s=-1 (selector0), Q=(0,0) (_selector0), or R.X==Q.X (_selector1), + // we assign a dummy point to R + selectorAny := g2.api.Or(g2.api.Or(selector0, _selector0), _selector1) + R = g2.Select(selectorAny, &G2Affine{P: g2AffP{X: *one, Y: *one}}, R) + } + + // precompute -Q, -Φ(Q), Φ(Q) + var tableQ, tablePhiQ [2]*G2Affine + negQY := g2.curveF.Neg(&_Q.P.Y) + tableQ[1] = &G2Affine{ + P: g2AffP{ + X: _Q.P.X, + Y: *g2.curveF.Select(isNegu1, negQY, &_Q.P.Y), + }, + } + tableQ[0] = g2.neg(tableQ[1]) + // For BW6-761 G2, phi(Q) = (w * Q.X, Q.Y) + tablePhiQ[1] = &G2Affine{ + P: g2AffP{ + X: *g2.curveF.Mul(&_Q.P.X, g2.w), + Y: *g2.curveF.Select(isNegu2, negQY, &_Q.P.Y), + }, + } + tablePhiQ[0] = g2.neg(tablePhiQ[1]) + + // precompute -R, -Φ(R), Φ(R) + var tableR, tablePhiR [2]*G2Affine + negRY := g2.curveF.Neg(&R.P.Y) + tableR[1] = &G2Affine{ + P: g2AffP{ + X: R.P.X, + Y: *g2.curveF.Select(isNegv1, negRY, &R.P.Y), + }, + } + tableR[0] = g2.neg(tableR[1]) + tablePhiR[1] = &G2Affine{ + P: g2AffP{ + X: *g2.curveF.Mul(&R.P.X, g2.w), + Y: *g2.curveF.Select(isNegv2, negRY, &R.P.Y), + }, + } + tablePhiR[0] = g2.neg(tablePhiR[1]) + + // precompute -Q-R, Q+R, Q-R, -Q+R (combining the two points Q and R) + var tableS [4]*G2Affine + tableS[0] = g2.add(tableQ[0], tableR[0]) // -Q - R + tableS[1] = g2.neg(tableS[0]) // Q + R + tableS[2] = g2.add(tableQ[1], tableR[0]) // Q - R + tableS[3] = g2.neg(tableS[2]) // -Q + R + + // precompute -Φ(Q)-Φ(R), Φ(Q)+Φ(R), Φ(Q)-Φ(R), -Φ(Q)+Φ(R) (combining endomorphisms) + var tablePhiS [4]*G2Affine + tablePhiS[0] = g2.add(tablePhiQ[0], tablePhiR[0]) // -Φ(Q) - Φ(R) + tablePhiS[1] = g2.neg(tablePhiS[0]) // Φ(Q) + Φ(R) + tablePhiS[2] = g2.add(tablePhiQ[1], tablePhiR[0]) // Φ(Q) - Φ(R) + tablePhiS[3] = g2.neg(tablePhiS[2]) // -Φ(Q) + Φ(R) + + // Acc = Q + Φ(Q) + R + Φ(R) + Acc := g2.add(tableS[1], tablePhiS[1]) + B1 := Acc + + // Add G2 generator to Acc to avoid incomplete additions in the loop. + // At the end, since [u1]Q + [u2]Φ(Q) + [v1]R + [v2]Φ(R) = 0, + // Acc will equal [2^nbits]G2 (precomputed). + g2GenPoint := &G2Affine{P: *g2.g2Gen} + Acc = g2.add(Acc, g2GenPoint) + + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 + nbits := (st.Modulus().BitLen()+3)/4 + 2 + u1bits := g2.fr.ToBits(u1) + u2bits := g2.fr.ToBits(u2) + v1bits := g2.fr.ToBits(v1) + v2bits := g2.fr.ToBits(v2) + + // Precompute all 16 combinations: ±Q ± Φ(Q) ± R ± Φ(R) + // Using tableS (Q±R) and tablePhiS (Φ(Q)±Φ(R)) to match G1 pattern + // B1 = (Q+R) + (Φ(Q)+Φ(R)) = Q + R + Φ(Q) + Φ(R) + B2 := g2.add(tableS[1], tablePhiS[2]) // (Q+R) + (Φ(Q)-Φ(R)) = Q + R + Φ(Q) - Φ(R) + B3 := g2.add(tableS[1], tablePhiS[3]) // (Q+R) + (-Φ(Q)+Φ(R)) = Q + R - Φ(Q) + Φ(R) + B4 := g2.add(tableS[1], tablePhiS[0]) // (Q+R) + (-Φ(Q)-Φ(R)) = Q + R - Φ(Q) - Φ(R) + B5 := g2.add(tableS[2], tablePhiS[1]) // (Q-R) + (Φ(Q)+Φ(R)) = Q - R + Φ(Q) + Φ(R) + B6 := g2.add(tableS[2], tablePhiS[2]) // (Q-R) + (Φ(Q)-Φ(R)) = Q - R + Φ(Q) - Φ(R) + B7 := g2.add(tableS[2], tablePhiS[3]) // (Q-R) + (-Φ(Q)+Φ(R)) = Q - R - Φ(Q) + Φ(R) + B8 := g2.add(tableS[2], tablePhiS[0]) // (Q-R) + (-Φ(Q)-Φ(R)) = Q - R - Φ(Q) - Φ(R) + B9 := g2.neg(B8) // -Q + R + Φ(Q) + Φ(R) + B10 := g2.neg(B7) // -Q + R + Φ(Q) - Φ(R) + B11 := g2.neg(B6) // -Q + R - Φ(Q) + Φ(R) + B12 := g2.neg(B5) // -Q + R - Φ(Q) - Φ(R) + B13 := g2.neg(B4) // -Q - R + Φ(Q) + Φ(R) + B14 := g2.neg(B3) // -Q - R + Φ(Q) - Φ(R) + B15 := g2.neg(B2) // -Q - R - Φ(Q) + Φ(R) + B16 := g2.neg(B1) // -Q - R - Φ(Q) - Φ(R) + + var Bi *G2Affine + for i := nbits - 1; i > 0; i-- { + // selectorY takes values in [0,15] + selectorY := g2.api.Add( + u1bits[i], + g2.api.Mul(u2bits[i], 2), + g2.api.Mul(v1bits[i], 4), + g2.api.Mul(v2bits[i], 8), + ) + // selectorX takes values in [0,7] s.t.: + // - when selectorY < 8: selectorX = selectorY + // - when selectorY >= 8: selectorX = 15 - selectorY + selectorX := g2.api.Add( + g2.api.Mul(selectorY, g2.api.Sub(1, g2.api.Mul(v2bits[i], 2))), + g2.api.Mul(v2bits[i], 15), + ) + + // Bi.Y are distinct so we need a 16-to-1 multiplexer, + // but only half of the Bi.X are distinct so we need an 8-to-1. + Bi = &G2Affine{ + P: g2AffP{ + X: *g2.curveF.Mux(selectorX, + &B16.P.X, &B8.P.X, &B14.P.X, &B6.P.X, &B12.P.X, &B4.P.X, &B10.P.X, &B2.P.X, + ), + Y: *g2.curveF.Mux(selectorY, + &B16.P.Y, &B8.P.Y, &B14.P.Y, &B6.P.Y, &B12.P.Y, &B4.P.Y, &B10.P.Y, &B2.P.Y, + &B15.P.Y, &B7.P.Y, &B13.P.Y, &B5.P.Y, &B11.P.Y, &B3.P.Y, &B9.P.Y, &B1.P.Y, + ), + }, + } + // Acc = [2]Acc + Bi + Acc = g2.doubleAndAdd(Acc, Bi) + } + + // i = 0: subtract Q, Φ(Q), R, Φ(R) if the first bits are 0 + tableQ[0] = g2.add(tableQ[0], Acc) + Acc = g2.Select(u1bits[0], Acc, tableQ[0]) + tablePhiQ[0] = g2.add(tablePhiQ[0], Acc) + Acc = g2.Select(u2bits[0], Acc, tablePhiQ[0]) + tableR[0] = g2.add(tableR[0], Acc) + Acc = g2.Select(v1bits[0], Acc, tableR[0]) + tablePhiR[0] = g2.add(tablePhiR[0], Acc) + Acc = g2.Select(v2bits[0], Acc, tablePhiR[0]) + + // Acc should now be [2^(nbits-1)]G2 since [u1]Q + [u2]Φ(Q) + [v1]R + [v2]Φ(R) = 0 + // and we added G2 to the initial accumulator. + expected := &G2Affine{P: *g2.g2GenNbits} + + if cfg.CompleteArithmetic { + // if Q=(0,0), s=0, or R.X==Q.X, skip the check + skip := g2.api.Or(g2.api.Or(selector0, _selector0), _selector1) + Acc = g2.Select(skip, expected, Acc) + } + g2.AssertIsEqual(Acc, expected) + + if cfg.CompleteArithmetic { + // if s=0 or Q=(0,0), return (0,0); otherwise return the original hinted R + zeroEl := g2.curveF.Zero() + returnZero := g2.api.Or(selector0, _selector0) + return g2.Select(returnZero, &G2Affine{P: g2AffP{X: *zeroEl, Y: *zeroEl}}, originalR) + } + + return R +} diff --git a/std/algebra/emulated/sw_bw6761/g2_test.go b/std/algebra/emulated/sw_bw6761/g2_test.go new file mode 100644 index 0000000000..66f54b520d --- /dev/null +++ b/std/algebra/emulated/sw_bw6761/g2_test.go @@ -0,0 +1,142 @@ +package sw_bw6761 + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/algopts" + "github.com/consensys/gnark/test" +) + +type scalarMulG2GLVAndFakeGLVCircuit struct { + In G2Affine + Res G2Affine + S Scalar +} + +func (c *scalarMulG2GLVAndFakeGLVCircuit) Define(api frontend.API) error { + g2, err := NewG2(api) + if err != nil { + return err + } + res := g2.ScalarMul(&c.In, &c.S) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +func TestScalarMulG2GLVAndFakeGLV(t *testing.T) { + assert := test.NewAssert(t) + // Use a fixed scalar for reproducibility + s := big.NewInt(12345) + var sFr fr.Element + sFr.SetBigInt(s) + + _, in1 := randomG1G2Affines() + var res bw6761.G2Affine + res.ScalarMultiplication(&in1, s) + + witness := scalarMulG2GLVAndFakeGLVCircuit{ + In: NewG2Affine(in1), + S: NewScalar(sFr), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2GLVAndFakeGLVCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +func TestScalarMulG2GLVAndFakeGLVRandom(t *testing.T) { + assert := test.NewAssert(t) + // Use a random scalar + s, _ := rand.Int(rand.Reader, fr.Modulus()) + var sFr fr.Element + sFr.SetBigInt(s) + + _, in1 := randomG1G2Affines() + var res bw6761.G2Affine + res.ScalarMultiplication(&in1, s) + + witness := scalarMulG2GLVAndFakeGLVCircuit{ + In: NewG2Affine(in1), + S: NewScalar(sFr), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2GLVAndFakeGLVCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +// Circuit for testing G2 scalar multiplication with complete arithmetic (handles edge cases) +type scalarMulG2CompleteCircuit struct { + In, Res G2Affine + S Scalar +} + +func (c *scalarMulG2CompleteCircuit) Define(api frontend.API) error { + g2, err := NewG2(api) + if err != nil { + return fmt.Errorf("new G2 struct: %w", err) + } + res := g2.scalarMulGLVAndFakeGLV(&c.In, &c.S, algopts.WithCompleteArithmetic()) + g2.AssertIsEqual(res, &c.Res) + return nil +} + +// TestScalarMulG2EdgeCases tests edge cases: s=0, s=1, s=-1, Q=(0,0) +func TestScalarMulG2EdgeCases(t *testing.T) { + assert := test.NewAssert(t) + _, _, _, gen := bw6761.Generators() + + // Test case: s = 1 (result should be Q) + t.Run("s=1", func(t *testing.T) { + var s fr.Element + s.SetOne() + var res bw6761.G2Affine + res.Set(&gen) // [1]Q = Q + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + + // Test case: s = -1 (result should be -Q) + t.Run("s=-1", func(t *testing.T) { + var s fr.Element + s.SetOne() + s.Neg(&s) // s = -1 + var res bw6761.G2Affine + res.Neg(&gen) // [-1]Q = -Q + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + + // Test case: s = 0 (result should be (0,0)) + t.Run("s=0", func(t *testing.T) { + var s fr.Element + s.SetZero() + var res bw6761.G2Affine // zero value is (0,0) + + witness := scalarMulG2CompleteCircuit{ + In: NewG2Affine(gen), + S: NewScalar(s), + Res: NewG2Affine(res), + } + err := test.IsSolved(&scalarMulG2CompleteCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + }) + +} diff --git a/std/algebra/emulated/sw_bw6761/hints.go b/std/algebra/emulated/sw_bw6761/hints.go index 7966519344..d9c777b673 100644 --- a/std/algebra/emulated/sw_bw6761/hints.go +++ b/std/algebra/emulated/sw_bw6761/hints.go @@ -1,8 +1,11 @@ package sw_bw6761 import ( + "fmt" "math/big" + "github.com/consensys/gnark-crypto/algebra/lattice" + "github.com/consensys/gnark-crypto/ecc" bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/std/math/emulated" @@ -17,6 +20,9 @@ func GetHints() []solver.Hint { return []solver.Hint{ finalExpHint, pairingCheckHint, + decomposeScalarG1, + scalarMulG2Hint, + rationalReconstructExtG2, } } @@ -109,3 +115,133 @@ func finalExpWitness(millerLoop *bw6761.E6, mInv *big.Int) (residueWitness bw676 return residueWitness } + +func decomposeScalarG1(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 1 { + return fmt.Errorf("expecting one moduli, got %d", len(moduli)) + } + _, nativeOutputs := hc.NativeInputsOutputs() + if len(nativeOutputs) != 2 { + return fmt.Errorf("expecting two outputs, got %d", len(nativeOutputs)) + } + emuInputs, emuOutputs := hc.InputsOutputs(moduli[0]) + if len(emuInputs) != 2 { + return fmt.Errorf("expecting two inputs, got %d", len(emuInputs)) + } + if len(emuOutputs) != 2 { + return fmt.Errorf("expecting two outputs, got %d", len(emuOutputs)) + } + + glvBasis := new(ecc.Lattice) + ecc.PrecomputeLattice(moduli[0], emuInputs[1], glvBasis) + sp := ecc.SplitScalar(emuInputs[0], glvBasis) + emuOutputs[0].Set(&sp[0]) + emuOutputs[1].Set(&sp[1]) + nativeOutputs[0].SetUint64(0) + nativeOutputs[1].SetUint64(0) + // we need the absolute values for the in-circuit computations, + // otherwise the negative values will be reduced modulo the SNARK scalar + // field and not the emulated field. + // output0 = |s0| mod r + // output1 = |s1| mod r + if emuOutputs[0].Sign() == -1 { + emuOutputs[0].Neg(emuOutputs[0]) + nativeOutputs[0].SetUint64(1) + } + if emuOutputs[1].Sign() == -1 { + emuOutputs[1].Neg(emuOutputs[1]) + nativeOutputs[1].SetUint64(1) + } + + return nil + }) +} + +func scalarMulG2Hint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintContext(field, inputs, outputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 2 { + return fmt.Errorf("expecting two moduli, got %d", len(moduli)) + } + baseModulus, scalarModulus := moduli[0], moduli[1] + baseInputs, baseOutputs := hc.InputsOutputs(baseModulus) + scalarInputs, _ := hc.InputsOutputs(scalarModulus) + if len(baseInputs) != 2 { + return fmt.Errorf("expecting two base inputs (Q.X, Q.Y), got %d", len(baseInputs)) + } + if len(baseOutputs) != 2 { + return fmt.Errorf("expecting two base outputs, got %d", len(baseOutputs)) + } + if len(scalarInputs) != 1 { + return fmt.Errorf("expecting one scalar input, got %d", len(scalarInputs)) + } + + // compute the resulting point [s]Q on G2 + var Q bw6761.G2Affine + Q.X.SetBigInt(baseInputs[0]) + Q.Y.SetBigInt(baseInputs[1]) + Q.ScalarMultiplication(&Q, scalarInputs[0]) + Q.X.BigInt(baseOutputs[0]) + Q.Y.BigInt(baseOutputs[1]) + return nil + }) +} + +func rationalReconstructExtG2(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { + moduli := hc.EmulatedModuli() + if len(moduli) != 1 { + return fmt.Errorf("expecting one modulus, got %d", len(moduli)) + } + _, nativeOutputs := hc.NativeInputsOutputs() + if len(nativeOutputs) != 4 { + return fmt.Errorf("expecting four outputs, got %d", len(nativeOutputs)) + } + emuInputs, emuOutputs := hc.InputsOutputs(moduli[0]) + if len(emuInputs) != 2 { + return fmt.Errorf("expecting two inputs, got %d", len(emuInputs)) + } + if len(emuOutputs) != 4 { + return fmt.Errorf("expecting four outputs, got %d", len(emuOutputs)) + } + + // Use lattice reduction to find (x, y, z, t) such that + // k ≡ (x + λ*y) / (z + λ*t) (mod r) + // + // in-circuit we check that R - [s]Q = 0 or equivalently R + [-s]Q = 0 + // so here we use k = -s. + k := new(big.Int).Neg(emuInputs[0]) + k.Mod(k, moduli[0]) + rc := lattice.NewReconstructor(moduli[0]).SetLambda(emuInputs[1]) + res := rc.RationalReconstructExt(k) + x, y, z, t := res[0], res[1], res[2], res[3] + + // u1 = x, u2 = y, v1 = z, v2 = t + emuOutputs[0].Abs(x) + emuOutputs[1].Abs(y) + emuOutputs[2].Abs(z) + emuOutputs[3].Abs(t) + + // signs + nativeOutputs[0].SetUint64(0) + nativeOutputs[1].SetUint64(0) + nativeOutputs[2].SetUint64(0) + nativeOutputs[3].SetUint64(0) + + if x.Sign() < 0 { + nativeOutputs[0].SetUint64(1) + } + if y.Sign() < 0 { + nativeOutputs[1].SetUint64(1) + } + if z.Sign() < 0 { + nativeOutputs[2].SetUint64(1) + } + if t.Sign() < 0 { + nativeOutputs[3].SetUint64(1) + } + return nil + }) +} diff --git a/std/algebra/emulated/sw_emulated/hints.go b/std/algebra/emulated/sw_emulated/hints.go index b6b42cffe3..a219886d37 100644 --- a/std/algebra/emulated/sw_emulated/hints.go +++ b/std/algebra/emulated/sw_emulated/hints.go @@ -6,7 +6,7 @@ import ( "fmt" "math/big" - "github.com/consensys/gnark-crypto/algebra/eisenstein" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" bls12381_fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" @@ -32,8 +32,8 @@ func GetHints() []solver.Hint { return []solver.Hint{ decomposeScalarG1, scalarMulHint, - halfGCD, - halfGCDEisenstein, + rationalReconstruct, + rationalReconstructExt, } } @@ -160,7 +160,7 @@ func scalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error }) } -func halfGCD(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { +func rationalReconstruct(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { moduli := hc.EmulatedModuli() if len(moduli) != 1 { @@ -177,25 +177,38 @@ func halfGCD(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { if len(emuOutputs) != 2 { return fmt.Errorf("expecting two outputs, got %d", len(emuOutputs)) } - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(moduli[0], emuInputs[0], glvBasis) - emuOutputs[0].Set(&glvBasis.V1[0]) - emuOutputs[1].Set(&glvBasis.V1[1]) + // Use lattice reduction to find (x, z) such that s ≡ x/z (mod r), + // i.e., x - s*z ≡ 0 (mod r), or equivalently x + s*(-z) ≡ 0 (mod r). + // The circuit checks: s1 + s*_s2 ≡ 0 (mod r) + // So we need s1 = x and _s2 = -z. + rc := lattice.NewReconstructor(moduli[0]) + res := rc.RationalReconstruct(emuInputs[0]) + x, z := res[0], res[1] + + // Ensure x is non-negative (the circuit bit-decomposes s1 assuming it's small positive). + // If x < 0, flip signs: (x, z) -> (-x, -z), which preserves s = x/z. + if x.Sign() < 0 { + x.Neg(x) + z.Neg(z) + } + + emuOutputs[0].Set(x) + emuOutputs[1].Abs(z) + // we need the absolute values for the in-circuit computations, // otherwise the negative values will be reduced modulo the SNARK scalar // field and not the emulated field. - // output0 = |s0| mod r - // output1 = |s1| mod r + // The sign indicates whether to negate s2 in circuit to get -z. + // sign = 1 when z > 0 (so -z < 0, and we need to negate |z| to get -z) nativeOutputs[0].SetUint64(0) - if emuOutputs[1].Sign() == -1 { - emuOutputs[1].Neg(emuOutputs[1]) - nativeOutputs[0].SetUint64(1) // we return the sign of the second subscalar + if z.Sign() > 0 { + nativeOutputs[0].SetUint64(1) } return nil }) } -func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { +func rationalReconstructExt(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { moduli := hc.EmulatedModuli() if len(moduli) != 1 { @@ -213,47 +226,54 @@ func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) erro return fmt.Errorf("expecting four outputs, got %d", len(emuOutputs)) } - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(moduli[0], emuInputs[1], glvBasis) - r := eisenstein.ComplexNumber{ - A0: glvBasis.V1[0], - A1: glvBasis.V1[1], - } - sp := ecc.SplitScalar(emuInputs[0], glvBasis) + // Use lattice reduction to find (x, y, z, t) such that + // k ≡ (x + λ*y) / (z + λ*t) (mod r) + // // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 - // so here we return -s instead of s. - s := eisenstein.ComplexNumber{ - A0: sp[0], - A1: sp[1], - } - s.Neg(&s) + // so here we use k = -s. + // + // With k = -s: + // -s ≡ (x + λ*y) / (z + λ*t) (mod r) + // s ≡ -(x + λ*y) / (z + λ*t) = (-x - λ*y) / (z + λ*t) (mod r) + // + // The circuit checks: s*(v1 + λ*v2) + u1 + λ*u2 ≡ 0 (mod r) + // Rearranging: s ≡ -(u1 + λ*u2) / (v1 + λ*v2) (mod r) + // + // Matching: (-x - λ*y) = -(u1 + λ*u2) + // So: u1 = x, u2 = y, v1 = z, v2 = t + k := new(big.Int).Neg(emuInputs[0]) + k.Mod(k, moduli[0]) + rc := lattice.NewReconstructor(moduli[0]).SetLambda(emuInputs[1]) + res := rc.RationalReconstructExt(k) + x, y, z, t := res[0], res[1], res[2], res[3] + + // u1 = x, u2 = y, v1 = z, v2 = t + // We return absolute values and track signs + emuOutputs[0].Abs(x) // |u1| = |x| + emuOutputs[1].Abs(y) // |u2| = |y| + emuOutputs[2].Abs(z) // |v1| = |z| + emuOutputs[3].Abs(t) // |v2| = |t| - res := eisenstein.HalfGCD(&r, &s) - // values - emuOutputs[0].Set(&res[0].A0) - emuOutputs[1].Set(&res[0].A1) - emuOutputs[2].Set(&res[1].A0) - emuOutputs[3].Set(&res[1].A1) // signs - nativeOutputs[0].SetUint64(0) - nativeOutputs[1].SetUint64(0) - nativeOutputs[2].SetUint64(0) - nativeOutputs[3].SetUint64(0) + nativeOutputs[0].SetUint64(0) // isNegu1 + nativeOutputs[1].SetUint64(0) // isNegu2 + nativeOutputs[2].SetUint64(0) // isNegv1 + nativeOutputs[3].SetUint64(0) // isNegv2 - if res[0].A0.Sign() == -1 { - emuOutputs[0].Neg(emuOutputs[0]) + // u1 = x is negative when x < 0 + if x.Sign() < 0 { nativeOutputs[0].SetUint64(1) } - if res[0].A1.Sign() == -1 { - emuOutputs[1].Neg(emuOutputs[1]) + // u2 = y is negative when y < 0 + if y.Sign() < 0 { nativeOutputs[1].SetUint64(1) } - if res[1].A0.Sign() == -1 { - emuOutputs[2].Neg(emuOutputs[2]) + // v1 = z is negative when z < 0 + if z.Sign() < 0 { nativeOutputs[2].SetUint64(1) } - if res[1].A1.Sign() == -1 { - emuOutputs[3].Neg(emuOutputs[3]) + // v2 = t is negative when t < 0 + if t.Sign() < 0 { nativeOutputs[3].SetUint64(1) } return nil diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index cc469a1650..2ba0b7589f 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -846,11 +846,27 @@ func (c *Curve[B, S]) jointScalarMul(p1, p2 *AffinePoint[B], s1, s2 *emulated.El // jointScalarMulFakeGLV computes [s1]p1 + [s2]p2. It doesn't modify p1, p2 nor s1, s2. // +// For non-GLV curves, using two separate ScalarMul calls with the 2D half-GCD +// decomposition (r^(1/2) sub-scalars) is more efficient than a 3D lattice approach +// which produces r^(2/3) sub-scalars. Constraint comparison for P-256: +// - Two ScalarMul + Add: ~152k constraints (r^(1/2) ≈ 128 bits) +// - 3D Lattice: ~245k constraints (r^(2/3) ≈ 171 bits) +// - Shamir's trick: ~221k constraints (256 bits) +// // ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set. func (c *Curve[B, S]) jointScalarMulFakeGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { - sm1 := c.scalarMulFakeGLV(p1, s1, opts...) - sm2 := c.scalarMulFakeGLV(p2, s2, opts...) - return c.AddUnified(sm1, sm2) + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(err) + } + // Two separate ScalarMul with fakeGLV (r^(1/2) decomposition) + Add is most efficient. + // Each ScalarMul uses half-GCD to decompose the scalar into ~128-bit sub-scalars. + r1 := c.ScalarMul(p1, s1, opts...) + r2 := c.ScalarMul(p2, s2, opts...) + if cfg.CompleteArithmetic { + return c.AddUnified(r1, r2) + } + return c.Add(r1, r2) } // jointScalarMulGenericUnsafe computes [s1]p1 + [s2]p2 using Shamir's trick and returns it. It doesn't modify p1, p2 nor s1, s2. @@ -901,6 +917,8 @@ func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated panic(fmt.Sprintf("parse opts: %v", err)) } if cfg.CompleteArithmetic { + // Use optimized Shamir's trick for complete arithmetic + // This handles edge cases: zero scalars, zero points res1 := c.scalarMulGLVAndFakeGLV(p1, s1, opts...) res2 := c.scalarMulGLVAndFakeGLV(p2, s2, opts...) return c.AddUnified(res1, res2) @@ -1229,18 +1247,20 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] panic(err) } - var selector1 frontend.Variable + // Handle edge cases for complete arithmetic: s=0, Q=(0,0) + var selector0 frontend.Variable _s := s if cfg.CompleteArithmetic { - selector1 = c.scalarApi.IsZero(s) - _s = c.scalarApi.Select(selector1, c.scalarApi.One(), s) + one := c.scalarApi.One() + selector0 = c.scalarApi.IsZero(s) + _s = c.scalarApi.Select(selector0, one, s) } // First we find the sub-salars s1, s2 s.t. s1 + s2*s = 0 mod r and s1, s2 < sqrt(r). // we also output the sign in case s2 is negative. In that case we compute _s2 = -s2 mod r. - sign, sd, err := c.scalarApi.NewHintGeneric(halfGCD, 1, 2, nil, []*emulated.Element[S]{_s}) + sign, sd, err := c.scalarApi.NewHintGeneric(rationalReconstruct, 1, 2, nil, []*emulated.Element[S]{_s}) if err != nil { - panic(fmt.Sprintf("halfGCD hint: %v", err)) + panic(fmt.Sprintf("rationalReconstruct hint: %v", err)) } s1, s2 := sd[0], sd[1] _s2 := c.scalarApi.Select(sign[0], c.scalarApi.Neg(s2), s2) @@ -1262,17 +1282,29 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] } r0, r1 := R[0], R[1] - var selector2 frontend.Variable - one := c.baseApi.One() - dummy := &AffinePoint[B]{X: *one, Y: *one} - addFn := c.Add + // Handle Q=(0,0), s=0/s=-1, s=±1 (where R=±Q), and s=±3 (where R=±[3]Q) + // for complete arithmetic + var selector1, selector2, selector3 frontend.Variable + _Q := Q if cfg.CompleteArithmetic { - addFn = c.AddUnified - // if Q=(0,0) we assign a dummy (1,1) to Q and R and continue - selector2 = c.api.And(c.baseApi.IsZero(&Q.X), c.baseApi.IsZero(&Q.Y)) - Q = c.Select(selector2, dummy, Q) - r0 = c.baseApi.Select(selector2, c.baseApi.Zero(), r0) - r1 = c.baseApi.Select(selector2, &dummy.Y, r1) + // Use different dummy points for _Q and R to avoid _Q == ±R + dummyQ := c.Generator() + dummyR := &c.GeneratorMultiples()[3] // 8*G, different from G + // selector1: Q=(0,0) + selector1 = c.api.And(c.baseApi.IsZero(&Q.X), c.baseApi.IsZero(&Q.Y)) + _Q = c.Select(selector1, dummyQ, Q) + // selector2: R.X == Q.X (happens when s=±1, so R=±Q and Add would fail) + selector2 = c.baseApi.IsZero(c.baseApi.Sub(&Q.X, r0)) + // selector3: R.X == [3]Q.X (happens when s=±3, so R=±[3]Q and + // tableQ[2]±tableR[1] would be a doubling or point-at-infinity) + tripleQ := c.triple(_Q) + selector3 = c.baseApi.IsZero(c.baseApi.Sub(&tripleQ.X, r0)) + // When s=0/s=-1 (selector0), Q=(0,0) (selector1), R.X==Q.X (selector2), + // or R.X==[3]Q.X (selector3), the incomplete addition formula fails. + // Use dummy for R in these cases. + selectorAny := c.api.Or(c.api.Or(c.api.Or(selector0, selector1), selector2), selector3) + r0 = c.baseApi.Select(selectorAny, &dummyR.X, r0) + r1 = c.baseApi.Select(selectorAny, &dummyR.Y, r1) } var st S @@ -1288,26 +1320,33 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] // tableR[1] = R or -R if s2 is negative // tableR[2] = [3]R or [-3]R if s2 is negative var tableQ, tableR [3]*AffinePoint[B] - tableQ[1] = Q - tableQ[0] = c.Neg(Q) + tableQ[1] = _Q + tableQ[0] = c.Neg(_Q) tableQ[2] = c.triple(tableQ[1]) tableR[1] = &AffinePoint[B]{ X: *r0, Y: *c.baseApi.Select(sign[0], c.baseApi.Neg(r1), r1), } tableR[0] = c.Neg(tableR[1]) - if cfg.CompleteArithmetic { - tableR[2] = c.AddUnified(tableR[1], tableR[1]) - tableR[2] = c.AddUnified(tableR[2], tableR[1]) - } else { - tableR[2] = c.triple(tableR[1]) - } + tableR[2] = c.triple(tableR[1]) // We should start the accumulator by the infinity point, but since affine // formulae are incomplete we suppose that the first bits of the // sub-scalars s1 and s2 are 1, and set: // Acc = Q + R - Acc := addFn(tableQ[1], tableR[1]) + // For complete arithmetic, we add a bias point G to avoid Acc == ±Bi during the loop. + // T2 = Q + R (without bias, used in table construction) + T2 := c.Add(tableQ[1], tableR[1]) + Acc := T2 + var t2EqNegG frontend.Variable + if cfg.CompleteArithmetic { + g := c.Generator() + // Guard: if T2 == -G (i.e. T2.X == G.X), the incomplete Add would fail. + // In that case, replace T2 with a safe dummy before adding G. + t2EqNegG = c.baseApi.IsZero(c.baseApi.Sub(&T2.X, &g.X)) + safeT2 := c.Select(t2EqNegG, &c.GeneratorMultiples()[1], T2) + Acc = c.Add(safeT2, g) + } // At each iteration we need to compute: // [2]Acc ± Q ± R. @@ -1325,16 +1364,13 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] // // T = [3](Q + R) // P = B1 and P' = B1 - t1 := addFn(tableQ[2], tableR[2]) - // T = Q + R - // P = B1 and P' = B2 - T2 := Acc + t1 := c.Add(tableQ[2], tableR[2]) // T = [3]Q + R // P = B1 and P' = B3 - T3 := addFn(tableQ[2], tableR[1]) + T3 := c.Add(tableQ[2], tableR[1]) // T = Q + [3]R // P = B1 and P' = B4 - t4 := addFn(tableQ[1], tableR[2]) + t4 := c.Add(tableQ[1], tableR[2]) // T = -Q - R // P = B2 and P' = B1 T5 := c.Neg(T2) @@ -1347,17 +1383,17 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] // T = -[3]Q - R // T = [3]Q - R // P = B3 and P' = B1 - t9 := addFn(tableQ[2], tableR[0]) + t9 := c.Add(tableQ[2], tableR[0]) // T = Q - [3]R // P = B3 and P' = B2 t := c.Neg(tableR[2]) - T10 := addFn(tableQ[1], t) + T10 := c.Add(tableQ[1], t) // T = [3](Q - R) // P = B3 and P' = B3 - T11 := addFn(tableQ[2], t) + T11 := c.Add(tableQ[2], t) // T = -R + Q // P = B3 and P' = B4 - T12 := addFn(tableR[0], tableQ[1]) + T12 := c.Add(tableR[0], tableQ[1]) // T = R - [3]Q // P = B4 and P' = B2 T14 := c.Neg(t9) @@ -1376,8 +1412,9 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] } // We don't use doubleAndAdd here as it would involve edge cases // when bits are 00 (T==-Acc) or 11 (T==Acc). - Acc = c.doubleGeneric(Acc, cfg.CompleteArithmetic) - Acc = addFn(Acc, T) + // With bias point, we can use regular double and Add. + Acc = c.double(Acc) + Acc = c.Add(Acc, T) } else { // when nbits is odd we start the main loop at normally nbits - 1 nbits++ @@ -1408,8 +1445,9 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] ), } // Acc = [4]Acc + T - Acc = c.doubleGeneric(Acc, cfg.CompleteArithmetic) - Acc = c.doubleAndAddGeneric(Acc, T, cfg.CompleteArithmetic) + // Bias point protects us from incomplete additions + Acc = c.double(Acc) + Acc = c.doubleAndAdd(Acc, T) } // i = 2 @@ -1441,24 +1479,28 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] } // to avoid incomplete additions we add [3]R to the precomputed T before computing [4]Acc+T // Acc = [4]Acc + T + [3]R - T = addFn(T, tableR[2]) - Acc = c.doubleGeneric(Acc, cfg.CompleteArithmetic) - Acc = c.doubleAndAddGeneric(Acc, T, cfg.CompleteArithmetic) + T = c.Add(T, tableR[2]) + Acc = c.double(Acc) + Acc = c.doubleAndAdd(Acc, T) // i = 0 - // subtract Q and R if the first bits are 0. - // When cfg.CompleteArithmetic is set, we use AddUnified instead of Add. - // This means when s=0 then Acc=(0,0) because AddUnified(Q, -Q) = (0,0). - tableQ[0] = addFn(tableQ[0], Acc) + // subtract Q and R if the first bits are 0 + tableQ[0] = c.Add(tableQ[0], Acc) Acc = c.Select(s1bits[0], Acc, tableQ[0]) - tableR[0] = addFn(tableR[0], Acc) + tableR[0] = c.Add(tableR[0], Acc) Acc = c.Select(s2bits[0], Acc, tableR[0]) + // For complete arithmetic, subtract the bias [2^nbits]G and handle edge cases if cfg.CompleteArithmetic { - Acc = c.Select(c.api.Or(selector1, selector2), tableR[2], Acc) + gm := c.GeneratorMultiples()[nbits-1] + Acc = c.Add(Acc, c.Neg(&gm)) + // If s=0, Q=(0,0), R.X==Q.X, R.X==[3]Q.X, or T2==-G (bias collision), + // use the precomputed [3]R as a fallback + selectorEdge := c.api.Or(c.api.Or(c.api.Or(selector0, selector1), c.api.Or(selector2, selector3)), t2EqNegG) + Acc = c.Select(selectorEdge, tableR[2], Acc) } // we added [3]R at the last iteration so the result should be - // Acc = [s1]Q + [s2]R + [3]R + // Acc = [s1]Q + [s2]R + [3]R (+ [2^nbits]G - [2^nbits]G for complete arithmetic) // = [s1]Q + [s2*s]Q + [3]R // = [s1+s2*s]Q + [3]R // = [0]Q + [3]R @@ -1491,16 +1533,12 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem panic(err) } - // handle 0-scalar and (-1)-scalar cases + // handle 0-scalar case var selector0 frontend.Variable _s := s if cfg.CompleteArithmetic { one := c.scalarApi.One() - selector0 = c.api.Or( - c.scalarApi.IsZero(s), - c.scalarApi.IsZero( - c.scalarApi.Add(s, one)), - ) + selector0 = c.scalarApi.IsZero(s) _s = c.scalarApi.Select(selector0, one, s) } @@ -1508,33 +1546,26 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem // Checking Q - [s]P = 0 is equivalent to [v]Q + [-s*v]P = 0 for some nonzero v. // // The GLV curves supported in gnark have j-invariant 0, which means the eigenvalue - // of the GLV endomorphism is a primitive cube root of unity. If we write - // v, s and r as Eisenstein integers we can express the check as: + // of the GLV endomorphism is a primitive cube root of unity λ. Using this we can + // express the check as: // // [v1 + λ*v2]Q + [u1 + λ*u2]P = 0 // [v1]Q + [v2]phi(Q) + [u1]P + [u2]phi(P) = 0 // - // where (v1 + λ*v2)*(s1 + λ*s2) = u1 + λu2 mod (r1 + λ*r2) - // and u1, u2, v1, v2 < r^{1/4} (up to a constant factor). + // where (v1 + λ*v2)*s = u1 + λ*u2 mod r + // and u1, u2, v1, v2 < c*r^{1/4} with c ≈ 1.25 (proven bound from LLL lattice reduction). // - // This can be done as follows: - // 1. decompose s into s1 + λ*s2 mod r s.t. s1, s2 < sqrt(r) (hinted classical GLV decomposition). - // 2. decompose r into r1 + λ*r2 s.t. r1, r2 < sqrt(r) (hardcoded half-GCD of λ mod r). - // 3. find u1, u2, v1, v2 < c*r^{1/4} s.t. (v1 + λ*v2)*(s1 + λ*s2) = (u1 + λ*u2) mod (r1 + λ*r2). - // This can be done through a hinted half-GCD in the number field - // K=Q[w]/f(w). This corresponds to K being the Eisenstein ring of - // integers i.e. w is a primitive cube root of unity, f(w)=w^2+w+1=0. + // We use LLL-based lattice reduction to find small u1, u2, v1, v2 satisfying + // s ≡ -(u1 + λ*u2) / (v1 + λ*v2) (mod r). // // The hint returns u1, u2, v1, v2. - // In-circuit we check that (v1 + λ*v2)*s = (u1 + λ*u2) mod r - // + // In-circuit we check that (v1 + λ*v2)*s + u1 + λ*u2 = 0 mod r // - // Eisenstein integers real and imaginary parts can be negative. So we - // return the absolute value in the hint and negate the corresponding - // points here when needed. - signs, sd, err := c.scalarApi.NewHintGeneric(halfGCDEisenstein, 4, 4, nil, []*emulated.Element[S]{_s, c.eigenvalue}) + // The sub-scalars can be negative. So we return the absolute value in the + // hint and negate the corresponding points here when needed. + signs, sd, err := c.scalarApi.NewHintGeneric(rationalReconstructExt, 4, 4, nil, []*emulated.Element[S]{_s, c.eigenvalue}) if err != nil { - panic(fmt.Sprintf("halfGCDEisenstein hint: %v", err)) + panic(fmt.Sprintf("rationalReconstructExt hint: %v", err)) } u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] @@ -1567,6 +1598,10 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem c.scalarApi.AssertIsEqual(lhs, rhs) + // Ensure the denominator v1 + λ*v2 is non-zero to prevent trivial decomposition + den := c.scalarApi.Add(v1, c.scalarApi.Mul(c.eigenvalue, v2)) + c.scalarApi.AssertIsDifferent(den, c.scalarApi.Zero()) + // Next we compute the hinted scalar mul Q = [s]P // P coordinates are in Fp and the scalar s in Fr // we decompose Q.X, Q.Y, s into limbs and recompose them in the hint. @@ -1643,10 +1678,9 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem g := c.Generator() Acc = c.Add(Acc, g) - // u1, u2, v1, v2 < r^{1/4} (up to a constant factor). - // We prove that the factor is log_(3/sqrt(3)))(r). - // so we need to add 9 bits to r^{1/4}.nbits(). - nbits := st.Modulus().BitLen()>>2 + 9 + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 (proven bound from LLL lattice reduction). + // We need ceil(r.BitLen()/4) + 2 bits to account for the constant factor. + nbits := (st.Modulus().BitLen()+3)/4 + 2 u1bits := c.scalarApi.ToBits(u1) u2bits := c.scalarApi.ToBits(u2) v1bits := c.scalarApi.ToBits(v1) @@ -1706,7 +1740,10 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem ), } // Acc = [2]Acc + Bi - Acc = c.doubleAndAdd(Acc, Bi) + // When P=(0,0) with CompleteArithmetic, table entries are identity-like + // causing Acc.X == Bi.X collisions, so we use unified addition. + // Otherwise, the bias point G prevents collisions and incomplete addition is safe. + Acc = c.doubleAndAddGeneric(Acc, Bi, cfg.CompleteArithmetic) } // i = 0 diff --git a/std/algebra/emulated/sw_emulated/point_test.go b/std/algebra/emulated/sw_emulated/point_test.go index 047f7de197..9bd13f36d1 100644 --- a/std/algebra/emulated/sw_emulated/point_test.go +++ b/std/algebra/emulated/sw_emulated/point_test.go @@ -2177,6 +2177,54 @@ func TestScalarMulFakeGLVEdgeCasesEdgeCases(t *testing.T) { } err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) assert.NoError(err) + + // -1 * P == -P + negPy := new(big.Int).Sub(p256.Params().P, py) + witness4 := ScalarMulFakeGLVEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](big.NewInt(-1)), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + R: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](negPy), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) + + // 3 * P == [3]P + threePx, threePy := p256.ScalarMult(px, py, big.NewInt(3).Bytes()) //nolint:staticcheck // compatibility test only + witness5 := ScalarMulFakeGLVEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](big.NewInt(3)), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + R: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](threePx), + Y: emulated.ValueOf[emulated.P256Fp](threePy), + }, + } + err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) + assert.NoError(err) + + // -3 * P == [-3]P + negThreePy := new(big.Int).Sub(p256.Params().P, threePy) + witness6 := ScalarMulFakeGLVEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](big.NewInt(-3)), + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + R: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](threePx), + Y: emulated.ValueOf[emulated.P256Fp](negThreePy), + }, + } + err = test.IsSolved(&circuit, &witness6, testCurve.ScalarField()) + assert.NoError(err) } func TestScalarMulFakeGLVEdgeCasesEdgeCases2(t *testing.T) { @@ -2233,6 +2281,22 @@ func TestScalarMulFakeGLVEdgeCasesEdgeCases2(t *testing.T) { } err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) assert.NoError(err) + + // -1 * P == -P + negPy := new(big.Int).Sub(p384.Params().P, py) + witness4 := ScalarMulFakeGLVEdgeCasesTest[emulated.P384Fp, emulated.P384Fr]{ + S: emulated.ValueOf[emulated.P384Fr](big.NewInt(-1)), + P: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](px), + Y: emulated.ValueOf[emulated.P384Fp](py), + }, + R: AffinePoint[emulated.P384Fp]{ + X: emulated.ValueOf[emulated.P384Fp](px), + Y: emulated.ValueOf[emulated.P384Fp](negPy), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) } func TestScalarMulFakeGLVEdgeCasesEdgeCases3(t *testing.T) { @@ -2291,6 +2355,23 @@ func TestScalarMulFakeGLVEdgeCasesEdgeCases3(t *testing.T) { } err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) assert.NoError(err) + + // -1 * P == -P + var negG stark_curve.G1Affine + negG.Neg(&g) + witness4 := ScalarMulFakeGLVEdgeCasesTest[emulated.STARKCurveFp, emulated.STARKCurveFr]{ + S: emulated.ValueOf[emulated.STARKCurveFr](big.NewInt(-1)), + P: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](g.X), + Y: emulated.ValueOf[emulated.STARKCurveFp](g.Y), + }, + R: AffinePoint[emulated.STARKCurveFp]{ + X: emulated.ValueOf[emulated.STARKCurveFp](negG.X), + Y: emulated.ValueOf[emulated.STARKCurveFp](negG.Y), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) } type ScalarMulGLVAndFakeGLVTest[T, S emulated.FieldParams] struct { @@ -2545,3 +2626,530 @@ func TestScalarMulGLVAndFakeGLVEdgeCasesEdgeCases2(t *testing.T) { err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) assert.NoError(err) } + +// JointScalarMulBaseCompleteTest tests JointScalarMulBase with complete arithmetic (for P256) +type JointScalarMulBaseCompleteTest[B, S emulated.FieldParams] struct { + P AffinePoint[B] + S1 emulated.Element[S] + S2 emulated.Element[S] + Res AffinePoint[B] +} + +func (c *JointScalarMulBaseCompleteTest[B, S]) Define(api frontend.API) error { + cr, err := New[B, S](api, GetCurveParams[B]()) + if err != nil { + return err + } + res := cr.JointScalarMulBase(&c.P, &c.S1, &c.S2, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res, &c.Res) + return nil +} + +// JointScalarMulBaseUnsafeTest tests JointScalarMulBase without complete arithmetic +type JointScalarMulBaseUnsafeTest[B, S emulated.FieldParams] struct { + P AffinePoint[B] + S1 emulated.Element[S] + S2 emulated.Element[S] + Res AffinePoint[B] +} + +func (c *JointScalarMulBaseUnsafeTest[B, S]) Define(api frontend.API) error { + cr, err := New[B, S](api, GetCurveParams[B]()) + if err != nil { + return err + } + res := cr.JointScalarMulBase(&c.P, &c.S1, &c.S2) + cr.AssertIsEqual(res, &c.Res) + return nil +} + +// ScalarMulFakeGLVCompleteTest tests ScalarMul with complete arithmetic for non-GLV curves +type ScalarMulFakeGLVCompleteTest[B, S emulated.FieldParams] struct { + P AffinePoint[B] + S emulated.Element[S] + Res AffinePoint[B] +} + +func (c *ScalarMulFakeGLVCompleteTest[B, S]) Define(api frontend.API) error { + cr, err := New[B, S](api, GetCurveParams[B]()) + if err != nil { + return err + } + res := cr.ScalarMul(&c.P, &c.S, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res, &c.Res) + return nil +} + +// ScalarMulGLVCompleteTest tests ScalarMul with complete arithmetic for GLV curves +type ScalarMulGLVCompleteTest[B, S emulated.FieldParams] struct { + P AffinePoint[B] + S emulated.Element[S] + Res AffinePoint[B] +} + +func (c *ScalarMulGLVCompleteTest[B, S]) Define(api frontend.API) error { + cr, err := New[B, S](api, GetCurveParams[B]()) + if err != nil { + return err + } + res := cr.ScalarMul(&c.P, &c.S, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res, &c.Res) + return nil +} + +// JointScalarMulGLVCompleteTest tests JointScalarMulBase with complete arithmetic for GLV curves +type JointScalarMulGLVCompleteTest[B, S emulated.FieldParams] struct { + P AffinePoint[B] + S1 emulated.Element[S] + S2 emulated.Element[S] + Res AffinePoint[B] +} + +func (c *JointScalarMulGLVCompleteTest[B, S]) Define(api frontend.API) error { + cr, err := New[B, S](api, GetCurveParams[B]()) + if err != nil { + return err + } + res := cr.JointScalarMulBase(&c.P, &c.S1, &c.S2, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res, &c.Res) + return nil +} + +// ScalarMulBaseCompleteTest tests ScalarMulBase with complete arithmetic +type ScalarMulBaseCompleteTest[B, S emulated.FieldParams] struct { + S emulated.Element[S] + Res AffinePoint[B] +} + +func (c *ScalarMulBaseCompleteTest[B, S]) Define(api frontend.API) error { + cr, err := New[B, S](api, GetCurveParams[B]()) + if err != nil { + return err + } + res := cr.ScalarMulBase(&c.S, algopts.WithCompleteArithmetic()) + cr.AssertIsEqual(res, &c.Res) + return nil +} + +// TestScalarMulBaseComplete tests ScalarMulBase with complete arithmetic for GLV curves +func TestScalarMulBaseComplete(t *testing.T) { + assert := test.NewAssert(t) + + // secp256k1 (GLV curve) + t.Run("secp256k1", func(t *testing.T) { + _, g := secp256k1.Generators() + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + + var res secp256k1.G1Affine + res.ScalarMultiplication(&g, s) + + circuit := ScalarMulBaseCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := ScalarMulBaseCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](s), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](res.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](res.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) + }) + + // P-256 (non-GLV curve) + t.Run("P256", func(t *testing.T) { + p256 := elliptic.P256() + s, _ := rand.Int(rand.Reader, p256.Params().N) + px, py := p256.ScalarBaseMult(s.Bytes()) //nolint:staticcheck // test needs low-level EC ops + + circuit := ScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{} + witness := ScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](s), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) + }) +} + +// TestScalarMulBaseEdgeCases tests ScalarMulBase edge cases with complete arithmetic +func TestScalarMulBaseEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + + // secp256k1 (GLV curve) + t.Run("secp256k1", func(t *testing.T) { + _, g := secp256k1.Generators() + var infinity secp256k1.G1Affine + + circuit := ScalarMulBaseCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + + // Test: [0]*G = infinity + witness0 := ScalarMulBaseCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](0), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](infinity.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](infinity.Y), + }, + } + err := test.IsSolved(&circuit, &witness0, testCurve.ScalarField()) + assert.NoError(err) + + // Test: [1]*G = G + witness1 := ScalarMulBaseCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](1), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](g.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y), + }, + } + err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // Test: [r-1]*G = -G + rMinus1 := new(big.Int).Sub(fr_secp.Modulus(), big.NewInt(1)) + var negG secp256k1.G1Affine + negG.Neg(&g) + witnessRm1 := ScalarMulBaseCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + S: emulated.ValueOf[emulated.Secp256k1Fr](rMinus1), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](negG.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](negG.Y), + }, + } + err = test.IsSolved(&circuit, &witnessRm1, testCurve.ScalarField()) + assert.NoError(err) + }) + + // P-256 (non-GLV curve) + t.Run("P256", func(t *testing.T) { + p256 := elliptic.P256() + gx := p256.Params().Gx + gy := p256.Params().Gy + + circuit := ScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{} + + // Test: [0]*G = infinity + witness0 := ScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](0), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + } + err := test.IsSolved(&circuit, &witness0, testCurve.ScalarField()) + assert.NoError(err) + + // Test: [1]*G = G + witness1 := ScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](1), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](gx), + Y: emulated.ValueOf[emulated.P256Fp](gy), + }, + } + err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // Test: [r-1]*G = -G + rMinus1 := new(big.Int).Sub(p256.Params().N, big.NewInt(1)) + px, py := p256.ScalarBaseMult(rMinus1.Bytes()) //nolint:staticcheck // test needs low-level EC ops + witnessRm1 := ScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + S: emulated.ValueOf[emulated.P256Fr](rMinus1), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + } + err = test.IsSolved(&circuit, &witnessRm1, testCurve.ScalarField()) + assert.NoError(err) + }) +} + +// TestJointScalarMulBaseComplete tests JointScalarMulBase with complete arithmetic +func TestJointScalarMulBaseComplete(t *testing.T) { + assert := test.NewAssert(t) + + // secp256k1 (GLV curve) + t.Run("secp256k1", func(t *testing.T) { + _, g := secp256k1.Generators() + var r1, r2 fr_secp.Element + _, _ = r1.SetRandom() + _, _ = r2.SetRandom() + s1 := new(big.Int) + s2 := new(big.Int) + r1.BigInt(s1) + r2.BigInt(s2) + + // P = random point + var p secp256k1.G1Affine + p.ScalarMultiplication(&g, s2) + + // Circuit computes: [c.S2]*G + [c.S1]*P (due to JointScalarMulBase(p, s2, s1) signature) + // So with witness S1=s1, S2=s2, result = [s2]*G + [s1]*P + var res1, res2, res secp256k1.G1Affine + res1.ScalarMultiplication(&g, s2) + res2.ScalarMultiplication(&p, s1) + res.Add(&res1, &res2) + + circuit := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + witness := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](p.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](p.Y), + }, + S1: emulated.ValueOf[emulated.Secp256k1Fr](s1), + S2: emulated.ValueOf[emulated.Secp256k1Fr](s2), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](res.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](res.Y), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) + }) + + // P-256 (non-GLV curve) + t.Run("P256", func(t *testing.T) { + p256 := elliptic.P256() + s1, _ := rand.Int(rand.Reader, p256.Params().N) + s2, _ := rand.Int(rand.Reader, p256.Params().N) + + // P = random point + px, py := p256.ScalarBaseMult(s1.Bytes()) //nolint:staticcheck // test needs low-level EC ops + + // Circuit computes: [c.S2]*G + [c.S1]*P (due to JointScalarMulBase(p, s2, s1) signature) + // So with witness S1=s1, S2=s2, result = [s2]*G + [s1]*P + tmp1x, tmp1y := p256.ScalarBaseMult(s2.Bytes()) //nolint:staticcheck // test needs low-level EC ops + tmp2x, tmp2y := p256.ScalarMult(px, py, s1.Bytes()) //nolint:staticcheck // test needs low-level EC ops + resx, resy := p256.Add(tmp1x, tmp1y, tmp2x, tmp2y) //nolint:staticcheck // test needs low-level EC ops + + circuit := JointScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{} + witness := JointScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + S1: emulated.ValueOf[emulated.P256Fr](s1), + S2: emulated.ValueOf[emulated.P256Fr](s2), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](resx), + Y: emulated.ValueOf[emulated.P256Fp](resy), + }, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) + }) +} + +// TestJointScalarMulBaseEdgeCases tests JointScalarMulBase edge cases with complete arithmetic +func TestJointScalarMulBaseEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + + // secp256k1 (GLV curve) + // Circuit computes: [S2]*G + [S1]*P (due to JointScalarMulBase(p, s2, s1) signature) + t.Run("secp256k1", func(t *testing.T) { + _, g := secp256k1.Generators() + var infinity secp256k1.G1Affine + var r fr_secp.Element + _, _ = r.SetRandom() + s := new(big.Int) + r.BigInt(s) + + // P = [s]*G (a random point) + var p, res secp256k1.G1Affine + p.ScalarMultiplication(&g, s) + + circuit := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{} + + // Test: S1=0, S2=0 => [0]*G + [0]*P = infinity + witness0 := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](p.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](p.Y), + }, + S1: emulated.ValueOf[emulated.Secp256k1Fr](0), + S2: emulated.ValueOf[emulated.Secp256k1Fr](0), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](infinity.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](infinity.Y), + }, + } + err := test.IsSolved(&circuit, &witness0, testCurve.ScalarField()) + assert.NoError(err) + + // Test: S1=0, S2=s => [s]*G + [0]*P = [s]*G + res.ScalarMultiplication(&g, s) + witness1 := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](p.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](p.Y), + }, + S1: emulated.ValueOf[emulated.Secp256k1Fr](0), + S2: emulated.ValueOf[emulated.Secp256k1Fr](s), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](res.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](res.Y), + }, + } + err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // Test: S1=s, S2=0 => [0]*G + [s]*P = [s]*P + res.ScalarMultiplication(&p, s) + witness2 := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](p.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](p.Y), + }, + S1: emulated.ValueOf[emulated.Secp256k1Fr](s), + S2: emulated.ValueOf[emulated.Secp256k1Fr](0), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](res.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](res.Y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // Test: P is infinity, S1=0, S2=s => [s]*G + [0]*infinity = [s]*G + res.ScalarMultiplication(&g, s) + witness3 := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](infinity.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](infinity.Y), + }, + S1: emulated.ValueOf[emulated.Secp256k1Fr](0), + S2: emulated.ValueOf[emulated.Secp256k1Fr](s), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](res.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](res.Y), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) + + // Test: S1=-1, S2=1 => [1]*G + [-1]*P = G - P + var negP secp256k1.G1Affine + negP.Neg(&p) + res.Add(&g, &negP) + witness4 := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](p.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](p.Y), + }, + S1: emulated.ValueOf[emulated.Secp256k1Fr](big.NewInt(-1)), + S2: emulated.ValueOf[emulated.Secp256k1Fr](1), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](res.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](res.Y), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) + + // Test: S1=1, S2=-1 => [-1]*G + [1]*P = P - G + var negG secp256k1.G1Affine + negG.Neg(&g) + res.Add(&p, &negG) + witness5 := JointScalarMulGLVCompleteTest[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{ + P: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](p.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](p.Y), + }, + S1: emulated.ValueOf[emulated.Secp256k1Fr](1), + S2: emulated.ValueOf[emulated.Secp256k1Fr](big.NewInt(-1)), + Res: AffinePoint[emulated.Secp256k1Fp]{ + X: emulated.ValueOf[emulated.Secp256k1Fp](res.X), + Y: emulated.ValueOf[emulated.Secp256k1Fp](res.Y), + }, + } + err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) + assert.NoError(err) + }) + + // P-256 (non-GLV curve) + // Circuit computes: [S2]*G + [S1]*P (due to JointScalarMulBase(p, s2, s1) signature) + t.Run("P256", func(t *testing.T) { + p256 := elliptic.P256() + s, _ := rand.Int(rand.Reader, p256.Params().N) + + // P = [s]*G (a random point) + px, py := p256.ScalarBaseMult(s.Bytes()) //nolint:staticcheck // test needs low-level EC ops + + circuit := JointScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{} + + // Test: S1=0, S2=0 => [0]*G + [0]*P = infinity + witness0 := JointScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + S1: emulated.ValueOf[emulated.P256Fr](0), + S2: emulated.ValueOf[emulated.P256Fr](0), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + } + err := test.IsSolved(&circuit, &witness0, testCurve.ScalarField()) + assert.NoError(err) + + // Test: S1=0, S2=s => [s]*G + [0]*P = [s]*G + resx, resy := p256.ScalarBaseMult(s.Bytes()) //nolint:staticcheck // test needs low-level EC ops + witness1 := JointScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + S1: emulated.ValueOf[emulated.P256Fr](0), + S2: emulated.ValueOf[emulated.P256Fr](s), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](resx), + Y: emulated.ValueOf[emulated.P256Fp](resy), + }, + } + err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // Test: S1=s, S2=0 => [0]*G + [s]*P = [s]*P + resx, resy = p256.ScalarMult(px, py, s.Bytes()) //nolint:staticcheck // test needs low-level EC ops + witness2 := JointScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + S1: emulated.ValueOf[emulated.P256Fr](s), + S2: emulated.ValueOf[emulated.P256Fr](0), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](resx), + Y: emulated.ValueOf[emulated.P256Fp](resy), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // Test: S1=-1, S2=1 => [1]*G + [-1]*P = G - P + gx, gy := p256.Params().Gx, p256.Params().Gy + negPy := new(big.Int).Sub(p256.Params().P, py) + resx, resy = p256.Add(gx, gy, px, negPy) //nolint:staticcheck // test needs low-level EC ops + witness3 := JointScalarMulBaseCompleteTest[emulated.P256Fp, emulated.P256Fr]{ + P: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](px), + Y: emulated.ValueOf[emulated.P256Fp](py), + }, + S1: emulated.ValueOf[emulated.P256Fr](big.NewInt(-1)), + S2: emulated.ValueOf[emulated.P256Fr](1), + Res: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](resx), + Y: emulated.ValueOf[emulated.P256Fp](resy), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) + }) +} diff --git a/std/algebra/native/sw_bls12377/g1.go b/std/algebra/native/sw_bls12377/g1.go index b678682e0f..0f6a75c26f 100644 --- a/std/algebra/native/sw_bls12377/g1.go +++ b/std/algebra/native/sw_bls12377/g1.go @@ -13,6 +13,8 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/algopts" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" ) // G1Affine point in affine coords @@ -455,16 +457,182 @@ func (p *G1Affine) jointScalarMul(api frontend.API, Q, R G1Affine, s, t frontend panic(err) } if cfg.CompleteArithmetic { - var tmp G1Affine - p.ScalarMul(api, Q, s, opts...) - tmp.ScalarMul(api, R, t, opts...) - p.AddUnified(api, tmp) + p.jointScalarMulComplete(api, Q, R, s, t) } else { p.jointScalarMulUnsafe(api, Q, R, s, t) } return p } +// jointScalarMulComplete computes [s]Q + [t]R using a hint and Shamir's trick verification. +// It handles edge cases: Q=(0,0), R=(0,0), s=0, t=0. +func (p *G1Affine) jointScalarMulComplete(api frontend.API, Q, R G1Affine, s, t frontend.Variable) *G1Affine { + cc := getInnerCurveConfig(api.Compiler().Field()) + + // handle zero scalars and zero points + sIsZero := api.IsZero(s) + tIsZero := api.IsZero(t) + QIsZero := api.And(api.IsZero(Q.X), api.IsZero(Q.Y)) + RIsZero := api.And(api.IsZero(R.X), api.IsZero(R.Y)) + + // sContribZero = s=0 OR Q=(0,0) + // tContribZero = t=0 OR R=(0,0) + sContribZero := api.Or(sIsZero, QIsZero) + tContribZero := api.Or(tIsZero, RIsZero) + // when s contribution is zero, set s=1 to avoid issues with scalar decomposition + _s := api.Select(sContribZero, 1, s) + // when t contribution is zero, set t=1 to avoid issues with scalar decomposition + _t := api.Select(tContribZero, 1, t) + + // Use on-curve generator points as dummies for soundness. + // Off-curve dummies would make the loop produce garbage for edge cases, + // preventing verification of the hint result. + // With on-curve dummies, the loop computes a valid (but shifted) result + // that we can adjust for at the end. + _, _, g1aff, _ := bls12377.Generators() + var g1Triple bls12377.G1Affine + g1Triple.Double(&g1aff) + g1Triple.Add(&g1Triple, &g1aff) + dummyQ := G1Affine{ + X: g1aff.X.BigInt(new(big.Int)), + Y: g1aff.Y.BigInt(new(big.Int)), + } + dummyR := G1Affine{ + X: g1Triple.X.BigInt(new(big.Int)), + Y: g1Triple.Y.BigInt(new(big.Int)), + } + + // when Q contribution is zero, assign dummyQ + _Q := Q + _Q.Select(api, sContribZero, dummyQ, Q) + // when R contribution is zero, assign dummyR + _R := R + _R.Select(api, tContribZero, dummyR, R) + + // Get the result from hint - handles all edge cases correctly + point, err := api.Compiler().NewHint(jointScalarMulG1Hint, 2, Q.X, Q.Y, R.X, R.Y, s, t) + if err != nil { + panic(err) + } + result := G1Affine{X: point[0], Y: point[1]} + + sd, err := api.Compiler().NewHint(decomposeScalarG1Simple, 2, _s) + if err != nil { + panic(err) + } + s1, s2 := sd[0], sd[1] + + td, err := api.Compiler().NewHint(decomposeScalarG1Simple, 2, _t) + if err != nil { + panic(err) + } + t1, t2 := td[0], td[1] + + api.AssertIsEqual(api.Add(s1, api.Mul(s2, cc.lambda)), _s) + api.AssertIsEqual(api.Add(t1, api.Mul(t2, cc.lambda)), _t) + + nbits := cc.lambda.BitLen() + + s1bits := api.ToBinary(s1, nbits) + s2bits := api.ToBinary(s2, nbits) + t1bits := api.ToBinary(t1, nbits) + t2bits := api.ToBinary(t2, nbits) + + // precompute -Q, -Φ(Q), Φ(Q) + var tableQ, tablePhiQ [2]G1Affine + tableQ[1] = _Q + tableQ[0].Neg(api, _Q) + cc.phi1(api, &tablePhiQ[1], &_Q) + tablePhiQ[0].Neg(api, tablePhiQ[1]) + // precompute -R, -Φ(R), Φ(R) + var tableR, tablePhiR [2]G1Affine + tableR[1] = _R + tableR[0].Neg(api, _R) + cc.phi1(api, &tablePhiR[1], &_R) + tablePhiR[0].Neg(api, tablePhiR[1]) + // precompute Q+R, -Q-R, Q-R, -Q+R, Φ(Q)+Φ(R), -Φ(Q)-Φ(R), Φ(Q)-Φ(R), -Φ(Q)+Φ(R) + var tableS, tablePhiS [4]G1Affine + tableS[0] = tableQ[0] + tableS[0].AddUnified(api, tableR[0]) + tableS[1].Neg(api, tableS[0]) + tableS[2] = _Q + tableS[2].AddUnified(api, tableR[0]) + tableS[3].Neg(api, tableS[2]) + cc.phi1(api, &tablePhiS[0], &tableS[0]) + cc.phi1(api, &tablePhiS[1], &tableS[1]) + cc.phi1(api, &tablePhiS[2], &tableS[2]) + cc.phi1(api, &tablePhiS[3], &tableS[3]) + + // suppose first bit is 1 and set: + // Acc = Q + R + Φ(Q) + Φ(R) = -Φ²(Q+R) + var Acc G1Affine + cc.phi2Neg(api, &Acc, &tableS[1]) + + // We add the point H=(0,1) on BLS12-377 of order 2 to avoid incomplete + // additions in the loop by forcing Acc to be different than the stored B. + // Since the loop size N=nbits-1 is even, [2^N]H = (0,1). + H := G1Affine{X: 0, Y: 1} + Acc.AddUnified(api, H) + + // Acc = [2]Acc ± Q ± R ± Φ(Q) ± Φ(R) + // We use Double + AddUnified instead of DoubleAndAdd/AddAssign to handle + // the case Q=±R where table entries may be the identity point (0,0). + var B G1Affine + for i := nbits - 1; i > 0; i-- { + B.X = api.Select(api.Xor(s1bits[i], t1bits[i]), tableS[2].X, tableS[0].X) + B.Y = api.Lookup2(s1bits[i], t1bits[i], tableS[0].Y, tableS[2].Y, tableS[3].Y, tableS[1].Y) + Acc.Double(api, Acc) + Acc.AddUnified(api, B) + B.X = api.Select(api.Xor(s2bits[i], t2bits[i]), tablePhiS[2].X, tablePhiS[0].X) + B.Y = api.Lookup2(s2bits[i], t2bits[i], tablePhiS[0].Y, tablePhiS[2].Y, tablePhiS[3].Y, tablePhiS[1].Y) + Acc.AddUnified(api, B) + } + + // i = 0 + // subtract the initial point from the accumulator when first bit was 0 + // use AddUnified for complete arithmetic at i=0 + tableQ[0].AddUnified(api, Acc) + Acc.Select(api, s1bits[0], Acc, tableQ[0]) + tablePhiQ[0].AddUnified(api, Acc) + Acc.Select(api, s2bits[0], Acc, tablePhiQ[0]) + tableR[0].AddUnified(api, Acc) + Acc.Select(api, t1bits[0], Acc, tableR[0]) + tablePhiR[0].AddUnified(api, Acc) + Acc.Select(api, t2bits[0], Acc, tablePhiR[0]) + + // subtract [2^N]H = (0,1) since we added H at the beginning + Acc.AddUnified(api, G1Affine{X: 0, Y: -1}) + + // Acc now equals [_s]*_Q + [_t]*_R where: + // - Common case: _s=s, _Q=Q, _t=t, _R=R => Acc = [s]*Q + [t]*R = result + // - sContribZero: _s=1, _Q=dummyQ, _t=t, _R=R => Acc = dummyQ + [t]*R + // - tContribZero: _s=s, _Q=Q, _t=1, _R=dummyR => Acc = [s]*Q + dummyR + // - Both zero: _s=1, _Q=dummyQ, _t=1, _R=dummyR => Acc = dummyQ + dummyR + // + // For edge cases, subtract the dummy contributions to recover the true result. + // AddUnified handles (0,0) as identity, so when the adjustment is (0,0) it's a no-op. + var negDummyQ, negDummyR G1Affine + negDummyQ.Neg(api, dummyQ) + negDummyR.Neg(api, dummyR) + + var adjQ G1Affine + adjQ.X = api.Select(sContribZero, negDummyQ.X, 0) + adjQ.Y = api.Select(sContribZero, negDummyQ.Y, 0) + Acc.AddUnified(api, adjQ) + + var adjR G1Affine + adjR.X = api.Select(tContribZero, negDummyR.X, 0) + adjR.Y = api.Select(tContribZero, negDummyR.Y, 0) + Acc.AddUnified(api, adjR) + + Acc.AssertIsEqual(api, result) + + p.X = result.X + p.Y = result.Y + + return p +} + // P = [s]Q + [t]R using Shamir's trick func (p *G1Affine) jointScalarMulUnsafe(api frontend.API, Q, R G1Affine, s, t frontend.Variable) *G1Affine { cc := getInnerCurveConfig(api.Compiler().Field()) @@ -660,9 +828,6 @@ func (p *G1Affine) scalarBitsMul(api frontend.API, Q G1Affine, s1bits, s2bits [] return p } -// fake-GLV -// -// N.B.: this method is more expensive than classical GLV, but it is useful for testing purposes. func (p *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s frontend.Variable, opts ...algopts.AlgebraOption) *G1Affine { cfg, err := algopts.NewConfig(opts...) if err != nil { @@ -682,68 +847,81 @@ func (p *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte // Checking Q - [s]P = 0 is equivalent to [v]Q + [-s*v]P = 0 for some nonzero v. // // The GLV curves supported in gnark have j-invariant 0, which means the eigenvalue - // of the GLV endomorphism is a primitive cube root of unity. If we write - // v, s and r as Eisenstein integers we can express the check as: + // of the GLV endomorphism is a primitive cube root of unity λ. Using this we can + // express the check as: // // [v1 + λ*v2]Q + [u1 + λ*u2]P = 0 // [v1]Q + [v2]phi(Q) + [u1]P + [u2]phi(P) = 0 // - // where (v1 + λ*v2)*(s1 + λ*s2) = u1 + λu2 mod (r1 + λ*r2) - // and u1, u2, v1, v2 < r^{1/4} (up to a constant factor). - // - // This can be done as follows: - // 1. decompose s into s1 + λ*s2 mod r s.t. s1, s2 < sqrt(r) (hinted classical GLV decomposition). - // 2. decompose r into r1 + λ*r2 s.t. r1, r2 < sqrt(r) (hardcoded half-GCD of λ mod r). - // 3. find u1, u2, v1, v2 < c*r^{1/4} s.t. (v1 + λ*v2)*(s1 + λ*s2) = (u1 + λ*u2) mod (r1 + λ*r2). - // This can be done through a hinted half-GCD in the number field - // K=Q[w]/f(w). This corresponds to K being the Eisenstein ring of - // integers i.e. w is a primitive cube root of unity, f(w)=w^2+w+1=0. + // where (v1 + λ*v2)*s = u1 + λ*u2 mod r + // and u1, u2, v1, v2 < c*r^{1/4} with c ≈ 1.25 (proven bound from LLL lattice reduction). // - // The hint returns u1, u2, v1, v2 and the quotient q. - // In-circuit we check that (v1 + λ*v2)*s = (u1 + λ*u2) + r*q + // We use LLL-based lattice reduction to find small u1, u2, v1, v2 satisfying + // s ≡ -(u1 + λ*u2) / (v1 + λ*v2) (mod r). // - // N.B.: this check may overflow. But we don't use this method anywhere but for testing purposes. + // The hint returns u1, u2, v1, v2. + // In-circuit we check that (v1 + λ*v2)*s + u1 + λ*u2 = 0 mod r + // using emulated arithmetic to avoid overflow in the native field. // - // Eisenstein integers real and imaginary parts can be negative. So we - // return the absolute value in the hint and negate the corresponding - // points here when needed. - sd, err := api.NewHint(halfGCDEisenstein, 10, _s, cc.lambda) + // The sub-scalars can be negative. So we return the absolute value in the + // hint and negate the corresponding points here when needed. + sd, err := api.NewHint(rationalReconstructExt, 8, _s, cc.lambda) if err != nil { - panic(fmt.Sprintf("halfGCDEisenstein hint: %v", err)) + panic(fmt.Sprintf("rationalReconstructExt hint: %v", err)) } - u1, u2, v1, v2, q := sd[0], sd[1], sd[2], sd[3], sd[4] - isNegu1, isNegu2, isNegv1, isNegv2, isNegq := sd[5], sd[6], sd[7], sd[8], sd[9] + u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] + isNegu1, isNegu2, isNegv1, isNegv2 := sd[4], sd[5], sd[6], sd[7] // We need to check that: - // s*(v1 + λ*v2) + u1 + λ*u2 - r * q = 0 - sv1 := api.Mul(_s, v1) - sλv2 := api.Mul(_s, api.Mul(cc.lambda, v2)) - λu2 := api.Mul(cc.lambda, u2) - rq := api.Mul(cc.fr, q) - - lhs1 := api.Select(isNegv1, 0, sv1) - lhs2 := api.Select(isNegv2, 0, sλv2) - lhs3 := api.Select(isNegu1, 0, u1) - lhs4 := api.Select(isNegu2, 0, λu2) - lhs5 := api.Select(isNegq, rq, 0) - lhs := api.Add( - api.Add(lhs1, lhs2), - api.Add(lhs3, lhs4), + // s*(v1 + λ*v2) + u1 + λ*u2 = 0 mod r + // + // We use emulated arithmetic over the BLS12-377 scalar field to avoid overflow. + // The native field (BW6-761 scalar field) is ~377 bits, but the products + // s*λ*v2 can exceed 400 bits, causing overflow in native arithmetic. + scalarApi, err := emulated.NewField[emparams.BLS12377Fr](api) + if err != nil { + panic(fmt.Sprintf("failed to create scalar field: %v", err)) + } + + // Convert to emulated elements + _sEmu := scalarApi.FromBits(api.ToBinary(_s, cc.fr.BitLen())...) + u1Emu := scalarApi.FromBits(api.ToBinary(u1, (cc.fr.BitLen()+3)/4+2)...) + u2Emu := scalarApi.FromBits(api.ToBinary(u2, (cc.fr.BitLen()+3)/4+2)...) + v1Emu := scalarApi.FromBits(api.ToBinary(v1, (cc.fr.BitLen()+3)/4+2)...) + v2Emu := scalarApi.FromBits(api.ToBinary(v2, (cc.fr.BitLen()+3)/4+2)...) + lambdaEmu := scalarApi.NewElement(cc.lambda) + zero := scalarApi.Zero() + + // Compute s*v1, s*λ*v2, λ*u2 in emulated arithmetic + sv1Emu := scalarApi.Mul(_sEmu, v1Emu) + λv2Emu := scalarApi.Mul(lambdaEmu, v2Emu) + sλv2Emu := scalarApi.Mul(_sEmu, λv2Emu) + λu2Emu := scalarApi.Mul(lambdaEmu, u2Emu) + + // Handle signs: positive terms go to lhs, negative terms go to rhs + lhs1Emu := scalarApi.Select(isNegv1, zero, sv1Emu) + lhs2Emu := scalarApi.Select(isNegv2, zero, sλv2Emu) + lhs3Emu := scalarApi.Select(isNegu1, zero, u1Emu) + lhs4Emu := scalarApi.Select(isNegu2, zero, λu2Emu) + lhsEmu := scalarApi.Add( + scalarApi.Add(lhs1Emu, lhs2Emu), + scalarApi.Add(lhs3Emu, lhs4Emu), ) - lhs = api.Add(lhs, lhs5) - - rhs1 := api.Select(isNegv1, sv1, 0) - rhs2 := api.Select(isNegv2, sλv2, 0) - rhs3 := api.Select(isNegu1, u1, 0) - rhs4 := api.Select(isNegu2, λu2, 0) - rhs5 := api.Select(isNegq, 0, rq) - rhs := api.Add( - api.Add(rhs1, rhs2), - api.Add(rhs3, rhs4), + + rhs1Emu := scalarApi.Select(isNegv1, sv1Emu, zero) + rhs2Emu := scalarApi.Select(isNegv2, sλv2Emu, zero) + rhs3Emu := scalarApi.Select(isNegu1, u1Emu, zero) + rhs4Emu := scalarApi.Select(isNegu2, λu2Emu, zero) + rhsEmu := scalarApi.Add( + scalarApi.Add(rhs1Emu, rhs2Emu), + scalarApi.Add(rhs3Emu, rhs4Emu), ) - rhs = api.Add(rhs, rhs5) - api.AssertIsEqual(lhs, rhs) + scalarApi.AssertIsEqual(lhsEmu, rhsEmu) + + // Ensure the denominator v1 + λ*v2 is non-zero to prevent trivial decomposition + denEmu := scalarApi.Add(v1Emu, scalarApi.Mul(lambdaEmu, v2Emu)) + scalarApi.AssertIsDifferent(denEmu, zero) // Next we compute the hinted scalar mul Q = [s]P point, err := api.NewHint(scalarMulGLVG1Hint, 2, P.X, P.Y, s) @@ -820,12 +998,12 @@ func (p *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte // Since the loop size N=nbits-1 is odd the result at the end should be // [2^N]H = H = (0,1). H := G1Affine{X: 0, Y: 1} - Acc.AddAssign(api, H) + Acc.AddUnified(api, H) - // u1, u2, v1, v2 < r^{1/4} (up to a constant factor). - // We prove that the factor is log_(3/sqrt(3)))(r). - // so we need to add 9 bits to r^{1/4}.nbits(). - nbits := cc.lambda.BitLen()>>1 + 9 // 72 + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 (proven bound from LLL lattice reduction). + // We need ceil(r.BitLen()/4) + 2 bits to account for the constant factor. + // For BLS12-377, r.BitLen() = 253, so nbits = 64 + 2 = 66. + nbits := (cc.fr.BitLen()+3)/4 + 2 u1bits := api.ToBinary(u1, nbits) u2bits := api.ToBinary(u2, nbits) v1bits := api.ToBinary(v1, nbits) diff --git a/std/algebra/native/sw_bls12377/g1_test.go b/std/algebra/native/sw_bls12377/g1_test.go index 677cd1119c..62955f1d85 100644 --- a/std/algebra/native/sw_bls12377/g1_test.go +++ b/std/algebra/native/sw_bls12377/g1_test.go @@ -674,6 +674,48 @@ func TestJointScalarMulG1EdgeCases(t *testing.T) { assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) } +type g1JointScalarMulOppositePoints struct { + A, NegA G1Affine + C G1Affine `gnark:",public"` + R, S frontend.Variable +} + +func (circuit *g1JointScalarMulOppositePoints) Define(api frontend.API) error { + expected := G1Affine{} + expected.jointScalarMul(api, circuit.A, circuit.NegA, circuit.R, circuit.S, algopts.WithCompleteArithmetic()) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestJointScalarMulG1OppositePoints(t *testing.T) { + _a := randomPointG1() + negAJac := _a + var a, negA, c bls12377.G1Affine + a.FromJacobian(&_a) + negAJac.Neg(&negAJac) + negA.FromJacobian(&negAJac) + + var circuit, witness g1JointScalarMulOppositePoints + var r, s fr.Element + _, _ = r.SetRandom() + _, _ = s.SetRandom() + witness.R = r.String() + witness.S = s.String() + witness.A.Assign(&a) + witness.NegA.Assign(&negA) + + var ar, as big.Int + var ra, sa, sum bls12377.G1Jac + ra.ScalarMultiplication(&_a, r.BigInt(&ar)) + sa.ScalarMultiplication(&negAJac, s.BigInt(&as)) + sum.Set(&ra).AddAssign(&sa) + c.FromJacobian(&sum) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) +} + type g1JointScalarMul struct { A, B G1Affine C G1Affine `gnark:",public"` diff --git a/std/algebra/native/sw_bls12377/g2.go b/std/algebra/native/sw_bls12377/g2.go index f3aa7db373..3a989a5cbc 100644 --- a/std/algebra/native/sw_bls12377/g2.go +++ b/std/algebra/native/sw_bls12377/g2.go @@ -4,6 +4,7 @@ package sw_bls12377 import ( + "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc" @@ -12,6 +13,8 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/algebra/native/fields_bls12377" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" ) type g2AffP struct { @@ -113,6 +116,20 @@ func (p *g2AffP) Select(api frontend.API, b frontend.Variable, p1, p2 g2AffP) *g return p } +// Lookup2 performs a 2-bit lookup between p1, p2, p3, p4 based on bits b0 and b1. +// Returns: +// - p1 if b0=0 and b1=0, +// - p2 if b0=1 and b1=0, +// - p3 if b0=0 and b1=1, +// - p4 if b0=1 and b1=1. +func (p *g2AffP) Lookup2(api frontend.API, b1, b2 frontend.Variable, p1, p2, p3, p4 g2AffP) *g2AffP { + + p.X.Lookup2(api, b1, b2, p1.X, p2.X, p3.X, p4.X) + p.Y.Lookup2(api, b1, b2, p1.Y, p2.Y, p3.Y, p4.Y) + + return p +} + // Double compute 2*p1, assign the result to p and return it // Only for curve with j invariant 0 (a=0). func (p *g2AffP) Double(api frontend.API, p1 g2AffP) *g2AffP { @@ -547,3 +564,262 @@ func (p *g2AffP) psi(api frontend.API, q *g2AffP) *g2AffP { return p } + +// scalarMulGLVAndFakeGLV computes [s]P using GLV+fakeGLV with r^(1/4) bounds. +// It implements the "GLV + fake GLV" optimization which achieves tighter bounds +// on the sub-scalars, reducing the number of iterations in the scalar multiplication loop. +// +// ⚠️ The scalar s must be nonzero and the point P different from (0,0) unless [algopts.WithCompleteArithmetic] is set. +func (p *g2AffP) scalarMulGLVAndFakeGLV(api frontend.API, P g2AffP, s frontend.Variable, opts ...algopts.AlgebraOption) *g2AffP { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(err) + } + cc := getInnerCurveConfig(api.Compiler().Field()) + + // handle zero-scalar + var selector0 frontend.Variable + _s := s + if cfg.CompleteArithmetic { + selector0 = api.IsZero(s) + _s = api.Select(selector0, 1, s) + } + + // Instead of computing [s]P=Q, we check that Q-[s]P == 0. + // Checking Q - [s]P = 0 is equivalent to [v]Q + [-s*v]P = 0 for some nonzero v. + // + // The GLV curves supported in gnark have j-invariant 0, which means the eigenvalue + // of the GLV endomorphism is a primitive cube root of unity λ. Using this we can + // express the check as: + // + // [v1 + λ*v2]Q + [u1 + λ*u2]P = 0 + // [v1]Q + [v2]phi(Q) + [u1]P + [u2]phi(P) = 0 + // + // where (v1 + λ*v2)*s = u1 + λ*u2 mod r + // and u1, u2, v1, v2 < c*r^{1/4} with c ≈ 1.25 (proven bound from LLL lattice reduction). + // + // The sub-scalars can be negative. So we return the absolute value in the + // hint and negate the corresponding points here when needed. + sd, err := api.NewHint(rationalReconstructExt, 8, _s, cc.lambda) + if err != nil { + panic(fmt.Sprintf("rationalReconstructExt hint: %v", err)) + } + u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] + isNegu1, isNegu2, isNegv1, isNegv2 := sd[4], sd[5], sd[6], sd[7] + + // We need to check that: + // s*(v1 + λ*v2) + u1 + λ*u2 = 0 mod r + // + // We use emulated arithmetic over the BLS12-377 scalar field to avoid overflow. + // The native field (BW6-761 scalar field) is ~377 bits, but the products + // s*λ*v2 can exceed 400 bits, causing overflow in native arithmetic. + scalarApi, err := emulated.NewField[emparams.BLS12377Fr](api) + if err != nil { + panic(fmt.Sprintf("failed to create scalar field: %v", err)) + } + + // Convert to emulated elements + _sEmu := scalarApi.FromBits(api.ToBinary(_s, cc.fr.BitLen())...) + u1Emu := scalarApi.FromBits(api.ToBinary(u1, (cc.fr.BitLen()+3)/4+2)...) + u2Emu := scalarApi.FromBits(api.ToBinary(u2, (cc.fr.BitLen()+3)/4+2)...) + v1Emu := scalarApi.FromBits(api.ToBinary(v1, (cc.fr.BitLen()+3)/4+2)...) + v2Emu := scalarApi.FromBits(api.ToBinary(v2, (cc.fr.BitLen()+3)/4+2)...) + lambdaEmu := scalarApi.NewElement(cc.lambda) + zeroEmu := scalarApi.Zero() + + // Compute s*v1, s*λ*v2, λ*u2 in emulated arithmetic + sv1Emu := scalarApi.Mul(_sEmu, v1Emu) + λv2Emu := scalarApi.Mul(lambdaEmu, v2Emu) + sλv2Emu := scalarApi.Mul(_sEmu, λv2Emu) + λu2Emu := scalarApi.Mul(lambdaEmu, u2Emu) + + // Handle signs: positive terms go to lhs, negative terms go to rhs + lhs1Emu := scalarApi.Select(isNegv1, zeroEmu, sv1Emu) + lhs2Emu := scalarApi.Select(isNegv2, zeroEmu, sλv2Emu) + lhs3Emu := scalarApi.Select(isNegu1, zeroEmu, u1Emu) + lhs4Emu := scalarApi.Select(isNegu2, zeroEmu, λu2Emu) + lhsEmu := scalarApi.Add( + scalarApi.Add(lhs1Emu, lhs2Emu), + scalarApi.Add(lhs3Emu, lhs4Emu), + ) + + rhs1Emu := scalarApi.Select(isNegv1, sv1Emu, zeroEmu) + rhs2Emu := scalarApi.Select(isNegv2, sλv2Emu, zeroEmu) + rhs3Emu := scalarApi.Select(isNegu1, u1Emu, zeroEmu) + rhs4Emu := scalarApi.Select(isNegu2, λu2Emu, zeroEmu) + rhsEmu := scalarApi.Add( + scalarApi.Add(rhs1Emu, rhs2Emu), + scalarApi.Add(rhs3Emu, rhs4Emu), + ) + + scalarApi.AssertIsEqual(lhsEmu, rhsEmu) + + // Ensure the denominator v1 + λ*v2 is non-zero to prevent trivial decomposition + denEmu := scalarApi.Add(v1Emu, scalarApi.Mul(lambdaEmu, v2Emu)) + scalarApi.AssertIsDifferent(denEmu, zeroEmu) + + // Next we compute the hinted scalar mul Q = [s]P + point, err := api.NewHint(scalarMulGLVG2Hint, 4, P.X.A0, P.X.A1, P.Y.A0, P.Y.A1, s) + if err != nil { + panic(fmt.Sprintf("scalar mul hint: %v", err)) + } + Q := g2AffP{ + X: fields_bls12377.E2{A0: point[0], A1: point[1]}, + Y: fields_bls12377.E2{A0: point[2], A1: point[3]}, + } + + // handle (0,0)-point + var _selector0, selectorQ0 frontend.Variable + _P := P + one := fields_bls12377.E2{A0: 1, A1: 0} + zero := fields_bls12377.E2{A0: 0, A1: 0} + if cfg.CompleteArithmetic { + // if P=(0,0) we assign a dummy point to P and continue + _selector0 = api.And(P.X.IsZero(api), P.Y.IsZero(api)) + two := fields_bls12377.E2{A0: 2, A1: 0} + _P.Select(api, _selector0, g2AffP{X: two, Y: one}, P) + // if Q=(0,0) (either because s=0 or P=(0,0)) we assign a dummy point to Q + selectorQ0 = api.And(Q.X.IsZero(api), Q.Y.IsZero(api)) + Q.Select(api, selectorQ0, g2AffP{X: one, Y: one}, Q) + } + + // precompute -P, -Φ(P), Φ(P) + var tableP, tablePhiP [2]g2AffP + var negPY fields_bls12377.E2 + negPY.Neg(api, _P.Y) + tableP[1] = g2AffP{ + X: _P.X, + Y: fields_bls12377.E2{ + A0: api.Select(isNegu1, negPY.A0, _P.Y.A0), + A1: api.Select(isNegu1, negPY.A1, _P.Y.A1), + }, + } + tableP[0].Neg(api, tableP[1]) + var phiPX fields_bls12377.E2 + phiPX.MulByFp(api, _P.X, cc.thirdRootOne2) + tablePhiP[1] = g2AffP{ + X: phiPX, + Y: fields_bls12377.E2{ + A0: api.Select(isNegu2, negPY.A0, _P.Y.A0), + A1: api.Select(isNegu2, negPY.A1, _P.Y.A1), + }, + } + tablePhiP[0].Neg(api, tablePhiP[1]) + + // precompute -Q, -Φ(Q), Φ(Q) + var tableQ, tablePhiQ [2]g2AffP + var negQY fields_bls12377.E2 + negQY.Neg(api, Q.Y) + tableQ[1] = g2AffP{ + X: Q.X, + Y: fields_bls12377.E2{ + A0: api.Select(isNegv1, negQY.A0, Q.Y.A0), + A1: api.Select(isNegv1, negQY.A1, Q.Y.A1), + }, + } + tableQ[0].Neg(api, tableQ[1]) + var phiQX fields_bls12377.E2 + phiQX.MulByFp(api, Q.X, cc.thirdRootOne2) + tablePhiQ[1] = g2AffP{ + X: phiQX, + Y: fields_bls12377.E2{ + A0: api.Select(isNegv2, negQY.A0, Q.Y.A0), + A1: api.Select(isNegv2, negQY.A1, Q.Y.A1), + }, + } + tablePhiQ[0].Neg(api, tablePhiQ[1]) + + // precompute -P-Q, P+Q, P-Q, -P+Q, -Φ(P)-Φ(Q), Φ(P)+Φ(Q), Φ(P)-Φ(Q), -Φ(P)+Φ(Q) + // We use AddUnified for table precomputation to handle edge cases like s=1 where Q=P + // and the points might be equal (requiring doubling instead of addition). + var tableS, tablePhiS [4]g2AffP + tableS[0] = tableP[0] + tableS[0].AddUnified(api, tableQ[0]) + tableS[1].Neg(api, tableS[0]) + tableS[2] = tableP[1] + tableS[2].AddUnified(api, tableQ[0]) + tableS[3].Neg(api, tableS[2]) + tablePhiS[0] = tablePhiP[0] + tablePhiS[0].AddUnified(api, tablePhiQ[0]) + tablePhiS[1].Neg(api, tablePhiS[0]) + tablePhiS[2] = tablePhiP[1] + tablePhiS[2].AddUnified(api, tablePhiQ[0]) + tablePhiS[3].Neg(api, tablePhiS[2]) + + // we suppose that the first bits of the sub-scalars are 1 and set: + // Acc = P + Q + Φ(P) + Φ(Q) + Acc := tableS[1] + Acc.AddAssign(api, tablePhiS[1]) + // When doing doubleAndAdd(Acc, B) as (Acc+B)+Acc it might happen that + // Acc==B or -B. So we add the G2 generator to it to avoid incomplete + // additions in the loop by forcing Acc to be different than the stored B. + // At the end, since [u1]P + [u2]Φ(P) + [v1]Q + [v2]Φ(Q) = 0, + // Acc will equal [2^(nbits-1)]G2 (precomputed). + points := getTwistPoints() + G2Gen := g2AffP{ + X: fields_bls12377.E2{A0: points.G2x[0], A1: points.G2x[1]}, + Y: fields_bls12377.E2{A0: points.G2y[0], A1: points.G2y[1]}, + } + Acc.AddAssign(api, G2Gen) + + // u1, u2, v1, v2 < c*r^{1/4} where c ≈ 1.25 (proven bound from LLL lattice reduction). + // We need ceil(r.BitLen()/4) + 2 bits to account for the constant factor. + // For BLS12-377, r.BitLen() = 253, so nbits = 64 + 2 = 66. + nbits := (cc.fr.BitLen()+3)/4 + 2 + u1bits := api.ToBinary(u1, nbits) + u2bits := api.ToBinary(u2, nbits) + v1bits := api.ToBinary(v1, nbits) + v2bits := api.ToBinary(v2, nbits) + + var B g2AffP + for i := nbits - 1; i > 0; i-- { + B.X.Select(api, api.Xor(u1bits[i], v1bits[i]), tableS[2].X, tableS[0].X) + B.Y.Lookup2(api, u1bits[i], v1bits[i], tableS[0].Y, tableS[2].Y, tableS[3].Y, tableS[1].Y) + Acc.DoubleAndAdd(api, &Acc, &B) + B.X.Select(api, api.Xor(u2bits[i], v2bits[i]), tablePhiS[2].X, tablePhiS[0].X) + B.Y.Lookup2(api, u2bits[i], v2bits[i], tablePhiS[0].Y, tablePhiS[2].Y, tablePhiS[3].Y, tablePhiS[1].Y) + Acc.AddAssign(api, B) + } + + // i = 0 + // subtract the P, Q, Φ(P), Φ(Q) if the first bits are 0 + tableP[0].AddAssign(api, Acc) + Acc.Select(api, u1bits[0], Acc, tableP[0]) + tablePhiP[0].AddAssign(api, Acc) + Acc.Select(api, u2bits[0], Acc, tablePhiP[0]) + tableQ[0].AddAssign(api, Acc) + Acc.Select(api, v1bits[0], Acc, tableQ[0]) + tablePhiQ[0].AddAssign(api, Acc) + Acc.Select(api, v2bits[0], Acc, tablePhiQ[0]) + + // Acc should be now equal to [2^(nbits-1)]G2 since we added G2 at the beginning + // and [u1]P + [u2]Φ(P) + [v1]Q + [v2]Φ(Q) = 0. + // The loop does nbits-1 doublings, so the generator accumulates to [2^(nbits-1)]G2. + // G2m[i] = [2^i]G2, so we need G2m[nbits-1] = [2^(nbits-1)]G2. + expected := g2AffP{ + X: fields_bls12377.E2{ + A0: points.G2m[nbits-1][0], + A1: points.G2m[nbits-1][1], + }, + Y: fields_bls12377.E2{ + A0: points.G2m[nbits-1][2], + A1: points.G2m[nbits-1][3], + }, + } + if cfg.CompleteArithmetic { + // if P=(0,0) or s=0 (which makes Q=(0,0)), set Acc to expected to pass the check + skipCheck := api.Or(selector0, _selector0) + Acc.Select(api, skipCheck, expected, Acc) + } + Acc.AssertIsEqual(api, expected) + + if cfg.CompleteArithmetic { + // Return (0,0) when s=0 or P=(0,0) + Q.Select(api, api.Or(selector0, _selector0), g2AffP{X: zero, Y: zero}, Q) + } + + p.X = Q.X + p.Y = Q.Y + + return p +} diff --git a/std/algebra/native/sw_bls12377/g2_test.go b/std/algebra/native/sw_bls12377/g2_test.go index caf89b0c45..fa2c095ec6 100644 --- a/std/algebra/native/sw_bls12377/g2_test.go +++ b/std/algebra/native/sw_bls12377/g2_test.go @@ -370,3 +370,84 @@ func randomPointG2() bls12377.G2Jac { p2.ScalarMultiplication(&p2, r1.BigInt(&b)) return p2 } + +// ------------------------------------------------------------------------------------------------- +// GLV and Fake GLV scalar multiplication tests + +type g2ScalarMulGLVAndFakeGLV struct { + A g2AffP + C g2AffP `gnark:",public"` + R frontend.Variable +} + +func (circuit *g2ScalarMulGLVAndFakeGLV) Define(api frontend.API) error { + expected := g2AffP{} + expected.scalarMulGLVAndFakeGLV(api, circuit.A, circuit.R) + expected.AssertIsEqual(api, circuit.C) + return nil +} + +func TestScalarMulG2GLVAndFakeGLV(t *testing.T) { + // sample random point + _a := randomPointG2() + var a, c bls12377.G2Affine + a.FromJacobian(&_a) + + // create the cs + var circuit, witness g2ScalarMulGLVAndFakeGLV + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // assign the inputs + witness.A.Assign(&a) + // compute the result + var br big.Int + _a.ScalarMultiplication(&_a, r.BigInt(&br)) + c.FromJacobian(&_a) + witness.C.Assign(&c) + + assert := test.NewAssert(t) + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) +} + +type g2ScalarMulGLVAndFakeGLVEdgeCases struct { + A g2AffP + R frontend.Variable + Zero frontend.Variable +} + +func (circuit *g2ScalarMulGLVAndFakeGLVEdgeCases) Define(api frontend.API) error { + // Note: The GLVAndFakeGLV algorithm assumes P ≠ Q where Q = [s]P. + // This means s=1 is not supported as it would make Q = P. + // The s=1 case should be handled separately (it's trivial: [1]P = P). + expected1, expected2, expected3 := g2AffP{}, g2AffP{}, g2AffP{} + zero := fields_bls12377.E2{A0: 0, A1: 0} + infinity := g2AffP{X: zero, Y: zero} + expected1.scalarMulGLVAndFakeGLV(api, circuit.A, circuit.Zero, algopts.WithCompleteArithmetic()) + expected2.scalarMulGLVAndFakeGLV(api, infinity, circuit.R, algopts.WithCompleteArithmetic()) + expected3.scalarMulGLVAndFakeGLV(api, infinity, circuit.Zero, algopts.WithCompleteArithmetic()) + expected1.AssertIsEqual(api, infinity) + expected2.AssertIsEqual(api, infinity) + expected3.AssertIsEqual(api, infinity) + return nil +} + +func TestScalarMulG2GLVAndFakeGLVEdgeCases(t *testing.T) { + // sample random point + _a := randomPointG2() + var a bls12377.G2Affine + a.FromJacobian(&_a) + + // create the cs + var circuit, witness g2ScalarMulGLVAndFakeGLVEdgeCases + var r fr.Element + _, _ = r.SetRandom() + witness.R = r.String() + // assign the inputs + witness.A.Assign(&a) + + witness.Zero = 0 + + assert := test.NewAssert(t) + assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761)) +} diff --git a/std/algebra/native/sw_bls12377/hints.go b/std/algebra/native/sw_bls12377/hints.go index 9764c2b71e..2b8860eaea 100644 --- a/std/algebra/native/sw_bls12377/hints.go +++ b/std/algebra/native/sw_bls12377/hints.go @@ -4,7 +4,7 @@ import ( "errors" "math/big" - "github.com/consensys/gnark-crypto/algebra/eisenstein" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/constraint/solver" @@ -16,7 +16,9 @@ func GetHints() []solver.Hint { decomposeScalarG1Simple, decomposeScalarG2, scalarMulGLVG1Hint, - halfGCDEisenstein, + scalarMulGLVG2Hint, + jointScalarMulG1Hint, + rationalReconstructExt, pairingCheckHint, pairingCheckTorusHint, } @@ -304,68 +306,133 @@ func scalarMulGLVG1Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big. return nil } -func halfGCDEisenstein(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { +func jointScalarMulG1Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 6 { + return errors.New("expecting six inputs") + } + if len(outputs) != 2 { + return errors.New("expecting two outputs") + } + + // compute the resulting point [s]Q + [t]R + var Q, R, result bls12377.G1Affine + Q.X.SetBigInt(inputs[0]) + Q.Y.SetBigInt(inputs[1]) + R.X.SetBigInt(inputs[2]) + R.Y.SetBigInt(inputs[3]) + + // handle infinity cases + QIsInfinity := Q.X.IsZero() && Q.Y.IsZero() + RIsInfinity := R.X.IsZero() && R.Y.IsZero() + sIsZero := inputs[4].Sign() == 0 + tIsZero := inputs[5].Sign() == 0 + + switch { + case (QIsInfinity || sIsZero) && (RIsInfinity || tIsZero): + // both contributions are zero + outputs[0].SetInt64(0) + outputs[1].SetInt64(0) + case QIsInfinity || sIsZero: + // only R contributes + R.ScalarMultiplication(&R, inputs[5]) + R.X.BigInt(outputs[0]) + R.Y.BigInt(outputs[1]) + case RIsInfinity || tIsZero: + // only Q contributes + Q.ScalarMultiplication(&Q, inputs[4]) + Q.X.BigInt(outputs[0]) + Q.Y.BigInt(outputs[1]) + default: + // both contribute + Q.ScalarMultiplication(&Q, inputs[4]) + R.ScalarMultiplication(&R, inputs[5]) + result.Add(&Q, &R) + result.X.BigInt(outputs[0]) + result.Y.BigInt(outputs[1]) + } + return nil +} + +func scalarMulGLVG2Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 5 { + return errors.New("expecting five inputs") + } + if len(outputs) != 4 { + return errors.New("expecting four outputs") + } + + // compute the resulting point [s]Q on G2 + var Q bls12377.G2Affine + Q.X.A0.SetBigInt(inputs[0]) + Q.X.A1.SetBigInt(inputs[1]) + Q.Y.A0.SetBigInt(inputs[2]) + Q.Y.A1.SetBigInt(inputs[3]) + Q.ScalarMultiplication(&Q, inputs[4]) + Q.X.A0.BigInt(outputs[0]) + Q.X.A1.BigInt(outputs[1]) + Q.Y.A0.BigInt(outputs[2]) + Q.Y.A1.BigInt(outputs[3]) + return nil +} + +func rationalReconstructExt(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { if len(inputs) != 2 { - return errors.New("expecting two input") + return errors.New("expecting two inputs") } - if len(outputs) != 10 { - return errors.New("expecting ten outputs") + if len(outputs) != 8 { + return errors.New("expecting eight outputs") } cc := getInnerCurveConfig(scalarField) - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis) - r := eisenstein.ComplexNumber{ - A0: glvBasis.V1[0], - A1: glvBasis.V1[1], - } - sp := ecc.SplitScalar(inputs[0], glvBasis) + + // Use lattice reduction to find (x, y, z, t) such that + // k ≡ (x + λ*y) / (z + λ*t) (mod r) + // // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 - // so here we return -s instead of s. - s := eisenstein.ComplexNumber{ - A0: sp[0], - A1: sp[1], - } - s.Neg(&s) - res := eisenstein.HalfGCD(&r, &s) - outputs[0].Set(&res[0].A0) - outputs[1].Set(&res[0].A1) - outputs[2].Set(&res[1].A0) - outputs[3].Set(&res[1].A1) - outputs[4].Mul(&res[1].A1, inputs[1]). - Add(outputs[4], &res[1].A0). - Mul(outputs[4], inputs[0]). - Add(outputs[4], &res[0].A0) - s.A0.Mul(&res[0].A1, inputs[1]) - outputs[4].Add(outputs[4], &s.A0). - Div(outputs[4], cc.fr) + // so here we use k = -s. + // + // With k = -s: + // -s ≡ (x + λ*y) / (z + λ*t) (mod r) + // s ≡ -(x + λ*y) / (z + λ*t) = (-x - λ*y) / (z + λ*t) (mod r) + // + // The circuit checks: s*(v1 + λ*v2) + u1 + λ*u2 ≡ 0 (mod r) + // Rearranging: s ≡ -(u1 + λ*u2) / (v1 + λ*v2) (mod r) + // + // Matching: (-x - λ*y) = -(u1 + λ*u2) + // So: u1 = x, u2 = y, v1 = z, v2 = t + k := new(big.Int).Neg(inputs[0]) + k.Mod(k, cc.fr) + rc := lattice.NewReconstructor(cc.fr).SetLambda(inputs[1]) + res := rc.RationalReconstructExt(k) + x, y, z, t := res[0], res[1], res[2], res[3] + + // u1 = x, u2 = y, v1 = z, v2 = t + outputs[0].Abs(x) // |u1| = |x| + outputs[1].Abs(y) // |u2| = |y| + outputs[2].Abs(z) // |v1| = |z| + outputs[3].Abs(t) // |v2| = |t| // set the signs - outputs[5].SetUint64(0) - outputs[6].SetUint64(0) - outputs[7].SetUint64(0) - outputs[8].SetUint64(0) - outputs[9].SetUint64(0) - - if outputs[0].Sign() == -1 { - outputs[0].Neg(outputs[0]) + outputs[4].SetUint64(0) // isNegu1 + outputs[5].SetUint64(0) // isNegu2 + outputs[6].SetUint64(0) // isNegv1 + outputs[7].SetUint64(0) // isNegv2 + + // u1 = x is negative when x < 0 + if x.Sign() < 0 { + outputs[4].SetUint64(1) + } + // u2 = y is negative when y < 0 + if y.Sign() < 0 { outputs[5].SetUint64(1) } - if outputs[1].Sign() == -1 { - outputs[1].Neg(outputs[1]) + // v1 = z is negative when z < 0 + if z.Sign() < 0 { outputs[6].SetUint64(1) } - if outputs[2].Sign() == -1 { - outputs[2].Neg(outputs[2]) + // v2 = t is negative when t < 0 + if t.Sign() < 0 { outputs[7].SetUint64(1) } - if outputs[3].Sign() == -1 { - outputs[3].Neg(outputs[3]) - outputs[8].SetUint64(1) - } - if outputs[4].Sign() == -1 { - outputs[4].Neg(outputs[4]) - outputs[9].SetUint64(1) - } return nil } diff --git a/std/algebra/native/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go index 9349e276e5..aa1e627c32 100644 --- a/std/algebra/native/twistededwards/curve.go +++ b/std/algebra/native/twistededwards/curve.go @@ -49,8 +49,20 @@ func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point { p.scalarMul(c.api, &p1, scalar, c.params, c.endo) return p } + +// DoubleBaseScalarMul computes s1*p1 + s2*p2 and returns the result. +// It uses the most efficient implementation available: +// - For curves with GLV endomorphism (Bandersnatch): 6-MSM with r^(1/3) bounds +// - For curves without endomorphism: hinted LogUp with √r bounds +// +// ⚠️ The scalars s1 and s2 must be nonzero and the points p1, p2 must not be +// the identity point (0,1). These optimized implementations do not handle edge cases. func (c *curve) DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point { var p Point - p.doubleBaseScalarMul(c.api, &p1, &p2, s1, s2, c.params) + if c.endo != nil { + p.doubleBaseScalarMul6MSMLogUp(c.api, &p1, &p2, s1, s2, c.params, c.endo) + } else { + p.doubleBaseScalarMul3MSMLogUp(c.api, &p1, &p2, s1, s2, c.params) + } return p } diff --git a/std/algebra/native/twistededwards/curve_test.go b/std/algebra/native/twistededwards/curve_test.go index 65fd31dea0..a3f7d7c31d 100644 --- a/std/algebra/native/twistededwards/curve_test.go +++ b/std/algebra/native/twistededwards/curve_test.go @@ -15,7 +15,9 @@ import ( tbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" "github.com/consensys/gnark-crypto/ecc/twistededwards" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/test" ) @@ -341,3 +343,34 @@ func (p *CurveParams) randomScalar() *big.Int { r, _ := rand.Int(rand.Reader, p.Order) return r } + +// Benchmarks for constraint counting + +type scalarMulCircuit struct { + curveID twistededwards.ID + P Point + S frontend.Variable + R Point +} + +func (circuit *scalarMulCircuit) Define(api frontend.API) error { + curve, err := NewEdCurve(api, circuit.curveID) + if err != nil { + return err + } + res := curve.ScalarMul(circuit.P, circuit.S) + api.AssertIsEqual(res.X, circuit.R.X) + api.AssertIsEqual(res.Y, circuit.R.Y) + return nil +} + +func BenchmarkScalarMulTwistedEdwards(b *testing.B) { + var circuit scalarMulCircuit + circuit.curveID = twistededwards.BN254 + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit) + } + ccs, _ := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit) + b.Log("constraints:", ccs.GetNbConstraints()) +} diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go index b83564f213..08580f34f3 100644 --- a/std/algebra/native/twistededwards/hints.go +++ b/std/algebra/native/twistededwards/hints.go @@ -3,8 +3,8 @@ package twistededwards import ( "errors" "math/big" - "sync" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" edbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" "github.com/consensys/gnark-crypto/ecc/bls12-381/bandersnatch" @@ -16,9 +16,10 @@ import ( func GetHints() []solver.Hint { return []solver.Hint{ - halfGCD, + rationalReconstruct, scalarMulHint, - decomposeScalar, + doubleBaseScalarMulHint, + multiRationalReconstructExtHint, } } @@ -26,53 +27,14 @@ func init() { solver.RegisterHint(GetHints()...) } -type glvParams struct { - lambda, order big.Int - glvBasis ecc.Lattice -} - -func decomposeScalar(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error { - // the efficient endomorphism exists on Bandersnatch only - if scalarField.Cmp(ecc.BLS12_381.ScalarField()) != 0 { - return errors.New("no efficient endomorphism is available on this curve") - } - var glv glvParams - var init sync.Once - init.Do(func() { - glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) - glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) - ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) - }) - - // sp[0] is always negative because, in SplitScalar(), we always round above - // the determinant/2 computed in PrecomputeLattice() which is negative for Bandersnatch. - // Thus taking -sp[0] here and negating the point in ScalarMul(). - // If we keep -sp[0] it will be reduced mod r (the BLS12-381 prime order) - // and not the Bandersnatch prime order (Order) and the result will be incorrect. - // Also, if we reduce it mod Order here, we can't use api.ToBinary(sp[0], 129) - // and hence we can't reduce optimally the number of constraints. - sp := ecc.SplitScalar(inputs[0], &glv.glvBasis) - res[0].Neg(&(sp[0])) - res[1].Set(&(sp[1])) - - // figure out how many times we have overflowed - res[2].Mul(res[1], &glv.lambda).Sub(res[2], res[0]) - res[2].Sub(res[2], inputs[0]) - res[2].Div(res[2], &glv.order) - - return nil -} - -func halfGCD(mod *big.Int, inputs, outputs []*big.Int) error { +func rationalReconstruct(mod *big.Int, inputs, outputs []*big.Int) error { if len(inputs) != 2 { return errors.New("expecting two inputs") } if len(outputs) != 4 { return errors.New("expecting four outputs") } - // using PrecomputeLattice for scalar decomposition is a hack and it doesn't - // work in case the scalar is zero. override it for now to avoid division by - // zero until a long-term solution is found. + // Handle zero scalar case if inputs[0].Sign() == 0 { outputs[0].SetUint64(0) outputs[1].SetUint64(0) @@ -80,23 +42,39 @@ func halfGCD(mod *big.Int, inputs, outputs []*big.Int) error { outputs[3].SetUint64(0) return nil } - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(inputs[1], inputs[0], glvBasis) - outputs[0].Set(&glvBasis.V1[0]) - outputs[1].Set(&glvBasis.V1[1]) - // figure out how many times we have overflowed - // s2 * s + s1 = k*r - outputs[3].Mul(outputs[1], inputs[0]). - Add(outputs[3], outputs[0]). - Div(outputs[3], inputs[1]) + // Use lattice reduction to find (x, z) such that s ≡ x/z (mod r), + // i.e., x - s*z ≡ 0 (mod r), or equivalently x + s*(-z) ≡ 0 (mod r). + // The circuit checks: s1 + s*_s2 ≡ 0 (mod r) + // So we need s1 = x and _s2 = -z. + rc := lattice.NewReconstructor(inputs[1]) + res := rc.RationalReconstruct(inputs[0]) + x, z := res[0], res[1] + // Ensure x is non-negative (the circuit bit-decomposes s1 assuming it's small positive). + // If x < 0, flip signs: (x, z) -> (-x, -z), which preserves s = x/z. + if x.Sign() < 0 { + x.Neg(x) + z.Neg(z) + } + + outputs[0].Set(x) + outputs[1].Abs(z) + + // The sign indicates whether to negate s2 in circuit to get -z. + // sign = 1 when z > 0 (so -z < 0, and we need to negate |z| to get -z) outputs[2].SetUint64(0) - if outputs[1].Sign() == -1 { - outputs[1].Neg(outputs[1]) + if z.Sign() > 0 { outputs[2].SetUint64(1) } + // Compute overflow: k = (x - s*z) / r + // The constraint is x - s*z ≡ 0 (mod r), so x - s*z = k*r for some integer k + // We need to keep the sign of k for the circuit to work correctly. + outputs[3].Mul(z, inputs[0]) // s*z + outputs[3].Sub(x, outputs[3]) // x - s*z + outputs[3].Div(outputs[3], inputs[1]) // k = (x - s*z) / r + return nil } @@ -151,3 +129,193 @@ func scalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error } return nil } + +// doubleBaseScalarMulHint computes [s1]P1 and [s2]P2 for the hinted double-base scalar multiplication +// inputs: P1.X, P1.Y, s1, P2.X, P2.Y, s2, order +// outputs: Q1.X, Q1.Y, Q2.X, Q2.Y where Q1=[s1]P1 and Q2=[s2]P2 +func doubleBaseScalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 7 { + return errors.New("expecting seven inputs") + } + if len(outputs) != 4 { + return errors.New("expecting four outputs") + } + // compute [s1]P1 and [s2]P2 + if field.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + order, _ := new(big.Int).SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + if inputs[6].Cmp(order) == 0 { + var P1, P2 bandersnatch.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else { + var P1, P2 jubjub.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } + } else if field.Cmp(ecc.BN254.ScalarField()) == 0 { + var P1, P2 babyjubjub.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else if field.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + var P1, P2 edbls12377.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else if field.Cmp(ecc.BW6_761.ScalarField()) == 0 { + var P1, P2 edbw6761.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else { + return errors.New("doubleBaseScalarMulHint: unknown curve") + } + return nil +} + +// multiRationalReconstructExtHint decomposes two scalars k1, k2 using MultiRationalReconstructExt +// for curves with a GLV endomorphism (Bandersnatch). +// inputs: k1, k2, order, lambda +// outputs [0..11]: |x1|, |y1|, |x2|, |y2|, |z|, |t|, signX1, signY1, signX2, signY2, signZ, signT +// outputs [12..19]: d, kd, n1, kn1, n2, kn2, k_1, k_2 (decomposition verification values) +// +// where k1 ≡ (x1 + λ*y1)/(z + λ*t) (mod order) and k2 ≡ (x2 + λ*y2)/(z + λ*t) (mod order) +// +// The circuit verifies: +// 1. [x1]P + [y1]φ(P) + [x2]Q + [y2]φ(Q) = [z]R + [t]φ(R) (group equation) +// 2. k1*(z+λt) ≡ x1+λy1 (mod r) and k2*(z+λt) ≡ x2+λy2 (mod r) (decomposition) +// +// where R = [k1]P + [k2]Q (hinted separately) +func multiRationalReconstructExtHint(mod *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 4 { + return errors.New("expecting four inputs: k1, k2, order, lambda") + } + if len(outputs) != 20 { + return errors.New("expecting 20 outputs") + } + + k1, k2, order, lambda := inputs[0], inputs[1], inputs[2], inputs[3] + + // Handle zero scalar cases + if k1.Sign() == 0 && k2.Sign() == 0 { + for i := 0; i < 20; i++ { + outputs[i].SetUint64(0) + } + return nil + } + + // Use MultiRationalReconstructExt to find (x1, y1, x2, y2, z, t) with shared denominator + // k1 ≡ (x1 + λ*y1)/(z + λ*t) (mod order) + // k2 ≡ (x2 + λ*y2)/(z + λ*t) (mod order) + rc := lattice.NewReconstructor(order).SetLambda(lambda) + res := rc.MultiRationalReconstructExt(k1, k2) + x1, y1, x2, y2, z, t := res[0], res[1], res[2], res[3], res[4], res[5] + + // Store absolute values + outputs[0].Abs(x1) + outputs[1].Abs(y1) + outputs[2].Abs(x2) + outputs[3].Abs(y2) + outputs[4].Abs(z) + outputs[5].Abs(t) + + // Store signs (1 if negative, 0 if non-negative) + setSign := func(out *big.Int, val *big.Int) { + if val.Sign() < 0 { + out.SetUint64(1) + } else { + out.SetUint64(0) + } + } + setSign(outputs[6], x1) + setSign(outputs[7], y1) + setSign(outputs[8], x2) + setSign(outputs[9], y2) + setSign(outputs[10], z) + setSign(outputs[11], t) + + // Compute decomposition verification values. + // We verify k_i*(z + λ*t) ≡ x_i + λ*y_i (mod r) by splitting into: + // (a) d = (z + λ*t) mod r, kd = (z + λ*t - d) / r + // (b) n_i = (x_i + λ*y_i) mod r, kn_i = (x_i + λ*y_i - n_i) / r + // (c) k_i*(z+λ*t) mod r check: k_i*d - n_i = k_i_overflow * r + + // d = (z + λ*t) mod r + zPlusLambdaT := new(big.Int).Mul(lambda, t) + zPlusLambdaT.Add(zPlusLambdaT, z) + d := new(big.Int).Mod(zPlusLambdaT, order) + kd := new(big.Int).Sub(zPlusLambdaT, d) + kd.Div(kd, order) + + // n1 = (x1 + λ*y1) mod r + x1PlusLambdaY1 := new(big.Int).Mul(lambda, y1) + x1PlusLambdaY1.Add(x1PlusLambdaY1, x1) + n1 := new(big.Int).Mod(x1PlusLambdaY1, order) + kn1 := new(big.Int).Sub(x1PlusLambdaY1, n1) + kn1.Div(kn1, order) + + // n2 = (x2 + λ*y2) mod r + x2PlusLambdaY2 := new(big.Int).Mul(lambda, y2) + x2PlusLambdaY2.Add(x2PlusLambdaY2, x2) + n2 := new(big.Int).Mod(x2PlusLambdaY2, order) + kn2 := new(big.Int).Sub(x2PlusLambdaY2, n2) + kn2.Div(kn2, order) + + // k_1 = (k1*d - n1) / r + k1d := new(big.Int).Mul(k1, d) + k1Overflow := new(big.Int).Sub(k1d, n1) + k1Overflow.Div(k1Overflow, order) + + // k_2 = (k2*d - n2) / r + k2d := new(big.Int).Mul(k2, d) + k2Overflow := new(big.Int).Sub(k2d, n2) + k2Overflow.Div(k2Overflow, order) + + outputs[12].Set(d) + outputs[13].Set(kd) + outputs[14].Set(n1) + outputs[15].Set(kn1) + outputs[16].Set(n2) + outputs[17].Set(kn2) + outputs[18].Set(k1Overflow) + outputs[19].Set(k2Overflow) + + return nil +} diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index fde04e192c..65a3d5fe19 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -3,7 +3,10 @@ package twistededwards -import "github.com/consensys/gnark/frontend" +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/lookup/logderivlookup" +) // neg computes the negative of a point in SNARK coordinates func (p *Point) neg(api frontend.API, p1 *Point) *Point { @@ -176,62 +179,18 @@ func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoP g = api.Mul(g, endo.Endo[0]) h := api.Sub(yy, endo.Endo[0]) - p.X = api.DivUnchecked(f, xy) + // When the input is the identity (0,1), xy=0 and f=0, so f/xy is 0/0. + // φ(identity) = identity, so p.X should be 0 in that case. + // We avoid DivUnchecked(0,0) by selecting xy=1 when x=0 (f is also 0, + // so 0/1=0 gives the correct result). + isIdentity := api.IsZero(p1.X) + safeXY := api.Select(isIdentity, 1, xy) + p.X = api.DivUnchecked(f, safeXY) p.Y = api.DivUnchecked(g, h) return p } -// scalarMulGLV computes the scalar multiplication of a point on a twisted -// Edwards curve à la GLV. -// p1: base point (as snark point) -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { - // the hints allow to decompose the scalar s into s1 and s2 such that - // s1 + λ * s2 == s mod Order, - // with λ s.t. λ² = -2 mod Order. - sd, err := api.NewHint(decomposeScalar, 3, scalar) - if err != nil { - // err is non-nil only for invalid number of inputs - panic(err) - } - - s1, s2 := sd[0], sd[1] - - // -s1 + λ * s2 == s + k*Order - api.AssertIsEqual(api.Sub(api.Mul(s2, endo.Lambda), s1), api.Add(scalar, api.Mul(curve.Order, sd[2]))) - - // Normally s1 and s2 are of the max size sqrt(Order) = 128 - // But in a circuit, we force s1 to be negative by rounding always above. - // This changes the size bounds to 2*sqrt(Order) = 129. - n := 129 - - b1 := api.ToBinary(s1, n) - b2 := api.ToBinary(s2, n) - - var res, _p1, p2, p3, tmp Point - _p1.neg(api, p1) - p2.phi(api, p1, curve, endo) - p3.add(api, &_p1, &p2, curve) - - res.X = api.Lookup2(b1[n-1], b2[n-1], 0, _p1.X, p2.X, p3.X) - res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, _p1.Y, p2.Y, p3.Y) - - for i := n - 2; i >= 0; i-- { - res.double(api, &res, curve) - tmp.X = api.Lookup2(b1[i], b2[i], 0, _p1.X, p2.X, p3.X) - tmp.Y = api.Lookup2(b1[i], b2[i], 1, _p1.Y, p2.Y, p3.Y) - res.add(api, &res, &tmp, curve) - } - - p.X = res.X - p.Y = res.Y - - return p -} - // scalarMulFakeGLV computes the scalar multiplication of a point on a twisted // Edwards curve following https://hackmd.io/@yelhousni/Hy-aWld50 // @@ -245,7 +204,7 @@ func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variab func (p *Point) scalarMulFakeGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams) *Point { // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + s * s2 == 0 mod Order, - s, err := api.NewHint(halfGCD, 4, scalar, curve.Order) + s, err := api.NewHint(rationalReconstruct, 4, scalar, curve.Order) if err != nil { // err is non-nil only for invalid number of inputs panic(err) @@ -292,3 +251,427 @@ func (p *Point) scalarMulFakeGLV(api frontend.API, p1 *Point, scalar frontend.Va return p } + +// doubleBaseScalarMul3MSMLogUp computes s1*P1+s2*P2 using MultiRationalReconstruct (true 3-MSM). +// This decomposes both scalars with a shared denominator in Z, giving ~r^(2/3)-bit scalars. +// Verifies: [x1]P + [x2]Q = [z]R +// where R = [s1]P + [s2]Q (hinted). +// Uses LogDerivLookup for the 4-point multi-scalar multiplication (16-entry table). +func (p *Point) doubleBaseScalarMul3MSMLogUp(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve *CurveParams) *Point { + // Get hinted results Q1 = [s1]P1 and Q2 = [s2]P2 + q, err := api.NewHint(doubleBaseScalarMulHint, 4, p1.X, p1.Y, s1, p2.X, p2.Y, s2, curve.Order) + if err != nil { + panic(err) + } + var Q1, Q2 Point + Q1.X, Q1.Y = q[0], q[1] + Q2.X, Q2.Y = q[2], q[3] + + // Decompose s1 into (u1, v1) such that u1 + s1*v1 ≡ 0 (mod Order) + h1, err := api.NewHint(rationalReconstruct, 4, s1, curve.Order) + if err != nil { + panic(err) + } + u1, v1, bit1, k1 := h1[0], h1[1], h1[2], h1[3] + + // Verify: u1 + s1*v1 == k1*Order (with sign handling) + _v1s1 := api.Mul(v1, s1) + _k1r := api.Mul(k1, curve.Order) + lhs1 := api.Select(bit1, u1, api.Add(u1, _v1s1)) + rhs1 := api.Select(bit1, api.Add(_k1r, _v1s1), _k1r) + api.AssertIsEqual(lhs1, rhs1) + // Ensure denominator v1 is non-zero to prevent trivial decomposition. + // When s1=0 the hint legitimately returns v1=0, so we only check when s1≠0. + // This is safe because [0]*P = identity regardless of the hint output. + s1IsZero := api.IsZero(s1) + _v1NonZero := api.Select(s1IsZero, 1, v1) + api.AssertIsDifferent(_v1NonZero, 0) + + // Decompose s2 into (u2, v2) such that u2 + s2*v2 ≡ 0 (mod Order) + h2, err := api.NewHint(rationalReconstruct, 4, s2, curve.Order) + if err != nil { + panic(err) + } + u2, v2, bit2, k2 := h2[0], h2[1], h2[2], h2[3] + + // Verify: u2 + s2*v2 == k2*Order (with sign handling) + _v2s2 := api.Mul(v2, s2) + _k2r := api.Mul(k2, curve.Order) + lhs2 := api.Select(bit2, u2, api.Add(u2, _v2s2)) + rhs2 := api.Select(bit2, api.Add(_k2r, _v2s2), _k2r) + api.AssertIsEqual(lhs2, rhs2) + // Ensure denominator v2 is non-zero to prevent trivial decomposition + s2IsZero := api.IsZero(s2) + _v2NonZero := api.Select(s2IsZero, 1, v2) + api.AssertIsDifferent(_v2NonZero, 0) + + // Apply sign to Q1 and Q2 based on decomposition + var _Q1, _Q2 Point + _Q1.X = api.Select(bit1, api.Neg(Q1.X), Q1.X) + _Q1.Y = Q1.Y + _Q2.X = api.Select(bit2, api.Neg(Q2.X), Q2.X) + _Q2.Y = Q2.Y + + // Build the 16-entry table for 4-MSM: P1, _Q1, P2, _Q2 + var table [16]Point + + // Precompute pair sums + var P1Q1, P2Q2, P1P2, P1Q2, Q1P2, Q1Q2 Point + P1Q1.add(api, p1, &_Q1, curve) + P2Q2.add(api, p2, &_Q2, curve) + P1P2.add(api, p1, p2, curve) + P1Q2.add(api, p1, &_Q2, curve) + Q1P2.add(api, &_Q1, p2, curve) + Q1Q2.add(api, &_Q1, &_Q2, curve) + + // Precompute triple sums + var P1Q1P2, P1Q1Q2, P1P2Q2, Q1P2Q2 Point + P1Q1P2.add(api, &P1Q1, p2, curve) + P1Q1Q2.add(api, &P1Q1, &_Q2, curve) + P1P2Q2.add(api, &P1P2, &_Q2, curve) + Q1P2Q2.add(api, &Q1P2, &_Q2, curve) + + // Precompute quad sum + var P1Q1P2Q2 Point + P1Q1P2Q2.add(api, &P1Q1P2, &_Q2, curve) + + // Build table: index i = b0 + 2*b1 + 4*b2 + 8*b3 + table[0] = Point{X: 0, Y: 1} + table[1] = *p1 + table[2] = _Q1 + table[3] = P1Q1 + table[4] = *p2 + table[5] = P1P2 + table[6] = Q1P2 + table[7] = P1Q1P2 + table[8] = _Q2 + table[9] = P1Q2 + table[10] = Q1Q2 + table[11] = P1Q1Q2 + table[12] = P2Q2 + table[13] = P1P2Q2 + table[14] = Q1P2Q2 + table[15] = P1Q1P2Q2 + + // Create LogDerivLookup tables + tableX := logderivlookup.New(api) + tableY := logderivlookup.New(api) + for i := 0; i < 16; i++ { + tableX.Insert(table[i].X) + tableY.Insert(table[i].Y) + } + + n := (curve.Order.BitLen() + 1) / 2 + b1 := api.ToBinary(u1, n) + b2 := api.ToBinary(v1, n) + b3 := api.ToBinary(u2, n) + b4 := api.ToBinary(v2, n) + + // Compute indices for lookups + indices := make([]frontend.Variable, n) + for i := 0; i < n; i++ { + // index = b1[i] + 2*b2[i] + 4*b3[i] + 8*b4[i] + indices[i] = api.Add( + b1[i], + api.Mul(b2[i], 2), + api.Mul(b3[i], 4), + api.Mul(b4[i], 8), + ) + } + + // Batch lookup + resX := tableX.Lookup(indices...) + resY := tableY.Lookup(indices...) + + // Initialize accumulator with first entry + var res Point + res.X = resX[n-1] + res.Y = resY[n-1] + + for i := n - 2; i >= 0; i-- { + res.double(api, &res, curve) + var tmp Point + tmp.X = resX[i] + tmp.Y = resY[i] + res.add(api, &res, &tmp, curve) + } + + // Verify accumulator equals identity (0, 1) + api.AssertIsEqual(res.X, 0) + api.AssertIsEqual(res.Y, 1) + + // Return Q1 + Q2 + p.add(api, &Q1, &Q2, curve) + + return p +} + +// doubleBaseScalarMul6MSMLogUp computes s1*P1+s2*P2 using MultiRationalReconstructExt (true 6-MSM). +// This decomposes both scalars with a shared denominator in Z[λ], giving ~r^(1/3)-bit scalars. +// Verifies: [x1]P + [y1]φ(P) + [x2]Q + [y2]φ(Q) = [z]R + [t]φ(R) +// where R = [s1]P + [s2]Q (hinted). +// Only works for curves with efficient endomorphism (e.g., Bandersnatch). +// Uses LogDerivLookup for the 64-entry table (6 points). +func (p *Point) doubleBaseScalarMul6MSMLogUp(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { + // Get hinted result R = [s1]P + [s2]Q + qHint, err := api.NewHint(doubleBaseScalarMulHint, 4, p1.X, p1.Y, s1, p2.X, p2.Y, s2, curve.Order) + if err != nil { + panic(err) + } + var R Point + // We need Q1 + Q2 = R + var Q1, Q2 Point + Q1.X, Q1.Y = qHint[0], qHint[1] + Q2.X, Q2.Y = qHint[2], qHint[3] + R.add(api, &Q1, &Q2, curve) + + // Decompose (s1, s2) using MultiRationalReconstructExt + // Returns |x1|, |y1|, |x2|, |y2|, |z|, |t|, signs, and decomposition verification values + h, err := api.NewHint(multiRationalReconstructExtHint, 20, s1, s2, curve.Order, endo.Lambda) + if err != nil { + panic(err) + } + absX1, absY1, absX2, absY2, absZ, absT := h[0], h[1], h[2], h[3], h[4], h[5] + signX1, signY1, signX2, signY2, signZ, signT := h[6], h[7], h[8], h[9], h[10], h[11] + d, kd, n1, kn1, n2, kn2, k1Over, k2Over := h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19] + + // Verify the decomposition: k_i*(z + λ*t) ≡ x_i + λ*y_i (mod r) + // We split this into intermediate steps to avoid native field overflow: + // (a) z + λ*t ≡ d (mod r): z_signed + λ*t_signed = d + kd*r (mod p) + // (b) x_i + λ*y_i ≡ n_i (mod r): x_i_signed + λ*y_i_signed = n_i + kn_i*r (mod p) + // (c) k_i*d ≡ n_i (mod r): k_i*d = n_i + k_i_overflow*r (mod p) + { + r := curve.Order + lambda := endo.Lambda + + // Signed values (negative = p-val in the native field) + zVal := api.Select(signZ, api.Sub(0, absZ), absZ) + tVal := api.Select(signT, api.Sub(0, absT), absT) + x1Val := api.Select(signX1, api.Sub(0, absX1), absX1) + y1Val := api.Select(signY1, api.Sub(0, absY1), absY1) + x2Val := api.Select(signX2, api.Sub(0, absX2), absX2) + y2Val := api.Select(signY2, api.Sub(0, absY2), absY2) + + // Range check d, n1, n2 (must be < 2^orderBits to bound overflow) + orderBits := r.BitLen() + api.ToBinary(d, orderBits) + api.ToBinary(n1, orderBits) + api.ToBinary(n2, orderBits) + + // (a) z + λ*t = d + kd*r (mod p) + api.AssertIsEqual( + api.Add(zVal, api.Mul(lambda, tVal)), + api.Add(d, api.Mul(kd, r)), + ) + + // (b) x1 + λ*y1 = n1 + kn1*r (mod p) + api.AssertIsEqual( + api.Add(x1Val, api.Mul(lambda, y1Val)), + api.Add(n1, api.Mul(kn1, r)), + ) + + // (b) x2 + λ*y2 = n2 + kn2*r (mod p) + api.AssertIsEqual( + api.Add(x2Val, api.Mul(lambda, y2Val)), + api.Add(n2, api.Mul(kn2, r)), + ) + + // (c) s1*d = n1 + k1Over*r (mod p), proving s1*d ≡ n1 (mod r) + api.AssertIsEqual( + api.Mul(s1, d), + api.Add(n1, api.Mul(k1Over, r)), + ) + + // (c) s2*d = n2 + k2Over*r (mod p), proving s2*d ≡ n2 (mod r) + api.AssertIsEqual( + api.Mul(s2, d), + api.Add(n2, api.Mul(k2Over, r)), + ) + + // Ensure shared denominator d = (z + λ*t) mod r is non-zero + // to prevent trivial decomposition leaving R unconstrained. + // When both scalars are zero the hint legitimately returns d=0. + bothZero := api.And(api.IsZero(s1), api.IsZero(s2)) + _dNonZero := api.Select(bothZero, 1, d) + api.AssertIsDifferent(_dNonZero, 0) + } + + // Compute φ(P1), φ(P2), φ(R) + var phiP1, phiP2, phiR Point + phiP1.phi(api, p1, curve, endo) + phiP2.phi(api, p2, curve, endo) + phiR.phi(api, &R, curve, endo) + + // Apply signs to create signed points for the 6-MSM + // The verification is: [x1]P + [y1]φ(P) + [x2]Q + [y2]φ(Q) - [z]R - [t]φ(R) = O + // With signs: we negate the point when the sign is 1 + var sP1, sPhiP1, sP2, sPhiP2, sR, sPhiR Point + + // For P1: if signX1 == 1, use -P1, else use P1 + sP1.X = api.Select(signX1, api.Neg(p1.X), p1.X) + sP1.Y = p1.Y + + // For φ(P1): if signY1 == 1, use -φ(P1), else use φ(P1) + sPhiP1.X = api.Select(signY1, api.Neg(phiP1.X), phiP1.X) + sPhiP1.Y = phiP1.Y + + // For P2: if signX2 == 1, use -P2, else use P2 + sP2.X = api.Select(signX2, api.Neg(p2.X), p2.X) + sP2.Y = p2.Y + + // For φ(P2): if signY2 == 1, use -φ(P2), else use φ(P2) + sPhiP2.X = api.Select(signY2, api.Neg(phiP2.X), phiP2.X) + sPhiP2.Y = phiP2.Y + + // For R: we subtract [z]R, so if signZ == 0 (z positive), use -R; if signZ == 1 (z negative), use R + sR.X = api.Select(signZ, R.X, api.Neg(R.X)) + sR.Y = R.Y + + // For φ(R): similarly for t + sPhiR.X = api.Select(signT, phiR.X, api.Neg(phiR.X)) + sPhiR.Y = phiR.Y + + // Build 64-entry table for 6-MSM + // Index = b0 + 2*b1 + 4*b2 + 8*b3 + 16*b4 + 32*b5 + // Points: sP1, sPhiP1, sP2, sPhiP2, sR, sPhiR + var table [64]Point + + // Precompute all 64 combinations + // table[i] = (i&1)*sP1 + ((i>>1)&1)*sPhiP1 + ((i>>2)&1)*sP2 + ((i>>3)&1)*sPhiP2 + ((i>>4)&1)*sR + ((i>>5)&1)*sPhiR + + // Start with identity + table[0] = Point{X: 0, Y: 1} + + // Single points + table[1] = sP1 + table[2] = sPhiP1 + table[4] = sP2 + table[8] = sPhiP2 + table[16] = sR + table[32] = sPhiR + + // 2-combinations + table[3].add(api, &sP1, &sPhiP1, curve) + table[5].add(api, &sP1, &sP2, curve) + table[6].add(api, &sPhiP1, &sP2, curve) + table[9].add(api, &sP1, &sPhiP2, curve) + table[10].add(api, &sPhiP1, &sPhiP2, curve) + table[12].add(api, &sP2, &sPhiP2, curve) + table[17].add(api, &sP1, &sR, curve) + table[18].add(api, &sPhiP1, &sR, curve) + table[20].add(api, &sP2, &sR, curve) + table[24].add(api, &sPhiP2, &sR, curve) + table[33].add(api, &sP1, &sPhiR, curve) + table[34].add(api, &sPhiP1, &sPhiR, curve) + table[36].add(api, &sP2, &sPhiR, curve) + table[40].add(api, &sPhiP2, &sPhiR, curve) + table[48].add(api, &sR, &sPhiR, curve) + + // 3-combinations (build from 2-combinations) + table[7].add(api, &table[3], &sP2, curve) // sP1 + sPhiP1 + sP2 + table[11].add(api, &table[3], &sPhiP2, curve) // sP1 + sPhiP1 + sPhiP2 + table[13].add(api, &table[5], &sPhiP2, curve) // sP1 + sP2 + sPhiP2 + table[14].add(api, &table[6], &sPhiP2, curve) // sPhiP1 + sP2 + sPhiP2 + table[19].add(api, &table[3], &sR, curve) // sP1 + sPhiP1 + sR + table[21].add(api, &table[5], &sR, curve) // sP1 + sP2 + sR + table[22].add(api, &table[6], &sR, curve) // sPhiP1 + sP2 + sR + table[25].add(api, &table[9], &sR, curve) // sP1 + sPhiP2 + sR + table[26].add(api, &table[10], &sR, curve) // sPhiP1 + sPhiP2 + sR + table[28].add(api, &table[12], &sR, curve) // sP2 + sPhiP2 + sR + table[35].add(api, &table[3], &sPhiR, curve) // sP1 + sPhiP1 + sPhiR + table[37].add(api, &table[5], &sPhiR, curve) // sP1 + sP2 + sPhiR + table[38].add(api, &table[6], &sPhiR, curve) // sPhiP1 + sP2 + sPhiR + table[41].add(api, &table[9], &sPhiR, curve) // sP1 + sPhiP2 + sPhiR + table[42].add(api, &table[10], &sPhiR, curve) // sPhiP1 + sPhiP2 + sPhiR + table[44].add(api, &table[12], &sPhiR, curve) // sP2 + sPhiP2 + sPhiR + table[49].add(api, &table[17], &sPhiR, curve) // sP1 + sR + sPhiR + table[50].add(api, &table[18], &sPhiR, curve) // sPhiP1 + sR + sPhiR + table[52].add(api, &table[20], &sPhiR, curve) // sP2 + sR + sPhiR + table[56].add(api, &table[24], &sPhiR, curve) // sPhiP2 + sR + sPhiR + + // 4-combinations + table[15].add(api, &table[7], &sPhiP2, curve) // sP1 + sPhiP1 + sP2 + sPhiP2 + table[23].add(api, &table[7], &sR, curve) // sP1 + sPhiP1 + sP2 + sR + table[27].add(api, &table[11], &sR, curve) // sP1 + sPhiP1 + sPhiP2 + sR + table[29].add(api, &table[13], &sR, curve) // sP1 + sP2 + sPhiP2 + sR + table[30].add(api, &table[14], &sR, curve) // sPhiP1 + sP2 + sPhiP2 + sR + table[39].add(api, &table[7], &sPhiR, curve) // sP1 + sPhiP1 + sP2 + sPhiR + table[43].add(api, &table[11], &sPhiR, curve) // sP1 + sPhiP1 + sPhiP2 + sPhiR + table[45].add(api, &table[13], &sPhiR, curve) // sP1 + sP2 + sPhiP2 + sPhiR + table[46].add(api, &table[14], &sPhiR, curve) // sPhiP1 + sP2 + sPhiP2 + sPhiR + table[51].add(api, &table[19], &sPhiR, curve) // sP1 + sPhiP1 + sR + sPhiR + table[53].add(api, &table[21], &sPhiR, curve) // sP1 + sP2 + sR + sPhiR + table[54].add(api, &table[22], &sPhiR, curve) // sPhiP1 + sP2 + sR + sPhiR + table[57].add(api, &table[25], &sPhiR, curve) // sP1 + sPhiP2 + sR + sPhiR + table[58].add(api, &table[26], &sPhiR, curve) // sPhiP1 + sPhiP2 + sR + sPhiR + table[60].add(api, &table[28], &sPhiR, curve) // sP2 + sPhiP2 + sR + sPhiR + + // 5-combinations + table[31].add(api, &table[15], &sR, curve) // all except sPhiR + table[47].add(api, &table[15], &sPhiR, curve) // all except sR + table[55].add(api, &table[23], &sPhiR, curve) // sP1 + sPhiP1 + sP2 + sR + sPhiR + table[59].add(api, &table[27], &sPhiR, curve) // sP1 + sPhiP1 + sPhiP2 + sR + sPhiR + table[61].add(api, &table[29], &sPhiR, curve) // sP1 + sP2 + sPhiP2 + sR + sPhiR + table[62].add(api, &table[30], &sPhiR, curve) // sPhiP1 + sP2 + sPhiP2 + sR + sPhiR + + // 6-combination (all points) + table[63].add(api, &table[31], &sPhiR, curve) + + // Use LogDerivLookup for the 64-entry table + tableX := logderivlookup.New(api) + tableY := logderivlookup.New(api) + for i := 0; i < 64; i++ { + tableX.Insert(table[i].X) + tableY.Insert(table[i].Y) + } + + // Scalar bit length: ~r^(1/3) ≈ 85 bits for 254-bit order + n := (curve.Order.BitLen() + 2) / 3 + + bX1 := api.ToBinary(absX1, n) + bY1 := api.ToBinary(absY1, n) + bX2 := api.ToBinary(absX2, n) + bY2 := api.ToBinary(absY2, n) + bZ := api.ToBinary(absZ, n) + bT := api.ToBinary(absT, n) + + // Compute indices for lookups + indices := make([]frontend.Variable, n) + for i := 0; i < n; i++ { + indices[i] = api.Add( + bX1[i], + api.Mul(bY1[i], 2), + api.Mul(bX2[i], 4), + api.Mul(bY2[i], 8), + api.Mul(bZ[i], 16), + api.Mul(bT[i], 32), + ) + } + + // Batch lookup + lookupX := tableX.Lookup(indices...) + lookupY := tableY.Lookup(indices...) + + // Initialize accumulator with last entry + var acc Point + acc.X = lookupX[n-1] + acc.Y = lookupY[n-1] + + for i := n - 2; i >= 0; i-- { + acc.double(api, &acc, curve) + var tmp Point + tmp.X = lookupX[i] + tmp.Y = lookupY[i] + acc.add(api, &acc, &tmp, curve) + } + + // Verify accumulator equals identity (0, 1) + api.AssertIsEqual(acc.X, 0) + api.AssertIsEqual(acc.Y, 1) + + // Return R (the hinted result) + p.X = R.X + p.Y = R.Y + + return p +}