Skip to content

Commit 7fdf255

Browse files
ale-linuxadecaro
authored andcommitted
Expose CurveID
Expose getter to retrieve ID for a curve given an element. Fixes #16 Signed-off-by: Alessandro Sorniotti <aso@zurich.ibm.com>
1 parent 7296296 commit 7fdf255

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

math.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ type Zr struct {
177177
curveID CurveID
178178
}
179179

180+
func (z *Zr) CurveID() CurveID {
181+
return z.curveID
182+
}
183+
180184
func (z *Zr) Plus(a *Zr) *Zr {
181185
return &Zr{zr: z.zr.Plus(a.zr), curveID: z.curveID}
182186
}
@@ -244,6 +248,10 @@ type G1 struct {
244248
curveID CurveID
245249
}
246250

251+
func (g *G1) CurveID() CurveID {
252+
return g.curveID
253+
}
254+
247255
func (g *G1) Clone(a *G1) {
248256
g.g1.Clone(a.g1)
249257
}
@@ -299,6 +307,10 @@ type G2 struct {
299307
curveID CurveID
300308
}
301309

310+
func (g *G2) CurveID() CurveID {
311+
return g.curveID
312+
}
313+
302314
func (g *G2) Clone(a *G2) {
303315
g.g2.Clone(a.g2)
304316
}
@@ -346,6 +358,10 @@ type Gt struct {
346358
curveID CurveID
347359
}
348360

361+
func (g *Gt) CurveID() CurveID {
362+
return g.curveID
363+
}
364+
349365
func (g *Gt) Equals(a *Gt) bool {
350366
return g.gt.Equals(a.gt)
351367
}

math_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,24 @@ func TestImmutability(t *testing.T) {
3232
}
3333
}
3434

35+
func TestCurveId(t *testing.T) {
36+
for _, curve := range Curves {
37+
rng, err := curve.Rand()
38+
assert.NoError(t, err)
39+
40+
runCurveIdTest(t, curve, rng)
41+
}
42+
}
43+
44+
func runCurveIdTest(t *testing.T, c *Curve, rng io.Reader) {
45+
r := c.NewRandomZr(rng)
46+
47+
assert.Equal(t, r.CurveID(), c.curveID)
48+
assert.Equal(t, c.GenG1.Mul(r).CurveID(), c.curveID)
49+
assert.Equal(t, c.GenG2.Mul(r).CurveID(), c.curveID)
50+
assert.Equal(t, c.GenGt.Exp(r).CurveID(), c.curveID)
51+
}
52+
3553
var r *Zr
3654
var g1 *G1
3755
var g2 *G2

0 commit comments

Comments
 (0)