diff --git a/cmd/devp2p/internal/ethtest/snap.go b/cmd/devp2p/internal/ethtest/snap.go index b8bad6477..131a22d8d 100644 --- a/cmd/devp2p/internal/ethtest/snap.go +++ b/cmd/devp2p/internal/ethtest/snap.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/protocols/snap" "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/trienode" "golang.org/x/crypto/sha3" @@ -938,10 +939,14 @@ func (s *Suite) snapGetTrieNodes(t *utesting.T, tc *trieNodesTest) error { } // write0 request + paths, err := rlp.EncodeToRawList(tc.paths) + if err != nil { + panic(err) + } req := &snap.GetTrieNodesPacket{ ID: uint64(rand.Int63()), Root: tc.root, - Paths: tc.paths, + Paths: paths, Bytes: tc.nBytes, } msg, err := conn.snapRequest(snap.GetTrieNodesMsg, req) diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go index 01eb0fd54..0db206b24 100644 --- a/cmd/devp2p/internal/ethtest/suite.go +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -18,6 +18,7 @@ package ethtest import ( "crypto/rand" + "fmt" "math/big" "reflect" @@ -31,6 +32,7 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" "github.com/holiman/uint256" ) @@ -150,7 +152,11 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) { if err != nil { t.Fatalf("failed to get headers for given request: %v", err) } - if !headersMatch(expected, headers.BlockHeadersRequest) { + received, err := headers.List.Items() + if err != nil { + t.Fatalf("invalid headers received: %v", err) + } + if !headersMatch(expected, received) { t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers) } } @@ -201,31 +207,23 @@ concurrently, with different request IDs.`) } // Wait for responses. - headers1 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers1); err != nil { - t.Fatalf("error reading block headers msg: %v", err) - } - if got, want := headers1.RequestId, req1.RequestId; got != want { - t.Fatalf("unexpected request id in response: got %d, want %d", got, want) - } - headers2 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers2); err != nil { - t.Fatalf("error reading block headers msg: %v", err) - } - if got, want := headers2.RequestId, req2.RequestId; got != want { - t.Fatalf("unexpected request id in response: got %d, want %d", got, want) + // Note they can arrive in either order. + resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { + if msg.RequestId != 111 && msg.RequestId != 222 { + t.Fatalf("response with unknown request ID: %v", msg.RequestId) + } + return msg.RequestId + }) + if err != nil { + t.Fatal(err) } - // Check received headers for accuracy. - if expected, err := s.chain.GetHeaders(req1); err != nil { - t.Fatalf("failed to get expected headers for request 1: %v", err) - } else if !headersMatch(expected, headers1.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers1) + // Check if headers match. + if err := s.checkHeadersAgainstChain(req1, resp[111]); err != nil { + t.Fatal(err) } - if expected, err := s.chain.GetHeaders(req2); err != nil { - t.Fatalf("failed to get expected headers for request 2: %v", err) - } else if !headersMatch(expected, headers2.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers2) + if err := s.checkHeadersAgainstChain(req2, resp[222]); err != nil { + t.Fatal(err) } } @@ -259,7 +257,7 @@ same request ID. The node should handle the request by responding to both reques Origin: eth.HashOrNumber{ Number: 33, }, - Amount: 2, + Amount: 3, }, } @@ -271,33 +269,57 @@ same request ID. The node should handle the request by responding to both reques t.Fatalf("failed to write to connection: %v", err) } - // Wait for the responses. - headers1 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers1); err != nil { - t.Fatalf("error reading from connection: %v", err) - } - if got, want := headers1.RequestId, request1.RequestId; got != want { - t.Fatalf("unexpected request id: got %d, want %d", got, want) + // Wait for the responses. They can arrive in either order, and we can't tell them + // apart by their request ID, so use the number of headers instead. + resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { + id := uint64(msg.List.Len()) + if id != 2 && id != 3 { + t.Fatalf("invalid number of headers in response: %d", id) + } + return id + }) + if err != nil { + t.Fatal(err) } - headers2 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers2); err != nil { - t.Fatalf("error reading from connection: %v", err) + + // Check if headers match. + if err := s.checkHeadersAgainstChain(request1, resp[2]); err != nil { + t.Fatal(err) } - if got, want := headers2.RequestId, request2.RequestId; got != want { - t.Fatalf("unexpected request id: got %d, want %d", got, want) + if err := s.checkHeadersAgainstChain(request2, resp[3]); err != nil { + t.Fatal(err) } +} - // Check if headers match. - if expected, err := s.chain.GetHeaders(request1); err != nil { - t.Fatalf("failed to get expected block headers: %v", err) - } else if !headersMatch(expected, headers1.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers1) +func (s *Suite) checkHeadersAgainstChain(req *eth.GetBlockHeadersPacket, resp *eth.BlockHeadersPacket) error { + received2, err := resp.List.Items() + if err != nil { + return fmt.Errorf("invalid headers in response with request ID %v (%d items): %v", resp.RequestId, resp.List.Len(), err) } - if expected, err := s.chain.GetHeaders(request2); err != nil { - t.Fatalf("failed to get expected block headers: %v", err) - } else if !headersMatch(expected, headers2.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers2) + if expected, err := s.chain.GetHeaders(req); err != nil { + return fmt.Errorf("test chain failed to get expected headers for request: %v", err) + } else if !headersMatch(expected, received2) { + return fmt.Errorf("header mismatch for request ID %v (%d items): \nexpected %v \ngot %v", resp.RequestId, resp.List.Len(), expected, resp) } + return nil +} + +// collectResponses waits for n messages of type T on the given connection. +// The messsages are collected according to the 'identity' function. +func collectHeaderResponses(conn *Conn, n int, identity func(*eth.BlockHeadersPacket) uint64) (map[uint64]*eth.BlockHeadersPacket, error) { + resp := make(map[uint64]*eth.BlockHeadersPacket, n) + for range n { + r := new(eth.BlockHeadersPacket) + if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, r); err != nil { + return resp, fmt.Errorf("read error: %v", err) + } + id := identity(r) + if resp[id] != nil { + return resp, fmt.Errorf("duplicate response %v", r) + } + resp[id] = r + } + return resp, nil } func (s *Suite) TestZeroRequestID(t *utesting.T) { @@ -329,10 +351,8 @@ and expects a response.`) if got, want := headers.RequestId, req.RequestId; got != want { t.Fatalf("unexpected request id") } - if expected, err := s.chain.GetHeaders(req); err != nil { - t.Fatalf("failed to get expected block headers: %v", err) - } else if !headersMatch(expected, headers.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers) + if err := s.checkHeadersAgainstChain(req, headers); err != nil { + t.Fatal(err) } } @@ -366,9 +386,52 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) { if got, want := resp.RequestId, req.RequestId; got != want { t.Fatalf("unexpected request id in respond", got, want) } - bodies := resp.BlockBodiesResponse - if len(bodies) != len(req.GetBlockBodiesRequest) { - t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), len(bodies)) + if resp.List.Len() != len(req.GetBlockBodiesRequest) { + t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), resp.List.Len()) + } +} + +func (s *Suite) TestGetReceipts(t *utesting.T) { + t.Log(`This test sends GetReceipts requests to the node for known blocks in the test chain.`) + conn, err := s.dial() + if err != nil { + t.Fatalf("dial failed: %v", err) + } + defer conn.Close() + if err := conn.peer(s.chain, nil); err != nil { + t.Fatalf("peering failed: %v", err) + } + + // Find some blocks containing receipts. + var hashes = make([]common.Hash, 0, 3) + for i := range s.chain.Len() { + block := s.chain.GetBlock(i) + if len(block.Transactions()) > 0 { + hashes = append(hashes, block.Hash()) + } + if len(hashes) == cap(hashes) { + break + } + } + + // Create receipts request. + req := ð.GetReceiptsPacket{ + RequestId: 66, + GetReceiptsRequest: (eth.GetReceiptsRequest)(hashes), + } + if err := conn.Write(ethProto, eth.GetReceiptsMsg, req); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + // Wait for response. + resp := new(eth.ReceiptsPacket) + if err := conn.ReadMsg(ethProto, eth.ReceiptsMsg, &resp); err != nil { + t.Fatalf("error reading block bodies msg: %v", err) + } + if got, want := resp.RequestId, req.RequestId; got != want { + t.Fatalf("unexpected request id in respond", got, want) + } + if resp.List.Len() != len(req.GetReceiptsRequest) { + t.Fatalf("wrong receipts in response: expected %d receipts, got %d", len(req.GetReceiptsRequest), resp.List.Len()) } } @@ -675,7 +738,11 @@ on another peer connection using GetPooledTransactions.`) if got, want := msg.RequestId, req.RequestId; got != want { t.Fatalf("unexpected request id in response: got %d, want %d", got, want) } - for _, got := range msg.PooledTransactionsResponse { + responseTxs, err := msg.List.Items() + if err != nil { + t.Fatalf("invalid transactions in response: %v", err) + } + for _, got := range responseTxs { if _, exists := set[got.Hash()]; !exists { t.Fatalf("unexpected tx received: %v", got.Hash()) } @@ -779,7 +846,7 @@ func (s *Suite) makeBlobTxs(count, blobs int, discriminator byte) (txs types.Tra from, nonce := s.chain.GetSender(5) for i := 0; i < count; i++ { // Make blob data, max of 2 blobs per tx. - blobdata := make([]byte, blobs%3) + blobdata := make([]byte, min(blobs, 2)) for i := range blobdata { blobdata[i] = discriminator blobs -= 1 @@ -851,7 +918,8 @@ func (s *Suite) TestBlobViolations(t *utesting.T) { if err := conn.ReadMsg(ethProto, eth.GetPooledTransactionsMsg, req); err != nil { t.Fatalf("reading pooled tx request failed: %v", err) } - resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: test.resp} + encTxs, _ := rlp.EncodeToRawList(test.resp) + resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs} if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil { t.Fatalf("writing pooled tx response failed: %v", err) } diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go index 80b5d8074..16f9a3ad8 100644 --- a/cmd/devp2p/internal/ethtest/transaction.go +++ b/cmd/devp2p/internal/ethtest/transaction.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/rlp" ) // sendTxs sends the given transactions to the node and @@ -51,7 +52,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error { return fmt.Errorf("peering failed: %v", err) } - if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket(txs)); err != nil { + encTxs, _ := rlp.EncodeToRawList(txs) + if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket{RawList: encTxs}); err != nil { return fmt.Errorf("failed to write message to connection: %v", err) } @@ -68,7 +70,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error { } switch msg := msg.(type) { case *eth.TransactionsPacket: - for _, tx := range *msg { + txs, _ := msg.Items() + for _, tx := range txs { got[tx.Hash()] = true } case *eth.NewPooledTransactionHashesPacket: @@ -80,9 +83,10 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error { if err != nil { t.Logf("invalid GetBlockHeaders request: %v", err) } + encHeaders, _ := rlp.EncodeToRawList(headers) recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{ - RequestId: msg.RequestId, - BlockHeadersRequest: headers, + RequestId: msg.RequestId, + List: encHeaders, }) default: return fmt.Errorf("unexpected eth wire msg: %s", pretty.Sdump(msg)) @@ -167,9 +171,10 @@ func (s *Suite) sendInvalidTxs(t *utesting.T, txs []*types.Transaction) error { if err != nil { t.Logf("invalid GetBlockHeaders request: %v", err) } + encHeaders, _ := rlp.EncodeToRawList(headers) recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{ - RequestId: msg.RequestId, - BlockHeadersRequest: headers, + RequestId: msg.RequestId, + List: encHeaders, }) default: return fmt.Errorf("unexpected eth message: %v", pretty.Sdump(msg)) diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 2468e1a98..27dc47a72 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -266,10 +266,12 @@ func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int, func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) { blobs := eth.ServiceGetBlockBodiesQuery(dlp.chain, hashes) - bodies := make([]*eth.BlockBody, len(blobs)) + bodies := make([]*types.Body, len(blobs)) + ethbodies := make([]eth.BlockBody, len(blobs)) for i, blob := range blobs { - bodies[i] = new(eth.BlockBody) + bodies[i] = new(types.Body) rlp.DecodeBytes(blob, bodies[i]) + rlp.DecodeBytes(blob, ðbodies[i]) } var ( txsHashes = make([]common.Hash, len(bodies)) @@ -285,9 +287,13 @@ func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *et Peer: dlp.id, } res := ð.Response{ - Req: req, - Res: (*eth.BlockBodiesResponse)(&bodies), - Meta: [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes}, + Req: req, + Res: (*eth.BlockBodiesResponse)(ðbodies), + Meta: eth.BlockBodyHashes{ + TransactionRoots: txsHashes, + UncleHashes: uncleHashes, + WithdrawalRoots: withdrawalHashes, + }, Time: 1, Done: make(chan error, 1), // Ignore the returned status } @@ -303,14 +309,14 @@ func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *et func (dlp *downloadTesterPeer) RequestReceipts(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) { blobs := eth.ServiceGetReceiptsQuery(dlp.chain, hashes) - receipts := make([][]*types.Receipt, len(blobs)) + receipts := make([]types.Receipts, len(blobs)) for i, blob := range blobs { rlp.DecodeBytes(blob, &receipts[i]) } hasher := trie.NewStackTrie(nil) hashes = make([]common.Hash, len(receipts)) for i, receipt := range receipts { - hashes[i] = types.DeriveSha(types.Receipts(receipt), hasher) + hashes[i] = types.DeriveSha(receipt, hasher) } req := ð.Request{ Peer: dlp.id, @@ -335,14 +341,14 @@ func (dlp *downloadTesterPeer) ID() string { // RequestAccountRange fetches a batch of accounts rooted in a specific account // trie, starting with the origin. -func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error { +func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error { // Create the request and service it req := &snap.GetAccountRangePacket{ ID: id, Root: root, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), } slimaccs, proofs := snap.ServiceGetAccountRangeQuery(dlp.chain, req) @@ -361,7 +367,7 @@ func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limi // RequestStorageRanges fetches a batch of storage slots belonging to one or // more accounts. If slots from only one account is requested, an origin marker // may also be used to retrieve from there. -func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error { +func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error { // Create the request and service it req := &snap.GetStorageRangesPacket{ ID: id, @@ -369,7 +375,7 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, Root: root, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), } storage, proofs := snap.ServiceGetStorageRangesQuery(dlp.chain, req) @@ -386,25 +392,28 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, } // RequestByteCodes fetches a batch of bytecodes by hash. -func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error { +func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error { req := &snap.GetByteCodesPacket{ ID: id, Hashes: hashes, - Bytes: bytes, + Bytes: uint64(bytes), } codes := snap.ServiceGetByteCodesQuery(dlp.chain, req) go dlp.dl.downloader.SnapSyncer.OnByteCodes(dlp, id, codes) return nil } -// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in -// a specific state trie. -func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, paths []snap.TrieNodePathSet, bytes uint64) error { +// RequestTrieNodes fetches a batch of account or storage trie nodes. +func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []snap.TrieNodePathSet, bytes int) error { + encPaths, err := rlp.EncodeToRawList(paths) + if err != nil { + panic(err) + } req := &snap.GetTrieNodesPacket{ ID: id, Root: root, - Paths: paths, - Bytes: bytes, + Paths: encPaths, + Bytes: uint64(bytes), } nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now()) go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes) diff --git a/eth/downloader/fetchers_concurrent_bodies.go b/eth/downloader/fetchers_concurrent_bodies.go index 5105fda66..aa5f8e180 100644 --- a/eth/downloader/fetchers_concurrent_bodies.go +++ b/eth/downloader/fetchers_concurrent_bodies.go @@ -89,15 +89,14 @@ func (q *bodyQueue) request(peer *peerConnection, req *fetchRequest, resCh chan // deliver is responsible for taking a generic response packet from the concurrent // fetcher, unpacking the body data and delivering it to the downloader's queue. func (q *bodyQueue) deliver(peer *peerConnection, packet *eth.Response) (int, error) { - txs, uncles, withdrawals := packet.Res.(*eth.BlockBodiesResponse).Unpack() - hashsets := packet.Meta.([][]common.Hash) // {txs hashes, uncle hashes, withdrawal hashes} - - accepted, err := q.queue.DeliverBodies(peer.id, txs, hashsets[0], uncles, hashsets[1], withdrawals, hashsets[2]) + resp := packet.Res.(*eth.BlockBodiesResponse) + meta := packet.Meta.(eth.BlockBodyHashes) + accepted, err := q.queue.DeliverBodies(peer.id, meta, *resp) switch { - case err == nil && len(txs) == 0: + case err == nil && len(*resp) == 0: peer.log.Trace("Requested bodies delivered") case err == nil: - peer.log.Trace("Delivered new batch of bodies", "count", len(txs), "accepted", accepted) + peer.log.Trace("Delivered new batch of bodies", "count", len(*resp), "accepted", accepted) default: peer.log.Debug("Failed to deliver retrieved bodies", "err", err) } diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index fe08ff64c..e50c876a3 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -29,10 +29,9 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/prque" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto/kzg4844" + "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" - "github.com/ethereum/go-ethereum/params" ) const ( @@ -772,62 +771,53 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, hashes []comm // DeliverBodies injects a block body retrieval response into the results queue. // The method returns the number of blocks bodies accepted from the delivery and // also wakes any threads waiting for data delivery. -func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListHashes []common.Hash, - uncleLists [][]*types.Header, uncleListHashes []common.Hash, - withdrawalLists [][]*types.Withdrawal, withdrawalListHashes []common.Hash) (int, error) { +func (q *queue) DeliverBodies(id string, hashes eth.BlockBodyHashes, bodies []eth.BlockBody) (int, error) { q.lock.Lock() defer q.lock.Unlock() + var txLists [][]*types.Transaction + var uncleLists [][]*types.Header + var withdrawalLists [][]*types.Withdrawal + validate := func(index int, header *types.Header) error { - if txListHashes[index] != header.TxHash { + if hashes.TransactionRoots[index] != header.TxHash { return errInvalidBody } - if uncleListHashes[index] != header.UncleHash { + if hashes.UncleHashes[index] != header.UncleHash { return errInvalidBody } if header.WithdrawalsHash == nil { // nil hash means that withdrawals should not be present in body - if withdrawalLists[index] != nil { + if bodies[index].Withdrawals != nil { return errInvalidBody } } else { // non-nil hash: body must have withdrawals - if withdrawalLists[index] == nil { + if bodies[index].Withdrawals == nil { return errInvalidBody } - if withdrawalListHashes[index] != *header.WithdrawalsHash { + if hashes.WithdrawalRoots[index] != *header.WithdrawalsHash { return errInvalidBody } } - // Blocks must have a number of blobs corresponding to the header gas usage, - // and zero before the Cancun hardfork. - var blobs int - for _, tx := range txLists[index] { - // Count the number of blobs to validate against the header's blobGasUsed - blobs += len(tx.BlobHashes()) - - // Validate the data blobs individually too - if tx.Type() == types.BlobTxType { - if len(tx.BlobHashes()) == 0 { - return errInvalidBody - } - for _, hash := range tx.BlobHashes() { - if !kzg4844.IsValidVersionedHash(hash[:]) { - return errInvalidBody - } - } - if tx.BlobTxSidecar() != nil { - return errInvalidBody - } - } + // decode + txs, err := bodies[index].Transactions.Items() + if err != nil { + return fmt.Errorf("%w: bad transactions: %v", errInvalidBody, err) } - if header.BlobGasUsed != nil { - if want := *header.BlobGasUsed / params.BlobTxBlobGasPerBlob; uint64(blobs) != want { // div because the header is surely good vs the body might be bloated - return errInvalidBody + txLists = append(txLists, txs) + uncles, err := bodies[index].Uncles.Items() + if err != nil { + return fmt.Errorf("%w: bad uncles: %v", errInvalidBody, err) + } + uncleLists = append(uncleLists, uncles) + if bodies[index].Withdrawals != nil { + withdrawals, err := bodies[index].Withdrawals.Items() + if err != nil { + return fmt.Errorf("%w: bad withdrawals: %v", errInvalidBody, err) } + withdrawalLists = append(withdrawalLists, withdrawals) } else { - if blobs != 0 { - return errInvalidBody - } + withdrawalLists = append(withdrawalLists, nil) } return nil } @@ -838,14 +828,15 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListH result.Withdrawals = withdrawalLists[index] result.SetBodyDone() } + nresults := len(hashes.TransactionRoots) return q.deliver(id, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, - bodyReqTimer, bodyInMeter, bodyDropMeter, len(txLists), validate, reconstruct) + bodyReqTimer, bodyInMeter, bodyDropMeter, nresults, validate, reconstruct) } // DeliverReceipts injects a receipt retrieval response into the results queue. // The method returns the number of transaction receipts accepted from the delivery // and also wakes any threads waiting for data delivery. -func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt, receiptListHashes []common.Hash) (int, error) { +func (q *queue) DeliverReceipts(id string, receiptList []types.Receipts, receiptListHashes []common.Hash) (int, error) { q.lock.Lock() defer q.lock.Unlock() diff --git a/eth/downloader/queue_test.go b/eth/downloader/queue_test.go index 857ac4813..255e1030b 100644 --- a/eth/downloader/queue_test.go +++ b/eth/downloader/queue_test.go @@ -30,8 +30,10 @@ import ( "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -323,26 +325,30 @@ func XTestDelivery(t *testing.T) { emptyList []*types.Header txset [][]*types.Transaction uncleset [][]*types.Header + bodies []eth.BlockBody ) numToSkip := rand.Intn(len(f.Headers)) for _, hdr := range f.Headers[0 : len(f.Headers)-numToSkip] { - txset = append(txset, world.getTransactions(hdr.Number.Uint64())) + txs := world.getTransactions(hdr.Number.Uint64()) + txset = append(txset, txs) uncleset = append(uncleset, emptyList) + txsList, _ := rlp.EncodeToRawList(txs) + bodies = append(bodies, eth.BlockBody{Transactions: txsList}) + } + hashes := eth.BlockBodyHashes{ + TransactionRoots: make([]common.Hash, len(txset)), + UncleHashes: make([]common.Hash, len(uncleset)), + WithdrawalRoots: make([]common.Hash, len(txset)), } - var ( - txsHashes = make([]common.Hash, len(txset)) - uncleHashes = make([]common.Hash, len(uncleset)) - ) hasher := trie.NewStackTrie(nil) for i, txs := range txset { - txsHashes[i] = types.DeriveSha(types.Transactions(txs), hasher) + hashes.TransactionRoots[i] = types.DeriveSha(types.Transactions(txs), hasher) } for i, uncles := range uncleset { - uncleHashes[i] = types.CalcUncleHash(uncles) + hashes.UncleHashes[i] = types.CalcUncleHash(uncles) } time.Sleep(100 * time.Millisecond) - _, err := q.DeliverBodies(peer.id, txset, txsHashes, uncleset, uncleHashes, nil, nil) - if err != nil { + if _, err := q.DeliverBodies(peer.id, hashes, bodies); err != nil { fmt.Printf("delivered %d bodies %v\n", len(txset), err) } } else { @@ -358,14 +364,14 @@ func XTestDelivery(t *testing.T) { for { f, _, _ := q.ReserveReceipts(peer, rand.Intn(50)) if f != nil { - var rcs [][]*types.Receipt + var rcs []types.Receipts for _, hdr := range f.Headers { rcs = append(rcs, world.getReceipts(hdr.Number.Uint64())) } hasher := trie.NewStackTrie(nil) hashes := make([]common.Hash, len(rcs)) for i, receipt := range rcs { - hashes[i] = types.DeriveSha(types.Receipts(receipt), hasher) + hashes[i] = types.DeriveSha(receipt, hasher) } _, err := q.DeliverReceipts(peer.id, rcs, hashes) if err != nil { diff --git a/eth/fetcher/block_fetcher.go b/eth/fetcher/block_fetcher.go index 603f6246c..b4cdcf9d5 100644 --- a/eth/fetcher/block_fetcher.go +++ b/eth/fetcher/block_fetcher.go @@ -541,7 +541,11 @@ func (f *BlockFetcher) loop() { case res := <-resCh: res.Done <- nil // Ignoring withdrawals here, since the block fetcher is not used post-merge. - txs, uncles, _ := res.Res.(*eth.BlockBodiesResponse).Unpack() + txs, uncles, _, err := res.Res.(*eth.BlockBodiesResponse).Unpack() + if err != nil { + f.dropPeer(peer) + return + } f.FilterBodies(peer, txs, uncles, time.Now()) case <-timeout.C: diff --git a/eth/fetcher/block_fetcher_test.go b/eth/fetcher/block_fetcher_test.go index cb7cbaf79..4f4292b81 100644 --- a/eth/fetcher/block_fetcher_test.go +++ b/eth/fetcher/block_fetcher_test.go @@ -32,6 +32,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/triedb" ) @@ -234,21 +235,11 @@ func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*typ // Create a function that returns blocks from the closure return func(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) { // Gather the block bodies to return - transactions := make([][]*types.Transaction, 0, len(hashes)) - uncles := make([][]*types.Header, 0, len(hashes)) + bodies := make([]eth.BlockBody, len(hashes)) - for _, hash := range hashes { + for i, hash := range hashes { if block, ok := closure[hash]; ok { - transactions = append(transactions, block.Transactions()) - uncles = append(uncles, block.Uncles()) - } - } - // Return on a new thread - bodies := make([]*eth.BlockBody, len(transactions)) - for i, txs := range transactions { - bodies[i] = ð.BlockBody{ - Transactions: txs, - Uncles: uncles[i], + bodies[i] = encodeBody(block) } } req := ð.Request{ @@ -267,6 +258,26 @@ func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*typ } } +func encodeBody(b *types.Block) eth.BlockBody { + body := eth.BlockBody{ + Transactions: encodeRL([]*types.Transaction(b.Transactions())), + Uncles: encodeRL(b.Uncles()), + } + if b.Withdrawals() != nil { + wd := encodeRL([]*types.Withdrawal(b.Withdrawals())) + body.Withdrawals = &wd + } + return body +} + +func encodeRL[T any](slice []T) rlp.RawList[T] { + rl, err := rlp.EncodeToRawList(slice) + if err != nil { + panic(err) + } + return rl +} + // verifyFetchingEvent verifies that one single event arrive on a fetching channel. func verifyFetchingEvent(t *testing.T, fetching chan []common.Hash, arrive bool) { t.Helper() diff --git a/eth/handler_eth.go b/eth/handler_eth.go index 8edf60e10..fbe096c29 100644 --- a/eth/handler_eth.go +++ b/eth/handler_eth.go @@ -71,15 +71,24 @@ func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error { return h.txFetcher.Notify(peer.ID(), packet.Types, packet.Sizes, packet.Hashes) case *eth.TransactionsPacket: - for _, tx := range *packet { - if tx.Type() == types.BlobTxType { - return errors.New("disallowed broadcast blob transaction") - } + txs, err := packet.Items() + if err != nil { + return fmt.Errorf("Transactions: %v", err) } - return h.txFetcher.Enqueue(peer.ID(), *packet, false) + if err := handleTransactions(peer, txs, true); err != nil { + return fmt.Errorf("Transactions: %v", err) + } + return h.txFetcher.Enqueue(peer.ID(), txs, false) - case *eth.PooledTransactionsResponse: - return h.txFetcher.Enqueue(peer.ID(), *packet, true) + case *eth.PooledTransactionsPacket: + txs, err := packet.List.Items() + if err != nil { + return fmt.Errorf("PooledTransactions: %v", err) + } + if err := handleTransactions(peer, txs, false); err != nil { + return fmt.Errorf("PooledTransactions: %v", err) + } + return h.txFetcher.Enqueue(peer.ID(), txs, true) default: return fmt.Errorf("unexpected eth packet type: %T", packet) @@ -137,3 +146,33 @@ func (h *ethHandler) handleBlockBroadcast(peer *eth.Peer, block *types.Block, td } return nil } + +// handleTransactions marks all given transactions as known to the peer +// and performs basic validations. +func handleTransactions(peer *eth.Peer, list []*types.Transaction, directBroadcast bool) error { + seen := make(map[common.Hash]struct{}) + for _, tx := range list { + if tx.Type() == types.BlobTxType { + if directBroadcast { + return errors.New("disallowed broadcast blob transaction") + } else { + // If we receive any blob transactions missing sidecars, + // disconnect from the sending peer. + if tx.BlobTxSidecar() == nil { + return errors.New("received sidecar-less blob transaction") + } + } + } + + // Check for duplicates. + hash := tx.Hash() + if _, exists := seen[hash]; exists { + return fmt.Errorf("multiple copies of the same hash %v", hash) + } + seen[hash] = struct{}{} + + // Mark as known. + peer.MarkTransaction(hash) + } + return nil +} diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go index 964f8422e..72ddd0078 100644 --- a/eth/handler_eth_test.go +++ b/eth/handler_eth_test.go @@ -63,11 +63,19 @@ func (h *testEthHandler) Handle(peer *eth.Peer, packet eth.Packet) error { return nil case *eth.TransactionsPacket: - h.txBroadcasts.Send(([]*types.Transaction)(*packet)) + txs, err := packet.Items() + if err != nil { + return err + } + h.txBroadcasts.Send(txs) return nil - case *eth.PooledTransactionsResponse: - h.txBroadcasts.Send(([]*types.Transaction)(*packet)) + case *eth.PooledTransactionsPacket: + txs, err := packet.List.Items() + if err != nil { + return err + } + h.txBroadcasts.Send(txs) return nil default: diff --git a/eth/protocols/eth/dispatcher.go b/eth/protocols/eth/dispatcher.go index ae98820cd..eed377cf6 100644 --- a/eth/protocols/eth/dispatcher.go +++ b/eth/protocols/eth/dispatcher.go @@ -22,6 +22,7 @@ import ( "time" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" ) var ( @@ -47,9 +48,10 @@ type Request struct { sink chan *Response // Channel to deliver the response on cancel chan struct{} // Channel to cancel requests ahead of time - code uint64 // Message code of the request packet - want uint64 // Message code of the response packet - data interface{} // Data content of the request packet + code uint64 // Message code of the request packet + want uint64 // Message code of the response packet + numItems int // Number of requested items + data interface{} // Data content of the request packet Peer string // Demultiplexer if cross-peer requests are batched together Sent time.Time // Timestamp when the request was sent @@ -136,7 +138,7 @@ func (p *Peer) dispatchRequest(req *Request) error { } } -// dispatchRequest fulfils a pending request and delivers it to the requested +// dispatchResponse fulfils a pending request and delivers it to the requested // sink. func (p *Peer) dispatchResponse(res *Response, metadata func() interface{}) error { resOp := &response{ @@ -188,20 +190,34 @@ func (p *Peer) dispatchResponse(res *Response, metadata func() interface{}) erro func (p *Peer) dispatcher() { pending := make(map[uint64]*Request) +loop: for { select { case reqOp := <-p.reqDispatch: req := reqOp.req req.Sent = time.Now() - requestTracker.Track(p.id, p.version, req.code, req.want, req.id) - err := p2p.Send(p.rw, req.code, req.data) - reqOp.fail <- err - - if err == nil { - pending[req.id] = req + // Register the request with the tracker before sending it on the + // wire. This ensures any incoming response can be matched even if + // it arrives before the send call returns. + treq := tracker.Request{ + ID: req.id, + ReqCode: req.code, + RespCode: req.want, + Size: req.numItems, + } + if err := p.tracker.Track(treq); err != nil { + reqOp.fail <- err + continue loop + } + if err := p2p.Send(p.rw, req.code, req.data); err != nil { + reqOp.fail <- err + continue loop } + pending[req.id] = req + reqOp.fail <- nil + case cancelOp := <-p.reqCancel: // Retrieve the pending request to cancel and short circuit if it // has already been serviced and is not available anymore @@ -218,9 +234,6 @@ func (p *Peer) dispatcher() { res := resOp.res res.Req = pending[res.id] - // Independent if the request exists or not, track this packet - requestTracker.Fulfil(p.id, p.version, res.code, res.id) - switch { case res.Req == nil: // Response arrived with an untracked ID. Since even cancelled @@ -247,6 +260,7 @@ func (p *Peer) dispatcher() { } case <-p.term: + p.tracker.Stop() return } } diff --git a/eth/protocols/eth/handler.go b/eth/protocols/eth/handler.go index 35c99d26a..45bbae3d7 100644 --- a/eth/protocols/eth/handler.go +++ b/eth/protocols/eth/handler.go @@ -194,7 +194,12 @@ func handleMessage(backend Backend, peer *Peer) error { } defer msg.Discard() - var handlers = eth68 + var handlers map[uint64]msgHandler + if peer.version == ETH68 { + handlers = eth68 + } else { + return fmt.Errorf("unknown eth protocol version: %v", peer.version) + } // Track the amount of time it takes to serve the request and run the handler if metrics.Enabled() { diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go index 307b13d38..4c43e052e 100644 --- a/eth/protocols/eth/handler_test.go +++ b/eth/protocols/eth/handler_test.go @@ -20,8 +20,10 @@ import ( "math" "math/big" "math/rand" + "reflect" "testing" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/beacon" @@ -37,6 +39,8 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" ) var ( @@ -339,8 +343,8 @@ func testGetBlockHeaders(t *testing.T, protocol uint) { GetBlockHeadersRequest: tt.query, }) if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, &BlockHeadersPacket{ - RequestId: 123, - BlockHeadersRequest: headers, + RequestId: 123, + List: encodeRL(headers), }); err != nil { t.Errorf("test %d: headers mismatch: %v", i, err) } @@ -353,7 +357,7 @@ func testGetBlockHeaders(t *testing.T, protocol uint) { RequestId: 456, GetBlockHeadersRequest: tt.query, }) - expected := &BlockHeadersPacket{RequestId: 456, BlockHeadersRequest: headers} + expected := &BlockHeadersPacket{RequestId: 456, List: encodeRL(headers)} if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, expected); err != nil { t.Errorf("test %d by hash: headers mismatch: %v", i, err) } @@ -417,7 +421,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) { // Collect the hashes to request, and the response to expect var ( hashes []common.Hash - bodies []*BlockBody + bodies []BlockBody seen = make(map[int64]bool) ) for j := 0; j < tt.random; j++ { @@ -429,7 +433,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) { block := backend.chain.GetBlockByNumber(uint64(num)) hashes = append(hashes, block.Hash()) if len(bodies) < tt.expected { - bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()}) + bodies = append(bodies, encodeBody(block)) } break } @@ -439,7 +443,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) { hashes = append(hashes, hash) if tt.available[j] && len(bodies) < tt.expected { block := backend.chain.GetBlockByHash(hash) - bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()}) + bodies = append(bodies, encodeBody(block)) } } @@ -449,14 +453,65 @@ func testGetBlockBodies(t *testing.T, protocol uint) { GetBlockBodiesRequest: hashes, }) if err := p2p.ExpectMsg(peer.app, BlockBodiesMsg, &BlockBodiesPacket{ - RequestId: 123, - BlockBodiesResponse: bodies, + RequestId: 123, + List: encodeRL(bodies), }); err != nil { t.Fatalf("test %d: bodies mismatch: %v", i, err) } } } +func encodeBody(b *types.Block) BlockBody { + body := BlockBody{ + Transactions: encodeRL([]*types.Transaction(b.Transactions())), + Uncles: encodeRL(b.Uncles()), + } + if b.Withdrawals() != nil { + wd := encodeRL([]*types.Withdrawal(b.Withdrawals())) + body.Withdrawals = &wd + } + return body +} + +func TestHashBody(t *testing.T) { + key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") + signer := types.NewCancunSigner(big.NewInt(1)) + + // create block 1 + header := &types.Header{Number: big.NewInt(11)} + txs := []*types.Transaction{ + types.MustSignNewTx(key, signer, &types.DynamicFeeTx{ + ChainID: big.NewInt(1), + Nonce: 1, + Data: []byte("testing"), + }), + types.MustSignNewTx(key, signer, &types.LegacyTx{ + Nonce: 2, + Data: []byte("testing"), + }), + } + uncles := []*types.Header{{Number: big.NewInt(10)}} + block1 := types.NewBlockWithWithdrawals(header, txs, uncles, nil, nil, trie.NewStackTrie(nil)) + + // create block 2 (has withdrawals) + header2 := &types.Header{Number: big.NewInt(12)} + block2 := types.NewBlockWithWithdrawals(header2, nil, nil, nil, []*types.Withdrawal{{Index: 10}, {Index: 11}}, trie.NewStackTrie(nil)) + + expectedHashes := BlockBodyHashes{ + TransactionRoots: []common.Hash{block1.TxHash(), block2.TxHash()}, + WithdrawalRoots: []common.Hash{common.Hash{}, *block2.Header().WithdrawalsHash}, + UncleHashes: []common.Hash{block1.UncleHash(), block2.UncleHash()}, + } + + // compute hash like protocol handler does + protocolBodies := []BlockBody{encodeBody(block1), encodeBody(block2)} + hashes := hashBodyParts(protocolBodies) + if !reflect.DeepEqual(hashes, expectedHashes) { + t.Errorf("wrong hashes: %s", spew.Sdump(hashes)) + t.Logf("expected: %s", spew.Sdump(expectedHashes)) + } +} + // Tests that the transaction receipts can be retrieved based on hashes. func TestGetBlockReceipts68(t *testing.T) { testGetBlockReceipts(t, ETH68) } @@ -508,13 +563,13 @@ func testGetBlockReceipts(t *testing.T, protocol uint) { // Collect the hashes to request, and the response to expect var ( hashes []common.Hash - receipts [][]*types.Receipt + receipts rlp.RawList[rlp.RawList[*types.Receipt]] ) for i := uint64(0); i <= backend.chain.CurrentBlock().Number.Uint64(); i++ { block := backend.chain.GetBlockByNumber(i) - hashes = append(hashes, block.Hash()) - receipts = append(receipts, backend.chain.GetReceiptsByHash(block.Hash())) + trs := backend.chain.GetReceiptsByHash(block.Hash()) + receipts.Append(encodeRL(trs)) } // Send the hash request and verify the response p2p.Send(peer.app, GetReceiptsMsg, &GetReceiptsPacket{ @@ -522,9 +577,17 @@ func testGetBlockReceipts(t *testing.T, protocol uint) { GetReceiptsRequest: hashes, }) if err := p2p.ExpectMsg(peer.app, ReceiptsMsg, &ReceiptsPacket{ - RequestId: 123, - ReceiptsResponse: receipts, + RequestId: 123, + List: receipts, }); err != nil { t.Errorf("receipts mismatch: %v", err) } } + +func encodeRL[T any](slice []T) rlp.RawList[T] { + rl, err := rlp.EncodeToRawList(slice) + if err != nil { + panic(err) + } + return rl +} diff --git a/eth/protocols/eth/handlers.go b/eth/protocols/eth/handlers.go index 5d4f47e49..d693cbc5c 100644 --- a/eth/protocols/eth/handlers.go +++ b/eth/protocols/eth/handlers.go @@ -17,13 +17,17 @@ package eth import ( + "bytes" "encoding/json" "fmt" + "math" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/tracker" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -326,9 +330,17 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockHeadersMsg, Size: res.List.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("BlockHeaders: %w", err) + } + headers, err := res.List.Items() + if err != nil { + return fmt.Errorf("BlockHeaders: %w", err) + } metadata := func() interface{} { - hashes := make([]common.Hash, len(res.BlockHeadersRequest)) - for i, header := range res.BlockHeadersRequest { + hashes := make([]common.Hash, len(headers)) + for i, header := range headers { hashes[i] = header.Hash() } return hashes @@ -336,7 +348,7 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error { return peer.dispatchResponse(&Response{ id: res.RequestId, code: BlockHeadersMsg, - Res: &res.BlockHeadersRequest, + Res: (*BlockHeadersRequest)(&headers), }, metadata) } @@ -346,47 +358,156 @@ func handleBlockBodies(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - metadata := func() interface{} { - var ( - txsHashes = make([]common.Hash, len(res.BlockBodiesResponse)) - uncleHashes = make([]common.Hash, len(res.BlockBodiesResponse)) - withdrawalHashes = make([]common.Hash, len(res.BlockBodiesResponse)) - ) - hasher := trie.NewStackTrie(nil) - for i, body := range res.BlockBodiesResponse { - txsHashes[i] = types.DeriveSha(types.Transactions(body.Transactions), hasher) - uncleHashes[i] = types.CalcUncleHash(body.Uncles) - if body.Withdrawals != nil { - withdrawalHashes[i] = types.DeriveSha(types.Withdrawals(body.Withdrawals), hasher) - } - } - return [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes} + + // Check against the request. + length := res.List.Len() + tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockBodiesMsg, Size: length} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("BlockBodies: %w", err) + } + + // Collect items and dispatch. + items, err := res.List.Items() + if err != nil { + return fmt.Errorf("BlockBodies: %w", err) } + metadata := func() any { return hashBodyParts(items) } return peer.dispatchResponse(&Response{ id: res.RequestId, code: BlockBodiesMsg, - Res: &res.BlockBodiesResponse, + Res: (*BlockBodiesResponse)(&items), }, metadata) } +// BlockBodyHashes contains the lists of block body part roots for a list of block bodies. +type BlockBodyHashes struct { + TransactionRoots []common.Hash + WithdrawalRoots []common.Hash + UncleHashes []common.Hash +} + +// hashBodyParts computes the MPT root hashes of the transactions and withdrawals, +// and the uncle hash, for each block body. It operates directly on the raw RLP +// bytes in each RawList, avoiding a full decode of the body contents. +func hashBodyParts(items []BlockBody) BlockBodyHashes { + h := BlockBodyHashes{ + TransactionRoots: make([]common.Hash, len(items)), + WithdrawalRoots: make([]common.Hash, len(items)), + UncleHashes: make([]common.Hash, len(items)), + } + hasher := trie.NewStackTrie(nil) + for i, body := range items { + // txs + txsList := newDerivableRawList(&body.Transactions, writeTxForHash) + h.TransactionRoots[i] = types.DeriveSha(txsList, hasher) + // uncles + if body.Uncles.Len() == 0 { + h.UncleHashes[i] = types.EmptyUncleHash + } else { + h.UncleHashes[i] = crypto.Keccak256Hash(body.Uncles.Bytes()) + } + // withdrawals + if body.Withdrawals != nil { + wdlist := newDerivableRawList(body.Withdrawals, nil) + h.WithdrawalRoots[i] = types.DeriveSha(wdlist, hasher) + } + } + return h +} + +// derivableRawList implements types.DerivableList for a serialized RLP list. +type derivableRawList struct { + data []byte + offsets []uint32 + write func([]byte, *bytes.Buffer) +} + +// newDerivableRawList creates a derivableRawList from a RawList. The write +// function transforms each raw element before it is written to the hash buffer; +// pass nil to use the identity transform (write raw bytes as-is). +func newDerivableRawList[T any](list *rlp.RawList[T], write func([]byte, *bytes.Buffer)) *derivableRawList { + dl := derivableRawList{data: list.Content(), write: write} + if dl.write == nil { + // default transform is identity + dl.write = func(b []byte, buf *bytes.Buffer) { buf.Write(b) } + } + // Assert to ensure 32-bit offsets are valid. This can never trigger + // unless a block body component is larger than 4GB. + if uint(len(dl.data)) > math.MaxUint32 { + panic("list data too big for derivableRawList") + } + it := list.ContentIterator() + dl.offsets = make([]uint32, list.Len()) + for i := 0; it.Next(); i++ { + dl.offsets[i] = uint32(it.Offset()) + } + return &dl +} + +// Len returns the number of items in the list. +func (dl *derivableRawList) Len() int { + return len(dl.offsets) +} + +// EncodeIndex writes the i'th item to the buffer. +func (dl *derivableRawList) EncodeIndex(i int, buf *bytes.Buffer) { + start := dl.offsets[i] + end := uint32(len(dl.data)) + if i != len(dl.offsets)-1 { + end = dl.offsets[i+1] + } + dl.write(dl.data[start:end], buf) +} + +// writeTxForHash changes a transaction in 'network encoding' into the format used for +// the transactions MPT. +func writeTxForHash(tx []byte, buf *bytes.Buffer) { + k, content, _, _ := rlp.Split(tx) + if k == rlp.List { + buf.Write(tx) // legacy tx + } else { + buf.Write(content) // typed tx + } +} + func handleReceipts(backend Backend, msg Decoder, peer *Peer) error { // A batch of receipts arrived to one of our previous requests res := new(ReceiptsPacket) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + + tresp := tracker.Response{ID: res.RequestId, MsgCode: ReceiptsMsg, Size: res.List.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("Receipts: %w", err) + } + + // Decode the two-level list: outer level is per-block, inner level is per-receipt. + blockLists, err := res.List.Items() + if err != nil { + return fmt.Errorf("Receipts: %w", err) + } + decoded := make([]types.Receipts, len(blockLists)) + for i := range blockLists { + items, err := blockLists[i].Items() + if err != nil { + return fmt.Errorf("Receipts: invalid list %d: %v", i, err) + } + decoded[i] = items + } metadata := func() interface{} { hasher := trie.NewStackTrie(nil) - hashes := make([]common.Hash, len(res.ReceiptsResponse)) - for i, receipt := range res.ReceiptsResponse { - hashes[i] = types.DeriveSha(types.Receipts(receipt), hasher) + hashes := make([]common.Hash, len(decoded)) + for i, receipts := range decoded { + hashes[i] = types.DeriveSha(receipts, hasher) } return hashes } + enc := ReceiptsResponse(decoded) return peer.dispatchResponse(&Response{ id: res.RequestId, code: ReceiptsMsg, - Res: &res.ReceiptsResponse, + Res: &enc, }, metadata) } @@ -405,7 +526,7 @@ func handleNewPooledTransactionHashes(backend Backend, msg Decoder, peer *Peer) } // Schedule all the unknown hashes for retrieval for _, hash := range ann.Hashes { - peer.markTransaction(hash) + peer.MarkTransaction(hash) } return backend.Handle(peer, ann) } @@ -458,12 +579,8 @@ func handleTransactions(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(&txs); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - for i, tx := range txs { - // Validate and mark the remote transaction - if tx == nil { - return fmt.Errorf("%w: transaction %d is nil", errDecode, i) - } - peer.markTransaction(tx.Hash()) + if txs.Len() > maxTransactionAnnouncements { + return fmt.Errorf("too many transactions") } return backend.Handle(peer, &txs) } @@ -473,19 +590,20 @@ func handlePooledTransactions(backend Backend, msg Decoder, peer *Peer) error { if !backend.AcceptTxs() { return nil } - // Transactions can be processed, parse all of them and deliver to the pool - var txs PooledTransactionsPacket - if err := msg.Decode(&txs); err != nil { + + // Check against request and decode. + var resp PooledTransactionsPacket + if err := msg.Decode(&resp); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - for i, tx := range txs.PooledTransactionsResponse { - // Validate and mark the remote transaction - if tx == nil { - return fmt.Errorf("%w: transaction %d is nil", errDecode, i) - } - peer.markTransaction(tx.Hash()) + tresp := tracker.Response{ + ID: resp.RequestId, + MsgCode: PooledTransactionsMsg, + Size: resp.List.Len(), + } + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("PooledTransactions: %w", err) } - requestTracker.Fulfil(peer.id, peer.version, PooledTransactionsMsg, txs.RequestId) - return backend.Handle(peer, &txs.PooledTransactionsResponse) + return backend.Handle(peer, &resp) } diff --git a/eth/protocols/eth/peer.go b/eth/protocols/eth/peer.go index d5c501a49..6f57f61e0 100644 --- a/eth/protocols/eth/peer.go +++ b/eth/protocols/eth/peer.go @@ -22,11 +22,13 @@ import ( "math/big" "math/rand" "sync" + "time" mapset "github.com/deckarep/golang-set/v2" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" "github.com/ethereum/go-ethereum/rlp" ) @@ -68,11 +70,12 @@ func max(a, b int) int { // Peer is a collection of relevant information we have about a `eth` peer. type Peer struct { + *p2p.Peer // The embedded P2P package peer + id string // Unique ID for the peer, cached - *p2p.Peer // The embedded P2P package peer - rw p2p.MsgReadWriter // Input/output streams for snap - version uint // Protocol version negotiated + rw p2p.MsgReadWriter // Input/output streams for snap + version uint // Protocol version negotiated head common.Hash // Latest advertised head block hash td *big.Int // Latest advertised head block total difficulty @@ -86,9 +89,10 @@ type Peer struct { txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests - reqDispatch chan *request // Dispatch channel to send requests and track then until fulfillment - reqCancel chan *cancel // Dispatch channel to cancel pending requests and untrack them - resDispatch chan *response // Dispatch channel to fulfil pending requests and untrack them + tracker *tracker.Tracker // Per-peer request/response tracker with timeout enforcement + reqDispatch chan *request // Dispatch channel to send requests and track then until fulfillment + reqCancel chan *cancel // Dispatch channel to cancel pending requests and untrack them + resDispatch chan *response // Dispatch channel to fulfil pending requests and untrack them term chan struct{} // Termination channel to stop the broadcasters lock sync.RWMutex // Mutex protecting the internal fields @@ -99,8 +103,10 @@ type Peer struct { // NewPeer creates a wrapper for a network connection and negotiated protocol // version. func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Peer { + cap := p2p.Cap{Name: ProtocolName, Version: version} + id := p.ID().String() peer := &Peer{ - id: p.ID().String(), + id: id, Peer: p, rw: rw, version: version, @@ -110,6 +116,7 @@ func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Pe queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns), txBroadcast: make(chan []common.Hash), txAnnounce: make(chan []common.Hash), + tracker: tracker.New(cap, id, 5*time.Minute), reqDispatch: make(chan *request), reqCancel: make(chan *cancel), resDispatch: make(chan *response), @@ -177,9 +184,9 @@ func (p *Peer) markBlock(hash common.Hash) { p.knownBlocks.Add(hash) } -// markTransaction marks a transaction as known for the peer, ensuring that it +// MarkTransaction marks a transaction as known for the peer, ensuring that it // will never be propagated to this particular peer. -func (p *Peer) markTransaction(hash common.Hash) { +func (p *Peer) MarkTransaction(hash common.Hash) { // If we reached the memory allowance, drop a previously known transaction hash p.knownTxs.Add(hash) } @@ -333,10 +340,11 @@ func (p *Peer) RequestOneHeader(hash common.Hash, sink chan *Response) (*Request id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockHeadersMsg, - want: BlockHeadersMsg, + id: id, + sink: sink, + code: GetBlockHeadersMsg, + want: BlockHeadersMsg, + numItems: 1, data: &GetBlockHeadersPacket{ RequestId: id, GetBlockHeadersRequest: &GetBlockHeadersRequest{ @@ -360,10 +368,11 @@ func (p *Peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, re id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockHeadersMsg, - want: BlockHeadersMsg, + id: id, + sink: sink, + code: GetBlockHeadersMsg, + want: BlockHeadersMsg, + numItems: amount, data: &GetBlockHeadersPacket{ RequestId: id, GetBlockHeadersRequest: &GetBlockHeadersRequest{ @@ -387,10 +396,11 @@ func (p *Peer) RequestHeadersByNumber(origin uint64, amount int, skip int, rever id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockHeadersMsg, - want: BlockHeadersMsg, + id: id, + sink: sink, + code: GetBlockHeadersMsg, + want: BlockHeadersMsg, + numItems: amount, data: &GetBlockHeadersPacket{ RequestId: id, GetBlockHeadersRequest: &GetBlockHeadersRequest{ @@ -414,10 +424,11 @@ func (p *Peer) RequestBodies(hashes []common.Hash, sink chan *Response) (*Reques id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockBodiesMsg, - want: BlockBodiesMsg, + id: id, + sink: sink, + code: GetBlockBodiesMsg, + want: BlockBodiesMsg, + numItems: len(hashes), data: &GetBlockBodiesPacket{ RequestId: id, GetBlockBodiesRequest: hashes, @@ -435,10 +446,11 @@ func (p *Peer) RequestReceipts(hashes []common.Hash, sink chan *Response) (*Requ id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetReceiptsMsg, - want: ReceiptsMsg, + id: id, + sink: sink, + code: GetReceiptsMsg, + want: ReceiptsMsg, + numItems: len(hashes), data: &GetReceiptsPacket{ RequestId: id, GetReceiptsRequest: hashes, @@ -455,7 +467,15 @@ func (p *Peer) RequestTxs(hashes []common.Hash) error { p.Log().Debug("Fetching batch of transactions", "count", len(hashes)) id := rand.Uint64() - requestTracker.Track(p.id, p.version, GetPooledTransactionsMsg, PooledTransactionsMsg, id) + err := p.tracker.Track(tracker.Request{ + ID: id, + ReqCode: GetPooledTransactionsMsg, + RespCode: PooledTransactionsMsg, + Size: len(hashes), + }) + if err != nil { + return err + } return p2p.Send(p.rw, GetPooledTransactionsMsg, &GetPooledTransactionsPacket{ RequestId: id, GetPooledTransactionsRequest: hashes, diff --git a/eth/protocols/eth/protocol.go b/eth/protocols/eth/protocol.go index 47e8d9724..d98d10ab5 100644 --- a/eth/protocols/eth/protocol.go +++ b/eth/protocols/eth/protocol.go @@ -48,6 +48,9 @@ var protocolLengths = map[uint]uint64{ETH68: 17} // maxMessageSize is the maximum cap on the size of a protocol message. const maxMessageSize = 10 * 1024 * 1024 +// This is the maximum number of transactions in a Transactions message. +const maxTransactionAnnouncements = 5000 + const ( StatusMsg = 0x00 NewBlockHashesMsg = 0x01 @@ -112,7 +115,9 @@ func (p *NewBlockHashesPacket) Unpack() ([]common.Hash, []uint64) { } // TransactionsPacket is the network packet for broadcasting new transactions. -type TransactionsPacket []*types.Transaction +type TransactionsPacket struct { + rlp.RawList[*types.Transaction] +} // GetBlockHeadersRequest represents a block header query. type GetBlockHeadersRequest struct { @@ -170,7 +175,7 @@ type BlockHeadersRequest []*types.Header // BlockHeadersPacket represents a block header response over with request ID wrapping. type BlockHeadersPacket struct { RequestId uint64 - BlockHeadersRequest + List rlp.RawList[*types.Header] } // BlockHeadersRLPResponse represents a block header response, to use when we already @@ -211,14 +216,11 @@ type GetBlockBodiesPacket struct { GetBlockBodiesRequest } -// BlockBodiesResponse is the network packet for block content distribution. -type BlockBodiesResponse []*BlockBody - // BlockBodiesPacket is the network packet for block content distribution with // request ID wrapping. type BlockBodiesPacket struct { RequestId uint64 - BlockBodiesResponse + List rlp.RawList[BlockBody] } // BlockBodiesRLPResponse is used for replying to block body requests, in cases @@ -232,26 +234,41 @@ type BlockBodiesRLPPacket struct { BlockBodiesRLPResponse } -// BlockBody represents the data content of a single block. -type BlockBody struct { - Transactions []*types.Transaction // Transactions contained within a block - Uncles []*types.Header // Uncles contained within a block - Withdrawals []*types.Withdrawal `rlp:"optional"` // Withdrawals contained within a block +// BlockBodiesResponse is the network packet for block content distribution. +type BlockBodiesResponse []BlockBody + +// Unpack retrieves the transactions, uncles, and withdrawals from each block body. +func (b *BlockBodiesResponse) Unpack() ([][]*types.Transaction, [][]*types.Header, [][]*types.Withdrawal, error) { + txLists := make([][]*types.Transaction, len(*b)) + uncleLists := make([][]*types.Header, len(*b)) + withdrawalLists := make([][]*types.Withdrawal, len(*b)) + for i, body := range *b { + txs, err := body.Transactions.Items() + if err != nil { + return nil, nil, nil, fmt.Errorf("body %d: bad transactions: %v", i, err) + } + txLists[i] = txs + uncles, err := body.Uncles.Items() + if err != nil { + return nil, nil, nil, fmt.Errorf("body %d: bad uncles: %v", i, err) + } + uncleLists[i] = uncles + if body.Withdrawals != nil { + withdrawals, err := body.Withdrawals.Items() + if err != nil { + return nil, nil, nil, fmt.Errorf("body %d: bad withdrawals: %v", i, err) + } + withdrawalLists[i] = withdrawals + } + } + return txLists, uncleLists, withdrawalLists, nil } -// Unpack retrieves the transactions and uncles from the range packet and returns -// them in a split flat format that's more consistent with the internal data structures. -func (p *BlockBodiesResponse) Unpack() ([][]*types.Transaction, [][]*types.Header, [][]*types.Withdrawal) { - // TODO(matt): add support for withdrawals to fetchers - var ( - txset = make([][]*types.Transaction, len(*p)) - uncleset = make([][]*types.Header, len(*p)) - withdrawalset = make([][]*types.Withdrawal, len(*p)) - ) - for i, body := range *p { - txset[i], uncleset[i], withdrawalset[i] = body.Transactions, body.Uncles, body.Withdrawals - } - return txset, uncleset, withdrawalset +// BlockBody represents the data content of a single block. +type BlockBody struct { + Transactions rlp.RawList[*types.Transaction] // Transactions contained within a block + Uncles rlp.RawList[*types.Header] // Uncles contained within a block + Withdrawals *rlp.RawList[*types.Withdrawal] `rlp:"optional"` // Withdrawals contained within a block } // GetReceiptsRequest represents a block receipts query. @@ -264,13 +281,15 @@ type GetReceiptsPacket struct { } // ReceiptsResponse is the network packet for block receipts distribution. -type ReceiptsResponse [][]*types.Receipt +type ReceiptsResponse []types.Receipts // ReceiptsPacket is the network packet for block receipts distribution with -// request ID wrapping. +// request ID wrapping. The outer list contains one entry per requested block; +// each entry is itself a list of receipts for that block. Both levels are kept +// as RawList to defer RLP decoding until the data is actually needed. type ReceiptsPacket struct { RequestId uint64 - ReceiptsResponse + List rlp.RawList[rlp.RawList[*types.Receipt]] } // ReceiptsRLPResponse is used for receipts, when we already have it encoded @@ -305,7 +324,7 @@ type PooledTransactionsResponse []*types.Transaction // with request ID wrapping. type PooledTransactionsPacket struct { RequestId uint64 - PooledTransactionsResponse + List rlp.RawList[*types.Transaction] } // PooledTransactionsRLPResponse is the network packet for transaction distribution, used @@ -348,8 +367,8 @@ func (*NewPooledTransactionHashesPacket) Kind() byte { return NewPooledTransac func (*GetPooledTransactionsRequest) Name() string { return "GetPooledTransactions" } func (*GetPooledTransactionsRequest) Kind() byte { return GetPooledTransactionsMsg } -func (*PooledTransactionsResponse) Name() string { return "PooledTransactions" } -func (*PooledTransactionsResponse) Kind() byte { return PooledTransactionsMsg } +func (*PooledTransactionsPacket) Name() string { return "PooledTransactions" } +func (*PooledTransactionsPacket) Kind() byte { return PooledTransactionsMsg } func (*GetReceiptsRequest) Name() string { return "GetReceipts" } func (*GetReceiptsRequest) Kind() byte { return GetReceiptsMsg } diff --git a/eth/protocols/eth/protocol_test.go b/eth/protocols/eth/protocol_test.go index bc2545dea..1633f7dbf 100644 --- a/eth/protocols/eth/protocol_test.go +++ b/eth/protocols/eth/protocol_test.go @@ -78,34 +78,33 @@ func TestEmptyMessages(t *testing.T) { for i, msg := range []interface{}{ // Headers GetBlockHeadersPacket{1111, nil}, - BlockHeadersPacket{1111, nil}, // Bodies GetBlockBodiesPacket{1111, nil}, - BlockBodiesPacket{1111, nil}, BlockBodiesRLPPacket{1111, nil}, // Receipts GetReceiptsPacket{1111, nil}, - ReceiptsPacket{1111, nil}, // Transactions GetPooledTransactionsPacket{1111, nil}, - PooledTransactionsPacket{1111, nil}, PooledTransactionsRLPPacket{1111, nil}, // Headers - BlockHeadersPacket{1111, BlockHeadersRequest([]*types.Header{})}, + BlockHeadersPacket{1111, encodeRL([]*types.Header{})}, // Bodies GetBlockBodiesPacket{1111, GetBlockBodiesRequest([]common.Hash{})}, - BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{})}, + BlockBodiesPacket{1111, encodeRL([]BlockBody{})}, BlockBodiesRLPPacket{1111, BlockBodiesRLPResponse([]rlp.RawValue{})}, // Receipts GetReceiptsPacket{1111, GetReceiptsRequest([]common.Hash{})}, - ReceiptsPacket{1111, ReceiptsResponse([][]*types.Receipt{})}, + ReceiptsPacket{1111, encodeRL([]rlp.RawList[*types.Receipt]{})}, // Transactions GetPooledTransactionsPacket{1111, GetPooledTransactionsRequest([]common.Hash{})}, - PooledTransactionsPacket{1111, PooledTransactionsResponse([]*types.Transaction{})}, + PooledTransactionsPacket{1111, encodeRL([]*types.Transaction{})}, PooledTransactionsRLPPacket{1111, PooledTransactionsRLPResponse([]rlp.RawValue{})}, } { - if have, _ := rlp.EncodeToBytes(msg); !bytes.Equal(have, want) { + have, err := rlp.EncodeToBytes(msg) + if err != nil { + t.Errorf("test %d, type %T, error: %v", i, msg, err) + } else if !bytes.Equal(have, want) { t.Errorf("test %d, type %T, have\n\t%x\nwant\n\t%x", i, msg, have, want) } } @@ -116,7 +115,7 @@ func TestMessages(t *testing.T) { // Some basic structs used during testing var ( header *types.Header - blockBody *BlockBody + blockBody BlockBody blockBodyRlp rlp.RawValue txs []*types.Transaction txRlps []rlp.RawValue @@ -150,9 +149,9 @@ func TestMessages(t *testing.T) { } } // init the block body data, both object and rlp form - blockBody = &BlockBody{ - Transactions: txs, - Uncles: []*types.Header{header}, + blockBody = BlockBody{ + Transactions: encodeRL(txs), + Uncles: encodeRL([]*types.Header{header}), } blockBodyRlp, err = rlp.EncodeToBytes(blockBody) if err != nil { @@ -201,7 +200,7 @@ func TestMessages(t *testing.T) { common.FromHex("ca820457c682270f050580"), }, { - BlockHeadersPacket{1111, BlockHeadersRequest{header}}, + BlockHeadersPacket{1111, encodeRL([]*types.Header{header})}, common.FromHex("f90202820457f901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"), }, { @@ -209,7 +208,7 @@ func TestMessages(t *testing.T) { common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"), }, { - BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{blockBody})}, + BlockBodiesPacket{1111, encodeRL([]BlockBody{blockBody})}, common.FromHex("f902dc820457f902d6f902d3f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afbf901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"), }, { // Identical to non-rlp-shortcut version @@ -221,7 +220,7 @@ func TestMessages(t *testing.T) { common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"), }, { - ReceiptsPacket{1111, ReceiptsResponse([][]*types.Receipt{receipts})}, + ReceiptsPacket{1111, encodeRL([]rlp.RawList[*types.Receipt]{encodeRL(receipts)})}, common.FromHex("f90172820457f9016cf90169f901668001b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100ff"), }, { @@ -233,7 +232,7 @@ func TestMessages(t *testing.T) { common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"), }, { - PooledTransactionsPacket{1111, PooledTransactionsResponse(txs)}, + PooledTransactionsPacket{1111, encodeRL(txs)}, common.FromHex("f8d7820457f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afb"), }, { diff --git a/eth/protocols/eth/tracker.go b/eth/protocols/eth/tracker.go deleted file mode 100644 index 324fd2283..000000000 --- a/eth/protocols/eth/tracker.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2021 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package eth - -import ( - "time" - - "github.com/ethereum/go-ethereum/p2p/tracker" -) - -// requestTracker is a singleton tracker for eth/66 and newer request times. -var requestTracker = tracker.New(ProtocolName, 5*time.Minute) diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 003e014f4..c4c0f131f 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -29,6 +29,8 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/tracker" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/trienode" ) @@ -97,6 +99,7 @@ func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol { Length: protocolLengths[version], Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { return backend.RunPeer(NewPeer(version, p, rw), func(peer *Peer) error { + defer peer.Close() return Handle(backend, peer) }) }, @@ -153,7 +156,6 @@ func HandleMessage(backend Backend, peer *Peer) error { // Handle the message depending on its contents switch { case msg.Code == GetAccountRangeMsg: - // Decode the account retrieval request var req GetAccountRangePacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -169,23 +171,40 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == AccountRangeMsg: - // A range of accounts arrived to one of our previous requests - res := new(AccountRangePacket) + res := new(accountRangeInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("AccountRange: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return err + } + + // Decode. + accounts, err := res.Accounts.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid accounts list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid proof: %v", err) + } + // Ensure the range is monotonically increasing - for i := 1; i < len(res.Accounts); i++ { - if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 { - return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:]) + for i := 1; i < len(accounts); i++ { + if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 { + return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:]) } } - requestTracker.Fulfil(peer.id, peer.version, AccountRangeMsg, res.ID) - return backend.Handle(peer, res) + return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof}) case msg.Code == GetStorageRangesMsg: - // Decode the storage retrieval request var req GetStorageRangesPacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -201,25 +220,42 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == StorageRangesMsg: - // A range of storage slots arrived to one of our previous requests - res := new(StorageRangesPacket) + res := new(storageRangesInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("StorageRanges: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("StorageRanges: %w", err) + } + + // Decode. + slotLists, err := res.Slots.Items() + if err != nil { + return fmt.Errorf("StorageRanges: invalid slot list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("StorageRanges: invalid proof: %v", err) + } + // Ensure the ranges are monotonically increasing - for i, slots := range res.Slots { + for i, slots := range slotLists { for j := 1; j < len(slots); j++ { if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 { return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:]) } } } - requestTracker.Fulfil(peer.id, peer.version, StorageRangesMsg, res.ID) - return backend.Handle(peer, res) + return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof}) case msg.Code == GetByteCodesMsg: - // Decode bytecode retrieval request var req GetByteCodesPacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -234,17 +270,25 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == ByteCodesMsg: - // A batch of byte codes arrived to one of our previous requests - res := new(ByteCodesPacket) + res := new(byteCodesInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - requestTracker.Fulfil(peer.id, peer.version, ByteCodesMsg, res.ID) - return backend.Handle(peer, res) + length := res.Codes.Len() + tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + codes, err := res.Codes.Items() + if err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + return backend.Handle(peer, &ByteCodesPacket{res.ID, codes}) case msg.Code == GetTrieNodesMsg: - // Decode trie node retrieval request var req GetTrieNodesPacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -261,14 +305,21 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == TrieNodesMsg: - // A batch of trie nodes arrived to one of our previous requests - res := new(TrieNodesPacket) + res := new(trieNodesInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - requestTracker.Fulfil(peer.id, peer.version, TrieNodesMsg, res.ID) - return backend.Handle(peer, res) + tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + nodes, err := res.Nodes.Items() + if err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + + return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes}) default: return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) @@ -496,19 +547,29 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s snap := chain.Snapshots().Snapshot(req.Root) // Retrieve trie nodes until the packet size limit is reached var ( - nodes [][]byte - bytes uint64 - loads int // Trie hash expansions to count database reads + outerIt = req.Paths.ContentIterator() + nodes [][]byte + bytes uint64 + loads int // Trie hash expansions to count database reads ) - for _, pathset := range req.Paths { - switch len(pathset) { + for outerIt.Next() { + innerIt, err := rlp.NewListIterator(outerIt.Value()) + if err != nil { + return nodes, err + } + + switch innerIt.Count() { case 0: // Ensure we penalize invalid requests return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest) case 1: // If we're only retrieving an account trie node, fetch it directly - blob, resolved, err := accTrie.GetNode(pathset[0]) + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest) + } + blob, resolved, err := accTrie.GetNode(accKey) loads += resolved // always account database reads, even for failures if err != nil { break @@ -517,32 +578,41 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s bytes += uint64(len(blob)) default: - var stRoot common.Hash // Storage slots requested, open the storage trie and retrieve from there + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest) + } + var stRoot common.Hash if snap == nil { // We don't have the requested state snapshotted yet (or it is stale), // but can look up the account via the trie instead. - account, err := accTrie.GetAccountByHash(common.BytesToHash(pathset[0])) + account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey)) loads += 8 // We don't know the exact cost of lookup, this is an estimate if err != nil || account == nil { break } stRoot = account.Root } else { - account, err := snap.Account(common.BytesToHash(pathset[0])) + account, err := snap.Account(common.BytesToHash(accKey)) loads++ // always account database reads, even for failures if err != nil || account == nil { break } stRoot = common.BytesToHash(account.Root) } - id := trie.StorageTrieID(req.Root, common.BytesToHash(pathset[0]), stRoot) + + id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot) stTrie, err := trie.NewStateTrie(id, triedb) loads++ // always account database reads, even for failures if err != nil { break } - for _, path := range pathset[1:] { + for innerIt.Next() { + path, _, err := rlp.SplitString(innerIt.Value()) + if err != nil { + return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err) + } blob, resolved, err := stTrie.GetNode(path) loads += resolved // always account database reads, even for failures if err != nil { @@ -565,6 +635,17 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s return nodes, nil } +func nextBytes(it *rlp.Iterator) []byte { + if !it.Next() { + return nil + } + content, _, err := rlp.SplitString(it.Value()) + if err != nil { + return nil + } + return content +} + // NodeInfo represents a short summary of the `snap` sub-protocol metadata // known about the host peer. type NodeInfo struct{} diff --git a/eth/protocols/snap/peer.go b/eth/protocols/snap/peer.go index c57931678..7b5c70146 100644 --- a/eth/protocols/snap/peer.go +++ b/eth/protocols/snap/peer.go @@ -17,9 +17,13 @@ package snap import ( + "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" + "github.com/ethereum/go-ethereum/rlp" ) // Peer is a collection of relevant information we have about a `snap` peer. @@ -29,6 +33,7 @@ type Peer struct { *p2p.Peer // The embedded P2P package peer rw p2p.MsgReadWriter // Input/output streams for snap version uint // Protocol version negotiated + tracker *tracker.Tracker logger log.Logger // Contextual logger with the peer id injected } @@ -36,22 +41,26 @@ type Peer struct { // NewPeer creates a wrapper for a network connection and negotiated protocol // version. func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer { + cap := p2p.Cap{Name: ProtocolName, Version: version} id := p.ID().String() return &Peer{ id: id, Peer: p, rw: rw, version: version, + tracker: tracker.New(cap, id, 1*time.Minute), logger: log.New("peer", id[:8]), } } // NewFakePeer creates a fake snap peer without a backing p2p peer, for testing purposes. func NewFakePeer(version uint, id string, rw p2p.MsgReadWriter) *Peer { + cap := p2p.Cap{Name: ProtocolName, Version: version} return &Peer{ id: id, rw: rw, version: version, + tracker: tracker.New(cap, id, 1*time.Minute), logger: log.New("peer", id[:8]), } } @@ -71,63 +80,102 @@ func (p *Peer) Log() log.Logger { return p.logger } +// Close releases resources associated with the peer. +func (p *Peer) Close() { + p.tracker.Stop() +} + // RequestAccountRange fetches a batch of accounts rooted in a specific account // trie, starting with the origin. -func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error { +func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes int) error { p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes)) - requestTracker.Track(p.id, p.version, GetAccountRangeMsg, AccountRangeMsg, id) + err := p.tracker.Track(tracker.Request{ + ReqCode: GetAccountRangeMsg, + RespCode: AccountRangeMsg, + ID: id, + Size: 2 * bytes, + }) + if err != nil { + return err + } return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{ ID: id, Root: root, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), }) } // RequestStorageRanges fetches a batch of storage slots belonging to one or more // accounts. If slots from only one account is requested, an origin marker may also // be used to retrieve from there. -func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error { +func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error { if len(accounts) == 1 && origin != nil { p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes)) } else { p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes)) } - requestTracker.Track(p.id, p.version, GetStorageRangesMsg, StorageRangesMsg, id) + + err := p.tracker.Track(tracker.Request{ + ReqCode: GetStorageRangesMsg, + RespCode: StorageRangesMsg, + ID: id, + Size: 2 * bytes, + }) + if err != nil { + return err + } return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{ ID: id, Root: root, Accounts: accounts, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), }) } // RequestByteCodes fetches a batch of bytecodes by hash. -func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error { +func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error { p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes)) - requestTracker.Track(p.id, p.version, GetByteCodesMsg, ByteCodesMsg, id) + err := p.tracker.Track(tracker.Request{ + ReqCode: GetByteCodesMsg, + RespCode: ByteCodesMsg, + ID: id, + Size: len(hashes), // ByteCodes is limited by the length of the hash list. + }) + if err != nil { + return err + } return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{ ID: id, Hashes: hashes, - Bytes: bytes, + Bytes: uint64(bytes), }) } // RequestTrieNodes fetches a batch of account or storage trie nodes rooted in -// a specific state trie. -func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error { +// a specific state trie. The `count` is the total count of paths being requested. +func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error { p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes)) - requestTracker.Track(p.id, p.version, GetTrieNodesMsg, TrieNodesMsg, id) + err := p.tracker.Track(tracker.Request{ + ReqCode: GetTrieNodesMsg, + RespCode: TrieNodesMsg, + ID: id, + Size: count, // TrieNodes is limited by number of items. + }) + if err != nil { + return err + } + encPaths, _ := rlp.EncodeToRawList(paths) return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{ ID: id, Root: root, - Paths: paths, - Bytes: bytes, + Paths: encPaths, + Bytes: uint64(bytes), }) } diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go index 0db206b08..25fe25822 100644 --- a/eth/protocols/snap/protocol.go +++ b/eth/protocols/snap/protocol.go @@ -78,6 +78,12 @@ type GetAccountRangePacket struct { Bytes uint64 // Soft limit at which to stop returning data } +type accountRangeInput struct { + ID uint64 // ID of the request this is a response for + Accounts rlp.RawList[*AccountData] // List of consecutive accounts from the trie + Proof rlp.RawList[[]byte] // List of trie nodes proving the account range +} + // AccountRangePacket represents an account query response. type AccountRangePacket struct { ID uint64 // ID of the request this is a response for @@ -123,6 +129,12 @@ type GetStorageRangesPacket struct { Bytes uint64 // Soft limit at which to stop returning data } +type storageRangesInput struct { + ID uint64 // ID of the request this is a response for + Slots rlp.RawList[[]*StorageData] // Lists of consecutive storage slots for the requested accounts + Proof rlp.RawList[[]byte] // Merkle proofs for the *last* slot range, if it's incomplete +} + // StorageRangesPacket represents a storage slot query response. type StorageRangesPacket struct { ID uint64 // ID of the request this is a response for @@ -161,6 +173,11 @@ type GetByteCodesPacket struct { Bytes uint64 // Soft limit at which to stop returning data } +type byteCodesInput struct { + ID uint64 // ID of the request this is a response for + Codes rlp.RawList[[]byte] // Requested contract bytecodes +} + // ByteCodesPacket represents a contract bytecode query response. type ByteCodesPacket struct { ID uint64 // ID of the request this is a response for @@ -169,10 +186,10 @@ type ByteCodesPacket struct { // GetTrieNodesPacket represents a state trie node query. type GetTrieNodesPacket struct { - ID uint64 // Request ID to match up responses with - Root common.Hash // Root hash of the account trie to serve - Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for - Bytes uint64 // Soft limit at which to stop returning data + ID uint64 // Request ID to match up responses with + Root common.Hash // Root hash of the account trie to serve + Paths rlp.RawList[TrieNodePathSet] // Trie node hashes to retrieve the nodes for + Bytes uint64 // Soft limit at which to stop returning data } // TrieNodePathSet is a list of trie node paths to retrieve. A naive way to @@ -187,6 +204,11 @@ type GetTrieNodesPacket struct { // that a slot is accessed before the account path is fully expanded. type TrieNodePathSet [][]byte +type trieNodesInput struct { + ID uint64 // ID of the request this is a response for + Nodes rlp.RawList[[]byte] // Requested state trie nodes +} + // TrieNodesPacket represents a state trie node query response. type TrieNodesPacket struct { ID uint64 // ID of the request this is a response for diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go index 050c2c4cc..0dbb1c886 100644 --- a/eth/protocols/snap/sync.go +++ b/eth/protocols/snap/sync.go @@ -413,19 +413,19 @@ type SyncPeer interface { // RequestAccountRange fetches a batch of accounts rooted in a specific account // trie, starting with the origin. - RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error + RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error // RequestStorageRanges fetches a batch of storage slots belonging to one or // more accounts. If slots from only one account is requested, an origin marker // may also be used to retrieve from there. - RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error + RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error // RequestByteCodes fetches a batch of bytecodes by hash. - RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error + RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error // RequestTrieNodes fetches a batch of account or storage trie nodes rooted in // a specific state trie. - RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error + RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error // Log retrieves the peer's own contextual logger. Log() log.Logger @@ -1083,7 +1083,7 @@ func (s *Syncer) assignAccountTasks(success chan *accountResponse, fail chan *ac if cap < minRequestSize { // Don't bother with peers below a bare minimum performance cap = minRequestSize } - if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, uint64(cap)); err != nil { + if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, cap); err != nil { peer.Log().Debug("Failed to request account range", "err", err) s.scheduleRevertAccountRequest(req) } @@ -1340,7 +1340,7 @@ func (s *Syncer) assignStorageTasks(success chan *storageResponse, fail chan *st if subtask != nil { origin, limit = req.origin[:], req.limit[:] } - if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, uint64(cap)); err != nil { + if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, cap); err != nil { log.Debug("Failed to request storage", "err", err) s.scheduleRevertStorageRequest(req) } @@ -1473,7 +1473,7 @@ func (s *Syncer) assignTrienodeHealTasks(success chan *trienodeHealResponse, fai defer s.pend.Done() // Attempt to send the remote request and revert if it fails - if err := peer.RequestTrieNodes(reqid, root, pathsets, maxRequestSize); err != nil { + if err := peer.RequestTrieNodes(reqid, root, len(paths), pathsets, maxRequestSize); err != nil { log.Debug("Failed to request trienode healers", "err", err) s.scheduleRevertTrienodeHealRequest(req) } diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 7d893e454..673a76e8b 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -119,10 +119,10 @@ func BenchmarkHashing(b *testing.B) { } type ( - accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error - storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error - trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error - codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error + accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error + storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error + trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error + codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max int) error ) type testPeer struct { @@ -182,21 +182,21 @@ Trienode requests: %d `, t.nAccountRequests, t.nStorageRequests, t.nBytecodeRequests, t.nTrienodeRequests) } -func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error { +func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error { t.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes)) t.nAccountRequests++ go t.accountRequestHandler(t, id, root, origin, limit, bytes) return nil } -func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error { +func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error { t.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes)) t.nTrienodeRequests++ go t.trieRequestHandler(t, id, root, paths, bytes) return nil } -func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error { +func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error { t.nStorageRequests++ if len(accounts) == 1 && origin != nil { t.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes)) @@ -207,7 +207,7 @@ func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts [] return nil } -func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error { +func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error { t.nBytecodeRequests++ t.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes)) go t.codeRequestHandler(t, id, hashes, bytes) @@ -215,7 +215,7 @@ func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint6 } // defaultTrieRequestHandler is a well-behaving handler for trie healing requests -func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error { +func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error { // Pass the response var nodes [][]byte for _, pathset := range paths { @@ -244,7 +244,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, } // defaultAccountRequestHandler is a well-behaving handler for AccountRangeRequests -func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { keys, vals, proofs := createAccountRequestResponse(t, root, origin, limit, cap) if err := t.remote.OnAccounts(t, id, keys, vals, proofs); err != nil { t.test.Errorf("Remote side rejected our delivery: %v", err) @@ -254,8 +254,8 @@ func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, orig return nil } -func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) (keys []common.Hash, vals [][]byte, proofs [][]byte) { - var size uint64 +func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap int) (keys []common.Hash, vals [][]byte, proofs [][]byte) { + var size int if limit == (common.Hash{}) { limit = common.MaxHash } @@ -266,7 +266,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H if bytes.Compare(origin[:], entry.k) <= 0 { keys = append(keys, common.BytesToHash(entry.k)) vals = append(vals, entry.v) - size += uint64(32 + len(entry.v)) + size += 32 + len(entry.v) } // If we've exceeded the request threshold, abort if bytes.Compare(entry.k, limit[:]) >= 0 { @@ -293,7 +293,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H } // defaultStorageRequestHandler is a well-behaving storage request handler -func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) error { +func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) error { hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, bOrigin, bLimit, max) if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil { t.test.Errorf("Remote side rejected our delivery: %v", err) @@ -302,7 +302,7 @@ func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Has return nil } -func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { +func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error { var bytecodes [][]byte for _, h := range hashes { bytecodes = append(bytecodes, getCodeByHash(h)) @@ -314,8 +314,8 @@ func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max return nil } -func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { - var size uint64 +func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { + var size int for _, account := range accounts { // The first account might start from a different origin and end sooner var originHash common.Hash @@ -341,7 +341,7 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm } keys = append(keys, common.BytesToHash(entry.k)) vals = append(vals, entry.v) - size += uint64(32 + len(entry.v)) + size += 32 + len(entry.v) if bytes.Compare(entry.k, limitHash[:]) >= 0 { break } @@ -382,8 +382,8 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm // createStorageRequestResponseAlwaysProve tests a cornercase, where the peer always // supplies the proof for the last account, even if it is 'complete'. -func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { - var size uint64 +func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { + var size int max = max * 3 / 4 var origin common.Hash @@ -400,7 +400,7 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco } keys = append(keys, common.BytesToHash(entry.k)) vals = append(vals, entry.v) - size += uint64(32 + len(entry.v)) + size += 32 + len(entry.v) if size > max { exit = true } @@ -440,34 +440,34 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco } // emptyRequestAccountRangeFn is a rejects AccountRangeRequests -func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { t.remote.OnAccounts(t, requestId, nil, nil, nil) return nil } -func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { return nil } -func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error { +func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error { t.remote.OnTrieNodes(t, requestId, nil) return nil } -func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error { +func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error { return nil } -func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { t.remote.OnStorage(t, requestId, nil, nil, nil) return nil } -func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { return nil } -func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { hashes, slots, proofs := createStorageRequestResponseAlwaysProve(t, root, accounts, origin, limit, max) if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil { t.test.Errorf("Remote side rejected our delivery: %v", err) @@ -482,7 +482,7 @@ func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common. // return nil //} -func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { +func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error { var bytecodes [][]byte for _, h := range hashes { // Send back the hashes @@ -496,7 +496,7 @@ func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max return nil } -func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { +func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error { var bytecodes [][]byte for _, h := range hashes[:1] { bytecodes = append(bytecodes, getCodeByHash(h)) @@ -510,11 +510,11 @@ func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max } // starvingStorageRequestHandler is somewhat well-behaving storage handler, but it caps the returned results to be very small -func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { return defaultStorageRequestHandler(t, requestId, root, accounts, origin, limit, 500) } -func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { return defaultAccountRequestHandler(t, requestId, root, origin, limit, 500) } @@ -522,7 +522,7 @@ func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Ha // return defaultAccountRequestHandler(t, requestId-1, root, origin, 500) //} -func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { hashes, accounts, proofs := createAccountRequestResponse(t, root, origin, limit, cap) if len(proofs) > 0 { proofs = proofs[1:] @@ -536,7 +536,7 @@ func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Has } // corruptStorageRequestHandler doesn't provide good proofs -func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, origin, limit, max) if len(proofs) > 0 { proofs = proofs[1:] @@ -549,7 +549,7 @@ func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Has return nil } -func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { hashes, slots, _ := createStorageRequestResponse(t, root, accounts, origin, limit, max) if err := t.remote.OnStorage(t, requestId, hashes, slots, nil); err != nil { t.logger.Info("remote error on delivery (as expected)", "error", err) @@ -584,7 +584,7 @@ func testSyncBloatedProof(t *testing.T, scheme string) { source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { + source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { var ( proofs [][]byte keys []common.Hash @@ -1177,7 +1177,7 @@ func testSyncNoStorageAndOneCodeCappedPeer(t *testing.T, scheme string) { var counter int syncer := setupSyncer( nodeScheme, - mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { + mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max int) error { counter++ return cappedCodeRequestHandler(t, id, hashes, max) }), @@ -1444,7 +1444,7 @@ func testSyncWithUnevenStorage(t *testing.T, scheme string) { source.accountValues = accounts source.setStorageTries(storageTries) source.storageValues = storageElems - source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { + source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { return defaultStorageRequestHandler(t, reqId, root, accounts, origin, limit, 128) // retrieve storage in large mode } return source diff --git a/eth/protocols/snap/tracker.go b/eth/protocols/snap/tracker.go deleted file mode 100644 index 2cf59cc23..000000000 --- a/eth/protocols/snap/tracker.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2021 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package snap - -import ( - "time" - - "github.com/ethereum/go-ethereum/p2p/tracker" -) - -// requestTracker is a singleton tracker for request times. -var requestTracker = tracker.New(ProtocolName, time.Minute) diff --git a/p2p/tracker/tracker.go b/p2p/tracker/tracker.go index 5b72eb2b8..17ba3688e 100644 --- a/p2p/tracker/tracker.go +++ b/p2p/tracker/tracker.go @@ -18,12 +18,14 @@ package tracker import ( "container/list" + "errors" "fmt" "sync" "time" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p" ) const ( @@ -42,28 +44,44 @@ const ( // maxTrackedPackets is a huge number to act as a failsafe on the number of // pending requests the node will track. It should never be hit unless an // attacker figures out a way to spin requests. - maxTrackedPackets = 100000 + maxTrackedPackets = 10000 ) -// request tracks sent network requests which have not yet received a response. -type request struct { - peer string - version uint // Protocol version +var ( + ErrNoMatchingRequest = errors.New("no matching request") + ErrTooManyItems = errors.New("response is larger than request allows") + ErrCollision = errors.New("request ID collision") + ErrCodeMismatch = errors.New("wrong response code for request") + ErrLimitReached = errors.New("request limit reached") + ErrStopped = errors.New("tracker stopped") +) - reqCode uint64 // Protocol message code of the request - resCode uint64 // Protocol message code of the expected response +// request tracks sent network requests which have not yet received a response. +type Request struct { + ID uint64 // Request ID + Size int // Number/size of requested items + ReqCode uint64 // Protocol message code of the request + RespCode uint64 // Protocol message code of the expected response time time.Time // Timestamp when the request was made expire *list.Element // Expiration marker to untrack it } +type Response struct { + ID uint64 // Request ID of the response + MsgCode uint64 // Protocol message code + Size int // number/size of items in response +} + // Tracker is a pending network request tracker to measure how much time it takes // a remote peer to respond. type Tracker struct { - protocol string // Protocol capability identifier for the metrics - timeout time.Duration // Global timeout after which to drop a tracked packet + cap p2p.Cap // Protocol capability identifier for the metrics + + peer string // Peer ID + timeout time.Duration // Global timeout after which to drop a tracked packet - pending map[uint64]*request // Currently pending requests + pending map[uint64]*Request // Currently pending requests expire *list.List // Linked list tracking the expiration order wake *time.Timer // Timer tracking the expiration of the next item @@ -72,52 +90,53 @@ type Tracker struct { // New creates a new network request tracker to monitor how much time it takes to // fill certain requests and how individual peers perform. -func New(protocol string, timeout time.Duration) *Tracker { +func New(cap p2p.Cap, peerID string, timeout time.Duration) *Tracker { return &Tracker{ - protocol: protocol, - timeout: timeout, - pending: make(map[uint64]*request), - expire: list.New(), + cap: cap, + peer: peerID, + timeout: timeout, + pending: make(map[uint64]*Request), + expire: list.New(), } } // Track adds a network request to the tracker to wait for a response to arrive // or until the request it cancelled or times out. -func (t *Tracker) Track(peer string, version uint, reqCode uint64, resCode uint64, id uint64) { - if !metrics.Enabled() { - return - } +func (t *Tracker) Track(req Request) error { t.lock.Lock() defer t.lock.Unlock() + if t.expire == nil { + return ErrStopped + } + // If there's a duplicate request, we've just random-collided (or more probably, // we have a bug), report it. We could also add a metric, but we're not really // expecting ourselves to be buggy, so a noisy warning should be enough. - if _, ok := t.pending[id]; ok { - log.Error("Network request id collision", "protocol", t.protocol, "version", version, "code", reqCode, "id", id) - return + if _, ok := t.pending[req.ID]; ok { + log.Error("Network request id collision", "cap", t.cap, "code", req.ReqCode, "id", req.ID) + return ErrCollision } // If we have too many pending requests, bail out instead of leaking memory if pending := len(t.pending); pending >= maxTrackedPackets { - log.Error("Request tracker exceeded allowance", "pending", pending, "peer", peer, "protocol", t.protocol, "version", version, "code", reqCode) - return + log.Error("Request tracker exceeded allowance", "pending", pending, "peer", t.peer, "cap", t.cap, "code", req.ReqCode) + return ErrLimitReached } + // Id doesn't exist yet, start tracking it - t.pending[id] = &request{ - peer: peer, - version: version, - reqCode: reqCode, - resCode: resCode, - time: time.Now(), - expire: t.expire.PushBack(id), + req.time = time.Now() + req.expire = t.expire.PushBack(req.ID) + t.pending[req.ID] = &req + + if metrics.Enabled() { + t.trackedGauge(req.ReqCode).Inc(1) } - g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, version, reqCode) - metrics.GetOrRegisterGauge(g, nil).Inc(1) // If we've just inserted the first item, start the expiration timer if t.wake == nil { t.wake = time.AfterFunc(t.timeout, t.clean) } + return nil } // clean is called automatically when a preset time passes without a response @@ -126,6 +145,10 @@ func (t *Tracker) clean() { t.lock.Lock() defer t.lock.Unlock() + if t.expire == nil { + return // Tracker was stopped. + } + // Expire anything within a certain threshold (might be no items at all if // we raced with the delivery) for t.expire.Len() > 0 { @@ -142,64 +165,111 @@ func (t *Tracker) clean() { t.expire.Remove(head) delete(t.pending, id) - g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode) - metrics.GetOrRegisterGauge(g, nil).Dec(1) - - m := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.protocol, req.version, req.reqCode) - metrics.GetOrRegisterMeter(m, nil).Mark(1) + if metrics.Enabled() { + t.trackedGauge(req.ReqCode).Dec(1) + t.lostMeter(req.ReqCode).Mark(1) + } } t.schedule() } -// schedule starts a timer to trigger on the expiration of the first network -// packet. +// schedule starts a timer to trigger on the expiration of the first network packet. func (t *Tracker) schedule() { if t.expire.Len() == 0 { t.wake = nil return } - t.wake = time.AfterFunc(time.Until(t.pending[t.expire.Front().Value.(uint64)].time.Add(t.timeout)), t.clean) + nextID := t.expire.Front().Value.(uint64) + nextTime := t.pending[nextID].time + t.wake = time.AfterFunc(time.Until(nextTime.Add(t.timeout)), t.clean) } -// Fulfil fills a pending request, if any is available, reporting on various metrics. -func (t *Tracker) Fulfil(peer string, version uint, code uint64, id uint64) { - if !metrics.Enabled() { - return +// Stop reclaims resources of the tracker. +func (t *Tracker) Stop() { + t.lock.Lock() + defer t.lock.Unlock() + + if t.wake != nil { + t.wake.Stop() + t.wake = nil + } + if metrics.Enabled() { + // Ensure metrics are decremented for pending requests. + counts := make(map[uint64]int64) + for _, req := range t.pending { + counts[req.ReqCode]++ + } + for code, count := range counts { + t.trackedGauge(code).Dec(count) + } } + clear(t.pending) + t.expire = nil +} + +// Fulfil fills a pending request, if any is available, reporting on various metrics. +func (t *Tracker) Fulfil(resp Response) error { t.lock.Lock() defer t.lock.Unlock() // If it's a non existing request, track as stale response - req, ok := t.pending[id] + req, ok := t.pending[resp.ID] if !ok { - m := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.protocol, version, code) - metrics.GetOrRegisterMeter(m, nil).Mark(1) - return + if metrics.Enabled() { + t.staleMeter(resp.MsgCode).Mark(1) + } + return ErrNoMatchingRequest } + // If the response is funky, it might be some active attack - if req.peer != peer || req.version != version || req.resCode != code { - log.Warn("Network response id collision", - "have", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, version, code), - "want", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, req.version, req.resCode), + if req.RespCode != resp.MsgCode { + log.Warn("Network response code collision", + "have", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, resp.MsgCode), + "want", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, req.RespCode), ) - return + return ErrCodeMismatch + } + if resp.Size > req.Size { + return ErrTooManyItems } + // Everything matches, mark the request serviced and meter it + wasHead := req.expire.Prev() == nil t.expire.Remove(req.expire) - delete(t.pending, id) - if req.expire.Prev() == nil { + delete(t.pending, req.ID) + if wasHead { if t.wake.Stop() { t.schedule() } } - g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode) - metrics.GetOrRegisterGauge(g, nil).Dec(1) - h := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.protocol, req.version, req.reqCode) + // Update request metrics. + if metrics.Enabled() { + t.trackedGauge(req.ReqCode).Dec(1) + t.waitHistogram(req.ReqCode).Update(time.Since(req.time).Microseconds()) + } + return nil +} + +func (t *Tracker) trackedGauge(code uint64) *metrics.Gauge { + name := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.cap.Name, t.cap.Version, code) + return metrics.GetOrRegisterGauge(name, nil) +} + +func (t *Tracker) lostMeter(code uint64) *metrics.Meter { + name := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.cap.Name, t.cap.Version, code) + return metrics.GetOrRegisterMeter(name, nil) +} + +func (t *Tracker) staleMeter(code uint64) *metrics.Meter { + name := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.cap.Name, t.cap.Version, code) + return metrics.GetOrRegisterMeter(name, nil) +} + +func (t *Tracker) waitHistogram(code uint64) metrics.Histogram { + name := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.cap.Name, t.cap.Version, code) sampler := func() metrics.Sample { - return metrics.ResettingSample( - metrics.NewExpDecaySample(1028, 0.015), - ) + return metrics.ResettingSample(metrics.NewExpDecaySample(1028, 0.015)) } - metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(req.time).Microseconds()) + return metrics.GetOrRegisterHistogramLazy(name, nil, sampler) } diff --git a/p2p/tracker/tracker_test.go b/p2p/tracker/tracker_test.go new file mode 100644 index 000000000..95e962964 --- /dev/null +++ b/p2p/tracker/tracker_test.go @@ -0,0 +1,83 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package tracker + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p" +) + +// TestCleanAfterStop verifies that the clean method does not crash when called +// after Stop. This can happen because clean is scheduled via time.AfterFunc and +// may fire after Stop sets t.expire to nil. +func TestCleanAfterStop(t *testing.T) { + cap := p2p.Cap{Name: "test", Version: 1} + timeout := 50 * time.Millisecond + tr := New(cap, "peer1", timeout) + + // Track a request to start the expiration timer. + tr.Track(Request{ID: 1, ReqCode: 0x01, RespCode: 0x02, Size: 1}) + + // Stop the tracker, then wait for the timer to fire. + tr.Stop() + time.Sleep(timeout + 50*time.Millisecond) + + // Also verify that calling clean directly after stop doesn't panic. + tr.clean() +} + +// This checks that metrics gauges for pending requests are be decremented when a +// Tracker is stopped. +func TestMetricsOnStop(t *testing.T) { + metrics.Enable() + + cap := p2p.Cap{Name: "test", Version: 1} + tr := New(cap, "peer1", time.Minute) + + // Track some requests with different ReqCodes. + var id uint64 + for i := 0; i < 3; i++ { + tr.Track(Request{ID: id, ReqCode: 0x01, RespCode: 0x02, Size: 1}) + id++ + } + for i := 0; i < 5; i++ { + tr.Track(Request{ID: id, ReqCode: 0x03, RespCode: 0x04, Size: 1}) + id++ + } + + gauge1 := tr.trackedGauge(0x01) + gauge2 := tr.trackedGauge(0x03) + + if gauge1.Snapshot().Value() != 3 { + t.Fatalf("gauge1 value mismatch: got %d, want 3", gauge1.Snapshot().Value()) + } + if gauge2.Snapshot().Value() != 5 { + t.Fatalf("gauge2 value mismatch: got %d, want 5", gauge2.Snapshot().Value()) + } + + tr.Stop() + + if gauge1.Snapshot().Value() != 0 { + t.Fatalf("gauge1 value after stop: got %d, want 0", gauge1.Snapshot().Value()) + } + if gauge2.Snapshot().Value() != 0 { + t.Fatalf("gauge2 value after stop: got %d, want 0", gauge2.Snapshot().Value()) + } +} diff --git a/rlp/encode.go b/rlp/encode.go index 3645bbfda..fa2a3008e 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -101,6 +101,29 @@ func EncodeToReader(val interface{}) (size int, r io.Reader, err error) { return buf.size(), &encReader{buf: buf}, nil } +// EncodeToRawList encodes val as an RLP list and returns it as a RawList. +func EncodeToRawList[T any](val []T) (RawList[T], error) { + if len(val) == 0 { + return RawList[T]{}, nil + } + + // Encode the value to an internal buffer. + buf := getEncBuffer() + defer encBufferPool.Put(buf) + if err := buf.encode(val); err != nil { + return RawList[T]{}, err + } + + // Create the RawList. RawList assumes the initial list header is padded + // 9 bytes, so we have to determine the offset where the value should be + // placed. + contentSize := buf.lheads[0].size + bytes := make([]byte, contentSize+9) + offset := 9 - headsize(uint64(contentSize)) + buf.copyTo(bytes[offset:]) + return RawList[T]{enc: bytes, length: len(val)}, nil +} + type listhead struct { offset int // index of this header in string data size int // total size of encoded data (including list headers) diff --git a/rlp/iterator.go b/rlp/iterator.go index 6be574572..9e41cec94 100644 --- a/rlp/iterator.go +++ b/rlp/iterator.go @@ -16,45 +16,74 @@ package rlp -type listIterator struct { - data []byte - next []byte - err error +// Iterator is an iterator over the elements of an encoded container. +type Iterator struct { + data []byte + next []byte + offset int + err error } -// NewListIterator creates an iterator for the (list) represented by data -// TODO: Consider removing this implementation, as it is no longer used. -func NewListIterator(data RawValue) (*listIterator, error) { +// NewListIterator creates an iterator for the (list) represented by data. +func NewListIterator(data RawValue) (Iterator, error) { k, t, c, err := readKind(data) if err != nil { - return nil, err + return Iterator{}, err } if k != List { - return nil, ErrExpectedList - } - it := &listIterator{ - data: data[t : t+c], + return Iterator{}, ErrExpectedList } + it := newIterator(data[t:t+c], int(t)) return it, nil } -// Next forwards the iterator one step, returns true if it was not at end yet -func (it *listIterator) Next() bool { +func newIterator(data []byte, initialOffset int) Iterator { + return Iterator{data: data, offset: initialOffset} +} + +// Next forwards the iterator one step. +// Returns true if there is a next item or an error occurred on this step (check Err()). +// On parse error, the iterator is marked finished and subsequent calls return false. +func (it *Iterator) Next() bool { if len(it.data) == 0 { return false } _, t, c, err := readKind(it.data) - it.next = it.data[:t+c] - it.data = it.data[t+c:] - it.err = err + if err != nil { + it.next = nil + it.err = err + // Mark iteration as finished to avoid potential infinite loops on subsequent Next calls. + it.data = nil + return true + } + length := t + c + it.next = it.data[:length] + it.data = it.data[length:] + it.offset += int(length) + it.err = nil return true } -// Value returns the current value -func (it *listIterator) Value() []byte { +// Value returns the current value. +func (it *Iterator) Value() []byte { return it.next } -func (it *listIterator) Err() error { +// Count returns the remaining number of items. +// Note this is O(n) and the result may be incorrect if the list data is invalid. +// The returned count is always an upper bound on the remaining items +// that will be visited by the iterator. +func (it *Iterator) Count() int { + count, _ := CountValues(it.data) + return count +} + +// Offset returns the offset of the current value into the list data. +func (it *Iterator) Offset() int { + return it.offset - len(it.next) +} + +// Err returns the error that caused Next to return false, if any. +func (it *Iterator) Err() error { return it.err } diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go index a22aaec86..275d4371c 100644 --- a/rlp/iterator_test.go +++ b/rlp/iterator_test.go @@ -17,6 +17,7 @@ package rlp import ( + "io" "testing" "github.com/ethereum/go-ethereum/common/hexutil" @@ -38,14 +39,25 @@ func TestIterator(t *testing.T) { t.Fatal("expected two elems, got zero") } txs := it.Value() + if offset := it.Offset(); offset != 3 { + t.Fatal("wrong offset", offset, "want 3") + } + // Check that uncles exist if !it.Next() { t.Fatal("expected two elems, got one") } + if offset := it.Offset(); offset != 219 { + t.Fatal("wrong offset", offset, "want 219") + } + txit, err := NewListIterator(txs) if err != nil { t.Fatal(err) } + if c := txit.Count(); c != 2 { + t.Fatal("wrong Count:", c) + } var i = 0 for txit.Next() { if txit.err != nil { @@ -57,3 +69,65 @@ func TestIterator(t *testing.T) { t.Errorf("count wrong, expected %d got %d", i, exp) } } + +func TestIteratorErrors(t *testing.T) { + tests := []struct { + input []byte + wantCount int // expected Count before iterating + wantErr error + }{ + // Second item string header claims 3 bytes content, but only 2 remain. + {unhex("C4 01 83AABB"), 2, ErrValueTooLarge}, + // Second item truncated: B9 requires 2 size bytes, none available. + {unhex("C2 01 B9"), 2, io.ErrUnexpectedEOF}, + // 0x05 should be encoded directly, not as 81 05. + {unhex("C3 01 8105"), 2, ErrCanonSize}, + // Long-form string header B8 used for 1-byte content (< 56). + {unhex("C4 01 B801AA"), 2, ErrCanonSize}, + // Long-form list header F8 used for 1-byte content (< 56). + {unhex("C4 01 F80101"), 2, ErrCanonSize}, + } + for _, tt := range tests { + it, err := NewListIterator(tt.input) + if err != nil { + t.Fatal("NewListIterator error:", err) + } + if c := it.Count(); c != tt.wantCount { + t.Fatalf("%x: Count = %d, want %d", tt.input, c, tt.wantCount) + } + n := 0 + for it.Next() { + if it.Err() != nil { + break + } + n++ + } + if wantN := tt.wantCount - 1; n != wantN { + t.Fatalf("%x: got %d valid items, want %d", tt.input, n, wantN) + } + if it.Err() != tt.wantErr { + t.Fatalf("%x: got error %v, want %v", tt.input, it.Err(), tt.wantErr) + } + if it.Next() { + t.Fatalf("%x: Next returned true after error", tt.input) + } + } +} + +func FuzzIteratorCount(f *testing.F) { + examples := [][]byte{unhex("010203"), unhex("018142"), unhex("01830202")} + for _, e := range examples { + f.Add(e) + } + f.Fuzz(func(t *testing.T, in []byte) { + it := newIterator(in, 0) + count := it.Count() + i := 0 + for it.Next() { + i++ + } + if i != count { + t.Fatalf("%x: count %d not equal to %d iterations", in, count, i) + } + }) +} diff --git a/rlp/raw.go b/rlp/raw.go index 773aa7e61..2848b44d5 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -17,8 +17,10 @@ package rlp import ( + "fmt" "io" "reflect" + "slices" ) // RawValue represents an encoded RLP value and can be used to delay @@ -28,6 +30,144 @@ type RawValue []byte var rawValueType = reflect.TypeOf(RawValue{}) +// RawList represents an encoded RLP list. +type RawList[T any] struct { + // The list is stored in encoded form. + // Note this buffer has some special properties: + // + // - if the buffer is nil, it's the zero value, representing + // an empty list. + // - if the buffer is non-nil, it must have a length of at least + // 9 bytes, which is reserved padding for the encoded list header. + // The remaining bytes, enc[9:], store the content bytes of the list. + // + // The implementation code mostly works with the Content method because it + // returns something valid either way. + enc []byte + + // length holds the number of items in the list. + length int +} + +// Content returns the RLP-encoded data of the list. +// This does not include the list-header. +// The return value is a direct reference to the internal buffer, not a copy. +func (r *RawList[T]) Content() []byte { + if r.enc == nil { + return nil + } else { + return r.enc[9:] + } +} + +// EncodeRLP writes the encoded list to the writer. +func (r RawList[T]) EncodeRLP(w io.Writer) error { + _, err := w.Write(r.Bytes()) + return err +} + +// Bytes returns the RLP encoding of the list. +// Note the return value aliases the internal buffer. +func (r *RawList[T]) Bytes() []byte { + if r == nil || r.enc == nil { + return []byte{0xC0} // zero value encodes as empty list + } + n := puthead(r.enc, 0xC0, 0xF7, uint64(len(r.Content()))) + copy(r.enc[9-n:], r.enc[:n]) + return r.enc[9-n:] +} + +// DecodeRLP decodes the list. This does not perform validation of the items! +func (r *RawList[T]) DecodeRLP(s *Stream) error { + k, size, err := s.Kind() + if err != nil { + return err + } + if k != List { + return fmt.Errorf("%w for %T", ErrExpectedList, r) + } + enc := make([]byte, 9+size) + if err := s.readFull(enc[9:]); err != nil { + return err + } + n, err := CountValues(enc[9:]) + if err != nil { + if err == ErrValueTooLarge { + return ErrElemTooLarge + } + return err + } + *r = RawList[T]{enc: enc, length: n} + return nil +} + +// Items decodes and returns all items in the list. +func (r *RawList[T]) Items() ([]T, error) { + items := make([]T, r.Len()) + it := r.ContentIterator() + for i := 0; it.Next(); i++ { + if err := DecodeBytes(it.Value(), &items[i]); err != nil { + return items[:i], err + } + } + return items, nil +} + +// Len returns the number of items in the list. +func (r *RawList[T]) Len() int { + return r.length +} + +// Size returns the encoded size of the list. +func (r *RawList[T]) Size() uint64 { + return ListSize(uint64(len(r.Content()))) +} + +// ContentIterator returns an iterator over the content of the list. +// Note the offsets returned by iterator.Offset are relative to the +// Content bytes of the list. +func (r *RawList[T]) ContentIterator() Iterator { + return newIterator(r.Content(), 0) +} + +// Append adds an item to the end of the list. +func (r *RawList[T]) Append(item T) error { + if r.enc == nil { + r.enc = make([]byte, 9) + } + + eb := getEncBuffer() + defer encBufferPool.Put(eb) + + if err := eb.encode(item); err != nil { + return err + } + prevEnd := len(r.enc) + end := prevEnd + eb.size() + r.enc = slices.Grow(r.enc, eb.size())[:end] + eb.copyTo(r.enc[prevEnd:end]) + r.length++ + return nil +} + +// AppendRaw adds an encoded item to the list. +// The given byte slice must contain exactly one RLP value. +func (r *RawList[T]) AppendRaw(b []byte) error { + _, tagsize, contentsize, err := readKind(b) + if err != nil { + return err + } + if tagsize+contentsize != uint64(len(b)) { + return fmt.Errorf("rlp: input has trailing bytes in AppendRaw") + } + if r.enc == nil { + r.enc = make([]byte, 9) + } + r.enc = append(r.enc, b...) + r.length++ + return nil +} + // StringSize returns the encoded size of a string. func StringSize(s string) uint64 { switch { @@ -143,7 +283,7 @@ func CountValues(b []byte) (int, error) { for ; len(b) > 0; i++ { _, tagsize, size, err := readKind(b) if err != nil { - return 0, err + return i + 1, err } b = b[tagsize+size:] } diff --git a/rlp/raw_test.go b/rlp/raw_test.go index 7b3255eca..e80b42718 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -19,11 +19,259 @@ package rlp import ( "bytes" "errors" + "fmt" "io" + "reflect" "testing" "testing/quick" ) +type rawListTest[T any] struct { + input string + content string + items []T + length int +} + +func (test rawListTest[T]) name() string { + return fmt.Sprintf("%T-%d", *new(T), test.length) +} + +func (test rawListTest[T]) run(t *testing.T) { + // check decoding and properties + input := unhex(test.input) + inputSize := len(input) + var rl RawList[T] + if err := DecodeBytes(input, &rl); err != nil { + t.Fatal("decode failed:", err) + } + if l := rl.Len(); l != test.length { + t.Fatalf("wrong Len %d, want %d", l, test.length) + } + if sz := rl.Size(); sz != uint64(inputSize) { + t.Fatalf("wrong Size %d, want %d", sz, inputSize) + } + items, err := rl.Items() + if err != nil { + t.Fatal("Items failed:", err) + } + if !reflect.DeepEqual(items, test.items) { + t.Fatal("wrong items:", items) + } + if !bytes.Equal(rl.Content(), unhex(test.content)) { + t.Fatalf("wrong Content %x, want %s", rl.Content(), test.content) + } + if !bytes.Equal(rl.Bytes(), unhex(test.input)) { + t.Fatalf("wrong Bytes %x, want %s", rl.Bytes(), test.input) + } + + // check iterator + it := rl.ContentIterator() + i := 0 + for it.Next() { + var item T + if err := DecodeBytes(it.Value(), &item); err != nil { + t.Fatalf("item %d decode error: %v", i, err) + } + if !reflect.DeepEqual(item, items[i]) { + t.Fatalf("iterator has wrong item %v at %d", item, i) + } + i++ + } + if i != test.length { + t.Fatalf("iterator produced %d values, want %d", i, test.length) + } + if it.Err() != nil { + t.Fatalf("iterator error: %v", it.Err()) + } + + // check encoding round trip + output, err := EncodeToBytes(&rl) + if err != nil { + t.Fatal("encode error:", err) + } + if !bytes.Equal(output, unhex(test.input)) { + t.Fatalf("encoding does not round trip: %x", output) + } + + // check EncodeToRawList on items produces same bytes + encRL, err := EncodeToRawList(test.items) + if err != nil { + t.Fatal("EncodeToRawList error:", err) + } + encRLOutput, err := EncodeToBytes(&encRL) + if err != nil { + t.Fatal("EncodeToBytes of encoded list failed:", err) + } + if !bytes.Equal(encRLOutput, output) { + t.Fatalf("wrong encoding of EncodeToRawList result: %x", encRLOutput) + } +} + +func TestRawList(t *testing.T) { + tests := []interface { + name() string + run(t *testing.T) + }{ + rawListTest[uint64]{ + input: "C0", + content: "", + items: []uint64{}, + length: 0, + }, + rawListTest[uint64]{ + input: "C3010203", + content: "010203", + items: []uint64{1, 2, 3}, + length: 3, + }, + rawListTest[simplestruct]{ + input: "C6C20102C20304", + content: "C20102C20304", + items: []simplestruct{{1, "\x02"}, {3, "\x04"}}, + length: 2, + }, + rawListTest[string]{ + input: "F83C836161618362626283636363836464648365656583666666836767678368686883696969836A6A6A836B6B6B836C6C6C836D6D6D836E6E6E836F6F6F", + content: "836161618362626283636363836464648365656583666666836767678368686883696969836A6A6A836B6B6B836C6C6C836D6D6D836E6E6E836F6F6F", + items: []string{"aaa", "bbb", "ccc", "ddd", "eee", "fff", "ggg", "hhh", "iii", "jjj", "kkk", "lll", "mmm", "nnn", "ooo"}, + length: 15, + }, + } + + for _, test := range tests { + t.Run(test.name(), test.run) + } +} + +func TestRawListEmpty(t *testing.T) { + // zero value list + var rl RawList[uint64] + b, _ := EncodeToBytes(&rl) + if !bytes.Equal(b, unhex("C0")) { + t.Fatalf("empty RawList has wrong encoding %x", b) + } + if rl.Len() != 0 { + t.Fatalf("empty list has Len %d", rl.Len()) + } + if rl.Size() != 1 { + t.Fatalf("empty list has Size %d", rl.Size()) + } + if len(rl.Content()) > 0 { + t.Fatalf("empty list has non-empty Content") + } + if !bytes.Equal(rl.Bytes(), []byte{0xC0}) { + t.Fatalf("empty list has wrong encoding") + } + + // nil pointer + var nilptr *RawList[uint64] + b, _ = EncodeToBytes(nilptr) + if !bytes.Equal(b, unhex("C0")) { + t.Fatalf("nil pointer to RawList has wrong encoding %x", b) + } +} + +// This checks that *RawList works in an 'optional' context. +func TestRawListOptional(t *testing.T) { + type foo struct { + L *RawList[uint64] `rlp:"optional"` + } + // nil pointer encoding + var empty foo + b, _ := EncodeToBytes(empty) + if !bytes.Equal(b, unhex("C0")) { + t.Fatalf("nil pointer to RawList has wrong encoding %x", b) + } + // decoding + var dec foo + if err := DecodeBytes(unhex("C0"), &dec); err != nil { + t.Fatal(err) + } + if dec.L != nil { + t.Fatal("rawlist was decoded as non-nil") + } +} + +func TestRawListAppend(t *testing.T) { + var rl RawList[simplestruct] + + v1 := simplestruct{1, "one"} + v2 := simplestruct{2, "two"} + if err := rl.Append(v1); err != nil { + t.Fatal("append 1 failed:", err) + } + if err := rl.Append(v2); err != nil { + t.Fatal("append 2 failed:", err) + } + + if rl.Len() != 2 { + t.Fatalf("wrong Len %d", rl.Len()) + } + if rl.Size() != 13 { + t.Fatalf("wrong Size %d", rl.Size()) + } + if !bytes.Equal(rl.Content(), unhex("C501836F6E65 C5028374776F")) { + t.Fatalf("wrong Content %x", rl.Content()) + } + encoded, _ := EncodeToBytes(&rl) + if !bytes.Equal(encoded, unhex("CC C501836F6E65 C5028374776F")) { + t.Fatalf("wrong encoding %x", encoded) + } +} + +func TestRawListAppendRaw(t *testing.T) { + var rl RawList[uint64] + + if err := rl.AppendRaw(unhex("01")); err != nil { + t.Fatal("AppendRaw(01) failed:", err) + } + if err := rl.AppendRaw(unhex("820102")); err != nil { + t.Fatal("AppendRaw(820102) failed:", err) + } + if rl.Len() != 2 { + t.Fatalf("wrong Len %d after valid appends", rl.Len()) + } + + if err := rl.AppendRaw(nil); err == nil { + t.Fatal("AppendRaw(nil) should fail") + } + if err := rl.AppendRaw(unhex("0102")); err == nil { + t.Fatal("AppendRaw(0102) should fail due to trailing bytes") + } + if err := rl.AppendRaw(unhex("8201")); err == nil { + t.Fatal("AppendRaw(8201) should fail due to truncated value") + } + if rl.Len() != 2 { + t.Fatalf("wrong Len %d after invalid appends, want 2", rl.Len()) + } +} + +func TestRawListDecodeInvalid(t *testing.T) { + tests := []struct { + input string + err error + }{ + // Single item with non-canonical size (0x81 wrapping byte <= 0x7F). + {input: "C28142", err: ErrCanonSize}, + // Single item claiming more bytes than available in the list. + {input: "C484020202", err: ErrElemTooLarge}, + // Two items, second has non-canonical size. + {input: "C3018142", err: ErrCanonSize}, + // Two items, second claims more bytes than remain in the list. + {input: "C401830202", err: ErrElemTooLarge}, + // Item is a sub-list whose declared size exceeds available bytes. + {input: "C3C40102", err: ErrElemTooLarge}, + } + for _, test := range tests { + var rl RawList[RawValue] + err := DecodeBytes(unhex(test.input), &rl) + if !errors.Is(err, test.err) { + t.Errorf("input %s: error mismatch: got %v, want %v", test.input, err, test.err) + } + } +} + func TestCountValues(t *testing.T) { tests := []struct { input string // note: spaces in input are stripped by unhex @@ -40,9 +288,9 @@ func TestCountValues(t *testing.T) { {"820101 820202 8403030303 04", 4, nil}, // size errors - {"8142", 0, ErrCanonSize}, - {"01 01 8142", 0, ErrCanonSize}, - {"02 84020202", 0, ErrValueTooLarge}, + {"8142", 1, ErrCanonSize}, + {"01 01 8142", 3, ErrCanonSize}, + {"02 84020202", 2, ErrValueTooLarge}, { input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470", diff --git a/trie/node.go b/trie/node.go index 15bbf62f1..1d0c5556c 100644 --- a/trie/node.go +++ b/trie/node.go @@ -151,11 +151,14 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) { if err != nil { return nil, fmt.Errorf("decode error: %v", err) } - switch c, _ := rlp.CountValues(elems); c { - case 2: + c, err := rlp.CountValues(elems) + switch { + case err != nil: + return nil, fmt.Errorf("invalid node list: %v", err) + case c == 2: n, err := decodeShort(hash, elems) return n, wrapError(err, "short") - case 17: + case c == 17: n, err := decodeFull(hash, elems) return n, wrapError(err, "full") default: