Skip to content

Commit 8301822

Browse files
authored
perf(data): arena decoder retains big.Int word backing across resets (#272)
Signed-off-by: Chris Gianelloni <wolf31o2@blinklabs.io>
1 parent ea3c757 commit 8301822

3 files changed

Lines changed: 158 additions & 1 deletion

File tree

data/arena_decoder.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ func (a *arenaChunks[S]) reset(retainCap int) {
8989
a.offset = 0
9090
}
9191

92+
func resetBigIntChunks(a *arenaChunks[big.Int], retainCap int) {
93+
// Keep big.Int word backing arrays across decoder reuse; integer decode
94+
// overwrites every allocated value before returning it.
95+
if len(a.chunks) > retainCap {
96+
for i := retainCap; i < len(a.chunks); i++ {
97+
a.chunks[i] = nil
98+
}
99+
a.chunks = a.chunks[:retainCap]
100+
}
101+
a.chunkIdx = 0
102+
a.offset = 0
103+
}
104+
92105
type arenaSlices[S any] struct {
93106
chunks [][]S
94107
pos int
@@ -182,7 +195,7 @@ func NewDecoder() *Decoder {
182195

183196
// Reset clears all internal arena pools so the Decoder can be reused; previously returned values become invalid.
184197
func (d *Decoder) Reset() {
185-
d.bigInts.reset(dataDecodeRetainCap)
198+
resetBigIntChunks(&d.bigInts, dataDecodeRetainCap)
186199
d.integers.reset(dataDecodeRetainCap)
187200
d.byteStrings.reset(dataDecodeRetainCap)
188201
d.constrs.reset(dataDecodeRetainCap)

data/data_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,40 @@ func TestDecoderReuseMatchesDecode(t *testing.T) {
487487
})
488488
}
489489

490+
func TestDecoderResetOverwritesRetainedBigInts(t *testing.T) {
491+
decoder := NewDecoder()
492+
493+
large := new(big.Int).Lsh(big.NewInt(1), 80)
494+
largeEncoded, err := Encode(NewInteger(large))
495+
if err != nil {
496+
t.Fatalf("Encode large failed: %v", err)
497+
}
498+
smallEncoded, err := Encode(NewInteger(big.NewInt(-3)))
499+
if err != nil {
500+
t.Fatalf("Encode small failed: %v", err)
501+
}
502+
503+
if _, err := decoder.Decode(largeEncoded); err != nil {
504+
t.Fatalf("Decode large failed: %v", err)
505+
}
506+
decoder.Reset()
507+
decoded, err := decoder.Decode(smallEncoded)
508+
if err != nil {
509+
t.Fatalf("Decode small failed: %v", err)
510+
}
511+
512+
integer, ok := decoded.(*Integer)
513+
if !ok {
514+
t.Fatalf("decoded value = %T, want *Integer", decoded)
515+
}
516+
if got, want := integer.Inner.Int64(), int64(-3); got != want {
517+
t.Fatalf("decoded integer = %d, want %d", got, want)
518+
}
519+
if !decoded.Equal(NewInteger(big.NewInt(-3))) {
520+
t.Fatalf("decoded value changed after overwrite: got %s", decoded)
521+
}
522+
}
523+
490524
func TestDecodeCBORTag(t *testing.T) {
491525
tests := []struct {
492526
name string

syn/encode_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,116 @@ func TestDeBruijnDecoderReuse(t *testing.T) {
426426
}
427427
}
428428

429+
func TestDeBruijnDecoderReuseOverwritesBigInts(t *testing.T) {
430+
tests := []struct {
431+
name string
432+
programSequence []string
433+
}{
434+
{
435+
name: "large_then_small",
436+
programSequence: []string{"large", "small"},
437+
},
438+
{
439+
name: "small_then_large",
440+
programSequence: []string{"small", "large"},
441+
},
442+
{
443+
name: "same_size_large_values",
444+
programSequence: []string{"same_size_a", "same_size_b"},
445+
},
446+
{
447+
name: "large_then_zero",
448+
programSequence: []string{"large", "zero"},
449+
},
450+
}
451+
452+
for _, tt := range tests {
453+
t.Run(tt.name, func(t *testing.T) {
454+
decoder := NewDeBruijnDecoder()
455+
456+
large := new(big.Int).Lsh(big.NewInt(1), 80)
457+
sameSizeA := new(big.Int).Add(new(big.Int).Lsh(big.NewInt(1), 80), big.NewInt(5))
458+
sameSizeB := new(big.Int).Add(new(big.Int).Lsh(big.NewInt(1), 80), big.NewInt(9))
459+
460+
largeProgram := &Program[DeBruijn]{
461+
Version: lang.LanguageVersionV3,
462+
Term: &Constant{Con: &Integer{Inner: large}},
463+
}
464+
smallProgram := &Program[DeBruijn]{
465+
Version: lang.LanguageVersionV3,
466+
Term: &Constant{Con: &Integer{Inner: big.NewInt(-3)}},
467+
}
468+
sameSizeAProgram := &Program[DeBruijn]{
469+
Version: lang.LanguageVersionV3,
470+
Term: &Constant{Con: &Integer{Inner: sameSizeA}},
471+
}
472+
sameSizeBProgram := &Program[DeBruijn]{
473+
Version: lang.LanguageVersionV3,
474+
Term: &Constant{Con: &Integer{Inner: sameSizeB}},
475+
}
476+
zeroProgram := &Program[DeBruijn]{
477+
Version: lang.LanguageVersionV3,
478+
Term: &Constant{Con: &Integer{Inner: big.NewInt(0)}},
479+
}
480+
481+
programs := map[string]*Program[DeBruijn]{
482+
"large": largeProgram,
483+
"small": smallProgram,
484+
"same_size_a": sameSizeAProgram,
485+
"same_size_b": sameSizeBProgram,
486+
"zero": zeroProgram,
487+
}
488+
expectedValues := map[string]*big.Int{
489+
"large": large,
490+
"small": big.NewInt(-3),
491+
"same_size_a": sameSizeA,
492+
"same_size_b": sameSizeB,
493+
"zero": big.NewInt(0),
494+
}
495+
496+
for _, programName := range tt.programSequence {
497+
program := programs[programName]
498+
expectedValue := expectedValues[programName]
499+
500+
expectedEncoded, err := Encode(program)
501+
if err != nil {
502+
t.Fatalf("Encode %s failed: %v", programName, err)
503+
}
504+
505+
decoded, err := decoder.Decode(expectedEncoded)
506+
if err != nil {
507+
t.Fatalf("Decode %s failed: %v", programName, err)
508+
}
509+
510+
constant, ok := decoded.Term.(*Constant)
511+
if !ok {
512+
t.Fatalf("decoded %s term = %T, want *Constant", programName, decoded.Term)
513+
}
514+
integer, ok := constant.Con.(*Integer)
515+
if !ok {
516+
t.Fatalf("decoded %s constant = %T, want *Integer", programName, constant.Con)
517+
}
518+
if expectedValue.IsInt64() {
519+
if got, want := integer.Inner.Int64(), expectedValue.Int64(); got != want {
520+
t.Fatalf("decoded %s integer = %d, want %d", programName, got, want)
521+
}
522+
}
523+
if integer.Inner.Cmp(expectedValue) != 0 {
524+
t.Fatalf("decoded %s integer = %s, want %s", programName, integer.Inner, expectedValue)
525+
}
526+
527+
reencoded, err := Encode(decoded)
528+
if err != nil {
529+
t.Fatalf("Re-encode %s failed: %v", programName, err)
530+
}
531+
if !bytes.Equal(expectedEncoded, reencoded) {
532+
t.Fatalf("%s integer roundtrip mismatch after decoder reuse", programName)
533+
}
534+
}
535+
})
536+
}
537+
}
538+
429539
func TestFlatRoundtripConstantTypes(t *testing.T) {
430540
tests := []struct {
431541
name string

0 commit comments

Comments
 (0)