Skip to content

Commit f4b7316

Browse files
committed
perf: add marshalOutputPool for fast-path marshal buffer reuse
Add marshalOutputPool (sync.Pool) to recycle []byte slices returned by the 10 type-specialized marshal functions (vectors and lists/sets). The connection layer (executeQuery, executeBatch) returns these buffers to the pool after the framer copies them via writeBytes. Key changes: - getMarshalOutput/putMarshalOutput: pool management with cap guard - pooledMarshalType: identifies types using pooled marshal fast paths - executeQuery: scan columns for poolable types, install defer before marshal loop so buffers are returned even on mid-loop errors - executeBatch: unconditional defer with pooledBufs collection, also handles error-path cleanup correctly - All 10 fast-path marshal functions use getMarshalOutput instead of make([]byte, size) - 9 new unit tests covering pool mechanics, round-trip reuse, and pooledMarshalType with 25 type coverage subcases
1 parent a7c78e2 commit f4b7316

File tree

3 files changed

+272
-10
lines changed

3 files changed

+272
-10
lines changed

conn.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,34 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) (iter *Iter) {
16781678
}
16791679

16801680
params.values = make([]queryValues, len(values))
1681+
1682+
// Return pooled marshal output buffers after the framer copies
1683+
// them (which happens inside c.exec → buildFrame → writeBytes).
1684+
// Installed before the marshal loop so that buffers are returned
1685+
// even if a later marshalQueryValue call fails mid-loop.
1686+
// Only install the defer when at least one column uses a poolable
1687+
// type to avoid ~50ns defer overhead on non-pooled queries.
1688+
{
1689+
cols := info.request.columns
1690+
vals := params.values
1691+
hasPooled := false
1692+
for _, col := range cols {
1693+
if pooledMarshalType(col.TypeInfo) {
1694+
hasPooled = true
1695+
break
1696+
}
1697+
}
1698+
if hasPooled {
1699+
defer func() {
1700+
for i, col := range cols {
1701+
if pooledMarshalType(col.TypeInfo) {
1702+
putMarshalOutput(vals[i].value)
1703+
}
1704+
}
1705+
}()
1706+
}
1707+
}
1708+
16811709
for i := 0; i < len(values); i++ {
16821710
v := &params.values[i]
16831711
value := values[i]
@@ -1866,6 +1894,17 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) (iter *Iter) {
18661894

18671895
hasLwtEntries := false
18681896

1897+
// pooledBufs collects marshalled byte slices from fast-path marshal
1898+
// functions so they can be returned to marshalOutputPool after the
1899+
// framer copies them. The defer is installed before the loop so that
1900+
// buffers are returned even if a later marshalQueryValue call fails.
1901+
var pooledBufs [][]byte
1902+
defer func() {
1903+
for _, buf := range pooledBufs {
1904+
putMarshalOutput(buf)
1905+
}
1906+
}()
1907+
18691908
for i := 0; i < n; i++ {
18701909
entry := &batch.Entries[i]
18711910
b := &req.statements[i]
@@ -1907,6 +1946,9 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) (iter *Iter) {
19071946
if err := marshalQueryValue(typ, value, v); err != nil {
19081947
return &Iter{err: err}
19091948
}
1949+
if pooledMarshalType(typ) {
1950+
pooledBufs = append(pooledBufs, v.value)
1951+
}
19101952
}
19111953

19121954
if !hasLwtEntries && info.request.lwt {

marshal.go

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,59 @@ func finishMarshalBuf(buf *bytes.Buffer) []byte {
122122
return result
123123
}
124124

125+
// marshalOutputPool pools []byte slices returned by fast-path marshal functions
126+
// (vectors and lists/sets). These slices are the final marshal output that gets
127+
// copied into the framer buffer by writeBytes. After the framer copies them,
128+
// the connection layer returns them to this pool via putMarshalOutput.
129+
var marshalOutputPool sync.Pool
130+
131+
// getMarshalOutput returns a []byte of exactly the requested size, from the
132+
// pool if a suitable buffer is available, or freshly allocated otherwise.
133+
func getMarshalOutput(size int) []byte {
134+
if bp := marshalOutputPool.Get(); bp != nil {
135+
buf := bp.([]byte)
136+
if cap(buf) >= size {
137+
return buf[:size]
138+
}
139+
}
140+
return make([]byte, size)
141+
}
142+
143+
// putMarshalOutput returns a []byte to the output pool. Nil slices are ignored.
144+
// Buffers larger than marshalBufMaxCap are discarded to avoid holding excessive
145+
// memory.
146+
func putMarshalOutput(buf []byte) {
147+
if buf == nil {
148+
return
149+
}
150+
if cap(buf) > marshalBufMaxCap {
151+
return
152+
}
153+
marshalOutputPool.Put(buf) //nolint:staticcheck // SA6002: []byte is a value type; boxing cost is acceptable for pool reuse
154+
}
155+
156+
// pooledMarshalType returns true if the given TypeInfo uses a marshal fast path
157+
// that allocates from marshalOutputPool. This is used by the connection layer to
158+
// determine which queryValues.value slices can be returned to the pool after
159+
// the framer copies them.
160+
func pooledMarshalType(info TypeInfo) bool {
161+
switch ti := info.(type) {
162+
case VectorType:
163+
switch ti.SubType.Type() {
164+
case TypeFloat, TypeDouble, TypeInt, TypeBigInt, TypeTimestamp, TypeCounter, TypeUUID, TypeTimeUUID:
165+
return true
166+
}
167+
case CollectionType:
168+
if ti.typ == TypeList || ti.typ == TypeSet {
169+
switch ti.Elem.Type() {
170+
case TypeFloat, TypeDouble, TypeInt, TypeBigInt, TypeTimestamp, TypeCounter:
171+
return true
172+
}
173+
}
174+
}
175+
return false
176+
}
177+
125178
// Marshaler is an interface for custom unmarshaler.
126179
// Each value of the 'CQL binary protocol' consist of <value_len> and <value_data>.
127180
// <value_len> can be 'unset'(-2), 'nil'(-1), 'zero'(0) or any value up to 2147483647.
@@ -1329,7 +1382,7 @@ func marshalVectorFloat32(vec []float32, dim int) ([]byte, error) {
13291382
if err != nil {
13301383
return nil, err
13311384
}
1332-
buf := make([]byte, size)
1385+
buf := getMarshalOutput(size)
13331386
if dim > 0 {
13341387
_ = buf[dim*4-1] // BCE hint
13351388
}
@@ -1347,7 +1400,7 @@ func marshalVectorFloat64(vec []float64, dim int) ([]byte, error) {
13471400
if err != nil {
13481401
return nil, err
13491402
}
1350-
buf := make([]byte, size)
1403+
buf := getMarshalOutput(size)
13511404
if dim > 0 {
13521405
_ = buf[dim*8-1] // BCE hint
13531406
}
@@ -1365,7 +1418,7 @@ func marshalVectorInt32(vec []int32, dim int) ([]byte, error) {
13651418
if err != nil {
13661419
return nil, err
13671420
}
1368-
buf := make([]byte, size)
1421+
buf := getMarshalOutput(size)
13691422
if dim > 0 {
13701423
_ = buf[dim*4-1] // BCE hint
13711424
}
@@ -1383,7 +1436,7 @@ func marshalVectorInt64(vec []int64, dim int) ([]byte, error) {
13831436
if err != nil {
13841437
return nil, err
13851438
}
1386-
buf := make([]byte, size)
1439+
buf := getMarshalOutput(size)
13871440
if dim > 0 {
13881441
_ = buf[dim*8-1] // BCE hint
13891442
}
@@ -1403,7 +1456,7 @@ func marshalVectorCounter(vec []int64, dim int) ([]byte, error) {
14031456
if err != nil {
14041457
return nil, err
14051458
}
1406-
buf := make([]byte, size)
1459+
buf := getMarshalOutput(size)
14071460
off := 0
14081461
for _, v := range vec {
14091462
buf[off] = 8
@@ -1422,7 +1475,7 @@ func marshalVectorUUID(vec []UUID, dim int) ([]byte, error) {
14221475
if err != nil {
14231476
return nil, err
14241477
}
1425-
buf := make([]byte, size)
1478+
buf := getMarshalOutput(size)
14261479
if dim > 0 {
14271480
_ = buf[dim*16-1] // BCE hint
14281481
}
@@ -1615,7 +1668,7 @@ func marshalListFloat32(list []float32) ([]byte, error) {
16151668
if err != nil {
16161669
return nil, err
16171670
}
1618-
buf := make([]byte, size)
1671+
buf := getMarshalOutput(size)
16191672
binary.BigEndian.PutUint32(buf, uint32(n))
16201673
off := 4
16211674
for _, v := range list {
@@ -1635,7 +1688,7 @@ func marshalListFloat64(list []float64) ([]byte, error) {
16351688
if err != nil {
16361689
return nil, err
16371690
}
1638-
buf := make([]byte, size)
1691+
buf := getMarshalOutput(size)
16391692
binary.BigEndian.PutUint32(buf, uint32(n))
16401693
off := 4
16411694
for _, v := range list {
@@ -1655,7 +1708,7 @@ func marshalListInt32(list []int32) ([]byte, error) {
16551708
if err != nil {
16561709
return nil, err
16571710
}
1658-
buf := make([]byte, size)
1711+
buf := getMarshalOutput(size)
16591712
binary.BigEndian.PutUint32(buf, uint32(n))
16601713
off := 4
16611714
for _, v := range list {
@@ -1675,7 +1728,7 @@ func marshalListInt64(list []int64) ([]byte, error) {
16751728
if err != nil {
16761729
return nil, err
16771730
}
1678-
buf := make([]byte, size)
1731+
buf := getMarshalOutput(size)
16791732
binary.BigEndian.PutUint32(buf, uint32(n))
16801733
off := 4
16811734
for _, v := range list {

marshal_buf_pool_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,170 @@ func BenchmarkMarshalMapStringInt(b *testing.B) {
586586
})
587587
}
588588
}
589+
590+
// --- marshalOutputPool tests ---
591+
592+
func TestGetMarshalOutputFresh(t *testing.T) {
593+
buf := getMarshalOutput(64)
594+
if len(buf) != 64 {
595+
t.Fatalf("expected len 64, got %d", len(buf))
596+
}
597+
if cap(buf) < 64 {
598+
t.Fatalf("expected cap >= 64, got %d", cap(buf))
599+
}
600+
}
601+
602+
func TestGetMarshalOutputFromPool(t *testing.T) {
603+
// Put a buffer into the pool, then retrieve it.
604+
orig := make([]byte, 0, 128)
605+
marshalOutputPool.Put(orig)
606+
607+
buf := getMarshalOutput(64)
608+
if len(buf) != 64 {
609+
t.Fatalf("expected len 64, got %d", len(buf))
610+
}
611+
// The pool should have returned the 128-cap buffer.
612+
if cap(buf) < 128 {
613+
t.Logf("pool did not return expected buffer (cap %d); may have been GC'd", cap(buf))
614+
}
615+
}
616+
617+
func TestGetMarshalOutputPoolTooSmall(t *testing.T) {
618+
// Put a small buffer, request a larger one — should get a fresh allocation.
619+
small := make([]byte, 0, 8)
620+
marshalOutputPool.Put(small)
621+
622+
buf := getMarshalOutput(64)
623+
if len(buf) != 64 {
624+
t.Fatalf("expected len 64, got %d", len(buf))
625+
}
626+
}
627+
628+
func TestPutMarshalOutputNil(t *testing.T) {
629+
// Should not panic.
630+
putMarshalOutput(nil)
631+
}
632+
633+
func TestPutMarshalOutputOversized(t *testing.T) {
634+
// Buffers larger than marshalBufMaxCap should be discarded.
635+
huge := make([]byte, marshalBufMaxCap+1)
636+
putMarshalOutput(huge)
637+
// If we get it back, the pool ignored the cap limit (unlikely).
638+
// Can't reliably test pool internals, just verify no panic.
639+
}
640+
641+
func TestMarshalOutputPoolRoundTrip(t *testing.T) {
642+
// Verify that a buffer returned to the pool can be reused.
643+
buf := getMarshalOutput(32)
644+
for i := range buf {
645+
buf[i] = byte(i)
646+
}
647+
putMarshalOutput(buf)
648+
649+
buf2 := getMarshalOutput(16)
650+
if len(buf2) != 16 {
651+
t.Fatalf("expected len 16, got %d", len(buf2))
652+
}
653+
// buf2 may or may not be the same underlying array (GC can collect pool entries).
654+
// Just verify it's usable.
655+
for i := range buf2 {
656+
buf2[i] = 0xff
657+
}
658+
}
659+
660+
func TestPooledMarshalType(t *testing.T) {
661+
tests := []struct {
662+
name string
663+
info TypeInfo
664+
expect bool
665+
}{
666+
// Vectors with pooled subtypes.
667+
{"vector<float>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeFloat}}, true},
668+
{"vector<double>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeDouble}}, true},
669+
{"vector<int>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeInt}}, true},
670+
{"vector<bigint>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeBigInt}}, true},
671+
{"vector<timestamp>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeTimestamp}}, true},
672+
{"vector<counter>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeCounter}}, true},
673+
{"vector<uuid>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeUUID}}, true},
674+
{"vector<timeuuid>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeTimeUUID}}, true},
675+
676+
// Vectors with non-pooled subtypes.
677+
{"vector<text>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeVarchar}}, false},
678+
{"vector<blob>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeBlob}}, false},
679+
{"vector<boolean>", VectorType{Dimensions: 3, SubType: NativeType{proto: protoVersion4, typ: TypeBoolean}}, false},
680+
681+
// Lists/sets with pooled elem types.
682+
{"list<float>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeFloat}}, true},
683+
{"list<double>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeDouble}}, true},
684+
{"list<int>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeInt}}, true},
685+
{"list<bigint>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeBigInt}}, true},
686+
{"list<timestamp>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeTimestamp}}, true},
687+
{"list<counter>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeCounter}}, true},
688+
{"set<float>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeSet}, Elem: NativeType{proto: protoVersion4, typ: TypeFloat}}, true},
689+
{"set<int>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeSet}, Elem: NativeType{proto: protoVersion4, typ: TypeInt}}, true},
690+
691+
// Lists/sets with non-pooled elem types.
692+
{"list<text>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeVarchar}}, false},
693+
{"set<blob>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeSet}, Elem: NativeType{proto: protoVersion4, typ: TypeBlob}}, false},
694+
{"list<uuid>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeList}, Elem: NativeType{proto: protoVersion4, typ: TypeUUID}}, false},
695+
696+
// Maps are never pooled.
697+
{"map<int,int>", CollectionType{NativeType: NativeType{proto: protoVersion4, typ: TypeMap}, Key: NativeType{proto: protoVersion4, typ: TypeInt}, Elem: NativeType{proto: protoVersion4, typ: TypeInt}}, false},
698+
699+
// Native types are never pooled.
700+
{"int", NativeType{proto: protoVersion4, typ: TypeInt}, false},
701+
{"text", NativeType{proto: protoVersion4, typ: TypeVarchar}, false},
702+
}
703+
704+
for _, tt := range tests {
705+
t.Run(tt.name, func(t *testing.T) {
706+
got := pooledMarshalType(tt.info)
707+
if got != tt.expect {
708+
t.Errorf("pooledMarshalType(%s) = %v, want %v", tt.name, got, tt.expect)
709+
}
710+
})
711+
}
712+
}
713+
714+
func TestMarshalVectorFloat32UsesPool(t *testing.T) {
715+
// Marshal, put back, marshal again — second call should reuse the buffer.
716+
vec := []float32{1.0, 2.0, 3.0}
717+
buf1, err := marshalVectorFloat32(vec, 3)
718+
if err != nil {
719+
t.Fatal(err)
720+
}
721+
// Copy the data before returning to pool.
722+
data1 := make([]byte, len(buf1))
723+
copy(data1, buf1)
724+
putMarshalOutput(buf1)
725+
726+
buf2, err := marshalVectorFloat32(vec, 3)
727+
if err != nil {
728+
t.Fatal(err)
729+
}
730+
// Verify data is correct regardless of pool reuse.
731+
if !bytes.Equal(data1, buf2) {
732+
t.Fatalf("data mismatch after pool reuse")
733+
}
734+
putMarshalOutput(buf2)
735+
}
736+
737+
func TestMarshalListInt32UsesPool(t *testing.T) {
738+
list := []int32{10, 20, 30}
739+
buf1, err := marshalListInt32(list)
740+
if err != nil {
741+
t.Fatal(err)
742+
}
743+
data1 := make([]byte, len(buf1))
744+
copy(data1, buf1)
745+
putMarshalOutput(buf1)
746+
747+
buf2, err := marshalListInt32(list)
748+
if err != nil {
749+
t.Fatal(err)
750+
}
751+
if !bytes.Equal(data1, buf2) {
752+
t.Fatalf("data mismatch after pool reuse")
753+
}
754+
putMarshalOutput(buf2)
755+
}

0 commit comments

Comments
 (0)