Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions pkg/preparation/blobs/blobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/ipfs/go-cid"
logging "github.com/ipfs/go-log/v2"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
"github.com/multiformats/go-multihash"
"github.com/storacha/go-libstoracha/blobindex"
"github.com/storacha/go-ucanto/did"

Expand Down Expand Up @@ -337,13 +336,6 @@ func (a API) fastWriteShard(ctx context.Context, shardID id.ShardID, offset uint
return nil
}

nodes, err := a.Repo.NodesByShard(ctx, shardID, offset)
if err != nil {
log.Debug("Error getting nodes for shard:", err)

return err
}

nodeReader, err := a.OpenNodeReader()
if err != nil {
return fmt.Errorf("failed to open node reader for shard %s: %w", shardID, err)
Expand Down Expand Up @@ -376,12 +368,24 @@ func (a API) fastWriteShard(ctx context.Context, shardID id.ShardID, offset uint

go func() {
defer close(jobs)
for idx, node := range nodes {

idx := 0
for nis, err := range a.Repo.ForEachNodeInShard(ctx, shardID, offset) {
if err != nil {
// If we fail to iterate nodes, send the error and stop producing jobs.
select {
case <-ctx.Done():
return
case results <- result{err: fmt.Errorf("failed to iterate nodes in shard %s: %w", shardID, err)}:
}
return
}
select {
case <-ctx.Done():
return
case jobs <- job{idx: idx, node: node}:
case jobs <- job{idx: idx, node: nis.Node}:
}
idx++
}
}()

Expand Down Expand Up @@ -451,11 +455,15 @@ func (a API) fastWriteShard(ctx context.Context, shardID id.ShardID, offset uint
// node.
func (a API) makeErrBadNodes(ctx context.Context, shardID id.ShardID, nodeReader nodereader.NodeReader) error {
// Collect the nodes first, because we can't read data for each node while
// holding the lock on the database that ForEachNode has. This means holding a
// bunch of nodes in memory, but it's limited to the size of a shard.
nodes, err := a.Repo.NodesByShard(ctx, shardID, 0)
if err != nil {
return fmt.Errorf("failed to iterate over nodes in shard %s: %w", shardID, err)
// holding the lock on the database that ForEachNodeInShard has. This means
// holding a bunch of nodes in memory, but it's limited to the size of a
// shard.
var nodes []dagsmodel.Node
for nis, err := range a.Repo.ForEachNodeInShard(ctx, shardID, 0) {
if err != nil {
return fmt.Errorf("failed to iterate over nodes in shard %s: %w", shardID, err)
}
nodes = append(nodes, nis.Node)
}

var errs []types.BadNodeError
Expand Down Expand Up @@ -499,16 +507,15 @@ func (a API) ReaderForIndex(ctx context.Context, indexID id.IndexID) (io.ReadClo
// Query all nodes across all shards in this index in a single batch query
nodeCount := 0
log.Infow("building index", "index", indexID)
err = a.Repo.ForEachNodeInIndex(ctx, indexID, func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error {
for nii, err := range a.Repo.ForEachNodeInIndex(ctx, indexID) {
if err != nil {
return nil, fmt.Errorf("iterating nodes in index %s: %w", indexID, err)
}
nodeCount++
if nodeCount%10000 == 0 {
log.Infow("building index", "index", indexID, "nodes", nodeCount)
}
indexView.SetSlice(shardDigest, nodeCID.Hash(), blobindex.Position{Offset: shardOffset, Length: nodeSize})
return nil
})
if err != nil {
return nil, fmt.Errorf("iterating nodes in index %s: %w", indexID, err)
indexView.SetSlice(nii.ShardDigest, nii.NodeCID.Hash(), blobindex.Position{Offset: nii.ShardOffset, Length: nii.NodeSize})
}
log.Infow("built index", "index", indexID, "nodes", nodeCount)

Expand Down
22 changes: 17 additions & 5 deletions pkg/preparation/blobs/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package blobs
import (
"context"
"io"
"iter"

"github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash"
Expand All @@ -24,10 +25,9 @@ type Repo interface {
// NodesNotInShards returns CIDs of nodes for the given upload that are not yet assigned to shards.
NodesNotInShards(ctx context.Context, uploadID id.UploadID, spaceDID did.DID) ([]cid.Cid, error)
FindNodeByCIDAndSpaceDID(ctx context.Context, c cid.Cid, spaceDID did.DID) (dagsmodel.Node, error)
ForEachNode(ctx context.Context, shardID id.ShardID, yield func(node dagsmodel.Node, shardOffset uint64) error) error
// NodesByShard fetches all the nodes for a given shard, returned in the order
// they should appear in the shard.
NodesByShard(ctx context.Context, shardID id.ShardID, startOffset uint64) ([]dagsmodel.Node, error)
// ForEachNodeInShard iterates over all the nodes for a given shard, in the
// order they should appear in the shard.
ForEachNodeInShard(ctx context.Context, shardID id.ShardID, startOffset uint64) iter.Seq2[NodeInShard, error]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, likely subjective preference: I'd find the name of iterator functions to be more idiomatic if the "ForEach" was dropped and the plural was used. I'd call them just NodesInShard and NodesInIndex. Personally, the resulting "for node = range for each node..." feels a bit awkward and redundant 🙂.

GetSpaceByDID(ctx context.Context, spaceDID did.DID) (*spacesmodel.Space, error)
DeleteShard(ctx context.Context, shardID id.ShardID) error

Expand All @@ -41,12 +41,24 @@ type Repo interface {
ShardsForIndex(ctx context.Context, indexID id.IndexID) ([]*model.Shard, error)
// ForEachNodeInIndex iterates over all nodes across all shards in an index,
// ordered by shard. This is a batch query that avoids per-shard round trips.
ForEachNodeInIndex(ctx context.Context, indexID id.IndexID, yield func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error) error
ForEachNodeInIndex(ctx context.Context, indexID id.IndexID) iter.Seq2[NodeInIndex, error]

// Upload methods
GetUploadByID(ctx context.Context, uploadID id.UploadID) (*uploadsmodel.Upload, error)
}

type NodeInShard struct {
Node dagsmodel.Node
ShardOffset uint64
}

type NodeInIndex struct {
NodeCID cid.Cid
NodeSize uint64
ShardDigest multihash.Multihash
ShardOffset uint64
}

// ShardEncoder is the interface for shard implementations.
type ShardEncoder interface {
// WriteHeader writes a header to the provided writer.
Expand Down
66 changes: 41 additions & 25 deletions pkg/preparation/sqlrepo/indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package sqlrepo
import (
"context"
"fmt"
"iter"

"github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash"
"github.com/storacha/go-ucanto/core/invocation"
"github.com/storacha/guppy/pkg/preparation/blobs"
"github.com/storacha/guppy/pkg/preparation/blobs/model"
"github.com/storacha/guppy/pkg/preparation/sqlrepo/util"
"github.com/storacha/guppy/pkg/preparation/types/id"
Expand Down Expand Up @@ -350,39 +352,53 @@ func (r *Repo) ShardsForIndex(ctx context.Context, indexID id.IndexID) ([]*model
return shards, nil
}

func (r *Repo) ForEachNodeInIndex(ctx context.Context, indexID id.IndexID, yield func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error) error {
stmt, err := r.prepareStmt(ctx, `
func (r *Repo) ForEachNodeInIndex(ctx context.Context, indexID id.IndexID) iter.Seq2[blobs.NodeInIndex, error] {
return func(yield func(blobs.NodeInIndex, error) bool) {
stmt, err := r.prepareStmt(ctx, `
SELECT s.id, s.digest, nodes.cid, nodes.size, nu.shard_offset
FROM shards_in_indexes si
JOIN shards s ON s.id = si.shard_id
JOIN node_uploads nu ON nu.shard_id = si.shard_id
JOIN nodes ON nodes.cid = nu.node_cid AND nodes.space_did = nu.space_did
WHERE si.index_id = ?`)
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
rows, err := stmt.QueryContext(ctx, indexID)
if err != nil {
return fmt.Errorf("failed to query nodes in index %s: %w", indexID, err)
}
defer rows.Close()

for rows.Next() {
var shardID id.ShardID
var shardDigest multihash.Multihash
var nodeCID cid.Cid
var nodeSize uint64
var shardOffset uint64
if err := rows.Scan(&shardID, util.DbBytes(&shardDigest), util.DbCID(&nodeCID), &nodeSize, &shardOffset); err != nil {
return fmt.Errorf("failed to scan node row in index %s: %w", indexID, err)
if err != nil {
yield(blobs.NodeInIndex{}, fmt.Errorf("failed to prepare statement: %w", err))
return
Comment on lines +365 to +366
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: iterator functions should check what yield returns and only return if yield == false. That lets the caller decide whether the loop should end in case of error.

In this specific case it's not a huge deal because these errors will likely happen for every element, but I think checking the result of yield would be nice from a correctness standpoint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

}
if len(shardDigest) == 0 {
return fmt.Errorf("failed to iterate nodes in index %s, because shard with ID %s has no digest set", indexID, shardID)
rows, err := stmt.QueryContext(ctx, indexID)
if err != nil {
yield(blobs.NodeInIndex{}, fmt.Errorf("failed to query nodes in index %s: %w", indexID, err))
return
}
if err := yield(shardDigest, nodeCID, nodeSize, shardOffset); err != nil {
return fmt.Errorf("failed to yield node %s in index %s: %w", nodeCID, indexID, err)
defer rows.Close()

for rows.Next() {
var shardID id.ShardID
var shardDigest multihash.Multihash
var nodeCID cid.Cid
var nodeSize uint64
var shardOffset uint64
if err := rows.Scan(&shardID, util.DbBytes(&shardDigest), util.DbCID(&nodeCID), &nodeSize, &shardOffset); err != nil {
yield(blobs.NodeInIndex{}, fmt.Errorf("failed to scan node row in index %s: %w", indexID, err))
return
}
if len(shardDigest) == 0 {
yield(blobs.NodeInIndex{}, fmt.Errorf("failed to iterate nodes in index %s, because shard with ID %s has no digest set", indexID, shardID))
return
}
if !yield(blobs.NodeInIndex{
NodeCID: nodeCID,
NodeSize: nodeSize,
ShardDigest: shardDigest,
ShardOffset: shardOffset,
}, nil) {
return
}
}
}

return rows.Err()
err = rows.Err()
if err != nil {
yield(blobs.NodeInIndex{}, fmt.Errorf("error iterating node rows in index %s: %w", indexID, err))
}
}
}
82 changes: 34 additions & 48 deletions pkg/preparation/sqlrepo/indexes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/multiformats/go-multihash"
"github.com/storacha/go-libstoracha/testutil"
"github.com/storacha/go-ucanto/did"
"github.com/storacha/guppy/pkg/preparation/blobs"
"github.com/storacha/guppy/pkg/preparation/internal/testdb"
"github.com/storacha/guppy/pkg/preparation/sqlrepo"
"github.com/storacha/guppy/pkg/preparation/types/id"
Expand Down Expand Up @@ -285,49 +286,35 @@ func TestForEachNodeInIndex(t *testing.T) {
err = repo.AddShardToIndex(t.Context(), index.ID(), shard2.ID())
require.NoError(t, err)

// Collect all yielded rows
type row struct {
shardDigest multihash.Multihash
nodeCID cid.Cid
nodeSize uint64
shardOffset uint64
}
var rows []row
err = repo.ForEachNodeInIndex(t.Context(), index.ID(), func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error {
rows = append(rows, row{shardDigest: shardDigest, nodeCID: nodeCID, nodeSize: nodeSize, shardOffset: shardOffset})
return nil
})
require.NoError(t, err)
require.Len(t, rows, 4)

// Build a lookup by nodeCID for easier assertions
byCID := map[string]row{}
for _, r := range rows {
byCID[r.nodeCID.String()] = r
byCID := map[string]blobs.NodeInIndex{}
for nii, err := range repo.ForEachNodeInIndex(t.Context(), index.ID()) {
require.NoError(t, err)
byCID[nii.NodeCID.String()] = nii
}
require.Len(t, byCID, 4)

// shard_offset = shard.size + offset at time of AddNodeToShard
// shard1: node1 offset=0+0=0, size becomes 100; node2 offset=100+0=100
// shard2: node3 offset=0+0=0, size becomes 300; node4 offset=300+0=300
r1 := byCID[nodeCID1.String()]
require.Equal(t, digest1, []byte(r1.shardDigest))
require.Equal(t, uint64(100), r1.nodeSize)
require.Equal(t, uint64(0), r1.shardOffset)
require.Equal(t, digest1, []byte(r1.ShardDigest))
require.Equal(t, uint64(100), r1.NodeSize)
require.Equal(t, uint64(0), r1.ShardOffset)

r2 := byCID[nodeCID2.String()]
require.Equal(t, digest1, []byte(r2.shardDigest))
require.Equal(t, uint64(200), r2.nodeSize)
require.Equal(t, uint64(100), r2.shardOffset)
require.Equal(t, digest1, []byte(r2.ShardDigest))
require.Equal(t, uint64(200), r2.NodeSize)
require.Equal(t, uint64(100), r2.ShardOffset)

r3 := byCID[nodeCID3.String()]
require.Equal(t, digest2, []byte(r3.shardDigest))
require.Equal(t, uint64(300), r3.nodeSize)
require.Equal(t, uint64(0), r3.shardOffset)
require.Equal(t, digest2, []byte(r3.ShardDigest))
require.Equal(t, uint64(300), r3.NodeSize)
require.Equal(t, uint64(0), r3.ShardOffset)

r4 := byCID[nodeCID4.String()]
require.Equal(t, digest2, []byte(r4.shardDigest))
require.Equal(t, uint64(400), r4.nodeSize)
require.Equal(t, uint64(300), r4.shardOffset)
require.Equal(t, digest2, []byte(r4.ShardDigest))
require.Equal(t, uint64(400), r4.NodeSize)
require.Equal(t, uint64(300), r4.ShardOffset)
})

t.Run("yields nothing for an empty index", func(t *testing.T) {
Expand All @@ -339,11 +326,9 @@ func TestForEachNodeInIndex(t *testing.T) {
require.NoError(t, err)

called := false
err = repo.ForEachNodeInIndex(t.Context(), index.ID(), func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error {
for range repo.ForEachNodeInIndex(t.Context(), index.ID()) {
called = true
return nil
})
require.NoError(t, err)
}
require.False(t, called)
})

Expand Down Expand Up @@ -401,21 +386,19 @@ func TestForEachNodeInIndex(t *testing.T) {

// ForEachNodeInIndex for index1 should only yield node1
var nodeCIDs1 []cid.Cid
err = repo.ForEachNodeInIndex(t.Context(), index1.ID(), func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error {
nodeCIDs1 = append(nodeCIDs1, nodeCID)
return nil
})
require.NoError(t, err)
for nii, err := range repo.ForEachNodeInIndex(t.Context(), index1.ID()) {
require.NoError(t, err)
nodeCIDs1 = append(nodeCIDs1, nii.NodeCID)
}
require.Len(t, nodeCIDs1, 1)
require.Equal(t, nodeCID1, nodeCIDs1[0])

// ForEachNodeInIndex for index2 should only yield node2
var nodeCIDs2 []cid.Cid
err = repo.ForEachNodeInIndex(t.Context(), index2.ID(), func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error {
nodeCIDs2 = append(nodeCIDs2, nodeCID)
return nil
})
require.NoError(t, err)
for nii, err := range repo.ForEachNodeInIndex(t.Context(), index2.ID()) {
require.NoError(t, err)
nodeCIDs2 = append(nodeCIDs2, nii.NodeCID)
}
require.Len(t, nodeCIDs2, 1)
require.Equal(t, nodeCID2, nodeCIDs2[0])
})
Expand Down Expand Up @@ -465,9 +448,12 @@ func TestForEachNodeInIndex(t *testing.T) {
err = repo.AddShardToIndex(t.Context(), index.ID(), shard2.ID())
require.NoError(t, err)

err = repo.ForEachNodeInIndex(t.Context(), index.ID(), func(shardDigest multihash.Multihash, nodeCID cid.Cid, nodeSize uint64, shardOffset uint64) error {
return nil
})
for _, nodeErr := range repo.ForEachNodeInIndex(t.Context(), index.ID()) {
if nodeErr != nil {
err = nodeErr
break
}
}
require.ErrorContains(t, err, fmt.Sprintf("failed to iterate nodes in index %s, because shard with ID %s has no digest set", index.ID(), shard2.ID()))
})
}
Loading