Skip to content

Commit 19dd38f

Browse files
authored
refactor(ipld): use Set/GetCell API from rstm2d (#1173)
1 parent 8356c21 commit 19dd38f

File tree

2 files changed

+60
-93
lines changed

2 files changed

+60
-93
lines changed

Diff for: share/eds/retriever.go

+43-55
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,18 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader
104104
// quadrant request retries. Also, provides an API
105105
// to reconstruct the block once enough shares are fetched.
106106
type retrievalSession struct {
107+
dah *da.DataAvailabilityHeader
107108
bget blockservice.BlockGetter
108109
adder *ipld.NmtNodeAdder
109110

110-
treeFn rsmt2d.TreeConstructorFn
111-
codec rsmt2d.Codec
112-
113-
dah *da.DataAvailabilityHeader
114-
squareImported *rsmt2d.ExtendedDataSquare
115-
116-
quadrants []*quadrant
117-
sharesLks []sync.Mutex
118-
sharesCount uint32
119-
120-
squareLk sync.RWMutex
121-
square [][]byte
122-
squareSig chan struct{}
123-
squareDn chan struct{}
111+
// TODO(@Wondertan): Extract into a separate data structure https://github.com/celestiaorg/rsmt2d/issues/135
112+
squareQuadrants []*quadrant
113+
squareCellsLks [][]sync.Mutex
114+
squareCellsCount uint32
115+
squareSig chan struct{}
116+
squareDn chan struct{}
117+
squareLk sync.RWMutex
118+
square *rsmt2d.ExtendedDataSquare
124119

125120
span trace.Span
126121
}
@@ -133,29 +128,31 @@ func (r *Retriever) newSession(ctx context.Context, dah *da.DataAvailabilityHead
133128
r.bServ,
134129
ipld.MaxSizeBatchOption(size),
135130
)
136-
ses := &retrievalSession{
137-
bget: blockservice.NewSession(ctx, r.bServ),
138-
adder: adder,
139-
treeFn: func(_ rsmt2d.Axis, index uint) rsmt2d.Tree {
140-
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit))
141-
return &tree
142-
},
143-
codec: share.DefaultRSMT2DCodec(),
144-
dah: dah,
145-
quadrants: newQuadrants(dah),
146-
sharesLks: make([]sync.Mutex, size*size),
147-
square: make([][]byte, size*size),
148-
squareSig: make(chan struct{}, 1),
149-
squareDn: make(chan struct{}),
150-
span: trace.SpanFromContext(ctx),
131+
132+
treeFn := func(_ rsmt2d.Axis, index uint) rsmt2d.Tree {
133+
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit))
134+
return &tree
151135
}
152136

153-
square, err := rsmt2d.ImportExtendedDataSquare(ses.square, ses.codec, ses.treeFn)
137+
square, err := rsmt2d.ImportExtendedDataSquare(make([][]byte, size*size), share.DefaultRSMT2DCodec(), treeFn)
154138
if err != nil {
155139
return nil, err
156140
}
157141

158-
ses.squareImported = square
142+
ses := &retrievalSession{
143+
dah: dah,
144+
bget: blockservice.NewSession(ctx, r.bServ),
145+
adder: adder,
146+
squareQuadrants: newQuadrants(dah),
147+
squareCellsLks: make([][]sync.Mutex, size),
148+
squareSig: make(chan struct{}, 1),
149+
squareDn: make(chan struct{}),
150+
square: square,
151+
span: trace.SpanFromContext(ctx),
152+
}
153+
for i := range ses.squareCellsLks {
154+
ses.squareCellsLks[i] = make([]sync.Mutex, size)
155+
}
159156
go ses.request(ctx)
160157
return ses, nil
161158
}
@@ -170,36 +167,24 @@ func (rs *retrievalSession) Done() <-chan struct{} {
170167
// Reconstruct tries to reconstruct the data square and returns it on success.
171168
func (rs *retrievalSession) Reconstruct(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error) {
172169
if rs.isReconstructed() {
173-
return rs.squareImported, nil
170+
return rs.square, nil
174171
}
175172
// prevent further writes to the square
176173
rs.squareLk.Lock()
177174
defer rs.squareLk.Unlock()
178175

179-
// TODO(@Wondertan): This is bad!
180-
// * We should not reimport the square multiple times
181-
// * We should set shares into imported square via
182-
// SetShare(https://github.com/celestiaorg/rsmt2d/issues/83) to accomplish the above point.
183-
{
184-
squareImported, err := rsmt2d.ImportExtendedDataSquare(rs.square, rs.codec, rs.treeFn)
185-
if err != nil {
186-
return nil, err
187-
}
188-
rs.squareImported = squareImported
189-
}
190-
191176
_, span := tracer.Start(ctx, "reconstruct-square")
192177
defer span.End()
193178

194179
// and try to repair with what we have
195-
err := rs.squareImported.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
180+
err := rs.square.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
196181
if err != nil {
197182
span.RecordError(err)
198183
return nil, err
199184
}
200185
log.Infow("data square reconstructed", "data_hash", hex.EncodeToString(rs.dah.Hash()), "size", len(rs.dah.RowsRoots))
201186
close(rs.squareDn)
202-
return rs.squareImported, nil
187+
return rs.square, nil
203188
}
204189

205190
// isReconstructed report true whether the square attached to the session
@@ -232,16 +217,16 @@ func (rs *retrievalSession) Close() error {
232217
func (rs *retrievalSession) request(ctx context.Context) {
233218
t := time.NewTicker(RetrieveQuadrantTimeout)
234219
defer t.Stop()
235-
for retry := 0; retry < len(rs.quadrants); retry++ {
236-
q := rs.quadrants[retry]
220+
for retry := 0; retry < len(rs.squareQuadrants); retry++ {
221+
q := rs.squareQuadrants[retry]
237222
log.Debugw("requesting quadrant",
238223
"axis", q.source,
239224
"x", q.x,
240225
"y", q.y,
241226
"size", len(q.roots),
242227
)
243228
rs.span.AddEvent("requesting quadrant", trace.WithAttributes(
244-
attribute.Int("axis", q.source),
229+
attribute.Int("axis", int(q.source)),
245230
attribute.Int("x", q.x),
246231
attribute.Int("y", q.y),
247232
attribute.Int("size", len(q.roots)),
@@ -260,7 +245,7 @@ func (rs *retrievalSession) request(ctx context.Context) {
260245
"size", len(q.roots),
261246
)
262247
rs.span.AddEvent("quadrant request timeout", trace.WithAttributes(
263-
attribute.Int("axis", q.source),
248+
attribute.Int("axis", int(q.source)),
264249
attribute.Int("x", q.x),
265250
attribute.Int("y", q.y),
266251
attribute.Int("size", len(q.roots)),
@@ -292,10 +277,10 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
292277
// in the square.
293278
// NOTE-2: We never actually fetch shares from the network *twice*.
294279
// Once a share is downloaded from the network it is cached on the IPLD(blockservice) level.
295-
// calc index of the share
296-
idx := q.index(i, j)
280+
// calc position of the share
281+
x, y := q.pos(i, j)
297282
// try to lock the share
298-
ok := rs.sharesLks[idx].TryLock()
283+
ok := rs.squareCellsLks[x][y].TryLock()
299284
if !ok {
300285
// if already locked and written - do nothing
301286
return
@@ -312,14 +297,17 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
312297
if rs.isReconstructed() {
313298
return
314299
}
315-
rs.square[idx] = share
300+
if rs.square.GetCell(uint(x), uint(y)) != nil {
301+
return
302+
}
303+
rs.square.SetCell(uint(x), uint(y), share)
316304
// if we have >= 1/4 of the square we can start trying to Reconstruct
317305
// TODO(@Wondertan): This is not an ideal way to know when to start
318306
// reconstruction and can cause idle reconstruction tries in some cases,
319307
// but it is totally fine for the happy case and for now.
320308
// The earlier we correctly know that we have the full square - the earlier
321309
// we cancel ongoing requests - the less data is being wastedly transferred.
322-
if atomic.AddUint32(&rs.sharesCount, 1) >= uint32(size*size) {
310+
if atomic.AddUint32(&rs.squareCellsCount, 1) >= uint32(size*size) {
323311
select {
324312
case rs.squareSig <- struct{}{}:
325313
default:

Diff for: share/eds/retriever_quadrant.go

+17-38
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package eds
22

33
import (
4-
"math"
54
"math/rand"
65
"time"
76

87
"github.com/ipfs/go-cid"
98

109
"github.com/celestiaorg/celestia-app/pkg/da"
10+
"github.com/celestiaorg/rsmt2d"
1111

1212
"github.com/celestiaorg/celestia-node/share/ipld"
1313
)
@@ -42,10 +42,8 @@ type quadrant struct {
4242
// |(0;1)| |(1;1)|
4343
// ------ -------
4444
x, y int
45-
// source defines the axis for quadrant
46-
// it can be either 1 or 0 similar to x and y
47-
// where 0 is Row source and 1 is Col respectively
48-
source int
45+
// source defines the axis(Row or Col) to fetch the quadrant from
46+
source rsmt2d.Axis
4947
}
5048

5149
// newQuadrants constructs a slice of quadrants from DAHeader.
@@ -70,17 +68,13 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
7068
}
7169

7270
for i := range quadrants {
73-
// convert quadrant index into coordinates
71+
// convert quadrant 1D into into 2D coordinates
7472
x, y := i%2, i/2
75-
if source == 1 { // swap coordinates for column
76-
x, y = y, x
77-
}
78-
7973
quadrants[i] = &quadrant{
8074
roots: roots[qsize*y : qsize*(y+1)],
8175
x: x,
8276
y: y,
83-
source: source,
77+
source: rsmt2d.Axis(source),
8478
}
8579
}
8680
}
@@ -93,31 +87,16 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
9387
return quadrants
9488
}
9589

96-
// index calculates index for a share in a data square slice flattened by rows.
97-
//
98-
// NOTE: The complexity of the formula below comes from:
99-
// - Goal to avoid share copying
100-
// - Goal to make formula generic for both rows and cols
101-
// - While data square is flattened by rows only
102-
//
103-
// TODO(@Wondertan): This can be simplified by making rsmt2d working over 3D byte slice(not
104-
// flattened)
105-
func (q *quadrant) index(rootIdx, cellIdx int) int {
106-
size := len(q.roots)
107-
// half square offsets, e.g. share is from Q3,
108-
// so we add to index Q1+Q2
109-
halfSquareOffsetCol := pow(size*2, q.source)
110-
halfSquareOffsetRow := pow(size*2, q.source^1)
111-
// offsets for the axis, e.g. share is from Q4.
112-
// so we add to index Q3
113-
offsetX := q.x * halfSquareOffsetCol * size
114-
offsetY := q.y * halfSquareOffsetRow * size
115-
116-
rootIdx *= halfSquareOffsetRow
117-
cellIdx *= halfSquareOffsetCol
118-
return rootIdx + cellIdx + offsetX + offsetY
119-
}
120-
121-
func pow(x, y int) int {
122-
return int(math.Pow(float64(x), float64(y)))
90+
// pos calculates position of a share in a data square.
91+
func (q *quadrant) pos(rootIdx, cellIdx int) (int, int) {
92+
cellIdx += len(q.roots) * q.x
93+
rootIdx += len(q.roots) * q.y
94+
switch q.source {
95+
case rsmt2d.Row:
96+
return rootIdx, cellIdx
97+
case rsmt2d.Col:
98+
return cellIdx, rootIdx
99+
default:
100+
panic("unknown axis")
101+
}
123102
}

0 commit comments

Comments
 (0)