diff --git a/cmd/devp2p/internal/ethtest/snap.go b/cmd/devp2p/internal/ethtest/snap.go
index b8bad6477..131a22d8d 100644
--- a/cmd/devp2p/internal/ethtest/snap.go
+++ b/cmd/devp2p/internal/ethtest/snap.go
@@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/internal/utesting"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode"
"golang.org/x/crypto/sha3"
@@ -938,10 +939,14 @@ func (s *Suite) snapGetTrieNodes(t *utesting.T, tc *trieNodesTest) error {
}
// write0 request
+ paths, err := rlp.EncodeToRawList(tc.paths)
+ if err != nil {
+ panic(err)
+ }
req := &snap.GetTrieNodesPacket{
ID: uint64(rand.Int63()),
Root: tc.root,
- Paths: tc.paths,
+ Paths: paths,
Bytes: tc.nBytes,
}
msg, err := conn.snapRequest(snap.GetTrieNodesMsg, req)
diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go
index 01eb0fd54..0db206b24 100644
--- a/cmd/devp2p/internal/ethtest/suite.go
+++ b/cmd/devp2p/internal/ethtest/suite.go
@@ -18,6 +18,7 @@ package ethtest
import (
"crypto/rand"
+ "fmt"
"math/big"
"reflect"
@@ -31,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/holiman/uint256"
)
@@ -150,7 +152,11 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) {
if err != nil {
t.Fatalf("failed to get headers for given request: %v", err)
}
- if !headersMatch(expected, headers.BlockHeadersRequest) {
+ received, err := headers.List.Items()
+ if err != nil {
+ t.Fatalf("invalid headers received: %v", err)
+ }
+ if !headersMatch(expected, received) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers)
}
}
@@ -201,31 +207,23 @@ concurrently, with different request IDs.`)
}
// Wait for responses.
- headers1 := new(eth.BlockHeadersPacket)
- if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers1); err != nil {
- t.Fatalf("error reading block headers msg: %v", err)
- }
- if got, want := headers1.RequestId, req1.RequestId; got != want {
- t.Fatalf("unexpected request id in response: got %d, want %d", got, want)
- }
- headers2 := new(eth.BlockHeadersPacket)
- if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers2); err != nil {
- t.Fatalf("error reading block headers msg: %v", err)
- }
- if got, want := headers2.RequestId, req2.RequestId; got != want {
- t.Fatalf("unexpected request id in response: got %d, want %d", got, want)
+ // Note they can arrive in either order.
+ resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
+ if msg.RequestId != 111 && msg.RequestId != 222 {
+ t.Fatalf("response with unknown request ID: %v", msg.RequestId)
+ }
+ return msg.RequestId
+ })
+ if err != nil {
+ t.Fatal(err)
}
- // Check received headers for accuracy.
- if expected, err := s.chain.GetHeaders(req1); err != nil {
- t.Fatalf("failed to get expected headers for request 1: %v", err)
- } else if !headersMatch(expected, headers1.BlockHeadersRequest) {
- t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers1)
+ // Check if headers match.
+ if err := s.checkHeadersAgainstChain(req1, resp[111]); err != nil {
+ t.Fatal(err)
}
- if expected, err := s.chain.GetHeaders(req2); err != nil {
- t.Fatalf("failed to get expected headers for request 2: %v", err)
- } else if !headersMatch(expected, headers2.BlockHeadersRequest) {
- t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers2)
+ if err := s.checkHeadersAgainstChain(req2, resp[222]); err != nil {
+ t.Fatal(err)
}
}
@@ -259,7 +257,7 @@ same request ID. The node should handle the request by responding to both reques
Origin: eth.HashOrNumber{
Number: 33,
},
- Amount: 2,
+ Amount: 3,
},
}
@@ -271,33 +269,57 @@ same request ID. The node should handle the request by responding to both reques
t.Fatalf("failed to write to connection: %v", err)
}
- // Wait for the responses.
- headers1 := new(eth.BlockHeadersPacket)
- if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers1); err != nil {
- t.Fatalf("error reading from connection: %v", err)
- }
- if got, want := headers1.RequestId, request1.RequestId; got != want {
- t.Fatalf("unexpected request id: got %d, want %d", got, want)
+ // Wait for the responses. They can arrive in either order, and we can't tell them
+ // apart by their request ID, so use the number of headers instead.
+ resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
+ id := uint64(msg.List.Len())
+ if id != 2 && id != 3 {
+ t.Fatalf("invalid number of headers in response: %d", id)
+ }
+ return id
+ })
+ if err != nil {
+ t.Fatal(err)
}
- headers2 := new(eth.BlockHeadersPacket)
- if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers2); err != nil {
- t.Fatalf("error reading from connection: %v", err)
+
+ // Check if headers match.
+ if err := s.checkHeadersAgainstChain(request1, resp[2]); err != nil {
+ t.Fatal(err)
}
- if got, want := headers2.RequestId, request2.RequestId; got != want {
- t.Fatalf("unexpected request id: got %d, want %d", got, want)
+ if err := s.checkHeadersAgainstChain(request2, resp[3]); err != nil {
+ t.Fatal(err)
}
+}
- // Check if headers match.
- if expected, err := s.chain.GetHeaders(request1); err != nil {
- t.Fatalf("failed to get expected block headers: %v", err)
- } else if !headersMatch(expected, headers1.BlockHeadersRequest) {
- t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers1)
+func (s *Suite) checkHeadersAgainstChain(req *eth.GetBlockHeadersPacket, resp *eth.BlockHeadersPacket) error {
+ received2, err := resp.List.Items()
+ if err != nil {
+ return fmt.Errorf("invalid headers in response with request ID %v (%d items): %v", resp.RequestId, resp.List.Len(), err)
}
- if expected, err := s.chain.GetHeaders(request2); err != nil {
- t.Fatalf("failed to get expected block headers: %v", err)
- } else if !headersMatch(expected, headers2.BlockHeadersRequest) {
- t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers2)
+ if expected, err := s.chain.GetHeaders(req); err != nil {
+ return fmt.Errorf("test chain failed to get expected headers for request: %v", err)
+ } else if !headersMatch(expected, received2) {
+ return fmt.Errorf("header mismatch for request ID %v (%d items): \nexpected %v \ngot %v", resp.RequestId, resp.List.Len(), expected, resp)
}
+ return nil
+}
+
+// collectResponses waits for n messages of type T on the given connection.
+// The messsages are collected according to the 'identity' function.
+func collectHeaderResponses(conn *Conn, n int, identity func(*eth.BlockHeadersPacket) uint64) (map[uint64]*eth.BlockHeadersPacket, error) {
+ resp := make(map[uint64]*eth.BlockHeadersPacket, n)
+ for range n {
+ r := new(eth.BlockHeadersPacket)
+ if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, r); err != nil {
+ return resp, fmt.Errorf("read error: %v", err)
+ }
+ id := identity(r)
+ if resp[id] != nil {
+ return resp, fmt.Errorf("duplicate response %v", r)
+ }
+ resp[id] = r
+ }
+ return resp, nil
}
func (s *Suite) TestZeroRequestID(t *utesting.T) {
@@ -329,10 +351,8 @@ and expects a response.`)
if got, want := headers.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id")
}
- if expected, err := s.chain.GetHeaders(req); err != nil {
- t.Fatalf("failed to get expected block headers: %v", err)
- } else if !headersMatch(expected, headers.BlockHeadersRequest) {
- t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers)
+ if err := s.checkHeadersAgainstChain(req, headers); err != nil {
+ t.Fatal(err)
}
}
@@ -366,9 +386,52 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) {
if got, want := resp.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in respond", got, want)
}
- bodies := resp.BlockBodiesResponse
- if len(bodies) != len(req.GetBlockBodiesRequest) {
- t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), len(bodies))
+ if resp.List.Len() != len(req.GetBlockBodiesRequest) {
+ t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), resp.List.Len())
+ }
+}
+
+func (s *Suite) TestGetReceipts(t *utesting.T) {
+ t.Log(`This test sends GetReceipts requests to the node for known blocks in the test chain.`)
+ conn, err := s.dial()
+ if err != nil {
+ t.Fatalf("dial failed: %v", err)
+ }
+ defer conn.Close()
+ if err := conn.peer(s.chain, nil); err != nil {
+ t.Fatalf("peering failed: %v", err)
+ }
+
+ // Find some blocks containing receipts.
+ var hashes = make([]common.Hash, 0, 3)
+ for i := range s.chain.Len() {
+ block := s.chain.GetBlock(i)
+ if len(block.Transactions()) > 0 {
+ hashes = append(hashes, block.Hash())
+ }
+ if len(hashes) == cap(hashes) {
+ break
+ }
+ }
+
+ // Create receipts request.
+ req := ð.GetReceiptsPacket{
+ RequestId: 66,
+ GetReceiptsRequest: (eth.GetReceiptsRequest)(hashes),
+ }
+ if err := conn.Write(ethProto, eth.GetReceiptsMsg, req); err != nil {
+ t.Fatalf("could not write to connection: %v", err)
+ }
+ // Wait for response.
+ resp := new(eth.ReceiptsPacket)
+ if err := conn.ReadMsg(ethProto, eth.ReceiptsMsg, &resp); err != nil {
+ t.Fatalf("error reading block bodies msg: %v", err)
+ }
+ if got, want := resp.RequestId, req.RequestId; got != want {
+ t.Fatalf("unexpected request id in respond", got, want)
+ }
+ if resp.List.Len() != len(req.GetReceiptsRequest) {
+ t.Fatalf("wrong receipts in response: expected %d receipts, got %d", len(req.GetReceiptsRequest), resp.List.Len())
}
}
@@ -675,7 +738,11 @@ on another peer connection using GetPooledTransactions.`)
if got, want := msg.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in response: got %d, want %d", got, want)
}
- for _, got := range msg.PooledTransactionsResponse {
+ responseTxs, err := msg.List.Items()
+ if err != nil {
+ t.Fatalf("invalid transactions in response: %v", err)
+ }
+ for _, got := range responseTxs {
if _, exists := set[got.Hash()]; !exists {
t.Fatalf("unexpected tx received: %v", got.Hash())
}
@@ -779,7 +846,7 @@ func (s *Suite) makeBlobTxs(count, blobs int, discriminator byte) (txs types.Tra
from, nonce := s.chain.GetSender(5)
for i := 0; i < count; i++ {
// Make blob data, max of 2 blobs per tx.
- blobdata := make([]byte, blobs%3)
+ blobdata := make([]byte, min(blobs, 2))
for i := range blobdata {
blobdata[i] = discriminator
blobs -= 1
@@ -851,7 +918,8 @@ func (s *Suite) TestBlobViolations(t *utesting.T) {
if err := conn.ReadMsg(ethProto, eth.GetPooledTransactionsMsg, req); err != nil {
t.Fatalf("reading pooled tx request failed: %v", err)
}
- resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: test.resp}
+ encTxs, _ := rlp.EncodeToRawList(test.resp)
+ resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs}
if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil {
t.Fatalf("writing pooled tx response failed: %v", err)
}
diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go
index 80b5d8074..16f9a3ad8 100644
--- a/cmd/devp2p/internal/ethtest/transaction.go
+++ b/cmd/devp2p/internal/ethtest/transaction.go
@@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/internal/utesting"
+ "github.com/ethereum/go-ethereum/rlp"
)
// sendTxs sends the given transactions to the node and
@@ -51,7 +52,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
return fmt.Errorf("peering failed: %v", err)
}
- if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket(txs)); err != nil {
+ encTxs, _ := rlp.EncodeToRawList(txs)
+ if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket{RawList: encTxs}); err != nil {
return fmt.Errorf("failed to write message to connection: %v", err)
}
@@ -68,7 +70,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
}
switch msg := msg.(type) {
case *eth.TransactionsPacket:
- for _, tx := range *msg {
+ txs, _ := msg.Items()
+ for _, tx := range txs {
got[tx.Hash()] = true
}
case *eth.NewPooledTransactionHashesPacket:
@@ -80,9 +83,10 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
if err != nil {
t.Logf("invalid GetBlockHeaders request: %v", err)
}
+ encHeaders, _ := rlp.EncodeToRawList(headers)
recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{
- RequestId: msg.RequestId,
- BlockHeadersRequest: headers,
+ RequestId: msg.RequestId,
+ List: encHeaders,
})
default:
return fmt.Errorf("unexpected eth wire msg: %s", pretty.Sdump(msg))
@@ -167,9 +171,10 @@ func (s *Suite) sendInvalidTxs(t *utesting.T, txs []*types.Transaction) error {
if err != nil {
t.Logf("invalid GetBlockHeaders request: %v", err)
}
+ encHeaders, _ := rlp.EncodeToRawList(headers)
recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{
- RequestId: msg.RequestId,
- BlockHeadersRequest: headers,
+ RequestId: msg.RequestId,
+ List: encHeaders,
})
default:
return fmt.Errorf("unexpected eth message: %v", pretty.Sdump(msg))
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index 2468e1a98..27dc47a72 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -266,10 +266,12 @@ func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int,
func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) {
blobs := eth.ServiceGetBlockBodiesQuery(dlp.chain, hashes)
- bodies := make([]*eth.BlockBody, len(blobs))
+ bodies := make([]*types.Body, len(blobs))
+ ethbodies := make([]eth.BlockBody, len(blobs))
for i, blob := range blobs {
- bodies[i] = new(eth.BlockBody)
+ bodies[i] = new(types.Body)
rlp.DecodeBytes(blob, bodies[i])
+ rlp.DecodeBytes(blob, ðbodies[i])
}
var (
txsHashes = make([]common.Hash, len(bodies))
@@ -285,9 +287,13 @@ func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *et
Peer: dlp.id,
}
res := ð.Response{
- Req: req,
- Res: (*eth.BlockBodiesResponse)(&bodies),
- Meta: [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes},
+ Req: req,
+ Res: (*eth.BlockBodiesResponse)(ðbodies),
+ Meta: eth.BlockBodyHashes{
+ TransactionRoots: txsHashes,
+ UncleHashes: uncleHashes,
+ WithdrawalRoots: withdrawalHashes,
+ },
Time: 1,
Done: make(chan error, 1), // Ignore the returned status
}
@@ -303,14 +309,14 @@ func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *et
func (dlp *downloadTesterPeer) RequestReceipts(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) {
blobs := eth.ServiceGetReceiptsQuery(dlp.chain, hashes)
- receipts := make([][]*types.Receipt, len(blobs))
+ receipts := make([]types.Receipts, len(blobs))
for i, blob := range blobs {
rlp.DecodeBytes(blob, &receipts[i])
}
hasher := trie.NewStackTrie(nil)
hashes = make([]common.Hash, len(receipts))
for i, receipt := range receipts {
- hashes[i] = types.DeriveSha(types.Receipts(receipt), hasher)
+ hashes[i] = types.DeriveSha(receipt, hasher)
}
req := ð.Request{
Peer: dlp.id,
@@ -335,14 +341,14 @@ func (dlp *downloadTesterPeer) ID() string {
// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
-func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error {
+func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error {
// Create the request and service it
req := &snap.GetAccountRangePacket{
ID: id,
Root: root,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
}
slimaccs, proofs := snap.ServiceGetAccountRangeQuery(dlp.chain, req)
@@ -361,7 +367,7 @@ func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limi
// RequestStorageRanges fetches a batch of storage slots belonging to one or
// more accounts. If slots from only one account is requested, an origin marker
// may also be used to retrieve from there.
-func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error {
// Create the request and service it
req := &snap.GetStorageRangesPacket{
ID: id,
@@ -369,7 +375,7 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash,
Root: root,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
}
storage, proofs := snap.ServiceGetStorageRangesQuery(dlp.chain, req)
@@ -386,25 +392,28 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash,
}
// RequestByteCodes fetches a batch of bytecodes by hash.
-func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error {
req := &snap.GetByteCodesPacket{
ID: id,
Hashes: hashes,
- Bytes: bytes,
+ Bytes: uint64(bytes),
}
codes := snap.ServiceGetByteCodesQuery(dlp.chain, req)
go dlp.dl.downloader.SnapSyncer.OnByteCodes(dlp, id, codes)
return nil
}
-// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
-// a specific state trie.
-func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, paths []snap.TrieNodePathSet, bytes uint64) error {
+// RequestTrieNodes fetches a batch of account or storage trie nodes.
+func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []snap.TrieNodePathSet, bytes int) error {
+ encPaths, err := rlp.EncodeToRawList(paths)
+ if err != nil {
+ panic(err)
+ }
req := &snap.GetTrieNodesPacket{
ID: id,
Root: root,
- Paths: paths,
- Bytes: bytes,
+ Paths: encPaths,
+ Bytes: uint64(bytes),
}
nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now())
go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes)
diff --git a/eth/downloader/fetchers_concurrent_bodies.go b/eth/downloader/fetchers_concurrent_bodies.go
index 5105fda66..aa5f8e180 100644
--- a/eth/downloader/fetchers_concurrent_bodies.go
+++ b/eth/downloader/fetchers_concurrent_bodies.go
@@ -89,15 +89,14 @@ func (q *bodyQueue) request(peer *peerConnection, req *fetchRequest, resCh chan
// deliver is responsible for taking a generic response packet from the concurrent
// fetcher, unpacking the body data and delivering it to the downloader's queue.
func (q *bodyQueue) deliver(peer *peerConnection, packet *eth.Response) (int, error) {
- txs, uncles, withdrawals := packet.Res.(*eth.BlockBodiesResponse).Unpack()
- hashsets := packet.Meta.([][]common.Hash) // {txs hashes, uncle hashes, withdrawal hashes}
-
- accepted, err := q.queue.DeliverBodies(peer.id, txs, hashsets[0], uncles, hashsets[1], withdrawals, hashsets[2])
+ resp := packet.Res.(*eth.BlockBodiesResponse)
+ meta := packet.Meta.(eth.BlockBodyHashes)
+ accepted, err := q.queue.DeliverBodies(peer.id, meta, *resp)
switch {
- case err == nil && len(txs) == 0:
+ case err == nil && len(*resp) == 0:
peer.log.Trace("Requested bodies delivered")
case err == nil:
- peer.log.Trace("Delivered new batch of bodies", "count", len(txs), "accepted", accepted)
+ peer.log.Trace("Delivered new batch of bodies", "count", len(*resp), "accepted", accepted)
default:
peer.log.Debug("Failed to deliver retrieved bodies", "err", err)
}
diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go
index fe08ff64c..e50c876a3 100644
--- a/eth/downloader/queue.go
+++ b/eth/downloader/queue.go
@@ -29,10 +29,9 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/prque"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/crypto/kzg4844"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
- "github.com/ethereum/go-ethereum/params"
)
const (
@@ -772,62 +771,53 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, hashes []comm
// DeliverBodies injects a block body retrieval response into the results queue.
// The method returns the number of blocks bodies accepted from the delivery and
// also wakes any threads waiting for data delivery.
-func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListHashes []common.Hash,
- uncleLists [][]*types.Header, uncleListHashes []common.Hash,
- withdrawalLists [][]*types.Withdrawal, withdrawalListHashes []common.Hash) (int, error) {
+func (q *queue) DeliverBodies(id string, hashes eth.BlockBodyHashes, bodies []eth.BlockBody) (int, error) {
q.lock.Lock()
defer q.lock.Unlock()
+ var txLists [][]*types.Transaction
+ var uncleLists [][]*types.Header
+ var withdrawalLists [][]*types.Withdrawal
+
validate := func(index int, header *types.Header) error {
- if txListHashes[index] != header.TxHash {
+ if hashes.TransactionRoots[index] != header.TxHash {
return errInvalidBody
}
- if uncleListHashes[index] != header.UncleHash {
+ if hashes.UncleHashes[index] != header.UncleHash {
return errInvalidBody
}
if header.WithdrawalsHash == nil {
// nil hash means that withdrawals should not be present in body
- if withdrawalLists[index] != nil {
+ if bodies[index].Withdrawals != nil {
return errInvalidBody
}
} else { // non-nil hash: body must have withdrawals
- if withdrawalLists[index] == nil {
+ if bodies[index].Withdrawals == nil {
return errInvalidBody
}
- if withdrawalListHashes[index] != *header.WithdrawalsHash {
+ if hashes.WithdrawalRoots[index] != *header.WithdrawalsHash {
return errInvalidBody
}
}
- // Blocks must have a number of blobs corresponding to the header gas usage,
- // and zero before the Cancun hardfork.
- var blobs int
- for _, tx := range txLists[index] {
- // Count the number of blobs to validate against the header's blobGasUsed
- blobs += len(tx.BlobHashes())
-
- // Validate the data blobs individually too
- if tx.Type() == types.BlobTxType {
- if len(tx.BlobHashes()) == 0 {
- return errInvalidBody
- }
- for _, hash := range tx.BlobHashes() {
- if !kzg4844.IsValidVersionedHash(hash[:]) {
- return errInvalidBody
- }
- }
- if tx.BlobTxSidecar() != nil {
- return errInvalidBody
- }
- }
+ // decode
+ txs, err := bodies[index].Transactions.Items()
+ if err != nil {
+ return fmt.Errorf("%w: bad transactions: %v", errInvalidBody, err)
}
- if header.BlobGasUsed != nil {
- if want := *header.BlobGasUsed / params.BlobTxBlobGasPerBlob; uint64(blobs) != want { // div because the header is surely good vs the body might be bloated
- return errInvalidBody
+ txLists = append(txLists, txs)
+ uncles, err := bodies[index].Uncles.Items()
+ if err != nil {
+ return fmt.Errorf("%w: bad uncles: %v", errInvalidBody, err)
+ }
+ uncleLists = append(uncleLists, uncles)
+ if bodies[index].Withdrawals != nil {
+ withdrawals, err := bodies[index].Withdrawals.Items()
+ if err != nil {
+ return fmt.Errorf("%w: bad withdrawals: %v", errInvalidBody, err)
}
+ withdrawalLists = append(withdrawalLists, withdrawals)
} else {
- if blobs != 0 {
- return errInvalidBody
- }
+ withdrawalLists = append(withdrawalLists, nil)
}
return nil
}
@@ -838,14 +828,15 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListH
result.Withdrawals = withdrawalLists[index]
result.SetBodyDone()
}
+ nresults := len(hashes.TransactionRoots)
return q.deliver(id, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool,
- bodyReqTimer, bodyInMeter, bodyDropMeter, len(txLists), validate, reconstruct)
+ bodyReqTimer, bodyInMeter, bodyDropMeter, nresults, validate, reconstruct)
}
// DeliverReceipts injects a receipt retrieval response into the results queue.
// The method returns the number of transaction receipts accepted from the delivery
// and also wakes any threads waiting for data delivery.
-func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt, receiptListHashes []common.Hash) (int, error) {
+func (q *queue) DeliverReceipts(id string, receiptList []types.Receipts, receiptListHashes []common.Hash) (int, error) {
q.lock.Lock()
defer q.lock.Unlock()
diff --git a/eth/downloader/queue_test.go b/eth/downloader/queue_test.go
index 857ac4813..255e1030b 100644
--- a/eth/downloader/queue_test.go
+++ b/eth/downloader/queue_test.go
@@ -30,8 +30,10 @@ import (
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
@@ -323,26 +325,30 @@ func XTestDelivery(t *testing.T) {
emptyList []*types.Header
txset [][]*types.Transaction
uncleset [][]*types.Header
+ bodies []eth.BlockBody
)
numToSkip := rand.Intn(len(f.Headers))
for _, hdr := range f.Headers[0 : len(f.Headers)-numToSkip] {
- txset = append(txset, world.getTransactions(hdr.Number.Uint64()))
+ txs := world.getTransactions(hdr.Number.Uint64())
+ txset = append(txset, txs)
uncleset = append(uncleset, emptyList)
+ txsList, _ := rlp.EncodeToRawList(txs)
+ bodies = append(bodies, eth.BlockBody{Transactions: txsList})
+ }
+ hashes := eth.BlockBodyHashes{
+ TransactionRoots: make([]common.Hash, len(txset)),
+ UncleHashes: make([]common.Hash, len(uncleset)),
+ WithdrawalRoots: make([]common.Hash, len(txset)),
}
- var (
- txsHashes = make([]common.Hash, len(txset))
- uncleHashes = make([]common.Hash, len(uncleset))
- )
hasher := trie.NewStackTrie(nil)
for i, txs := range txset {
- txsHashes[i] = types.DeriveSha(types.Transactions(txs), hasher)
+ hashes.TransactionRoots[i] = types.DeriveSha(types.Transactions(txs), hasher)
}
for i, uncles := range uncleset {
- uncleHashes[i] = types.CalcUncleHash(uncles)
+ hashes.UncleHashes[i] = types.CalcUncleHash(uncles)
}
time.Sleep(100 * time.Millisecond)
- _, err := q.DeliverBodies(peer.id, txset, txsHashes, uncleset, uncleHashes, nil, nil)
- if err != nil {
+ if _, err := q.DeliverBodies(peer.id, hashes, bodies); err != nil {
fmt.Printf("delivered %d bodies %v\n", len(txset), err)
}
} else {
@@ -358,14 +364,14 @@ func XTestDelivery(t *testing.T) {
for {
f, _, _ := q.ReserveReceipts(peer, rand.Intn(50))
if f != nil {
- var rcs [][]*types.Receipt
+ var rcs []types.Receipts
for _, hdr := range f.Headers {
rcs = append(rcs, world.getReceipts(hdr.Number.Uint64()))
}
hasher := trie.NewStackTrie(nil)
hashes := make([]common.Hash, len(rcs))
for i, receipt := range rcs {
- hashes[i] = types.DeriveSha(types.Receipts(receipt), hasher)
+ hashes[i] = types.DeriveSha(receipt, hasher)
}
_, err := q.DeliverReceipts(peer.id, rcs, hashes)
if err != nil {
diff --git a/eth/fetcher/block_fetcher.go b/eth/fetcher/block_fetcher.go
index 603f6246c..b4cdcf9d5 100644
--- a/eth/fetcher/block_fetcher.go
+++ b/eth/fetcher/block_fetcher.go
@@ -541,7 +541,11 @@ func (f *BlockFetcher) loop() {
case res := <-resCh:
res.Done <- nil
// Ignoring withdrawals here, since the block fetcher is not used post-merge.
- txs, uncles, _ := res.Res.(*eth.BlockBodiesResponse).Unpack()
+ txs, uncles, _, err := res.Res.(*eth.BlockBodiesResponse).Unpack()
+ if err != nil {
+ f.dropPeer(peer)
+ return
+ }
f.FilterBodies(peer, txs, uncles, time.Now())
case <-timeout.C:
diff --git a/eth/fetcher/block_fetcher_test.go b/eth/fetcher/block_fetcher_test.go
index cb7cbaf79..4f4292b81 100644
--- a/eth/fetcher/block_fetcher_test.go
+++ b/eth/fetcher/block_fetcher_test.go
@@ -32,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/triedb"
)
@@ -234,21 +235,11 @@ func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*typ
// Create a function that returns blocks from the closure
return func(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) {
// Gather the block bodies to return
- transactions := make([][]*types.Transaction, 0, len(hashes))
- uncles := make([][]*types.Header, 0, len(hashes))
+ bodies := make([]eth.BlockBody, len(hashes))
- for _, hash := range hashes {
+ for i, hash := range hashes {
if block, ok := closure[hash]; ok {
- transactions = append(transactions, block.Transactions())
- uncles = append(uncles, block.Uncles())
- }
- }
- // Return on a new thread
- bodies := make([]*eth.BlockBody, len(transactions))
- for i, txs := range transactions {
- bodies[i] = ð.BlockBody{
- Transactions: txs,
- Uncles: uncles[i],
+ bodies[i] = encodeBody(block)
}
}
req := ð.Request{
@@ -267,6 +258,26 @@ func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*typ
}
}
+func encodeBody(b *types.Block) eth.BlockBody {
+ body := eth.BlockBody{
+ Transactions: encodeRL([]*types.Transaction(b.Transactions())),
+ Uncles: encodeRL(b.Uncles()),
+ }
+ if b.Withdrawals() != nil {
+ wd := encodeRL([]*types.Withdrawal(b.Withdrawals()))
+ body.Withdrawals = &wd
+ }
+ return body
+}
+
+func encodeRL[T any](slice []T) rlp.RawList[T] {
+ rl, err := rlp.EncodeToRawList(slice)
+ if err != nil {
+ panic(err)
+ }
+ return rl
+}
+
// verifyFetchingEvent verifies that one single event arrive on a fetching channel.
func verifyFetchingEvent(t *testing.T, fetching chan []common.Hash, arrive bool) {
t.Helper()
diff --git a/eth/handler_eth.go b/eth/handler_eth.go
index 8edf60e10..fbe096c29 100644
--- a/eth/handler_eth.go
+++ b/eth/handler_eth.go
@@ -71,15 +71,24 @@ func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
return h.txFetcher.Notify(peer.ID(), packet.Types, packet.Sizes, packet.Hashes)
case *eth.TransactionsPacket:
- for _, tx := range *packet {
- if tx.Type() == types.BlobTxType {
- return errors.New("disallowed broadcast blob transaction")
- }
+ txs, err := packet.Items()
+ if err != nil {
+ return fmt.Errorf("Transactions: %v", err)
}
- return h.txFetcher.Enqueue(peer.ID(), *packet, false)
+ if err := handleTransactions(peer, txs, true); err != nil {
+ return fmt.Errorf("Transactions: %v", err)
+ }
+ return h.txFetcher.Enqueue(peer.ID(), txs, false)
- case *eth.PooledTransactionsResponse:
- return h.txFetcher.Enqueue(peer.ID(), *packet, true)
+ case *eth.PooledTransactionsPacket:
+ txs, err := packet.List.Items()
+ if err != nil {
+ return fmt.Errorf("PooledTransactions: %v", err)
+ }
+ if err := handleTransactions(peer, txs, false); err != nil {
+ return fmt.Errorf("PooledTransactions: %v", err)
+ }
+ return h.txFetcher.Enqueue(peer.ID(), txs, true)
default:
return fmt.Errorf("unexpected eth packet type: %T", packet)
@@ -137,3 +146,33 @@ func (h *ethHandler) handleBlockBroadcast(peer *eth.Peer, block *types.Block, td
}
return nil
}
+
+// handleTransactions marks all given transactions as known to the peer
+// and performs basic validations.
+func handleTransactions(peer *eth.Peer, list []*types.Transaction, directBroadcast bool) error {
+ seen := make(map[common.Hash]struct{})
+ for _, tx := range list {
+ if tx.Type() == types.BlobTxType {
+ if directBroadcast {
+ return errors.New("disallowed broadcast blob transaction")
+ } else {
+ // If we receive any blob transactions missing sidecars,
+ // disconnect from the sending peer.
+ if tx.BlobTxSidecar() == nil {
+ return errors.New("received sidecar-less blob transaction")
+ }
+ }
+ }
+
+ // Check for duplicates.
+ hash := tx.Hash()
+ if _, exists := seen[hash]; exists {
+ return fmt.Errorf("multiple copies of the same hash %v", hash)
+ }
+ seen[hash] = struct{}{}
+
+ // Mark as known.
+ peer.MarkTransaction(hash)
+ }
+ return nil
+}
diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go
index 964f8422e..72ddd0078 100644
--- a/eth/handler_eth_test.go
+++ b/eth/handler_eth_test.go
@@ -63,11 +63,19 @@ func (h *testEthHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
return nil
case *eth.TransactionsPacket:
- h.txBroadcasts.Send(([]*types.Transaction)(*packet))
+ txs, err := packet.Items()
+ if err != nil {
+ return err
+ }
+ h.txBroadcasts.Send(txs)
return nil
- case *eth.PooledTransactionsResponse:
- h.txBroadcasts.Send(([]*types.Transaction)(*packet))
+ case *eth.PooledTransactionsPacket:
+ txs, err := packet.List.Items()
+ if err != nil {
+ return err
+ }
+ h.txBroadcasts.Send(txs)
return nil
default:
diff --git a/eth/protocols/eth/dispatcher.go b/eth/protocols/eth/dispatcher.go
index ae98820cd..eed377cf6 100644
--- a/eth/protocols/eth/dispatcher.go
+++ b/eth/protocols/eth/dispatcher.go
@@ -22,6 +22,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
)
var (
@@ -47,9 +48,10 @@ type Request struct {
sink chan *Response // Channel to deliver the response on
cancel chan struct{} // Channel to cancel requests ahead of time
- code uint64 // Message code of the request packet
- want uint64 // Message code of the response packet
- data interface{} // Data content of the request packet
+ code uint64 // Message code of the request packet
+ want uint64 // Message code of the response packet
+ numItems int // Number of requested items
+ data interface{} // Data content of the request packet
Peer string // Demultiplexer if cross-peer requests are batched together
Sent time.Time // Timestamp when the request was sent
@@ -136,7 +138,7 @@ func (p *Peer) dispatchRequest(req *Request) error {
}
}
-// dispatchRequest fulfils a pending request and delivers it to the requested
+// dispatchResponse fulfils a pending request and delivers it to the requested
// sink.
func (p *Peer) dispatchResponse(res *Response, metadata func() interface{}) error {
resOp := &response{
@@ -188,20 +190,34 @@ func (p *Peer) dispatchResponse(res *Response, metadata func() interface{}) erro
func (p *Peer) dispatcher() {
pending := make(map[uint64]*Request)
+loop:
for {
select {
case reqOp := <-p.reqDispatch:
req := reqOp.req
req.Sent = time.Now()
- requestTracker.Track(p.id, p.version, req.code, req.want, req.id)
- err := p2p.Send(p.rw, req.code, req.data)
- reqOp.fail <- err
-
- if err == nil {
- pending[req.id] = req
+ // Register the request with the tracker before sending it on the
+ // wire. This ensures any incoming response can be matched even if
+ // it arrives before the send call returns.
+ treq := tracker.Request{
+ ID: req.id,
+ ReqCode: req.code,
+ RespCode: req.want,
+ Size: req.numItems,
+ }
+ if err := p.tracker.Track(treq); err != nil {
+ reqOp.fail <- err
+ continue loop
+ }
+ if err := p2p.Send(p.rw, req.code, req.data); err != nil {
+ reqOp.fail <- err
+ continue loop
}
+ pending[req.id] = req
+ reqOp.fail <- nil
+
case cancelOp := <-p.reqCancel:
// Retrieve the pending request to cancel and short circuit if it
// has already been serviced and is not available anymore
@@ -218,9 +234,6 @@ func (p *Peer) dispatcher() {
res := resOp.res
res.Req = pending[res.id]
- // Independent if the request exists or not, track this packet
- requestTracker.Fulfil(p.id, p.version, res.code, res.id)
-
switch {
case res.Req == nil:
// Response arrived with an untracked ID. Since even cancelled
@@ -247,6 +260,7 @@ func (p *Peer) dispatcher() {
}
case <-p.term:
+ p.tracker.Stop()
return
}
}
diff --git a/eth/protocols/eth/handler.go b/eth/protocols/eth/handler.go
index 35c99d26a..45bbae3d7 100644
--- a/eth/protocols/eth/handler.go
+++ b/eth/protocols/eth/handler.go
@@ -194,7 +194,12 @@ func handleMessage(backend Backend, peer *Peer) error {
}
defer msg.Discard()
- var handlers = eth68
+ var handlers map[uint64]msgHandler
+ if peer.version == ETH68 {
+ handlers = eth68
+ } else {
+ return fmt.Errorf("unknown eth protocol version: %v", peer.version)
+ }
// Track the amount of time it takes to serve the request and run the handler
if metrics.Enabled() {
diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go
index 307b13d38..4c43e052e 100644
--- a/eth/protocols/eth/handler_test.go
+++ b/eth/protocols/eth/handler_test.go
@@ -20,8 +20,10 @@ import (
"math"
"math/big"
"math/rand"
+ "reflect"
"testing"
+ "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/consensus/beacon"
@@ -37,6 +39,8 @@ import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
)
var (
@@ -339,8 +343,8 @@ func testGetBlockHeaders(t *testing.T, protocol uint) {
GetBlockHeadersRequest: tt.query,
})
if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, &BlockHeadersPacket{
- RequestId: 123,
- BlockHeadersRequest: headers,
+ RequestId: 123,
+ List: encodeRL(headers),
}); err != nil {
t.Errorf("test %d: headers mismatch: %v", i, err)
}
@@ -353,7 +357,7 @@ func testGetBlockHeaders(t *testing.T, protocol uint) {
RequestId: 456,
GetBlockHeadersRequest: tt.query,
})
- expected := &BlockHeadersPacket{RequestId: 456, BlockHeadersRequest: headers}
+ expected := &BlockHeadersPacket{RequestId: 456, List: encodeRL(headers)}
if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, expected); err != nil {
t.Errorf("test %d by hash: headers mismatch: %v", i, err)
}
@@ -417,7 +421,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
// Collect the hashes to request, and the response to expect
var (
hashes []common.Hash
- bodies []*BlockBody
+ bodies []BlockBody
seen = make(map[int64]bool)
)
for j := 0; j < tt.random; j++ {
@@ -429,7 +433,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
block := backend.chain.GetBlockByNumber(uint64(num))
hashes = append(hashes, block.Hash())
if len(bodies) < tt.expected {
- bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()})
+ bodies = append(bodies, encodeBody(block))
}
break
}
@@ -439,7 +443,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
hashes = append(hashes, hash)
if tt.available[j] && len(bodies) < tt.expected {
block := backend.chain.GetBlockByHash(hash)
- bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()})
+ bodies = append(bodies, encodeBody(block))
}
}
@@ -449,14 +453,65 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
GetBlockBodiesRequest: hashes,
})
if err := p2p.ExpectMsg(peer.app, BlockBodiesMsg, &BlockBodiesPacket{
- RequestId: 123,
- BlockBodiesResponse: bodies,
+ RequestId: 123,
+ List: encodeRL(bodies),
}); err != nil {
t.Fatalf("test %d: bodies mismatch: %v", i, err)
}
}
}
+func encodeBody(b *types.Block) BlockBody {
+ body := BlockBody{
+ Transactions: encodeRL([]*types.Transaction(b.Transactions())),
+ Uncles: encodeRL(b.Uncles()),
+ }
+ if b.Withdrawals() != nil {
+ wd := encodeRL([]*types.Withdrawal(b.Withdrawals()))
+ body.Withdrawals = &wd
+ }
+ return body
+}
+
+func TestHashBody(t *testing.T) {
+ key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
+ signer := types.NewCancunSigner(big.NewInt(1))
+
+ // create block 1
+ header := &types.Header{Number: big.NewInt(11)}
+ txs := []*types.Transaction{
+ types.MustSignNewTx(key, signer, &types.DynamicFeeTx{
+ ChainID: big.NewInt(1),
+ Nonce: 1,
+ Data: []byte("testing"),
+ }),
+ types.MustSignNewTx(key, signer, &types.LegacyTx{
+ Nonce: 2,
+ Data: []byte("testing"),
+ }),
+ }
+ uncles := []*types.Header{{Number: big.NewInt(10)}}
+ block1 := types.NewBlockWithWithdrawals(header, txs, uncles, nil, nil, trie.NewStackTrie(nil))
+
+ // create block 2 (has withdrawals)
+ header2 := &types.Header{Number: big.NewInt(12)}
+ block2 := types.NewBlockWithWithdrawals(header2, nil, nil, nil, []*types.Withdrawal{{Index: 10}, {Index: 11}}, trie.NewStackTrie(nil))
+
+ expectedHashes := BlockBodyHashes{
+ TransactionRoots: []common.Hash{block1.TxHash(), block2.TxHash()},
+ WithdrawalRoots: []common.Hash{common.Hash{}, *block2.Header().WithdrawalsHash},
+ UncleHashes: []common.Hash{block1.UncleHash(), block2.UncleHash()},
+ }
+
+ // compute hash like protocol handler does
+ protocolBodies := []BlockBody{encodeBody(block1), encodeBody(block2)}
+ hashes := hashBodyParts(protocolBodies)
+ if !reflect.DeepEqual(hashes, expectedHashes) {
+ t.Errorf("wrong hashes: %s", spew.Sdump(hashes))
+ t.Logf("expected: %s", spew.Sdump(expectedHashes))
+ }
+}
+
// Tests that the transaction receipts can be retrieved based on hashes.
func TestGetBlockReceipts68(t *testing.T) { testGetBlockReceipts(t, ETH68) }
@@ -508,13 +563,13 @@ func testGetBlockReceipts(t *testing.T, protocol uint) {
// Collect the hashes to request, and the response to expect
var (
hashes []common.Hash
- receipts [][]*types.Receipt
+ receipts rlp.RawList[rlp.RawList[*types.Receipt]]
)
for i := uint64(0); i <= backend.chain.CurrentBlock().Number.Uint64(); i++ {
block := backend.chain.GetBlockByNumber(i)
-
hashes = append(hashes, block.Hash())
- receipts = append(receipts, backend.chain.GetReceiptsByHash(block.Hash()))
+ trs := backend.chain.GetReceiptsByHash(block.Hash())
+ receipts.Append(encodeRL(trs))
}
// Send the hash request and verify the response
p2p.Send(peer.app, GetReceiptsMsg, &GetReceiptsPacket{
@@ -522,9 +577,17 @@ func testGetBlockReceipts(t *testing.T, protocol uint) {
GetReceiptsRequest: hashes,
})
if err := p2p.ExpectMsg(peer.app, ReceiptsMsg, &ReceiptsPacket{
- RequestId: 123,
- ReceiptsResponse: receipts,
+ RequestId: 123,
+ List: receipts,
}); err != nil {
t.Errorf("receipts mismatch: %v", err)
}
}
+
+func encodeRL[T any](slice []T) rlp.RawList[T] {
+ rl, err := rlp.EncodeToRawList(slice)
+ if err != nil {
+ panic(err)
+ }
+ return rl
+}
diff --git a/eth/protocols/eth/handlers.go b/eth/protocols/eth/handlers.go
index 5d4f47e49..d693cbc5c 100644
--- a/eth/protocols/eth/handlers.go
+++ b/eth/protocols/eth/handlers.go
@@ -17,13 +17,17 @@
package eth
import (
+ "bytes"
"encoding/json"
"fmt"
+ "math"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
@@ -326,9 +330,17 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error {
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
+ tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockHeadersMsg, Size: res.List.Len()}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("BlockHeaders: %w", err)
+ }
+ headers, err := res.List.Items()
+ if err != nil {
+ return fmt.Errorf("BlockHeaders: %w", err)
+ }
metadata := func() interface{} {
- hashes := make([]common.Hash, len(res.BlockHeadersRequest))
- for i, header := range res.BlockHeadersRequest {
+ hashes := make([]common.Hash, len(headers))
+ for i, header := range headers {
hashes[i] = header.Hash()
}
return hashes
@@ -336,7 +348,7 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error {
return peer.dispatchResponse(&Response{
id: res.RequestId,
code: BlockHeadersMsg,
- Res: &res.BlockHeadersRequest,
+ Res: (*BlockHeadersRequest)(&headers),
}, metadata)
}
@@ -346,47 +358,156 @@ func handleBlockBodies(backend Backend, msg Decoder, peer *Peer) error {
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
- metadata := func() interface{} {
- var (
- txsHashes = make([]common.Hash, len(res.BlockBodiesResponse))
- uncleHashes = make([]common.Hash, len(res.BlockBodiesResponse))
- withdrawalHashes = make([]common.Hash, len(res.BlockBodiesResponse))
- )
- hasher := trie.NewStackTrie(nil)
- for i, body := range res.BlockBodiesResponse {
- txsHashes[i] = types.DeriveSha(types.Transactions(body.Transactions), hasher)
- uncleHashes[i] = types.CalcUncleHash(body.Uncles)
- if body.Withdrawals != nil {
- withdrawalHashes[i] = types.DeriveSha(types.Withdrawals(body.Withdrawals), hasher)
- }
- }
- return [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes}
+
+ // Check against the request.
+ length := res.List.Len()
+ tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockBodiesMsg, Size: length}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("BlockBodies: %w", err)
+ }
+
+ // Collect items and dispatch.
+ items, err := res.List.Items()
+ if err != nil {
+ return fmt.Errorf("BlockBodies: %w", err)
}
+ metadata := func() any { return hashBodyParts(items) }
return peer.dispatchResponse(&Response{
id: res.RequestId,
code: BlockBodiesMsg,
- Res: &res.BlockBodiesResponse,
+ Res: (*BlockBodiesResponse)(&items),
}, metadata)
}
+// BlockBodyHashes contains the lists of block body part roots for a list of block bodies.
+type BlockBodyHashes struct {
+ TransactionRoots []common.Hash
+ WithdrawalRoots []common.Hash
+ UncleHashes []common.Hash
+}
+
+// hashBodyParts computes the MPT root hashes of the transactions and withdrawals,
+// and the uncle hash, for each block body. It operates directly on the raw RLP
+// bytes in each RawList, avoiding a full decode of the body contents.
+func hashBodyParts(items []BlockBody) BlockBodyHashes {
+ h := BlockBodyHashes{
+ TransactionRoots: make([]common.Hash, len(items)),
+ WithdrawalRoots: make([]common.Hash, len(items)),
+ UncleHashes: make([]common.Hash, len(items)),
+ }
+ hasher := trie.NewStackTrie(nil)
+ for i, body := range items {
+ // txs
+ txsList := newDerivableRawList(&body.Transactions, writeTxForHash)
+ h.TransactionRoots[i] = types.DeriveSha(txsList, hasher)
+ // uncles
+ if body.Uncles.Len() == 0 {
+ h.UncleHashes[i] = types.EmptyUncleHash
+ } else {
+ h.UncleHashes[i] = crypto.Keccak256Hash(body.Uncles.Bytes())
+ }
+ // withdrawals
+ if body.Withdrawals != nil {
+ wdlist := newDerivableRawList(body.Withdrawals, nil)
+ h.WithdrawalRoots[i] = types.DeriveSha(wdlist, hasher)
+ }
+ }
+ return h
+}
+
+// derivableRawList implements types.DerivableList for a serialized RLP list.
+type derivableRawList struct {
+ data []byte
+ offsets []uint32
+ write func([]byte, *bytes.Buffer)
+}
+
+// newDerivableRawList creates a derivableRawList from a RawList. The write
+// function transforms each raw element before it is written to the hash buffer;
+// pass nil to use the identity transform (write raw bytes as-is).
+func newDerivableRawList[T any](list *rlp.RawList[T], write func([]byte, *bytes.Buffer)) *derivableRawList {
+ dl := derivableRawList{data: list.Content(), write: write}
+ if dl.write == nil {
+ // default transform is identity
+ dl.write = func(b []byte, buf *bytes.Buffer) { buf.Write(b) }
+ }
+ // Assert to ensure 32-bit offsets are valid. This can never trigger
+ // unless a block body component is larger than 4GB.
+ if uint(len(dl.data)) > math.MaxUint32 {
+ panic("list data too big for derivableRawList")
+ }
+ it := list.ContentIterator()
+ dl.offsets = make([]uint32, list.Len())
+ for i := 0; it.Next(); i++ {
+ dl.offsets[i] = uint32(it.Offset())
+ }
+ return &dl
+}
+
+// Len returns the number of items in the list.
+func (dl *derivableRawList) Len() int {
+ return len(dl.offsets)
+}
+
+// EncodeIndex writes the i'th item to the buffer.
+func (dl *derivableRawList) EncodeIndex(i int, buf *bytes.Buffer) {
+ start := dl.offsets[i]
+ end := uint32(len(dl.data))
+ if i != len(dl.offsets)-1 {
+ end = dl.offsets[i+1]
+ }
+ dl.write(dl.data[start:end], buf)
+}
+
+// writeTxForHash changes a transaction in 'network encoding' into the format used for
+// the transactions MPT.
+func writeTxForHash(tx []byte, buf *bytes.Buffer) {
+ k, content, _, _ := rlp.Split(tx)
+ if k == rlp.List {
+ buf.Write(tx) // legacy tx
+ } else {
+ buf.Write(content) // typed tx
+ }
+}
+
func handleReceipts(backend Backend, msg Decoder, peer *Peer) error {
// A batch of receipts arrived to one of our previous requests
res := new(ReceiptsPacket)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
+
+ tresp := tracker.Response{ID: res.RequestId, MsgCode: ReceiptsMsg, Size: res.List.Len()}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("Receipts: %w", err)
+ }
+
+ // Decode the two-level list: outer level is per-block, inner level is per-receipt.
+ blockLists, err := res.List.Items()
+ if err != nil {
+ return fmt.Errorf("Receipts: %w", err)
+ }
+ decoded := make([]types.Receipts, len(blockLists))
+ for i := range blockLists {
+ items, err := blockLists[i].Items()
+ if err != nil {
+ return fmt.Errorf("Receipts: invalid list %d: %v", i, err)
+ }
+ decoded[i] = items
+ }
metadata := func() interface{} {
hasher := trie.NewStackTrie(nil)
- hashes := make([]common.Hash, len(res.ReceiptsResponse))
- for i, receipt := range res.ReceiptsResponse {
- hashes[i] = types.DeriveSha(types.Receipts(receipt), hasher)
+ hashes := make([]common.Hash, len(decoded))
+ for i, receipts := range decoded {
+ hashes[i] = types.DeriveSha(receipts, hasher)
}
return hashes
}
+ enc := ReceiptsResponse(decoded)
return peer.dispatchResponse(&Response{
id: res.RequestId,
code: ReceiptsMsg,
- Res: &res.ReceiptsResponse,
+ Res: &enc,
}, metadata)
}
@@ -405,7 +526,7 @@ func handleNewPooledTransactionHashes(backend Backend, msg Decoder, peer *Peer)
}
// Schedule all the unknown hashes for retrieval
for _, hash := range ann.Hashes {
- peer.markTransaction(hash)
+ peer.MarkTransaction(hash)
}
return backend.Handle(peer, ann)
}
@@ -458,12 +579,8 @@ func handleTransactions(backend Backend, msg Decoder, peer *Peer) error {
if err := msg.Decode(&txs); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
- for i, tx := range txs {
- // Validate and mark the remote transaction
- if tx == nil {
- return fmt.Errorf("%w: transaction %d is nil", errDecode, i)
- }
- peer.markTransaction(tx.Hash())
+ if txs.Len() > maxTransactionAnnouncements {
+ return fmt.Errorf("too many transactions")
}
return backend.Handle(peer, &txs)
}
@@ -473,19 +590,20 @@ func handlePooledTransactions(backend Backend, msg Decoder, peer *Peer) error {
if !backend.AcceptTxs() {
return nil
}
- // Transactions can be processed, parse all of them and deliver to the pool
- var txs PooledTransactionsPacket
- if err := msg.Decode(&txs); err != nil {
+
+ // Check against request and decode.
+ var resp PooledTransactionsPacket
+ if err := msg.Decode(&resp); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
- for i, tx := range txs.PooledTransactionsResponse {
- // Validate and mark the remote transaction
- if tx == nil {
- return fmt.Errorf("%w: transaction %d is nil", errDecode, i)
- }
- peer.markTransaction(tx.Hash())
+ tresp := tracker.Response{
+ ID: resp.RequestId,
+ MsgCode: PooledTransactionsMsg,
+ Size: resp.List.Len(),
+ }
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("PooledTransactions: %w", err)
}
- requestTracker.Fulfil(peer.id, peer.version, PooledTransactionsMsg, txs.RequestId)
- return backend.Handle(peer, &txs.PooledTransactionsResponse)
+ return backend.Handle(peer, &resp)
}
diff --git a/eth/protocols/eth/peer.go b/eth/protocols/eth/peer.go
index d5c501a49..6f57f61e0 100644
--- a/eth/protocols/eth/peer.go
+++ b/eth/protocols/eth/peer.go
@@ -22,11 +22,13 @@ import (
"math/big"
"math/rand"
"sync"
+ "time"
mapset "github.com/deckarep/golang-set/v2"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -68,11 +70,12 @@ func max(a, b int) int {
// Peer is a collection of relevant information we have about a `eth` peer.
type Peer struct {
+ *p2p.Peer // The embedded P2P package peer
+
id string // Unique ID for the peer, cached
- *p2p.Peer // The embedded P2P package peer
- rw p2p.MsgReadWriter // Input/output streams for snap
- version uint // Protocol version negotiated
+ rw p2p.MsgReadWriter // Input/output streams for snap
+ version uint // Protocol version negotiated
head common.Hash // Latest advertised head block hash
td *big.Int // Latest advertised head block total difficulty
@@ -86,9 +89,10 @@ type Peer struct {
txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests
txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests
- reqDispatch chan *request // Dispatch channel to send requests and track then until fulfillment
- reqCancel chan *cancel // Dispatch channel to cancel pending requests and untrack them
- resDispatch chan *response // Dispatch channel to fulfil pending requests and untrack them
+ tracker *tracker.Tracker // Per-peer request/response tracker with timeout enforcement
+ reqDispatch chan *request // Dispatch channel to send requests and track then until fulfillment
+ reqCancel chan *cancel // Dispatch channel to cancel pending requests and untrack them
+ resDispatch chan *response // Dispatch channel to fulfil pending requests and untrack them
term chan struct{} // Termination channel to stop the broadcasters
lock sync.RWMutex // Mutex protecting the internal fields
@@ -99,8 +103,10 @@ type Peer struct {
// NewPeer creates a wrapper for a network connection and negotiated protocol
// version.
func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Peer {
+ cap := p2p.Cap{Name: ProtocolName, Version: version}
+ id := p.ID().String()
peer := &Peer{
- id: p.ID().String(),
+ id: id,
Peer: p,
rw: rw,
version: version,
@@ -110,6 +116,7 @@ func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Pe
queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns),
txBroadcast: make(chan []common.Hash),
txAnnounce: make(chan []common.Hash),
+ tracker: tracker.New(cap, id, 5*time.Minute),
reqDispatch: make(chan *request),
reqCancel: make(chan *cancel),
resDispatch: make(chan *response),
@@ -177,9 +184,9 @@ func (p *Peer) markBlock(hash common.Hash) {
p.knownBlocks.Add(hash)
}
-// markTransaction marks a transaction as known for the peer, ensuring that it
+// MarkTransaction marks a transaction as known for the peer, ensuring that it
// will never be propagated to this particular peer.
-func (p *Peer) markTransaction(hash common.Hash) {
+func (p *Peer) MarkTransaction(hash common.Hash) {
// If we reached the memory allowance, drop a previously known transaction hash
p.knownTxs.Add(hash)
}
@@ -333,10 +340,11 @@ func (p *Peer) RequestOneHeader(hash common.Hash, sink chan *Response) (*Request
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockHeadersMsg,
- want: BlockHeadersMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockHeadersMsg,
+ want: BlockHeadersMsg,
+ numItems: 1,
data: &GetBlockHeadersPacket{
RequestId: id,
GetBlockHeadersRequest: &GetBlockHeadersRequest{
@@ -360,10 +368,11 @@ func (p *Peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, re
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockHeadersMsg,
- want: BlockHeadersMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockHeadersMsg,
+ want: BlockHeadersMsg,
+ numItems: amount,
data: &GetBlockHeadersPacket{
RequestId: id,
GetBlockHeadersRequest: &GetBlockHeadersRequest{
@@ -387,10 +396,11 @@ func (p *Peer) RequestHeadersByNumber(origin uint64, amount int, skip int, rever
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockHeadersMsg,
- want: BlockHeadersMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockHeadersMsg,
+ want: BlockHeadersMsg,
+ numItems: amount,
data: &GetBlockHeadersPacket{
RequestId: id,
GetBlockHeadersRequest: &GetBlockHeadersRequest{
@@ -414,10 +424,11 @@ func (p *Peer) RequestBodies(hashes []common.Hash, sink chan *Response) (*Reques
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockBodiesMsg,
- want: BlockBodiesMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockBodiesMsg,
+ want: BlockBodiesMsg,
+ numItems: len(hashes),
data: &GetBlockBodiesPacket{
RequestId: id,
GetBlockBodiesRequest: hashes,
@@ -435,10 +446,11 @@ func (p *Peer) RequestReceipts(hashes []common.Hash, sink chan *Response) (*Requ
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetReceiptsMsg,
- want: ReceiptsMsg,
+ id: id,
+ sink: sink,
+ code: GetReceiptsMsg,
+ want: ReceiptsMsg,
+ numItems: len(hashes),
data: &GetReceiptsPacket{
RequestId: id,
GetReceiptsRequest: hashes,
@@ -455,7 +467,15 @@ func (p *Peer) RequestTxs(hashes []common.Hash) error {
p.Log().Debug("Fetching batch of transactions", "count", len(hashes))
id := rand.Uint64()
- requestTracker.Track(p.id, p.version, GetPooledTransactionsMsg, PooledTransactionsMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ID: id,
+ ReqCode: GetPooledTransactionsMsg,
+ RespCode: PooledTransactionsMsg,
+ Size: len(hashes),
+ })
+ if err != nil {
+ return err
+ }
return p2p.Send(p.rw, GetPooledTransactionsMsg, &GetPooledTransactionsPacket{
RequestId: id,
GetPooledTransactionsRequest: hashes,
diff --git a/eth/protocols/eth/protocol.go b/eth/protocols/eth/protocol.go
index 47e8d9724..d98d10ab5 100644
--- a/eth/protocols/eth/protocol.go
+++ b/eth/protocols/eth/protocol.go
@@ -48,6 +48,9 @@ var protocolLengths = map[uint]uint64{ETH68: 17}
// maxMessageSize is the maximum cap on the size of a protocol message.
const maxMessageSize = 10 * 1024 * 1024
+// This is the maximum number of transactions in a Transactions message.
+const maxTransactionAnnouncements = 5000
+
const (
StatusMsg = 0x00
NewBlockHashesMsg = 0x01
@@ -112,7 +115,9 @@ func (p *NewBlockHashesPacket) Unpack() ([]common.Hash, []uint64) {
}
// TransactionsPacket is the network packet for broadcasting new transactions.
-type TransactionsPacket []*types.Transaction
+type TransactionsPacket struct {
+ rlp.RawList[*types.Transaction]
+}
// GetBlockHeadersRequest represents a block header query.
type GetBlockHeadersRequest struct {
@@ -170,7 +175,7 @@ type BlockHeadersRequest []*types.Header
// BlockHeadersPacket represents a block header response over with request ID wrapping.
type BlockHeadersPacket struct {
RequestId uint64
- BlockHeadersRequest
+ List rlp.RawList[*types.Header]
}
// BlockHeadersRLPResponse represents a block header response, to use when we already
@@ -211,14 +216,11 @@ type GetBlockBodiesPacket struct {
GetBlockBodiesRequest
}
-// BlockBodiesResponse is the network packet for block content distribution.
-type BlockBodiesResponse []*BlockBody
-
// BlockBodiesPacket is the network packet for block content distribution with
// request ID wrapping.
type BlockBodiesPacket struct {
RequestId uint64
- BlockBodiesResponse
+ List rlp.RawList[BlockBody]
}
// BlockBodiesRLPResponse is used for replying to block body requests, in cases
@@ -232,26 +234,41 @@ type BlockBodiesRLPPacket struct {
BlockBodiesRLPResponse
}
-// BlockBody represents the data content of a single block.
-type BlockBody struct {
- Transactions []*types.Transaction // Transactions contained within a block
- Uncles []*types.Header // Uncles contained within a block
- Withdrawals []*types.Withdrawal `rlp:"optional"` // Withdrawals contained within a block
+// BlockBodiesResponse is the network packet for block content distribution.
+type BlockBodiesResponse []BlockBody
+
+// Unpack retrieves the transactions, uncles, and withdrawals from each block body.
+func (b *BlockBodiesResponse) Unpack() ([][]*types.Transaction, [][]*types.Header, [][]*types.Withdrawal, error) {
+ txLists := make([][]*types.Transaction, len(*b))
+ uncleLists := make([][]*types.Header, len(*b))
+ withdrawalLists := make([][]*types.Withdrawal, len(*b))
+ for i, body := range *b {
+ txs, err := body.Transactions.Items()
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("body %d: bad transactions: %v", i, err)
+ }
+ txLists[i] = txs
+ uncles, err := body.Uncles.Items()
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("body %d: bad uncles: %v", i, err)
+ }
+ uncleLists[i] = uncles
+ if body.Withdrawals != nil {
+ withdrawals, err := body.Withdrawals.Items()
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("body %d: bad withdrawals: %v", i, err)
+ }
+ withdrawalLists[i] = withdrawals
+ }
+ }
+ return txLists, uncleLists, withdrawalLists, nil
}
-// Unpack retrieves the transactions and uncles from the range packet and returns
-// them in a split flat format that's more consistent with the internal data structures.
-func (p *BlockBodiesResponse) Unpack() ([][]*types.Transaction, [][]*types.Header, [][]*types.Withdrawal) {
- // TODO(matt): add support for withdrawals to fetchers
- var (
- txset = make([][]*types.Transaction, len(*p))
- uncleset = make([][]*types.Header, len(*p))
- withdrawalset = make([][]*types.Withdrawal, len(*p))
- )
- for i, body := range *p {
- txset[i], uncleset[i], withdrawalset[i] = body.Transactions, body.Uncles, body.Withdrawals
- }
- return txset, uncleset, withdrawalset
+// BlockBody represents the data content of a single block.
+type BlockBody struct {
+ Transactions rlp.RawList[*types.Transaction] // Transactions contained within a block
+ Uncles rlp.RawList[*types.Header] // Uncles contained within a block
+ Withdrawals *rlp.RawList[*types.Withdrawal] `rlp:"optional"` // Withdrawals contained within a block
}
// GetReceiptsRequest represents a block receipts query.
@@ -264,13 +281,15 @@ type GetReceiptsPacket struct {
}
// ReceiptsResponse is the network packet for block receipts distribution.
-type ReceiptsResponse [][]*types.Receipt
+type ReceiptsResponse []types.Receipts
// ReceiptsPacket is the network packet for block receipts distribution with
-// request ID wrapping.
+// request ID wrapping. The outer list contains one entry per requested block;
+// each entry is itself a list of receipts for that block. Both levels are kept
+// as RawList to defer RLP decoding until the data is actually needed.
type ReceiptsPacket struct {
RequestId uint64
- ReceiptsResponse
+ List rlp.RawList[rlp.RawList[*types.Receipt]]
}
// ReceiptsRLPResponse is used for receipts, when we already have it encoded
@@ -305,7 +324,7 @@ type PooledTransactionsResponse []*types.Transaction
// with request ID wrapping.
type PooledTransactionsPacket struct {
RequestId uint64
- PooledTransactionsResponse
+ List rlp.RawList[*types.Transaction]
}
// PooledTransactionsRLPResponse is the network packet for transaction distribution, used
@@ -348,8 +367,8 @@ func (*NewPooledTransactionHashesPacket) Kind() byte { return NewPooledTransac
func (*GetPooledTransactionsRequest) Name() string { return "GetPooledTransactions" }
func (*GetPooledTransactionsRequest) Kind() byte { return GetPooledTransactionsMsg }
-func (*PooledTransactionsResponse) Name() string { return "PooledTransactions" }
-func (*PooledTransactionsResponse) Kind() byte { return PooledTransactionsMsg }
+func (*PooledTransactionsPacket) Name() string { return "PooledTransactions" }
+func (*PooledTransactionsPacket) Kind() byte { return PooledTransactionsMsg }
func (*GetReceiptsRequest) Name() string { return "GetReceipts" }
func (*GetReceiptsRequest) Kind() byte { return GetReceiptsMsg }
diff --git a/eth/protocols/eth/protocol_test.go b/eth/protocols/eth/protocol_test.go
index bc2545dea..1633f7dbf 100644
--- a/eth/protocols/eth/protocol_test.go
+++ b/eth/protocols/eth/protocol_test.go
@@ -78,34 +78,33 @@ func TestEmptyMessages(t *testing.T) {
for i, msg := range []interface{}{
// Headers
GetBlockHeadersPacket{1111, nil},
- BlockHeadersPacket{1111, nil},
// Bodies
GetBlockBodiesPacket{1111, nil},
- BlockBodiesPacket{1111, nil},
BlockBodiesRLPPacket{1111, nil},
// Receipts
GetReceiptsPacket{1111, nil},
- ReceiptsPacket{1111, nil},
// Transactions
GetPooledTransactionsPacket{1111, nil},
- PooledTransactionsPacket{1111, nil},
PooledTransactionsRLPPacket{1111, nil},
// Headers
- BlockHeadersPacket{1111, BlockHeadersRequest([]*types.Header{})},
+ BlockHeadersPacket{1111, encodeRL([]*types.Header{})},
// Bodies
GetBlockBodiesPacket{1111, GetBlockBodiesRequest([]common.Hash{})},
- BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{})},
+ BlockBodiesPacket{1111, encodeRL([]BlockBody{})},
BlockBodiesRLPPacket{1111, BlockBodiesRLPResponse([]rlp.RawValue{})},
// Receipts
GetReceiptsPacket{1111, GetReceiptsRequest([]common.Hash{})},
- ReceiptsPacket{1111, ReceiptsResponse([][]*types.Receipt{})},
+ ReceiptsPacket{1111, encodeRL([]rlp.RawList[*types.Receipt]{})},
// Transactions
GetPooledTransactionsPacket{1111, GetPooledTransactionsRequest([]common.Hash{})},
- PooledTransactionsPacket{1111, PooledTransactionsResponse([]*types.Transaction{})},
+ PooledTransactionsPacket{1111, encodeRL([]*types.Transaction{})},
PooledTransactionsRLPPacket{1111, PooledTransactionsRLPResponse([]rlp.RawValue{})},
} {
- if have, _ := rlp.EncodeToBytes(msg); !bytes.Equal(have, want) {
+ have, err := rlp.EncodeToBytes(msg)
+ if err != nil {
+ t.Errorf("test %d, type %T, error: %v", i, msg, err)
+ } else if !bytes.Equal(have, want) {
t.Errorf("test %d, type %T, have\n\t%x\nwant\n\t%x", i, msg, have, want)
}
}
@@ -116,7 +115,7 @@ func TestMessages(t *testing.T) {
// Some basic structs used during testing
var (
header *types.Header
- blockBody *BlockBody
+ blockBody BlockBody
blockBodyRlp rlp.RawValue
txs []*types.Transaction
txRlps []rlp.RawValue
@@ -150,9 +149,9 @@ func TestMessages(t *testing.T) {
}
}
// init the block body data, both object and rlp form
- blockBody = &BlockBody{
- Transactions: txs,
- Uncles: []*types.Header{header},
+ blockBody = BlockBody{
+ Transactions: encodeRL(txs),
+ Uncles: encodeRL([]*types.Header{header}),
}
blockBodyRlp, err = rlp.EncodeToBytes(blockBody)
if err != nil {
@@ -201,7 +200,7 @@ func TestMessages(t *testing.T) {
common.FromHex("ca820457c682270f050580"),
},
{
- BlockHeadersPacket{1111, BlockHeadersRequest{header}},
+ BlockHeadersPacket{1111, encodeRL([]*types.Header{header})},
common.FromHex("f90202820457f901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"),
},
{
@@ -209,7 +208,7 @@ func TestMessages(t *testing.T) {
common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"),
},
{
- BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{blockBody})},
+ BlockBodiesPacket{1111, encodeRL([]BlockBody{blockBody})},
common.FromHex("f902dc820457f902d6f902d3f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afbf901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"),
},
{ // Identical to non-rlp-shortcut version
@@ -221,7 +220,7 @@ func TestMessages(t *testing.T) {
common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"),
},
{
- ReceiptsPacket{1111, ReceiptsResponse([][]*types.Receipt{receipts})},
+ ReceiptsPacket{1111, encodeRL([]rlp.RawList[*types.Receipt]{encodeRL(receipts)})},
common.FromHex("f90172820457f9016cf90169f901668001b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100ff"),
},
{
@@ -233,7 +232,7 @@ func TestMessages(t *testing.T) {
common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"),
},
{
- PooledTransactionsPacket{1111, PooledTransactionsResponse(txs)},
+ PooledTransactionsPacket{1111, encodeRL(txs)},
common.FromHex("f8d7820457f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afb"),
},
{
diff --git a/eth/protocols/eth/tracker.go b/eth/protocols/eth/tracker.go
deleted file mode 100644
index 324fd2283..000000000
--- a/eth/protocols/eth/tracker.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package eth
-
-import (
- "time"
-
- "github.com/ethereum/go-ethereum/p2p/tracker"
-)
-
-// requestTracker is a singleton tracker for eth/66 and newer request times.
-var requestTracker = tracker.New(ProtocolName, 5*time.Minute)
diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go
index 003e014f4..c4c0f131f 100644
--- a/eth/protocols/snap/handler.go
+++ b/eth/protocols/snap/handler.go
@@ -29,6 +29,8 @@ import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode"
)
@@ -97,6 +99,7 @@ func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol {
Length: protocolLengths[version],
Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
return backend.RunPeer(NewPeer(version, p, rw), func(peer *Peer) error {
+ defer peer.Close()
return Handle(backend, peer)
})
},
@@ -153,7 +156,6 @@ func HandleMessage(backend Backend, peer *Peer) error {
// Handle the message depending on its contents
switch {
case msg.Code == GetAccountRangeMsg:
- // Decode the account retrieval request
var req GetAccountRangePacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -169,23 +171,40 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == AccountRangeMsg:
- // A range of accounts arrived to one of our previous requests
- res := new(AccountRangePacket)
+ res := new(accountRangeInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
+
+ // Check response validity.
+ if len := res.Proof.Len(); len > 128 {
+ return fmt.Errorf("AccountRange: invalid proof (length %d)", len)
+ }
+ tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return err
+ }
+
+ // Decode.
+ accounts, err := res.Accounts.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid accounts list: %v", err)
+ }
+ proof, err := res.Proof.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid proof: %v", err)
+ }
+
// Ensure the range is monotonically increasing
- for i := 1; i < len(res.Accounts); i++ {
- if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 {
- return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:])
+ for i := 1; i < len(accounts); i++ {
+ if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 {
+ return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:])
}
}
- requestTracker.Fulfil(peer.id, peer.version, AccountRangeMsg, res.ID)
- return backend.Handle(peer, res)
+ return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof})
case msg.Code == GetStorageRangesMsg:
- // Decode the storage retrieval request
var req GetStorageRangesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -201,25 +220,42 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == StorageRangesMsg:
- // A range of storage slots arrived to one of our previous requests
- res := new(StorageRangesPacket)
+ res := new(storageRangesInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
+
+ // Check response validity.
+ if len := res.Proof.Len(); len > 128 {
+ return fmt.Errorf("StorageRanges: invalid proof (length %d)", len)
+ }
+ tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("StorageRanges: %w", err)
+ }
+
+ // Decode.
+ slotLists, err := res.Slots.Items()
+ if err != nil {
+ return fmt.Errorf("StorageRanges: invalid slot list: %v", err)
+ }
+ proof, err := res.Proof.Items()
+ if err != nil {
+ return fmt.Errorf("StorageRanges: invalid proof: %v", err)
+ }
+
// Ensure the ranges are monotonically increasing
- for i, slots := range res.Slots {
+ for i, slots := range slotLists {
for j := 1; j < len(slots); j++ {
if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 {
return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:])
}
}
}
- requestTracker.Fulfil(peer.id, peer.version, StorageRangesMsg, res.ID)
- return backend.Handle(peer, res)
+ return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof})
case msg.Code == GetByteCodesMsg:
- // Decode bytecode retrieval request
var req GetByteCodesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -234,17 +270,25 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == ByteCodesMsg:
- // A batch of byte codes arrived to one of our previous requests
- res := new(ByteCodesPacket)
+ res := new(byteCodesInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
- requestTracker.Fulfil(peer.id, peer.version, ByteCodesMsg, res.ID)
- return backend.Handle(peer, res)
+ length := res.Codes.Len()
+ tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("ByteCodes: %w", err)
+ }
+
+ codes, err := res.Codes.Items()
+ if err != nil {
+ return fmt.Errorf("ByteCodes: %w", err)
+ }
+
+ return backend.Handle(peer, &ByteCodesPacket{res.ID, codes})
case msg.Code == GetTrieNodesMsg:
- // Decode trie node retrieval request
var req GetTrieNodesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -261,14 +305,21 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == TrieNodesMsg:
- // A batch of trie nodes arrived to one of our previous requests
- res := new(TrieNodesPacket)
+ res := new(trieNodesInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
- requestTracker.Fulfil(peer.id, peer.version, TrieNodesMsg, res.ID)
- return backend.Handle(peer, res)
+ tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("TrieNodes: %w", err)
+ }
+ nodes, err := res.Nodes.Items()
+ if err != nil {
+ return fmt.Errorf("TrieNodes: %w", err)
+ }
+
+ return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes})
default:
return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
@@ -496,19 +547,29 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s
snap := chain.Snapshots().Snapshot(req.Root)
// Retrieve trie nodes until the packet size limit is reached
var (
- nodes [][]byte
- bytes uint64
- loads int // Trie hash expansions to count database reads
+ outerIt = req.Paths.ContentIterator()
+ nodes [][]byte
+ bytes uint64
+ loads int // Trie hash expansions to count database reads
)
- for _, pathset := range req.Paths {
- switch len(pathset) {
+ for outerIt.Next() {
+ innerIt, err := rlp.NewListIterator(outerIt.Value())
+ if err != nil {
+ return nodes, err
+ }
+
+ switch innerIt.Count() {
case 0:
// Ensure we penalize invalid requests
return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest)
case 1:
// If we're only retrieving an account trie node, fetch it directly
- blob, resolved, err := accTrie.GetNode(pathset[0])
+ accKey := nextBytes(&innerIt)
+ if accKey == nil {
+ return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest)
+ }
+ blob, resolved, err := accTrie.GetNode(accKey)
loads += resolved // always account database reads, even for failures
if err != nil {
break
@@ -517,32 +578,41 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s
bytes += uint64(len(blob))
default:
- var stRoot common.Hash
// Storage slots requested, open the storage trie and retrieve from there
+ accKey := nextBytes(&innerIt)
+ if accKey == nil {
+ return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest)
+ }
+ var stRoot common.Hash
if snap == nil {
// We don't have the requested state snapshotted yet (or it is stale),
// but can look up the account via the trie instead.
- account, err := accTrie.GetAccountByHash(common.BytesToHash(pathset[0]))
+ account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey))
loads += 8 // We don't know the exact cost of lookup, this is an estimate
if err != nil || account == nil {
break
}
stRoot = account.Root
} else {
- account, err := snap.Account(common.BytesToHash(pathset[0]))
+ account, err := snap.Account(common.BytesToHash(accKey))
loads++ // always account database reads, even for failures
if err != nil || account == nil {
break
}
stRoot = common.BytesToHash(account.Root)
}
- id := trie.StorageTrieID(req.Root, common.BytesToHash(pathset[0]), stRoot)
+
+ id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot)
stTrie, err := trie.NewStateTrie(id, triedb)
loads++ // always account database reads, even for failures
if err != nil {
break
}
- for _, path := range pathset[1:] {
+ for innerIt.Next() {
+ path, _, err := rlp.SplitString(innerIt.Value())
+ if err != nil {
+ return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err)
+ }
blob, resolved, err := stTrie.GetNode(path)
loads += resolved // always account database reads, even for failures
if err != nil {
@@ -565,6 +635,17 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s
return nodes, nil
}
+func nextBytes(it *rlp.Iterator) []byte {
+ if !it.Next() {
+ return nil
+ }
+ content, _, err := rlp.SplitString(it.Value())
+ if err != nil {
+ return nil
+ }
+ return content
+}
+
// NodeInfo represents a short summary of the `snap` sub-protocol metadata
// known about the host peer.
type NodeInfo struct{}
diff --git a/eth/protocols/snap/peer.go b/eth/protocols/snap/peer.go
index c57931678..7b5c70146 100644
--- a/eth/protocols/snap/peer.go
+++ b/eth/protocols/snap/peer.go
@@ -17,9 +17,13 @@
package snap
import (
+ "time"
+
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
+ "github.com/ethereum/go-ethereum/rlp"
)
// Peer is a collection of relevant information we have about a `snap` peer.
@@ -29,6 +33,7 @@ type Peer struct {
*p2p.Peer // The embedded P2P package peer
rw p2p.MsgReadWriter // Input/output streams for snap
version uint // Protocol version negotiated
+ tracker *tracker.Tracker
logger log.Logger // Contextual logger with the peer id injected
}
@@ -36,22 +41,26 @@ type Peer struct {
// NewPeer creates a wrapper for a network connection and negotiated protocol
// version.
func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer {
+ cap := p2p.Cap{Name: ProtocolName, Version: version}
id := p.ID().String()
return &Peer{
id: id,
Peer: p,
rw: rw,
version: version,
+ tracker: tracker.New(cap, id, 1*time.Minute),
logger: log.New("peer", id[:8]),
}
}
// NewFakePeer creates a fake snap peer without a backing p2p peer, for testing purposes.
func NewFakePeer(version uint, id string, rw p2p.MsgReadWriter) *Peer {
+ cap := p2p.Cap{Name: ProtocolName, Version: version}
return &Peer{
id: id,
rw: rw,
version: version,
+ tracker: tracker.New(cap, id, 1*time.Minute),
logger: log.New("peer", id[:8]),
}
}
@@ -71,63 +80,102 @@ func (p *Peer) Log() log.Logger {
return p.logger
}
+// Close releases resources associated with the peer.
+func (p *Peer) Close() {
+ p.tracker.Stop()
+}
+
// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
-func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error {
+func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes int) error {
p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
- requestTracker.Track(p.id, p.version, GetAccountRangeMsg, AccountRangeMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ReqCode: GetAccountRangeMsg,
+ RespCode: AccountRangeMsg,
+ ID: id,
+ Size: 2 * bytes,
+ })
+ if err != nil {
+ return err
+ }
return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{
ID: id,
Root: root,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
})
}
// RequestStorageRanges fetches a batch of storage slots belonging to one or more
// accounts. If slots from only one account is requested, an origin marker may also
// be used to retrieve from there.
-func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error {
if len(accounts) == 1 && origin != nil {
p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
} else {
p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes))
}
- requestTracker.Track(p.id, p.version, GetStorageRangesMsg, StorageRangesMsg, id)
+
+ err := p.tracker.Track(tracker.Request{
+ ReqCode: GetStorageRangesMsg,
+ RespCode: StorageRangesMsg,
+ ID: id,
+ Size: 2 * bytes,
+ })
+ if err != nil {
+ return err
+ }
return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{
ID: id,
Root: root,
Accounts: accounts,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
})
}
// RequestByteCodes fetches a batch of bytecodes by hash.
-func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error {
p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
- requestTracker.Track(p.id, p.version, GetByteCodesMsg, ByteCodesMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ReqCode: GetByteCodesMsg,
+ RespCode: ByteCodesMsg,
+ ID: id,
+ Size: len(hashes), // ByteCodes is limited by the length of the hash list.
+ })
+ if err != nil {
+ return err
+ }
return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{
ID: id,
Hashes: hashes,
- Bytes: bytes,
+ Bytes: uint64(bytes),
})
}
// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
-// a specific state trie.
-func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
+// a specific state trie. The `count` is the total count of paths being requested.
+func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error {
p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
- requestTracker.Track(p.id, p.version, GetTrieNodesMsg, TrieNodesMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ReqCode: GetTrieNodesMsg,
+ RespCode: TrieNodesMsg,
+ ID: id,
+ Size: count, // TrieNodes is limited by number of items.
+ })
+ if err != nil {
+ return err
+ }
+ encPaths, _ := rlp.EncodeToRawList(paths)
return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{
ID: id,
Root: root,
- Paths: paths,
- Bytes: bytes,
+ Paths: encPaths,
+ Bytes: uint64(bytes),
})
}
diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go
index 0db206b08..25fe25822 100644
--- a/eth/protocols/snap/protocol.go
+++ b/eth/protocols/snap/protocol.go
@@ -78,6 +78,12 @@ type GetAccountRangePacket struct {
Bytes uint64 // Soft limit at which to stop returning data
}
+type accountRangeInput struct {
+ ID uint64 // ID of the request this is a response for
+ Accounts rlp.RawList[*AccountData] // List of consecutive accounts from the trie
+ Proof rlp.RawList[[]byte] // List of trie nodes proving the account range
+}
+
// AccountRangePacket represents an account query response.
type AccountRangePacket struct {
ID uint64 // ID of the request this is a response for
@@ -123,6 +129,12 @@ type GetStorageRangesPacket struct {
Bytes uint64 // Soft limit at which to stop returning data
}
+type storageRangesInput struct {
+ ID uint64 // ID of the request this is a response for
+ Slots rlp.RawList[[]*StorageData] // Lists of consecutive storage slots for the requested accounts
+ Proof rlp.RawList[[]byte] // Merkle proofs for the *last* slot range, if it's incomplete
+}
+
// StorageRangesPacket represents a storage slot query response.
type StorageRangesPacket struct {
ID uint64 // ID of the request this is a response for
@@ -161,6 +173,11 @@ type GetByteCodesPacket struct {
Bytes uint64 // Soft limit at which to stop returning data
}
+type byteCodesInput struct {
+ ID uint64 // ID of the request this is a response for
+ Codes rlp.RawList[[]byte] // Requested contract bytecodes
+}
+
// ByteCodesPacket represents a contract bytecode query response.
type ByteCodesPacket struct {
ID uint64 // ID of the request this is a response for
@@ -169,10 +186,10 @@ type ByteCodesPacket struct {
// GetTrieNodesPacket represents a state trie node query.
type GetTrieNodesPacket struct {
- ID uint64 // Request ID to match up responses with
- Root common.Hash // Root hash of the account trie to serve
- Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for
- Bytes uint64 // Soft limit at which to stop returning data
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Paths rlp.RawList[TrieNodePathSet] // Trie node hashes to retrieve the nodes for
+ Bytes uint64 // Soft limit at which to stop returning data
}
// TrieNodePathSet is a list of trie node paths to retrieve. A naive way to
@@ -187,6 +204,11 @@ type GetTrieNodesPacket struct {
// that a slot is accessed before the account path is fully expanded.
type TrieNodePathSet [][]byte
+type trieNodesInput struct {
+ ID uint64 // ID of the request this is a response for
+ Nodes rlp.RawList[[]byte] // Requested state trie nodes
+}
+
// TrieNodesPacket represents a state trie node query response.
type TrieNodesPacket struct {
ID uint64 // ID of the request this is a response for
diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go
index 050c2c4cc..0dbb1c886 100644
--- a/eth/protocols/snap/sync.go
+++ b/eth/protocols/snap/sync.go
@@ -413,19 +413,19 @@ type SyncPeer interface {
// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
- RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error
+ RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error
// RequestStorageRanges fetches a batch of storage slots belonging to one or
// more accounts. If slots from only one account is requested, an origin marker
// may also be used to retrieve from there.
- RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error
+ RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error
// RequestByteCodes fetches a batch of bytecodes by hash.
- RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error
+ RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error
// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
// a specific state trie.
- RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error
+ RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error
// Log retrieves the peer's own contextual logger.
Log() log.Logger
@@ -1083,7 +1083,7 @@ func (s *Syncer) assignAccountTasks(success chan *accountResponse, fail chan *ac
if cap < minRequestSize { // Don't bother with peers below a bare minimum performance
cap = minRequestSize
}
- if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, uint64(cap)); err != nil {
+ if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, cap); err != nil {
peer.Log().Debug("Failed to request account range", "err", err)
s.scheduleRevertAccountRequest(req)
}
@@ -1340,7 +1340,7 @@ func (s *Syncer) assignStorageTasks(success chan *storageResponse, fail chan *st
if subtask != nil {
origin, limit = req.origin[:], req.limit[:]
}
- if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, uint64(cap)); err != nil {
+ if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, cap); err != nil {
log.Debug("Failed to request storage", "err", err)
s.scheduleRevertStorageRequest(req)
}
@@ -1473,7 +1473,7 @@ func (s *Syncer) assignTrienodeHealTasks(success chan *trienodeHealResponse, fai
defer s.pend.Done()
// Attempt to send the remote request and revert if it fails
- if err := peer.RequestTrieNodes(reqid, root, pathsets, maxRequestSize); err != nil {
+ if err := peer.RequestTrieNodes(reqid, root, len(paths), pathsets, maxRequestSize); err != nil {
log.Debug("Failed to request trienode healers", "err", err)
s.scheduleRevertTrienodeHealRequest(req)
}
diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go
index 7d893e454..673a76e8b 100644
--- a/eth/protocols/snap/sync_test.go
+++ b/eth/protocols/snap/sync_test.go
@@ -119,10 +119,10 @@ func BenchmarkHashing(b *testing.B) {
}
type (
- accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error
- storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error
- trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error
- codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error
+ accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error
+ storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error
+ trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error
+ codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max int) error
)
type testPeer struct {
@@ -182,21 +182,21 @@ Trienode requests: %d
`, t.nAccountRequests, t.nStorageRequests, t.nBytecodeRequests, t.nTrienodeRequests)
}
-func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error {
+func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error {
t.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
t.nAccountRequests++
go t.accountRequestHandler(t, id, root, origin, limit, bytes)
return nil
}
-func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
+func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error {
t.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
t.nTrienodeRequests++
go t.trieRequestHandler(t, id, root, paths, bytes)
return nil
}
-func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error {
t.nStorageRequests++
if len(accounts) == 1 && origin != nil {
t.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
@@ -207,7 +207,7 @@ func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []
return nil
}
-func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error {
t.nBytecodeRequests++
t.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
go t.codeRequestHandler(t, id, hashes, bytes)
@@ -215,7 +215,7 @@ func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint6
}
// defaultTrieRequestHandler is a well-behaving handler for trie healing requests
-func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error {
// Pass the response
var nodes [][]byte
for _, pathset := range paths {
@@ -244,7 +244,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash,
}
// defaultAccountRequestHandler is a well-behaving handler for AccountRangeRequests
-func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
keys, vals, proofs := createAccountRequestResponse(t, root, origin, limit, cap)
if err := t.remote.OnAccounts(t, id, keys, vals, proofs); err != nil {
t.test.Errorf("Remote side rejected our delivery: %v", err)
@@ -254,8 +254,8 @@ func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, orig
return nil
}
-func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) (keys []common.Hash, vals [][]byte, proofs [][]byte) {
- var size uint64
+func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap int) (keys []common.Hash, vals [][]byte, proofs [][]byte) {
+ var size int
if limit == (common.Hash{}) {
limit = common.MaxHash
}
@@ -266,7 +266,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H
if bytes.Compare(origin[:], entry.k) <= 0 {
keys = append(keys, common.BytesToHash(entry.k))
vals = append(vals, entry.v)
- size += uint64(32 + len(entry.v))
+ size += 32 + len(entry.v)
}
// If we've exceeded the request threshold, abort
if bytes.Compare(entry.k, limit[:]) >= 0 {
@@ -293,7 +293,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H
}
// defaultStorageRequestHandler is a well-behaving storage request handler
-func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) error {
+func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) error {
hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, bOrigin, bLimit, max)
if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil {
t.test.Errorf("Remote side rejected our delivery: %v", err)
@@ -302,7 +302,7 @@ func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Has
return nil
}
-func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error {
var bytecodes [][]byte
for _, h := range hashes {
bytecodes = append(bytecodes, getCodeByHash(h))
@@ -314,8 +314,8 @@ func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max
return nil
}
-func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
- var size uint64
+func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
+ var size int
for _, account := range accounts {
// The first account might start from a different origin and end sooner
var originHash common.Hash
@@ -341,7 +341,7 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm
}
keys = append(keys, common.BytesToHash(entry.k))
vals = append(vals, entry.v)
- size += uint64(32 + len(entry.v))
+ size += 32 + len(entry.v)
if bytes.Compare(entry.k, limitHash[:]) >= 0 {
break
}
@@ -382,8 +382,8 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm
// createStorageRequestResponseAlwaysProve tests a cornercase, where the peer always
// supplies the proof for the last account, even if it is 'complete'.
-func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
- var size uint64
+func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
+ var size int
max = max * 3 / 4
var origin common.Hash
@@ -400,7 +400,7 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco
}
keys = append(keys, common.BytesToHash(entry.k))
vals = append(vals, entry.v)
- size += uint64(32 + len(entry.v))
+ size += 32 + len(entry.v)
if size > max {
exit = true
}
@@ -440,34 +440,34 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco
}
// emptyRequestAccountRangeFn is a rejects AccountRangeRequests
-func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
t.remote.OnAccounts(t, requestId, nil, nil, nil)
return nil
}
-func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
return nil
}
-func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error {
t.remote.OnTrieNodes(t, requestId, nil)
return nil
}
-func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error {
return nil
}
-func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
t.remote.OnStorage(t, requestId, nil, nil, nil)
return nil
}
-func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
return nil
}
-func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
hashes, slots, proofs := createStorageRequestResponseAlwaysProve(t, root, accounts, origin, limit, max)
if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil {
t.test.Errorf("Remote side rejected our delivery: %v", err)
@@ -482,7 +482,7 @@ func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.
// return nil
//}
-func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error {
var bytecodes [][]byte
for _, h := range hashes {
// Send back the hashes
@@ -496,7 +496,7 @@ func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max
return nil
}
-func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error {
var bytecodes [][]byte
for _, h := range hashes[:1] {
bytecodes = append(bytecodes, getCodeByHash(h))
@@ -510,11 +510,11 @@ func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max
}
// starvingStorageRequestHandler is somewhat well-behaving storage handler, but it caps the returned results to be very small
-func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
return defaultStorageRequestHandler(t, requestId, root, accounts, origin, limit, 500)
}
-func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
return defaultAccountRequestHandler(t, requestId, root, origin, limit, 500)
}
@@ -522,7 +522,7 @@ func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Ha
// return defaultAccountRequestHandler(t, requestId-1, root, origin, 500)
//}
-func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
hashes, accounts, proofs := createAccountRequestResponse(t, root, origin, limit, cap)
if len(proofs) > 0 {
proofs = proofs[1:]
@@ -536,7 +536,7 @@ func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Has
}
// corruptStorageRequestHandler doesn't provide good proofs
-func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, origin, limit, max)
if len(proofs) > 0 {
proofs = proofs[1:]
@@ -549,7 +549,7 @@ func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Has
return nil
}
-func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
hashes, slots, _ := createStorageRequestResponse(t, root, accounts, origin, limit, max)
if err := t.remote.OnStorage(t, requestId, hashes, slots, nil); err != nil {
t.logger.Info("remote error on delivery (as expected)", "error", err)
@@ -584,7 +584,7 @@ func testSyncBloatedProof(t *testing.T, scheme string) {
source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems
- source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
var (
proofs [][]byte
keys []common.Hash
@@ -1177,7 +1177,7 @@ func testSyncNoStorageAndOneCodeCappedPeer(t *testing.T, scheme string) {
var counter int
syncer := setupSyncer(
nodeScheme,
- mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+ mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max int) error {
counter++
return cappedCodeRequestHandler(t, id, hashes, max)
}),
@@ -1444,7 +1444,7 @@ func testSyncWithUnevenStorage(t *testing.T, scheme string) {
source.accountValues = accounts
source.setStorageTries(storageTries)
source.storageValues = storageElems
- source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
return defaultStorageRequestHandler(t, reqId, root, accounts, origin, limit, 128) // retrieve storage in large mode
}
return source
diff --git a/eth/protocols/snap/tracker.go b/eth/protocols/snap/tracker.go
deleted file mode 100644
index 2cf59cc23..000000000
--- a/eth/protocols/snap/tracker.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package snap
-
-import (
- "time"
-
- "github.com/ethereum/go-ethereum/p2p/tracker"
-)
-
-// requestTracker is a singleton tracker for request times.
-var requestTracker = tracker.New(ProtocolName, time.Minute)
diff --git a/p2p/tracker/tracker.go b/p2p/tracker/tracker.go
index 5b72eb2b8..17ba3688e 100644
--- a/p2p/tracker/tracker.go
+++ b/p2p/tracker/tracker.go
@@ -18,12 +18,14 @@ package tracker
import (
"container/list"
+ "errors"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/p2p"
)
const (
@@ -42,28 +44,44 @@ const (
// maxTrackedPackets is a huge number to act as a failsafe on the number of
// pending requests the node will track. It should never be hit unless an
// attacker figures out a way to spin requests.
- maxTrackedPackets = 100000
+ maxTrackedPackets = 10000
)
-// request tracks sent network requests which have not yet received a response.
-type request struct {
- peer string
- version uint // Protocol version
+var (
+ ErrNoMatchingRequest = errors.New("no matching request")
+ ErrTooManyItems = errors.New("response is larger than request allows")
+ ErrCollision = errors.New("request ID collision")
+ ErrCodeMismatch = errors.New("wrong response code for request")
+ ErrLimitReached = errors.New("request limit reached")
+ ErrStopped = errors.New("tracker stopped")
+)
- reqCode uint64 // Protocol message code of the request
- resCode uint64 // Protocol message code of the expected response
+// request tracks sent network requests which have not yet received a response.
+type Request struct {
+ ID uint64 // Request ID
+ Size int // Number/size of requested items
+ ReqCode uint64 // Protocol message code of the request
+ RespCode uint64 // Protocol message code of the expected response
time time.Time // Timestamp when the request was made
expire *list.Element // Expiration marker to untrack it
}
+type Response struct {
+ ID uint64 // Request ID of the response
+ MsgCode uint64 // Protocol message code
+ Size int // number/size of items in response
+}
+
// Tracker is a pending network request tracker to measure how much time it takes
// a remote peer to respond.
type Tracker struct {
- protocol string // Protocol capability identifier for the metrics
- timeout time.Duration // Global timeout after which to drop a tracked packet
+ cap p2p.Cap // Protocol capability identifier for the metrics
+
+ peer string // Peer ID
+ timeout time.Duration // Global timeout after which to drop a tracked packet
- pending map[uint64]*request // Currently pending requests
+ pending map[uint64]*Request // Currently pending requests
expire *list.List // Linked list tracking the expiration order
wake *time.Timer // Timer tracking the expiration of the next item
@@ -72,52 +90,53 @@ type Tracker struct {
// New creates a new network request tracker to monitor how much time it takes to
// fill certain requests and how individual peers perform.
-func New(protocol string, timeout time.Duration) *Tracker {
+func New(cap p2p.Cap, peerID string, timeout time.Duration) *Tracker {
return &Tracker{
- protocol: protocol,
- timeout: timeout,
- pending: make(map[uint64]*request),
- expire: list.New(),
+ cap: cap,
+ peer: peerID,
+ timeout: timeout,
+ pending: make(map[uint64]*Request),
+ expire: list.New(),
}
}
// Track adds a network request to the tracker to wait for a response to arrive
// or until the request it cancelled or times out.
-func (t *Tracker) Track(peer string, version uint, reqCode uint64, resCode uint64, id uint64) {
- if !metrics.Enabled() {
- return
- }
+func (t *Tracker) Track(req Request) error {
t.lock.Lock()
defer t.lock.Unlock()
+ if t.expire == nil {
+ return ErrStopped
+ }
+
// If there's a duplicate request, we've just random-collided (or more probably,
// we have a bug), report it. We could also add a metric, but we're not really
// expecting ourselves to be buggy, so a noisy warning should be enough.
- if _, ok := t.pending[id]; ok {
- log.Error("Network request id collision", "protocol", t.protocol, "version", version, "code", reqCode, "id", id)
- return
+ if _, ok := t.pending[req.ID]; ok {
+ log.Error("Network request id collision", "cap", t.cap, "code", req.ReqCode, "id", req.ID)
+ return ErrCollision
}
// If we have too many pending requests, bail out instead of leaking memory
if pending := len(t.pending); pending >= maxTrackedPackets {
- log.Error("Request tracker exceeded allowance", "pending", pending, "peer", peer, "protocol", t.protocol, "version", version, "code", reqCode)
- return
+ log.Error("Request tracker exceeded allowance", "pending", pending, "peer", t.peer, "cap", t.cap, "code", req.ReqCode)
+ return ErrLimitReached
}
+
// Id doesn't exist yet, start tracking it
- t.pending[id] = &request{
- peer: peer,
- version: version,
- reqCode: reqCode,
- resCode: resCode,
- time: time.Now(),
- expire: t.expire.PushBack(id),
+ req.time = time.Now()
+ req.expire = t.expire.PushBack(req.ID)
+ t.pending[req.ID] = &req
+
+ if metrics.Enabled() {
+ t.trackedGauge(req.ReqCode).Inc(1)
}
- g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, version, reqCode)
- metrics.GetOrRegisterGauge(g, nil).Inc(1)
// If we've just inserted the first item, start the expiration timer
if t.wake == nil {
t.wake = time.AfterFunc(t.timeout, t.clean)
}
+ return nil
}
// clean is called automatically when a preset time passes without a response
@@ -126,6 +145,10 @@ func (t *Tracker) clean() {
t.lock.Lock()
defer t.lock.Unlock()
+ if t.expire == nil {
+ return // Tracker was stopped.
+ }
+
// Expire anything within a certain threshold (might be no items at all if
// we raced with the delivery)
for t.expire.Len() > 0 {
@@ -142,64 +165,111 @@ func (t *Tracker) clean() {
t.expire.Remove(head)
delete(t.pending, id)
- g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode)
- metrics.GetOrRegisterGauge(g, nil).Dec(1)
-
- m := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.protocol, req.version, req.reqCode)
- metrics.GetOrRegisterMeter(m, nil).Mark(1)
+ if metrics.Enabled() {
+ t.trackedGauge(req.ReqCode).Dec(1)
+ t.lostMeter(req.ReqCode).Mark(1)
+ }
}
t.schedule()
}
-// schedule starts a timer to trigger on the expiration of the first network
-// packet.
+// schedule starts a timer to trigger on the expiration of the first network packet.
func (t *Tracker) schedule() {
if t.expire.Len() == 0 {
t.wake = nil
return
}
- t.wake = time.AfterFunc(time.Until(t.pending[t.expire.Front().Value.(uint64)].time.Add(t.timeout)), t.clean)
+ nextID := t.expire.Front().Value.(uint64)
+ nextTime := t.pending[nextID].time
+ t.wake = time.AfterFunc(time.Until(nextTime.Add(t.timeout)), t.clean)
}
-// Fulfil fills a pending request, if any is available, reporting on various metrics.
-func (t *Tracker) Fulfil(peer string, version uint, code uint64, id uint64) {
- if !metrics.Enabled() {
- return
+// Stop reclaims resources of the tracker.
+func (t *Tracker) Stop() {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ if t.wake != nil {
+ t.wake.Stop()
+ t.wake = nil
+ }
+ if metrics.Enabled() {
+ // Ensure metrics are decremented for pending requests.
+ counts := make(map[uint64]int64)
+ for _, req := range t.pending {
+ counts[req.ReqCode]++
+ }
+ for code, count := range counts {
+ t.trackedGauge(code).Dec(count)
+ }
}
+ clear(t.pending)
+ t.expire = nil
+}
+
+// Fulfil fills a pending request, if any is available, reporting on various metrics.
+func (t *Tracker) Fulfil(resp Response) error {
t.lock.Lock()
defer t.lock.Unlock()
// If it's a non existing request, track as stale response
- req, ok := t.pending[id]
+ req, ok := t.pending[resp.ID]
if !ok {
- m := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.protocol, version, code)
- metrics.GetOrRegisterMeter(m, nil).Mark(1)
- return
+ if metrics.Enabled() {
+ t.staleMeter(resp.MsgCode).Mark(1)
+ }
+ return ErrNoMatchingRequest
}
+
// If the response is funky, it might be some active attack
- if req.peer != peer || req.version != version || req.resCode != code {
- log.Warn("Network response id collision",
- "have", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, version, code),
- "want", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, req.version, req.resCode),
+ if req.RespCode != resp.MsgCode {
+ log.Warn("Network response code collision",
+ "have", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, resp.MsgCode),
+ "want", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, req.RespCode),
)
- return
+ return ErrCodeMismatch
+ }
+ if resp.Size > req.Size {
+ return ErrTooManyItems
}
+
// Everything matches, mark the request serviced and meter it
+ wasHead := req.expire.Prev() == nil
t.expire.Remove(req.expire)
- delete(t.pending, id)
- if req.expire.Prev() == nil {
+ delete(t.pending, req.ID)
+ if wasHead {
if t.wake.Stop() {
t.schedule()
}
}
- g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode)
- metrics.GetOrRegisterGauge(g, nil).Dec(1)
- h := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.protocol, req.version, req.reqCode)
+ // Update request metrics.
+ if metrics.Enabled() {
+ t.trackedGauge(req.ReqCode).Dec(1)
+ t.waitHistogram(req.ReqCode).Update(time.Since(req.time).Microseconds())
+ }
+ return nil
+}
+
+func (t *Tracker) trackedGauge(code uint64) *metrics.Gauge {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.cap.Name, t.cap.Version, code)
+ return metrics.GetOrRegisterGauge(name, nil)
+}
+
+func (t *Tracker) lostMeter(code uint64) *metrics.Meter {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.cap.Name, t.cap.Version, code)
+ return metrics.GetOrRegisterMeter(name, nil)
+}
+
+func (t *Tracker) staleMeter(code uint64) *metrics.Meter {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.cap.Name, t.cap.Version, code)
+ return metrics.GetOrRegisterMeter(name, nil)
+}
+
+func (t *Tracker) waitHistogram(code uint64) metrics.Histogram {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.cap.Name, t.cap.Version, code)
sampler := func() metrics.Sample {
- return metrics.ResettingSample(
- metrics.NewExpDecaySample(1028, 0.015),
- )
+ return metrics.ResettingSample(metrics.NewExpDecaySample(1028, 0.015))
}
- metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(req.time).Microseconds())
+ return metrics.GetOrRegisterHistogramLazy(name, nil, sampler)
}
diff --git a/p2p/tracker/tracker_test.go b/p2p/tracker/tracker_test.go
new file mode 100644
index 000000000..95e962964
--- /dev/null
+++ b/p2p/tracker/tracker_test.go
@@ -0,0 +1,83 @@
+// Copyright 2026 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package tracker
+
+import (
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/p2p"
+)
+
+// TestCleanAfterStop verifies that the clean method does not crash when called
+// after Stop. This can happen because clean is scheduled via time.AfterFunc and
+// may fire after Stop sets t.expire to nil.
+func TestCleanAfterStop(t *testing.T) {
+ cap := p2p.Cap{Name: "test", Version: 1}
+ timeout := 50 * time.Millisecond
+ tr := New(cap, "peer1", timeout)
+
+ // Track a request to start the expiration timer.
+ tr.Track(Request{ID: 1, ReqCode: 0x01, RespCode: 0x02, Size: 1})
+
+ // Stop the tracker, then wait for the timer to fire.
+ tr.Stop()
+ time.Sleep(timeout + 50*time.Millisecond)
+
+ // Also verify that calling clean directly after stop doesn't panic.
+ tr.clean()
+}
+
+// This checks that metrics gauges for pending requests are be decremented when a
+// Tracker is stopped.
+func TestMetricsOnStop(t *testing.T) {
+ metrics.Enable()
+
+ cap := p2p.Cap{Name: "test", Version: 1}
+ tr := New(cap, "peer1", time.Minute)
+
+ // Track some requests with different ReqCodes.
+ var id uint64
+ for i := 0; i < 3; i++ {
+ tr.Track(Request{ID: id, ReqCode: 0x01, RespCode: 0x02, Size: 1})
+ id++
+ }
+ for i := 0; i < 5; i++ {
+ tr.Track(Request{ID: id, ReqCode: 0x03, RespCode: 0x04, Size: 1})
+ id++
+ }
+
+ gauge1 := tr.trackedGauge(0x01)
+ gauge2 := tr.trackedGauge(0x03)
+
+ if gauge1.Snapshot().Value() != 3 {
+ t.Fatalf("gauge1 value mismatch: got %d, want 3", gauge1.Snapshot().Value())
+ }
+ if gauge2.Snapshot().Value() != 5 {
+ t.Fatalf("gauge2 value mismatch: got %d, want 5", gauge2.Snapshot().Value())
+ }
+
+ tr.Stop()
+
+ if gauge1.Snapshot().Value() != 0 {
+ t.Fatalf("gauge1 value after stop: got %d, want 0", gauge1.Snapshot().Value())
+ }
+ if gauge2.Snapshot().Value() != 0 {
+ t.Fatalf("gauge2 value after stop: got %d, want 0", gauge2.Snapshot().Value())
+ }
+}
diff --git a/rlp/encode.go b/rlp/encode.go
index 3645bbfda..fa2a3008e 100644
--- a/rlp/encode.go
+++ b/rlp/encode.go
@@ -101,6 +101,29 @@ func EncodeToReader(val interface{}) (size int, r io.Reader, err error) {
return buf.size(), &encReader{buf: buf}, nil
}
+// EncodeToRawList encodes val as an RLP list and returns it as a RawList.
+func EncodeToRawList[T any](val []T) (RawList[T], error) {
+ if len(val) == 0 {
+ return RawList[T]{}, nil
+ }
+
+ // Encode the value to an internal buffer.
+ buf := getEncBuffer()
+ defer encBufferPool.Put(buf)
+ if err := buf.encode(val); err != nil {
+ return RawList[T]{}, err
+ }
+
+ // Create the RawList. RawList assumes the initial list header is padded
+ // 9 bytes, so we have to determine the offset where the value should be
+ // placed.
+ contentSize := buf.lheads[0].size
+ bytes := make([]byte, contentSize+9)
+ offset := 9 - headsize(uint64(contentSize))
+ buf.copyTo(bytes[offset:])
+ return RawList[T]{enc: bytes, length: len(val)}, nil
+}
+
type listhead struct {
offset int // index of this header in string data
size int // total size of encoded data (including list headers)
diff --git a/rlp/iterator.go b/rlp/iterator.go
index 6be574572..9e41cec94 100644
--- a/rlp/iterator.go
+++ b/rlp/iterator.go
@@ -16,45 +16,74 @@
package rlp
-type listIterator struct {
- data []byte
- next []byte
- err error
+// Iterator is an iterator over the elements of an encoded container.
+type Iterator struct {
+ data []byte
+ next []byte
+ offset int
+ err error
}
-// NewListIterator creates an iterator for the (list) represented by data
-// TODO: Consider removing this implementation, as it is no longer used.
-func NewListIterator(data RawValue) (*listIterator, error) {
+// NewListIterator creates an iterator for the (list) represented by data.
+func NewListIterator(data RawValue) (Iterator, error) {
k, t, c, err := readKind(data)
if err != nil {
- return nil, err
+ return Iterator{}, err
}
if k != List {
- return nil, ErrExpectedList
- }
- it := &listIterator{
- data: data[t : t+c],
+ return Iterator{}, ErrExpectedList
}
+ it := newIterator(data[t:t+c], int(t))
return it, nil
}
-// Next forwards the iterator one step, returns true if it was not at end yet
-func (it *listIterator) Next() bool {
+func newIterator(data []byte, initialOffset int) Iterator {
+ return Iterator{data: data, offset: initialOffset}
+}
+
+// Next forwards the iterator one step.
+// Returns true if there is a next item or an error occurred on this step (check Err()).
+// On parse error, the iterator is marked finished and subsequent calls return false.
+func (it *Iterator) Next() bool {
if len(it.data) == 0 {
return false
}
_, t, c, err := readKind(it.data)
- it.next = it.data[:t+c]
- it.data = it.data[t+c:]
- it.err = err
+ if err != nil {
+ it.next = nil
+ it.err = err
+ // Mark iteration as finished to avoid potential infinite loops on subsequent Next calls.
+ it.data = nil
+ return true
+ }
+ length := t + c
+ it.next = it.data[:length]
+ it.data = it.data[length:]
+ it.offset += int(length)
+ it.err = nil
return true
}
-// Value returns the current value
-func (it *listIterator) Value() []byte {
+// Value returns the current value.
+func (it *Iterator) Value() []byte {
return it.next
}
-func (it *listIterator) Err() error {
+// Count returns the remaining number of items.
+// Note this is O(n) and the result may be incorrect if the list data is invalid.
+// The returned count is always an upper bound on the remaining items
+// that will be visited by the iterator.
+func (it *Iterator) Count() int {
+ count, _ := CountValues(it.data)
+ return count
+}
+
+// Offset returns the offset of the current value into the list data.
+func (it *Iterator) Offset() int {
+ return it.offset - len(it.next)
+}
+
+// Err returns the error that caused Next to return false, if any.
+func (it *Iterator) Err() error {
return it.err
}
diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go
index a22aaec86..275d4371c 100644
--- a/rlp/iterator_test.go
+++ b/rlp/iterator_test.go
@@ -17,6 +17,7 @@
package rlp
import (
+ "io"
"testing"
"github.com/ethereum/go-ethereum/common/hexutil"
@@ -38,14 +39,25 @@ func TestIterator(t *testing.T) {
t.Fatal("expected two elems, got zero")
}
txs := it.Value()
+ if offset := it.Offset(); offset != 3 {
+ t.Fatal("wrong offset", offset, "want 3")
+ }
+
// Check that uncles exist
if !it.Next() {
t.Fatal("expected two elems, got one")
}
+ if offset := it.Offset(); offset != 219 {
+ t.Fatal("wrong offset", offset, "want 219")
+ }
+
txit, err := NewListIterator(txs)
if err != nil {
t.Fatal(err)
}
+ if c := txit.Count(); c != 2 {
+ t.Fatal("wrong Count:", c)
+ }
var i = 0
for txit.Next() {
if txit.err != nil {
@@ -57,3 +69,65 @@ func TestIterator(t *testing.T) {
t.Errorf("count wrong, expected %d got %d", i, exp)
}
}
+
+func TestIteratorErrors(t *testing.T) {
+ tests := []struct {
+ input []byte
+ wantCount int // expected Count before iterating
+ wantErr error
+ }{
+ // Second item string header claims 3 bytes content, but only 2 remain.
+ {unhex("C4 01 83AABB"), 2, ErrValueTooLarge},
+ // Second item truncated: B9 requires 2 size bytes, none available.
+ {unhex("C2 01 B9"), 2, io.ErrUnexpectedEOF},
+ // 0x05 should be encoded directly, not as 81 05.
+ {unhex("C3 01 8105"), 2, ErrCanonSize},
+ // Long-form string header B8 used for 1-byte content (< 56).
+ {unhex("C4 01 B801AA"), 2, ErrCanonSize},
+ // Long-form list header F8 used for 1-byte content (< 56).
+ {unhex("C4 01 F80101"), 2, ErrCanonSize},
+ }
+ for _, tt := range tests {
+ it, err := NewListIterator(tt.input)
+ if err != nil {
+ t.Fatal("NewListIterator error:", err)
+ }
+ if c := it.Count(); c != tt.wantCount {
+ t.Fatalf("%x: Count = %d, want %d", tt.input, c, tt.wantCount)
+ }
+ n := 0
+ for it.Next() {
+ if it.Err() != nil {
+ break
+ }
+ n++
+ }
+ if wantN := tt.wantCount - 1; n != wantN {
+ t.Fatalf("%x: got %d valid items, want %d", tt.input, n, wantN)
+ }
+ if it.Err() != tt.wantErr {
+ t.Fatalf("%x: got error %v, want %v", tt.input, it.Err(), tt.wantErr)
+ }
+ if it.Next() {
+ t.Fatalf("%x: Next returned true after error", tt.input)
+ }
+ }
+}
+
+func FuzzIteratorCount(f *testing.F) {
+ examples := [][]byte{unhex("010203"), unhex("018142"), unhex("01830202")}
+ for _, e := range examples {
+ f.Add(e)
+ }
+ f.Fuzz(func(t *testing.T, in []byte) {
+ it := newIterator(in, 0)
+ count := it.Count()
+ i := 0
+ for it.Next() {
+ i++
+ }
+ if i != count {
+ t.Fatalf("%x: count %d not equal to %d iterations", in, count, i)
+ }
+ })
+}
diff --git a/rlp/raw.go b/rlp/raw.go
index 773aa7e61..2848b44d5 100644
--- a/rlp/raw.go
+++ b/rlp/raw.go
@@ -17,8 +17,10 @@
package rlp
import (
+ "fmt"
"io"
"reflect"
+ "slices"
)
// RawValue represents an encoded RLP value and can be used to delay
@@ -28,6 +30,144 @@ type RawValue []byte
var rawValueType = reflect.TypeOf(RawValue{})
+// RawList represents an encoded RLP list.
+type RawList[T any] struct {
+ // The list is stored in encoded form.
+ // Note this buffer has some special properties:
+ //
+ // - if the buffer is nil, it's the zero value, representing
+ // an empty list.
+ // - if the buffer is non-nil, it must have a length of at least
+ // 9 bytes, which is reserved padding for the encoded list header.
+ // The remaining bytes, enc[9:], store the content bytes of the list.
+ //
+ // The implementation code mostly works with the Content method because it
+ // returns something valid either way.
+ enc []byte
+
+ // length holds the number of items in the list.
+ length int
+}
+
+// Content returns the RLP-encoded data of the list.
+// This does not include the list-header.
+// The return value is a direct reference to the internal buffer, not a copy.
+func (r *RawList[T]) Content() []byte {
+ if r.enc == nil {
+ return nil
+ } else {
+ return r.enc[9:]
+ }
+}
+
+// EncodeRLP writes the encoded list to the writer.
+func (r RawList[T]) EncodeRLP(w io.Writer) error {
+ _, err := w.Write(r.Bytes())
+ return err
+}
+
+// Bytes returns the RLP encoding of the list.
+// Note the return value aliases the internal buffer.
+func (r *RawList[T]) Bytes() []byte {
+ if r == nil || r.enc == nil {
+ return []byte{0xC0} // zero value encodes as empty list
+ }
+ n := puthead(r.enc, 0xC0, 0xF7, uint64(len(r.Content())))
+ copy(r.enc[9-n:], r.enc[:n])
+ return r.enc[9-n:]
+}
+
+// DecodeRLP decodes the list. This does not perform validation of the items!
+func (r *RawList[T]) DecodeRLP(s *Stream) error {
+ k, size, err := s.Kind()
+ if err != nil {
+ return err
+ }
+ if k != List {
+ return fmt.Errorf("%w for %T", ErrExpectedList, r)
+ }
+ enc := make([]byte, 9+size)
+ if err := s.readFull(enc[9:]); err != nil {
+ return err
+ }
+ n, err := CountValues(enc[9:])
+ if err != nil {
+ if err == ErrValueTooLarge {
+ return ErrElemTooLarge
+ }
+ return err
+ }
+ *r = RawList[T]{enc: enc, length: n}
+ return nil
+}
+
+// Items decodes and returns all items in the list.
+func (r *RawList[T]) Items() ([]T, error) {
+ items := make([]T, r.Len())
+ it := r.ContentIterator()
+ for i := 0; it.Next(); i++ {
+ if err := DecodeBytes(it.Value(), &items[i]); err != nil {
+ return items[:i], err
+ }
+ }
+ return items, nil
+}
+
+// Len returns the number of items in the list.
+func (r *RawList[T]) Len() int {
+ return r.length
+}
+
+// Size returns the encoded size of the list.
+func (r *RawList[T]) Size() uint64 {
+ return ListSize(uint64(len(r.Content())))
+}
+
+// ContentIterator returns an iterator over the content of the list.
+// Note the offsets returned by iterator.Offset are relative to the
+// Content bytes of the list.
+func (r *RawList[T]) ContentIterator() Iterator {
+ return newIterator(r.Content(), 0)
+}
+
+// Append adds an item to the end of the list.
+func (r *RawList[T]) Append(item T) error {
+ if r.enc == nil {
+ r.enc = make([]byte, 9)
+ }
+
+ eb := getEncBuffer()
+ defer encBufferPool.Put(eb)
+
+ if err := eb.encode(item); err != nil {
+ return err
+ }
+ prevEnd := len(r.enc)
+ end := prevEnd + eb.size()
+ r.enc = slices.Grow(r.enc, eb.size())[:end]
+ eb.copyTo(r.enc[prevEnd:end])
+ r.length++
+ return nil
+}
+
+// AppendRaw adds an encoded item to the list.
+// The given byte slice must contain exactly one RLP value.
+func (r *RawList[T]) AppendRaw(b []byte) error {
+ _, tagsize, contentsize, err := readKind(b)
+ if err != nil {
+ return err
+ }
+ if tagsize+contentsize != uint64(len(b)) {
+ return fmt.Errorf("rlp: input has trailing bytes in AppendRaw")
+ }
+ if r.enc == nil {
+ r.enc = make([]byte, 9)
+ }
+ r.enc = append(r.enc, b...)
+ r.length++
+ return nil
+}
+
// StringSize returns the encoded size of a string.
func StringSize(s string) uint64 {
switch {
@@ -143,7 +283,7 @@ func CountValues(b []byte) (int, error) {
for ; len(b) > 0; i++ {
_, tagsize, size, err := readKind(b)
if err != nil {
- return 0, err
+ return i + 1, err
}
b = b[tagsize+size:]
}
diff --git a/rlp/raw_test.go b/rlp/raw_test.go
index 7b3255eca..e80b42718 100644
--- a/rlp/raw_test.go
+++ b/rlp/raw_test.go
@@ -19,11 +19,259 @@ package rlp
import (
"bytes"
"errors"
+ "fmt"
"io"
+ "reflect"
"testing"
"testing/quick"
)
+type rawListTest[T any] struct {
+ input string
+ content string
+ items []T
+ length int
+}
+
+func (test rawListTest[T]) name() string {
+ return fmt.Sprintf("%T-%d", *new(T), test.length)
+}
+
+func (test rawListTest[T]) run(t *testing.T) {
+ // check decoding and properties
+ input := unhex(test.input)
+ inputSize := len(input)
+ var rl RawList[T]
+ if err := DecodeBytes(input, &rl); err != nil {
+ t.Fatal("decode failed:", err)
+ }
+ if l := rl.Len(); l != test.length {
+ t.Fatalf("wrong Len %d, want %d", l, test.length)
+ }
+ if sz := rl.Size(); sz != uint64(inputSize) {
+ t.Fatalf("wrong Size %d, want %d", sz, inputSize)
+ }
+ items, err := rl.Items()
+ if err != nil {
+ t.Fatal("Items failed:", err)
+ }
+ if !reflect.DeepEqual(items, test.items) {
+ t.Fatal("wrong items:", items)
+ }
+ if !bytes.Equal(rl.Content(), unhex(test.content)) {
+ t.Fatalf("wrong Content %x, want %s", rl.Content(), test.content)
+ }
+ if !bytes.Equal(rl.Bytes(), unhex(test.input)) {
+ t.Fatalf("wrong Bytes %x, want %s", rl.Bytes(), test.input)
+ }
+
+ // check iterator
+ it := rl.ContentIterator()
+ i := 0
+ for it.Next() {
+ var item T
+ if err := DecodeBytes(it.Value(), &item); err != nil {
+ t.Fatalf("item %d decode error: %v", i, err)
+ }
+ if !reflect.DeepEqual(item, items[i]) {
+ t.Fatalf("iterator has wrong item %v at %d", item, i)
+ }
+ i++
+ }
+ if i != test.length {
+ t.Fatalf("iterator produced %d values, want %d", i, test.length)
+ }
+ if it.Err() != nil {
+ t.Fatalf("iterator error: %v", it.Err())
+ }
+
+ // check encoding round trip
+ output, err := EncodeToBytes(&rl)
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ if !bytes.Equal(output, unhex(test.input)) {
+ t.Fatalf("encoding does not round trip: %x", output)
+ }
+
+ // check EncodeToRawList on items produces same bytes
+ encRL, err := EncodeToRawList(test.items)
+ if err != nil {
+ t.Fatal("EncodeToRawList error:", err)
+ }
+ encRLOutput, err := EncodeToBytes(&encRL)
+ if err != nil {
+ t.Fatal("EncodeToBytes of encoded list failed:", err)
+ }
+ if !bytes.Equal(encRLOutput, output) {
+ t.Fatalf("wrong encoding of EncodeToRawList result: %x", encRLOutput)
+ }
+}
+
+func TestRawList(t *testing.T) {
+ tests := []interface {
+ name() string
+ run(t *testing.T)
+ }{
+ rawListTest[uint64]{
+ input: "C0",
+ content: "",
+ items: []uint64{},
+ length: 0,
+ },
+ rawListTest[uint64]{
+ input: "C3010203",
+ content: "010203",
+ items: []uint64{1, 2, 3},
+ length: 3,
+ },
+ rawListTest[simplestruct]{
+ input: "C6C20102C20304",
+ content: "C20102C20304",
+ items: []simplestruct{{1, "\x02"}, {3, "\x04"}},
+ length: 2,
+ },
+ rawListTest[string]{
+ input: "F83C836161618362626283636363836464648365656583666666836767678368686883696969836A6A6A836B6B6B836C6C6C836D6D6D836E6E6E836F6F6F",
+ content: "836161618362626283636363836464648365656583666666836767678368686883696969836A6A6A836B6B6B836C6C6C836D6D6D836E6E6E836F6F6F",
+ items: []string{"aaa", "bbb", "ccc", "ddd", "eee", "fff", "ggg", "hhh", "iii", "jjj", "kkk", "lll", "mmm", "nnn", "ooo"},
+ length: 15,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name(), test.run)
+ }
+}
+
+func TestRawListEmpty(t *testing.T) {
+ // zero value list
+ var rl RawList[uint64]
+ b, _ := EncodeToBytes(&rl)
+ if !bytes.Equal(b, unhex("C0")) {
+ t.Fatalf("empty RawList has wrong encoding %x", b)
+ }
+ if rl.Len() != 0 {
+ t.Fatalf("empty list has Len %d", rl.Len())
+ }
+ if rl.Size() != 1 {
+ t.Fatalf("empty list has Size %d", rl.Size())
+ }
+ if len(rl.Content()) > 0 {
+ t.Fatalf("empty list has non-empty Content")
+ }
+ if !bytes.Equal(rl.Bytes(), []byte{0xC0}) {
+ t.Fatalf("empty list has wrong encoding")
+ }
+
+ // nil pointer
+ var nilptr *RawList[uint64]
+ b, _ = EncodeToBytes(nilptr)
+ if !bytes.Equal(b, unhex("C0")) {
+ t.Fatalf("nil pointer to RawList has wrong encoding %x", b)
+ }
+}
+
+// This checks that *RawList works in an 'optional' context.
+func TestRawListOptional(t *testing.T) {
+ type foo struct {
+ L *RawList[uint64] `rlp:"optional"`
+ }
+ // nil pointer encoding
+ var empty foo
+ b, _ := EncodeToBytes(empty)
+ if !bytes.Equal(b, unhex("C0")) {
+ t.Fatalf("nil pointer to RawList has wrong encoding %x", b)
+ }
+ // decoding
+ var dec foo
+ if err := DecodeBytes(unhex("C0"), &dec); err != nil {
+ t.Fatal(err)
+ }
+ if dec.L != nil {
+ t.Fatal("rawlist was decoded as non-nil")
+ }
+}
+
+func TestRawListAppend(t *testing.T) {
+ var rl RawList[simplestruct]
+
+ v1 := simplestruct{1, "one"}
+ v2 := simplestruct{2, "two"}
+ if err := rl.Append(v1); err != nil {
+ t.Fatal("append 1 failed:", err)
+ }
+ if err := rl.Append(v2); err != nil {
+ t.Fatal("append 2 failed:", err)
+ }
+
+ if rl.Len() != 2 {
+ t.Fatalf("wrong Len %d", rl.Len())
+ }
+ if rl.Size() != 13 {
+ t.Fatalf("wrong Size %d", rl.Size())
+ }
+ if !bytes.Equal(rl.Content(), unhex("C501836F6E65 C5028374776F")) {
+ t.Fatalf("wrong Content %x", rl.Content())
+ }
+ encoded, _ := EncodeToBytes(&rl)
+ if !bytes.Equal(encoded, unhex("CC C501836F6E65 C5028374776F")) {
+ t.Fatalf("wrong encoding %x", encoded)
+ }
+}
+
+func TestRawListAppendRaw(t *testing.T) {
+ var rl RawList[uint64]
+
+ if err := rl.AppendRaw(unhex("01")); err != nil {
+ t.Fatal("AppendRaw(01) failed:", err)
+ }
+ if err := rl.AppendRaw(unhex("820102")); err != nil {
+ t.Fatal("AppendRaw(820102) failed:", err)
+ }
+ if rl.Len() != 2 {
+ t.Fatalf("wrong Len %d after valid appends", rl.Len())
+ }
+
+ if err := rl.AppendRaw(nil); err == nil {
+ t.Fatal("AppendRaw(nil) should fail")
+ }
+ if err := rl.AppendRaw(unhex("0102")); err == nil {
+ t.Fatal("AppendRaw(0102) should fail due to trailing bytes")
+ }
+ if err := rl.AppendRaw(unhex("8201")); err == nil {
+ t.Fatal("AppendRaw(8201) should fail due to truncated value")
+ }
+ if rl.Len() != 2 {
+ t.Fatalf("wrong Len %d after invalid appends, want 2", rl.Len())
+ }
+}
+
+func TestRawListDecodeInvalid(t *testing.T) {
+ tests := []struct {
+ input string
+ err error
+ }{
+ // Single item with non-canonical size (0x81 wrapping byte <= 0x7F).
+ {input: "C28142", err: ErrCanonSize},
+ // Single item claiming more bytes than available in the list.
+ {input: "C484020202", err: ErrElemTooLarge},
+ // Two items, second has non-canonical size.
+ {input: "C3018142", err: ErrCanonSize},
+ // Two items, second claims more bytes than remain in the list.
+ {input: "C401830202", err: ErrElemTooLarge},
+ // Item is a sub-list whose declared size exceeds available bytes.
+ {input: "C3C40102", err: ErrElemTooLarge},
+ }
+ for _, test := range tests {
+ var rl RawList[RawValue]
+ err := DecodeBytes(unhex(test.input), &rl)
+ if !errors.Is(err, test.err) {
+ t.Errorf("input %s: error mismatch: got %v, want %v", test.input, err, test.err)
+ }
+ }
+}
+
func TestCountValues(t *testing.T) {
tests := []struct {
input string // note: spaces in input are stripped by unhex
@@ -40,9 +288,9 @@ func TestCountValues(t *testing.T) {
{"820101 820202 8403030303 04", 4, nil},
// size errors
- {"8142", 0, ErrCanonSize},
- {"01 01 8142", 0, ErrCanonSize},
- {"02 84020202", 0, ErrValueTooLarge},
+ {"8142", 1, ErrCanonSize},
+ {"01 01 8142", 3, ErrCanonSize},
+ {"02 84020202", 2, ErrValueTooLarge},
{
input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",
diff --git a/trie/node.go b/trie/node.go
index 15bbf62f1..1d0c5556c 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -151,11 +151,14 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) {
if err != nil {
return nil, fmt.Errorf("decode error: %v", err)
}
- switch c, _ := rlp.CountValues(elems); c {
- case 2:
+ c, err := rlp.CountValues(elems)
+ switch {
+ case err != nil:
+ return nil, fmt.Errorf("invalid node list: %v", err)
+ case c == 2:
n, err := decodeShort(hash, elems)
return n, wrapError(err, "short")
- case 17:
+ case c == 17:
n, err := decodeFull(hash, elems)
return n, wrapError(err, "full")
default: