Skip to content

Commit 74890d9

Browse files
authored
feat!(eds): make Size method return error (#4091)
The Size method in the eds.Accessor interface now returns an error to handle cases where: - The header is not initialized in ODS - The extended data square is not initialized in Rsmt2D - The accessor is closed in closeOnce This change makes the interface more consistent with other methods that return errors and allows proper error handling in cases where size information cannot be retrieved. Updates: - eds.Accessor interface to include error in Size method signature - All implementations (ODS, ODSQ4, Rsmt2D, NoopFile, validation, proofsCache, closeOnce) - Tests and mocks to handle the error return Resolves TODO in store/file/ods_q4.go about Size method returning error.
1 parent ecbf331 commit 74890d9

11 files changed

+100
-41
lines changed

share/eds/accessor.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ var EmptyAccessor = &Rsmt2D{ExtendedDataSquare: share.EmptyEDS()}
1717
// Accessor is an interface for accessing extended data square data.
1818
type Accessor interface {
1919
// Size returns square size of the Accessor.
20-
Size(ctx context.Context) int
20+
Size(ctx context.Context) (int, error)
2121
// DataHash returns data hash of the Accessor.
2222
DataHash(ctx context.Context) (share.DataHash, error)
2323
// AxisRoots returns share.AxisRoots (DataAvailabilityHeader) of the Accessor.

share/eds/close_once.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ func (c *closeOnce) Close() error {
3636
return err
3737
}
3838

39-
func (c *closeOnce) Size(ctx context.Context) int {
39+
func (c *closeOnce) Size(ctx context.Context) (int, error) {
4040
if c.closed.Load() {
41-
return 0
41+
return 0, errAccessorClosed
4242
}
4343
return c.f.Size(ctx)
4444
}

share/eds/close_once_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ type stubEdsAccessorCloser struct {
4747
closed bool
4848
}
4949

50-
func (s *stubEdsAccessorCloser) Size(context.Context) int {
51-
return 0
50+
func (s *stubEdsAccessorCloser) Size(context.Context) (int, error) {
51+
return 0, nil
5252
}
5353

5454
func (s *stubEdsAccessorCloser) DataHash(context.Context) (share.DataHash, error) {

share/eds/proofs_cache.go

+42-13
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,18 @@ func WithProofsCache(ac AccessorStreamer) AccessorStreamer {
7575
}
7676
}
7777

78-
func (c *proofsCache) Size(ctx context.Context) int {
78+
func (c *proofsCache) Size(ctx context.Context) (int, error) {
7979
size := c.size.Load()
80-
if size == 0 {
81-
size = int32(c.inner.Size(ctx))
82-
c.size.Store(size)
80+
if size != 0 {
81+
return int(size), nil
8382
}
84-
return int(size)
83+
84+
loaded, err := c.inner.Size(ctx)
85+
if err != nil {
86+
return 0, fmt.Errorf("loading size from inner accessor: %w", err)
87+
}
88+
c.size.Store(int32(loaded))
89+
return loaded, nil
8590
}
8691

8792
func (c *proofsCache) DataHash(ctx context.Context) (share.DataHash, error) {
@@ -121,7 +126,11 @@ func (c *proofsCache) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap
121126

122127
// build share proof from proofs cached for given axis
123128
share := ax.shares[shrIdx]
124-
proofs, err := ipld.GetProof(ctx, ax.proofs, ax.root, shrIdx, c.Size(ctx))
129+
size, err := c.Size(ctx)
130+
if err != nil {
131+
return shwap.Sample{}, fmt.Errorf("getting size: %w", err)
132+
}
133+
proofs, err := ipld.GetProof(ctx, ax.proofs, ax.root, shrIdx, size)
125134
if err != nil {
126135
return shwap.Sample{}, fmt.Errorf("building proof from cache: %w", err)
127136
}
@@ -159,9 +168,13 @@ func (c *proofsCache) axisWithProofs(ctx context.Context, axisType rsmt2d.Axis,
159168
}
160169

161170
// build proofs from Shares and cache them
162-
adder := ipld.NewProofsAdder(c.Size(ctx), true)
171+
size, err := c.Size(ctx)
172+
if err != nil {
173+
return axisWithProofs{}, fmt.Errorf("getting size: %w", err)
174+
}
175+
adder := ipld.NewProofsAdder(size, true)
163176
tree := wrapper.NewErasuredNamespacedMerkleTree(
164-
uint64(c.Size(ctx)/2),
177+
uint64(size/2),
165178
uint(axisIdx),
166179
nmt.NodeVisitor(adder.VisitFn()),
167180
)
@@ -221,7 +234,11 @@ func (c *proofsCache) RowNamespaceData(
221234
return shwap.RowNamespaceData{}, err
222235
}
223236

224-
row, proof, err := ipld.GetSharesByNamespace(ctx, ax.proofs, ax.root, namespace, c.Size(ctx))
237+
size, err := c.Size(ctx)
238+
if err != nil {
239+
return shwap.RowNamespaceData{}, fmt.Errorf("getting size: %w", err)
240+
}
241+
row, proof, err := ipld.GetSharesByNamespace(ctx, ax.proofs, ax.root, namespace, size)
225242
if err != nil {
226243
return shwap.RowNamespaceData{}, fmt.Errorf("shares by namespace %s for row %v: %w", namespace.String(), rowIdx, err)
227244
}
@@ -233,9 +250,13 @@ func (c *proofsCache) RowNamespaceData(
233250
}
234251

235252
func (c *proofsCache) Shares(ctx context.Context) ([]libshare.Share, error) {
236-
odsSize := c.Size(ctx) / 2
253+
size, err := c.Size(ctx)
254+
if err != nil {
255+
return nil, fmt.Errorf("getting size: %w", err)
256+
}
257+
odsSize := size / 2
237258
shares := make([]libshare.Share, 0, odsSize*odsSize)
238-
for i := 0; i < c.Size(ctx)/2; i++ {
259+
for i := 0; i < odsSize; i++ {
239260
ax, err := c.AxisHalf(ctx, rsmt2d.Row, i)
240261
if err != nil {
241262
return nil, err
@@ -256,7 +277,11 @@ func (c *proofsCache) Shares(ctx context.Context) ([]libshare.Share, error) {
256277
}
257278

258279
func (c *proofsCache) Reader() (io.Reader, error) {
259-
odsSize := c.Size(context.TODO()) / 2
280+
size, err := c.Size(context.TODO())
281+
if err != nil {
282+
return nil, fmt.Errorf("getting size: %w", err)
283+
}
284+
odsSize := size / 2
260285
reader := NewShareReader(odsSize, c.getShare)
261286
return reader, nil
262287
}
@@ -307,7 +332,11 @@ func (c *proofsCache) getAxisFromCache(axisType rsmt2d.Axis, axisIdx int) (axisW
307332

308333
func (c *proofsCache) getShare(rowIdx, colIdx int) (libshare.Share, error) {
309334
ctx := context.TODO()
310-
odsSize := c.Size(ctx) / 2
335+
size, err := c.Size(ctx)
336+
if err != nil {
337+
return libshare.Share{}, fmt.Errorf("getting size: %w", err)
338+
}
339+
odsSize := size / 2
311340
half, err := c.AxisHalf(ctx, rsmt2d.Row, rowIdx)
312341
if err != nil {
313342
return libshare.Share{}, fmt.Errorf("reading axis half: %w", err)

share/eds/rsmt2d.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ type Rsmt2D struct {
2121
}
2222

2323
// Size returns the size of the Extended Data Square.
24-
func (eds *Rsmt2D) Size(context.Context) int {
25-
return int(eds.Width())
24+
func (eds *Rsmt2D) Size(context.Context) (int, error) {
25+
return int(eds.Width()), nil
2626
}
2727

2828
// DataHash returns data hash of the Accessor.

share/eds/testing.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,16 @@ func BenchGetHalfAxisFromAccessor(
418418
name := fmt.Sprintf("Size:%v/ProofType:%s/squareHalf:%s", size, axisType, strconv.Itoa(squareHalf))
419419
b.Run(name, func(b *testing.B) {
420420
// warm up cache
421-
_, err := acc.AxisHalf(ctx, axisType, acc.Size(ctx)/2*(squareHalf))
421+
size, err := acc.Size(ctx)
422+
require.NoError(b, err)
423+
_, err = acc.AxisHalf(ctx, axisType, size/2*(squareHalf))
422424
require.NoError(b, err)
423425

424426
b.ResetTimer()
425427
for i := 0; i < b.N; i++ {
426-
_, err := acc.AxisHalf(ctx, axisType, acc.Size(ctx)/2*(squareHalf))
428+
size, err := acc.Size(ctx)
429+
require.NoError(b, err)
430+
_, err = acc.AxisHalf(ctx, axisType, size/2*(squareHalf))
427431
require.NoError(b, err)
428432
}
429433
})
@@ -446,11 +450,13 @@ func BenchGetSampleFromAccessor(
446450
for _, q := range quadrants {
447451
name := fmt.Sprintf("Size:%v/quadrant:%s", size, q)
448452
b.Run(name, func(b *testing.B) {
449-
rowIdx, colIdx := q.coordinates(acc.Size(ctx))
453+
edsSize, err := acc.Size(ctx)
454+
require.NoError(b, err)
455+
rowIdx, colIdx := q.coordinates(edsSize)
450456
idx := shwap.SampleCoords{Row: rowIdx, Col: colIdx}
451457

452458
// warm up cache
453-
_, err := acc.Sample(ctx, idx)
459+
_, err = acc.Sample(ctx, idx)
454460
require.NoError(b, err, q.String())
455461

456462
b.ResetTimer()

share/eds/validation.go

+25-9
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,38 @@ func WithValidation(f Accessor) Accessor {
2424
return &validation{Accessor: f, size: new(atomic.Int32)}
2525
}
2626

27-
func (f validation) Size(ctx context.Context) int {
27+
func (f validation) Size(ctx context.Context) (int, error) {
2828
size := f.size.Load()
29-
if size == 0 {
30-
loaded := f.Accessor.Size(ctx)
31-
f.size.Store(int32(loaded))
32-
return loaded
29+
if size != 0 {
30+
return int(size), nil
3331
}
34-
return int(size)
32+
33+
loaded, err := f.Accessor.Size(ctx)
34+
if err != nil {
35+
return 0, fmt.Errorf("loading size: %w", err)
36+
}
37+
f.size.Store(int32(loaded))
38+
return loaded, nil
3539
}
3640

3741
func (f validation) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) {
38-
_, err := shwap.NewSampleID(1, idx, f.Size(ctx))
42+
size, err := f.Size(ctx)
43+
if err != nil {
44+
return shwap.Sample{}, fmt.Errorf("getting size: %w", err)
45+
}
46+
_, err = shwap.NewSampleID(1, idx, size)
3947
if err != nil {
4048
return shwap.Sample{}, fmt.Errorf("sample validation: %w", err)
4149
}
4250
return f.Accessor.Sample(ctx, idx)
4351
}
4452

4553
func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
46-
_, err := shwap.NewRowID(1, axisIdx, f.Size(ctx))
54+
size, err := f.Size(ctx)
55+
if err != nil {
56+
return AxisHalf{}, fmt.Errorf("getting size: %w", err)
57+
}
58+
_, err = shwap.NewRowID(1, axisIdx, size)
4759
if err != nil {
4860
return AxisHalf{}, fmt.Errorf("axis half validation: %w", err)
4961
}
@@ -55,7 +67,11 @@ func (f validation) RowNamespaceData(
5567
namespace libshare.Namespace,
5668
rowIdx int,
5769
) (shwap.RowNamespaceData, error) {
58-
_, err := shwap.NewRowNamespaceDataID(1, rowIdx, namespace, f.Size(ctx))
70+
size, err := f.Size(ctx)
71+
if err != nil {
72+
return shwap.RowNamespaceData{}, fmt.Errorf("getting size: %w", err)
73+
}
74+
_, err = shwap.NewRowNamespaceDataID(1, rowIdx, namespace, size)
5975
if err != nil {
6076
return shwap.RowNamespaceData{}, fmt.Errorf("row namespace data validation: %w", err)
6177
}

store/cache/accessor_cache_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ type mockAccessor struct {
303303
isClosed bool
304304
}
305305

306-
func (m *mockAccessor) Size(context.Context) int {
306+
func (m *mockAccessor) Size(context.Context) (int, error) {
307307
panic("implement me")
308308
}
309309

store/cache/noop.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ func (n NoopFile) Reader() (io.Reader, error) {
4747
return noopReader{}, nil
4848
}
4949

50-
func (n NoopFile) Size(context.Context) int {
51-
return 0
50+
func (n NoopFile) Size(context.Context) (int, error) {
51+
return 0, nil
5252
}
5353

5454
func (n NoopFile) DataHash(context.Context) (share.DataHash, error) {

store/file/ods.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ func OpenODS(path string) (*ODS, error) {
183183
}
184184

185185
// Size returns EDS size stored in file's header.
186-
func (o *ODS) Size(context.Context) int {
187-
return o.size()
186+
func (o *ODS) Size(context.Context) (int, error) {
187+
return o.size(), nil
188188
}
189189

190190
func (o *ODS) size() int {
@@ -238,8 +238,13 @@ func (o *ODS) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample,
238238
// to calculate the sample
239239
rowIdx, colIdx := idx.Row, idx.Col
240240

241+
size, err := o.Size(ctx)
242+
if err != nil {
243+
return shwap.Sample{}, fmt.Errorf("getting size: %w", err)
244+
}
245+
241246
axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx
242-
if colIdx < o.size()/2 && rowIdx >= o.size()/2 {
247+
if colIdx < size/2 && rowIdx >= size/2 {
243248
axisType, axisIdx, shrIdx = rsmt2d.Col, colIdx, rowIdx
244249
}
245250

store/file/ods_q4.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func (odsq4 *ODSQ4) tryLoadQ4() *q4 {
110110
return q4
111111
}
112112

113-
func (odsq4 *ODSQ4) Size(ctx context.Context) int {
113+
func (odsq4 *ODSQ4) Size(ctx context.Context) (int, error) {
114114
return odsq4.ods.Size(ctx)
115115
}
116116

@@ -137,7 +137,10 @@ func (odsq4 *ODSQ4) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.S
137137
}
138138

139139
func (odsq4 *ODSQ4) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) {
140-
size := odsq4.Size(ctx) // TODO(@Wondertan): Should return error.
140+
size, err := odsq4.Size(ctx)
141+
if err != nil {
142+
return eds.AxisHalf{}, fmt.Errorf("getting size: %w", err)
143+
}
141144

142145
if axisIdx >= size/2 {
143146
// lazy load Q4 file and read axis from it if loaded

0 commit comments

Comments
 (0)