Skip to content

feat: add modular inverse for variable modulus #1507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions std/math/emulated/custommod.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// NB! circuit complexity depends on T rather on the actual length of the modulus.
func (f *Field[T]) ModMul(a, b *Element[T], modulus *Element[T]) *Element[T] {
// fast path when either of the inputs is zero then result is always zero
if len(a.Limbs) == 0 || len(b.Limbs) == 0 {
if a.isStrictZero() || b.isStrictZero() {
return f.Zero()
}
res := f.mulMod(a, b, 0, modulus)
Expand Down Expand Up @@ -97,8 +97,8 @@ func (f *Field[T]) ModAssertIsEqual(a, b *Element[T], modulus *Element[T]) {
//
// NB! circuit complexity depends on T rather on the actual length of the modulus.
func (f *Field[T]) ModExp(base, exp, modulus *Element[T]) *Element[T] {
// fasth path when the base is zero then result is always zero
if len(base.Limbs) == 0 {
if base.isStrictZero() {
// fast path when the base is zero then result is always zero
return f.Zero()
}
expBts := f.ToBits(exp)
Expand All @@ -112,3 +112,27 @@ func (f *Field[T]) ModExp(base, exp, modulus *Element[T]) *Element[T] {
res = f.Select(expBts[n-1], f.ModMul(base, res, modulus), res)
return res
}

// ModInverse computes the modular inverse of a mod modulus. Instead of taking
// modulus as a constant parametrized by T, it is passed as an argument. This
// allows to use a variable modulus in the circuit. Type parameter T should be
// sufficiently big to fit a and modulus. Recommended to use [emparams.Mod1e512]
// or [emparams.Mod1e4096].
//
// The method panics at solving time if the modular inverse does not exist.
//
// NB! circuit complexity depends on T rather on the actual length of the modulus.
func (f *Field[T]) ModInverse(a, modulus *Element[T]) *Element[T] {
// fast path when a is zero then result is always zero
if a.isStrictZero() {
return f.Zero()
}
k, err := f.computeInverseHint(a.Limbs, modulus)
if err != nil {
panic("failed to compute inverse hint: " + err.Error())
}
inv := f.packLimbs(k, true)
mul := f.ModMul(a, inv, modulus)
f.ModAssertIsEqual(mul, f.One(), modulus)
return inv
}
45 changes: 45 additions & 0 deletions std/math/emulated/custommod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,48 @@ func TestVariableExp(t *testing.T) {
err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField())
assert.NoError(err)
}

type variableInverse[T FieldParams] struct {
Modulus Element[T]
A Element[T]
Expected Element[T]
}

func (c *variableInverse[T]) Define(api frontend.API) error {
f, err := NewField[T](api)
if err != nil {
return fmt.Errorf("new variable modulus: %w", err)
}
res := f.ModInverse(&c.A, &c.Modulus)
f.ModAssertIsEqual(&c.Expected, res, &c.Modulus)
return nil
}

func TestVariableInverse(t *testing.T) {
assert := test.NewAssert(t)
modulus, _ := new(big.Int).SetString("4294967311", 10)
a, _ := rand.Int(rand.Reader, modulus)
expected := new(big.Int).ModInverse(a, modulus)
if expected == nil {
t.Fatal("modular inverse should exists. Check modulus")
}
circuit := &variableInverse[emparams.Mod1e512]{}
assignment := &variableInverse[emparams.Mod1e512]{
Modulus: ValueOf[emparams.Mod1e512](modulus),
A: ValueOf[emparams.Mod1e512](a),
Expected: ValueOf[emparams.Mod1e512](expected),
}
err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField())
assert.NoError(err)

modulus2 := new(big.Int).Add(modulus, big.NewInt(1))
base2 := big.NewInt(16)
expected2 := big.NewInt(0) // no modular inverse
assignment2 := &variableInverse[emparams.Mod1e512]{
Modulus: ValueOf[emparams.Mod1e512](modulus2),
A: ValueOf[emparams.Mod1e512](base2),
Expected: ValueOf[emparams.Mod1e512](expected2),
}
err = test.IsSolved(circuit, assignment2, ecc.BLS12_377.ScalarField())
assert.Error(err)
}
2 changes: 1 addition & 1 deletion std/math/emulated/field_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (f *Field[T]) inverse(a, _ *Element[T], _ uint) *Element[T] {
if !f.fParams.IsPrime() {
panic("modulus not a prime")
}
k, err := f.computeInverseHint(a.Limbs)
k, err := f.computeInverseHint(a.Limbs, f.Modulus())
if err != nil {
panic(fmt.Sprintf("compute inverse: %v", err))
}
Expand Down
8 changes: 4 additions & 4 deletions std/math/emulated/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ func nbMultiplicationResLimbs(lenLeft, lenRight int) int {
return res
}

// computeInverseHint packs the inputs for the InverseHint hint function.
func (f *Field[T]) computeInverseHint(inLimbs []frontend.Variable) (inverseLimbs []frontend.Variable, err error) {
// computeInverseHint packs the inputs for the InverseHint hint function. The modulus is passed as an argument,
// allowing the method to be used both for fixed and variable modulus cases.
func (f *Field[T]) computeInverseHint(inLimbs []frontend.Variable, modulus *Element[T]) (inverseLimbs []frontend.Variable, err error) {
hintInputs := []frontend.Variable{
f.fParams.BitsPerLimb(),
f.fParams.NbLimbs(),
}
p := f.Modulus()
hintInputs = append(hintInputs, p.Limbs...)
hintInputs = append(hintInputs, modulus.Limbs...)
hintInputs = append(hintInputs, inLimbs...)
return f.api.NewHint(InverseHint, int(f.fParams.NbLimbs()), hintInputs...)
}
Expand Down