diff --git a/ecc/bls12-377/twistededwards/eddsa/eddsa.go b/ecc/bls12-377/twistededwards/eddsa/eddsa.go index c427be304..9ddb67807 100644 --- a/ecc/bls12-377/twistededwards/eddsa/eddsa.go +++ b/ecc/bls12-377/twistededwards/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // hash(h) = private_key || random_source, on 32 bytes each @@ -77,9 +75,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -123,6 +119,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -131,10 +128,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -218,10 +216,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go b/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go index 1e8e1a1d5..bf8396a7d 100644 --- a/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,184 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h := blake2b.Sum512(seed[:]) + for i := range 32 { + priv.randSrc[i] = h[i+32] + } + + h[0] &= 0xF8 + h[31] &= 0x7F + h[31] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +363,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BLS12_377.New() + hFunc := ghash.MIMC_BLS12_377.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +403,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +436,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BLS12_377.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BLS12_377.New() + hFunc := ghash.MIMC_BLS12_377.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bls12-377/twistededwards/point.go b/ecc/bls12-377/twistededwards/point.go index 7f616c56f..7c2e4084c 100644 --- a/ecc/bls12-377/twistededwards/point.go +++ b/ecc/bls12-377/twistededwards/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -261,6 +279,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -292,6 +335,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bls12-377/twistededwards/point_test.go b/ecc/bls12-377/twistededwards/point_test.go index ac9825a62..0b815e04d 100644 --- a/ecc/bls12-377/twistededwards/point_test.go +++ b/ecc/bls12-377/twistededwards/point_test.go @@ -856,6 +856,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -935,6 +997,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/ecc/bls12-381/bandersnatch/eddsa/eddsa.go b/ecc/bls12-381/bandersnatch/eddsa/eddsa.go index f66c95608..af9894c97 100644 --- a/ecc/bls12-381/bandersnatch/eddsa/eddsa.go +++ b/ecc/bls12-381/bandersnatch/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // hash(h) = private_key || random_source, on 32 bytes each @@ -77,9 +75,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -123,6 +119,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -131,10 +128,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -218,10 +216,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go b/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go index b46ba80b2..e46889942 100644 --- a/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go +++ b/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,184 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h := blake2b.Sum512(seed[:]) + for i := range 32 { + priv.randSrc[i] = h[i+32] + } + + h[0] &= 0xF8 + h[31] &= 0x7F + h[31] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +363,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BLS12_381.New() + hFunc := ghash.MIMC_BLS12_381.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +403,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +436,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BLS12_381.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BLS12_381.New() + hFunc := ghash.MIMC_BLS12_381.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bls12-381/bandersnatch/point.go b/ecc/bls12-381/bandersnatch/point.go index 4373eb2d4..17b14efff 100644 --- a/ecc/bls12-381/bandersnatch/point.go +++ b/ecc/bls12-381/bandersnatch/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -319,6 +337,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -350,6 +393,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bls12-381/bandersnatch/point_test.go b/ecc/bls12-381/bandersnatch/point_test.go index 4a7cf881a..98ba0b22f 100644 --- a/ecc/bls12-381/bandersnatch/point_test.go +++ b/ecc/bls12-381/bandersnatch/point_test.go @@ -925,6 +925,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -1004,6 +1066,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/ecc/bls12-381/twistededwards/eddsa/eddsa.go b/ecc/bls12-381/twistededwards/eddsa/eddsa.go index f66c95608..af9894c97 100644 --- a/ecc/bls12-381/twistededwards/eddsa/eddsa.go +++ b/ecc/bls12-381/twistededwards/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // hash(h) = private_key || random_source, on 32 bytes each @@ -77,9 +75,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -123,6 +119,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -131,10 +128,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -218,10 +216,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go b/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go index b46ba80b2..e46889942 100644 --- a/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,184 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h := blake2b.Sum512(seed[:]) + for i := range 32 { + priv.randSrc[i] = h[i+32] + } + + h[0] &= 0xF8 + h[31] &= 0x7F + h[31] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +363,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BLS12_381.New() + hFunc := ghash.MIMC_BLS12_381.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +403,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +436,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BLS12_381.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BLS12_381.New() + hFunc := ghash.MIMC_BLS12_381.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bls12-381/twistededwards/point.go b/ecc/bls12-381/twistededwards/point.go index 8706a32e2..2bb3c589c 100644 --- a/ecc/bls12-381/twistededwards/point.go +++ b/ecc/bls12-381/twistededwards/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -261,6 +279,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -292,6 +335,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bls12-381/twistededwards/point_test.go b/ecc/bls12-381/twistededwards/point_test.go index 1db80527e..8d74f4f9a 100644 --- a/ecc/bls12-381/twistededwards/point_test.go +++ b/ecc/bls12-381/twistededwards/point_test.go @@ -856,6 +856,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -935,6 +997,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/ecc/bls24-315/twistededwards/eddsa/eddsa.go b/ecc/bls24-315/twistededwards/eddsa/eddsa.go index ca9808e5b..f19422cee 100644 --- a/ecc/bls24-315/twistededwards/eddsa/eddsa.go +++ b/ecc/bls24-315/twistededwards/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // hash(h) = private_key || random_source, on 32 bytes each @@ -77,9 +75,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -123,6 +119,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -131,10 +128,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -218,10 +216,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go b/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go index 13ef9a711..2b24fe7f8 100644 --- a/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls24-315/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,184 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h := blake2b.Sum512(seed[:]) + for i := range 32 { + priv.randSrc[i] = h[i+32] + } + + h[0] &= 0xF8 + h[31] &= 0x7F + h[31] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +363,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BLS24_315.New() + hFunc := ghash.MIMC_BLS24_315.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +403,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +436,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BLS24_315.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BLS24_315.New() + hFunc := ghash.MIMC_BLS24_315.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bls24-315/twistededwards/point.go b/ecc/bls24-315/twistededwards/point.go index 621f96e74..678b14cfe 100644 --- a/ecc/bls24-315/twistededwards/point.go +++ b/ecc/bls24-315/twistededwards/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -261,6 +279,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -292,6 +335,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bls24-315/twistededwards/point_test.go b/ecc/bls24-315/twistededwards/point_test.go index da9ee08ec..80bf32c22 100644 --- a/ecc/bls24-315/twistededwards/point_test.go +++ b/ecc/bls24-315/twistededwards/point_test.go @@ -856,6 +856,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -935,6 +997,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/ecc/bls24-317/twistededwards/eddsa/eddsa.go b/ecc/bls24-317/twistededwards/eddsa/eddsa.go index a07cc1e27..772dac974 100644 --- a/ecc/bls24-317/twistededwards/eddsa/eddsa.go +++ b/ecc/bls24-317/twistededwards/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // hash(h) = private_key || random_source, on 32 bytes each @@ -77,9 +75,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -123,6 +119,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -131,10 +128,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -218,10 +216,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go b/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go index dfc689e52..0ba826a31 100644 --- a/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls24-317/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,184 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h := blake2b.Sum512(seed[:]) + for i := range 32 { + priv.randSrc[i] = h[i+32] + } + + h[0] &= 0xF8 + h[31] &= 0x7F + h[31] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +363,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BLS24_317.New() + hFunc := ghash.MIMC_BLS24_317.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +403,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +436,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BLS24_317.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BLS24_317.New() + hFunc := ghash.MIMC_BLS24_317.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bls24-317/twistededwards/point.go b/ecc/bls24-317/twistededwards/point.go index 98ab773dc..14963e7fd 100644 --- a/ecc/bls24-317/twistededwards/point.go +++ b/ecc/bls24-317/twistededwards/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -261,6 +279,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -292,6 +335,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bls24-317/twistededwards/point_test.go b/ecc/bls24-317/twistededwards/point_test.go index 132516a82..799c27aba 100644 --- a/ecc/bls24-317/twistededwards/point_test.go +++ b/ecc/bls24-317/twistededwards/point_test.go @@ -856,6 +856,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -935,6 +997,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/ecc/bn254/twistededwards/eddsa/eddsa.go b/ecc/bn254/twistededwards/eddsa/eddsa.go index 42e83a432..05cb2b662 100644 --- a/ecc/bn254/twistededwards/eddsa/eddsa.go +++ b/ecc/bn254/twistededwards/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // hash(h) = private_key || random_source, on 32 bytes each @@ -77,9 +75,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -123,6 +119,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -131,10 +128,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -218,10 +216,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bn254/twistededwards/eddsa/eddsa_test.go b/ecc/bn254/twistededwards/eddsa/eddsa_test.go index c35e29bec..4902c472d 100644 --- a/ecc/bn254/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bn254/twistededwards/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bn254/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,184 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h := blake2b.Sum512(seed[:]) + for i := range 32 { + priv.randSrc[i] = h[i+32] + } + + h[0] &= 0xF8 + h[31] &= 0x7F + h[31] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +363,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BN254.New() + hFunc := ghash.MIMC_BN254.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +403,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +436,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BN254.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BN254.New() + hFunc := ghash.MIMC_BN254.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bn254/twistededwards/point.go b/ecc/bn254/twistededwards/point.go index 6b1dd412f..d056006a4 100644 --- a/ecc/bn254/twistededwards/point.go +++ b/ecc/bn254/twistededwards/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -261,6 +279,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -292,6 +335,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bn254/twistededwards/point_test.go b/ecc/bn254/twistededwards/point_test.go index 83cef096e..81cbe9946 100644 --- a/ecc/bn254/twistededwards/point_test.go +++ b/ecc/bn254/twistededwards/point_test.go @@ -856,6 +856,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -935,6 +997,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/ecc/bw6-633/twistededwards/eddsa/eddsa.go b/ecc/bw6-633/twistededwards/eddsa/eddsa.go index c8d7bde96..4e66b39d8 100644 --- a/ecc/bw6-633/twistededwards/eddsa/eddsa.go +++ b/ecc/bw6-633/twistededwards/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // The source of randomness and the secret scalar must come @@ -86,9 +84,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h1[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -132,6 +128,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -140,10 +137,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -227,10 +225,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go b/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go index 1f52d1e7b..b64a890bd 100644 --- a/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,186 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h1 := blake2b.Sum512(seed[:]) + + h2 := blake2b.Sum512(h1[:]) + for i := range 32 { + priv.randSrc[i] = h2[i] + } + + h1[0] &= 0xF8 + h1[sizeFr-1] &= 0x7F + h1[sizeFr-1] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h1[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +365,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BW6_633.New() + hFunc := ghash.MIMC_BW6_633.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +405,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +438,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BW6_633.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BW6_633.New() + hFunc := ghash.MIMC_BW6_633.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bw6-633/twistededwards/point.go b/ecc/bw6-633/twistededwards/point.go index c3187c323..fabea040c 100644 --- a/ecc/bw6-633/twistededwards/point.go +++ b/ecc/bw6-633/twistededwards/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -261,6 +279,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -292,6 +335,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bw6-633/twistededwards/point_test.go b/ecc/bw6-633/twistededwards/point_test.go index dc58b8a18..85f083aca 100644 --- a/ecc/bw6-633/twistededwards/point_test.go +++ b/ecc/bw6-633/twistededwards/point_test.go @@ -856,6 +856,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -935,6 +997,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/ecc/bw6-761/twistededwards/eddsa/eddsa.go b/ecc/bw6-761/twistededwards/eddsa/eddsa.go index 3fce96a37..5b5cbbd5e 100644 --- a/ecc/bw6-761/twistededwards/eddsa/eddsa.go +++ b/ecc/bw6-761/twistededwards/eddsa/eddsa.go @@ -50,8 +50,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey // The source of randomness and the secret scalar must come @@ -86,9 +84,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = h1[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -132,6 +128,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -140,10 +137,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -227,10 +225,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go b/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go index 50a6527c7..5fa6f371e 100644 --- a/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go @@ -6,19 +6,23 @@ package eddsa import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" + "golang.org/x/crypto/blake2b" ) func Example() { @@ -48,6 +52,186 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h1 := blake2b.Sum512(seed[:]) + + h2 := blake2b.Sum512(h1[:]) + for i := range 32 { + priv.randSrc[i] = h2[i] + } + + h1[0] &= 0xF8 + h1[sizeFr-1] &= 0x7F + h1[sizeFr-1] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h1[j] + } + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -181,7 +365,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_BW6_761.New() + hFunc := ghash.MIMC_BW6_761.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -221,8 +405,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -256,12 +438,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_BW6_761.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_BW6_761.New() + hFunc := ghash.MIMC_BW6_761.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/ecc/bw6-761/twistededwards/point.go b/ecc/bw6-761/twistededwards/point.go index abe72f936..0e592cb71 100644 --- a/ecc/bw6-761/twistededwards/point.go +++ b/ecc/bw6-761/twistededwards/point.go @@ -10,6 +10,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" ) @@ -39,6 +40,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -131,6 +141,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -261,6 +279,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -292,6 +335,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/ecc/bw6-761/twistededwards/point_test.go b/ecc/bw6-761/twistededwards/point_test.go index 0ace00144..9b0d6486c 100644 --- a/ecc/bw6-761/twistededwards/point_test.go +++ b/ecc/bw6-761/twistededwards/point_test.go @@ -856,6 +856,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -935,6 +997,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int diff --git a/internal/generator/edwards/eddsa/template/eddsa.go.tmpl b/internal/generator/edwards/eddsa/template/eddsa.go.tmpl index 74363e179..413a86880 100644 --- a/internal/generator/edwards/eddsa/template/eddsa.go.tmpl +++ b/internal/generator/edwards/eddsa/template/eddsa.go.tmpl @@ -44,8 +44,6 @@ type Signature struct { // GenerateKey generates a public and private key pair. func GenerateKey(r io.Reader) (*PrivateKey, error) { - c := twistededwards.GetEdwardsCurve() - var pub PublicKey var priv PrivateKey @@ -104,9 +102,7 @@ func GenerateKey(r io.Reader) (*PrivateKey, error) { priv.scalar[i] = {{$h}}[j] } - var bScalar big.Int - bScalar.SetBytes(priv.scalar[:]) - pub.A.ScalarMultiplication(&c.Base, &bScalar) + pub.A.ScalarMultiplicationBase(&priv.scalar) priv.PublicKey = pub @@ -151,6 +147,7 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // blindingFactorBigInt must be the same size as the private key, // blindingFactorBigInt = h(randomness_source||message)[:sizeFr] var blindingFactorBigInt big.Int + var blindingFactorScalar [sizeFr]byte // randSrc = privKey.randSrc || msg (-> message = MSB message .. LSB message) randSrc := make([]byte, 32+len(message)) @@ -159,10 +156,11 @@ func (privKey *PrivateKey) Sign(message []byte, hFunc hash.Hash) ([]byte, error) // randBytes = H(randSrc) blindingFactorBytes := blake2b.Sum512(randSrc[:]) // TODO ensures that the hash used to build the key and the one used here is the same - blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + copy(blindingFactorScalar[:], blindingFactorBytes[:sizeFr]) + blindingFactorBigInt.SetBytes(blindingFactorScalar[:]) // compute R = randScalar*Base - res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + res.R.ScalarMultiplicationBase(&blindingFactorScalar) if !res.R.IsOnCurve() { return nil, errNotOnCurve } @@ -246,10 +244,9 @@ func (pub *PublicKey) Verify(sigBin, message []byte, hFunc hash.Hash) (bool, err // lhs = cofactor*S*Base var lhs twistededwards.PointAffine - var bCofactor, bs big.Int + var bCofactor big.Int curveParams.Cofactor.BigInt(&bCofactor) - bs.SetBytes(sig.S[:]) - lhs.ScalarMultiplication(&curveParams.Base, &bs). + lhs.ScalarMultiplicationBase(&sig.S). ScalarMultiplication(&lhs, &bCofactor) if !lhs.IsOnCurve() { diff --git a/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl b/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl index 8d04cf6f0..e56cae96d 100644 --- a/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl +++ b/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl @@ -1,20 +1,23 @@ import ( + "bytes" "crypto/sha256" + "io" "math/big" "math/rand" "testing" crand "crypto/rand" + stdhash "hash" "fmt" - "github.com/consensys/gnark-crypto/hash" + ghash "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark-crypto/ecc/{{.Name}}/twistededwards" "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr/mimc" + "golang.org/x/crypto/blake2b" ) - func Example() { // instantiate hash function hFunc := mimc.NewMiMC() @@ -42,6 +45,206 @@ func Example() { // Output: 1. valid signature } +func generateKeyReference(r io.Reader) (*PrivateKey, error) { + c := twistededwards.GetEdwardsCurve() + + var pub PublicKey + var priv PrivateKey + + {{- if or (eq .Name "bw6-761") (eq .Name "bw6-633")}} + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h1 := blake2b.Sum512(seed[:]) + + h2 := blake2b.Sum512(h1[:]) + for i := range 32 { + priv.randSrc[i] = h2[i] + } + + h1[0] &= 0xF8 + h1[sizeFr-1] &= 0x7F + h1[sizeFr-1] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h1[j] + } + {{- else }} + seed := make([]byte, 32) + _, err := r.Read(seed) + if err != nil { + return nil, err + } + h := blake2b.Sum512(seed[:]) + for i := range 32 { + priv.randSrc[i] = h[i+32] + } + + h[0] &= 0xF8 + h[31] &= 0x7F + h[31] |= 0x40 + for i, j := 0, sizeFr-1; i < sizeFr; i, j = i+1, j-1 { + priv.scalar[i] = h[j] + } + {{- end }} + + var bScalar big.Int + bScalar.SetBytes(priv.scalar[:]) + pub.A.ScalarMultiplication(&c.Base, &bScalar) + + priv.PublicKey = pub + + return &priv, nil +} + +func signReference(privKey *PrivateKey, message []byte, hFunc stdhash.Hash) ([]byte, error) { + if hFunc == nil { + return nil, errHashNeeded + } + + curveParams := twistededwards.GetEdwardsCurve() + + var res Signature + var blindingFactorBigInt big.Int + + randSrc := make([]byte, 32+len(message)) + copy(randSrc, privKey.randSrc[:]) + copy(randSrc[32:], message) + + blindingFactorBytes := blake2b.Sum512(randSrc[:]) + blindingFactorBigInt.SetBytes(blindingFactorBytes[:sizeFr]) + + res.R.ScalarMultiplication(&curveParams.Base, &blindingFactorBigInt) + if !res.R.IsOnCurve() { + return nil, errNotOnCurve + } + + hFunc.Reset() + + resRX := res.R.X.Bytes() + resRY := res.R.Y.Bytes() + resAX := privKey.PublicKey.A.X.Bytes() + resAY := privKey.PublicKey.A.Y.Bytes() + toWrite := [][]byte{resRX[:], resRY[:], resAX[:], resAY[:], message} + for _, chunk := range toWrite { + if _, err := hFunc.Write(chunk); err != nil { + return nil, err + } + } + + var hramInt big.Int + hramBin := hFunc.Sum(nil) + hramInt.SetBytes(hramBin) + + var bscalar, bs big.Int + bscalar.SetBytes(privKey.scalar[:]) + bs.Mul(&hramInt, &bscalar). + Add(&bs, &blindingFactorBigInt). + Mod(&bs, &curveParams.Order) + sb := bs.Bytes() + if len(sb) < sizeFr { + offset := make([]byte, sizeFr-len(sb)) + sb = append(offset, sb...) + } + copy(res.S[:], sb[:]) + + return res.Bytes(), nil +} + +func TestGenerateKeyMatchesGenericReference(t *testing.T) { + got, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + want, err := generateKeyReference(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + if got.scalar != want.scalar { + t.Fatal("scalar mismatch against generic reference") + } + if got.randSrc != want.randSrc { + t.Fatal("randSrc mismatch against generic reference") + } + if !got.PublicKey.A.Equal(&want.PublicKey.A) { + t.Fatal("public key mismatch against generic reference") + } +} + +func TestSignMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + + message := []byte("message") + got, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + want, err := signReference(privKey, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, want) { + t.Fatal("signature mismatch against generic reference") + } +} + +func TestVerifyFixedBaseMatchesGenericReference(t *testing.T) { + privKey, err := GenerateKey(rand.New(rand.NewSource(0))) //#nosec G404 deterministic test seed + if err != nil { + t.Fatal(err) + } + pubKey := privKey.PublicKey + + message := []byte("message") + signature, err := privKey.Sign(message, sha256.New()) + if err != nil { + t.Fatal(err) + } + + var sig Signature + if _, err := sig.SetBytes(signature); err != nil { + t.Fatal(err) + } + + curveParams := twistededwards.GetEdwardsCurve() + var cofactor, scalar big.Int + curveParams.Cofactor.BigInt(&cofactor) + scalar.SetBytes(sig.S[:]) + + var lhsFixed, lhsGeneric twistededwards.PointAffine + lhsFixed.ScalarMultiplicationBase(&sig.S). + ScalarMultiplication(&lhsFixed, &cofactor) + lhsGeneric.ScalarMultiplication(&curveParams.Base, &scalar). + ScalarMultiplication(&lhsGeneric, &cofactor) + if !lhsFixed.Equal(&lhsGeneric) { + t.Fatal("[S]Base mismatch against generic reference") + } + + ok, err := pubKey.Verify(signature, message, sha256.New()) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Verify correct signature should return true") + } + + ok, err = pubKey.Verify(signature, []byte("wrong_message"), sha256.New()) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("Verify wrong signature should be false") + } +} + func TestNonMalleability(t *testing.T) { // buffer too big @@ -175,7 +378,7 @@ func TestEddsaMIMC(t *testing.T) { t.Fatal(nil) } pubKey := privKey.PublicKey - hFunc := hash.MIMC_{{ .EnumID }}.New() + hFunc := ghash.MIMC_{{ .EnumID }}.New() var frMsg fr.Element frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") @@ -215,8 +418,6 @@ func TestEddsaSHA256(t *testing.T) { hFunc := sha256.New() // create eddsa obj and sign a message - // create eddsa obj and sign a message - privKey, err := GenerateKey(r) pubKey := privKey.PublicKey if err != nil { @@ -250,12 +451,45 @@ func TestEddsaSHA256(t *testing.T) { // benchmarks +func BenchmarkGenerateKey(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + + b.ResetTimer() + for range b.N { + if _, err := GenerateKey(r); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign(b *testing.B) { + + r := rand.New(rand.NewSource(0)) //#nosec G404 deterministic benchmark seed + privKey, err := GenerateKey(r) + if err != nil { + b.Fatal(err) + } + + var frMsg fr.Element + frMsg.SetString("44717650746155748460101257525078853138837311576962212923649547644148297035978") + msgBin := frMsg.Bytes() + hFunc := ghash.MIMC_{{ .EnumID }}.New() + + b.ResetTimer() + for range b.N { + if _, err := privKey.Sign(msgBin[:], hFunc); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkVerify(b *testing.B) { src := rand.NewSource(0) r := rand.New(src) //#nosec G404 weak rng is fine here - hFunc := hash.MIMC_{{ .EnumID }}.New() + hFunc := ghash.MIMC_{{ .EnumID }}.New() // create eddsa obj and sign a message privKey, err := GenerateKey(r) diff --git a/internal/generator/edwards/template/point.go.tmpl b/internal/generator/edwards/template/point.go.tmpl index 52f203f73..a34fdbae5 100644 --- a/internal/generator/edwards/template/point.go.tmpl +++ b/internal/generator/edwards/template/point.go.tmpl @@ -3,6 +3,7 @@ import ( "io" "math/big" "math/bits" + "sync" "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" ) @@ -32,6 +33,15 @@ const ( // size in byte of a compressed point (point.Y --> fr.Element) sizePointCompressed = fr.Bytes + + fixedBaseWindowSize = 4 + fixedBaseWindowEntries = 1 << fixedBaseWindowSize + fixedBaseWindowCount = fr.Bytes * 2 +) + +var ( + fixedBaseTableOnce sync.Once + fixedBaseTable [fixedBaseWindowCount][fixedBaseWindowEntries]PointAffine ) // Bytes returns the compressed point as a byte array @@ -124,6 +134,14 @@ func (p *PointAffine) Set(p1 *PointAffine) *PointAffine { return p } +// selectPoint is a constant-time conditional move. +// If c=0, p = p0. Else p = p1. +func (p *PointAffine) selectPoint(c int, p0, p1 *PointAffine) *PointAffine { + p.X.Select(c, &p0.X, &p1.X) + p.Y.Select(c, &p0.Y, &p1.Y) + return p +} + // Equal returns true if p=p1 false otherwise func (p *PointAffine) Equal(p1 *PointAffine) bool { return p.X.Equal(&p1.X) && p.Y.Equal(&p1.Y) @@ -317,6 +335,31 @@ func (p *PointAffine) ScalarMultiplication(p1 *PointAffine, scalar *big.Int) *Po return p.scalarMulWindowed(p1, scalar) } +// ScalarMultiplicationBase computes [scalar]Base in affine coordinates. +// scalar is interpreted as a fixed-length big-endian unsigned integer. +func (p *PointAffine) ScalarMultiplicationBase(scalar *[fr.Bytes]byte) *PointAffine { + fixedBaseTableOnce.Do(initFixedBaseTable) + + var resExtended PointExtended + resExtended.setInfinity() + + for i := range fixedBaseWindowCount { + digit := fixedBaseNibble(scalar, i) + + var selected PointAffine + selected.setInfinity() + for j := range fixedBaseWindowEntries { + match := subtle.ConstantTimeByteEq(digit, byte(j)) + selected.selectPoint(match, &selected, &fixedBaseTable[i][j]) + } + + resExtended.MixedAdd(&resExtended, &selected) + } + + p.FromExtended(&resExtended) + return p +} + // scalarMulWindowed scalar multiplication of a point // p1 in affine coordinates with a scalar in big.Int // using the windowed double-and-add method. @@ -348,6 +391,32 @@ func (p *PointAffine) scalarMulWindowed(p1 *PointAffine, scalar *big.Int) *Point return p } +func fixedBaseNibble(scalar *[fr.Bytes]byte, i int) byte { + b := scalar[fr.Bytes-1-(i>>1)] + if i&1 == 0 { + return b & (fixedBaseWindowEntries - 1) + } + return b >> fixedBaseWindowSize +} + +func initFixedBaseTable() { + initOnce.Do(initCurveParams) + + var base PointAffine + base.Set(&curveParams.Base) + + for i := range fixedBaseWindowCount { + fixedBaseTable[i][0].setInfinity() + fixedBaseTable[i][1].Set(&base) + for j := 2; j < fixedBaseWindowEntries; j++ { + fixedBaseTable[i][j].Add(&fixedBaseTable[i][j-1], &base) + } + for range fixedBaseWindowSize { + base.Double(&base) + } + } +} + // setInfinity sets p to O (0:1) func (p *PointAffine) setInfinity() *PointAffine { p.X.SetZero() diff --git a/internal/generator/edwards/template/tests/point.go.tmpl b/internal/generator/edwards/template/tests/point.go.tmpl index b5af58a89..048253239 100644 --- a/internal/generator/edwards/template/tests/point.go.tmpl +++ b/internal/generator/edwards/template/tests/point.go.tmpl @@ -934,6 +934,68 @@ func GenBigInt() gopter.Gen { } } +func scalarBytesFromBigInt(s *big.Int) [fr.Bytes]byte { + var res [fr.Bytes]byte + s.FillBytes(res[:]) + return res +} + +func TestScalarMultiplicationBase(t *testing.T) { + t.Parallel() + + params := GetEdwardsCurve() + + cases := make([][fr.Bytes]byte, 0, 8+nbFuzz) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(0))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(1))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(2))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(15))) + cases = append(cases, scalarBytesFromBigInt(big.NewInt(16))) + + var orderMinusOne big.Int + orderMinusOne.Sub(¶ms.Order, big.NewInt(1)) + cases = append(cases, scalarBytesFromBigInt(&orderMinusOne)) + + var allFF [fr.Bytes]byte + for i := range allFF { + allFF[i] = 0xff + } + cases = append(cases, allFF) + + var highestNibble [fr.Bytes]byte + highestNibble[0] = 0xf0 + cases = append(cases, highestNibble) + + var lowestNibble [fr.Bytes]byte + lowestNibble[fr.Bytes-1] = 0x0f + cases = append(cases, lowestNibble) + + loops := nbFuzz + if testing.Short() { + loops = nbFuzzShort + } + for range loops { + var scalarBytes [fr.Bytes]byte + if _, err := rand.Read(scalarBytes[:]); err != nil { + t.Fatal(err) + } + cases = append(cases, scalarBytes) + } + + for _, scalarBytes := range cases { + var scalar big.Int + scalar.SetBytes(scalarBytes[:]) + + var fixed, generic PointAffine + fixed.ScalarMultiplicationBase(&scalarBytes) + generic.ScalarMultiplication(¶ms.Base, &scalar) + + if !fixed.Equal(&generic) { + t.Fatalf("fixed-base mismatch for scalar %x", scalarBytes) + } + } +} + // ------------------------------------------------------------ // benches @@ -1013,6 +1075,33 @@ func BenchmarkScalarMulProjective(b *testing.B) { } } +func BenchmarkScalarMulAffineBase(b *testing.B) { + params := GetEdwardsCurve() + + var scalar big.Int + scalar.SetString("52435875175126190479447705081859658376581184513", 10) + scalar.Add(&scalar, ¶ms.Order) + scalarBytes := scalarBytesFromBigInt(&scalar) + + b.Run("generic", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplication(¶ms.Base, &scalar) + } + }) + + b.Run("fixed", func(b *testing.B) { + var point PointAffine + + b.ResetTimer() + for range b.N { + point.ScalarMultiplicationBase(&scalarBytes) + } + }) +} + func BenchmarkNeg(b *testing.B) { params := GetEdwardsCurve() var s big.Int