Skip to content

Commit 6515446

Browse files
authored
Merge pull request #157 from rootulp/rp/errors
chore!: add error return params to tree interface
2 parents 1e85aab + 1956b16 commit 6515446

5 files changed

+129
-57
lines changed

datasquare.go

+31-17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"fmt"
66
"math"
77
"sync"
8+
9+
"golang.org/x/sync/errgroup"
810
)
911

1012
// ErrUnevenChunks is thrown when non-nil chunks are not all of equal size.
@@ -187,29 +189,41 @@ func (ds *dataSquare) resetRoots() {
187189
}
188190
}
189191

190-
func (ds *dataSquare) computeRoots() {
191-
var wg sync.WaitGroup
192+
func (ds *dataSquare) computeRoots() error {
193+
var g errgroup.Group
192194

193195
rowRoots := make([][]byte, ds.width)
194196
colRoots := make([][]byte, ds.width)
195197

196198
for i := uint(0); i < ds.width; i++ {
197-
wg.Add(2)
198-
199-
go func(i uint) {
200-
defer wg.Done()
201-
rowRoots[i] = ds.getRowRoot(i)
202-
}(i)
199+
i := i // https://go.dev/doc/faq#closures_and_goroutines
200+
g.Go(func() error {
201+
rowRoot, err := ds.getRowRoot(i)
202+
if err != nil {
203+
return err
204+
}
205+
rowRoots[i] = rowRoot
206+
return nil
207+
})
208+
209+
g.Go(func() error {
210+
colRoot, err := ds.getColRoot(i)
211+
if err != nil {
212+
return err
213+
}
214+
colRoots[i] = colRoot
215+
return nil
216+
})
217+
}
203218

204-
go func(i uint) {
205-
defer wg.Done()
206-
colRoots[i] = ds.getColRoot(i)
207-
}(i)
219+
err := g.Wait()
220+
if err != nil {
221+
return err
208222
}
209223

210-
wg.Wait()
211224
ds.rowRoots = rowRoots
212225
ds.colRoots = colRoots
226+
return nil
213227
}
214228

215229
// getRowRoots returns the Merkle roots of all the rows in the square.
@@ -223,9 +237,9 @@ func (ds *dataSquare) getRowRoots() [][]byte {
223237

224238
// getRowRoot calculates and returns the root of the selected row. Note: unlike the
225239
// getRowRoots method, getRowRoot does not write to the built-in cache.
226-
func (ds *dataSquare) getRowRoot(x uint) []byte {
240+
func (ds *dataSquare) getRowRoot(x uint) ([]byte, error) {
227241
if ds.rowRoots != nil {
228-
return ds.rowRoots[x]
242+
return ds.rowRoots[x], nil
229243
}
230244

231245
tree := ds.createTreeFn(Row, x)
@@ -247,9 +261,9 @@ func (ds *dataSquare) getColRoots() [][]byte {
247261

248262
// getColRoot calculates and returns the root of the selected row. Note: unlike the
249263
// getColRoots method, getColRoot does not write to the built-in cache.
250-
func (ds *dataSquare) getColRoot(y uint) []byte {
264+
func (ds *dataSquare) getColRoot(y uint) ([]byte, error) {
251265
if ds.colRoots != nil {
252-
return ds.colRoots[y]
266+
return ds.colRoots[y], nil
253267
}
254268

255269
tree := ds.createTreeFn(Col, y)

datasquare_test.go

+57-20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"reflect"
66
"testing"
77

8+
"github.com/celestiaorg/merkletree"
9+
"github.com/minio/sha256-simd"
810
"github.com/stretchr/testify/assert"
911
)
1012

@@ -138,36 +140,60 @@ func TestLazyRootGeneration(t *testing.T) {
138140
var colRoots [][]byte
139141

140142
for i := uint(0); i < square.width; i++ {
141-
rowRoots = append(rowRoots, square.getRowRoot(i))
142-
colRoots = append(rowRoots, square.getColRoot(i))
143+
rowRoot, err := square.getRowRoot(i)
144+
assert.NoError(t, err)
145+
colRoot, err := square.getColRoot(i)
146+
assert.NoError(t, err)
147+
rowRoots = append(rowRoots, rowRoot)
148+
colRoots = append(colRoots, colRoot)
143149
}
144150

145-
square.computeRoots()
151+
err = square.computeRoots()
152+
assert.NoError(t, err)
146153

147154
if !reflect.DeepEqual(square.rowRoots, rowRoots) && !reflect.DeepEqual(square.colRoots, colRoots) {
148155
t.Error("getRowRoot or getColRoot did not produce identical roots to computeRoots")
149156
}
150157
}
151158

159+
func TestComputeRoots(t *testing.T) {
160+
t.Run("default tree computeRoots() returns no error", func(t *testing.T) {
161+
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
162+
assert.NoError(t, err)
163+
err = square.computeRoots()
164+
assert.NoError(t, err)
165+
})
166+
t.Run("error tree computeRoots() returns an error", func(t *testing.T) {
167+
square, err := newDataSquare([][]byte{{1}}, newErrorTree)
168+
assert.NoError(t, err)
169+
err = square.computeRoots()
170+
assert.Error(t, err)
171+
})
172+
}
173+
152174
func TestRootAPI(t *testing.T) {
153175
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
154176
if err != nil {
155177
panic(err)
156178
}
157179

158180
for i := uint(0); i < square.width; i++ {
159-
if !reflect.DeepEqual(square.getRowRoots()[i], square.getRowRoot(i)) {
181+
rowRoot, err := square.getRowRoot(i)
182+
assert.NoError(t, err)
183+
if !reflect.DeepEqual(square.getRowRoots()[i], rowRoot) {
160184
t.Errorf(
161185
"Row root API results in different roots, expected %v got %v",
162186
square.getRowRoots()[i],
163-
square.getRowRoot(i),
187+
rowRoot,
164188
)
165189
}
166-
if !reflect.DeepEqual(square.getColRoots()[i], square.getColRoot(i)) {
190+
colRoot, err := square.getColRoot(i)
191+
assert.NoError(t, err)
192+
if !reflect.DeepEqual(square.getColRoots()[i], colRoot) {
167193
t.Errorf(
168194
"Column root API results in different roots, expected %v got %v",
169195
square.getColRoots()[i],
170-
square.getColRoot(i),
196+
colRoot,
171197
)
172198
}
173199
}
@@ -205,7 +231,8 @@ func BenchmarkEDSRoots(b *testing.B) {
205231
func(b *testing.B) {
206232
for n := 0; n < b.N; n++ {
207233
square.resetRoots()
208-
square.computeRoots()
234+
err := square.computeRoots()
235+
assert.NoError(b, err)
209236
}
210237
},
211238
)
@@ -224,18 +251,6 @@ func computeRowProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, ui
224251
return merkleRoot, proof, uint(proofIndex), uint(numLeaves), nil
225252
}
226253

227-
func computeColProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, uint, error) {
228-
tree := ds.createTreeFn(Col, y)
229-
data := ds.col(y)
230-
231-
for i := uint(0); i < ds.width; i++ {
232-
tree.Push(data[i])
233-
}
234-
// TODO(ismail): check for overflow when casting from uint -> int
235-
merkleRoot, proof, proofIndex, numLeaves := treeProve(tree.(*DefaultTree), int(x))
236-
return merkleRoot, proof, uint(proofIndex), uint(numLeaves), nil
237-
}
238-
239254
func treeProve(d *DefaultTree, idx int) (merkleRoot []byte, proofSet [][]byte, proofIndex uint64, numLeaves uint64) {
240255
if err := d.Tree.SetIndex(uint64(idx)); err != nil {
241256
panic(fmt.Sprintf("don't call prove on a already used tree: %v", err))
@@ -245,3 +260,25 @@ func treeProve(d *DefaultTree, idx int) (merkleRoot []byte, proofSet [][]byte, p
245260
}
246261
return d.Tree.Prove()
247262
}
263+
264+
type errorTree struct {
265+
*merkletree.Tree
266+
leaves [][]byte
267+
}
268+
269+
func newErrorTree(axis Axis, index uint) Tree {
270+
return &errorTree{
271+
Tree: merkletree.New(sha256.New()),
272+
leaves: make([][]byte, 0, 128),
273+
}
274+
}
275+
276+
func (d *errorTree) Push(data []byte) error {
277+
// ignore the idx, as this implementation doesn't need that info
278+
d.leaves = append(d.leaves, data)
279+
return nil
280+
}
281+
282+
func (d *errorTree) Root() ([]byte, error) {
283+
return nil, fmt.Errorf("error")
284+
}

extendeddatacrossword.go

+28-12
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ func (e *ErrByzantineData) Error() string {
5151
// square (EDS), comparing repaired rows and columns against expected Merkle
5252
// roots.
5353
//
54-
// Input
54+
// # Input
5555
//
5656
// Missing shares must be nil.
5757
//
58-
// Output
58+
// # Output
5959
//
6060
// The EDS is modified in-place. If repairing is successful, the EDS will be
6161
// complete. If repairing is unsuccessful, the EDS will be the most-repaired
@@ -282,10 +282,14 @@ func (eds *ExtendedDataSquare) verifyAgainstRowRoots(
282282
rebuiltShare []byte,
283283
) error {
284284
var root []byte
285+
var err error
285286
if rebuiltIndex < 0 || rebuiltShare == nil {
286-
root = eds.computeSharesRoot(oldShares, Row, r)
287+
root, err = eds.computeSharesRoot(oldShares, Row, r)
287288
} else {
288-
root = eds.computeSharesRootWithRebuiltShare(oldShares, Row, r, rebuiltIndex, rebuiltShare)
289+
root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Row, r, rebuiltIndex, rebuiltShare)
290+
}
291+
if err != nil {
292+
return err
289293
}
290294

291295
if !bytes.Equal(root, rowRoots[r]) {
@@ -303,10 +307,14 @@ func (eds *ExtendedDataSquare) verifyAgainstColRoots(
303307
rebuiltShare []byte,
304308
) error {
305309
var root []byte
310+
var err error
306311
if rebuiltIndex < 0 || rebuiltShare == nil {
307-
root = eds.computeSharesRoot(oldShares, Col, c)
312+
root, err = eds.computeSharesRoot(oldShares, Col, c)
308313
} else {
309-
root = eds.computeSharesRootWithRebuiltShare(oldShares, Col, c, rebuiltIndex, rebuiltShare)
314+
root, err = eds.computeSharesRootWithRebuiltShare(oldShares, Col, c, rebuiltIndex, rebuiltShare)
315+
}
316+
if err != nil {
317+
return err
310318
}
311319

312320
if !bytes.Equal(root, colRoots[c]) {
@@ -331,8 +339,12 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(
331339
if rowIsComplete {
332340
errs.Go(func() error {
333341
// ensure that the roots are equal
334-
if !bytes.Equal(rowRoots[i], eds.getRowRoot(i)) {
335-
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], eds.getRowRoot(i))
342+
rowRoot, err := eds.getRowRoot(i)
343+
if err != nil {
344+
return err
345+
}
346+
if !bytes.Equal(rowRoots[i], rowRoot) {
347+
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], rowRoot)
336348
}
337349
return nil
338350
})
@@ -342,8 +354,12 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(
342354
if colIsComplete {
343355
errs.Go(func() error {
344356
// ensure that the roots are equal
345-
if !bytes.Equal(colRoots[i], eds.getColRoot(i)) {
346-
return fmt.Errorf("bad root input: col %d expected %v got %v", i, colRoots[i], eds.getColRoot(i))
357+
colRoot, err := eds.getColRoot(i)
358+
if err != nil {
359+
return err
360+
}
361+
if !bytes.Equal(colRoots[i], colRoot) {
362+
return fmt.Errorf("bad root input: col %d expected %v got %v", i, colRoots[i], colRoot)
347363
}
348364
return nil
349365
})
@@ -391,15 +407,15 @@ func noMissingData(input [][]byte, rebuiltIndex int) bool {
391407
return true
392408
}
393409

394-
func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) []byte {
410+
func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) ([]byte, error) {
395411
tree := eds.createTreeFn(axis, i)
396412
for _, d := range shares {
397413
tree.Push(d)
398414
}
399415
return tree.Root()
400416
}
401417

402-
func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte, axis Axis, i uint, rebuiltIndex int, rebuiltShare []byte) []byte {
418+
func (eds *ExtendedDataSquare) computeSharesRootWithRebuiltShare(shares [][]byte, axis Axis, i uint, rebuiltIndex int, rebuiltShare []byte) ([]byte, error) {
403419
tree := eds.createTreeFn(axis, i)
404420
for _, d := range shares[:rebuiltIndex] {
405421
tree.Push(d)

extendeddatacrossword_test.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,11 @@ func TestValidFraudProof(t *testing.T) {
119119
if err != nil {
120120
t.Errorf("could not decode fraud proof shares; got: %v", err)
121121
}
122-
root := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index)
123-
if bytes.Equal(root, corrupted.getRowRoot(fraudProof.Index)) {
122+
root, err := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index)
123+
assert.NoError(t, err)
124+
rowRoot, err := corrupted.getRowRoot(fraudProof.Index)
125+
assert.NoError(t, err)
126+
if bytes.Equal(root, rowRoot) {
124127
// If the roots match, then the fraud proof should be for invalid erasure coding.
125128
parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth])
126129
if err != nil {

tree.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import (
66
"github.com/celestiaorg/merkletree"
77
)
88

9-
// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle inside of rsmt2d.
9+
// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle tree
10+
// inside of rsmt2d.
1011
type TreeConstructorFn = func(axis Axis, index uint) Tree
1112

1213
// SquareIndex contains all information needed to identify the cell that is being
@@ -17,8 +18,8 @@ type SquareIndex struct {
1718

1819
// Tree wraps Merkle tree implementations to work with rsmt2d
1920
type Tree interface {
20-
Push(data []byte)
21-
Root() []byte
21+
Push(data []byte) error
22+
Root() ([]byte, error)
2223
}
2324

2425
var _ Tree = &DefaultTree{}
@@ -36,17 +37,18 @@ func NewDefaultTree(axis Axis, index uint) Tree {
3637
}
3738
}
3839

39-
func (d *DefaultTree) Push(data []byte) {
40+
func (d *DefaultTree) Push(data []byte) error {
4041
// ignore the idx, as this implementation doesn't need that info
4142
d.leaves = append(d.leaves, data)
43+
return nil
4244
}
4345

44-
func (d *DefaultTree) Root() []byte {
46+
func (d *DefaultTree) Root() ([]byte, error) {
4547
if d.root == nil {
4648
for _, l := range d.leaves {
4749
d.Tree.Push(l)
4850
}
4951
d.root = d.Tree.Root()
5052
}
51-
return d.root
53+
return d.root, nil
5254
}

0 commit comments

Comments
 (0)