diff --git a/curves/bls/bls12377/g1.go b/curves/bls/bls12377/g1.go new file mode 100644 index 0000000..a16d357 --- /dev/null +++ b/curves/bls/bls12377/g1.go @@ -0,0 +1,52 @@ +package bls12377 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupG1 struct { +} + +// String returns the string for the group +func (g1 *groupG1) String() string { + return "BLS12-377 G1" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (g1 *groupG1) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar initialized with base point on G1 +func (g1 *groupG1) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (g1 *groupG1) PointLen() int { + return fp.Bytes +} + +// CreatePoint creates a new point +func (g1 *groupG1) CreatePoint() crypto.Point { + return NewPointG1() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (g1 *groupG1) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + var p crypto.Point + var err error + p = NewPointG1() + p, err = p.Mul(scalar) + if err != nil { + log.Error("groupG1 CreatePointForScalar", "error", err.Error()) + } + return p +} + +// IsInterfaceNil returns true if there is no value under the interface +func (g1 *groupG1) IsInterfaceNil() bool { + return g1 == nil +} diff --git a/curves/bls/bls12377/g1_test.go b/curves/bls/bls12377/g1_test.go new file mode 100644 index 0000000..792cda7 --- /dev/null +++ b/curves/bls/bls12377/g1_test.go @@ -0,0 +1,160 @@ +package bls12377 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/stretchr/testify/require" + "math/big" + "testing" +) + +func TestGroupG1_String(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + str := grG1.String() + require.Equal(t, "BLS12-377 G1", str) +} + +func TestGroupG1_ScalarLen(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + x := grG1.ScalarLen() + require.Equal(t, fr.Bytes, x) +} + +func TestGroupG1_PointLen(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + x := grG1.PointLen() + require.Equal(t, 48, x) +} + +func TestGroupG1_CreatePoint(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + g1Gen, _, _, _ := gnark.Generators() + point := &PointG1{ + G1: &g1Gen, + } + + x := grG1.CreatePoint() + require.NotNil(t, x) + bls12381Point, ok := x.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381Point.IsOnCurve()) + require.Equal(t, point.G1, bls12381Point) +} + +func TestGroupG1_CreateScalar(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + sc := grG1.CreateScalar() + require.NotNil(t, sc) + + bls12381Scalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) +} + +func TestGroupG1_CreatePointForScalar(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + bls12381PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381PointG1.IsOnCurve()) + + bG1 := NewPointG1().G1 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG1 := bG1.ScalarMultiplication(bG1, &scalarBigInt) + + require.True(t, bls12381PointG1.Equal(computedG1)) +} + +func TestGroupG1_CreatePointForScalarZero(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + scalar.SetInt64(0) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsZero()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + bls12381PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381PointG1.Z.IsZero()) + require.True(t, bls12381PointG1.IsOnCurve()) + + bG1 := NewPointG1().G1 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG1 := bG1.ScalarMultiplication(bG1, &scalarBigInt) + + require.True(t, bls12381PointG1.Equal(computedG1)) + +} + +func TestGroupG1_CreatePointForScalarOne(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + scalar.SetInt64(1) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsOne()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + baseG1 := NewPointG1().G1 + bls12381PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381PointG1.Equal(baseG1)) +} + +func TestGroupG1_CreatePointForScalarNil(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + pG1 := grG1.CreatePointForScalar(nil) + require.Equal(t, nil, pG1) +} + +func TestGroupG1_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grG1 *groupG1 + + require.True(t, grG1.IsInterfaceNil()) + grG1 = &groupG1{} + require.False(t, grG1.IsInterfaceNil()) +} diff --git a/curves/bls/bls12377/g2.go b/curves/bls/bls12377/g2.go new file mode 100644 index 0000000..f903707 --- /dev/null +++ b/curves/bls/bls12377/g2.go @@ -0,0 +1,52 @@ +package bls12377 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupG2 struct { +} + +// String returns the string for the group +func (g2 *groupG2) String() string { + return "BLS12-377 G2" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (g2 *groupG2) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar initialized with base point on G2 +func (g2 *groupG2) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (g2 *groupG2) PointLen() int { + return fp.Bytes * 2 +} + +// CreatePoint creates a new point +func (g2 *groupG2) CreatePoint() crypto.Point { + return NewPointG2() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (g2 *groupG2) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + var p crypto.Point + var err error + p = NewPointG2() + p, err = p.Mul(scalar) + if err != nil { + log.Error("groupG2 CreatePointForScalar", "error", err.Error()) + } + return p +} + +// IsInterfaceNil returns true if there is no value under the interface +func (g2 *groupG2) IsInterfaceNil() bool { + return g2 == nil +} diff --git a/curves/bls/bls12377/g2_test.go b/curves/bls/bls12377/g2_test.go new file mode 100644 index 0000000..f70366f --- /dev/null +++ b/curves/bls/bls12377/g2_test.go @@ -0,0 +1,150 @@ +package bls12377 + +import ( + "math/big" + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/require" +) + +func TestGroupG2_String(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + str := grG2.String() + require.Equal(t, str, "BLS12-377 G2") +} + +func TestGroupG2_ScalarLen(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + x := grG2.ScalarLen() + require.Equal(t, 32, x) +} + +func TestGroupG2_PointLen(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + x := grG2.PointLen() + require.Equal(t, 96, x) +} + +func TestGroupG2_CreatePoint(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + point := &PointG2{ + G2: &gnark.G2Jac{}, + } + + _, g2Gen, _, _ := gnark.Generators() + point.G2 = &g2Gen + x := grG2.CreatePoint() + require.NotNil(t, x) + bls12381Point, ok := x.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.Equal(t, point.G2, bls12381Point) +} + +func TestGroupG2_CreateScalar(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + sc := grG2.CreateScalar() + require.NotNil(t, sc) + + bls12381Scalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) +} + +func TestGroupG2_CreatePointForScalar(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bls12381PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, bls12381PointG2.Equal(computedG2)) +} + +func TestGroupG2_CreatePointForScalarZero(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + scalar.SetInt64(0) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsZero()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bls12381PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bls12381PointG2.Z.IsZero()) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, bls12381PointG2.Equal(computedG2)) +} + +func TestGroupG2_CreatePointForScalarOne(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + scalar.SetInt64(1) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsOne()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bG2 := NewPointG2().G2 + bls12381PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bls12381PointG2.Equal(bG2)) +} + +func TestGroupG2_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grG2 *groupG2 + + require.True(t, check.IfNil(grG2)) + grG2 = &groupG2{} + require.False(t, check.IfNil(grG2)) +} diff --git a/curves/bls/bls12377/gt.go b/curves/bls/bls12377/gt.go new file mode 100644 index 0000000..282cd9b --- /dev/null +++ b/curves/bls/bls12377/gt.go @@ -0,0 +1,45 @@ +package bls12377 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupGT struct { +} + +// String returns the string for the group +func (gt *groupGT) String() string { + return "BLS12-377 GT" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (gt *groupGT) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar +func (gt *groupGT) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (gt *groupGT) PointLen() int { + return fp.Bytes * 12 +} + +// CreatePoint creates a new point +func (gt *groupGT) CreatePoint() crypto.Point { + return NewPointGT() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (gt *groupGT) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + panic("not supported") +} + +// IsInterfaceNil returns true if there is no value under the interface +func (gt *groupGT) IsInterfaceNil() bool { + return gt == nil +} diff --git a/curves/bls/bls12377/gt_test.go b/curves/bls/bls12377/gt_test.go new file mode 100644 index 0000000..5148023 --- /dev/null +++ b/curves/bls/bls12377/gt_test.go @@ -0,0 +1,95 @@ +package bls12377 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestGroupGT_String(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + str := grGT.String() + require.Equal(t, str, "BLS12-377 GT") +} + +func TestGroupGT_ScalarLen(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.ScalarLen() + require.Equal(t, 32, x) +} + +func TestGroupGT_PointLen(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.PointLen() + require.Equal(t, 48*12, x) +} + +func TestGroupGT_CreatePoint(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.CreatePoint() + require.NotNil(t, x) + + bls12381Point, ok := x.GetUnderlyingObj().(*gnark.GT) + require.True(t, ok) + // points created on GT are initialized with PointZero + require.True(t, bls12381Point.IsZero()) +} + +func TestGroupGT_CreateScalar(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + sc := grGT.CreateScalar() + require.NotNil(t, sc) + + mclScalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, mclScalar.IsZero()) + require.False(t, mclScalar.IsOne()) +} + +func TestGroupGT_CreatePointForScalar(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should panic as currently not supported") + } + }() + + grGT := &groupGT{} + + scalar := grGT.CreateScalar() + mclScalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, mclScalar.IsZero()) + require.False(t, mclScalar.IsOne()) + + _ = grGT.CreatePointForScalar(scalar) +} + +func TestGroupGT_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grGT *groupGT + + require.True(t, grGT.IsInterfaceNil()) + grGT = &groupGT{} + require.False(t, grGT.IsInterfaceNil()) +} diff --git a/curves/bls/bls12377/pointG1.go b/curves/bls/bls12377/pointG1.go new file mode 100644 index 0000000..25db1af --- /dev/null +++ b/curves/bls/bls12377/pointG1.go @@ -0,0 +1,206 @@ +package bls12377 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointG1 - +type PointG1 struct { + G1 *gnark.G1Jac +} + +func NewPointG1() *PointG1 { + point := &PointG1{ + G1: &gnark.G1Jac{}, + } + + g1Gen, _, _, _ := gnark.Generators() + point.G1 = &g1Gen + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointG1) Equal(p crypto.Point) (bool, error) { + if p == nil { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointG1) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.G1.Equal(po2.G1), nil +} + +// Clone returns a clone of the receiver. +func (po *PointG1) Clone() crypto.Point { + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.Set(po.G1) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointG1) Null() crypto.Point { + p := &PointG1{ + G1: &gnark.G1Jac{}, + } + + p.G1.Z.SetZero() + p.G1.X.SetOne() + p.G1.Y.SetOne() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointG1) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return crypto.ErrInvalidParam + } + + po.G1.Set(po1.G1) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointG1) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G1 = po2.G1.AddAssign(po1.G1) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointG1) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G1 = po2.G1.SubAssign(po1.G1) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointG1) Neg() crypto.Point { + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.Neg(po.G1) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointG1) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.G1 = po2.G1.ScalarMultiplication(po.G1, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointG1) Pick() (crypto.Point, error) { + scalar := NewScalar() + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.ScalarMultiplication(po.G1, scalar.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointG1) GetUnderlyingObj() interface{} { + return po.G1 +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointG1) MarshalBinary() ([]byte, error) { + affinePoint := &gnark.G1Affine{} + affinePoint.FromJacobian(po.G1) + + return affinePoint.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointG1) UnmarshalBinary(point []byte) error { + affinePoint := &gnark.G1Affine{} + err := affinePoint.Unmarshal(point) + if err != nil { + return err + } + + po.G1 = po.G1.FromAffine(affinePoint) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointG1) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bls/bls12377/pointG1_test.go b/curves/bls/bls12377/pointG1_test.go new file mode 100644 index 0000000..e9b5b97 --- /dev/null +++ b/curves/bls/bls12377/pointG1_test.go @@ -0,0 +1,310 @@ +package bls12377 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNewPointG1(t *testing.T) { + g1Gen, _, _, _ := gnark.Generators() + bG1 := &g1Gen + + pG1 := NewPointG1() + require.NotNil(t, pG1) + + mclPointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bG1.Equal(mclPointG1)) +} + +func TestPointG1_Equal(t *testing.T) { + p1G1 := NewPointG1() + p2G1 := NewPointG1() + + eq, err := p1G1.Equal(p2G1) + require.Nil(t, err) + require.True(t, eq) + + // Make p1G1 different by multiplying it by 2 + scalar := NewScalar() + scalar.SetInt64(2) + p1Modified, err := p1G1.Mul(scalar) + require.Nil(t, err) + p1G1 = p1Modified.(*PointG1) + + eq, err = p1G1.Equal(p2G1) + require.Nil(t, err) + require.False(t, eq) + + grG1 := &groupG1{} + sc1G1 := grG1.CreateScalar() + p1 := grG1.CreatePointForScalar(sc1G1) + p2 := grG1.CreatePointForScalar(sc1G1) + + var ok bool + p1G1, ok = p1.(*PointG1) + require.True(t, ok) + + p2G1, ok = p2.(*PointG1) + require.True(t, ok) + + eq, err = p1G1.Equal(p2G1) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_CloneNilShouldPanic(t *testing.T) { + var p1 *PointG1 + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointG1_Clone(t *testing.T) { + p1 := NewPointG1() + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_Null(t *testing.T) { + p1 := NewPointG1() + + point := p1.Null() + bls12381Point, ok := point.(*PointG1) + require.True(t, ok) + require.True(t, bls12381Point.G1.X.IsOne()) + require.True(t, bls12381Point.G1.Y.IsOne()) + require.True(t, bls12381Point.G1.Z.IsZero()) + + bls12381PointNeg := &gnark.G1Jac{} + bls12381PointNeg = bls12381PointNeg.Neg(bls12381Point.G1) + + // neutral identity point should be equal to it's negation + ok = bls12381Point.G1.Equal(bls12381PointNeg) + require.True(t, ok) +} + +func TestPointG1_Set(t *testing.T) { + p1 := NewPointG1() + p2 := NewPointG1() + + scalar := NewScalar() + scalar.SetInt64(2) + p2Modified, err := p2.Mul(scalar) + require.Nil(t, err) + p2 = p2Modified.(*PointG1) + + err = p1.Set(p2) + require.Nil(t, err) + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG1_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG1_AddOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point1, err := pointG1.Pick() + require.Nil(t, err) + + point2, err := pointG1.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointG1_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point2, err := pointG1.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG1_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point2 := &mock.PointMock{} + point3, err := pointG1.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG1_SubOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point1, err := pointG1.Pick() + require.Nil(t, err) + + point2, err := pointG1.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG1_Neg(t *testing.T) { + point1 := NewPointG1() + + point2 := point1.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point1, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point1, point3) +} + +func TestPointG1_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointG1_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointG1_MulOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := pointG1.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, pointG1, res) + + grG1 := &groupG1{} + point2 := grG1.CreatePointForScalar(scalar) + eq, err := res.Equal(point2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointG1_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointG1_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointG1_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointG1().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointG1() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG1_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointG1 + + require.True(t, check.IfNil(point)) + point = NewPointG1() + require.False(t, check.IfNil(point)) +} diff --git a/curves/bls/bls12377/pointG2.go b/curves/bls/bls12377/pointG2.go new file mode 100644 index 0000000..b95e055 --- /dev/null +++ b/curves/bls/bls12377/pointG2.go @@ -0,0 +1,207 @@ +package bls12377 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointG2 - +type PointG2 struct { + G2 *gnark.G2Jac +} + +// NewPointG2 creates a new point on G2 initialized with base point +func NewPointG2() *PointG2 { + point := &PointG2{ + G2: &gnark.G2Jac{}, + } + + _, g2Gen, _, _ := gnark.Generators() + point.G2 = &g2Gen + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointG2) Equal(p crypto.Point) (bool, error) { + if check.IfNil(p) { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointG2) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.G2.Equal(po2.G2), nil +} + +// Clone returns a clone of the receiver. +func (po *PointG2) Clone() crypto.Point { + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.Set(po.G2) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointG2) Null() crypto.Point { + p := &PointG2{ + G2: &gnark.G2Jac{}, + } + + p.G2.Z.SetOne() + p.G2.X.SetZero() + p.G2.Y.SetZero() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointG2) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return crypto.ErrInvalidParam + } + + po.G2.Set(po1.G2) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointG2) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G2 = po2.G2.AddAssign(po1.G2) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointG2) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G2 = po2.G2.SubAssign(po1.G2) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointG2) Neg() crypto.Point { + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.Neg(po.G2) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointG2) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.G2 = po2.G2.ScalarMultiplication(po.G2, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointG2) Pick() (crypto.Point, error) { + scalar := NewScalar() + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.ScalarMultiplication(po.G2, scalar.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointG2) GetUnderlyingObj() interface{} { + return po.G2 +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointG2) MarshalBinary() ([]byte, error) { + affinePoint := &gnark.G2Affine{} + affinePoint.FromJacobian(po.G2) + + return affinePoint.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointG2) UnmarshalBinary(point []byte) error { + affinePoint := &gnark.G2Affine{} + err := affinePoint.Unmarshal(point) + if err != nil { + return err + } + + po.G2 = po.G2.FromAffine(affinePoint) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointG2) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bls/bls12377/pointG2_test.go b/curves/bls/bls12377/pointG2_test.go new file mode 100644 index 0000000..6098eb0 --- /dev/null +++ b/curves/bls/bls12377/pointG2_test.go @@ -0,0 +1,311 @@ +package bls12377 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNewPointG2(t *testing.T) { + _, g2Gen, _, _ := gnark.Generators() + bG2 := &g2Gen + + pG2 := NewPointG2() + require.NotNil(t, pG2) + + mclPointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bG2.Equal(mclPointG2)) +} + +func TestPointG2_Equal(t *testing.T) { + p1G2 := NewPointG2() + p2G2 := NewPointG2() + + // new points should be initialized with base point so should be equal + eq, err := p1G2.Equal(p2G2) + require.Nil(t, err) + require.True(t, eq) + + // Make p1G1 different by multiplying it by 2 + scalar := NewScalar() + scalar.SetInt64(2) + p1Modified, err := p1G2.Mul(scalar) + require.Nil(t, err) + p1G2 = p1Modified.(*PointG2) + + eq, err = p1G2.Equal(p2G2) + require.Nil(t, err) + require.False(t, eq) + + grG2 := &groupG2{} + sc1G2 := grG2.CreateScalar() + p1 := grG2.CreatePointForScalar(sc1G2) + p2 := grG2.CreatePointForScalar(sc1G2) + + var ok bool + p1G2, ok = p1.(*PointG2) + require.True(t, ok) + + p2G2, ok = p2.(*PointG2) + require.True(t, ok) + + eq, err = p1G2.Equal(p2G2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_CloneNilShouldPanic(t *testing.T) { + var p1 *PointG2 + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointG2_Clone(t *testing.T) { + p1 := NewPointG2() + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_Null(t *testing.T) { + p1 := NewPointG2() + + point := p1.Null() + bls12381Point, ok := point.(*PointG2) + require.True(t, ok) + require.True(t, bls12381Point.G2.X.IsZero()) + require.True(t, bls12381Point.G2.Y.IsZero()) + require.True(t, bls12381Point.G2.Z.IsOne()) + + bls12381PointNeg := &gnark.G2Jac{} + bls12381PointNeg = bls12381PointNeg.Neg(bls12381Point.G2) + + // neutral identity point should be equal to it's negation + ok = bls12381Point.G2.Equal(bls12381PointNeg) + require.True(t, ok) +} + +func TestPointG2_Set(t *testing.T) { + p1 := NewPointG2() + p2 := NewPointG2() + + scalar := NewScalar() + scalar.SetInt64(2) + p2Modified, err := p2.Mul(scalar) + require.Nil(t, err) + p2 = p2Modified.(*PointG2) + + err = p1.Set(p2) + require.Nil(t, err) + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG2_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG2_AddOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point1, err := pointG2.Pick() + require.Nil(t, err) + + point2, err := pointG2.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointG2_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point2, err := pointG2.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG2_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point2 := &mock.PointMock{} + point3, err := pointG2.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG2_SubOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point1, err := pointG2.Pick() + require.Nil(t, err) + + point2, err := pointG2.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG2_Neg(t *testing.T) { + point1 := NewPointG2() + + point2 := point1.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point1, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point1, point3) +} + +func TestPointG2_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointG2_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointG2_MulOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := pointG2.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, pointG2, res) + + grG2 := &groupG2{} + point2 := grG2.CreatePointForScalar(scalar) + eq, err := res.Equal(point2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointG2_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointG2_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointG2_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointG2().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointG2() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG2_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointG2 + + require.True(t, check.IfNil(point)) + point = NewPointG2() + require.False(t, check.IfNil(point)) +} diff --git a/curves/bls/bls12377/pointGT.go b/curves/bls/bls12377/pointGT.go new file mode 100644 index 0000000..54709bd --- /dev/null +++ b/curves/bls/bls12377/pointGT.go @@ -0,0 +1,221 @@ +package bls12377 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointGT - +type PointGT struct { + *gnark.GT +} + +// NewPointGT creates a new point on GT initialized with identity +func NewPointGT() *PointGT { + point := &PointGT{ + GT: &gnark.GT{}, + } + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointGT) Equal(p crypto.Point) (bool, error) { + if check.IfNil(p) { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointGT) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.GT.Equal(po2.GT), nil +} + +// Clone returns a clone of the receiver. +func (po *PointGT) Clone() crypto.Point { + po2 := PointGT{ + GT: &gnark.GT{}, + } + + po2.GT = po2.GT.Set(po.GT) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointGT) Null() crypto.Point { + p := NewPointGT() + p.GT.C0.B0.SetZero() + p.GT.C0.B1.SetZero() + p.GT.C0.B2.SetZero() + + p.GT.C1.B0.SetZero() + p.GT.C1.B1.SetZero() + p.GT.C1.B2.SetZero() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointGT) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return crypto.ErrInvalidParam + } + + po.GT.Set(po1.GT) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointGT) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.GT = po2.GT.Add(po2.GT, po1.GT) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointGT) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.GT = po2.GT.Sub(po2.GT, po1.GT) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointGT) Neg() crypto.Point { + po2 := PointGT{ + GT: &gnark.GT{}, + } + + // Multiplicative in GT, we can use inverse + po2.GT = po2.GT.Inverse(po.GT) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointGT) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.GT = po2.GT.Exp(*po.GT, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointGT) Pick() (crypto.Point, error) { + var p1, p2 crypto.Point + var err error + + p1, err = NewPointG1().Pick() + if err != nil { + return nil, err + } + + p2, err = NewPointG2().Pick() + if err != nil { + return nil, err + } + + poG1 := p1.(*PointG1) + poG2 := p2.(*PointG2) + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + g1Affine := &gnark.G1Affine{} + g1Affine.FromJacobian(poG1.G1) + + g2Affine := &gnark.G2Affine{} + g2Affine.FromJacobian(poG2.G2) + + paired, err := gnark.Pair([]gnark.G1Affine{*g1Affine}, []gnark.G2Affine{*g2Affine}) + if err != nil { + return nil, err + } + + po2.GT = &paired + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointGT) GetUnderlyingObj() interface{} { + return po.GT +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointGT) MarshalBinary() ([]byte, error) { + return po.GT.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointGT) UnmarshalBinary(point []byte) error { + return po.GT.Unmarshal(point) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointGT) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bls/bls12377/pointGT_test.go b/curves/bls/bls12377/pointGT_test.go new file mode 100644 index 0000000..08ea857 --- /dev/null +++ b/curves/bls/bls12377/pointGT_test.go @@ -0,0 +1,275 @@ +package bls12377 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNewPointGT(t *testing.T) { + pGT := NewPointGT() + require.NotNil(t, pGT) + + mclPointGT, ok := pGT.GetUnderlyingObj().(*gnark.GT) + require.True(t, ok) + require.True(t, mclPointGT.IsZero()) +} + +func TestPointGT_Equal(t *testing.T) { + p1GT := NewPointGT() + p2GT := NewPointGT() + + eq, err := p1GT.Equal(p2GT) + require.Nil(t, err) + require.True(t, eq) + + p2GT.SetOne() + eq, err = p1GT.Equal(p2GT) + require.Nil(t, err) + require.False(t, eq) +} + +func TestPointGT_CloneNilShouldPanic(t *testing.T) { + var p1 *PointGT + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointGT_Clone(t *testing.T) { + p1 := NewPointGT() + _, err := p1.GT.SetRandom() + require.Nil(t, err) + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointGT_Null(t *testing.T) { + p1 := NewPointGT() + + point := p1.Null() + bls12381Point, ok := point.(*PointGT) + require.True(t, ok) + require.True(t, bls12381Point.IsZero()) + bls12381PointNeg := &gnark.GT{} + // TODO + + // neutral identity point should be equal to it's negation + require.True(t, bls12381Point.GT.Equal(bls12381PointNeg)) +} + +func TestPointGT_Set(t *testing.T) { + p1 := NewPointGT() + p2 := NewPointGT() + + p2.GT.SetOne() + + err := p1.Set(p2) + require.Nil(t, err) + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointGT_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointGT_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointGT_AddOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point1, err := pointGT.Pick() + require.Nil(t, err) + + point2, err := pointGT.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointGT_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2, err := pointGT.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointGT_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2 := &mock.PointMock{} + point3, err := pointGT.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointGT_SubOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point1, err := pointGT.Pick() + require.Nil(t, err) + + point2, err := pointGT.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointGT_Neg(t *testing.T) { + point1 := NewPointGT() + point, _ := point1.Pick() + + point2 := point.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point, point3) +} + +func TestPointGT_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointGT_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointGT_MulOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2, _ := pointGT.Pick() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := point2.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, point2, res) +} + +func TestPointGT_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointGT_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointGT_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointGT_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointGT().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointGT() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointGT_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointGT + + require.True(t, point.IsInterfaceNil()) + point = NewPointGT() + require.False(t, point.IsInterfaceNil()) +} diff --git a/curves/bls/bls12377/scalar.go b/curves/bls/bls12377/scalar.go new file mode 100644 index 0000000..2274391 --- /dev/null +++ b/curves/bls/bls12377/scalar.go @@ -0,0 +1,270 @@ +package bls12377 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type Scalar struct { + Scalar *fr.Element +} + +func NewScalar() *Scalar { + scalar := &Scalar{Scalar: &fr.Element{}} + scalar.setRandom() + + for scalar.Scalar.IsOne() || scalar.Scalar.IsZero() { + scalar.setRandom() + } + + return scalar +} + +// Equal tests if receiver is equal with the scalarInt s given as parameter. +// Both scalars need to be derived from the same Group +func (sc *Scalar) Equal(s crypto.Scalar) (bool, error) { + if check.IfNil(s) { + return false, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return false, crypto.ErrInvalidParam + } + + areEqual := sc.Scalar.Equal(s2.Scalar) + + return areEqual, nil +} + +// Set sets the receiver to Scalar s given as parameter +func (sc *Scalar) Set(s crypto.Scalar) error { + if check.IfNil(s) { + return crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return crypto.ErrInvalidParam + } + + return sc.Scalar.SetBytesCanonical(s2.Scalar.Marshal()) +} + +// Clone creates a new Scalar with same value as receiver +func (sc *Scalar) Clone() crypto.Scalar { + scalar := &Scalar{ + Scalar: &fr.Element{}, + } + + scalar.Scalar.SetBytes(sc.Scalar.Marshal()) + + return scalar +} + +// SetInt64 sets the receiver to a small integer value v given as parameter +func (sc *Scalar) SetInt64(v int64) { + sc.Scalar.SetInt64(v) +} + +// Zero returns the additive identity (0) +func (sc *Scalar) Zero() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + s.Scalar.SetZero() + + return &s +} + +// Add returns the modular sum of receiver with scalar s given as parameter +func (sc *Scalar) Add(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := &Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Add(sc.Scalar, s2.Scalar) + + return s1, nil +} + +// Sub returns the modular difference between receiver and scalar s given as parameter +func (sc *Scalar) Sub(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := &Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Sub(sc.Scalar, s2.Scalar) + + return s1, nil +} + +// Neg returns the modular negation of receiver +func (sc *Scalar) Neg() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + + s.Scalar.Neg(sc.Scalar) + + return &s +} + +// One returns the multiplicative identity (1) +func (sc *Scalar) One() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + s.Scalar.SetOne() + + return &s +} + +// Mul returns the modular product of receiver with scalar s given as parameter +func (sc *Scalar) Mul(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Mul(sc.Scalar, s2.Scalar) + + return &s1, nil +} + +// Div returns the modular division between receiver and scalar s given as parameter +func (sc *Scalar) Div(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Div(sc.Scalar, s2.Scalar) + + return &s1, nil +} + +// Inv returns the modular inverse of scalar s given as parameter +func (sc *Scalar) Inv(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Inverse(s2.Scalar) + + return &s1, nil +} + +// Pick returns a fresh random or pseudo-random scalar +// For the mock set X to the original scalar.X *2 +func (sc *Scalar) Pick() (crypto.Scalar, error) { + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + _, err := s1.Scalar.SetRandom() + if err != nil { + return nil, err + } + + for s1.Scalar.IsOne() || s1.Scalar.IsZero() { + _, err = s1.Scalar.SetRandom() + if err != nil { + return nil, err + } + } + + return &s1, nil +} + +// SetBytes sets the scalar from a byte-slice, +// reducing if necessary to the appropriate modulus. +func (sc *Scalar) SetBytes(s []byte) (crypto.Scalar, error) { + if len(s) == 0 { + return nil, crypto.ErrNilParam + } + + s1 := sc.Clone() + s2, ok := s1.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidScalar + } + + err := s2.Scalar.SetBytesCanonical(s) + if err != nil { + return nil, err + } + + return s1, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (sc *Scalar) GetUnderlyingObj() interface{} { + return sc.Scalar +} + +// MarshalBinary transforms the Scalar into a byte array +func (sc *Scalar) MarshalBinary() ([]byte, error) { + return sc.Scalar.Marshal(), nil +} + +// UnmarshalBinary recreates the Scalar from a byte array +func (sc *Scalar) UnmarshalBinary(val []byte) error { + return sc.Scalar.SetBytesCanonical(val) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sc *Scalar) IsInterfaceNil() bool { + return sc == nil +} + +func (sc *Scalar) setRandom() { + _, err := sc.Scalar.SetRandom() + if err != nil { + panic("BLS12377 cannot read from rand to create a new scalar") + } +} diff --git a/curves/bls/bls12377/scalar_test.go b/curves/bls/bls12377/scalar_test.go new file mode 100644 index 0000000..e79a514 --- /dev/null +++ b/curves/bls/bls12377/scalar_test.go @@ -0,0 +1,406 @@ +package bls12377 + +import ( + "testing" + + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/require" +) + +func TestBLSScalar_EqualInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + eq, err := scalar1.Equal(scalar2) + + require.False(t, eq) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBLSScalar_EqualTrue(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().One() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_EqualFalse(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().Zero() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.False(t, eq) +} + +func TestBLSScalar_SetNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + err := scalar.Set(nil) + + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBLSScalar_SetInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + err := scalar1.Set(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBLSScalar_SetOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().Zero() + err := scalar1.Set(scalar2) + eq, _ := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_Clone(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := scalar1.Clone() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_SetInt64(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2 := NewScalar() + scalar1.SetInt64(int64(555555555)) + scalar2.SetInt64(int64(444444444)) + + diff, _ := scalar1.Sub(scalar2) + scalar3 := NewScalar() + scalar3.SetInt64(int64(111111111)) + + eq, err := diff.Equal(scalar3) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_Zero(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := NewScalar() + scalar2.SetInt64(0) + + eq, err := scalar2.Equal(scalar1) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().Zero() + sum, err := scalar.Add(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, sum) +} + +func TestBLSScalar_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + sum, err := scalar1.Add(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, sum) +} + +func TestBLSScalar_AddOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().One() + sum, err := scalar1.Add(scalar2) + require.Nil(t, err) + scalar3 := NewScalar() + scalar3.SetInt64(2) + eq, err := scalar3.Equal(sum) + + require.True(t, eq) + require.Nil(t, err) +} + +func TestBLSScalar_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().Zero() + diff, err := scalar.Sub(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, diff) +} + +func TestBLSScalar_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + diff, err := scalar1.Sub(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, diff) +} + +func TestBLSScalar_SubOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2 := NewScalar().One() + diff, err := scalar1.Sub(scalar2) + require.Nil(t, err) + scalar3 := NewScalar() + scalar3.SetInt64(3) + eq, err := scalar3.Equal(diff) + + require.True(t, eq) + require.Nil(t, err) +} + +func TestBLSScalar_Neg(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2 := scalar1.Neg() + scalar3 := NewScalar() + scalar3.SetInt64(-4) + eq, err := scalar2.Equal(scalar3) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_One(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(1) + scalar2 := NewScalar().One() + + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + res, err := scalar.Mul(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + res, err := scalar1.Mul(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_MulOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar() + scalar2.SetInt64(4) + res, err := scalar1.Mul(scalar2) + + require.Nil(t, err) + + eq, _ := res.Equal(scalar2) + + require.True(t, eq) +} + +func TestBLSScalar_DivNilParamShouldEr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + res, err := scalar.Div(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_DivInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + res, err := scalar1.Div(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_DivOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar() + scalar2.SetInt64(4) + res, err := scalar2.Div(scalar1) + + require.Nil(t, err) + + eq, _ := res.Equal(scalar2) + + require.True(t, eq) +} + +func TestBLSScalar_InvNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.Inv(nil) + + require.Nil(t, scalar2) + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBLSScalar_InvInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2 := &mock.ScalarMock{} + scalar3, err := scalar1.Inv(scalar2) + + require.Nil(t, scalar3) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBLSScalar_InvOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2, err := scalar1.Inv(scalar1) + eq, _ := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.NotNil(t, scalar2) + require.False(t, eq) + + one := NewScalar().One() + scalar3, err := scalar1.Inv(one) + require.Nil(t, err) + eq, _ = one.Equal(scalar3) + + require.True(t, eq) +} + +func TestBLSScalar_PickOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.Pick() + require.Nil(t, err) + require.NotNil(t, scalar1, scalar2) + + eq, _ := scalar1.Equal(scalar2) + + require.False(t, eq) +} + +func TestBLSScalar_SetBytesNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.SetBytes(nil) + + require.Nil(t, scalar2) + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBLSScalar_SetBytesOK(t *testing.T) { + t.Parallel() + + val := int64(555555555) + scalar1 := NewScalar().One() + + sc2 := NewScalar() + sc2.SetInt64(val) + buf, _ := sc2.MarshalBinary() + + scalar2, err := scalar1.SetBytes(buf) + require.Nil(t, err) + require.NotEqual(t, scalar1, scalar2) + + scalar3 := NewScalar() + scalar3.SetInt64(val) + + eq, _ := scalar3.Equal(scalar2) + require.True(t, eq) +} + +func TestBLSScalar_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + x := scalar1.GetUnderlyingObj() + + require.NotNil(t, x) +} + +func TestBLSScalar_MarshalBinary(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + + scalarBytes, err := scalar1.MarshalBinary() + + require.Nil(t, err) + require.NotNil(t, scalarBytes) +} + +func TestBLSScalar_UnmarshalBinary(t *testing.T) { + scalar1, _ := NewScalar().Pick() + scalarBytes, err := scalar1.MarshalBinary() + require.Nil(t, err) + scalar2 := NewScalar().Zero() + err = scalar2.UnmarshalBinary(scalarBytes) + require.Nil(t, err) + + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} diff --git a/curves/bls/bls12377/suiteBLS12_377.go b/curves/bls/bls12377/suiteBLS12_377.go new file mode 100644 index 0000000..575f001 --- /dev/null +++ b/curves/bls/bls12377/suiteBLS12_377.go @@ -0,0 +1,129 @@ +package bls12377 + +import ( + "crypto/cipher" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("curves/bls12377") + +// SuiteBLS12377 provides an implementation of the Suite interface for BLS12-377 +type SuiteBLS12377 struct { + G1 *groupG1 + G2 *groupG2 + GT *groupGT + strSuite string +} + +// NewSuiteBLS12 returns a wrapper over a BLS12 curve. +func NewSuiteBLS12() *SuiteBLS12377 { + return &SuiteBLS12377{ + G1: &groupG1{}, + G2: &groupG2{}, + GT: &groupGT{}, + strSuite: "BLS12-377 suite", + } +} + +// RandomStream returns a cipher.Stream that returns a key stream +// from crypto/rand. +func (s *SuiteBLS12377) RandomStream() cipher.Stream { + return nil +} + +// CreatePoint creates a new point +func (s *SuiteBLS12377) CreatePoint() crypto.Point { + return s.G2.CreatePoint() +} + +// String returns the string for the group +func (s *SuiteBLS12377) String() string { + return s.strSuite +} + +// ScalarLen returns the maximum length of scalars in bytes +func (s *SuiteBLS12377) ScalarLen() int { + return s.G2.ScalarLen() +} + +// CreateScalar creates a new Scalar +func (s *SuiteBLS12377) CreateScalar() crypto.Scalar { + return s.G2.CreateScalar() +} + +// CreatePointForScalar creates a new point corresponding to the given scalar +func (s *SuiteBLS12377) CreatePointForScalar(scalar crypto.Scalar) (crypto.Point, error) { + if check.IfNil(scalar) { + return nil, crypto.ErrNilPrivateKeyScalar + } + sc, ok := scalar.GetUnderlyingObj().(*fr.Element) + if !ok { + return nil, crypto.ErrInvalidScalar + } + + if sc.IsZero() { + return nil, crypto.ErrInvalidPrivateKey + } + + point := s.G2.CreatePointForScalar(scalar) + + return point, nil +} + +// PointLen returns the max length of point in nb of bytes +func (s *SuiteBLS12377) PointLen() int { + return s.G2.PointLen() +} + +// CreateKeyPair returns a pair of private public BLS keys. +// The private key is a scalarInt, while the public key is a Point on G2 curve +func (s *SuiteBLS12377) CreateKeyPair() (crypto.Scalar, crypto.Point) { + var sc crypto.Scalar + var err error + + sc = s.G2.CreateScalar() + sc, err = sc.Pick() + if err != nil { + log.Error("SuiteBLS12 CreateKeyPair", "error", err.Error()) + return nil, nil + } + + p := s.G2.CreatePointForScalar(sc) + + return sc, p +} + +// GetUnderlyingSuite returns the underlying suite +func (s *SuiteBLS12377) GetUnderlyingSuite() interface{} { + return s +} + +// CheckPointValid returns error if the point is not valid (zero is also not valid), otherwise nil +func (s *SuiteBLS12377) CheckPointValid(pointBytes []byte) error { + if len(pointBytes) != s.PointLen() { + return crypto.ErrInvalidParam + } + + point := s.G2.CreatePoint() + err := point.UnmarshalBinary(pointBytes) + if err != nil { + return err + } + + pG2, ok := point.GetUnderlyingObj().(*gnark.G2Jac) + if !ok || !pG2.IsOnCurve() || !pG2.IsInSubGroup() { + return crypto.ErrInvalidPoint + } + + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (s *SuiteBLS12377) IsInterfaceNil() bool { + return s == nil +} diff --git a/curves/bls/bls12377/suiteBLS12_377_test.go b/curves/bls/bls12377/suiteBLS12_377_test.go new file mode 100644 index 0000000..7ae1f2c --- /dev/null +++ b/curves/bls/bls12377/suiteBLS12_377_test.go @@ -0,0 +1,181 @@ +package bls12377 + +import ( + "encoding/hex" + "math/big" + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSuiteBLS12(t *testing.T) { + suite := NewSuiteBLS12() + + assert.NotNil(t, suite) +} + +func TestSuiteBLS12_RandomStream(t *testing.T) { + suite := NewSuiteBLS12() + stream := suite.RandomStream() + require.Nil(t, stream) +} + +func TestSuiteBLS12_CreatePoint(t *testing.T) { + suite := NewSuiteBLS12() + + point1 := suite.CreatePoint() + point2 := suite.CreatePoint() + + assert.NotNil(t, point1) + assert.NotNil(t, point2) + assert.False(t, point1 == point2) +} + +func TestSuiteBLS12_String(t *testing.T) { + suite := NewSuiteBLS12() + + str := suite.String() + assert.Equal(t, "BLS12-377 suite", str) +} + +func TestSuiteBLS12_ScalarLen(t *testing.T) { + suite := NewSuiteBLS12() + + length := suite.ScalarLen() + assert.Equal(t, 32, length) +} + +func TestSuiteBLS12_CreateScalar(t *testing.T) { + suite := NewSuiteBLS12() + + scalar := suite.CreateScalar() + assert.NotNil(t, scalar) +} + +func TestSuiteBLS12_CreatePointForScalar(t *testing.T) { + suite := NewSuiteBLS12() + scalar := NewScalar() + + point, err := suite.CreatePointForScalar(scalar) + require.Nil(t, err) + pG2, ok := point.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.NotNil(t, pG2) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + blsScalar, _ := scalar.GetUnderlyingObj().(*fr.Element) + blsScalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, pG2.Equal(computedG2)) +} +func TestSuiteBLS12_PointLen(t *testing.T) { + suite := NewSuiteBLS12() + + pointLength := suite.PointLen() + + // G2 point length is 128 bytes + assert.Equal(t, 96, pointLength) +} + +func TestSuiteBLS12_CreateKey(t *testing.T) { + suite := NewSuiteBLS12() + private, public := suite.CreateKeyPair() + assert.NotNil(t, private) + assert.NotNil(t, public) +} + +func TestSuiteBLS12_GetUnderlyingSuite(t *testing.T) { + suite := NewSuiteBLS12() + + obj := suite.GetUnderlyingSuite() + + assert.NotNil(t, obj) +} + +func TestSuiteBLS12_CheckPointValidOK(t *testing.T) { + t.Skip() + // valid point: "a0ea6040e700403170dc5a51b1b140d5532777ee6651cecbe7223ece0799c9de5cf89984bff76fe6b26bfefa6ea16a" + + // "fe018480be71c785fec89630a2a3841d01c565f071203e50317ea501f557db6b9b71889f52bb53540274e3e48f7c005196" + + validPointHexStr := "33e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7" + + "e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8" + suite := NewSuiteBLS12() + + validPointBytes, err := hex.DecodeString(validPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(validPointBytes) + require.Nil(t, err) +} + +func TestSuiteBLS12_CheckPointValidShortHexStringShouldErr(t *testing.T) { + shortPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + suite := NewSuiteBLS12() + + shortPointBytes, err := hex.DecodeString(shortPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(shortPointBytes) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestSuiteBLS12_CheckPointValidLongHexStrShouldErr(t *testing.T) { + longPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + suite := NewSuiteBLS12() + + longPointBytes, err := hex.DecodeString(longPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(longPointBytes) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestSuiteBLS12_CheckPointValidInvalidPointHexStrShouldErr(t *testing.T) { + invalidPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5caaaaaaaaaaa" + oneHexCharCorruptedPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74dca0a" + suite := NewSuiteBLS12() + + invalidPointBytes, err := hex.DecodeString(invalidPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(invalidPointBytes) + require.NotNil(t, err) + + oneHexCharCorruptedPointBytes, err := hex.DecodeString(oneHexCharCorruptedPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(oneHexCharCorruptedPointBytes) + require.NotNil(t, err) +} + +func TestSuiteBLS12_CheckPointValidZeroHexStrShouldErr(t *testing.T) { + t.Skip() + + zeroPointHexStr := "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + suite := NewSuiteBLS12() + + zeroPointBytes, err := hex.DecodeString(zeroPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(zeroPointBytes) + require.Equal(t, crypto.ErrInvalidPoint, err) +} + +func TestSuiteBLS12_IsInterfaceNil(t *testing.T) { + t.Parallel() + var suite *SuiteBLS12377 + + require.True(t, check.IfNil(suite)) + suite = NewSuiteBLS12() + require.False(t, check.IfNil(suite)) +} diff --git a/curves/bls/bls12381/g1.go b/curves/bls/bls12381/g1.go new file mode 100644 index 0000000..9e3b18c --- /dev/null +++ b/curves/bls/bls12381/g1.go @@ -0,0 +1,52 @@ +package bls12381 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupG1 struct { +} + +// String returns the string for the group +func (g1 *groupG1) String() string { + return "BLS12-381 G1" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (g1 *groupG1) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar initialized with base point on G1 +func (g1 *groupG1) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (g1 *groupG1) PointLen() int { + return fp.Bytes +} + +// CreatePoint creates a new point +func (g1 *groupG1) CreatePoint() crypto.Point { + return NewPointG1() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (g1 *groupG1) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + var p crypto.Point + var err error + p = NewPointG1() + p, err = p.Mul(scalar) + if err != nil { + log.Error("groupG1 CreatePointForScalar", "error", err.Error()) + } + return p +} + +// IsInterfaceNil returns true if there is no value under the interface +func (g1 *groupG1) IsInterfaceNil() bool { + return g1 == nil +} diff --git a/curves/bls/bls12381/g1_test.go b/curves/bls/bls12381/g1_test.go new file mode 100644 index 0000000..3411409 --- /dev/null +++ b/curves/bls/bls12381/g1_test.go @@ -0,0 +1,161 @@ +package bls12381 + +import ( + "math/big" + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/require" +) + +func TestGroupG1_String(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + str := grG1.String() + require.Equal(t, "BLS12-381 G1", str) +} + +func TestGroupG1_ScalarLen(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + x := grG1.ScalarLen() + require.Equal(t, fr.Bytes, x) +} + +func TestGroupG1_PointLen(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + x := grG1.PointLen() + require.Equal(t, 48, x) +} + +func TestGroupG1_CreatePoint(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + g1Gen, _, _, _ := gnark.Generators() + point := &PointG1{ + G1: &g1Gen, + } + + x := grG1.CreatePoint() + require.NotNil(t, x) + bls12381Point, ok := x.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381Point.IsOnCurve()) + require.Equal(t, point.G1, bls12381Point) +} + +func TestGroupG1_CreateScalar(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + sc := grG1.CreateScalar() + require.NotNil(t, sc) + + bls12381Scalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) +} + +func TestGroupG1_CreatePointForScalar(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + bls12381PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381PointG1.IsOnCurve()) + + bG1 := NewPointG1().G1 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG1 := bG1.ScalarMultiplication(bG1, &scalarBigInt) + + require.True(t, bls12381PointG1.Equal(computedG1)) +} + +func TestGroupG1_CreatePointForScalarZero(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + scalar.SetInt64(0) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsZero()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + bls12381PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381PointG1.Z.IsZero()) + require.True(t, bls12381PointG1.IsOnCurve()) + + bG1 := NewPointG1().G1 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG1 := bG1.ScalarMultiplication(bG1, &scalarBigInt) + + require.True(t, bls12381PointG1.Equal(computedG1)) + +} + +func TestGroupG1_CreatePointForScalarOne(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + scalar.SetInt64(1) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsOne()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + baseG1 := NewPointG1().G1 + bls12381PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bls12381PointG1.Equal(baseG1)) +} + +func TestGroupG1_CreatePointForScalarNil(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + pG1 := grG1.CreatePointForScalar(nil) + require.Equal(t, nil, pG1) +} + +func TestGroupG1_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grG1 *groupG1 + + require.True(t, grG1.IsInterfaceNil()) + grG1 = &groupG1{} + require.False(t, grG1.IsInterfaceNil()) +} diff --git a/curves/bls/bls12381/g2.go b/curves/bls/bls12381/g2.go new file mode 100644 index 0000000..60a75d7 --- /dev/null +++ b/curves/bls/bls12381/g2.go @@ -0,0 +1,52 @@ +package bls12381 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupG2 struct { +} + +// String returns the string for the group +func (g2 *groupG2) String() string { + return "BLS12-381 G2" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (g2 *groupG2) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar initialized with base point on G2 +func (g2 *groupG2) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (g2 *groupG2) PointLen() int { + return fp.Bytes * 2 +} + +// CreatePoint creates a new point +func (g2 *groupG2) CreatePoint() crypto.Point { + return NewPointG2() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (g2 *groupG2) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + var p crypto.Point + var err error + p = NewPointG2() + p, err = p.Mul(scalar) + if err != nil { + log.Error("groupG2 CreatePointForScalar", "error", err.Error()) + } + return p +} + +// IsInterfaceNil returns true if there is no value under the interface +func (g2 *groupG2) IsInterfaceNil() bool { + return g2 == nil +} diff --git a/curves/bls/bls12381/g2_test.go b/curves/bls/bls12381/g2_test.go new file mode 100644 index 0000000..fd636eb --- /dev/null +++ b/curves/bls/bls12381/g2_test.go @@ -0,0 +1,151 @@ +package bls12381 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + "math/big" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGroupG2_String(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + str := grG2.String() + require.Equal(t, str, "BLS12-381 G2") +} + +func TestGroupG2_ScalarLen(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + x := grG2.ScalarLen() + require.Equal(t, 32, x) +} + +func TestGroupG2_PointLen(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + x := grG2.PointLen() + require.Equal(t, 96, x) +} + +func TestGroupG2_CreatePoint(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + point := &PointG2{ + G2: &gnark.G2Jac{}, + } + + _, g2Gen, _, _ := gnark.Generators() + point.G2 = &g2Gen + + x := grG2.CreatePoint() + require.NotNil(t, x) + bls12381Point, ok := x.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.Equal(t, point.G2, bls12381Point) +} + +func TestGroupG2_CreateScalar(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + sc := grG2.CreateScalar() + require.NotNil(t, sc) + + bls12381Scalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) +} + +func TestGroupG2_CreatePointForScalar(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bls12381Scalar.IsZero()) + require.False(t, bls12381Scalar.IsOne()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bls12381PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, bls12381PointG2.Equal(computedG2)) +} + +func TestGroupG2_CreatePointForScalarZero(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + scalar.SetInt64(0) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsZero()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bls12381PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bls12381PointG2.Z.IsZero()) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + bls12381Scalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, bls12381PointG2.Equal(computedG2)) +} + +func TestGroupG2_CreatePointForScalarOne(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + scalar.SetInt64(1) + bls12381Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bls12381Scalar.IsOne()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bG2 := NewPointG2().G2 + bls12381PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bls12381PointG2.Equal(bG2)) +} + +func TestGroupG2_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grG2 *groupG2 + + require.True(t, check.IfNil(grG2)) + grG2 = &groupG2{} + require.False(t, check.IfNil(grG2)) +} diff --git a/curves/bls/bls12381/gt.go b/curves/bls/bls12381/gt.go new file mode 100644 index 0000000..54919d8 --- /dev/null +++ b/curves/bls/bls12381/gt.go @@ -0,0 +1,45 @@ +package bls12381 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupGT struct { +} + +// String returns the string for the group +func (gt *groupGT) String() string { + return "BLS12-381 GT" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (gt *groupGT) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar +func (gt *groupGT) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (gt *groupGT) PointLen() int { + return fp.Bytes * 12 +} + +// CreatePoint creates a new point +func (gt *groupGT) CreatePoint() crypto.Point { + return NewPointGT() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (gt *groupGT) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + panic("not supported") +} + +// IsInterfaceNil returns true if there is no value under the interface +func (gt *groupGT) IsInterfaceNil() bool { + return gt == nil +} diff --git a/curves/bls/bls12381/gt_test.go b/curves/bls/bls12381/gt_test.go new file mode 100644 index 0000000..8d7562f --- /dev/null +++ b/curves/bls/bls12381/gt_test.go @@ -0,0 +1,96 @@ +package bls12381 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/assert" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGroupGT_String(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + str := grGT.String() + require.Equal(t, str, "BLS12-381 GT") +} + +func TestGroupGT_ScalarLen(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.ScalarLen() + require.Equal(t, 32, x) +} + +func TestGroupGT_PointLen(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.PointLen() + require.Equal(t, 48*12, x) +} + +func TestGroupGT_CreatePoint(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.CreatePoint() + require.NotNil(t, x) + + bls12381Point, ok := x.GetUnderlyingObj().(*gnark.GT) + require.True(t, ok) + // points created on GT are initialized with PointZero + require.True(t, bls12381Point.IsZero()) +} + +func TestGroupGT_CreateScalar(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + sc := grGT.CreateScalar() + require.NotNil(t, sc) + + bn254Scalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bn254Scalar.IsZero()) + require.False(t, bn254Scalar.IsOne()) +} + +func TestGroupGT_CreatePointForScalar(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should panic as currently not supported") + } + }() + + grGT := &groupGT{} + + scalar := grGT.CreateScalar() + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bn254Scalar.IsZero()) + require.False(t, bn254Scalar.IsOne()) + + _ = grGT.CreatePointForScalar(scalar) +} + +func TestGroupGT_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grGT *groupGT + + require.True(t, grGT.IsInterfaceNil()) + grGT = &groupGT{} + require.False(t, grGT.IsInterfaceNil()) +} diff --git a/curves/bls/bls12381/interop/flags.go b/curves/bls/bls12381/interop/flags.go new file mode 100644 index 0000000..f0b601c --- /dev/null +++ b/curves/bls/bls12381/interop/flags.go @@ -0,0 +1,48 @@ +package interop + +import ( + "errors" + "slices" +) + +const ( + g2CompressedSize = 96 + compressedMask = 0x80 + yOddMask = 0x20 +) + +// PointBytesFromMcl adds the compression flag to a point from MCL since in non-eth mode it is not used, +// but we always know we are using the compressed version. Also, MCL sets the Y odd flag as the MSB +// of the last byte which we need to use for compression, so we add that as the 3rd MSB in that +// byte as Ganrk expects it +func PointBytesFromMcl(rawPoint []byte) ([]byte, error) { + if len(rawPoint) != g2CompressedSize { + return nil, errors.New("interop: raw MCL point must be 96 bytes") + } + + isYodd := (rawPoint[g2CompressedSize-1] >> 7) != 0 + // fmt.Println("MCL Point is odd", isYodd) + be := reverseBytes(rawPoint) + + markCompressed(be) + if isYodd { + markOdd(be) + } + + return be, nil +} + +func reverseBytes(in []byte) []byte { + out := append([]byte(nil), in...) + slices.Reverse(out) + + return out +} + +func markCompressed(buf []byte) { + buf[0] |= compressedMask +} + +func markOdd(buf []byte) { + buf[0] |= yOddMask +} diff --git a/curves/bls/bls12381/interop/flags_test.go b/curves/bls/bls12381/interop/flags_test.go new file mode 100644 index 0000000..0263a74 --- /dev/null +++ b/curves/bls/bls12381/interop/flags_test.go @@ -0,0 +1,46 @@ +package interop_test + +import ( + "testing" + + "github.com/multiversx/mx-chain-crypto-go/curves/bls/bls12381/interop" + "github.com/stretchr/testify/require" +) + +const g2CompressedSize = 96 + +// reverse returns a new slice with elements of in reversed. +func reverse(in []byte) []byte { + out := make([]byte, len(in)) + for i := range in { + out[i] = in[len(in)-1-i] + } + return out +} + +func TestPointBytesFromMcl_Success(t *testing.T) { + raw := make([]byte, g2CompressedSize) + for i := range raw { + raw[i] = byte(i) + } + + out, err := interop.PointBytesFromMcl(raw) + require.Nil(t, err) + + expected := reverse(raw) + expected[0] |= 0x80 + + require.Equal(t, expected, out) +} + +func TestPointBytesFromMcl_BadLength(t *testing.T) { + tooShort := make([]byte, 10) + if _, err := interop.PointBytesFromMcl(tooShort); err == nil { + t.Error("PointBytesFromMcl should error on short input") + } + + tooLong := make([]byte, g2CompressedSize+1) + if _, err := interop.PointBytesFromMcl(tooLong); err == nil { + t.Error("PointBytesFromMcl should error on long input") + } +} diff --git a/curves/bls/bls12381/pointG1.go b/curves/bls/bls12381/pointG1.go new file mode 100644 index 0000000..2473bba --- /dev/null +++ b/curves/bls/bls12381/pointG1.go @@ -0,0 +1,206 @@ +package bls12381 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointG1 - +type PointG1 struct { + G1 *gnark.G1Jac +} + +func NewPointG1() *PointG1 { + point := &PointG1{ + G1: &gnark.G1Jac{}, + } + + g1Gen, _, _, _ := gnark.Generators() + point.G1 = &g1Gen + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointG1) Equal(p crypto.Point) (bool, error) { + if p == nil { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointG1) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.G1.Equal(po2.G1), nil +} + +// Clone returns a clone of the receiver. +func (po *PointG1) Clone() crypto.Point { + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.Set(po.G1) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointG1) Null() crypto.Point { + p := &PointG1{ + G1: &gnark.G1Jac{}, + } + + p.G1.Z.SetZero() + p.G1.X.SetOne() + p.G1.Y.SetOne() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointG1) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return crypto.ErrInvalidParam + } + + po.G1.Set(po1.G1) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointG1) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G1 = po2.G1.AddAssign(po1.G1) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointG1) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G1 = po2.G1.SubAssign(po1.G1) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointG1) Neg() crypto.Point { + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.Neg(po.G1) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointG1) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.G1 = po2.G1.ScalarMultiplication(po.G1, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointG1) Pick() (crypto.Point, error) { + scalar := NewScalar() + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.ScalarMultiplication(po.G1, scalar.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointG1) GetUnderlyingObj() interface{} { + return po.G1 +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointG1) MarshalBinary() ([]byte, error) { + affinePoint := &gnark.G1Affine{} + affinePoint.FromJacobian(po.G1) + + return affinePoint.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointG1) UnmarshalBinary(point []byte) error { + affinePoint := &gnark.G1Affine{} + err := affinePoint.Unmarshal(point) + if err != nil { + return err + } + + po.G1 = po.G1.FromAffine(affinePoint) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointG1) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bls/bls12381/pointG1_test.go b/curves/bls/bls12381/pointG1_test.go new file mode 100644 index 0000000..6770a42 --- /dev/null +++ b/curves/bls/bls12381/pointG1_test.go @@ -0,0 +1,310 @@ +package bls12381 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNewPointG1(t *testing.T) { + g1Gen, _, _, _ := gnark.Generators() + bG1 := &g1Gen + + pG1 := NewPointG1() + require.NotNil(t, pG1) + + mclPointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bG1.Equal(mclPointG1)) +} + +func TestPointG1_Equal(t *testing.T) { + p1G1 := NewPointG1() + p2G1 := NewPointG1() + + eq, err := p1G1.Equal(p2G1) + require.Nil(t, err) + require.True(t, eq) + + // Make p1G1 different by multiplying it by 2 + scalar := NewScalar() + scalar.SetInt64(2) + p1Modified, err := p1G1.Mul(scalar) + require.Nil(t, err) + p1G1 = p1Modified.(*PointG1) + + eq, err = p1G1.Equal(p2G1) + require.Nil(t, err) + require.False(t, eq) + + grG1 := &groupG1{} + sc1G1 := grG1.CreateScalar() + p1 := grG1.CreatePointForScalar(sc1G1) + p2 := grG1.CreatePointForScalar(sc1G1) + + var ok bool + p1G1, ok = p1.(*PointG1) + require.True(t, ok) + + p2G1, ok = p2.(*PointG1) + require.True(t, ok) + + eq, err = p1G1.Equal(p2G1) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_CloneNilShouldPanic(t *testing.T) { + var p1 *PointG1 + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointG1_Clone(t *testing.T) { + p1 := NewPointG1() + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_Null(t *testing.T) { + p1 := NewPointG1() + + point := p1.Null() + bls12381Point, ok := point.(*PointG1) + require.True(t, ok) + require.True(t, bls12381Point.G1.X.IsOne()) + require.True(t, bls12381Point.G1.Y.IsOne()) + require.True(t, bls12381Point.G1.Z.IsZero()) + + bls12381PointNeg := &gnark.G1Jac{} + bls12381PointNeg = bls12381PointNeg.Neg(bls12381Point.G1) + + // neutral identity point should be equal to it's negation + ok = bls12381Point.G1.Equal(bls12381PointNeg) + require.True(t, ok) +} + +func TestPointG1_Set(t *testing.T) { + p1 := NewPointG1() + p2 := NewPointG1() + + scalar := NewScalar() + scalar.SetInt64(2) + p2Modified, err := p2.Mul(scalar) + require.Nil(t, err) + p2 = p2Modified.(*PointG1) + + err = p1.Set(p2) + require.Nil(t, err) + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG1_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG1_AddOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point1, err := pointG1.Pick() + require.Nil(t, err) + + point2, err := pointG1.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointG1_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point2, err := pointG1.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG1_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point2 := &mock.PointMock{} + point3, err := pointG1.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG1_SubOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point1, err := pointG1.Pick() + require.Nil(t, err) + + point2, err := pointG1.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG1_Neg(t *testing.T) { + point1 := NewPointG1() + + point2 := point1.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point1, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point1, point3) +} + +func TestPointG1_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointG1_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointG1_MulOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := pointG1.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, pointG1, res) + + grG1 := &groupG1{} + point2 := grG1.CreatePointForScalar(scalar) + eq, err := res.Equal(point2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointG1_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointG1_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointG1_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointG1().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointG1() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG1_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointG1 + + require.True(t, check.IfNil(point)) + point = NewPointG1() + require.False(t, check.IfNil(point)) +} diff --git a/curves/bls/bls12381/pointG2.go b/curves/bls/bls12381/pointG2.go new file mode 100644 index 0000000..2f21afd --- /dev/null +++ b/curves/bls/bls12381/pointG2.go @@ -0,0 +1,208 @@ +package bls12381 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointG2 - +type PointG2 struct { + G2 *gnark.G2Jac +} + +// NewPointG2 creates a new point on G2 initialized with base point +func NewPointG2() *PointG2 { + point := &PointG2{ + G2: &gnark.G2Jac{}, + } + + _, g2Gen, _, _ := gnark.Generators() + point.G2 = &g2Gen + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointG2) Equal(p crypto.Point) (bool, error) { + if check.IfNil(p) { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointG2) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.G2.Equal(po2.G2), nil +} + +// Clone returns a clone of the receiver. +func (po *PointG2) Clone() crypto.Point { + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.Set(po.G2) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointG2) Null() crypto.Point { + p := &PointG2{ + G2: &gnark.G2Jac{}, + } + + p.G2.Z.SetOne() + p.G2.X.SetZero() + p.G2.Y.SetZero() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointG2) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return crypto.ErrInvalidParam + } + + po.G2.Set(po1.G2) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointG2) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G2 = po2.G2.AddAssign(po1.G2) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointG2) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G2 = po2.G2.SubAssign(po1.G2) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointG2) Neg() crypto.Point { + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.Neg(po.G2) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointG2) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.G2 = po2.G2.ScalarMultiplication(po.G2, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointG2) Pick() (crypto.Point, error) { + scalar := NewScalar() + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.ScalarMultiplication(po.G2, scalar.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointG2) GetUnderlyingObj() interface{} { + return po.G2 +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointG2) MarshalBinary() ([]byte, error) { + affinePoint := &gnark.G2Affine{} + affinePoint.FromJacobian(po.G2) + + return affinePoint.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointG2) UnmarshalBinary(point []byte) error { + affinePoint := &gnark.G2Affine{} + affinePoint.Y.LexicographicallyLargest() + err := affinePoint.Unmarshal(point) + if err != nil { + return err + } + + po.G2 = po.G2.FromAffine(affinePoint) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointG2) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bls/bls12381/pointG2_test.go b/curves/bls/bls12381/pointG2_test.go new file mode 100644 index 0000000..de3cdd1 --- /dev/null +++ b/curves/bls/bls12381/pointG2_test.go @@ -0,0 +1,312 @@ +package bls12381 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "testing" +) + +func TestNewPointG2(t *testing.T) { + _, g2Gen, _, _ := gnark.Generators() + bG2 := &g2Gen + + pG2 := NewPointG2() + require.NotNil(t, pG2) + + mclPointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bG2.Equal(mclPointG2)) +} + +func TestPointG2_Equal(t *testing.T) { + p1G2 := NewPointG2() + p2G2 := NewPointG2() + + // new points should be initialized with base point so should be equal + eq, err := p1G2.Equal(p2G2) + require.Nil(t, err) + require.True(t, eq) + + // Make p1G1 different by multiplying it by 2 + scalar := NewScalar() + scalar.SetInt64(2) + p1Modified, err := p1G2.Mul(scalar) + require.Nil(t, err) + p1G2 = p1Modified.(*PointG2) + + eq, err = p1G2.Equal(p2G2) + require.Nil(t, err) + require.False(t, eq) + + grG2 := &groupG2{} + sc1G2 := grG2.CreateScalar() + p1 := grG2.CreatePointForScalar(sc1G2) + p2 := grG2.CreatePointForScalar(sc1G2) + + var ok bool + p1G2, ok = p1.(*PointG2) + require.True(t, ok) + + p2G2, ok = p2.(*PointG2) + require.True(t, ok) + + eq, err = p1G2.Equal(p2G2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_CloneNilShouldPanic(t *testing.T) { + var p1 *PointG2 + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointG2_Clone(t *testing.T) { + p1 := NewPointG2() + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_Null(t *testing.T) { + p1 := NewPointG2() + + point := p1.Null() + bls12381Point, ok := point.(*PointG2) + require.True(t, ok) + require.True(t, bls12381Point.G2.X.IsZero()) + require.True(t, bls12381Point.G2.Y.IsZero()) + require.True(t, bls12381Point.G2.Z.IsOne()) + + bls12381PointNeg := &gnark.G2Jac{} + bls12381PointNeg = bls12381PointNeg.Neg(bls12381Point.G2) + + // neutral identity point should be equal to it's negation + ok = bls12381Point.G2.Equal(bls12381PointNeg) + require.True(t, ok) +} + +func TestPointG2_Set(t *testing.T) { + p1 := NewPointG2() + p2 := NewPointG2() + + scalar := NewScalar() + scalar.SetInt64(2) + p2Modified, err := p2.Mul(scalar) + require.Nil(t, err) + p2 = p2Modified.(*PointG2) + + err = p1.Set(p2) + require.Nil(t, err) + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG2_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG2_AddOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point1, err := pointG2.Pick() + require.Nil(t, err) + + point2, err := pointG2.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointG2_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point2, err := pointG2.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG2_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point2 := &mock.PointMock{} + point3, err := pointG2.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG2_SubOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point1, err := pointG2.Pick() + require.Nil(t, err) + + point2, err := pointG2.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG2_Neg(t *testing.T) { + point1 := NewPointG2() + + point2 := point1.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point1, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point1, point3) +} + +func TestPointG2_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointG2_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointG2_MulOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := pointG2.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, pointG2, res) + + grG2 := &groupG2{} + point2 := grG2.CreatePointForScalar(scalar) + eq, err := res.Equal(point2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointG2_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointG2_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointG2_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointG2().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointG2() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG2_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointG2 + + require.True(t, check.IfNil(point)) + point = NewPointG2() + require.False(t, check.IfNil(point)) +} diff --git a/curves/bls/bls12381/pointGT.go b/curves/bls/bls12381/pointGT.go new file mode 100644 index 0000000..b279bac --- /dev/null +++ b/curves/bls/bls12381/pointGT.go @@ -0,0 +1,221 @@ +package bls12381 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointGT - +type PointGT struct { + *gnark.GT +} + +// NewPointGT creates a new point on GT initialized with identity +func NewPointGT() *PointGT { + point := &PointGT{ + GT: &gnark.GT{}, + } + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointGT) Equal(p crypto.Point) (bool, error) { + if check.IfNil(p) { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointGT) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.GT.Equal(po2.GT), nil +} + +// Clone returns a clone of the receiver. +func (po *PointGT) Clone() crypto.Point { + po2 := PointGT{ + GT: &gnark.GT{}, + } + + po2.GT = po2.GT.Set(po.GT) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointGT) Null() crypto.Point { + p := NewPointGT() + p.GT.C0.B0.SetZero() + p.GT.C0.B1.SetZero() + p.GT.C0.B2.SetZero() + + p.GT.C1.B0.SetZero() + p.GT.C1.B1.SetZero() + p.GT.C1.B2.SetZero() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointGT) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return crypto.ErrInvalidParam + } + + po.GT.Set(po1.GT) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointGT) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.GT = po2.GT.Add(po2.GT, po1.GT) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointGT) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.GT = po2.GT.Sub(po2.GT, po1.GT) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointGT) Neg() crypto.Point { + po2 := PointGT{ + GT: &gnark.GT{}, + } + + // Multiplicative in GT, we can use inverse + po2.GT = po2.GT.Inverse(po.GT) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointGT) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.GT = po2.GT.Exp(*po.GT, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointGT) Pick() (crypto.Point, error) { + var p1, p2 crypto.Point + var err error + + p1, err = NewPointG1().Pick() + if err != nil { + return nil, err + } + + p2, err = NewPointG2().Pick() + if err != nil { + return nil, err + } + + poG1 := p1.(*PointG1) + poG2 := p2.(*PointG2) + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + g1Affine := &gnark.G1Affine{} + g1Affine.FromJacobian(poG1.G1) + + g2Affine := &gnark.G2Affine{} + g2Affine.FromJacobian(poG2.G2) + + paired, err := gnark.Pair([]gnark.G1Affine{*g1Affine}, []gnark.G2Affine{*g2Affine}) + if err != nil { + return nil, err + } + + po2.GT = &paired + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointGT) GetUnderlyingObj() interface{} { + return po.GT +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointGT) MarshalBinary() ([]byte, error) { + return po.GT.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointGT) UnmarshalBinary(point []byte) error { + return po.GT.Unmarshal(point) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointGT) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bls/bls12381/pointGT_test.go b/curves/bls/bls12381/pointGT_test.go new file mode 100644 index 0000000..f30f3ea --- /dev/null +++ b/curves/bls/bls12381/pointGT_test.go @@ -0,0 +1,276 @@ +package bls12381 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "testing" +) + +func TestNewPointGT(t *testing.T) { + pGT := NewPointGT() + require.NotNil(t, pGT) + + bls12381PointGT, ok := pGT.GetUnderlyingObj().(*gnark.GT) + require.True(t, ok) + require.True(t, bls12381PointGT.IsZero()) +} + +func TestPointGT_Equal(t *testing.T) { + p1GT := NewPointGT() + p2GT := NewPointGT() + + eq, err := p1GT.Equal(p2GT) + require.Nil(t, err) + require.True(t, eq) + + p2GT.SetOne() + eq, err = p1GT.Equal(p2GT) + require.Nil(t, err) + require.False(t, eq) +} + +func TestPointGT_CloneNilShouldPanic(t *testing.T) { + var p1 *PointGT + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointGT_Clone(t *testing.T) { + p1 := NewPointGT() + _, err := p1.GT.SetRandom() + require.Nil(t, err) + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointGT_Null(t *testing.T) { + p1 := NewPointGT() + + point := p1.Null() + bls12381Point, ok := point.(*PointGT) + require.True(t, ok) + require.True(t, bls12381Point.IsZero()) + bls12381PointNeg := &gnark.GT{} + // TODO + + // neutral identity point should be equal to it's negation + require.True(t, bls12381Point.GT.Equal(bls12381PointNeg)) +} + +func TestPointGT_Set(t *testing.T) { + p1 := NewPointGT() + p2 := NewPointGT() + + p2.GT.SetOne() + + err := p1.Set(p2) + require.Nil(t, err) + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointGT_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointGT_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointGT_AddOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point1, err := pointGT.Pick() + require.Nil(t, err) + + point2, err := pointGT.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointGT_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2, err := pointGT.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointGT_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2 := &mock.PointMock{} + point3, err := pointGT.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointGT_SubOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point1, err := pointGT.Pick() + require.Nil(t, err) + + point2, err := pointGT.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointGT_Neg(t *testing.T) { + point1 := NewPointGT() + point, _ := point1.Pick() + + point2 := point.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point, point3) +} + +func TestPointGT_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointGT_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointGT_MulOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2, _ := pointGT.Pick() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := point2.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, point2, res) +} + +func TestPointGT_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointGT_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointGT_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointGT_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointGT().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointGT() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointGT_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointGT + + require.True(t, point.IsInterfaceNil()) + point = NewPointGT() + require.False(t, point.IsInterfaceNil()) +} diff --git a/curves/bls/bls12381/scalar.go b/curves/bls/bls12381/scalar.go new file mode 100644 index 0000000..9c85c07 --- /dev/null +++ b/curves/bls/bls12381/scalar.go @@ -0,0 +1,271 @@ +package bls12381 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// Scalar - +type Scalar struct { + Scalar *fr.Element +} + +func NewScalar() *Scalar { + scalar := &Scalar{Scalar: &fr.Element{}} + scalar.setRandom() + + for scalar.Scalar.IsOne() || scalar.Scalar.IsZero() { + scalar.setRandom() + } + + return scalar +} + +// Equal tests if receiver is equal with the scalarInt s given as parameter. +// Both scalars need to be derived from the same Group +func (sc *Scalar) Equal(s crypto.Scalar) (bool, error) { + if check.IfNil(s) { + return false, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return false, crypto.ErrInvalidParam + } + + areEqual := sc.Scalar.Equal(s2.Scalar) + + return areEqual, nil +} + +// Set sets the receiver to Scalar s given as parameter +func (sc *Scalar) Set(s crypto.Scalar) error { + if check.IfNil(s) { + return crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return crypto.ErrInvalidParam + } + + return sc.Scalar.SetBytesCanonical(s2.Scalar.Marshal()) +} + +// Clone creates a new Scalar with same value as receiver +func (sc *Scalar) Clone() crypto.Scalar { + scalar := &Scalar{ + Scalar: &fr.Element{}, + } + + scalar.Scalar.SetBytes(sc.Scalar.Marshal()) + + return scalar +} + +// SetInt64 sets the receiver to a small integer value v given as parameter +func (sc *Scalar) SetInt64(v int64) { + sc.Scalar.SetInt64(v) +} + +// Zero returns the additive identity (0) +func (sc *Scalar) Zero() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + s.Scalar.SetZero() + + return &s +} + +// Add returns the modular sum of receiver with scalar s given as parameter +func (sc *Scalar) Add(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := &Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Add(sc.Scalar, s2.Scalar) + + return s1, nil +} + +// Sub returns the modular difference between receiver and scalar s given as parameter +func (sc *Scalar) Sub(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := &Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Sub(sc.Scalar, s2.Scalar) + + return s1, nil +} + +// Neg returns the modular negation of receiver +func (sc *Scalar) Neg() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + + s.Scalar.Neg(sc.Scalar) + + return &s +} + +// One returns the multiplicative identity (1) +func (sc *Scalar) One() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + s.Scalar.SetOne() + + return &s +} + +// Mul returns the modular product of receiver with scalar s given as parameter +func (sc *Scalar) Mul(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Mul(sc.Scalar, s2.Scalar) + + return &s1, nil +} + +// Div returns the modular division between receiver and scalar s given as parameter +func (sc *Scalar) Div(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Div(sc.Scalar, s2.Scalar) + + return &s1, nil +} + +// Inv returns the modular inverse of scalar s given as parameter +func (sc *Scalar) Inv(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Inverse(s2.Scalar) + + return &s1, nil +} + +// Pick returns a fresh random or pseudo-random scalar +// For the mock set X to the original scalar.X *2 +func (sc *Scalar) Pick() (crypto.Scalar, error) { + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + _, err := s1.Scalar.SetRandom() + if err != nil { + return nil, err + } + + for s1.Scalar.IsOne() || s1.Scalar.IsZero() { + _, err = s1.Scalar.SetRandom() + if err != nil { + return nil, err + } + } + + return &s1, nil +} + +// SetBytes sets the scalar from a byte-slice, +// reducing if necessary to the appropriate modulus. +func (sc *Scalar) SetBytes(s []byte) (crypto.Scalar, error) { + if len(s) == 0 { + return nil, crypto.ErrNilParam + } + + s1 := sc.Clone() + s2, ok := s1.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidScalar + } + + err := s2.Scalar.SetBytesCanonical(s) + if err != nil { + return nil, err + } + + return s1, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (sc *Scalar) GetUnderlyingObj() interface{} { + return sc.Scalar +} + +// MarshalBinary transforms the Scalar into a byte array +func (sc *Scalar) MarshalBinary() ([]byte, error) { + return sc.Scalar.Marshal(), nil +} + +// UnmarshalBinary recreates the Scalar from a byte array +func (sc *Scalar) UnmarshalBinary(val []byte) error { + return sc.Scalar.SetBytesCanonical(val) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sc *Scalar) IsInterfaceNil() bool { + return sc == nil +} + +func (sc *Scalar) setRandom() { + _, err := sc.Scalar.SetRandom() + if err != nil { + panic("BLS12381 cannot read from rand to create a new scalar") + } +} diff --git a/curves/bls/bls12381/scalar_test.go b/curves/bls/bls12381/scalar_test.go new file mode 100644 index 0000000..483f0ed --- /dev/null +++ b/curves/bls/bls12381/scalar_test.go @@ -0,0 +1,406 @@ +package bls12381 + +import ( + "testing" + + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/require" +) + +func TestBLSScalar_EqualInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + eq, err := scalar1.Equal(scalar2) + + require.False(t, eq) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBLSScalar_EqualTrue(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().One() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_EqualFalse(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().Zero() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.False(t, eq) +} + +func TestBLSScalar_SetNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + err := scalar.Set(nil) + + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBLSScalar_SetInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + err := scalar1.Set(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBLSScalar_SetOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().Zero() + err := scalar1.Set(scalar2) + eq, _ := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_Clone(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := scalar1.Clone() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_SetInt64(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2 := NewScalar() + scalar1.SetInt64(int64(555555555)) + scalar2.SetInt64(int64(444444444)) + + diff, _ := scalar1.Sub(scalar2) + scalar3 := NewScalar() + scalar3.SetInt64(int64(111111111)) + + eq, err := diff.Equal(scalar3) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_Zero(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := NewScalar() + scalar2.SetInt64(0) + + eq, err := scalar2.Equal(scalar1) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().Zero() + sum, err := scalar.Add(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, sum) +} + +func TestBLSScalar_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + sum, err := scalar1.Add(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, sum) +} + +func TestBLSScalar_AddOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().One() + sum, err := scalar1.Add(scalar2) + require.Nil(t, err) + scalar3 := NewScalar() + scalar3.SetInt64(2) + eq, err := scalar3.Equal(sum) + + require.True(t, eq) + require.Nil(t, err) +} + +func TestBLSScalar_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().Zero() + diff, err := scalar.Sub(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, diff) +} + +func TestBLSScalar_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + diff, err := scalar1.Sub(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, diff) +} + +func TestBLSScalar_SubOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2 := NewScalar().One() + diff, err := scalar1.Sub(scalar2) + require.Nil(t, err) + scalar3 := NewScalar() + scalar3.SetInt64(3) + eq, err := scalar3.Equal(diff) + + require.True(t, eq) + require.Nil(t, err) +} + +func TestBLSScalar_Neg(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2 := scalar1.Neg() + scalar3 := NewScalar() + scalar3.SetInt64(-4) + eq, err := scalar2.Equal(scalar3) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_One(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(1) + scalar2 := NewScalar().One() + + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBLSScalar_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + res, err := scalar.Mul(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + res, err := scalar1.Mul(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_MulOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar() + scalar2.SetInt64(4) + res, err := scalar1.Mul(scalar2) + + require.Nil(t, err) + + eq, _ := res.Equal(scalar2) + + require.True(t, eq) +} + +func TestBLSScalar_DivNilParamShouldEr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + res, err := scalar.Div(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_DivInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + res, err := scalar1.Div(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, res) +} + +func TestBLSScalar_DivOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar() + scalar2.SetInt64(4) + res, err := scalar2.Div(scalar1) + + require.Nil(t, err) + + eq, _ := res.Equal(scalar2) + + require.True(t, eq) +} + +func TestBLSScalar_InvNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.Inv(nil) + + require.Nil(t, scalar2) + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBLSScalar_InvInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2 := &mock.ScalarMock{} + scalar3, err := scalar1.Inv(scalar2) + + require.Nil(t, scalar3) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBLSScalar_InvOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2, err := scalar1.Inv(scalar1) + eq, _ := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.NotNil(t, scalar2) + require.False(t, eq) + + one := NewScalar().One() + scalar3, err := scalar1.Inv(one) + require.Nil(t, err) + eq, _ = one.Equal(scalar3) + + require.True(t, eq) +} + +func TestBLSScalar_PickOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.Pick() + require.Nil(t, err) + require.NotNil(t, scalar1, scalar2) + + eq, _ := scalar1.Equal(scalar2) + + require.False(t, eq) +} + +func TestBLSScalar_SetBytesNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.SetBytes(nil) + + require.Nil(t, scalar2) + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBLSScalar_SetBytesOK(t *testing.T) { + t.Parallel() + + val := int64(555555555) + scalar1 := NewScalar().One() + + sc2 := NewScalar() + sc2.SetInt64(val) + buf, _ := sc2.MarshalBinary() + + scalar2, err := scalar1.SetBytes(buf) + require.Nil(t, err) + require.NotEqual(t, scalar1, scalar2) + + scalar3 := NewScalar() + scalar3.SetInt64(val) + + eq, _ := scalar3.Equal(scalar2) + require.True(t, eq) +} + +func TestBLSScalar_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + x := scalar1.GetUnderlyingObj() + + require.NotNil(t, x) +} + +func TestBLSScalar_MarshalBinary(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + + scalarBytes, err := scalar1.MarshalBinary() + + require.Nil(t, err) + require.NotNil(t, scalarBytes) +} + +func TestBLSScalar_UnmarshalBinary(t *testing.T) { + scalar1, _ := NewScalar().Pick() + scalarBytes, err := scalar1.MarshalBinary() + require.Nil(t, err) + scalar2 := NewScalar().Zero() + err = scalar2.UnmarshalBinary(scalarBytes) + require.Nil(t, err) + + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} diff --git a/curves/bls/bls12381/suiteBLS12_381.go b/curves/bls/bls12381/suiteBLS12_381.go new file mode 100644 index 0000000..42c903f --- /dev/null +++ b/curves/bls/bls12381/suiteBLS12_381.go @@ -0,0 +1,129 @@ +package bls12381 + +import ( + "crypto/cipher" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("curves/bls12381") + +// SuiteBLS12 provides an implementation of the Suite interface for BLS12-381 +type SuiteBLS12 struct { + G1 *groupG1 + G2 *groupG2 + GT *groupGT + strSuite string +} + +// NewSuiteBLS12 returns a wrapper over a BLS12 curve. +func NewSuiteBLS12() *SuiteBLS12 { + return &SuiteBLS12{ + G1: &groupG1{}, + G2: &groupG2{}, + GT: &groupGT{}, + strSuite: "BLS12-381 suite", + } +} + +// RandomStream returns a cipher.Stream that returns a key stream +// from crypto/rand. +func (s *SuiteBLS12) RandomStream() cipher.Stream { + return nil +} + +// CreatePoint creates a new point +func (s *SuiteBLS12) CreatePoint() crypto.Point { + return s.G2.CreatePoint() +} + +// String returns the string for the group +func (s *SuiteBLS12) String() string { + return s.strSuite +} + +// ScalarLen returns the maximum length of scalars in bytes +func (s *SuiteBLS12) ScalarLen() int { + return s.G2.ScalarLen() +} + +// CreateScalar creates a new Scalar +func (s *SuiteBLS12) CreateScalar() crypto.Scalar { + return s.G2.CreateScalar() +} + +// CreatePointForScalar creates a new point corresponding to the given scalar +func (s *SuiteBLS12) CreatePointForScalar(scalar crypto.Scalar) (crypto.Point, error) { + if check.IfNil(scalar) { + return nil, crypto.ErrNilPrivateKeyScalar + } + sc, ok := scalar.GetUnderlyingObj().(*fr.Element) + if !ok { + return nil, crypto.ErrInvalidScalar + } + + if sc.IsZero() { + return nil, crypto.ErrInvalidPrivateKey + } + + point := s.G2.CreatePointForScalar(scalar) + + return point, nil +} + +// PointLen returns the max length of point in nb of bytes +func (s *SuiteBLS12) PointLen() int { + return s.G2.PointLen() +} + +// CreateKeyPair returns a pair of private public BLS keys. +// The private key is a scalarInt, while the public key is a Point on G2 curve +func (s *SuiteBLS12) CreateKeyPair() (crypto.Scalar, crypto.Point) { + var sc crypto.Scalar + var err error + + sc = s.G2.CreateScalar() + sc, err = sc.Pick() + if err != nil { + log.Error("SuiteBLS12 CreateKeyPair", "error", err.Error()) + return nil, nil + } + + p := s.G2.CreatePointForScalar(sc) + + return sc, p +} + +// GetUnderlyingSuite returns the underlying suite +func (s *SuiteBLS12) GetUnderlyingSuite() interface{} { + return s +} + +// CheckPointValid returns error if the point is not valid (zero is also not valid), otherwise nil +func (s *SuiteBLS12) CheckPointValid(pointBytes []byte) error { + if len(pointBytes) != s.PointLen() { + return crypto.ErrInvalidParam + } + + point := s.G2.CreatePoint() + err := point.UnmarshalBinary(pointBytes) + if err != nil { + return err + } + + pG2, ok := point.GetUnderlyingObj().(*gnark.G2Jac) + if !ok || !pG2.IsOnCurve() || !pG2.IsInSubGroup() { + return crypto.ErrInvalidPoint + } + + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (s *SuiteBLS12) IsInterfaceNil() bool { + return s == nil +} diff --git a/curves/bls/bls12381/suiteBLS12_381_test.go b/curves/bls/bls12381/suiteBLS12_381_test.go new file mode 100644 index 0000000..482feb4 --- /dev/null +++ b/curves/bls/bls12381/suiteBLS12_381_test.go @@ -0,0 +1,199 @@ +package bls12381 + +import ( + "encoding/hex" + "math/big" + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSuiteBLS12(t *testing.T) { + suite := NewSuiteBLS12() + + assert.NotNil(t, suite) +} + +func TestSuiteBLS12_RandomStream(t *testing.T) { + suite := NewSuiteBLS12() + stream := suite.RandomStream() + require.Nil(t, stream) +} + +func TestSuiteBLS12_CreatePoint(t *testing.T) { + suite := NewSuiteBLS12() + + point1 := suite.CreatePoint() + point2 := suite.CreatePoint() + + assert.NotNil(t, point1) + assert.NotNil(t, point2) + assert.False(t, point1 == point2) +} + +func TestSuiteBLS12_String(t *testing.T) { + suite := NewSuiteBLS12() + + str := suite.String() + assert.Equal(t, "BLS12-381 suite", str) +} + +func TestSuiteBLS12_ScalarLen(t *testing.T) { + suite := NewSuiteBLS12() + + length := suite.ScalarLen() + assert.Equal(t, 32, length) +} + +func TestSuiteBLS12_CreateScalar(t *testing.T) { + suite := NewSuiteBLS12() + + scalar := suite.CreateScalar() + assert.NotNil(t, scalar) +} + +func TestSuiteBLS12_CreatePointForScalar(t *testing.T) { + suite := NewSuiteBLS12() + scalar := NewScalar() + + point, err := suite.CreatePointForScalar(scalar) + require.Nil(t, err) + pG2, ok := point.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.NotNil(t, pG2) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + blsScalar, _ := scalar.GetUnderlyingObj().(*fr.Element) + blsScalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, pG2.Equal(computedG2)) +} + +func TestSuiteBLS12_PointLen(t *testing.T) { + suite := NewSuiteBLS12() + + pointLength := suite.PointLen() + + // G2 point length is 128 bytes + assert.Equal(t, 96, pointLength) +} + +func TestSuiteBLS12_CreateKey(t *testing.T) { + suite := NewSuiteBLS12() + private, public := suite.CreateKeyPair() + assert.NotNil(t, private) + assert.NotNil(t, public) +} + +func TestSuiteBLS12_GetUnderlyingSuite(t *testing.T) { + suite := NewSuiteBLS12() + + obj := suite.GetUnderlyingSuite() + + assert.NotNil(t, obj) +} + +func TestSuiteBLS12_CheckPointValidOK(t *testing.T) { + t.Skip() + // valid point: "93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b" + + // "7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8" + + validPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74dca01" + + suite := NewSuiteBLS12() + + validPointBytes, err := hex.DecodeString(validPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(validPointBytes) + require.Nil(t, err) +} + +func TestSuiteBLS12_PointFromMclCurveShouldBeBackwardsCompatible(t *testing.T) { + t.Skip() + + mclSuite := mcl.NewSuiteBLS12() + _, pk := mclSuite.CreateKeyPair() + + pointBytes, err := pk.MarshalBinary() + require.Nil(t, err) + require.Equal(t, len(pointBytes), 96) + + bls12381Suite := NewSuiteBLS12() + err = bls12381Suite.CheckPointValid(pointBytes) + require.Nil(t, err) +} + +func TestSuiteBLS12_CheckPointValidShortHexStringShouldErr(t *testing.T) { + shortPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + suite := NewSuiteBLS12() + + shortPointBytes, err := hex.DecodeString(shortPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(shortPointBytes) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestSuiteBLS12_CheckPointValidLongHexStrShouldErr(t *testing.T) { + longPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + suite := NewSuiteBLS12() + + longPointBytes, err := hex.DecodeString(longPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(longPointBytes) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestSuiteBLS12_CheckPointValidInvalidPointHexStrShouldErr(t *testing.T) { + invalidPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5caaaaaaaaaaa" + oneHexCharCorruptedPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74dca0a" + suite := NewSuiteBLS12() + + invalidPointBytes, err := hex.DecodeString(invalidPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(invalidPointBytes) + require.NotNil(t, err) + + oneHexCharCorruptedPointBytes, err := hex.DecodeString(oneHexCharCorruptedPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(oneHexCharCorruptedPointBytes) + require.NotNil(t, err) +} + +func TestSuiteBLS12_CheckPointValidZeroHexStrShouldErr(t *testing.T) { + t.Skip() + + zeroPointHexStr := "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + suite := NewSuiteBLS12() + + zeroPointBytes, err := hex.DecodeString(zeroPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(zeroPointBytes) + require.Equal(t, crypto.ErrInvalidPoint, err) +} + +func TestSuiteBLS12_IsInterfaceNil(t *testing.T) { + t.Parallel() + var suite *SuiteBLS12 + + require.True(t, check.IfNil(suite)) + suite = NewSuiteBLS12() + require.False(t, check.IfNil(suite)) +} diff --git a/curves/bn254/g1.go b/curves/bn254/g1.go new file mode 100644 index 0000000..9ebfb30 --- /dev/null +++ b/curves/bn254/g1.go @@ -0,0 +1,52 @@ +package bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupG1 struct { +} + +// String returns the string for the group +func (g1 *groupG1) String() string { + return "BN254 G1" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (g1 *groupG1) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar initialized with base point on G1 +func (g1 *groupG1) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (g1 *groupG1) PointLen() int { + return fp.Bytes +} + +// CreatePoint creates a new point +func (g1 *groupG1) CreatePoint() crypto.Point { + return NewPointG1() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (g1 *groupG1) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + var p crypto.Point + var err error + p = NewPointG1() + p, err = p.Mul(scalar) + if err != nil { + log.Error("groupG1 CreatePointForScalar", "error", err.Error()) + } + return p +} + +// IsInterfaceNil returns true if there is no value under the interface +func (g1 *groupG1) IsInterfaceNil() bool { + return g1 == nil +} diff --git a/curves/bn254/g1_test.go b/curves/bn254/g1_test.go new file mode 100644 index 0000000..effb0cf --- /dev/null +++ b/curves/bn254/g1_test.go @@ -0,0 +1,158 @@ +package bn254 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math/big" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGroupG1_String(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + str := grG1.String() + require.Equal(t, str, "BN254 G1") +} + +func TestGroupG1_ScalarLen(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + x := grG1.ScalarLen() + require.Equal(t, 32, x) +} + +func TestGroupG1_PointLen(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + x := grG1.PointLen() + require.Equal(t, 32, x) +} + +func TestGroupG1_CreatePoint(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + point := &PointG1{ + G1: &gnark.G1Jac{}, + } + + g1Gen, _, _, _ := gnark.Generators() + point.G1 = &g1Gen + + x := grG1.CreatePoint() + require.NotNil(t, x) + bn254Point, ok := x.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.Equal(t, point.G1, bn254Point) +} + +func TestGroupG1_CreateScalar(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + sc := grG1.CreateScalar() + require.NotNil(t, sc) + + mclScalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, mclScalar.IsZero()) + require.False(t, mclScalar.IsOne()) +} + +func TestGroupG1_CreatePointForScalar(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bn254Scalar.IsZero()) + require.False(t, bn254Scalar.IsOne()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + bn254PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bn254PointG1.IsOnCurve()) + + bG1 := NewPointG1().G1 + var scalarBigInt big.Int + bn254Scalar.BigInt(&scalarBigInt) + computedG1 := bG1.ScalarMultiplication(bG1, &scalarBigInt) + + require.True(t, bn254PointG1.Equal(computedG1)) +} + +func TestGroupG1_CreatePointForScalarZero(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + scalar.SetInt64(0) + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bn254Scalar.IsZero()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + bn254PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + + bG1 := NewPointG1().G1 + var scalarBigInt big.Int + bn254Scalar.BigInt(&scalarBigInt) + computedG1 := bG1.ScalarMultiplication(bG1, &scalarBigInt) + + require.True(t, bn254PointG1.Equal(computedG1)) +} + +func TestGroupG1_CreatePointForScalarOne(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + + scalar := grG1.CreateScalar() + scalar.SetInt64(1) + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bn254Scalar.IsOne()) + + pG1 := grG1.CreatePointForScalar(scalar) + require.NotNil(t, pG1) + + baseG1 := NewPointG1().G1 + bn254PointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bn254PointG1.Equal(baseG1)) +} + +func TestGroupG1_CreatePointForScalarNil(t *testing.T) { + t.Parallel() + + grG1 := &groupG1{} + pG1 := grG1.CreatePointForScalar(nil) + require.Equal(t, nil, pG1) +} + +func TestGroupG1_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grG1 *groupG1 + + require.True(t, grG1.IsInterfaceNil()) + grG1 = &groupG1{} + require.False(t, grG1.IsInterfaceNil()) +} diff --git a/curves/bn254/g2.go b/curves/bn254/g2.go new file mode 100644 index 0000000..19c4a90 --- /dev/null +++ b/curves/bn254/g2.go @@ -0,0 +1,52 @@ +package bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupG2 struct { +} + +// String returns the string for the group +func (g2 *groupG2) String() string { + return "BN254 G2" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (g2 *groupG2) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar initialized with base point on G2 +func (g2 *groupG2) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (g2 *groupG2) PointLen() int { + return fp.Bytes * 2 +} + +// CreatePoint creates a new point +func (g2 *groupG2) CreatePoint() crypto.Point { + return NewPointG2() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (g2 *groupG2) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + var p crypto.Point + var err error + p = NewPointG2() + p, err = p.Mul(scalar) + if err != nil { + log.Error("groupG2 CreatePointForScalar", "error", err.Error()) + } + return p +} + +// IsInterfaceNil returns true if there is no value under the interface +func (g2 *groupG2) IsInterfaceNil() bool { + return g2 == nil +} diff --git a/curves/bn254/g2_test.go b/curves/bn254/g2_test.go new file mode 100644 index 0000000..6856c6f --- /dev/null +++ b/curves/bn254/g2_test.go @@ -0,0 +1,151 @@ +package bn254 + +import ( + "math/big" + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/require" +) + +func TestGroupG2_String(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + str := grG2.String() + require.Equal(t, str, "BN254 G2") +} + +func TestGroupG2_ScalarLen(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + x := grG2.ScalarLen() + require.Equal(t, 32, x) +} + +func TestGroupG2_PointLen(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + x := grG2.PointLen() + require.Equal(t, 64, x) +} + +func TestGroupG2_CreatePoint(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + point := &PointG2{ + G2: &gnark.G2Jac{}, + } + + _, g2Gen, _, _ := gnark.Generators() + point.G2 = &g2Gen + + x := grG2.CreatePoint() + require.NotNil(t, x) + bn254Point, ok := x.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.Equal(t, point.G2, bn254Point) +} + +func TestGroupG2_CreateScalar(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + sc := grG2.CreateScalar() + require.NotNil(t, sc) + + bn254Scalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bn254Scalar.IsZero()) + require.False(t, bn254Scalar.IsOne()) +} + +func TestGroupG2_CreatePointForScalar(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bn254Scalar.IsZero()) + require.False(t, bn254Scalar.IsOne()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bn254PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + bn254Scalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, bn254PointG2.Equal(computedG2)) +} + +func TestGroupG2_CreatePointForScalarZero(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + scalar.SetInt64(0) + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bn254Scalar.IsZero()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bn254PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bn254PointG2.Z.IsZero()) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + bn254Scalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, bn254PointG2.Equal(computedG2)) +} + +func TestGroupG2_CreatePointForScalarOne(t *testing.T) { + t.Parallel() + + grG2 := &groupG2{} + + scalar := grG2.CreateScalar() + scalar.SetInt64(1) + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.True(t, bn254Scalar.IsOne()) + + pG2 := grG2.CreatePointForScalar(scalar) + require.NotNil(t, pG2) + + bG2 := NewPointG2().G2 + bn254PointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bn254PointG2.Equal(bG2)) +} + +func TestGroupG2_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grG2 *groupG2 + + require.True(t, check.IfNil(grG2)) + grG2 = &groupG2{} + require.False(t, check.IfNil(grG2)) +} diff --git a/curves/bn254/gt.go b/curves/bn254/gt.go new file mode 100644 index 0000000..c755665 --- /dev/null +++ b/curves/bn254/gt.go @@ -0,0 +1,45 @@ +package bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +type groupGT struct { +} + +// String returns the string for the group +func (gt *groupGT) String() string { + return "BN254 GT" +} + +// ScalarLen returns the maximum length of scalars in bytes +func (gt *groupGT) ScalarLen() int { + return fr.Bytes +} + +// CreateScalar creates a new Scalar +func (gt *groupGT) CreateScalar() crypto.Scalar { + return NewScalar() +} + +// PointLen returns the max length of point in nb of bytes +func (gt *groupGT) PointLen() int { + return fp.Bytes * 12 +} + +// CreatePoint creates a new point +func (gt *groupGT) CreatePoint() crypto.Point { + return NewPointGT() +} + +// CreatePointForScalar creates a new point corresponding to the given scalarInt +func (gt *groupGT) CreatePointForScalar(scalar crypto.Scalar) crypto.Point { + panic("not supported") +} + +// IsInterfaceNil returns true if there is no value under the interface +func (gt *groupGT) IsInterfaceNil() bool { + return gt == nil +} diff --git a/curves/bn254/gt_test.go b/curves/bn254/gt_test.go new file mode 100644 index 0000000..bf0ba1c --- /dev/null +++ b/curves/bn254/gt_test.go @@ -0,0 +1,96 @@ +package bn254 + +import ( + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGroupGT_String(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + str := grGT.String() + require.Equal(t, str, "BN254 GT") +} + +func TestGroupGT_ScalarLen(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.ScalarLen() + require.Equal(t, 32, x) +} + +func TestGroupGT_PointLen(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.PointLen() + require.Equal(t, 32*12, x) +} + +func TestGroupGT_CreatePoint(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + x := grGT.CreatePoint() + require.NotNil(t, x) + + bn254Point, ok := x.GetUnderlyingObj().(*gnark.GT) + require.True(t, ok) + // points created on GT are initialized with PointZero + require.True(t, bn254Point.IsZero()) +} + +func TestGroupGT_CreateScalar(t *testing.T) { + t.Parallel() + + grGT := &groupGT{} + + sc := grGT.CreateScalar() + require.NotNil(t, sc) + + bn254Scalar, ok := sc.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bn254Scalar.IsZero()) + require.False(t, bn254Scalar.IsOne()) +} + +func TestGroupGT_CreatePointForScalar(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should panic as currently not supported") + } + }() + + grGT := &groupGT{} + + scalar := grGT.CreateScalar() + bn254Scalar, ok := scalar.GetUnderlyingObj().(*fr.Element) + require.True(t, ok) + require.False(t, bn254Scalar.IsZero()) + require.False(t, bn254Scalar.IsOne()) + + _ = grGT.CreatePointForScalar(scalar) +} + +func TestGroupGT_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var grGT *groupGT + + require.True(t, grGT.IsInterfaceNil()) + grGT = &groupGT{} + require.False(t, grGT.IsInterfaceNil()) +} diff --git a/curves/bn254/pointG1.go b/curves/bn254/pointG1.go new file mode 100644 index 0000000..9fc17ff --- /dev/null +++ b/curves/bn254/pointG1.go @@ -0,0 +1,206 @@ +package bn254 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointG1 - +type PointG1 struct { + G1 *gnark.G1Jac +} + +func NewPointG1() *PointG1 { + point := &PointG1{ + G1: &gnark.G1Jac{}, + } + + g1Gen, _, _, _ := gnark.Generators() + point.G1 = &g1Gen + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointG1) Equal(p crypto.Point) (bool, error) { + if p == nil { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointG1) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.G1.Equal(po2.G1), nil +} + +// Clone returns a clone of the receiver. +func (po *PointG1) Clone() crypto.Point { + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.Set(po.G1) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointG1) Null() crypto.Point { + p := &PointG1{ + G1: &gnark.G1Jac{}, + } + + p.G1.Z.SetZero() + p.G1.X.SetOne() + p.G1.Y.SetOne() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointG1) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return crypto.ErrInvalidParam + } + + po.G1.Set(po1.G1) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointG1) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G1 = po2.G1.AddAssign(po1.G1) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointG1) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG1) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G1 = po2.G1.SubAssign(po1.G1) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointG1) Neg() crypto.Point { + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.Neg(po.G1) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointG1) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.G1 = po2.G1.ScalarMultiplication(po.G1, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointG1) Pick() (crypto.Point, error) { + scalar := NewScalar() + + po2 := PointG1{ + G1: &gnark.G1Jac{}, + } + + po2.G1 = po2.G1.ScalarMultiplication(po.G1, scalar.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointG1) GetUnderlyingObj() interface{} { + return po.G1 +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointG1) MarshalBinary() ([]byte, error) { + affinePoint := &gnark.G1Affine{} + affinePoint.FromJacobian(po.G1) + + return affinePoint.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointG1) UnmarshalBinary(point []byte) error { + affinePoint := &gnark.G1Affine{} + err := affinePoint.Unmarshal(point) + if err != nil { + return err + } + + po.G1 = po.G1.FromAffine(affinePoint) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointG1) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bn254/pointG1_test.go b/curves/bn254/pointG1_test.go new file mode 100644 index 0000000..0db1860 --- /dev/null +++ b/curves/bn254/pointG1_test.go @@ -0,0 +1,311 @@ +package bn254 + +import ( + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPointG1(t *testing.T) { + g1Gen, _, _, _ := gnark.Generators() + bG1 := &g1Gen + + pG1 := NewPointG1() + require.NotNil(t, pG1) + + mclPointG1, ok := pG1.GetUnderlyingObj().(*gnark.G1Jac) + require.True(t, ok) + require.True(t, bG1.Equal(mclPointG1)) +} + +func TestPointG1_Equal(t *testing.T) { + p1G1 := NewPointG1() + p2G1 := NewPointG1() + + eq, err := p1G1.Equal(p2G1) + require.Nil(t, err) + require.True(t, eq) + + // Make p1G1 different by multiplying it by 2 + scalar := NewScalar() + scalar.SetInt64(2) + p1Modified, err := p1G1.Mul(scalar) + require.Nil(t, err) + p1G1 = p1Modified.(*PointG1) + + eq, err = p1G1.Equal(p2G1) + require.Nil(t, err) + require.False(t, eq) + + grG1 := &groupG1{} + sc1G1 := grG1.CreateScalar() + p1 := grG1.CreatePointForScalar(sc1G1) + p2 := grG1.CreatePointForScalar(sc1G1) + + var ok bool + p1G1, ok = p1.(*PointG1) + require.True(t, ok) + + p2G1, ok = p2.(*PointG1) + require.True(t, ok) + + eq, err = p1G1.Equal(p2G1) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_CloneNilShouldPanic(t *testing.T) { + var p1 *PointG1 + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointG1_Clone(t *testing.T) { + p1 := NewPointG1() + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_Null(t *testing.T) { + p1 := NewPointG1() + + point := p1.Null() + bls12381Point, ok := point.(*PointG1) + require.True(t, ok) + require.True(t, bls12381Point.G1.X.IsOne()) + require.True(t, bls12381Point.G1.Y.IsOne()) + require.True(t, bls12381Point.G1.Z.IsZero()) + + bls12381PointNeg := &gnark.G1Jac{} + bls12381PointNeg = bls12381PointNeg.Neg(bls12381Point.G1) + + // neutral identity point should be equal to it's negation + ok = bls12381Point.G1.Equal(bls12381PointNeg) + require.True(t, ok) +} + +func TestPointG1_Set(t *testing.T) { + p1 := NewPointG1() + p2 := NewPointG1() + + scalar := NewScalar() + scalar.SetInt64(2) + p2Modified, err := p2.Mul(scalar) + require.Nil(t, err) + p2 = p2Modified.(*PointG1) + + err = p1.Set(p2) + require.Nil(t, err) + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG1_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG1_AddOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point1, err := pointG1.Pick() + require.Nil(t, err) + + point2, err := pointG1.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointG1_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point2, err := pointG1.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG1_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point2 := &mock.PointMock{} + point3, err := pointG1.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG1_SubOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + point1, err := pointG1.Pick() + require.Nil(t, err) + + point2, err := pointG1.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG1_Neg(t *testing.T) { + point1 := NewPointG1() + + point2 := point1.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point1, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point1, point3) +} + +func TestPointG1_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointG1_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG1() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointG1_MulOK(t *testing.T) { + t.Parallel() + + pointG1 := NewPointG1() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := pointG1.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, pointG1, res) + + grG1 := &groupG1{} + point2 := grG1.CreatePointForScalar(scalar) + eq, err := res.Equal(point2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG1_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointG1_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointG1_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointG1() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointG1_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointG1().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointG1() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG1_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointG1 + + require.True(t, check.IfNil(point)) + point = NewPointG1() + require.False(t, check.IfNil(point)) +} diff --git a/curves/bn254/pointG2.go b/curves/bn254/pointG2.go new file mode 100644 index 0000000..0b9bd9b --- /dev/null +++ b/curves/bn254/pointG2.go @@ -0,0 +1,207 @@ +package bn254 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointG2 - +type PointG2 struct { + G2 *gnark.G2Jac +} + +// NewPointG2 creates a new point on G2 initialized with base point +func NewPointG2() *PointG2 { + point := &PointG2{ + G2: &gnark.G2Jac{}, + } + + _, g2Gen, _, _ := gnark.Generators() + point.G2 = &g2Gen + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointG2) Equal(p crypto.Point) (bool, error) { + if check.IfNil(p) { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointG2) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.G2.Equal(po2.G2), nil +} + +// Clone returns a clone of the receiver. +func (po *PointG2) Clone() crypto.Point { + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.Set(po.G2) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointG2) Null() crypto.Point { + p := &PointG2{ + G2: &gnark.G2Jac{}, + } + + p.G2.Z.SetOne() + p.G2.X.SetZero() + p.G2.Y.SetZero() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointG2) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return crypto.ErrInvalidParam + } + + po.G2.Set(po1.G2) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointG2) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G2 = po2.G2.AddAssign(po1.G2) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointG2) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointG2) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.G2 = po2.G2.SubAssign(po1.G2) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointG2) Neg() crypto.Point { + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.Neg(po.G2) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointG2) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.G2 = po2.G2.ScalarMultiplication(po.G2, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointG2) Pick() (crypto.Point, error) { + scalar := NewScalar() + + po2 := PointG2{ + G2: &gnark.G2Jac{}, + } + + po2.G2 = po2.G2.ScalarMultiplication(po.G2, scalar.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointG2) GetUnderlyingObj() interface{} { + return po.G2 +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointG2) MarshalBinary() ([]byte, error) { + affinePoint := &gnark.G2Affine{} + affinePoint.FromJacobian(po.G2) + + return affinePoint.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointG2) UnmarshalBinary(point []byte) error { + affinePoint := &gnark.G2Affine{} + err := affinePoint.Unmarshal(point) + if err != nil { + return err + } + + po.G2 = po.G2.FromAffine(affinePoint) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointG2) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bn254/pointG2_test.go b/curves/bn254/pointG2_test.go new file mode 100644 index 0000000..e9e0585 --- /dev/null +++ b/curves/bn254/pointG2_test.go @@ -0,0 +1,312 @@ +package bn254 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "testing" +) + +func TestNewPointG2(t *testing.T) { + _, g2Gen, _, _ := gnark.Generators() + bG2 := &g2Gen + + pG2 := NewPointG2() + require.NotNil(t, pG2) + + mclPointG2, ok := pG2.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.True(t, bG2.Equal(mclPointG2)) +} + +func TestPointG2_Equal(t *testing.T) { + p1G2 := NewPointG2() + p2G2 := NewPointG2() + + // new points should be initialized with base point so should be equal + eq, err := p1G2.Equal(p2G2) + require.Nil(t, err) + require.True(t, eq) + + // Make p1G1 different by multiplying it by 2 + scalar := NewScalar() + scalar.SetInt64(2) + p1Modified, err := p1G2.Mul(scalar) + require.Nil(t, err) + p1G2 = p1Modified.(*PointG2) + + eq, err = p1G2.Equal(p2G2) + require.Nil(t, err) + require.False(t, eq) + + grG2 := &groupG2{} + sc1G2 := grG2.CreateScalar() + p1 := grG2.CreatePointForScalar(sc1G2) + p2 := grG2.CreatePointForScalar(sc1G2) + + var ok bool + p1G2, ok = p1.(*PointG2) + require.True(t, ok) + + p2G2, ok = p2.(*PointG2) + require.True(t, ok) + + eq, err = p1G2.Equal(p2G2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_CloneNilShouldPanic(t *testing.T) { + var p1 *PointG2 + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointG2_Clone(t *testing.T) { + p1 := NewPointG2() + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_Null(t *testing.T) { + p1 := NewPointG2() + + point := p1.Null() + bls12381Point, ok := point.(*PointG2) + require.True(t, ok) + require.True(t, bls12381Point.G2.X.IsZero()) + require.True(t, bls12381Point.G2.Y.IsZero()) + require.True(t, bls12381Point.G2.Z.IsOne()) + + bls12381PointNeg := &gnark.G2Jac{} + bls12381PointNeg = bls12381PointNeg.Neg(bls12381Point.G2) + + // neutral identity point should be equal to it's negation + ok = bls12381Point.G2.Equal(bls12381PointNeg) + require.True(t, ok) +} + +func TestPointG2_Set(t *testing.T) { + p1 := NewPointG2() + p2 := NewPointG2() + + scalar := NewScalar() + scalar.SetInt64(2) + p2Modified, err := p2.Mul(scalar) + require.Nil(t, err) + p2 = p2Modified.(*PointG2) + + err = p1.Set(p2) + require.Nil(t, err) + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG2_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG2_AddOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point1, err := pointG2.Pick() + require.Nil(t, err) + + point2, err := pointG2.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointG2_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point2, err := pointG2.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointG2_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point2 := &mock.PointMock{} + point3, err := pointG2.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointG2_SubOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + point1, err := pointG2.Pick() + require.Nil(t, err) + + point2, err := pointG2.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG2_Neg(t *testing.T) { + point1 := NewPointG2() + + point2 := point1.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point1, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point1, point3) +} + +func TestPointG2_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointG2_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointG2() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointG2_MulOK(t *testing.T) { + t.Parallel() + + pointG2 := NewPointG2() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := pointG2.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, pointG2, res) + + grG2 := &groupG2{} + point2 := grG2.CreatePointForScalar(scalar) + eq, err := res.Equal(point2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointG2_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointG2_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointG2_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointG2() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointG2_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointG2().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointG2() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointG2_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointG2 + + require.True(t, check.IfNil(point)) + point = NewPointG2() + require.False(t, check.IfNil(point)) +} diff --git a/curves/bn254/pointGT.go b/curves/bn254/pointGT.go new file mode 100644 index 0000000..4d522bd --- /dev/null +++ b/curves/bn254/pointGT.go @@ -0,0 +1,221 @@ +package bn254 + +import ( + "math/big" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// PointGT - +type PointGT struct { + *gnark.GT +} + +// NewPointGT creates a new point on GT initialized with identity +func NewPointGT() *PointGT { + point := &PointGT{ + GT: &gnark.GT{}, + } + + return point +} + +// Equal tests if receiver is equal with the Point p given as parameter. +// Both Points need to be derived from the same Group +func (po *PointGT) Equal(p crypto.Point) (bool, error) { + if check.IfNil(p) { + return false, crypto.ErrNilParam + } + + po2, ok := p.(*PointGT) + if !ok { + return false, crypto.ErrInvalidParam + } + + return po.GT.Equal(po2.GT), nil +} + +// Clone returns a clone of the receiver. +func (po *PointGT) Clone() crypto.Point { + po2 := PointGT{ + GT: &gnark.GT{}, + } + + po2.GT = po2.GT.Set(po.GT) + + return &po2 +} + +// Null returns the neutral identity element. +func (po *PointGT) Null() crypto.Point { + p := NewPointGT() + p.GT.C0.B0.SetZero() + p.GT.C0.B1.SetZero() + p.GT.C0.B2.SetZero() + + p.GT.C1.B0.SetZero() + p.GT.C1.B1.SetZero() + p.GT.C1.B2.SetZero() + + return p +} + +// Set sets the receiver equal to another Point p. +func (po *PointGT) Set(p crypto.Point) error { + if check.IfNil(p) { + return crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return crypto.ErrInvalidParam + } + + po.GT.Set(po1.GT) + + return nil +} + +// Add returns the result of adding receiver with Point p given as parameter, +// so that their scalars add homomorphically +func (po *PointGT) Add(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.GT = po2.GT.Add(po2.GT, po1.GT) + + return &po2, nil +} + +// Sub returns the result of subtracting from receiver the Point p given as parameter, +// so that their scalars subtract homomorphically +func (po *PointGT) Sub(p crypto.Point) (crypto.Point, error) { + if check.IfNil(p) { + return nil, crypto.ErrNilParam + } + + po1, ok := p.(*PointGT) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + err := po2.Set(po) + if err != nil { + return nil, err + } + + po2.GT = po2.GT.Sub(po2.GT, po1.GT) + + return &po2, nil +} + +// Neg returns the negation of receiver +func (po *PointGT) Neg() crypto.Point { + po2 := PointGT{ + GT: &gnark.GT{}, + } + + // Multiplicative in GT, we can use inverse + po2.GT = po2.GT.Inverse(po.GT) + + return &po2 +} + +// Mul returns the result of multiplying receiver by the scalarInt s. +func (po *PointGT) Mul(s crypto.Scalar) (crypto.Point, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + s1, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + po2.GT = po2.GT.Exp(*po.GT, s1.Scalar.BigInt(&big.Int{})) + + return &po2, nil +} + +// Pick returns a new random or pseudo-random Point. +func (po *PointGT) Pick() (crypto.Point, error) { + var p1, p2 crypto.Point + var err error + + p1, err = NewPointG1().Pick() + if err != nil { + return nil, err + } + + p2, err = NewPointG2().Pick() + if err != nil { + return nil, err + } + + poG1 := p1.(*PointG1) + poG2 := p2.(*PointG2) + + po2 := PointGT{ + GT: &gnark.GT{}, + } + + g1Affine := &gnark.G1Affine{} + g1Affine.FromJacobian(poG1.G1) + + g2Affine := &gnark.G2Affine{} + g2Affine.FromJacobian(poG2.G2) + + paired, err := gnark.Pair([]gnark.G1Affine{*g1Affine}, []gnark.G2Affine{*g2Affine}) + if err != nil { + return nil, err + } + + po2.GT = &paired + + return &po2, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (po *PointGT) GetUnderlyingObj() interface{} { + return po.GT +} + +// MarshalBinary converts the point into its byte array representation +func (po *PointGT) MarshalBinary() ([]byte, error) { + return po.GT.Marshal(), nil +} + +// UnmarshalBinary reconstructs a point from its byte array representation +func (po *PointGT) UnmarshalBinary(point []byte) error { + return po.GT.Unmarshal(point) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (po *PointGT) IsInterfaceNil() bool { + return po == nil +} diff --git a/curves/bn254/pointGT_test.go b/curves/bn254/pointGT_test.go new file mode 100644 index 0000000..dc5cfc3 --- /dev/null +++ b/curves/bn254/pointGT_test.go @@ -0,0 +1,276 @@ +package bn254 + +import ( + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "testing" +) + +func TestNewPointGT(t *testing.T) { + pGT := NewPointGT() + require.NotNil(t, pGT) + + bn254PointGT, ok := pGT.GetUnderlyingObj().(*gnark.GT) + require.True(t, ok) + require.True(t, bn254PointGT.IsZero()) +} + +func TestPointGT_Equal(t *testing.T) { + p1GT := NewPointGT() + p2GT := NewPointGT() + + eq, err := p1GT.Equal(p2GT) + require.Nil(t, err) + require.True(t, eq) + + p2GT.SetOne() + eq, err = p1GT.Equal(p2GT) + require.Nil(t, err) + require.False(t, eq) +} + +func TestPointGT_CloneNilShouldPanic(t *testing.T) { + var p1 *PointGT + + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "should have panicked") + } + }() + + _ = p1.Clone() +} + +func TestPointGT_Clone(t *testing.T) { + p1 := NewPointGT() + _, err := p1.GT.SetRandom() + require.Nil(t, err) + p2 := p1.Clone() + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointGT_Null(t *testing.T) { + p1 := NewPointGT() + + point := p1.Null() + bls12381Point, ok := point.(*PointGT) + require.True(t, ok) + require.True(t, bls12381Point.IsZero()) + bls12381PointNeg := &gnark.GT{} + // TODO + + // neutral identity point should be equal to it's negation + require.True(t, bls12381Point.GT.Equal(bls12381PointNeg)) +} + +func TestPointGT_Set(t *testing.T) { + p1 := NewPointGT() + p2 := NewPointGT() + + p2.GT.SetOne() + + err := p1.Set(p2) + require.Nil(t, err) + + eq, err := p1.Equal(p2) + require.Nil(t, err) + require.True(t, eq) +} + +func TestPointGT_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + point2, err := point.Add(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointGT_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + point2 := &mock.PointMock{} + point3, err := point.Add(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointGT_AddOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point1, err := pointGT.Pick() + require.Nil(t, err) + + point2, err := pointGT.Pick() + require.Nil(t, err) + + sum, err := point1.Add(point2) + require.Nil(t, err) + + p, err := sum.Sub(point2) + require.Nil(t, err) + + eq1, _ := point1.Equal(sum) + eq2, _ := point2.Equal(sum) + eq3, _ := point1.Equal(p) + + assert.False(t, eq1) + assert.False(t, eq2) + assert.True(t, eq3) +} + +func TestPointGT_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2, err := pointGT.Sub(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, point2) +} + +func TestPointGT_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2 := &mock.PointMock{} + point3, err := pointGT.Sub(point2) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, point3) +} + +func TestPointGT_SubOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point1, err := pointGT.Pick() + require.Nil(t, err) + + point2, err := pointGT.Pick() + require.Nil(t, err) + + sum, _ := point1.Add(point2) + point3, err := sum.Sub(point2) + assert.Nil(t, err) + + eq, err := point3.Equal(point1) + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointGT_Neg(t *testing.T) { + point1 := NewPointGT() + point, _ := point1.Pick() + + point2 := point.Neg() + point3 := point2.Neg() + + assert.NotEqual(t, point, point2) + assert.NotEqual(t, point2, point3) + assert.Equal(t, point, point3) +} + +func TestPointGT_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + res, err := point.Mul(nil) + + assert.Equal(t, crypto.ErrNilParam, err) + assert.Nil(t, res) +} + +func TestPointGT_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + point := NewPointGT() + scalar := &mock.ScalarMock{} + res, err := point.Mul(scalar) + + assert.Equal(t, crypto.ErrInvalidParam, err) + assert.Nil(t, res) +} + +func TestPointGT_MulOK(t *testing.T) { + t.Parallel() + + pointGT := NewPointGT() + point2, _ := pointGT.Pick() + s := NewScalar() + scalar, err := s.Pick() + require.Nil(t, err) + + res, err := point2.Mul(scalar) + + require.Nil(t, err) + require.NotNil(t, res) + require.NotEqual(t, point2, res) +} + +func TestPointGT_PickOK(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + point2, err1 := point1.Pick() + eq, err2 := point1.Equal(point2) + + assert.Nil(t, err1) + assert.Nil(t, err2) + assert.False(t, eq) +} + +func TestPointGT_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + p := point1.GetUnderlyingObj() + + assert.NotNil(t, p) +} + +func TestPointGT_MarshalBinary(t *testing.T) { + t.Parallel() + + point1 := NewPointGT() + pointBytes, err := point1.MarshalBinary() + + assert.Nil(t, err) + assert.NotNil(t, pointBytes) +} + +func TestPointGT_UnmarshalBinary(t *testing.T) { + t.Parallel() + + point1, _ := NewPointGT().Pick() + pointBytes, _ := point1.MarshalBinary() + + point2 := NewPointGT() + err := point2.UnmarshalBinary(pointBytes) + eq, _ := point1.Equal(point2) + + assert.Nil(t, err) + assert.True(t, eq) +} + +func TestPointGT_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var point *PointGT + + require.True(t, point.IsInterfaceNil()) + point = NewPointGT() + require.False(t, point.IsInterfaceNil()) +} diff --git a/curves/bn254/scalar.go b/curves/bn254/scalar.go new file mode 100644 index 0000000..49a6d3b --- /dev/null +++ b/curves/bn254/scalar.go @@ -0,0 +1,271 @@ +package bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" +) + +// Scalar - +type Scalar struct { + Scalar *fr.Element +} + +func NewScalar() *Scalar { + scalar := &Scalar{Scalar: &fr.Element{}} + scalar.setRandom() + + for scalar.Scalar.IsOne() || scalar.Scalar.IsZero() { + scalar.setRandom() + } + + return scalar +} + +// Equal tests if receiver is equal with the scalarInt s given as parameter. +// Both scalars need to be derived from the same Group +func (sc *Scalar) Equal(s crypto.Scalar) (bool, error) { + if check.IfNil(s) { + return false, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return false, crypto.ErrInvalidParam + } + + areEqual := sc.Scalar.Equal(s2.Scalar) + + return areEqual, nil +} + +// Set sets the receiver to Scalar s given as parameter +func (sc *Scalar) Set(s crypto.Scalar) error { + if check.IfNil(s) { + return crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return crypto.ErrInvalidParam + } + + return sc.Scalar.SetBytesCanonical(s2.Scalar.Marshal()) +} + +// Clone creates a new Scalar with same value as receiver +func (sc *Scalar) Clone() crypto.Scalar { + scalar := &Scalar{ + Scalar: &fr.Element{}, + } + + scalar.Scalar.SetBytes(sc.Scalar.Marshal()) + + return scalar +} + +// SetInt64 sets the receiver to a small integer value v given as parameter +func (sc *Scalar) SetInt64(v int64) { + sc.Scalar.SetInt64(v) +} + +// Zero returns the additive identity (0) +func (sc *Scalar) Zero() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + s.Scalar.SetZero() + + return &s +} + +// Add returns the modular sum of receiver with scalar s given as parameter +func (sc *Scalar) Add(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := &Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Add(sc.Scalar, s2.Scalar) + + return s1, nil +} + +// Sub returns the modular difference between receiver and scalar s given as parameter +func (sc *Scalar) Sub(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := &Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Sub(sc.Scalar, s2.Scalar) + + return s1, nil +} + +// Neg returns the modular negation of receiver +func (sc *Scalar) Neg() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + + s.Scalar.Neg(sc.Scalar) + + return &s +} + +// One returns the multiplicative identity (1) +func (sc *Scalar) One() crypto.Scalar { + s := Scalar{ + Scalar: &fr.Element{}, + } + s.Scalar.SetOne() + + return &s +} + +// Mul returns the modular product of receiver with scalar s given as parameter +func (sc *Scalar) Mul(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Mul(sc.Scalar, s2.Scalar) + + return &s1, nil +} + +// Div returns the modular division between receiver and scalar s given as parameter +func (sc *Scalar) Div(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Div(sc.Scalar, s2.Scalar) + + return &s1, nil +} + +// Inv returns the modular inverse of scalar s given as parameter +func (sc *Scalar) Inv(s crypto.Scalar) (crypto.Scalar, error) { + if check.IfNil(s) { + return nil, crypto.ErrNilParam + } + + s2, ok := s.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidParam + } + + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + s1.Scalar.Inverse(s2.Scalar) + + return &s1, nil +} + +// Pick returns a fresh random or pseudo-random scalar +// For the mock set X to the original scalar.X *2 +func (sc *Scalar) Pick() (crypto.Scalar, error) { + s1 := Scalar{ + Scalar: &fr.Element{}, + } + + _, err := s1.Scalar.SetRandom() + if err != nil { + return nil, err + } + + for s1.Scalar.IsOne() || s1.Scalar.IsZero() { + _, err = s1.Scalar.SetRandom() + if err != nil { + return nil, err + } + } + + return &s1, nil +} + +// SetBytes sets the scalar from a byte-slice, +// reducing if necessary to the appropriate modulus. +func (sc *Scalar) SetBytes(s []byte) (crypto.Scalar, error) { + if len(s) == 0 { + return nil, crypto.ErrNilParam + } + + s1 := sc.Clone() + s2, ok := s1.(*Scalar) + if !ok { + return nil, crypto.ErrInvalidScalar + } + + err := s2.Scalar.SetBytesCanonical(s) + if err != nil { + return nil, err + } + + return s1, nil +} + +// GetUnderlyingObj returns the object the implementation wraps +func (sc *Scalar) GetUnderlyingObj() interface{} { + return sc.Scalar +} + +// MarshalBinary transforms the Scalar into a byte array +func (sc *Scalar) MarshalBinary() ([]byte, error) { + return sc.Scalar.Marshal(), nil +} + +// UnmarshalBinary recreates the Scalar from a byte array +func (sc *Scalar) UnmarshalBinary(val []byte) error { + return sc.Scalar.SetBytesCanonical(val) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sc *Scalar) IsInterfaceNil() bool { + return sc == nil +} + +func (sc *Scalar) setRandom() { + _, err := sc.Scalar.SetRandom() + if err != nil { + panic("BN254 cannot read from rand to create a new scalar") + } +} diff --git a/curves/bn254/scalar_test.go b/curves/bn254/scalar_test.go new file mode 100644 index 0000000..40e2dc1 --- /dev/null +++ b/curves/bn254/scalar_test.go @@ -0,0 +1,406 @@ +package bn254 + +import ( + "testing" + + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/mock" + "github.com/stretchr/testify/require" +) + +func TestBNScalar_EqualInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + eq, err := scalar1.Equal(scalar2) + + require.False(t, eq) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBNScalar_EqualTrue(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().One() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBNScalar_EqualFalse(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().Zero() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.False(t, eq) +} + +func TestBNScalar_SetNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + err := scalar.Set(nil) + + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBNScalar_SetInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + err := scalar1.Set(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBNScalar_SetOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().Zero() + err := scalar1.Set(scalar2) + eq, _ := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBNScalar_Clone(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := scalar1.Clone() + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBNScalar_SetInt64(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2 := NewScalar() + scalar1.SetInt64(int64(555555555)) + scalar2.SetInt64(int64(444444444)) + + diff, _ := scalar1.Sub(scalar2) + scalar3 := NewScalar() + scalar3.SetInt64(int64(111111111)) + + eq, err := diff.Equal(scalar3) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBNScalar_Zero(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := NewScalar() + scalar2.SetInt64(0) + + eq, err := scalar2.Equal(scalar1) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBNScalar_AddNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().Zero() + sum, err := scalar.Add(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, sum) +} + +func TestBNScalar_AddInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + sum, err := scalar1.Add(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, sum) +} + +func TestBNScalar_AddOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar().One() + sum, err := scalar1.Add(scalar2) + require.Nil(t, err) + scalar3 := NewScalar() + scalar3.SetInt64(2) + eq, err := scalar3.Equal(sum) + + require.True(t, eq) + require.Nil(t, err) +} + +func TestBNScalar_SubNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().Zero() + diff, err := scalar.Sub(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, diff) +} + +func TestBNScalar_SubInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().Zero() + scalar2 := &mock.ScalarMock{} + diff, err := scalar1.Sub(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, diff) +} + +func TestBNScalar_SubOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2 := NewScalar().One() + diff, err := scalar1.Sub(scalar2) + require.Nil(t, err) + scalar3 := NewScalar() + scalar3.SetInt64(3) + eq, err := scalar3.Equal(diff) + + require.True(t, eq) + require.Nil(t, err) +} + +func TestBNScalar_Neg(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2 := scalar1.Neg() + scalar3 := NewScalar() + scalar3.SetInt64(-4) + eq, err := scalar2.Equal(scalar3) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBNScalar_One(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(1) + scalar2 := NewScalar().One() + + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} + +func TestBNScalar_MulNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + res, err := scalar.Mul(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, res) +} + +func TestBNScalar_MulInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + res, err := scalar1.Mul(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, res) +} + +func TestBNScalar_MulOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar() + scalar2.SetInt64(4) + res, err := scalar1.Mul(scalar2) + + require.Nil(t, err) + + eq, _ := res.Equal(scalar2) + + require.True(t, eq) +} + +func TestBNScalar_DivNilParamShouldEr(t *testing.T) { + t.Parallel() + + scalar := NewScalar().One() + res, err := scalar.Div(nil) + + require.Equal(t, crypto.ErrNilParam, err) + require.Nil(t, res) +} + +func TestBNScalar_DivInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := &mock.ScalarMock{} + res, err := scalar1.Div(scalar2) + + require.Equal(t, crypto.ErrInvalidParam, err) + require.Nil(t, res) +} + +func TestBNScalar_DivOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + scalar2 := NewScalar() + scalar2.SetInt64(4) + res, err := scalar2.Div(scalar1) + + require.Nil(t, err) + + eq, _ := res.Equal(scalar2) + + require.True(t, eq) +} + +func TestBNScalar_InvNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.Inv(nil) + + require.Nil(t, scalar2) + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBNScalar_InvInvalidParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2 := &mock.ScalarMock{} + scalar3, err := scalar1.Inv(scalar2) + + require.Nil(t, scalar3) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestBNScalar_InvOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar1.SetInt64(4) + scalar2, err := scalar1.Inv(scalar1) + eq, _ := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.NotNil(t, scalar2) + require.False(t, eq) + + one := NewScalar().One() + scalar3, err := scalar1.Inv(one) + require.Nil(t, err) + eq, _ = one.Equal(scalar3) + + require.True(t, eq) +} + +func TestBNScalar_PickOK(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.Pick() + require.Nil(t, err) + require.NotNil(t, scalar1, scalar2) + + eq, _ := scalar1.Equal(scalar2) + + require.False(t, eq) +} + +func TestBNScalar_SetBytesNilParamShouldErr(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar() + scalar2, err := scalar1.SetBytes(nil) + + require.Nil(t, scalar2) + require.Equal(t, crypto.ErrNilParam, err) +} + +func TestBNScalar_SetBytesOK(t *testing.T) { + t.Parallel() + + val := int64(555555555) + scalar1 := NewScalar().One() + + sc2 := NewScalar() + sc2.SetInt64(val) + buf, _ := sc2.MarshalBinary() + + scalar2, err := scalar1.SetBytes(buf) + require.Nil(t, err) + require.NotEqual(t, scalar1, scalar2) + + scalar3 := NewScalar() + scalar3.SetInt64(val) + + eq, _ := scalar3.Equal(scalar2) + require.True(t, eq) +} + +func TestBNScalar_GetUnderlyingObj(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + x := scalar1.GetUnderlyingObj() + + require.NotNil(t, x) +} + +func TestBNScalar_MarshalBinary(t *testing.T) { + t.Parallel() + + scalar1 := NewScalar().One() + + scalarBytes, err := scalar1.MarshalBinary() + + require.Nil(t, err) + require.NotNil(t, scalarBytes) +} + +func TestBNScalar_UnmarshalBinary(t *testing.T) { + scalar1, _ := NewScalar().Pick() + scalarBytes, err := scalar1.MarshalBinary() + require.Nil(t, err) + scalar2 := NewScalar().Zero() + err = scalar2.UnmarshalBinary(scalarBytes) + require.Nil(t, err) + + eq, err := scalar1.Equal(scalar2) + + require.Nil(t, err) + require.True(t, eq) +} diff --git a/curves/bn254/suiteBN254.go b/curves/bn254/suiteBN254.go new file mode 100644 index 0000000..a9f623c --- /dev/null +++ b/curves/bn254/suiteBN254.go @@ -0,0 +1,129 @@ +package bn254 + +import ( + "crypto/cipher" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("curves/bn254") + +// SuiteBN254 provides an implementation of the Suite interface for BN254 +type SuiteBN254 struct { + G1 *groupG1 + G2 *groupG2 + GT *groupGT + strSuite string +} + +// NewSuiteBN254 returns a wrapper over a BN254 curve. +func NewSuiteBN254() *SuiteBN254 { + return &SuiteBN254{ + G1: &groupG1{}, + G2: &groupG2{}, + GT: &groupGT{}, + strSuite: "BN254 suite", + } +} + +// RandomStream returns a cipher.Stream that returns a key stream +// from crypto/rand. +func (s *SuiteBN254) RandomStream() cipher.Stream { + return nil +} + +// CreatePoint creates a new point +func (s *SuiteBN254) CreatePoint() crypto.Point { + return s.G2.CreatePoint() +} + +// String returns the string for the group +func (s *SuiteBN254) String() string { + return s.strSuite +} + +// ScalarLen returns the maximum length of scalars in bytes +func (s *SuiteBN254) ScalarLen() int { + return s.G2.ScalarLen() +} + +// CreateScalar creates a new Scalar +func (s *SuiteBN254) CreateScalar() crypto.Scalar { + return s.G2.CreateScalar() +} + +// CreatePointForScalar creates a new point corresponding to the given scalar +func (s *SuiteBN254) CreatePointForScalar(scalar crypto.Scalar) (crypto.Point, error) { + if check.IfNil(scalar) { + return nil, crypto.ErrNilPrivateKeyScalar + } + sc, ok := scalar.GetUnderlyingObj().(*fr.Element) + if !ok { + return nil, crypto.ErrInvalidScalar + } + + if sc.IsZero() { + return nil, crypto.ErrInvalidPrivateKey + } + + point := s.G2.CreatePointForScalar(scalar) + + return point, nil +} + +// PointLen returns the max length of point in nb of bytes +func (s *SuiteBN254) PointLen() int { + return s.G2.PointLen() +} + +// CreateKeyPair returns a pair of private public BN254 keys. +// The private key is a scalarInt, while the public key is a Point on G2 curve +func (s *SuiteBN254) CreateKeyPair() (crypto.Scalar, crypto.Point) { + var sc crypto.Scalar + var err error + + sc = s.G2.CreateScalar() + sc, err = sc.Pick() + if err != nil { + log.Error("SuiteBN254 CreateKeyPair", "error", err.Error()) + return nil, nil + } + + p := s.G2.CreatePointForScalar(sc) + + return sc, p +} + +// GetUnderlyingSuite returns the underlying suite +func (s *SuiteBN254) GetUnderlyingSuite() interface{} { + return s +} + +// CheckPointValid returns error if the point is not valid (zero is also not valid), otherwise nil +func (s *SuiteBN254) CheckPointValid(pointBytes []byte) error { + if len(pointBytes) != s.PointLen() { + return crypto.ErrInvalidParam + } + + point := s.G2.CreatePoint() + err := point.UnmarshalBinary(pointBytes) + if err != nil { + return err + } + + pG2, ok := point.GetUnderlyingObj().(*gnark.G2Jac) + if !ok || !pG2.IsOnCurve() || !pG2.IsInSubGroup() { + return crypto.ErrInvalidPoint + } + + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (s *SuiteBN254) IsInterfaceNil() bool { + return s == nil +} diff --git a/curves/bn254/suiteBN254_test.go b/curves/bn254/suiteBN254_test.go new file mode 100644 index 0000000..4685cd0 --- /dev/null +++ b/curves/bn254/suiteBN254_test.go @@ -0,0 +1,180 @@ +package bn254 + +import ( + "encoding/hex" + "math/big" + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/multiversx/mx-chain-core-go/core/check" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSuiteBN254(t *testing.T) { + suite := NewSuiteBN254() + + assert.NotNil(t, suite) +} + +func TestSuiteBN254_RandomStream(t *testing.T) { + suite := NewSuiteBN254() + stream := suite.RandomStream() + require.Nil(t, stream) +} + +func TestSuiteBN254_CreatePoint(t *testing.T) { + suite := NewSuiteBN254() + + point1 := suite.CreatePoint() + point2 := suite.CreatePoint() + + assert.NotNil(t, point1) + assert.NotNil(t, point2) + assert.False(t, point1 == point2) +} + +func TestSuiteBN254_String(t *testing.T) { + suite := NewSuiteBN254() + + str := suite.String() + assert.Equal(t, "BN254 suite", str) +} + +func TestSuiteBN254_ScalarLen(t *testing.T) { + suite := NewSuiteBN254() + + length := suite.ScalarLen() + assert.Equal(t, 32, length) +} + +func TestSuiteBN254_CreateScalar(t *testing.T) { + suite := NewSuiteBN254() + + scalar := suite.CreateScalar() + assert.NotNil(t, scalar) +} + +func TestSuiteBN254_CreatePointForScalar(t *testing.T) { + suite := NewSuiteBN254() + scalar := NewScalar() + + point, err := suite.CreatePointForScalar(scalar) + require.Nil(t, err) + pG2, ok := point.GetUnderlyingObj().(*gnark.G2Jac) + require.True(t, ok) + require.NotNil(t, pG2) + + bG2 := NewPointG2().G2 + var scalarBigInt big.Int + bn254Scalar, _ := scalar.GetUnderlyingObj().(*fr.Element) + bn254Scalar.BigInt(&scalarBigInt) + computedG2 := bG2.ScalarMultiplication(bG2, &scalarBigInt) + + require.True(t, pG2.Equal(computedG2)) +} + +func TestSuiteBN254_PointLen(t *testing.T) { + suite := NewSuiteBN254() + + pointLength := suite.PointLen() + + assert.Equal(t, 64, pointLength) +} + +func TestSuiteBN254_CreateKey(t *testing.T) { + suite := NewSuiteBN254() + private, public := suite.CreateKeyPair() + assert.NotNil(t, private) + assert.NotNil(t, public) +} + +func TestSuiteBN254_GetUnderlyingSuite(t *testing.T) { + suite := NewSuiteBN254() + + obj := suite.GetUnderlyingSuite() + + assert.NotNil(t, obj) +} + +func TestSuiteBN254_CheckPointValidOK(t *testing.T) { + t.Skip() + + validPointHexStr := "998e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c21800deef121f1e76426a00665e5c44" + + "79674322d4f75edadd46debd5cd992f6ed" + + suite := NewSuiteBN254() + + validPointBytes, err := hex.DecodeString(validPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(validPointBytes) + require.Nil(t, err) +} + +func TestSuiteBN254_CheckPointValidShortHexStringShouldErr(t *testing.T) { + shortPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + suite := NewSuiteBN254() + + shortPointBytes, err := hex.DecodeString(shortPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(shortPointBytes) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestSuiteBN254_CheckPointValidLongHexStrShouldErr(t *testing.T) { + longPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74d" + + suite := NewSuiteBN254() + + longPointBytes, err := hex.DecodeString(longPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(longPointBytes) + require.Equal(t, crypto.ErrInvalidParam, err) +} + +func TestSuiteBN254_CheckPointValidInvalidPointHexStrShouldErr(t *testing.T) { + invalidPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5caaaaaaaaaaa" + oneHexCharCorruptedPointHexStr := "368723d835fca6bc0c17a270e51b731f69f9fe482ed88e8c3d879f228291d48057aa12d0de8476b4a111e945399253" + + "15d2d3fd1b85e29e465b8814b713cbf833115f4562e28dcf58e960751f0581578ca1819c8790aa5a5300c5c317b74dca0a" + suite := NewSuiteBN254() + + invalidPointBytes, err := hex.DecodeString(invalidPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(invalidPointBytes) + require.NotNil(t, err) + + oneHexCharCorruptedPointBytes, err := hex.DecodeString(oneHexCharCorruptedPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(oneHexCharCorruptedPointBytes) + require.NotNil(t, err) +} + +func TestSuiteBN254_CheckPointValidZeroHexStrShouldErr(t *testing.T) { + t.Skip() + + zeroPointHexStr := "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + + suite := NewSuiteBN254() + + zeroPointBytes, err := hex.DecodeString(zeroPointHexStr) + require.Nil(t, err) + err = suite.CheckPointValid(zeroPointBytes) + require.Equal(t, crypto.ErrInvalidPoint, err) +} + +func TestSuiteBN254_IsInterfaceNil(t *testing.T) { + t.Parallel() + var suite *SuiteBN254 + + require.True(t, check.IfNil(suite)) + suite = NewSuiteBN254() + require.False(t, check.IfNil(suite)) +} diff --git a/go.mod b/go.mod index 3567bc0..10d1ef6 100644 --- a/go.mod +++ b/go.mod @@ -4,23 +4,39 @@ go 1.23 require ( filippo.io/edwards25519 v1.0.0 + github.com/consensys/gnark v0.12.0 + github.com/consensys/gnark-crypto v0.17.0 github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 github.com/herumi/bls-go-binary v1.28.2 github.com/multiversx/mx-chain-core-go v1.4.0 github.com/multiversx/mx-chain-logger-go v1.1.0 - github.com/stretchr/testify v1.8.0 - golang.org/x/crypto v0.3.0 + github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.33.0 ) require ( + github.com/bits-and-blooms/bitset v1.20.0 // indirect + github.com/blang/semver/v4 v4.0.0 // indirect + github.com/consensys/bavard v0.1.29 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/denisbrodbeck/machineid v1.0.1 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.2 // indirect + github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect + github.com/ingonyama-zk/icicle/v3 v3.1.1-0.20241118092657-fccdb2f0921b // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect github.com/pelletier/go-toml v1.9.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sys v0.2.0 // indirect + github.com/ronanh/intcomp v1.1.0 // indirect + github.com/rs/zerolog v1.33.0 // indirect + github.com/x448/float16 v0.8.4 // indirect + golang.org/x/sync v0.11.0 // indirect + golang.org/x/sys v0.30.0 // indirect google.golang.org/protobuf v1.26.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + rsc.io/tmplfunc v0.0.3 // indirect ) diff --git a/go.sum b/go.sum index 3efa91a..ad855b9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,16 @@ filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/bits-and-blooms/bitset v1.20.0 h1:2F+rfL86jE2d/bmw7OhqUg2Sj/1rURkBn3MdfoPyRVU= +github.com/bits-and-blooms/bitset v1.20.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= +github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= +github.com/consensys/bavard v0.1.29 h1:fobxIYksIQ+ZSrTJUuQgu+HIJwclrAPcdXqd7H2hh1k= +github.com/consensys/bavard v0.1.29/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/gnark v0.12.0 h1:XgQ1kh2R6fHuf5fBYl+i7TxR+QTbGQuZaaqqkk5nLO0= +github.com/consensys/gnark v0.12.0/go.mod h1:WDvuIQ8qrRvWT9NhTrib84WeLVBSGhSTrbQBXs1yR5w= +github.com/consensys/gnark-crypto v0.17.0 h1:vKDhZMOrySbpZDCvGMOELrHFv/A9mJ7+9I8HEfRZSkI= +github.com/consensys/gnark-crypto v0.17.0/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/crypto/blake256 v1.0.0 h1:/8DMNYp9SGi5f0w7uCm6d6M4OU2rGFK09Y2A4Xv7EE0= @@ -9,17 +19,41 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1 github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= github.com/denisbrodbeck/machineid v1.0.1 h1:geKr9qtkB876mXguW2X6TU4ZynleN6ezuMSRhl4D7AQ= github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbjJCrnectwCyxcUSI= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/herumi/bls-go-binary v1.28.2 h1:F0AezsC0M1a9aZjk7g0l2hMb1F56Xtpfku97pDndNZE= github.com/herumi/bls-go-binary v1.28.2/go.mod h1:O4Vp1AfR4raRGwFeQpr9X/PQtncEicMoOe6BQt1oX0Y= +github.com/ingonyama-zk/icicle/v3 v3.1.1-0.20241118092657-fccdb2f0921b h1:AvQTK7l0PTHODD06PVQX1Tn2o29sRIaKIDOvTJmKurY= +github.com/ingonyama-zk/icicle/v3 v3.1.1-0.20241118092657-fccdb2f0921b/go.mod h1:e0JHb27/P6WorCJS3YolbY5XffS4PGBuoW38OthLkDs= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leanovate/gopter v0.2.11 h1:vRjThO1EKPb/1NsDXuDrzldR28RLkBflWYcU9CvzWu4= +github.com/leanovate/gopter v0.2.11/go.mod h1:aK3tzZP/C+p1m3SPRE4SYZFGP7jjkuSI4f7Xvpt0S9c= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= +github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= +github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/multiversx/mx-chain-core-go v1.4.0 h1:p6FbfCzvMXF54kpS0B5mrjNWYpq4SEQqo0UvrMF7YVY= @@ -28,20 +62,27 @@ github.com/multiversx/mx-chain-logger-go v1.1.0 h1:97x84A6L4RfCa6YOx1HpAFxZp1cf/ github.com/multiversx/mx-chain-logger-go v1.1.0/go.mod h1:K9XgiohLwOsNACETMNL0LItJMREuEvTH6NsoXWXWg7g= github.com/pelletier/go-toml v1.9.3 h1:zeC5b1GviRUyKYd6OJPvBU/mcVDVoL1OhT17FCt5dSQ= github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/ronanh/intcomp v1.1.0 h1:i54kxmpmSoOZFcWPMWryuakN0vLxLswASsGa07zkvLU= +github.com/ronanh/intcomp v1.1.0/go.mod h1:7FOLy3P3Zj3er/kVrU/pl+Ql7JFZj7bwliMGketo0IU= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= -golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -51,11 +92,16 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -65,13 +111,16 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/tmplfunc v0.0.3 h1:53XFQh69AfOa8Tw0Jm7t+GV7KZhOi6jzsCzTtKbMvzU= +rsc.io/tmplfunc v0.0.3/go.mod h1:AG3sTPzElb1Io3Yg4voV9AGZJuleGAwaVRxL9M49PhA= diff --git a/integrationTests/groth16Verifier_test.go b/integrationTests/groth16Verifier_test.go new file mode 100644 index 0000000..b8cbedc --- /dev/null +++ b/integrationTests/groth16Verifier_test.go @@ -0,0 +1,55 @@ +package integrationTests + +import ( + "bytes" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + gnarkgroth16 "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/examples/exponentiate" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/multiversx/mx-chain-crypto-go/zk/groth16" + "github.com/stretchr/testify/require" +) + +func TestGroth16Verifier(t *testing.T) { + css, err := frontend.Compile(ecc.BLS12_381.ScalarField(), r1cs.NewBuilder, &exponentiate.Circuit{}) + require.Nil(t, err) + + // Setup on the prover side + pk, vk, err := gnarkgroth16.Setup(css) + require.Nil(t, err) + + homework := &exponentiate.Circuit{ + X: 2, + Y: 16, + E: 4, + } + + witness, err := frontend.NewWitness(homework, ecc.BLS12_381.ScalarField()) + require.Nil(t, err) + + proof, err := gnarkgroth16.Prove(css, pk, witness) + require.Nil(t, err) + + var serializedProof bytes.Buffer + _, err = proof.WriteTo(&serializedProof) + require.Nil(t, err) + + var serializedVK bytes.Buffer + _, err = vk.WriteTo(&serializedVK) + require.Nil(t, err) + + // There are two ways to generate the public witness - either from the prover full witness, either recreate + // using the circuit with only the public inputs into it + pubW, err := witness.Public() + require.Nil(t, err) + pubWBytes, err := pubW.MarshalBinary() + require.Nil(t, err) + // Now a tx can do: verify@proof_bytes@pub_witness_bytes; the curve_id and vk should be in the contract state + + verified, err := groth16.VerifyGroth16(uint16(ecc.BLS12_381), serializedProof.Bytes(), serializedVK.Bytes(), pubWBytes) + require.True(t, verified) + require.Nil(t, err) +} diff --git a/integrationTests/mclGnarkConversions_test.go b/integrationTests/mclGnarkConversions_test.go new file mode 100644 index 0000000..15a29eb --- /dev/null +++ b/integrationTests/mclGnarkConversions_test.go @@ -0,0 +1,141 @@ +package integrationTests + +import ( + "fmt" + "testing" + + gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/herumi/bls-go-binary/bls" + "github.com/multiversx/mx-chain-crypto-go/curves/bls/bls12381" + blsInterop "github.com/multiversx/mx-chain-crypto-go/curves/bls/bls12381/interop" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + mclInterop "github.com/multiversx/mx-chain-crypto-go/signing/mcl/interop" + "github.com/stretchr/testify/require" +) + +func TestFromMCLToGnark(t *testing.T) { + mclSuite := mcl.NewSuiteBLS12() + bls12381Suite := bls12381.NewSuiteBLS12() + + for i := 0; i < 500; i++ { + _, pk := mclSuite.CreateKeyPair() + + pointBytes, _ := pk.MarshalBinary() + convertedPoint, err := blsInterop.PointBytesFromMcl(pointBytes) + require.Nil(t, err) + + err = bls12381Suite.CheckPointValid(convertedPoint) + + require.Nil(t, err) + } +} + +func TestFromGnarkToMCL(t *testing.T) { + gnarkSuite := bls12381.NewSuiteBLS12() + mclSuite := mcl.NewSuiteBLS12() + + for i := 0; i < 500; i++ { + _, pk := gnarkSuite.CreateKeyPair() + + pointBytes, _ := pk.MarshalBinary() + convertedPoint, err := mclInterop.PointBytesFromGnark(pointBytes) + require.Nil(t, err) + + err = mclSuite.CheckPointValid(convertedPoint) + + require.Nil(t, err) + } +} + +func printBytesAsBinary(data []byte) { + for i, b := range data { + // %08b formats b as binary, padded to 8 digits with leading zeros + fmt.Printf("%08b", b) + if i < len(data)-1 { + fmt.Print(" ") + } + } + fmt.Println() +} + +func TestPointsBackAndForth(t *testing.T) { + mclSuite := mcl.NewSuiteBLS12() + _, pk1 := mclSuite.CreateKeyPair() + pointBytes1, _ := pk1.MarshalBinary() + + convertedPointBytes1, _ := blsInterop.PointBytesFromMcl(pointBytes1) + gnarkPoint1 := bls12381.NewPointG2() + _ = gnarkPoint1.UnmarshalBinary(convertedPointBytes1) + + actualGnarkBytes, _ := gnarkPoint1.MarshalBinary() + convertBackGnarkForMCL, _ := mclInterop.PointBytesFromGnark(actualGnarkBytes) + + lastPoint := mcl.NewPointG2() + _ = lastPoint.UnmarshalBinary(convertBackGnarkForMCL) + lastPointBytes, _ := lastPoint.MarshalBinary() + + require.Equal(t, pointBytes1, lastPointBytes) +} + +func TestSameOperationsDifferentSuitesShouldBeEqual(t *testing.T) { + t.Skip("TODO: Find the difference in curve operations") + + mclSuite := mcl.NewSuiteBLS12() + _, pk1 := mclSuite.CreateKeyPair() + _, pk2 := mclSuite.CreateKeyPair() + mclResult, err := pk1.Add(pk2) + + fmt.Println("starting point") + spb, _ := mclResult.MarshalBinary() + printBytesAsBinary(spb) + + require.Nil(t, err) + + pointBytes1, _ := pk1.MarshalBinary() + convertedPointBytes1, err := blsInterop.PointBytesFromMcl(pointBytes1) + require.Nil(t, err) + + convertedPoint1 := bls12381.NewPointG2() + err = convertedPoint1.UnmarshalBinary(convertedPointBytes1) + require.Nil(t, err) + + pointBytes2, _ := pk2.MarshalBinary() + convertedPointBytes2, err := blsInterop.PointBytesFromMcl(pointBytes2) + require.Nil(t, err) + + convertedPoint2 := bls12381.NewPointG2() + err = convertedPoint2.UnmarshalBinary(convertedPointBytes2) + require.Nil(t, err) + + gnarkResult, err := convertedPoint1.Add(convertedPoint2) + require.Nil(t, err) + gnarkResult1, _ := gnarkResult.(*bls12381.PointG2) + gnarkResult1.G2.ClearCofactor(gnarkResult1.G2) + + gnarkResultBytes, err := gnarkResult1.MarshalBinary() + require.Nil(t, err) + convertedPointBytes, err := mclInterop.PointBytesFromGnark(gnarkResultBytes) + require.Nil(t, err) + convertedPoint := mcl.NewPointG2() + + err = convertedPoint.UnmarshalBinary(convertedPointBytes) + require.Nil(t, err) + + fmt.Println("resulting point") + printBytesAsBinary(convertedPointBytes) + equal, err := mclResult.Equal(convertedPoint) + require.Nil(t, err) + require.True(t, equal) +} + +// TODO: Remove this test, printed generators just for reference and proofs they are not the same +func TestMclGnarkGenerators(t *testing.T) { + pubKey1 := &bls.PublicKey{} + bls.GetGeneratorOfPublicKey(pubKey1) + b := pubKey1.GetHexString() + fmt.Println(b) + + _, g2, _, _ := gnark.Generators() + fmt.Println(g2.String()) + +} diff --git a/integrationTests/plonkVerifier_test.go b/integrationTests/plonkVerifier_test.go new file mode 100644 index 0000000..34e371b --- /dev/null +++ b/integrationTests/plonkVerifier_test.go @@ -0,0 +1,60 @@ +package integrationTests + +import ( + "bytes" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + gnarkplonk "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/examples/exponentiate" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test/unsafekzg" + "github.com/multiversx/mx-chain-crypto-go/zk/plonk" + "github.com/stretchr/testify/require" +) + +func TestPlonkVerifier(t *testing.T) { + css, err := frontend.Compile(ecc.BLS12_381.ScalarField(), scs.NewBuilder, &exponentiate.Circuit{}) + require.Nil(t, err) + + srs, srsLagrange, err := unsafekzg.NewSRS(css) + require.Nil(t, err) + + // Setup on the prover side + pk, vk, err := gnarkplonk.Setup(css, srs, srsLagrange) + require.Nil(t, err) + + homework := &exponentiate.Circuit{ + X: 2, + Y: 16, + + E: 4, + } + + witness, err := frontend.NewWitness(homework, ecc.BLS12_381.ScalarField()) + require.Nil(t, err) + + proof, err := gnarkplonk.Prove(css, pk, witness) + require.Nil(t, err) + + var serializedProof bytes.Buffer + _, err = proof.WriteTo(&serializedProof) + require.Nil(t, err) + + var serializedVK bytes.Buffer + _, err = vk.WriteTo(&serializedVK) + require.Nil(t, err) + + // There are two ways to generate the public witness - either from the prover full witness, either recreate + // using the circuit with only the public inputs into it + pubW, err := witness.Public() + require.Nil(t, err) + pubWBytes, err := pubW.MarshalBinary() + require.Nil(t, err) + + // Now a tx can do: verify@proof_bytes@pub_witness_bytes; the curve_id and vk should be in the contract state + verified, err := plonk.VerifyPlonk(uint16(ecc.BLS12_381), serializedProof.Bytes(), serializedVK.Bytes(), pubWBytes) + require.True(t, verified) + require.Nil(t, err) +} diff --git a/signing/mcl/interop/flags.go b/signing/mcl/interop/flags.go new file mode 100644 index 0000000..ee6c650 --- /dev/null +++ b/signing/mcl/interop/flags.go @@ -0,0 +1,44 @@ +package interop + +import ( + "errors" + "slices" + + "github.com/multiversx/mx-chain-crypto-go/curves/bls/bls12381" +) + +const ( + g2CompressedSize = 96 + g2UnCompressedSize = 192 + yOddMask = 0x80 +) + +// PointBytesFromGnark converts a point from BLS to MCL format +func PointBytesFromGnark(rawPoint []byte) ([]byte, error) { + if len(rawPoint) != g2UnCompressedSize { + return nil, errors.New("interop: raw BLS point must be 192 bytes") + } + + // TODO: Find a cleaner way to test Y sign without re-assembling the gnark point + gnarkPoint := bls12381.NewPointG2() + _ = gnarkPoint.UnmarshalBinary(rawPoint) + isYodd := gnarkPoint.G2.Y.LexicographicallyLargest() + + X := rawPoint[:g2CompressedSize] + X = reverseBytes(X) + + if isYodd { + X[g2CompressedSize-1] |= yOddMask + } else { + X[g2CompressedSize-1] &= 0x7f + } + + return X, nil +} + +func reverseBytes(in []byte) []byte { + out := append([]byte(nil), in...) + slices.Reverse(out) + + return out +} diff --git a/zk/groth16/verify.go b/zk/groth16/verify.go new file mode 100644 index 0000000..1ad8b95 --- /dev/null +++ b/zk/groth16/verify.go @@ -0,0 +1,50 @@ +package groth16 + +import ( + "bytes" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" + "github.com/multiversx/mx-chain-crypto-go/zk/lowLevelFeatures" +) + +// VerifyGroth16 verifies the groth16 proof from the given input and curve id +func VerifyGroth16(curveID uint16, proofBytes, vkBytes, pubWitnessBytes []byte) (bool, error) { + if len(proofBytes) == 0 || len(vkBytes) == 0 || len(pubWitnessBytes) == 0 { + return false, lowLevelFeatures.ErrNilOrEmptyInput + } + + _, ok := lowLevelFeatures.SupportedCurvesRegistry[ecc.ID(curveID)] + if !ok { + return false, lowLevelFeatures.ErrInvalidCurve + } + + vk := groth16.NewVerifyingKey(ecc.ID(curveID)) + _, err := vk.ReadFrom(bytes.NewReader(vkBytes)) + if err != nil { + return false, err + } + + proof := groth16.NewProof(ecc.ID(curveID)) + _, err = proof.ReadFrom(bytes.NewReader(proofBytes)) + if err != nil { + return false, err + } + + pubWitness, err := witness.New(ecc.ID(curveID).ScalarField()) + if err != nil { + return false, err + } + err = pubWitness.UnmarshalBinary(pubWitnessBytes) + if err != nil { + return false, err + } + + err = groth16.Verify(proof, vk, pubWitness) + if err != nil { + return false, nil + } + + return true, nil +} diff --git a/zk/groth16/verify_test.go b/zk/groth16/verify_test.go new file mode 100644 index 0000000..bf5f51d --- /dev/null +++ b/zk/groth16/verify_test.go @@ -0,0 +1,141 @@ +package groth16 + +import ( + "bytes" + "github.com/multiversx/mx-chain-crypto-go/zk/lowLevelFeatures" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/stretchr/testify/require" +) + +// CubicCircuit defines a simple circuit +type CubicCircuit struct { + X frontend.Variable `gnark:"x"` + Y frontend.Variable `gnark:"y,public"` +} + +// Define defines the circuit constraints +func (circuit *CubicCircuit) Define(api frontend.API) error { + x3 := api.Mul(circuit.X, circuit.X, circuit.X) + api.AssertIsEqual(circuit.Y, api.Add(x3, circuit.X, 5)) + return nil +} + +func TestVerifyGroth16(t *testing.T) { + // 1. Compile the circuit + var circuit CubicCircuit + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + require.NoError(t, err) + + // 2. Run the setup + pk, vk, err := groth16.Setup(ccs) + require.NoError(t, err) + + // 3. Create a valid witness + assignment := CubicCircuit{X: 3, Y: 35} + w, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + require.NoError(t, err) + publicW, err := w.Public() + require.NoError(t, err) + + // 4. Generate a proof + proof, err := groth16.Prove(ccs, pk, w) + require.NoError(t, err) + + // 5. Convert vk, proof, and public witness to bytes + var vkBuf bytes.Buffer + _, err = vk.WriteTo(&vkBuf) + require.NoError(t, err) + vkBytes := vkBuf.Bytes() + + var proofBuf bytes.Buffer + _, err = proof.WriteTo(&proofBuf) + require.NoError(t, err) + proofBytes := proofBuf.Bytes() + + pubWitnessBytes, err := publicW.MarshalBinary() + require.NoError(t, err) + + // 6. Test success case + ok, err := VerifyGroth16(uint16(ecc.BN254), proofBytes, vkBytes, pubWitnessBytes) + require.NoError(t, err) + require.True(t, ok) + + // 7. Test failure cases + // Invalid proof + invalidProofBytes := append([]byte(nil), proofBytes...) + invalidProofBytes[0] ^= 0x01 + ok, _ = VerifyGroth16(uint16(ecc.BN254), invalidProofBytes, vkBytes, pubWitnessBytes) + require.False(t, ok) + + // Invalid public witness + invalidAssignment := CubicCircuit{X: 3, Y: 36} // Y is incorrect + invalidW, err := frontend.NewWitness(&invalidAssignment, ecc.BN254.ScalarField()) + require.NoError(t, err) + invalidPublicW, err := invalidW.Public() + require.NoError(t, err) + invalidPubWitnessBytes, err := invalidPublicW.MarshalBinary() + require.NoError(t, err) + ok, err = VerifyGroth16(uint16(ecc.BN254), proofBytes, vkBytes, invalidPubWitnessBytes) + require.Nil(t, err) + require.False(t, ok) + + // Invalid vk + invalidVkBytes := append([]byte(nil), vkBytes...) + invalidVkBytes[0] ^= 0x01 + ok, _ = VerifyGroth16(uint16(ecc.BN254), proofBytes, invalidVkBytes, pubWitnessBytes) + require.False(t, ok) + + // test error cases from my previous fix + _, err = VerifyGroth16(uint16(ecc.BN254), []byte("invalid"), vkBytes, pubWitnessBytes) + require.Error(t, err) + + // create a dummy witness for a different curve + invalidWitness, err := witness.New(ecc.BLS12_381.ScalarField()) + require.NoError(t, err) + invalidWitnessBytes, err := invalidWitness.MarshalBinary() + require.NoError(t, err) + ok, err = VerifyGroth16(uint16(ecc.BN254), proofBytes, vkBytes, invalidWitnessBytes) + require.Nil(t, err) + require.False(t, ok) + + // test invalid curve + _, err = VerifyGroth16(uint16(ecc.UNKNOWN), proofBytes, vkBytes, invalidWitnessBytes) + require.Equal(t, lowLevelFeatures.ErrInvalidCurve, err) + _, err = VerifyGroth16(42, proofBytes, vkBytes, invalidWitnessBytes) + require.Equal(t, lowLevelFeatures.ErrInvalidCurve, err) +} + +func TestVerifyGroth16_NilOrEmptyInput(t *testing.T) { + // Invalid proof + verified, err := VerifyGroth16(uint16(ecc.BN254), nil, []byte("vk"), []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + verified, err = VerifyGroth16(uint16(ecc.BN254), []byte{}, []byte("vk"), []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + // Invalid vk + verified, err = VerifyGroth16(uint16(ecc.BN254), []byte("proof"), nil, []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + verified, err = VerifyGroth16(uint16(ecc.BN254), []byte("proof"), []byte{}, []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + // Invalid public witness + verified, err = VerifyGroth16(uint16(ecc.BN254), []byte("proof"), []byte("vk"), nil) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + verified, err = VerifyGroth16(uint16(ecc.BN254), []byte("proof"), []byte("vk"), []byte{}) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) +} diff --git a/zk/lowLevelFeatures/constants.go b/zk/lowLevelFeatures/constants.go new file mode 100644 index 0000000..965a449 --- /dev/null +++ b/zk/lowLevelFeatures/constants.go @@ -0,0 +1,23 @@ +package lowLevelFeatures + +import ( + "github.com/consensys/gnark-crypto/ecc" +) + +type ID = ecc.ID + +const ( + Unknown = ecc.UNKNOWN + BN254 = ecc.BN254 + BLS12_377 = ecc.BLS12_377 + BLS12_381 = ecc.BLS12_381 +) + +// GroupID defines the given group +type GroupID uint16 + +const ( + UnknownGroup GroupID = iota + G1 + G2 +) diff --git a/zk/lowLevelFeatures/ec.go b/zk/lowLevelFeatures/ec.go new file mode 100644 index 0000000..e544d19 --- /dev/null +++ b/zk/lowLevelFeatures/ec.go @@ -0,0 +1,71 @@ +package lowLevelFeatures + +// PointAdd performs addition on two points of a specified curve +func PointAdd(curveID ID, group GroupID, point1Bytes, point2Bytes []byte) ([]byte, error) { + if len(point1Bytes) == 0 || len(point2Bytes) == 0 { + return nil, ErrNilOrEmptyInput + } + + handler, ok := EcRegistry[ECParams{curveID, group}] + if !ok { + return nil, ErrInvalidCurve + } + + return handler.Add(point1Bytes, point2Bytes) +} + +// ScalarMul performs scalar multiplication on the specified curve +func ScalarMul(curveID ID, group GroupID, point, scalar []byte) ([]byte, error) { + if len(point) == 0 || len(scalar) == 0 { + return nil, ErrNilOrEmptyInput + } + + handler, ok := EcRegistry[ECParams{curveID, group}] + if !ok { + return nil, ErrInvalidCurve + } + + return handler.Mul(point, scalar) +} + +// MultiExp performs multi exponent on the specified curve +func MultiExp(curveID ID, group GroupID, points [][]byte, scalars [][]byte) ([]byte, error) { + if len(points) == 0 || len(scalars) == 0 { + return nil, ErrNilOrEmptyInput + } + + handler, ok := EcRegistry[ECParams{curveID, group}] + if !ok { + return nil, ErrInvalidCurve + } + + return handler.MultiExp(points, scalars) +} + +// MapToCurve performs map to curve operation on the specified curve +func MapToCurve(curveID ID, group GroupID, element []byte) ([]byte, error) { + if len(element) == 0 { + return nil, ErrNilOrEmptyInput + } + + handler, ok := EcRegistry[ECParams{curveID, group}] + if !ok { + return nil, ErrInvalidCurve + } + + return handler.MapToCurve(element) +} + +// PairingCheck performs pairing check operation on the specified curve +func PairingCheck(curveID ID, pointsG1, pointsG2 [][]byte) (bool, error) { + if len(pointsG1) == 0 || len(pointsG2) == 0 { + return false, ErrNilOrEmptyInput + } + + handler, ok := PairingRegistry[curveID] + if !ok { + return false, ErrInvalidCurve + } + + return handler.PairingCheck(pointsG1, pointsG2) +} diff --git a/zk/lowLevelFeatures/ec_test.go b/zk/lowLevelFeatures/ec_test.go new file mode 100644 index 0000000..d51e14d --- /dev/null +++ b/zk/lowLevelFeatures/ec_test.go @@ -0,0 +1,218 @@ +package lowLevelFeatures + +import ( + "github.com/consensys/gnark-crypto/ecc" + bls12377_gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + bls12377_fr "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + bls12381_gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + bls12381_fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + bls12381_fr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + bn254_gnark "github.com/consensys/gnark-crypto/ecc/bn254" + bn254_fr "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/require" + "math/big" + "testing" +) + +func TestPointAdd(t *testing.T) { + _, _, p1, _ := bls12381_gnark.Generators() + p1Bytes := p1.Marshal() + p2 := p1 + p2Bytes := p2.Marshal() + + var expected bls12381_gnark.G1Affine + expected.Add(&p1, &p2) + + res, err := PointAdd(BLS12_381, G1, p1Bytes, p2Bytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + // Test invalid curve + _, err = PointAdd(Unknown, G1, p1Bytes, p2Bytes) + require.ErrorIs(t, err, ErrInvalidCurve) + + // Test invalid curve + _, err = PointAdd(BLS12_381, G1, nil, nil) + require.NotNil(t, err) +} + +func TestScalarMul(t *testing.T) { + _, _, p1, _ := bls12381_gnark.Generators() + p1Bytes := p1.Marshal() + scalar, err := new(bls12381_fr.Element).SetRandom() + require.NoError(t, err) + scalarBytes := scalar.Marshal() + var expected bls12381_gnark.G1Affine + expected.ScalarMultiplication(&p1, new(big.Int).SetBytes(scalar.Marshal())) + + res, err := ScalarMul(BLS12_381, G1, p1Bytes, scalarBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + // Test invalid curve + _, err = ScalarMul(Unknown, G1, p1Bytes, scalarBytes) + require.ErrorIs(t, err, ErrInvalidCurve) +} + +func TestMultiExp(t *testing.T) { + _, _, p1, _ := bls12381_gnark.Generators() + points := []bls12381_gnark.G1Affine{p1, p1, p1} + pointsBytes := [][]byte{p1.Marshal(), p1.Marshal(), p1.Marshal()} + scalars, scalarsBytes := generateBLS12381Scalars(t, 3) + + var expected bls12381_gnark.G1Affine + _, err := expected.MultiExp(points, scalars, ecc.MultiExpConfig{}) + require.NoError(t, err) + + res, err := MultiExp(BLS12_381, G1, pointsBytes, scalarsBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + // Test invalid curve + _, err = MultiExp(Unknown, G1, pointsBytes, scalarsBytes) + require.ErrorIs(t, err, ErrInvalidCurve) +} + +func TestMapToCurve(t *testing.T) { + var fp bls12381_fp.Element + _, err := fp.SetRandom() + require.NoError(t, err) + element := fp.Marshal() + + _, err = MapToCurve(BLS12_381, G1, element) + require.NoError(t, err) + + // Test invalid curve + _, err = MapToCurve(Unknown, G1, element) + require.ErrorIs(t, err, ErrInvalidCurve) +} + +func TestPairingCheck(t *testing.T) { + t.Run("BLS12_381", func(t *testing.T) { + _, _, p1, p2 := bls12381_gnark.Generators() + a, err := new(bls12381_fr.Element).SetRandom() + require.NoError(t, err) + + var aG1, negAG1 bls12381_gnark.G1Affine + var aG2 bls12381_gnark.G2Affine + aG1.ScalarMultiplication(&p1, new(big.Int).SetBytes(a.Marshal())) + aG2.ScalarMultiplication(&p2, new(big.Int).SetBytes(a.Marshal())) + negAG1.Neg(&aG1) + + pointsG1 := [][]byte{negAG1.Marshal(), p1.Marshal()} + pointsG2 := [][]byte{p2.Marshal(), aG2.Marshal()} + + ok, err := PairingCheck(BLS12_381, pointsG1, pointsG2) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("BLS12_377", func(t *testing.T) { + _, _, p1, p2 := bls12377_gnark.Generators() + a, err := new(bls12377_fr.Element).SetRandom() + require.NoError(t, err) + + var aG1, negAG1 bls12377_gnark.G1Affine + var aG2 bls12377_gnark.G2Affine + aG1.ScalarMultiplication(&p1, new(big.Int).SetBytes(a.Marshal())) + aG2.ScalarMultiplication(&p2, new(big.Int).SetBytes(a.Marshal())) + negAG1.Neg(&aG1) + + pointsG1 := [][]byte{negAG1.Marshal(), p1.Marshal()} + pointsG2 := [][]byte{p2.Marshal(), aG2.Marshal()} + + ok, err := PairingCheck(BLS12_377, pointsG1, pointsG2) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("BN254", func(t *testing.T) { + _, _, p1, p2 := bn254_gnark.Generators() + a, err := new(bn254_fr.Element).SetRandom() + require.NoError(t, err) + + var aG1, negAG1 bn254_gnark.G1Affine + var aG2 bn254_gnark.G2Affine + aG1.ScalarMultiplication(&p1, new(big.Int).SetBytes(a.Marshal())) + aG2.ScalarMultiplication(&p2, new(big.Int).SetBytes(a.Marshal())) + negAG1.Neg(&aG1) + + pointsG1 := [][]byte{negAG1.Marshal(), p1.Marshal()} + pointsG2 := [][]byte{p2.Marshal(), aG2.Marshal()} + + ok, err := PairingCheck(BN254, pointsG1, pointsG2) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("InvalidCurve", func(t *testing.T) { + _, err := PairingCheck(Unknown, [][]byte{[]byte("p1")}, [][]byte{[]byte("p2")}) + require.ErrorIs(t, err, ErrInvalidCurve) + }) +} + +func TestNilOrEmptyInputs(t *testing.T) { + t.Run("PointAdd", func(t *testing.T) { + _, err := PointAdd(BLS12_381, G1, nil, []byte("point2")) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = PointAdd(BLS12_381, G1, []byte{}, []byte("point2")) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = PointAdd(BLS12_381, G1, []byte("point1"), nil) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = PointAdd(BLS12_381, G1, []byte("point1"), []byte{}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + }) + + t.Run("ScalarMul", func(t *testing.T) { + _, err := ScalarMul(BLS12_381, G1, nil, []byte("scalar")) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = ScalarMul(BLS12_381, G1, []byte{}, []byte("scalar")) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = ScalarMul(BLS12_381, G1, []byte("point"), nil) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = ScalarMul(BLS12_381, G1, []byte("point"), []byte{}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + }) + + t.Run("MultiExp", func(t *testing.T) { + _, err := MultiExp(BLS12_381, G1, nil, [][]byte{[]byte("scalar")}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = MultiExp(BLS12_381, G1, [][]byte{}, [][]byte{[]byte("scalar")}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = MultiExp(BLS12_381, G1, [][]byte{[]byte("point")}, nil) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = MultiExp(BLS12_381, G1, [][]byte{[]byte("point")}, [][]byte{}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + }) + + t.Run("MapToCurve", func(t *testing.T) { + _, err := MapToCurve(BLS12_381, G1, nil) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = MapToCurve(BLS12_381, G1, []byte{}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + }) + + t.Run("PairingCheck", func(t *testing.T) { + _, err := PairingCheck(BLS12_381, nil, [][]byte{[]byte("point2")}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = PairingCheck(BLS12_381, [][]byte{}, [][]byte{[]byte("point2")}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = PairingCheck(BLS12_381, [][]byte{[]byte("point1")}, nil) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + + _, err = PairingCheck(BLS12_381, [][]byte{[]byte("point1")}, [][]byte{}) + require.ErrorIs(t, err, ErrNilOrEmptyInput) + }) +} diff --git a/zk/lowLevelFeatures/errors.go b/zk/lowLevelFeatures/errors.go new file mode 100644 index 0000000..beeafa8 --- /dev/null +++ b/zk/lowLevelFeatures/errors.go @@ -0,0 +1,21 @@ +package lowLevelFeatures + +import "errors" + +// ErrInvalidCurve signals invalid curveID error +var ErrInvalidCurve = errors.New("invalid curveID provided") + +// ErrInvalidPoints signals invalid points error +var ErrInvalidPoints = errors.New("invalid points provided") + +// ErrPairingPointsLenShouldMatch signals pairing mismatch error +var ErrPairingPointsLenShouldMatch = errors.New("the number of G1 and G2 points should match for pairing") + +// ErrPointsAndScalarsShouldMatch signals points and scalars mismatch error +var ErrPointsAndScalarsShouldMatch = errors.New("the number of points and scalars provided should match") + +// ErrInvalidFpElement signals invalid field element error +var ErrInvalidFpElement = errors.New("invalid field element") + +// ErrNilOrEmptyInput signals nil or empty input error +var ErrNilOrEmptyInput = errors.New("nil or empty input provided") diff --git a/zk/lowLevelFeatures/helpers_test.go b/zk/lowLevelFeatures/helpers_test.go new file mode 100644 index 0000000..e5c7c4e --- /dev/null +++ b/zk/lowLevelFeatures/helpers_test.go @@ -0,0 +1,49 @@ +package lowLevelFeatures + +import ( + "testing" + + bls12377_fr "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + bls12381_fr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + bn254_fr "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/require" +) + +func generateBLS12381Scalars(t *testing.T, n int) ([]bls12381_fr.Element, [][]byte) { + t.Helper() + scalars := make([]bls12381_fr.Element, n) + scalarsBytes := make([][]byte, n) + for i := 0; i < n; i++ { + s, err := new(bls12381_fr.Element).SetRandom() + require.NoError(t, err) + scalars[i] = *s + scalarsBytes[i] = s.Marshal() + } + return scalars, scalarsBytes +} + +func generateBLS12377Scalars(t *testing.T, n int) ([]bls12377_fr.Element, [][]byte) { + t.Helper() + scalars := make([]bls12377_fr.Element, n) + scalarsBytes := make([][]byte, n) + for i := 0; i < n; i++ { + s, err := new(bls12377_fr.Element).SetRandom() + require.NoError(t, err) + scalars[i] = *s + scalarsBytes[i] = s.Marshal() + } + return scalars, scalarsBytes +} + +func generateBN254Scalars(t *testing.T, n int) ([]bn254_fr.Element, [][]byte) { + t.Helper() + scalars := make([]bn254_fr.Element, n) + scalarsBytes := make([][]byte, n) + for i := 0; i < n; i++ { + s, err := new(bn254_fr.Element).SetRandom() + require.NoError(t, err) + scalars[i] = *s + scalarsBytes[i] = s.Marshal() + } + return scalars, scalarsBytes +} diff --git a/zk/lowLevelFeatures/operations.go b/zk/lowLevelFeatures/operations.go new file mode 100644 index 0000000..2021b81 --- /dev/null +++ b/zk/lowLevelFeatures/operations.go @@ -0,0 +1,833 @@ +package lowLevelFeatures + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + bls123772 "github.com/consensys/gnark-crypto/ecc/bls12-377" + bls12377fp "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + bls12377fr "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + bls123812 "github.com/consensys/gnark-crypto/ecc/bls12-381" + bls12381fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + bls12381fr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + bn2542 "github.com/consensys/gnark-crypto/ecc/bn254" + bn254fp "github.com/consensys/gnark-crypto/ecc/bn254/fp" + bn254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/curves/bls/bls12377" + "github.com/multiversx/mx-chain-crypto-go/curves/bls/bls12381" + "github.com/multiversx/mx-chain-crypto-go/curves/bn254" +) + +// ECParams defines the elliptic curve params structure +type ECParams struct { + Curve ID + Group GroupID +} + +// String stringify the structure +func (ecp *ECParams) String() string { + return fmt.Sprintf("%d_%d", ecp.Curve, ecp.Group) +} + +// ECGroup defines the interface for elliptic curves +type ECGroup interface { + Add([]byte, []byte) ([]byte, error) + Mul([]byte, []byte) ([]byte, error) + MultiExp([][]byte, [][]byte) ([]byte, error) + MapToCurve([]byte) ([]byte, error) +} + +// PairingGroup defines the interface for pairing groups +type PairingGroup interface { + PairingCheck([][]byte, [][]byte) (bool, error) +} + +type bls12381G1 struct{} + +func (b12g1 *bls12381G1) unmarshalPointsG1(points ...[]byte) ([]crypto.Point, error) { + uPoints := make([]crypto.Point, len(points)) + for i, p := range points { + uPoints[i] = bls12381.NewPointG1() + err := uPoints[i].UnmarshalBinary(p) + if err != nil { + return nil, err + } + } + + return uPoints, nil +} + +// Add 2 points on the given curve +func (b12g1 *bls12381G1) Add(p1, p2 []byte) ([]byte, error) { + pointsSlice, err := b12g1.unmarshalPointsG1(p1, p2) + if err != nil { + return nil, err + } + if len(pointsSlice) != 2 { + return nil, ErrInvalidPoints + } + + res, err := pointsSlice[0].Add(pointsSlice[1]) + if err != nil { + return nil, err + } + + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// Mul multiplies the point and the scalar +func (b12g1 *bls12381G1) Mul(point, scalar []byte) ([]byte, error) { + pointsSlice, err := b12g1.unmarshalPointsG1(point) + if err != nil { + return nil, err + } + if len(pointsSlice) != 1 { + return nil, ErrInvalidPoints + } + + sc := bls12381.NewScalar() + err = sc.UnmarshalBinary(scalar) + if err != nil { + return nil, err + } + + res, err := pointsSlice[0].Mul(sc) + if err != nil { + return nil, err + } + + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// MultiExp multiplies the scalars with the point +func (b12g1 *bls12381G1) MultiExp(points, scalars [][]byte) ([]byte, error) { + if len(points) != len(scalars) { + return nil, ErrPointsAndScalarsShouldMatch + } + + underlyingP := make([]bls123812.G1Affine, len(points)) + underlyingS := make([]bls12381fr.Element, len(scalars)) + for i := range points { + p := bls123812.G1Affine{} + err := p.Unmarshal(points[i]) + if err != nil { + return nil, err + } + + underlyingP[i] = p + underlyingS[i] = *new(bls12381fr.Element).SetBytes(scalars[i]) + } + + r := new(bls123812.G1Affine) + r, err := r.MultiExp(underlyingP, underlyingS, ecc.MultiExpConfig{}) + if err != nil { + return nil, err + } + + return r.Marshal(), nil +} + +// MapToCurve creates a point from the given element +func (b12g1 *bls12381G1) MapToCurve(element []byte) ([]byte, error) { + if len(element) != 48 { + return nil, ErrInvalidFpElement + } + + fpEl, err := bls12381fp.BigEndian.Element((*[48]byte)(element)) + if err != nil { + return nil, err + } + + point := bls123812.MapToG1(fpEl) + return point.Marshal(), nil +} + +type bls12381G2 struct{} + +func (b12g2 *bls12381G2) unmarshalPointsG2(points ...[]byte) ([]crypto.Point, error) { + uPoints := make([]crypto.Point, len(points)) + for i, p := range points { + uPoints[i] = bls12381.NewPointG2() + err := uPoints[i].UnmarshalBinary(p) + if err != nil { + return nil, err + } + } + + return uPoints, nil +} + +// Add adds the 2 points together in the given curve +func (b12g2 *bls12381G2) Add(p1, p2 []byte) ([]byte, error) { + pointsSlice, err := b12g2.unmarshalPointsG2(p1, p2) + if err != nil { + return nil, err + } + if len(pointsSlice) != 2 { + return nil, ErrInvalidPoints + } + + res, err := pointsSlice[0].Add(pointsSlice[1]) + if err != nil { + return nil, err + } + + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// Mul multiplies the point and the scalar +func (b12g2 *bls12381G2) Mul(point, scalar []byte) ([]byte, error) { + pointsSlice, err := b12g2.unmarshalPointsG2(point) + if err != nil { + return nil, err + } + if len(pointsSlice) != 1 { + return nil, ErrInvalidPoints + } + + sc := bls12381.NewScalar() + err = sc.UnmarshalBinary(scalar) + if err != nil { + return nil, err + } + + res, err := pointsSlice[0].Mul(sc) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// MultiExp multiplies the points and the scalars +func (b12g2 *bls12381G2) MultiExp(points, scalars [][]byte) ([]byte, error) { + if len(points) != len(scalars) { + return nil, ErrPointsAndScalarsShouldMatch + } + + underlyingP := make([]bls123812.G2Affine, len(points)) + underlyingS := make([]bls12381fr.Element, len(scalars)) + for i := range points { + p := bls123812.G2Affine{} + err := p.Unmarshal(points[i]) + if err != nil { + return nil, err + } + + underlyingP[i] = p + underlyingS[i] = *new(bls12381fr.Element).SetBytes(scalars[i]) + } + + r := new(bls123812.G2Affine) + r, err := r.MultiExp(underlyingP, underlyingS, ecc.MultiExpConfig{}) + if err != nil { + return nil, err + } + + return r.Marshal(), nil +} + +// MapToCurve maps the given element to the curve +func (b12g2 *bls12381G2) MapToCurve(element []byte) ([]byte, error) { + if len(element) != 96 { + return nil, ErrInvalidFpElement + } + + fpEl0, err := bls12381fp.BigEndian.Element((*[48]byte)(element[:48])) + if err != nil { + return nil, err + } + fpEl1, err := bls12381fp.BigEndian.Element((*[48]byte)(element[48:])) + if err != nil { + return nil, err + } + + point := bls123812.MapToG2(bls123812.E2{A0: fpEl0, A1: fpEl1}) + return point.Marshal(), nil +} + +type bls12377G1 struct{} + +func (b12g1 *bls12377G1) unmarshalPointsG1(points ...[]byte) ([]crypto.Point, error) { + uPoints := make([]crypto.Point, len(points)) + for i, p := range points { + uPoints[i] = bls12377.NewPointG1() + err := uPoints[i].UnmarshalBinary(p) + if err != nil { + return nil, err + } + } + + return uPoints, nil +} + +// Add the 2 points together on the curve +func (b12g1 *bls12377G1) Add(p1, p2 []byte) ([]byte, error) { + pointsSlice, err := b12g1.unmarshalPointsG1(p1, p2) + if err != nil { + return nil, err + } + if len(pointsSlice) != 2 { + return nil, ErrInvalidPoints + } + + res, err := pointsSlice[0].Add(pointsSlice[1]) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// Mul multiplies the point and scalar on the given curve +func (b12g1 *bls12377G1) Mul(point, scalar []byte) ([]byte, error) { + pointsSlice, err := b12g1.unmarshalPointsG1(point) + if err != nil { + return nil, err + } + if len(pointsSlice) != 1 { + return nil, ErrInvalidPoints + } + + sc := bls12377.NewScalar() + err = sc.UnmarshalBinary(scalar) + if err != nil { + return nil, err + } + + res, err := pointsSlice[0].Mul(sc) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// MultiExp multiplies the points and the scalars on the given curve +func (b12g1 *bls12377G1) MultiExp(points, scalars [][]byte) ([]byte, error) { + if len(points) != len(scalars) { + return nil, ErrPointsAndScalarsShouldMatch + } + + underlyingP := make([]bls123772.G1Affine, len(points)) + underlyingS := make([]bls12377fr.Element, len(scalars)) + for i := range points { + p := bls123772.G1Affine{} + err := p.Unmarshal(points[i]) + if err != nil { + return nil, err + } + + underlyingP[i] = p + underlyingS[i] = *new(bls12377fr.Element).SetBytes(scalars[i]) + } + + r := new(bls123772.G1Affine) + r, err := r.MultiExp(underlyingP, underlyingS, ecc.MultiExpConfig{}) + if err != nil { + return nil, err + } + + return r.Marshal(), nil +} + +// MapToCurve creates a mapping to the given element +func (b12g1 *bls12377G1) MapToCurve(element []byte) ([]byte, error) { + if len(element) != 48 { + return nil, ErrInvalidFpElement + } + + fpEl, err := bls12377fp.BigEndian.Element((*[48]byte)(element)) + if err != nil { + return nil, err + } + + point := bls123772.MapToG1(fpEl) + return point.Marshal(), nil +} + +type bls12377G2 struct{} + +func (b12g2 *bls12377G2) unmarshalPointsG2(points ...[]byte) ([]crypto.Point, error) { + uPoints := make([]crypto.Point, len(points)) + for i, p := range points { + uPoints[i] = bls12377.NewPointG2() + err := uPoints[i].UnmarshalBinary(p) + if err != nil { + return nil, err + } + } + + return uPoints, nil +} + +// Add the 2 points together on the curve +func (b12g2 *bls12377G2) Add(p1, p2 []byte) ([]byte, error) { + pointsSlice, err := b12g2.unmarshalPointsG2(p1, p2) + if err != nil { + return nil, err + } + if len(pointsSlice) != 2 { + return nil, ErrInvalidPoints + } + + res, err := pointsSlice[0].Add(pointsSlice[1]) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// Mul multiplies the point and scalar on the given curve +func (b12g2 *bls12377G2) Mul(point, scalar []byte) ([]byte, error) { + pointsSlice, err := b12g2.unmarshalPointsG2(point) + if err != nil { + return nil, err + } + if len(pointsSlice) != 1 { + return nil, ErrInvalidPoints + } + + sc := bls12377.NewScalar() + err = sc.UnmarshalBinary(scalar) + if err != nil { + return nil, err + } + + res, err := pointsSlice[0].Mul(sc) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// MultiExp multiplies the points and the scalars on the given curve +func (b12g2 *bls12377G2) MultiExp(points, scalars [][]byte) ([]byte, error) { + if len(points) != len(scalars) { + return nil, ErrPointsAndScalarsShouldMatch + } + + underlyingP := make([]bls123772.G2Affine, len(points)) + underlyingS := make([]bls12377fr.Element, len(scalars)) + for i := range points { + p := bls123772.G2Affine{} + err := p.Unmarshal(points[i]) + if err != nil { + return nil, err + } + + underlyingP[i] = p + underlyingS[i] = *new(bls12377fr.Element).SetBytes(scalars[i]) + } + + r := new(bls123772.G2Affine) + r, err := r.MultiExp(underlyingP, underlyingS, ecc.MultiExpConfig{}) + if err != nil { + return nil, err + } + + return r.Marshal(), nil +} + +// MapToCurve creates a mapping to the given element +func (b12g2 *bls12377G2) MapToCurve(element []byte) ([]byte, error) { + if len(element) != 96 { + return nil, ErrInvalidFpElement + } + + fpEl0, err := bls12377fp.BigEndian.Element((*[48]byte)(element[:48])) + if err != nil { + return nil, err + } + fpEl1, err := bls12377fp.BigEndian.Element((*[48]byte)(element[48:])) + if err != nil { + return nil, err + } + + point := bls123772.MapToG2(bls123772.E2{A0: fpEl0, A1: fpEl1}) + return point.Marshal(), nil +} + +type bn254G1 struct{} + +func (bng1 *bn254G1) unmarshalPointsG1(points ...[]byte) ([]crypto.Point, error) { + uPoints := make([]crypto.Point, len(points)) + for i, p := range points { + uPoints[i] = bn254.NewPointG1() + err := uPoints[i].UnmarshalBinary(p) + if err != nil { + return nil, err + } + } + + return uPoints, nil +} + +// Add the 2 points together on the curve +func (bng1 *bn254G1) Add(p1, p2 []byte) ([]byte, error) { + pointsSlice, err := bng1.unmarshalPointsG1(p1, p2) + if err != nil { + return nil, err + } + if len(pointsSlice) != 2 { + return nil, ErrInvalidPoints + } + + res, err := pointsSlice[0].Add(pointsSlice[1]) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// Mul multiplies the point and scalar on the given curve +func (bng1 *bn254G1) Mul(point, scalar []byte) ([]byte, error) { + pointsSlice, err := bng1.unmarshalPointsG1(point) + if err != nil { + return nil, err + } + if len(pointsSlice) != 1 { + return nil, ErrInvalidPoints + } + + sc := bn254.NewScalar() + err = sc.UnmarshalBinary(scalar) + if err != nil { + return nil, err + } + + res, err := pointsSlice[0].Mul(sc) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// MultiExp multiplies the points and the scalars on the given curve +func (bng1 *bn254G1) MultiExp(points, scalars [][]byte) ([]byte, error) { + if len(points) != len(scalars) { + return nil, ErrPointsAndScalarsShouldMatch + } + + underlyingP := make([]bn2542.G1Affine, len(points)) + underlyingS := make([]bn254fr.Element, len(scalars)) + for i := range points { + p := bn2542.G1Affine{} + err := p.Unmarshal(points[i]) + if err != nil { + return nil, err + } + + underlyingP[i] = p + underlyingS[i] = *new(bn254fr.Element).SetBytes(scalars[i]) + } + + r := new(bn2542.G1Affine) + r, err := r.MultiExp(underlyingP, underlyingS, ecc.MultiExpConfig{}) + if err != nil { + return nil, err + } + + return r.Marshal(), nil +} + +// MapToCurve creates a mapping to the given element +func (bng1 *bn254G1) MapToCurve(element []byte) ([]byte, error) { + if len(element) != 32 { + return nil, ErrInvalidFpElement + } + + fpEl, err := bn254fp.BigEndian.Element((*[32]byte)(element)) + if err != nil { + return nil, err + } + + point := bn2542.MapToG1(fpEl) + return point.Marshal(), nil +} + +type bn254G2 struct{} + +func (bng2 *bn254G2) unmarshalPointsG2(points ...[]byte) ([]crypto.Point, error) { + uPoints := make([]crypto.Point, len(points)) + for i, p := range points { + uPoints[i] = bn254.NewPointG2() + err := uPoints[i].UnmarshalBinary(p) + if err != nil { + return nil, err + } + } + + return uPoints, nil +} + +// Add the 2 points together on the curve +func (bng2 *bn254G2) Add(p1, p2 []byte) ([]byte, error) { + pointsSlice, err := bng2.unmarshalPointsG2(p1, p2) + if err != nil { + return nil, err + } + if len(pointsSlice) != 2 { + return nil, ErrInvalidPoints + } + + res, err := pointsSlice[0].Add(pointsSlice[1]) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// Mul multiplies the point and scalar on the given curve +func (bng2 *bn254G2) Mul(point, scalar []byte) ([]byte, error) { + pointsSlice, err := bng2.unmarshalPointsG2(point) + if err != nil { + return nil, err + } + if len(pointsSlice) != 1 { + return nil, ErrInvalidPoints + } + + sc := bn254.NewScalar() + err = sc.UnmarshalBinary(scalar) + if err != nil { + return nil, err + } + + res, err := pointsSlice[0].Mul(sc) + if err != nil { + return nil, err + } + resBytes, err := res.MarshalBinary() + if err != nil { + return nil, err + } + + return resBytes, nil +} + +// MultiExp multiplies the points and the scalars on the given curve +func (bng2 *bn254G2) MultiExp(points, scalars [][]byte) ([]byte, error) { + if len(points) != len(scalars) { + return nil, ErrPointsAndScalarsShouldMatch + } + + underlyingP := make([]bn2542.G2Affine, len(points)) + underlyingS := make([]bn254fr.Element, len(scalars)) + for i := range points { + p := bn2542.G2Affine{} + err := p.Unmarshal(points[i]) + if err != nil { + return nil, err + } + + underlyingP[i] = p + underlyingS[i] = *new(bn254fr.Element).SetBytes(scalars[i]) + } + + r := new(bn2542.G2Affine) + r, err := r.MultiExp(underlyingP, underlyingS, ecc.MultiExpConfig{}) + if err != nil { + return nil, err + } + + return r.Marshal(), nil +} + +// MapToCurve creates a mapping to the given element +func (bng2 *bn254G2) MapToCurve(element []byte) ([]byte, error) { + if len(element) != 64 { + return nil, ErrInvalidFpElement + } + + fpEl0, err := bn254fp.BigEndian.Element((*[32]byte)(element[:32])) + if err != nil { + return nil, err + } + fpEl1, err := bn254fp.BigEndian.Element((*[32]byte)(element[32:])) + if err != nil { + return nil, err + } + + point := bn2542.MapToG2(bn2542.E2{A0: fpEl0, A1: fpEl1}) + return point.Marshal(), nil +} + +type bls12381Pairing struct{} + +// PairingCheck checks whether the given points are matching on the curve +func (b12381 *bls12381Pairing) PairingCheck(pointsG1, pointsG2 [][]byte) (bool, error) { + if len(pointsG1) != len(pointsG2) { + return false, ErrPairingPointsLenShouldMatch + } + g1Points := make([]bls123812.G1Affine, len(pointsG1)) + g2Points := make([]bls123812.G2Affine, len(pointsG2)) + + for i := range pointsG1 { + pg1 := bls123812.G1Affine{} + err := pg1.Unmarshal(pointsG1[i]) + if err != nil { + return false, err + } + g1Points[i] = pg1 + + pg2 := bls123812.G2Affine{} + err = pg2.Unmarshal(pointsG2[i]) + if err != nil { + return false, err + } + g2Points[i] = pg2 + } + + ok, err := bls123812.PairingCheck(g1Points, g2Points) + if err != nil { + return false, err + } + + return ok, nil +} + +type bls12377Pairing struct{} + +// PairingCheck checks whether the given points are matching on the curve +func (b12377 *bls12377Pairing) PairingCheck(pointsG1, pointsG2 [][]byte) (bool, error) { + if len(pointsG1) != len(pointsG2) { + return false, ErrPairingPointsLenShouldMatch + } + g1Points := make([]bls123772.G1Affine, len(pointsG1)) + g2Points := make([]bls123772.G2Affine, len(pointsG2)) + + for i := range pointsG1 { + pg1 := bls123772.G1Affine{} + err := pg1.Unmarshal(pointsG1[i]) + if err != nil { + return false, err + } + g1Points[i] = pg1 + + pg2 := bls123772.G2Affine{} + err = pg2.Unmarshal(pointsG2[i]) + if err != nil { + return false, err + } + g2Points[i] = pg2 + } + + ok, err := bls123772.PairingCheck(g1Points, g2Points) + if err != nil { + return false, err + } + + return ok, nil +} + +type bn254Pairing struct{} + +// PairingCheck checks whether the given points are matching on the curve +func (bn254 *bn254Pairing) PairingCheck(pointsG1, pointsG2 [][]byte) (bool, error) { + if len(pointsG1) != len(pointsG2) { + return false, ErrPairingPointsLenShouldMatch + } + g1Points := make([]bn2542.G1Affine, len(pointsG1)) + g2Points := make([]bn2542.G2Affine, len(pointsG2)) + + for i := range pointsG1 { + pg1 := bn2542.G1Affine{} + err := pg1.Unmarshal(pointsG1[i]) + if err != nil { + return false, err + } + g1Points[i] = pg1 + + pg2 := bn2542.G2Affine{} + err = pg2.Unmarshal(pointsG2[i]) + if err != nil { + return false, err + } + g2Points[i] = pg2 + } + + ok, err := bn2542.PairingCheck(g1Points, g2Points) + if err != nil { + return false, err + } + + return ok, nil +} + +// EcRegistry is the map for the set of ECParams and ECGroup +var EcRegistry = map[ECParams]ECGroup{ + {BLS12_381, G1}: &bls12381G1{}, + {BLS12_381, G2}: &bls12381G2{}, + {BLS12_377, G1}: &bls12377G1{}, + {BLS12_377, G2}: &bls12377G2{}, + {BN254, G1}: &bn254G1{}, + {BN254, G2}: &bn254G2{}, +} + +// PairingRegistry return the map of pairing types +var PairingRegistry = map[ID]PairingGroup{ + BLS12_381: &bls12381Pairing{}, + BLS12_377: &bls12377Pairing{}, + BN254: &bn254Pairing{}, +} + +// SupportedCurvesRegistry returns the map of accepted CurveIDs +var SupportedCurvesRegistry = map[ID]struct{}{ + BLS12_381: {}, + BLS12_377: {}, + BN254: {}, +} diff --git a/zk/lowLevelFeatures/operations_test.go b/zk/lowLevelFeatures/operations_test.go new file mode 100644 index 0000000..e54d842 --- /dev/null +++ b/zk/lowLevelFeatures/operations_test.go @@ -0,0 +1,759 @@ +package lowLevelFeatures + +import ( + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bls12377_gnark "github.com/consensys/gnark-crypto/ecc/bls12-377" + bls12377_fp "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + bls12377_fr "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + bls12381_gnark "github.com/consensys/gnark-crypto/ecc/bls12-381" + bls12381_fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + bls12381_fr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + bn254_gnark "github.com/consensys/gnark-crypto/ecc/bn254" + bn254_fp "github.com/consensys/gnark-crypto/ecc/bn254/fp" + bn254_fr "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/require" +) + +func TestBLS12_381(t *testing.T) { + t.Run("G1", func(t *testing.T) { + g1 := &bls12381G1{} + _, _, p1, _ := bls12381_gnark.Generators() + p1Bytes := p1.Marshal() + + t.Run("Add", func(t *testing.T) { + p2 := p1 + p2Bytes := p2.Marshal() + var expected bls12381_gnark.G1Affine + expected.Add(&p1, &p2) + + res, err := g1.Add(p1Bytes, p2Bytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.Add([]byte("invalid"), p2Bytes) + require.Error(t, err) + _, err = g1.Add(p1Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12381_gnark.G1Affine + p.X.SetOne() + p.Y.SetOne() + invalidPointBytes := p.Marshal() + _, err = g1.Add(invalidPointBytes, p2Bytes) + require.Error(t, err) + }) + + t.Run("Mul", func(t *testing.T) { + scalar, err := new(bls12381_fr.Element).SetRandom() + require.NoError(t, err) + scalarBytes := scalar.Marshal() + var expected bls12381_gnark.G1Affine + expected.ScalarMultiplication(&p1, new(big.Int).SetBytes(scalar.Marshal())) + + res, err := g1.Mul(p1Bytes, scalarBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.Mul([]byte("invalid"), scalarBytes) + require.Error(t, err) + _, err = g1.Mul(p1Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12381_gnark.G1Affine + p.X.SetOne() + p.Y.SetOne() + invalidPointBytes := p.Marshal() + _, err = g1.Mul(invalidPointBytes, scalarBytes) + require.Error(t, err) + }) + + t.Run("MultiExp", func(t *testing.T) { + points := []bls12381_gnark.G1Affine{p1, p1, p1} + pointsBytes := [][]byte{p1.Marshal(), p1.Marshal(), p1.Marshal()} + scalars, scalarsBytes := generateBLS12381Scalars(t, 3) + + var expected bls12381_gnark.G1Affine + _, err := expected.MultiExp(points, scalars, ecc.MultiExpConfig{}) + require.NoError(t, err) + + res, err := g1.MultiExp(pointsBytes, scalarsBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.MultiExp(pointsBytes, scalarsBytes[:2]) + require.ErrorIs(t, err, ErrPointsAndScalarsShouldMatch) + + invalidPointsBytes := append([][]byte(nil), pointsBytes...) + invalidPointsBytes[0] = []byte("invalid") + _, err = g1.MultiExp(invalidPointsBytes, scalarsBytes) + require.Error(t, err) + }) + + t.Run("MapToCurve", func(t *testing.T) { + var fp bls12381_fp.Element + _, err := fp.SetRandom() + require.NoError(t, err) + element := fp.Marshal() + + expected := bls12381_gnark.MapToG1(fp) + + res, err := g1.MapToCurve(element) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.MapToCurve(element[:47]) + require.ErrorIs(t, err, ErrInvalidFpElement) + + invalidElement := make([]byte, 48) + for i := range invalidElement { + invalidElement[i] = 0xff + } + _, err = g1.MapToCurve(invalidElement) + require.Error(t, err) + }) + }) + + t.Run("G2", func(t *testing.T) { + g2 := &bls12381G2{} + _, _, _, p2 := bls12381_gnark.Generators() + p2Bytes := p2.Marshal() + + t.Run("Add", func(t *testing.T) { + p1 := p2 + p1Bytes := p1.Marshal() + var expected bls12381_gnark.G2Affine + expected.Add(&p2, &p1) + + res, err := g2.Add(p2Bytes, p1Bytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.Add([]byte("invalid"), p1Bytes) + require.Error(t, err) + _, err = g2.Add(p2Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12381_gnark.G2Affine + p.X.A0.SetOne() + p.Y.A0.SetOne() + invalidPointBytes := p.Marshal() + _, err = g2.Add(invalidPointBytes, p1Bytes) + require.Error(t, err) + }) + + t.Run("Mul", func(t *testing.T) { + scalar, err := new(bls12381_fr.Element).SetRandom() + require.NoError(t, err) + scalarBytes := scalar.Marshal() + var expected bls12381_gnark.G2Affine + expected.ScalarMultiplication(&p2, new(big.Int).SetBytes(scalar.Marshal())) + + res, err := g2.Mul(p2Bytes, scalarBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.Mul([]byte("invalid"), scalarBytes) + require.Error(t, err) + _, err = g2.Mul(p2Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12381_gnark.G2Affine + p.X.A0.SetOne() + p.Y.A0.SetOne() + invalidPointBytes := p.Marshal() + _, err = g2.Mul(invalidPointBytes, scalarBytes) + require.Error(t, err) + }) + + t.Run("MultiExp", func(t *testing.T) { + points := []bls12381_gnark.G2Affine{p2, p2, p2} + pointsBytes := [][]byte{p2.Marshal(), p2.Marshal(), p2.Marshal()} + scalars, scalarsBytes := generateBLS12381Scalars(t, 3) + + var expected bls12381_gnark.G2Affine + _, err := expected.MultiExp(points, scalars, ecc.MultiExpConfig{}) + require.NoError(t, err) + + res, err := g2.MultiExp(pointsBytes, scalarsBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.MultiExp(pointsBytes, scalarsBytes[:2]) + require.ErrorIs(t, err, ErrPointsAndScalarsShouldMatch) + + invalidPointsBytes := append([][]byte(nil), pointsBytes...) + invalidPointsBytes[0] = []byte("invalid") + _, err = g2.MultiExp(invalidPointsBytes, scalarsBytes) + require.Error(t, err) + }) + + t.Run("MapToCurve", func(t *testing.T) { + var fpE2 bls12381_gnark.E2 + _, err := fpE2.SetRandom() + require.NoError(t, err) + element := append(fpE2.A0.Marshal(), fpE2.A1.Marshal()...) + + expected := bls12381_gnark.MapToG2(fpE2) + + res, err := g2.MapToCurve(element) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.MapToCurve(element[:95]) + require.ErrorIs(t, err, ErrInvalidFpElement) + + invalidElement := make([]byte, 96) + for i := range invalidElement { + invalidElement[i] = 0xff + } + _, err = g2.MapToCurve(invalidElement) + require.Error(t, err) + + invalidElement2 := make([]byte, 96) + var fp bls12381_fp.Element + _, err = fp.SetRandom() + require.NoError(t, err) + copy(invalidElement2, fp.Marshal()) + for i := 48; i < 96; i++ { + invalidElement2[i] = 0xff + } + _, err = g2.MapToCurve(invalidElement2) + require.Error(t, err) + }) + }) + + t.Run("Pairing", func(t *testing.T) { + pairing := &bls12381Pairing{} + _, _, p1, p2 := bls12381_gnark.Generators() + + t.Run("PairingCheck", func(t *testing.T) { + a, err := new(bls12381_fr.Element).SetRandom() + require.NoError(t, err) + + var aG1, negAG1 bls12381_gnark.G1Affine + var aG2 bls12381_gnark.G2Affine + aG1.ScalarMultiplication(&p1, new(big.Int).SetBytes(a.Marshal())) + aG2.ScalarMultiplication(&p2, new(big.Int).SetBytes(a.Marshal())) + negAG1.Neg(&aG1) + + pointsG1 := [][]byte{negAG1.Marshal(), p1.Marshal()} + pointsG2 := [][]byte{p2.Marshal(), aG2.Marshal()} + + res, err := pairing.PairingCheck(pointsG1, pointsG2) + require.NoError(t, err) + require.True(t, res) + + _, err = pairing.PairingCheck(pointsG1, pointsG2[:1]) + require.ErrorIs(t, err, ErrPairingPointsLenShouldMatch) + + invalidPointsG1 := append([][]byte(nil), pointsG1...) + invalidPointsG1[0] = []byte("invalid") + _, err = pairing.PairingCheck(invalidPointsG1, pointsG2) + require.Error(t, err) + + invalidPointsG2 := append([][]byte(nil), pointsG2...) + invalidPointsG2[0] = []byte("invalid") + _, err = pairing.PairingCheck(pointsG1, invalidPointsG2) + require.Error(t, err) + }) + }) +} + +func TestBLS12_377(t *testing.T) { + t.Run("G1", func(t *testing.T) { + g1 := &bls12377G1{} + _, _, p1, _ := bls12377_gnark.Generators() + p1Bytes := p1.Marshal() + + t.Run("Add", func(t *testing.T) { + p2 := p1 + p2Bytes := p2.Marshal() + var expected bls12377_gnark.G1Affine + expected.Add(&p1, &p2) + + res, err := g1.Add(p1Bytes, p2Bytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.Add([]byte("invalid"), p2Bytes) + require.Error(t, err) + _, err = g1.Add(p1Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12377_gnark.G1Affine + p.X.SetOne() + p.Y.SetOne() + invalidPointBytes := p.Marshal() + _, err = g1.Add(invalidPointBytes, p2Bytes) + require.Error(t, err) + }) + + t.Run("Mul", func(t *testing.T) { + scalar, err := new(bls12377_fr.Element).SetRandom() + require.NoError(t, err) + scalarBytes := scalar.Marshal() + var expected bls12377_gnark.G1Affine + expected.ScalarMultiplication(&p1, new(big.Int).SetBytes(scalar.Marshal())) + + res, err := g1.Mul(p1Bytes, scalarBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.Mul([]byte("invalid"), scalarBytes) + require.Error(t, err) + _, err = g1.Mul(p1Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12377_gnark.G1Affine + p.X.SetOne() + p.Y.SetOne() + invalidPointBytes := p.Marshal() + _, err = g1.Mul(invalidPointBytes, scalarBytes) + require.Error(t, err) + }) + + t.Run("MultiExp", func(t *testing.T) { + points := []bls12377_gnark.G1Affine{p1, p1, p1} + pointsBytes := [][]byte{p1.Marshal(), p1.Marshal(), p1.Marshal()} + scalars, scalarsBytes := generateBLS12377Scalars(t, 3) + + var expected bls12377_gnark.G1Affine + _, err := expected.MultiExp(points, scalars, ecc.MultiExpConfig{}) + require.NoError(t, err) + + res, err := g1.MultiExp(pointsBytes, scalarsBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.MultiExp(pointsBytes, scalarsBytes[:2]) + require.ErrorIs(t, err, ErrPointsAndScalarsShouldMatch) + + invalidPointsBytes := append([][]byte(nil), pointsBytes...) + invalidPointsBytes[0] = []byte("invalid") + _, err = g1.MultiExp(invalidPointsBytes, scalarsBytes) + require.Error(t, err) + }) + + t.Run("MapToCurve", func(t *testing.T) { + var fp bls12377_fp.Element + _, err := fp.SetRandom() + require.NoError(t, err) + element := fp.Marshal() + + expected := bls12377_gnark.MapToG1(fp) + + res, err := g1.MapToCurve(element) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.MapToCurve(element[:47]) + require.ErrorIs(t, err, ErrInvalidFpElement) + + invalidElement := make([]byte, 48) + for i := range invalidElement { + invalidElement[i] = 0xff + } + _, err = g1.MapToCurve(invalidElement) + require.Error(t, err) + }) + }) + + t.Run("G2", func(t *testing.T) { + g2 := &bls12377G2{} + _, _, _, p2 := bls12377_gnark.Generators() + p2Bytes := p2.Marshal() + + t.Run("Add", func(t *testing.T) { + p1 := p2 + p1Bytes := p1.Marshal() + var expected bls12377_gnark.G2Affine + expected.Add(&p2, &p1) + + res, err := g2.Add(p2Bytes, p1Bytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.Add([]byte("invalid"), p1Bytes) + require.Error(t, err) + _, err = g2.Add(p2Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12377_gnark.G2Affine + p.X.A0.SetOne() + p.Y.A0.SetOne() + invalidPointBytes := p.Marshal() + _, err = g2.Add(invalidPointBytes, p1Bytes) + require.Error(t, err) + }) + + t.Run("Mul", func(t *testing.T) { + scalar, err := new(bls12377_fr.Element).SetRandom() + require.NoError(t, err) + scalarBytes := scalar.Marshal() + var expected bls12377_gnark.G2Affine + expected.ScalarMultiplication(&p2, new(big.Int).SetBytes(scalar.Marshal())) + + res, err := g2.Mul(p2Bytes, scalarBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.Mul([]byte("invalid"), scalarBytes) + require.Error(t, err) + _, err = g2.Mul(p2Bytes, []byte("invalid")) + require.Error(t, err) + + var p bls12377_gnark.G2Affine + p.X.A0.SetOne() + p.Y.A0.SetOne() + invalidPointBytes := p.Marshal() + _, err = g2.Mul(invalidPointBytes, scalarBytes) + require.Error(t, err) + }) + + t.Run("MultiExp", func(t *testing.T) { + points := []bls12377_gnark.G2Affine{p2, p2, p2} + pointsBytes := [][]byte{p2.Marshal(), p2.Marshal(), p2.Marshal()} + scalars, scalarsBytes := generateBLS12377Scalars(t, 3) + + var expected bls12377_gnark.G2Affine + _, err := expected.MultiExp(points, scalars, ecc.MultiExpConfig{}) + require.NoError(t, err) + + res, err := g2.MultiExp(pointsBytes, scalarsBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.MultiExp(pointsBytes, scalarsBytes[:2]) + require.ErrorIs(t, err, ErrPointsAndScalarsShouldMatch) + + invalidPointsBytes := append([][]byte(nil), pointsBytes...) + invalidPointsBytes[0] = []byte("invalid") + _, err = g2.MultiExp(invalidPointsBytes, scalarsBytes) + require.Error(t, err) + }) + + t.Run("MapToCurve", func(t *testing.T) { + var fpE2 bls12377_gnark.E2 + _, err := fpE2.SetRandom() + require.NoError(t, err) + element := append(fpE2.A0.Marshal(), fpE2.A1.Marshal()...) + + expected := bls12377_gnark.MapToG2(fpE2) + + res, err := g2.MapToCurve(element) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.MapToCurve(element[:95]) + require.ErrorIs(t, err, ErrInvalidFpElement) + + invalidElement := make([]byte, 96) + for i := range invalidElement { + invalidElement[i] = 0xff + } + _, err = g2.MapToCurve(invalidElement) + require.Error(t, err) + + invalidElement2 := make([]byte, 96) + var fp bls12377_fp.Element + _, err = fp.SetRandom() + require.NoError(t, err) + copy(invalidElement2, fp.Marshal()) + for i := 48; i < 96; i++ { + invalidElement2[i] = 0xff + } + _, err = g2.MapToCurve(invalidElement2) + require.Error(t, err) + }) + }) + + t.Run("Pairing", func(t *testing.T) { + pairing := &bls12377Pairing{} + _, _, p1, p2 := bls12377_gnark.Generators() + + t.Run("PairingCheck", func(t *testing.T) { + a, err := new(bls12377_fr.Element).SetRandom() + require.NoError(t, err) + + var aG1, negAG1 bls12377_gnark.G1Affine + var aG2 bls12377_gnark.G2Affine + aG1.ScalarMultiplication(&p1, new(big.Int).SetBytes(a.Marshal())) + aG2.ScalarMultiplication(&p2, new(big.Int).SetBytes(a.Marshal())) + negAG1.Neg(&aG1) + + pointsG1 := [][]byte{negAG1.Marshal(), p1.Marshal()} + pointsG2 := [][]byte{p2.Marshal(), aG2.Marshal()} + + res, err := pairing.PairingCheck(pointsG1, pointsG2) + require.NoError(t, err) + require.True(t, res) + + _, err = pairing.PairingCheck(pointsG1, pointsG2[:1]) + require.ErrorIs(t, err, ErrPairingPointsLenShouldMatch) + + invalidPointsG1 := append([][]byte(nil), pointsG1...) + invalidPointsG1[0] = []byte("invalid") + _, err = pairing.PairingCheck(invalidPointsG1, pointsG2) + require.Error(t, err) + + invalidPointsG2 := append([][]byte(nil), pointsG2...) + invalidPointsG2[0] = []byte("invalid") + _, err = pairing.PairingCheck(pointsG1, invalidPointsG2) + require.Error(t, err) + }) + }) +} + +func TestBN254(t *testing.T) { + t.Run("G1", func(t *testing.T) { + g1 := &bn254G1{} + _, _, p1, _ := bn254_gnark.Generators() + p1Bytes := p1.Marshal() + + t.Run("Add", func(t *testing.T) { + p2 := p1 + p2Bytes := p2.Marshal() + var expected bn254_gnark.G1Affine + expected.Add(&p1, &p2) + + res, err := g1.Add(p1Bytes, p2Bytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.Add([]byte("invalid"), p2Bytes) + require.Error(t, err) + _, err = g1.Add(p1Bytes, []byte("invalid")) + require.Error(t, err) + + var p bn254_gnark.G1Affine + p.X.SetOne() + p.Y.SetOne() + invalidPointBytes := p.Marshal() + _, err = g1.Add(invalidPointBytes, p2Bytes) + require.Error(t, err) + }) + + t.Run("Mul", func(t *testing.T) { + scalar, err := new(bn254_fr.Element).SetRandom() + require.NoError(t, err) + scalarBytes := scalar.Marshal() + var expected bn254_gnark.G1Affine + expected.ScalarMultiplication(&p1, new(big.Int).SetBytes(scalar.Marshal())) + + res, err := g1.Mul(p1Bytes, scalarBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.Mul([]byte("invalid"), scalarBytes) + require.Error(t, err) + _, err = g1.Mul(p1Bytes, []byte("invalid")) + require.Error(t, err) + + var p bn254_gnark.G1Affine + p.X.SetOne() + p.Y.SetOne() + invalidPointBytes := p.Marshal() + _, err = g1.Mul(invalidPointBytes, scalarBytes) + require.Error(t, err) + }) + + t.Run("MultiExp", func(t *testing.T) { + points := []bn254_gnark.G1Affine{p1, p1, p1} + pointsBytes := [][]byte{p1.Marshal(), p1.Marshal(), p1.Marshal()} + scalars, scalarsBytes := generateBN254Scalars(t, 3) + + var expected bn254_gnark.G1Affine + _, err := expected.MultiExp(points, scalars, ecc.MultiExpConfig{}) + require.NoError(t, err) + + res, err := g1.MultiExp(pointsBytes, scalarsBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.MultiExp(pointsBytes, scalarsBytes[:2]) + require.ErrorIs(t, err, ErrPointsAndScalarsShouldMatch) + + invalidPointsBytes := append([][]byte(nil), pointsBytes...) + invalidPointsBytes[0] = []byte("invalid") + _, err = g1.MultiExp(invalidPointsBytes, scalarsBytes) + require.Error(t, err) + }) + + t.Run("MapToCurve", func(t *testing.T) { + var fp bn254_fp.Element + _, err := fp.SetRandom() + require.NoError(t, err) + element := fp.Marshal() + + expected := bn254_gnark.MapToG1(fp) + + res, err := g1.MapToCurve(element) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g1.MapToCurve(element[:31]) + require.ErrorIs(t, err, ErrInvalidFpElement) + + invalidElement := make([]byte, 32) + for i := range invalidElement { + invalidElement[i] = 0xff + } + _, err = g1.MapToCurve(invalidElement) + require.Error(t, err) + }) + }) + + t.Run("G2", func(t *testing.T) { + g2 := &bn254G2{} + _, _, _, p2 := bn254_gnark.Generators() + p2Bytes := p2.Marshal() + + t.Run("Add", func(t *testing.T) { + p1 := p2 + p1Bytes := p1.Marshal() + var expected bn254_gnark.G2Affine + expected.Add(&p2, &p1) + + res, err := g2.Add(p2Bytes, p1Bytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.Add([]byte("invalid"), p1Bytes) + require.Error(t, err) + _, err = g2.Add(p2Bytes, []byte("invalid")) + require.Error(t, err) + + var p bn254_gnark.G2Affine + p.X.A0.SetOne() + p.Y.A0.SetOne() + invalidPointBytes := p.Marshal() + _, err = g2.Add(invalidPointBytes, p1Bytes) + require.Error(t, err) + }) + + t.Run("Mul", func(t *testing.T) { + scalar, err := new(bn254_fr.Element).SetRandom() + require.NoError(t, err) + scalarBytes := scalar.Marshal() + var expected bn254_gnark.G2Affine + expected.ScalarMultiplication(&p2, new(big.Int).SetBytes(scalar.Marshal())) + + res, err := g2.Mul(p2Bytes, scalarBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.Mul([]byte("invalid"), scalarBytes) + require.Error(t, err) + _, err = g2.Mul(p2Bytes, []byte("invalid")) + require.Error(t, err) + + var p bn254_gnark.G2Affine + p.X.A0.SetOne() + p.Y.A0.SetOne() + invalidPointBytes := p.Marshal() + _, err = g2.Mul(invalidPointBytes, scalarBytes) + require.Error(t, err) + }) + + t.Run("MultiExp", func(t *testing.T) { + points := []bn254_gnark.G2Affine{p2, p2, p2} + pointsBytes := [][]byte{p2.Marshal(), p2.Marshal(), p2.Marshal()} + scalars, scalarsBytes := generateBN254Scalars(t, 3) + + var expected bn254_gnark.G2Affine + _, err := expected.MultiExp(points, scalars, ecc.MultiExpConfig{}) + require.NoError(t, err) + + res, err := g2.MultiExp(pointsBytes, scalarsBytes) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.MultiExp(pointsBytes, scalarsBytes[:2]) + require.ErrorIs(t, err, ErrPointsAndScalarsShouldMatch) + + invalidPointsBytes := append([][]byte(nil), pointsBytes...) + invalidPointsBytes[0] = []byte("invalid") + _, err = g2.MultiExp(invalidPointsBytes, scalarsBytes) + require.Error(t, err) + }) + + t.Run("MapToCurve", func(t *testing.T) { + var fpE2 bn254_gnark.E2 + _, err := fpE2.SetRandom() + require.NoError(t, err) + element := append(fpE2.A0.Marshal(), fpE2.A1.Marshal()...) + + expected := bn254_gnark.MapToG2(fpE2) + + res, err := g2.MapToCurve(element) + require.NoError(t, err) + require.Equal(t, expected.Marshal(), res) + + _, err = g2.MapToCurve(element[:63]) + require.ErrorIs(t, err, ErrInvalidFpElement) + + invalidElement := make([]byte, 64) + for i := range invalidElement { + invalidElement[i] = 0xff + } + _, err = g2.MapToCurve(invalidElement) + require.Error(t, err) + + invalidElement2 := make([]byte, 64) + var fp bn254_fp.Element + _, err = fp.SetRandom() + require.NoError(t, err) + copy(invalidElement2, fp.Marshal()) + for i := 32; i < 64; i++ { + invalidElement2[i] = 0xff + } + _, err = g2.MapToCurve(invalidElement2) + require.Error(t, err) + }) + }) + + t.Run("Pairing", func(t *testing.T) { + pairing := &bn254Pairing{} + _, _, p1, p2 := bn254_gnark.Generators() + + t.Run("PairingCheck", func(t *testing.T) { + a, err := new(bn254_fr.Element).SetRandom() + require.NoError(t, err) + + var aG1, negAG1 bn254_gnark.G1Affine + var aG2 bn254_gnark.G2Affine + aG1.ScalarMultiplication(&p1, new(big.Int).SetBytes(a.Marshal())) + aG2.ScalarMultiplication(&p2, new(big.Int).SetBytes(a.Marshal())) + negAG1.Neg(&aG1) + + pointsG1 := [][]byte{negAG1.Marshal(), p1.Marshal()} + pointsG2 := [][]byte{p2.Marshal(), aG2.Marshal()} + + res, err := pairing.PairingCheck(pointsG1, pointsG2) + require.NoError(t, err) + require.True(t, res) + + _, err = pairing.PairingCheck(pointsG1, pointsG2[:1]) + require.ErrorIs(t, err, ErrPairingPointsLenShouldMatch) + + invalidPointsG1 := append([][]byte(nil), pointsG1...) + invalidPointsG1[0] = []byte("invalid") + _, err = pairing.PairingCheck(invalidPointsG1, pointsG2) + require.Error(t, err) + + invalidPointsG2 := append([][]byte(nil), pointsG2...) + invalidPointsG2[0] = []byte("invalid") + _, err = pairing.PairingCheck(pointsG1, invalidPointsG2) + require.Error(t, err) + }) + }) +} + +func TestECParams_String(t *testing.T) { + params := ECParams{Curve: BN254, Group: G1} + require.Equal(t, fmt.Sprintf("%d_%d", BN254, G1), params.String()) +} diff --git a/zk/plonk/verify.go b/zk/plonk/verify.go new file mode 100644 index 0000000..a5f5932 --- /dev/null +++ b/zk/plonk/verify.go @@ -0,0 +1,49 @@ +package plonk + +import ( + "bytes" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/backend/witness" + "github.com/multiversx/mx-chain-crypto-go/zk/lowLevelFeatures" +) + +// VerifyPlonk verifies the plonk signature on the given curveID +func VerifyPlonk(curveID uint16, proofBytes, vkBytes, pubWitnessBytes []byte) (bool, error) { + if len(proofBytes) == 0 || len(vkBytes) == 0 || len(pubWitnessBytes) == 0 { + return false, lowLevelFeatures.ErrNilOrEmptyInput + } + + _, ok := lowLevelFeatures.SupportedCurvesRegistry[ecc.ID(curveID)] + if !ok { + return false, lowLevelFeatures.ErrInvalidCurve + } + + vk := plonk.NewVerifyingKey(ecc.ID(curveID)) + if _, err := vk.ReadFrom(bytes.NewReader(vkBytes)); err != nil { + return false, err + } + + proof := plonk.NewProof(ecc.ID(curveID)) + if _, err := proof.ReadFrom(bytes.NewReader(proofBytes)); err != nil { + return false, err + } + + w, err := witness.New(ecc.ID(curveID).ScalarField()) + if err != nil { + return false, err + } + + err = w.UnmarshalBinary(pubWitnessBytes) + if err != nil { + return false, err + } + + err = plonk.Verify(proof, vk, w) + if err != nil { + return false, nil + } + + return true, nil +} diff --git a/zk/plonk/verify_test.go b/zk/plonk/verify_test.go new file mode 100644 index 0000000..130b2be --- /dev/null +++ b/zk/plonk/verify_test.go @@ -0,0 +1,108 @@ +package plonk + +import ( + "bytes" + "github.com/multiversx/mx-chain-crypto-go/zk/lowLevelFeatures" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/examples/exponentiate" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test/unsafekzg" + "github.com/stretchr/testify/require" +) + +func TestVerifyPlonk(t *testing.T) { + css, err := frontend.Compile(ecc.BLS12_381.ScalarField(), scs.NewBuilder, &exponentiate.Circuit{}) + require.Nil(t, err) + + srs, srsLagrange, err := unsafekzg.NewSRS(css) + require.Nil(t, err) + + // Setup on the prover side + pk, vk, err := plonk.Setup(css, srs, srsLagrange) + require.Nil(t, err) + + homework := &exponentiate.Circuit{ + X: 2, + Y: 16, + E: 4, + } + + witness, err := frontend.NewWitness(homework, ecc.BLS12_381.ScalarField()) + require.Nil(t, err) + + proof, err := plonk.Prove(css, pk, witness) + require.Nil(t, err) + + var serializedProof bytes.Buffer + _, err = proof.WriteTo(&serializedProof) + require.Nil(t, err) + + var serializedVK bytes.Buffer + _, err = vk.WriteTo(&serializedVK) + require.Nil(t, err) + + // There are two ways to generate the public witness - either from the prover full witness, either recreate + // using the circuit with only the public inputs into it + pubW, err := witness.Public() + require.Nil(t, err) + pubWBytes, err := pubW.MarshalBinary() + require.Nil(t, err) + + // Now a tx can do: verify@proof_bytes@pub_witness_bytes; the curve_id and vk should be in the contract state + verified, err := VerifyPlonk(uint16(ecc.BLS12_381), serializedProof.Bytes(), serializedVK.Bytes(), pubWBytes) + require.True(t, verified) + require.Nil(t, err) + + // Invalid proof + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), []byte{}, serializedVK.Bytes(), pubWBytes) + require.False(t, verified) + require.Error(t, err) + + // Invalid public witness + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), serializedProof.Bytes(), serializedVK.Bytes(), []byte{}) + require.False(t, verified) + require.Error(t, err) + + // Invalid vk + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), serializedProof.Bytes(), []byte{}, pubWBytes) + require.False(t, verified) + require.Error(t, err) + + _, err = VerifyPlonk(uint16(ecc.UNKNOWN), serializedProof.Bytes(), serializedVK.Bytes(), pubWBytes) + require.Error(t, err) + _, err = VerifyPlonk(42, serializedProof.Bytes(), serializedVK.Bytes(), pubWBytes) + require.Error(t, err) +} + +func TestVerifyPlonk_NilOrEmptyInput(t *testing.T) { + // Invalid proof + verified, err := VerifyPlonk(uint16(ecc.BLS12_381), nil, []byte("vk"), []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), []byte{}, []byte("vk"), []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + // Invalid vk + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), []byte("proof"), nil, []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), []byte("proof"), []byte{}, []byte("pubw")) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + // Invalid public witness + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), []byte("proof"), []byte("vk"), nil) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) + + verified, err = VerifyPlonk(uint16(ecc.BLS12_381), []byte("proof"), []byte("vk"), []byte{}) + require.False(t, verified) + require.ErrorIs(t, err, lowLevelFeatures.ErrNilOrEmptyInput) +}