Skip to content

Commit b5d3220

Browse files
authored
eth/protocols/snap: fix block accessList encoding rule (#34644)
This PR refactors the encoding rules for `AccessListsPacket` in the wire protocol. Specifically: - The response is now encoded as a list of `rlp.RawValue` - `rlp.EmptyString` is used as a placeholder for unavailable BAL objects
1 parent bd6530a commit b5d3220

File tree

3 files changed

+174
-55
lines changed

3 files changed

+174
-55
lines changed

eth/protocols/snap/handler_test.go

Lines changed: 160 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,48 @@ package snap
1818

1919
import (
2020
"bytes"
21+
"encoding/binary"
22+
"reflect"
2123
"testing"
2224
"time"
2325

2426
"github.com/ethereum/go-ethereum/common"
27+
"github.com/ethereum/go-ethereum/consensus/beacon"
2528
"github.com/ethereum/go-ethereum/consensus/ethash"
2629
"github.com/ethereum/go-ethereum/core"
2730
"github.com/ethereum/go-ethereum/core/rawdb"
31+
"github.com/ethereum/go-ethereum/core/types/bal"
2832
"github.com/ethereum/go-ethereum/params"
2933
"github.com/ethereum/go-ethereum/rlp"
3034
)
3135

36+
func makeTestBAL(minSize int) *bal.BlockAccessList {
37+
n := minSize/33 + 1 // 33 bytes per storage read slot in RLP
38+
access := bal.AccountAccess{
39+
Address: common.HexToAddress("0x01"),
40+
StorageReads: make([][32]byte, n),
41+
}
42+
for i := range access.StorageReads {
43+
binary.BigEndian.PutUint64(access.StorageReads[i][24:], uint64(i))
44+
}
45+
return &bal.BlockAccessList{Accesses: []bal.AccountAccess{access}}
46+
}
47+
3248
// getChainWithBALs creates a minimal test chain with BALs stored for each block.
3349
// It returns the chain, block hashes, and the stored BAL data.
3450
func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) {
3551
gspec := &core.Genesis{
36-
Config: params.TestChainConfig,
52+
Config: params.MergedTestChainConfig,
3753
}
3854
db := rawdb.NewMemoryDatabase()
39-
_, blocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), nBlocks, func(i int, gen *core.BlockGen) {})
55+
engine := beacon.New(ethash.NewFaker())
56+
_, blocks, _ := core.GenerateChainWithGenesis(gspec, engine, nBlocks, func(i int, gen *core.BlockGen) {})
4057
options := &core.BlockChainConfig{
41-
TrieCleanLimit: 0,
42-
TrieDirtyLimit: 0,
43-
TrieTimeLimit: 5 * time.Minute,
44-
NoPrefetch: true,
45-
SnapshotLimit: 0,
58+
StateScheme: rawdb.PathScheme,
59+
TrieTimeLimit: 5 * time.Minute,
60+
NoPrefetch: true,
4661
}
47-
bc, err := core.NewBlockChain(db, gspec, ethash.NewFaker(), options)
62+
bc, err := core.NewBlockChain(db, gspec, engine, options)
4863
if err != nil {
4964
panic(err)
5065
}
@@ -53,20 +68,22 @@ func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash
5368
}
5469

5570
// Store BALs for each block
56-
var hashes []common.Hash
57-
var bals []rlp.RawValue
71+
var (
72+
hashes []common.Hash
73+
bals []rlp.RawValue
74+
)
5875
for _, block := range blocks {
5976
hash := block.Hash()
6077
number := block.NumberU64()
61-
bal := make(rlp.RawValue, balSize)
6278

6379
// Fill with data based on block number
64-
for j := range bal {
65-
bal[j] = byte(number + uint64(j))
80+
bytes, err := rlp.EncodeToBytes(makeTestBAL(balSize))
81+
if err != nil {
82+
panic(err)
6683
}
67-
rawdb.WriteAccessListRLP(db, hash, number, bal)
84+
rawdb.WriteAccessListRLP(db, hash, number, bytes)
6885
hashes = append(hashes, hash)
69-
bals = append(bals, bal)
86+
bals = append(bals, bytes)
7087
}
7188
return bc, hashes, bals
7289
}
@@ -85,13 +102,18 @@ func TestServiceGetAccessListsQuery(t *testing.T) {
85102
result := ServiceGetAccessListsQuery(bc, req)
86103

87104
// Verify the results
88-
if len(result) != len(hashes) {
89-
t.Fatalf("expected %d results, got %d", len(hashes), len(result))
105+
if result.Len() != len(hashes) {
106+
t.Fatalf("expected %d results, got %d", len(hashes), result.Len())
90107
}
91-
for i, bal := range result {
92-
if !bytes.Equal(bal, bals[i]) {
93-
t.Errorf("BAL %d mismatch: got %x, want %x", i, bal, bals[i])
108+
var (
109+
index int
110+
it = result.ContentIterator()
111+
)
112+
for it.Next() {
113+
if !bytes.Equal(it.Value(), bals[index]) {
114+
t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), bals[index])
94115
}
116+
index++
95117
}
96118
}
97119

@@ -111,25 +133,23 @@ func TestServiceGetAccessListsQueryEmpty(t *testing.T) {
111133
result := ServiceGetAccessListsQuery(bc, req)
112134

113135
// Verify length
114-
if len(result) != len(mixed) {
115-
t.Fatalf("expected %d results, got %d", len(mixed), len(result))
136+
if result.Len() != len(mixed) {
137+
t.Fatalf("expected %d results, got %d", len(mixed), result.Len())
116138
}
117139

118140
// Check positional correspondence
119-
if !bytes.Equal(result[0], bals[0]) {
120-
t.Errorf("index 0: expected known BAL, got %x", result[0])
121-
}
122-
if result[1] != nil {
123-
t.Errorf("index 1: expected nil for unknown hash, got %x", result[1])
141+
var expectVal = []rlp.RawValue{
142+
bals[0], rlp.EmptyString, bals[1], rlp.EmptyString, bals[2],
124143
}
125-
if !bytes.Equal(result[2], bals[1]) {
126-
t.Errorf("index 2: expected known BAL, got %x", result[2])
127-
}
128-
if result[3] != nil {
129-
t.Errorf("index 3: expected nil for unknown hash, got %x", result[3])
130-
}
131-
if !bytes.Equal(result[4], bals[2]) {
132-
t.Errorf("index 4: expected known BAL, got %x", result[4])
144+
var (
145+
index int
146+
it = result.ContentIterator()
147+
)
148+
for it.Next() {
149+
if !bytes.Equal(it.Value(), expectVal[index]) {
150+
t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), expectVal[index])
151+
}
152+
index++
133153
}
134154
}
135155

@@ -154,8 +174,8 @@ func TestServiceGetAccessListsQueryCap(t *testing.T) {
154174
result := ServiceGetAccessListsQuery(bc, req)
155175

156176
// Can't get more than maxAccessListLookups results
157-
if len(result) > maxAccessListLookups {
158-
t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, len(result))
177+
if result.Len() > maxAccessListLookups {
178+
t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, result.Len())
159179
}
160180
}
161181

@@ -179,21 +199,116 @@ func TestServiceGetAccessListsQueryByteLimit(t *testing.T) {
179199
result := ServiceGetAccessListsQuery(bc, req)
180200

181201
// Should have stopped before returning all blocks
182-
if len(result) >= nBlocks {
183-
t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, len(result))
202+
if result.Len() >= nBlocks {
203+
t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, result.Len())
184204
}
185205

186206
// Should have returned at least one
187-
if len(result) == 0 {
207+
if result.Len() == 0 {
188208
t.Fatal("expected at least one result")
189209
}
190210

191211
// The total size should exceed the limit (the entry that crosses it is included)
192-
var total uint64
193-
for _, bal := range result {
194-
total += uint64(len(bal))
212+
if result.Size() <= softResponseLimit {
213+
t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", result.Size(), softResponseLimit)
214+
}
215+
}
216+
217+
// TestGetAccessListResponseDecoding verifies that an AccessListsPacket
218+
// round-trips through RLP encode/decode, preserving positional
219+
// correspondence and correctly representing absent BALs as empty strings.
220+
func TestGetAccessListResponseDecoding(t *testing.T) {
221+
t.Parallel()
222+
223+
// Build two real BALs of different sizes.
224+
bal1 := makeTestBAL(100)
225+
bal2 := makeTestBAL(200)
226+
bytes1, _ := rlp.EncodeToBytes(bal1)
227+
bytes2, _ := rlp.EncodeToBytes(bal2)
228+
229+
tests := []struct {
230+
name string
231+
items []rlp.RawValue // nil entry = unavailable BAL
232+
counts int // expected decoded length
233+
}{
234+
{
235+
name: "all present",
236+
items: []rlp.RawValue{bytes1, bytes2},
237+
counts: 2,
238+
},
239+
{
240+
name: "all absent",
241+
items: []rlp.RawValue{rlp.EmptyString, rlp.EmptyString, rlp.EmptyString},
242+
counts: 3,
243+
},
244+
{
245+
name: "mixed present and absent",
246+
items: []rlp.RawValue{bytes1, rlp.EmptyString, bytes2, rlp.EmptyString},
247+
counts: 4,
248+
},
249+
{
250+
name: "empty response",
251+
items: []rlp.RawValue{},
252+
counts: 0,
253+
},
195254
}
196-
if total <= softResponseLimit {
197-
t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", total, softResponseLimit)
255+
for _, tt := range tests {
256+
t.Run(tt.name, func(t *testing.T) {
257+
// Build the packet using Append.
258+
var orig AccessListsPacket
259+
orig.ID = 42
260+
for _, item := range tt.items {
261+
if err := orig.AccessLists.AppendRaw(item); err != nil {
262+
t.Fatalf("AppendRaw failed: %v", err)
263+
}
264+
}
265+
266+
// Encode -> Decode round-trip.
267+
enc, err := rlp.EncodeToBytes(&orig)
268+
if err != nil {
269+
t.Fatalf("encode failed: %v", err)
270+
}
271+
var dec AccessListsPacket
272+
if err := rlp.DecodeBytes(enc, &dec); err != nil {
273+
t.Fatalf("decode failed: %v", err)
274+
}
275+
276+
// Verify ID preserved.
277+
if dec.ID != orig.ID {
278+
t.Fatalf("ID mismatch: got %d, want %d", dec.ID, orig.ID)
279+
}
280+
281+
// Verify element count.
282+
if dec.AccessLists.Len() != tt.counts {
283+
t.Fatalf("length mismatch: got %d, want %d", dec.AccessLists.Len(), tt.counts)
284+
}
285+
286+
// Verify each element positionally.
287+
it := dec.AccessLists.ContentIterator()
288+
for i, want := range tt.items {
289+
if !it.Next() {
290+
t.Fatalf("iterator exhausted at index %d", i)
291+
}
292+
got := it.Value()
293+
if !bytes.Equal(got, want) {
294+
t.Errorf("element %d: got %x, want %x", i, got, want)
295+
}
296+
if !bytes.Equal(got, rlp.EmptyString) {
297+
obj := new(bal.BlockAccessList)
298+
if err := rlp.DecodeBytes(got, obj); err != nil {
299+
t.Fatalf("decode failed: %v", err)
300+
}
301+
if bytes.Equal(got, bytes1) && !reflect.DeepEqual(obj, bal1) {
302+
t.Fatalf("decode failed: got %x, want %x", obj, bal1)
303+
}
304+
if bytes.Equal(got, bytes2) && !reflect.DeepEqual(obj, bal2) {
305+
t.Fatalf("decode failed: got %x, want %x", obj, bal2)
306+
}
307+
}
308+
}
309+
if it.Next() {
310+
t.Error("iterator has extra elements after expected end")
311+
}
312+
})
198313
}
199314
}

eth/protocols/snap/handlers.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -559,16 +559,15 @@ func handleGetAccessLists(backend Backend, msg Decoder, peer *Peer) error {
559559
if err := msg.Decode(&req); err != nil {
560560
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
561561
}
562-
bals := ServiceGetAccessListsQuery(backend.Chain(), &req)
563562
return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{
564563
ID: req.ID,
565-
AccessLists: bals,
564+
AccessLists: ServiceGetAccessListsQuery(backend.Chain(), &req),
566565
})
567566
}
568567

569568
// ServiceGetAccessListsQuery assembles the response to an access list query.
570569
// It is exposed to allow external packages to test protocol behavior.
571-
func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) []rlp.RawValue {
570+
func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) rlp.RawList[rlp.RawValue] {
572571
if req.Bytes > softResponseLimit {
573572
req.Bytes = softResponseLimit
574573
}
@@ -577,20 +576,25 @@ func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacke
577576
req.Hashes = req.Hashes[:maxAccessListLookups]
578577
}
579578
var (
580-
bals []rlp.RawValue
581-
bytes uint64
579+
err error
580+
bytes uint64
581+
response = rlp.RawList[rlp.RawValue]{}
582582
)
583583
for _, hash := range req.Hashes {
584584
if bal := chain.GetAccessListRLP(hash); len(bal) > 0 {
585-
bals = append(bals, bal)
585+
err = response.AppendRaw(bal)
586586
bytes += uint64(len(bal))
587587
} else {
588588
// Either the block is unknown or the BAL doesn't exist
589-
bals = append(bals, nil)
589+
err = response.AppendRaw(rlp.EmptyString)
590+
bytes += 1
591+
}
592+
if err != nil {
593+
break
590594
}
591595
if bytes > req.Bytes {
592596
break
593597
}
594598
}
595-
return bals
599+
return response
596600
}

eth/protocols/snap/protocol.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ type GetAccessListsPacket struct {
229229
// Each entry corresponds to the requested hash at the same index.
230230
// Empty entries indicate the BAL is unavailable.
231231
type AccessListsPacket struct {
232-
ID uint64 // ID of the request this is a response for
233-
AccessLists []rlp.RawValue // Requested BALs
232+
ID uint64 // ID of the request this is a response for
233+
AccessLists rlp.RawList[rlp.RawValue] // Requested BALs
234234
}
235235

236236
func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }

0 commit comments

Comments
 (0)