Skip to content

Commit ee9f508

Browse files
committed
solver, solverrpc, csppsolver: Add RootFactors
The RootFactors function provides the result of the polynomial factorization, including too few results or repeated roots, without erroring for these conditions.
1 parent 09ca221 commit ee9f508

File tree

5 files changed

+128
-20
lines changed

5 files changed

+128
-20
lines changed

cmd/csppsolver/solver.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,38 @@ type Args struct {
3636

3737
type Result struct {
3838
Roots []*big.Int
39+
Exponents []int
3940
RepeatedRoot *big.Int
4041
}
4142

43+
func (*Solver) RootFactors(args Args, res *Result) error {
44+
roots, exps, err := solver.RootFactors(args.A, args.F)
45+
if err != nil {
46+
return err
47+
}
48+
res.Roots = roots
49+
res.Exponents = exps
50+
return nil
51+
}
52+
4253
type repeatedRoot interface {
4354
RepeatedRoot() *big.Int
4455
}
4556

4657
func (*Solver) Roots(args Args, res *Result) error {
47-
roots, err := solver.Roots(args.A, args.F)
48-
if rr, ok := err.(repeatedRoot); ok {
49-
res.RepeatedRoot = rr.RepeatedRoot()
50-
return nil // error set by client package
51-
}
58+
roots, exps, err := solver.RootFactors(args.A, args.F)
5259
if err != nil {
5360
return err
5461
}
62+
for i, exp := range exps {
63+
if exp != 1 {
64+
res.RepeatedRoot = roots[i]
65+
return nil // error set by client package
66+
}
67+
}
68+
5569
res.Roots = roots
70+
res.Exponents = exps
5671
return nil
5772
}
5873

solver/solver.go

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,12 @@ func factorPoly(fac *C.fmpz_mod_poly_factor_struct, i uintptr) *C.fmpz_mod_poly_
3535
return (*C.fmpz_mod_poly_struct)(unsafe.Pointer(uintptr(unsafe.Pointer(fac.poly)) + i*C.sizeof_fmpz_mod_poly_struct))
3636
}
3737

38-
type repeatedRoot big.Int
39-
40-
func (r *repeatedRoot) Error() string { return "repeated roots" }
41-
func (r *repeatedRoot) RepeatedRoot() *big.Int { return (*big.Int)(r) }
42-
43-
// Roots solves for len(a)-1 roots of the polynomial with coefficients a (mod F).
44-
// Repeated roots are considered an error for the purposes of unique slot assignment.
45-
func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {
38+
// RootFactors returns the roots and their number of solutions in the
39+
// factorized polynomial. Repeated roots are an error in the mixing protocol
40+
// but unlike the Roots function are not returned as an error here.
41+
func RootFactors(a []*big.Int, F *big.Int) ([]*big.Int, []int, error) {
4642
if len(a) < 2 {
47-
return nil, errors.New("too few coefficients")
43+
return nil, nil, errors.New("too few coefficients")
4844
}
4945

5046
var mod C.fmpz_t
@@ -80,6 +76,7 @@ func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {
8076
C.fmpz_mod_poly_factor(&factor[0], &poly[0], &modctx[0])
8177

8278
roots := make([]*big.Int, 0, len(a)-1)
79+
exps := make([]int, 0, len(a)-1)
8380
var m C.fmpz_t
8481
C.fmpz_init(&m[0])
8582
defer C.fmpz_clear(&m[0])
@@ -93,19 +90,36 @@ func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {
9390

9491
b, ok := new(big.Int).SetString(str, base)
9592
if !ok {
96-
return nil, errors.New("failed to read fmpz")
93+
return nil, nil, errors.New("failed to read fmpz")
9794
}
9895
b.Neg(b)
9996
b.Mod(b, F)
10097

101-
if factorExp(&factor[0], uintptr(i)) != 1 {
102-
return nil, (*repeatedRoot)(b)
103-
}
10498
roots = append(roots, b)
99+
exps = append(exps, int(factorExp(&factor[0], uintptr(i))))
100+
}
101+
102+
return roots, exps, nil
103+
}
104+
105+
type repeatedRoot big.Int
106+
107+
func (r *repeatedRoot) Error() string { return "repeated roots" }
108+
func (r *repeatedRoot) RepeatedRoot() *big.Int { return (*big.Int)(r) }
109+
110+
// Roots solves for len(a)-1 roots of the polynomial with coefficients a (mod F).
111+
// Repeated roots are considered an error for the purposes of unique slot,
112+
// assignment, and an error with method RepeatedRoot() *big.Int is returned.
113+
func Roots(a []*big.Int, F *big.Int) ([]*big.Int, error) {
114+
roots, exps, err := RootFactors(a, F)
115+
if err != nil {
116+
return roots, err
105117
}
106118

107-
if len(roots) != len(a)-1 {
108-
return nil, errors.New("too few roots")
119+
for i, exp := range exps {
120+
if exp != 1 {
121+
return nil, (*repeatedRoot)(roots[i])
122+
}
109123
}
110124

111125
return roots, nil

solver/solver_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,33 @@ func TestRoots(t *testing.T) {
429429
}
430430
}
431431

432+
func TestRootFactors(t *testing.T) {
433+
for i := range tests {
434+
roots, exps, err := RootFactors(tests[i].coeffs, tests[i].field)
435+
if err != nil {
436+
t.Error(err)
437+
continue
438+
}
439+
for i, exp := range exps {
440+
if exp != 1 {
441+
t.Errorf("repeated root %v at index %v", roots[i], i)
442+
continue
443+
}
444+
}
445+
if len(roots) != len(tests[i].messages) {
446+
t.Error("wrong root count")
447+
continue
448+
}
449+
sortBig(tests[i].messages)
450+
sortBig(roots)
451+
for j := range roots {
452+
if roots[j].Cmp(tests[i].messages[j]) != 0 {
453+
t.Error("recovered wrong message")
454+
}
455+
}
456+
}
457+
}
458+
432459
func BenchmarkRoots(b *testing.B) {
433460
for i := range tests {
434461
b.Run(fmt.Sprintf("%d", tests[i].n), func(b *testing.B) {

solverrpc/rpc.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,31 @@ func StartSolver() error {
7777
return onceErr
7878
}
7979

80+
// RootFactors returns the roots and their number of solutions in the
81+
// factorized polynomial. Repeated roots are an error in the mixing protocol
82+
// but unlike the Roots function are not returned as an error here.
83+
func RootFactors(a []*big.Int, F *big.Int) ([]*big.Int, []int, error) {
84+
if err := StartSolver(); err != nil {
85+
return nil, nil, err
86+
}
87+
88+
var args struct {
89+
A []*big.Int
90+
F *big.Int
91+
}
92+
args.A = a
93+
args.F = F
94+
var result struct {
95+
Roots []*big.Int
96+
Exponents []int
97+
}
98+
err := client.Call("Solver.RootFactors", args, &result)
99+
if err != nil {
100+
return nil, nil, err
101+
}
102+
return result.Roots, result.Exponents, nil
103+
}
104+
80105
type repeatedRoot big.Int
81106

82107
func (r *repeatedRoot) Error() string { return "repeated roots" }

solverrpc/solver_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,33 @@ func TestRoots(t *testing.T) {
429429
}
430430
}
431431

432+
func TestRootFactors(t *testing.T) {
433+
for i := range tests {
434+
roots, exps, err := RootFactors(tests[i].coeffs, tests[i].field)
435+
if err != nil {
436+
t.Error(err)
437+
continue
438+
}
439+
for i, exp := range exps {
440+
if exp != 1 {
441+
t.Errorf("repeated root %v at index %v", roots[i], i)
442+
continue
443+
}
444+
}
445+
if len(roots) != len(tests[i].messages) {
446+
t.Error("wrong root count")
447+
continue
448+
}
449+
sortBig(tests[i].messages)
450+
sortBig(roots)
451+
for j := range roots {
452+
if roots[j].Cmp(tests[i].messages[j]) != 0 {
453+
t.Error("recovered wrong message")
454+
}
455+
}
456+
}
457+
}
458+
432459
func BenchmarkRoots(b *testing.B) {
433460
for i := range tests {
434461
b.Run(fmt.Sprintf("%d", tests[i].n), func(b *testing.B) {

0 commit comments

Comments
 (0)