Skip to content

Commit ffd3eb7

Browse files
nikagradkropachev
andcommitted
feat(scylla): handle dynamic shard topology changes
- Implement connection migration during shard count changes - Preserve existing connections during shard count increases - Close excess connections when shard count decreases - Add comprehensive unit tests for topology change scenarios Fixes: "invalid number of shards" panic during connection pooling Co-authored-by: Dmitry Kropachev <[email protected]>
1 parent 426c16f commit ffd3eb7

File tree

5 files changed

+228
-26
lines changed

5 files changed

+228
-26
lines changed

connectionpool.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ func (h *hostConnPool) String() string {
272272
h.filling, h.closed, size, h.size, h.host)
273273
}
274274

275-
func newHostConnPool(session *Session, host *HostInfo, port, size int,
275+
func newHostConnPool(session *Session, host *HostInfo, port, size int, // FIXME: Remove unused port parameter
276276
keyspace string) *hostConnPool {
277277

278278
pool := &hostConnPool{
@@ -544,7 +544,9 @@ func (pool *hostConnPool) connect() (err error) {
544544

545545
// lazily initialize the connPicker when we know the required type
546546
pool.initConnPicker(conn)
547-
pool.connPicker.Put(conn)
547+
if err := pool.connPicker.Put(conn); err != nil {
548+
return err
549+
}
548550
conn.finalizeConnection()
549551

550552
return nil

connpicker.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88

99
type ConnPicker interface {
1010
Pick(Token, ExecutableQuery) *Conn
11-
Put(*Conn)
11+
Put(*Conn) error
1212
Remove(conn *Conn)
1313
InFlight() int
1414
Size() (int, int)
@@ -96,10 +96,11 @@ func (p *defaultConnPicker) Pick(Token, ExecutableQuery) *Conn {
9696
return leastBusyConn
9797
}
9898

99-
func (p *defaultConnPicker) Put(conn *Conn) {
99+
func (p *defaultConnPicker) Put(conn *Conn) error {
100100
p.mu.Lock()
101101
p.conns = append(p.conns, conn)
102102
p.mu.Unlock()
103+
return nil
103104
}
104105

105106
func (*defaultConnPicker) NextShard() (shardID, nrShards int) {
@@ -115,7 +116,8 @@ func (nopConnPicker) Pick(Token, ExecutableQuery) *Conn {
115116
return nil
116117
}
117118

118-
func (nopConnPicker) Put(*Conn) {
119+
func (nopConnPicker) Put(*Conn) error {
120+
return nil
119121
}
120122

121123
func (nopConnPicker) Remove(conn *Conn) {

integration_only.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func (s *Session) MissingConnections() (int, error) {
4545

4646
type ConnPickerIntegration interface {
4747
Pick(Token, ExecutableQuery) *Conn
48-
Put(*Conn)
48+
Put(*Conn) error
4949
Remove(conn *Conn)
5050
InFlight() int
5151
Size() (int, int)

scylla.go

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ outer:
395395
idx = p.shardOf(mmt)
396396
}
397397

398+
// FIXME: Add bounds check: shardOf() result may exceed len(p.conns) during topology transitions
399+
// FIXME: Prevent connection leaks when shard count decreases
400+
398401
if c := p.conns[idx]; c != nil {
399402
// We have this shard's connection
400403
// so let's give it to the caller.
@@ -452,22 +455,25 @@ func (p *scyllaConnPicker) shardOf(token int64Token) int {
452455
return int(sum >> 32)
453456
}
454457

455-
func (p *scyllaConnPicker) Put(conn *Conn) {
458+
func (p *scyllaConnPicker) Put(conn *Conn) error {
456459
var (
457460
nrShards = conn.scyllaSupported.nrShards
458461
shard = conn.scyllaSupported.shard
459462
)
460463

461464
if nrShards == 0 {
462-
panic(fmt.Sprintf("scylla: %s not a sharded connection", p.address))
465+
return errors.New("server reported that it has no shards")
463466
}
464467

465-
if nrShards != len(p.conns) {
466-
if nrShards != p.nrShards {
467-
panic(fmt.Sprintf("scylla: %s invalid number of shards", p.address))
468+
if nrShards != p.nrShards {
469+
if gocqlDebug {
470+
p.logger.Printf("scylla: %s shard count changed from %d to %d, rebuilding connection pool",
471+
p.address, p.nrShards, nrShards)
468472
}
473+
p.handleShardTopologyChange(conn, nrShards)
474+
} else if nrShards != len(p.conns) {
469475
conns := p.conns
470-
p.conns = make([]*Conn, nrShards, nrShards)
476+
p.conns = make([]*Conn, nrShards)
471477
copy(p.conns, conns)
472478
}
473479

@@ -509,6 +515,45 @@ func (p *scyllaConnPicker) Put(conn *Conn) {
509515
if p.shouldCloseExcessConns() {
510516
p.closeExcessConns()
511517
}
518+
519+
return nil
520+
}
521+
522+
func (p *scyllaConnPicker) handleShardTopologyChange(newConn *Conn, newShardCount int) {
523+
oldShardCount := p.nrShards
524+
oldConns := make([]*Conn, len(p.conns))
525+
copy(oldConns, p.conns)
526+
527+
if gocqlDebug {
528+
p.logger.Printf("scylla: %s handling shard topology change from %d to %d", p.address, oldShardCount, newShardCount)
529+
}
530+
531+
newConns := make([]*Conn, newShardCount)
532+
var toClose []*Conn
533+
migratedCount := 0
534+
535+
for i, conn := range oldConns {
536+
if conn != nil && i < newShardCount {
537+
newConns[i] = conn
538+
migratedCount++
539+
} else if conn != nil {
540+
toClose = append(toClose, conn)
541+
}
542+
}
543+
544+
p.nrShards = newShardCount
545+
p.msbIgnore = newConn.scyllaSupported.msbIgnore
546+
p.conns = newConns
547+
p.nrConns = migratedCount
548+
p.lastAttemptedShard = 0
549+
550+
if len(toClose) > 0 {
551+
go closeConns(toClose...)
552+
}
553+
554+
if gocqlDebug {
555+
p.logger.Printf("scylla: %s migrated %d/%d connections to new shard topology, closing %d excess connections", p.address, migratedCount, len(oldConns), len(toClose))
556+
}
512557
}
513558

514559
func (p *scyllaConnPicker) shouldCloseExcessConns() bool {
@@ -624,6 +669,10 @@ func (p *scyllaConnPicker) NextShard() (shardID, nrShards int) {
624669
// to consider the next shard after the previously attempted one
625670
for i := 1; i <= p.nrShards; i++ {
626671
shardID := (p.lastAttemptedShard + i) % p.nrShards
672+
673+
// FIXME: Replace p.nrShards with len(p.conns) in loop bounds to prevent index out of bounds
674+
// FIXME: Handle topology changes where cached p.nrShards != actual len(p.conns)
675+
627676
if p.conns == nil || p.conns[shardID] == nil {
628677
p.lastAttemptedShard = shardID
629678
return shardID, p.nrShards

scylla_test.go

Lines changed: 163 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@ import (
77
"context"
88
"fmt"
99
"math"
10+
"net"
1011
"runtime"
1112
"sync"
13+
"sync/atomic"
1214
"testing"
1315
"time"
1416

17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
19+
1520
"github.com/gocql/gocql/internal/streams"
1621
)
1722

@@ -128,19 +133,6 @@ func TestScyllaConnPickerRemove(t *testing.T) {
128133
}
129134
}
130135

131-
func mockConn(shard int) *Conn {
132-
return &Conn{
133-
streams: streams.New(),
134-
scyllaSupported: scyllaSupported{
135-
shard: shard,
136-
nrShards: 4,
137-
msbIgnore: 12,
138-
partitioner: "org.apache.cassandra.dht.Murmur3Partitioner",
139-
shardingAlgorithm: "biased-token-round-robin",
140-
},
141-
}
142-
}
143-
144136
func TestScyllaConnPickerShardOf(t *testing.T) {
145137
t.Parallel()
146138

@@ -155,7 +147,7 @@ func TestScyllaConnPickerShardOf(t *testing.T) {
155147
}
156148
}
157149

158-
func TestScyllaRandomConnPIcker(t *testing.T) {
150+
func TestScyllaRandomConnPicker(t *testing.T) {
159151
t.Parallel()
160152

161153
t.Run("max iterations", func(t *testing.T) {
@@ -328,3 +320,160 @@ func TestScyllaPortIterator(t *testing.T) {
328320
})
329321
}
330322
}
323+
324+
func TestScyllaConnPickerHandleShardTopologyChange(t *testing.T) {
325+
tests := []struct {
326+
name string
327+
initialShards int
328+
newShards int
329+
initialConns []int // shard IDs of initial connections
330+
expectedMigrated int
331+
expectedClosed int
332+
}{
333+
{
334+
name: "shard increase from 4 to 8",
335+
initialShards: 4,
336+
newShards: 8,
337+
initialConns: []int{0, 2, 3},
338+
expectedMigrated: 3, // All initial connections survive
339+
expectedClosed: 0,
340+
},
341+
{
342+
name: "shard decrease from 8 to 4",
343+
initialShards: 8,
344+
newShards: 4,
345+
initialConns: []int{0, 2, 5, 7},
346+
expectedMigrated: 2, // Only shards 0, 2 survive
347+
expectedClosed: 2, // Shards 5, 7 get closed
348+
},
349+
{
350+
name: "no change same count",
351+
initialShards: 8,
352+
newShards: 8,
353+
initialConns: []int{1, 3, 5},
354+
expectedMigrated: 4, // All initial connections survive + new one
355+
expectedClosed: 0,
356+
},
357+
{
358+
name: "massive decrease from 16 to 2",
359+
initialShards: 16,
360+
newShards: 2,
361+
initialConns: []int{0, 1, 5, 8, 12, 15},
362+
expectedMigrated: 2, // Only shards 0, 1 survive
363+
expectedClosed: 4, // Shards 5, 8, 12, 15 get closed
364+
},
365+
}
366+
367+
for _, tt := range tests {
368+
t.Run(tt.name, func(t *testing.T) {
369+
t.Parallel()
370+
371+
logger := &testLogger{}
372+
picker := &scyllaConnPicker{
373+
logger: logger,
374+
disableShardAwarePortUntil: new(atomic.Value),
375+
hostId: "test-host-id",
376+
shardAwareAddress: "192.168.1.1:19042", // Shard-aware port
377+
address: "192.168.1.1:9042", // Regular port
378+
conns: make([]*Conn, tt.initialShards),
379+
excessConns: make([]*Conn, 0),
380+
nrShards: tt.initialShards,
381+
msbIgnore: 12,
382+
nrConns: 0,
383+
pos: 0,
384+
lastAttemptedShard: 0,
385+
shardAwarePortDisabled: false,
386+
excessConnsLimitRate: 0.1,
387+
}
388+
picker.disableShardAwarePortUntil.Store(time.Time{})
389+
390+
var connectionsToCheck []*Conn
391+
392+
// Add initial connections
393+
for _, shardID := range tt.initialConns {
394+
conn := mockConnForPicker(shardID, tt.initialShards)
395+
err := picker.Put(conn)
396+
require.NoError(t, err)
397+
398+
if shardID >= tt.newShards {
399+
connectionsToCheck = append(connectionsToCheck, conn)
400+
}
401+
}
402+
403+
// Verify initial state
404+
assert.Equal(t, tt.initialShards, picker.nrShards)
405+
assert.Equal(t, len(tt.initialConns), picker.nrConns)
406+
407+
// Execute topology change
408+
newConn := mockConnForPicker(0, tt.newShards)
409+
err := picker.Put(newConn)
410+
require.NoError(t, err)
411+
412+
// Allow background goroutine to complete
413+
time.Sleep(50 * time.Millisecond)
414+
415+
// Verify new topology
416+
assert.Equal(t, tt.newShards, picker.nrShards)
417+
assert.Equal(t, len(picker.conns), tt.newShards)
418+
419+
// Count migrated connections
420+
migratedCount := 0
421+
for _, conn := range picker.conns {
422+
if conn != nil {
423+
migratedCount++
424+
}
425+
}
426+
427+
assert.Equal(t, tt.expectedMigrated, migratedCount)
428+
429+
// Verify connections that should be closed are actually closed
430+
closedCount := 0
431+
for _, conn := range connectionsToCheck {
432+
if conn.Closed() {
433+
closedCount++
434+
}
435+
}
436+
437+
assert.Equal(t, tt.expectedClosed, closedCount,
438+
"Expected %d connections to be closed, but %d were closed",
439+
tt.expectedClosed, closedCount)
440+
})
441+
}
442+
}
443+
444+
func mockConn(shard int) *Conn {
445+
return &Conn{
446+
streams: streams.New(),
447+
scyllaSupported: scyllaSupported{
448+
shard: shard,
449+
nrShards: 4,
450+
msbIgnore: 12,
451+
partitioner: "org.apache.cassandra.dht.Murmur3Partitioner",
452+
shardingAlgorithm: "biased-token-round-robin",
453+
},
454+
}
455+
}
456+
457+
func mockConnForPicker(shard, nrShards int) *Conn {
458+
ctx, cancel := context.WithCancel(context.Background())
459+
460+
conn1, conn2 := net.Pipe()
461+
_ = conn2.Close()
462+
463+
return &Conn{
464+
scyllaSupported: scyllaSupported{
465+
shard: shard,
466+
nrShards: nrShards,
467+
msbIgnore: 12,
468+
},
469+
conn: conn1,
470+
addr: fmt.Sprintf("192.168.1.%d:9042", shard+1),
471+
closed: false,
472+
mu: sync.Mutex{},
473+
logger: &testLogger{},
474+
ctx: ctx,
475+
cancel: cancel,
476+
calls: make(map[int]*callReq),
477+
streams: streams.New(),
478+
}
479+
}

0 commit comments

Comments
 (0)