Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (h *hostConnPool) String() string {
h.filling, h.closed, size, h.size, h.host)
}

func newHostConnPool(session *Session, host *HostInfo, port, size int,
func newHostConnPool(session *Session, host *HostInfo, port, size int, // FIXME: Remove unused port parameter
keyspace string) *hostConnPool {

pool := &hostConnPool{
Expand Down Expand Up @@ -544,7 +544,9 @@ func (pool *hostConnPool) connect() (err error) {

// lazily initialize the connPicker when we know the required type
pool.initConnPicker(conn)
pool.connPicker.Put(conn)
if err := pool.connPicker.Put(conn); err != nil {
return err
Copy link
Collaborator

@dkropachev dkropachev Nov 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to close connection if put had failed, look at the Put code, probably connection is being closed there too, so you need to extract all connection closing code out to here, if it is possible.

}
conn.finalizeConnection()

return nil
Expand Down
8 changes: 5 additions & 3 deletions connpicker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

type ConnPicker interface {
Pick(Token, ExecutableQuery) *Conn
Put(*Conn)
Put(*Conn) error
Remove(conn *Conn)
InFlight() int
Size() (int, int)
Expand Down Expand Up @@ -96,10 +96,11 @@ func (p *defaultConnPicker) Pick(Token, ExecutableQuery) *Conn {
return leastBusyConn
}

func (p *defaultConnPicker) Put(conn *Conn) {
func (p *defaultConnPicker) Put(conn *Conn) error {
p.mu.Lock()
p.conns = append(p.conns, conn)
p.mu.Unlock()
return nil
}

func (*defaultConnPicker) NextShard() (shardID, nrShards int) {
Expand All @@ -115,7 +116,8 @@ func (nopConnPicker) Pick(Token, ExecutableQuery) *Conn {
return nil
}

func (nopConnPicker) Put(*Conn) {
func (nopConnPicker) Put(*Conn) error {
return nil
}

func (nopConnPicker) Remove(conn *Conn) {
Expand Down
2 changes: 1 addition & 1 deletion integration_only.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (s *Session) MissingConnections() (int, error) {

type ConnPickerIntegration interface {
Pick(Token, ExecutableQuery) *Conn
Put(*Conn)
Put(*Conn) error
Remove(conn *Conn)
InFlight() int
Size() (int, int)
Expand Down
54 changes: 48 additions & 6 deletions scylla.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,22 +452,25 @@ func (p *scyllaConnPicker) shardOf(token int64Token) int {
return int(sum >> 32)
}

func (p *scyllaConnPicker) Put(conn *Conn) {
func (p *scyllaConnPicker) Put(conn *Conn) error {
var (
nrShards = conn.scyllaSupported.nrShards
shard = conn.scyllaSupported.shard
)

if nrShards == 0 {
panic(fmt.Sprintf("scylla: %s not a sharded connection", p.address))
return errors.New("server reported that it has no shards")
}

if nrShards != len(p.conns) {
if nrShards != p.nrShards {
panic(fmt.Sprintf("scylla: %s invalid number of shards", p.address))
if nrShards != p.nrShards {
if gocqlDebug {
p.logger.Printf("scylla: %s shard count changed from %d to %d, rebuilding connection pool",
p.address, p.nrShards, nrShards)
}
p.handleShardTopologyChange(conn, nrShards)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
p.handleShardTopologyChange(conn, nrShards)
p.handleShardCountChange(conn, nrShards)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also before do so, please check if new shard count is not a zero, if it is zero you need to drop the connection.

We also need to consider that it could persist, which may lead to connection storm.
Let's address it in here - #614

} else if nrShards != len(p.conns) {
conns := p.conns
p.conns = make([]*Conn, nrShards, nrShards)
p.conns = make([]*Conn, nrShards)
copy(p.conns, conns)
}

Expand Down Expand Up @@ -509,6 +512,45 @@ func (p *scyllaConnPicker) Put(conn *Conn) {
if p.shouldCloseExcessConns() {
p.closeExcessConns()
}

return nil
}

func (p *scyllaConnPicker) handleShardTopologyChange(newConn *Conn, newShardCount int) {
oldShardCount := p.nrShards
oldConns := make([]*Conn, len(p.conns))
copy(oldConns, p.conns)

if gocqlDebug {
p.logger.Printf("scylla: %s handling shard topology change from %d to %d", p.address, oldShardCount, newShardCount)
}

newConns := make([]*Conn, newShardCount)
var toClose []*Conn
migratedCount := 0

for i, conn := range oldConns {
if conn != nil && i < newShardCount {
newConns[i] = conn
migratedCount++
} else if conn != nil {
toClose = append(toClose, conn)
}
}
Comment on lines +532 to +539
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i, conn := range oldConns {
if conn != nil && i < newShardCount {
newConns[i] = conn
migratedCount++
} else if conn != nil {
toClose = append(toClose, conn)
}
}
for i, conn := range oldConns {
if conn == nil {
continue
}
if i < newShardCount {
newConns[i] = conn
migratedCount++
} else {
toClose = append(toClose, conn)
}
}


p.nrShards = newShardCount
p.msbIgnore = newConn.scyllaSupported.msbIgnore
p.conns = newConns
p.nrConns = migratedCount
p.lastAttemptedShard = 0

if len(toClose) > 0 {
go closeConns(toClose...)
}

if gocqlDebug {
p.logger.Printf("scylla: %s migrated %d/%d connections to new shard topology, closing %d excess connections", p.address, migratedCount, len(oldConns), len(toClose))
}
}

func (p *scyllaConnPicker) shouldCloseExcessConns() bool {
Expand Down
177 changes: 163 additions & 14 deletions scylla_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ import (
"context"
"fmt"
"math"
"net"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gocql/gocql/internal/streams"
)

Expand Down Expand Up @@ -128,19 +133,6 @@ func TestScyllaConnPickerRemove(t *testing.T) {
}
}

func mockConn(shard int) *Conn {
return &Conn{
streams: streams.New(),
scyllaSupported: scyllaSupported{
shard: shard,
nrShards: 4,
msbIgnore: 12,
partitioner: "org.apache.cassandra.dht.Murmur3Partitioner",
shardingAlgorithm: "biased-token-round-robin",
},
}
}

func TestScyllaConnPickerShardOf(t *testing.T) {
t.Parallel()

Expand All @@ -155,7 +147,7 @@ func TestScyllaConnPickerShardOf(t *testing.T) {
}
}

func TestScyllaRandomConnPIcker(t *testing.T) {
func TestScyllaRandomConnPicker(t *testing.T) {
t.Parallel()

t.Run("max iterations", func(t *testing.T) {
Expand Down Expand Up @@ -328,3 +320,160 @@ func TestScyllaPortIterator(t *testing.T) {
})
}
}

func TestScyllaConnPickerHandleShardTopologyChange(t *testing.T) {
tests := []struct {
name string
initialShards int
newShards int
initialConns []int // shard IDs of initial connections
expectedMigrated int
expectedClosed int
}{
{
name: "shard increase from 4 to 8",
initialShards: 4,
newShards: 8,
initialConns: []int{0, 2, 3},
expectedMigrated: 3, // All initial connections survive
expectedClosed: 0,
},
{
name: "shard decrease from 8 to 4",
initialShards: 8,
newShards: 4,
initialConns: []int{0, 2, 5, 7},
expectedMigrated: 2, // Only shards 0, 2 survive
expectedClosed: 2, // Shards 5, 7 get closed
},
{
name: "no change same count",
initialShards: 8,
newShards: 8,
initialConns: []int{1, 3, 5},
expectedMigrated: 4, // All initial connections survive + new one
expectedClosed: 0,
},
{
name: "massive decrease from 16 to 2",
initialShards: 16,
newShards: 2,
initialConns: []int{0, 1, 5, 8, 12, 15},
expectedMigrated: 2, // Only shards 0, 1 survive
expectedClosed: 4, // Shards 5, 8, 12, 15 get closed
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

logger := &testLogger{}
picker := &scyllaConnPicker{
logger: logger,
disableShardAwarePortUntil: new(atomic.Value),
hostId: "test-host-id",
shardAwareAddress: "192.168.1.1:19042", // Shard-aware port
address: "192.168.1.1:9042", // Regular port
conns: make([]*Conn, tt.initialShards),
excessConns: make([]*Conn, 0),
nrShards: tt.initialShards,
msbIgnore: 12,
nrConns: 0,
pos: 0,
lastAttemptedShard: 0,
shardAwarePortDisabled: false,
excessConnsLimitRate: 0.1,
}
picker.disableShardAwarePortUntil.Store(time.Time{})

var connectionsToCheck []*Conn

// Add initial connections
for _, shardID := range tt.initialConns {
conn := mockConnForPicker(shardID, tt.initialShards)
err := picker.Put(conn)
require.NoError(t, err)

if shardID >= tt.newShards {
connectionsToCheck = append(connectionsToCheck, conn)
}
}

// Verify initial state
assert.Equal(t, tt.initialShards, picker.nrShards)
assert.Equal(t, len(tt.initialConns), picker.nrConns)

// Execute topology change
newConn := mockConnForPicker(0, tt.newShards)
err := picker.Put(newConn)
require.NoError(t, err)

// Allow background goroutine to complete
time.Sleep(50 * time.Millisecond)

// Verify new topology
assert.Equal(t, tt.newShards, picker.nrShards)
assert.Equal(t, len(picker.conns), tt.newShards)

// Count migrated connections
migratedCount := 0
for _, conn := range picker.conns {
if conn != nil {
migratedCount++
}
}

assert.Equal(t, tt.expectedMigrated, migratedCount)

// Verify connections that should be closed are actually closed
closedCount := 0
for _, conn := range connectionsToCheck {
if conn.Closed() {
closedCount++
}
}

assert.Equal(t, tt.expectedClosed, closedCount,
"Expected %d connections to be closed, but %d were closed",
tt.expectedClosed, closedCount)
})
}
}

func mockConn(shard int) *Conn {
return &Conn{
streams: streams.New(),
scyllaSupported: scyllaSupported{
shard: shard,
nrShards: 4,
msbIgnore: 12,
partitioner: "org.apache.cassandra.dht.Murmur3Partitioner",
shardingAlgorithm: "biased-token-round-robin",
},
}
}

func mockConnForPicker(shard, nrShards int) *Conn {
ctx, cancel := context.WithCancel(context.Background())

conn1, conn2 := net.Pipe()
_ = conn2.Close()

return &Conn{
scyllaSupported: scyllaSupported{
shard: shard,
nrShards: nrShards,
msbIgnore: 12,
},
conn: conn1,
addr: fmt.Sprintf("192.168.1.%d:9042", shard+1),
closed: false,
mu: sync.Mutex{},
logger: &testLogger{},
ctx: ctx,
cancel: cancel,
calls: make(map[int]*callReq),
streams: streams.New(),
}
}
Loading