Skip to content

Commit 0efebff

Browse files
rlpgen (#19513)
for #11116 this PR does introduce `rlpgen` package, but doesn't enable it for any our type ``` go test -bench=BenchmarkLogCustomVsGenerated -benchmem ./execution/types/ goos: linux goarch: amd64 pkg: github.com/erigontech/erigon/execution/types cpu: AMD EPYC 4344P 8-Core Processor BenchmarkLogCustomVsGenerated/Custom/Encode-16 8309613 139.7 ns/op 104 B/op 2 allocs/op BenchmarkLogCustomVsGenerated/Custom/Decode-16 5085690 228.0 ns/op 272 B/op 5 allocs/op BenchmarkLogCustomVsGenerated/Generated/Encode-16 18465400 60.34 ns/op 32 B/op 1 allocs/op BenchmarkLogCustomVsGenerated/Generated/Decode-16 5429574 215.8 ns/op 240 B/op 4 allocs/op PASS ok github.com/erigontech/erigon/execution/types 7.332s ``` Reason of bench numbers: - encode: we don't have custom hand-written encoder (reflection based now) - decode: we have custom hand-written decoder (I ported all existing optimizations to `rlpgen`) --------- Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com>
1 parent c72b6d4 commit 0efebff

File tree

5 files changed

+461
-22
lines changed

5 files changed

+461
-22
lines changed

cmd/rlpgen/handlers.go

Lines changed: 134 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,62 @@ func addToImports(named *types.Named) (typ string) {
8989
return
9090
}
9191

92+
func boolHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
93+
// size - bool encoded as 0 or 1 (1 byte)
94+
fmt.Fprintf(b1, " size += 1\n")
95+
96+
// encode - bool encoded as 0 or 1
97+
fmt.Fprintf(b2, " var bval uint64\n")
98+
fmt.Fprintf(b2, " if obj.%s {\n", fieldName)
99+
fmt.Fprintf(b2, " bval = 1\n")
100+
fmt.Fprintf(b2, " }\n")
101+
fmt.Fprintf(b2, " if err := rlp.EncodeInt(bval, w, b[:]); err != nil {\n")
102+
fmt.Fprintf(b2, " return err\n")
103+
fmt.Fprintf(b2, " }\n")
104+
105+
// decode
106+
fmt.Fprintf(b3, " if n, err := s.Uint(); err != nil {\n")
107+
fmt.Fprintf(b3, " %s\n", decodeErrorMsg(fieldName))
108+
fmt.Fprintf(b3, " } else {\n")
109+
fmt.Fprintf(b3, " obj.%s = n != 0\n", fieldName)
110+
fmt.Fprintf(b3, " }\n")
111+
}
112+
113+
func boolPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
114+
// size - bool encoded as 0 or 1 (1 byte) or empty if nil
115+
fmt.Fprintf(b1, " if obj.%s != nil {\n", fieldName)
116+
fmt.Fprintf(b1, " size += 1\n")
117+
fmt.Fprintf(b1, " }\n")
118+
119+
// encode - bool encoded as 0 or 1, or empty if nil
120+
fmt.Fprintf(b2, " var bval uint64\n")
121+
fmt.Fprintf(b2, " if obj.%s != nil {\n", fieldName)
122+
fmt.Fprintf(b2, " if *obj.%s {\n", fieldName)
123+
fmt.Fprintf(b2, " bval = 1\n")
124+
fmt.Fprintf(b2, " }\n")
125+
fmt.Fprintf(b2, " if err := rlp.EncodeInt(bval, w, b[:]); err != nil {\n")
126+
fmt.Fprintf(b2, " return err\n")
127+
fmt.Fprintf(b2, " }\n")
128+
fmt.Fprintf(b2, " } else {\n")
129+
fmt.Fprintf(b2, " if err := rlp.EncodeInt(0, w, b[:]); err != nil {\n")
130+
fmt.Fprintf(b2, " return err\n")
131+
fmt.Fprintf(b2, " }\n")
132+
fmt.Fprintf(b2, " }\n")
133+
134+
// decode
135+
fmt.Fprintf(b3, " if n, err := s.Uint(); err != nil {\n")
136+
fmt.Fprintf(b3, " %s\n", decodeErrorMsg(fieldName))
137+
fmt.Fprintf(b3, " } else {\n")
138+
fmt.Fprintf(b3, " bval := n != 0\n")
139+
fmt.Fprintf(b3, " obj.%s = &bval\n", fieldName)
140+
fmt.Fprintf(b3, " }\n")
141+
}
142+
92143
func uint64CastTo(kind types.BasicKind) string {
93144
var cast string
94145
switch kind {
146+
case types.Int8:
147+
cast = "int8"
95148
case types.Int16:
96149
cast = "int16"
97150
case types.Int32:
@@ -100,6 +153,8 @@ func uint64CastTo(kind types.BasicKind) string {
100153
cast = "int"
101154
case types.Int64:
102155
cast = "int64"
156+
case types.Uint8:
157+
cast = "uint8"
103158
case types.Uint16:
104159
cast = "uint16"
105160
case types.Uint32:
@@ -294,15 +349,10 @@ func _shortArrayHandle(b1, b2, b3 *bytes.Buffer, fieldName string, size int) { /
294349
fmt.Fprintf(b2, " return err\n")
295350
fmt.Fprintf(b2, " }\n")
296351

297-
// decode
298-
addDecodeBuf(b3)
299-
fmt.Fprintf(b3, " if b, err = s.Bytes(); err != nil {\n")
300-
fmt.Fprintf(b3, " %s\n", decodeErrorMsg(fieldName))
352+
// decode - optimized: use s.ReadBytes() directly into fixed-size array
353+
fmt.Fprintf(b3, " if err = s.ReadBytes(obj.%s[:]); err != nil {\n", fieldName)
354+
fmt.Fprintf(b3, " return fmt.Errorf(\"error decoding field %s, err: %%w\", err)\n", fieldName)
301355
fmt.Fprintf(b3, " }\n")
302-
fmt.Fprintf(b3, " if len(b) > 0 && len(b) != %d {\n", size)
303-
fmt.Fprintf(b3, " %s\n", decodeLenMismatch(size))
304-
fmt.Fprintf(b3, " }\n")
305-
fmt.Fprintf(b3, " copy(obj.%s[:], b)\n", fieldName)
306356
}
307357

308358
func _shortArrayPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string, size int) error {
@@ -339,16 +389,11 @@ func _shortArrayPtrHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldN
339389
fmt.Fprintf(b2, " }\n")
340390
fmt.Fprintf(b2, " }\n")
341391

342-
// decode
343-
addDecodeBuf(b3)
344-
fmt.Fprintf(b3, " if b, err = s.Bytes(); err != nil {\n")
345-
fmt.Fprintf(b3, " %s\n", decodeErrorMsg(fieldName))
346-
fmt.Fprintf(b3, " }\n")
347-
fmt.Fprintf(b3, " if len(b) > 0 && len(b) != %d {\n", size)
348-
fmt.Fprintf(b3, " %s\n", decodeLenMismatch(size))
349-
fmt.Fprintf(b3, " }\n")
392+
// decode - optimized: use s.ReadBytes() directly
350393
fmt.Fprintf(b3, " obj.%s = &%s{}\n", fieldName, typ)
351-
fmt.Fprintf(b3, " copy((*obj.%s)[:], b)\n", fieldName)
394+
fmt.Fprintf(b3, " if err = s.ReadBytes((*obj.%s)[:]); err != nil {\n", fieldName)
395+
fmt.Fprintf(b3, " return fmt.Errorf(\"error decoding field %s, err: %%w\", err)\n", fieldName)
396+
fmt.Fprintf(b3, " }\n")
352397

353398
return nil
354399
}
@@ -463,8 +508,7 @@ func byteSliceHandle(b1, b2, b3 *bytes.Buffer, _ types.Type, fieldName string) {
463508
fmt.Fprintf(b2, " return err\n")
464509
fmt.Fprintf(b2, " }\n")
465510

466-
// decode
467-
addDecodeBuf(b3)
511+
// decode - no buffer needed, s.Bytes() returns directly
468512
fmt.Fprintf(b3, " if obj.%s, err = s.Bytes(); err != nil {\n", fieldName)
469513
fmt.Fprintf(b3, " %s\n", decodeErrorMsg(fieldName))
470514
fmt.Fprintf(b3, " }\n")
@@ -661,6 +705,77 @@ func hashSliceHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName s
661705
_shortArraySliceHandle(b1, b2, b3, fieldType, fieldName, 32)
662706
}
663707

708+
func hashSliceHandleOptimized(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
709+
// Optimized handler for []common.Hash with pre-allocation limit to prevent DoS attacks
710+
// Similar to decodeTopics2 in log.go - limits pre-allocation to 128 elements
711+
712+
var typ string
713+
if slc, ok := fieldType.(*types.Slice); !ok {
714+
_exit("hashSliceHandleOptimized: expected fieldType to be Slice")
715+
} else {
716+
if named, ok := slc.Elem().(*types.Named); !ok {
717+
_exit("hashSliceHandleOptimized: expected fieldType to be Slice Named")
718+
} else {
719+
typ = addToImports(named)
720+
}
721+
}
722+
723+
// size
724+
addIntSize(b1)
725+
fmt.Fprintf(b1, " gidx = (32 + 1) * len(obj.%s)\n", fieldName)
726+
fmt.Fprintf(b1, " size += rlp.ListPrefixLen(gidx) + gidx\n")
727+
728+
// encode
729+
addIntEncode(b2)
730+
fmt.Fprintf(b2, " gidx = (32 + 1) * len(obj.%s)\n", fieldName)
731+
fmt.Fprintf(b2, " if err := rlp.EncodeStructSizePrefix(gidx, w, b[:]); err != nil {\n")
732+
fmt.Fprintf(b2, " return err\n")
733+
fmt.Fprintf(b2, " }\n")
734+
fmt.Fprintf(b2, " for i := 0; i < len(obj.%s); i++ {\n", fieldName)
735+
fmt.Fprintf(b2, " if err := rlp.EncodeString(obj.%s[i][:], w, b[:]); err != nil {\n", fieldName)
736+
fmt.Fprintf(b2, " return err\n")
737+
fmt.Fprintf(b2, " }\n")
738+
fmt.Fprintf(b2, " }\n")
739+
740+
// decode - with pre-allocation optimization and fast-path for common cases
741+
// Calculate expected list length and apply hard limit of 128 to prevent DoS
742+
// Only call s.List() ONCE to get size, calculate listLen once
743+
// No buffer needed since both paths use direct ReadBytes
744+
fmt.Fprintf(b3, " l, err := s.List()\n")
745+
fmt.Fprintf(b3, " if err != nil {\n")
746+
fmt.Fprintf(b3, " return fmt.Errorf(\"error decoding field %s - expected list start, err: %%w\", err)\n", fieldName)
747+
fmt.Fprintf(b3, " }\n")
748+
fmt.Fprintf(b3, " var listLen int\n")
749+
fmt.Fprintf(b3, " if l > 0 {\n")
750+
fmt.Fprintf(b3, " listLen = int(l / (1 + 32)) // Each hash: 1-byte RLP prefix + 32-byte hash\n")
751+
fmt.Fprintf(b3, " preAlloc := min(128, listLen) // Hard limit against DoS\n")
752+
fmt.Fprintf(b3, " obj.%s = make([]%s, 0, preAlloc)\n", fieldName, typ)
753+
fmt.Fprintf(b3, " } else {\n")
754+
fmt.Fprintf(b3, " obj.%s = []%s{}\n", fieldName, typ)
755+
fmt.Fprintf(b3, " }\n")
756+
// Fast-path: Read directly into pre-allocated slice (zero-alloc, zero-copy)
757+
// Slow-path: Still use direct ReadBytes but allocate full size needed
758+
fmt.Fprintf(b3, " if listLen <= 128 {\n")
759+
fmt.Fprintf(b3, " // Fast-path: within pre-alloc limit, use pre-allocated buffer\n")
760+
fmt.Fprintf(b3, " obj.%s = obj.%s[:listLen]\n", fieldName, fieldName)
761+
fmt.Fprintf(b3, " for i := 0; i < listLen; i++ {\n")
762+
fmt.Fprintf(b3, " if err = s.ReadBytes(obj.%s[i][:]); err != nil {\n", fieldName)
763+
fmt.Fprintf(b3, " return err\n")
764+
fmt.Fprintf(b3, " }\n")
765+
fmt.Fprintf(b3, " }\n")
766+
fmt.Fprintf(b3, " } else if listLen > 128 {\n")
767+
fmt.Fprintf(b3, " // Slow-path: exceeded pre-alloc limit, allocate exact size and use direct ReadBytes\n")
768+
fmt.Fprintf(b3, " obj.%s = make([]%s, listLen)\n", fieldName, typ)
769+
fmt.Fprintf(b3, " for i := 0; i < listLen; i++ {\n")
770+
fmt.Fprintf(b3, " if err = s.ReadBytes(obj.%s[i][:]); err != nil {\n", fieldName)
771+
fmt.Fprintf(b3, " return err\n")
772+
fmt.Fprintf(b3, " }\n")
773+
fmt.Fprintf(b3, " }\n")
774+
fmt.Fprintf(b3, " }\n")
775+
776+
endListDecode(b3, fieldName)
777+
}
778+
664779
func hashPtrSliceHandle(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName string) {
665780
_shortArrayPtrSliceHandle(b1, b2, b3, fieldType, fieldName, 32)
666781
}

cmd/rlpgen/matcher.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ type handle func(b1, b2, b3 *bytes.Buffer, fieldType types.Type, fieldName strin
2929
// all possible types that this generator can handle for the time being.
3030
// to add a new type add a string representation of type here and write the handle function for it in the `handlers.go`
3131
var handlers = map[string]handle{
32+
"bool": boolHandle,
33+
"*bool": boolPtrHandle,
3234
"uint64": uintHandle,
3335
"*uint64": uintPtrHandle,
3436
"big.Int": bigIntHandle,
@@ -50,7 +52,7 @@ var handlers = map[string]handle{
5052
"[]*types.BlockNonce": blockNoncePtrSliceHandle,
5153
"[]common.Address": addressSliceHandle,
5254
"[]*common.Address": addressPtrSliceHandle,
53-
"[]common.Hash": hashSliceHandle,
55+
"[]common.Hash": hashSliceHandleOptimized,
5456
"[]*common.Hash": hashPtrSliceHandle,
5557
"[n]byte": byteArrayHandle,
5658
"*[n]byte": byteArrayPtrHandle,
@@ -76,9 +78,13 @@ func matchTypeToString(fieldType types.Type, in string) string {
7678
// matches string representation of a type to a corresponding function
7779
func matchStrTypeToFunc(strType string) handle {
7880
switch strType {
79-
case "int16", "int32", "int", "int64", "uint16", "uint32", "uint", "uint64":
81+
case "bool":
82+
return handlers["bool"]
83+
case "*bool":
84+
return handlers["*bool"]
85+
case "int8", "int16", "int32", "int", "int64", "uint8", "uint16", "uint32", "uint", "uint64":
8086
return handlers["uint64"]
81-
case "*int16", "*int32", "*int", "*int64", "*uint16", "*uint32", "*uint", "*uint64":
87+
case "*int8", "*int16", "*int32", "*int", "*int64", "*uint8", "*uint16", "*uint32", "*uint", "*uint64":
8288
return handlers["*uint64"]
8389
default:
8490
if fn, ok := handlers[strType]; ok {

0 commit comments

Comments
 (0)