Skip to content

Perf/bounded comparator #1460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions frontend/api.go
Original file line number Diff line number Diff line change
@@ -99,7 +99,7 @@ type API interface {
//
// If the absolute difference between the variables i1 and i2 is known, then
// it is more efficient to use the bounded methods in package
// [github.com/consensys/gnark/std/math/bits].
// [https://github.com/Consensys/gnark/blob/master/std/math/cmp].
Cmp(i1, i2 Variable) Variable

// ---------------------------------------------------------------------------------------------
@@ -121,7 +121,7 @@ type API interface {
//
// If the absolute difference between the variables b and bound is known, then
// it is more efficient to use the bounded methods in package
// [github.com/consensys/gnark/std/math/bits].
// [https://github.com/Consensys/gnark/blob/master/std/math/cmp].
AssertIsLessOrEqual(v Variable, bound Variable)

// Println behaves like fmt.Println but accepts cd.Variable as parameter
11 changes: 10 additions & 1 deletion std/math/cmp/bounded.go
Original file line number Diff line number Diff line change
@@ -2,10 +2,11 @@ package cmp

import (
"fmt"
"math/big"

"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/bits"
"math/big"
)

func init() {
@@ -151,6 +152,10 @@ func (bc BoundedComparator) AssertIsLess(a, b frontend.Variable) {
}

// IsLess returns 1 if a < b, and returns 0 if a >= b.
// When |a - b| >= 2^absDiffUpp.BitLen(), a panic is occurred,
// then the method has no return value, and a proof can not be generated.
// It is recommended to use the IsLess method to get a valid return value
// in https://github.com/Consensys/gnark/blob/master/std/math/cmp/generic.go
func (bc BoundedComparator) IsLess(a, b frontend.Variable) frontend.Variable {
res, err := bc.api.Compiler().NewHint(isLessOutputHint, 1, a, b)
if err != nil {
@@ -164,6 +169,10 @@ func (bc BoundedComparator) IsLess(a, b frontend.Variable) frontend.Variable {
}

// IsLessEq returns 1 if a <= b, and returns 0 if a > b.
// When |a - b| > 2^absDiffUpp.BitLen(), a panic is occurred,
// then the method has no return value, and a proof can not be generated.
// It is recommended to use the IsLessOrEqual method to get a valid return value
// in https://github.com/Consensys/gnark/blob/master/std/math/cmp/generic.go
func (bc BoundedComparator) IsLessEq(a, b frontend.Variable) frontend.Variable {
// a <= b <==> a < b + 1
return bc.IsLess(a, bc.api.Add(b, 1))
76 changes: 74 additions & 2 deletions std/math/cmp/bounded_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package cmp_test

import (
"fmt"
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/test"
"math/big"
"testing"
)

func TestAssertIsLessEq(t *testing.T) {
@@ -143,3 +146,72 @@ func (c *minCircuit) Define(api frontend.API) error {

return nil
}

type boundedComparatorCircuit struct {
A frontend.Variable

WantIsLess int
WantIsLessEq int
Bound int
}

func (c *boundedComparatorCircuit) Define(api frontend.API) error {
comparator := cmp.NewBoundedComparator(api, big.NewInt(int64(c.Bound)), true)
if c.WantIsLess == 1 {
comparator.AssertIsLess(c.A, c.Bound)
}
if c.WantIsLessEq == 1 {
comparator.AssertIsLessEq(c.A, c.Bound)
}

api.AssertIsEqual(c.WantIsLess, comparator.IsLess(c.A, c.Bound))
api.AssertIsEqual(c.WantIsLessEq, comparator.IsLessEq(c.A, c.Bound))

return nil
}

type boundedComparatorTestCase struct {
A int

WantIsLess int
WantIsLessEq int
Bound int

expectedSuccess bool
}

func TestBoundedComparator(t *testing.T) {
assert := test.NewAssert(t)

var testCases []boundedComparatorTestCase
for bound := 2; bound <= 15; bound++ {
c := 1 << (big.NewInt(int64(bound)).BitLen())
for i := 0; i <= bound+5; i++ {
testCase := boundedComparatorTestCase{
A: i, Bound: bound, WantIsLess: 1, WantIsLessEq: 1, expectedSuccess: true}
if i >= bound {
testCase.WantIsLess = 0
if i > bound {
testCase.WantIsLessEq = 0
}
}
if i-bound >= c {
testCase.expectedSuccess = false
}
testCases = append(testCases, testCase)
}
}

for _, tc := range testCases {
assert.Run(func(assert *test.Assert) {
circuit := &boundedComparatorCircuit{Bound: tc.Bound, WantIsLess: tc.WantIsLess, WantIsLessEq: tc.WantIsLessEq}
assignment := &boundedComparatorCircuit{A: tc.A}
err := test.IsSolved(circuit, assignment, ecc.BN254.ScalarField())
if tc.expectedSuccess {
assert.NoError(err)
} else {
assert.Error(err)
}
}, fmt.Sprintf("bound=%d a=%d", tc.Bound, tc.A))
}
}