Skip to content
This repository was archived by the owner on Jul 21, 2025. It is now read-only.

Commit bd3757b

Browse files
committed
transport: replace atomic.Value with atomic.Pointer[T]
Fixes #269
1 parent 655c377 commit bd3757b

File tree

4 files changed

+13
-17
lines changed

4 files changed

+13
-17
lines changed

transport/cluster.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@ import (
88
"sort"
99
"strconv"
1010
"strings"
11+
"sync/atomic"
1112
"time"
1213

1314
"github.com/scylladb/scylla-go-driver/frame"
1415
. "github.com/scylladb/scylla-go-driver/frame/response"
15-
16-
"go.uber.org/atomic"
1716
)
1817

1918
type (
@@ -26,7 +25,7 @@ type (
2625
)
2726

2827
type Cluster struct {
29-
topology atomic.Value // *topology
28+
topology atomic.Pointer[topology]
3029
control *Conn
3130
cfg ConnConfig
3231
handledEvents []frame.EventType // This will probably be moved to config.
@@ -121,7 +120,7 @@ func (c *Cluster) NewTokenAwareQueryInfo(t Token, ks string) (QueryInfo, error)
121120

122121
// TODO overflow and negative modulo.
123122
func (c *Cluster) generateOffset() uint64 {
124-
return c.queryInfoCounter.Inc() - 1
123+
return c.queryInfoCounter.Add(1) - 1
125124
}
126125

127126
// NewCluster also creates control connection and starts handling events and refreshing topology.
@@ -443,7 +442,7 @@ func parseTokensFromRow(n *Node, r frame.Row, ring *Ring) error {
443442
}
444443

445444
func (c *Cluster) Topology() *topology {
446-
return c.topology.Load().(*topology)
445+
return c.topology.Load()
447446
}
448447

449448
func (c *Cluster) setTopology(t *topology) {

transport/cluster_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func TestClusterIntegration(t *testing.T) {
8080
}
8181

8282
// There should be at least system keyspaces present.
83-
if len(c.topology.Load().(*topology).keyspaces) == 0 {
83+
if len(c.Topology().keyspaces) == 0 {
8484
t.Fatalf("Keyspaces failed to load")
8585
}
8686

transport/export_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package transport
22

33
func (p *ConnPool) AllConns() []*Conn {
44
var conns = make([]*Conn, len(p.conns))
5-
for i, v := range p.conns {
6-
conns[i], _ = v.Load().(*Conn)
5+
for i := range conns {
6+
conns[i] = p.loadConn(i)
77
}
88
return conns
99
}

transport/pool.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ import (
66
"log"
77
"math"
88
"net"
9+
"sync/atomic"
910
"time"
1011

1112
. "github.com/scylladb/scylla-go-driver/frame/response"
12-
13-
"go.uber.org/atomic"
1413
)
1514

1615
const poolCloseShard = -1
@@ -19,7 +18,7 @@ type ConnPool struct {
1918
host string
2019
nrShards int
2120
msbIgnore uint8
22-
conns []atomic.Value
21+
conns []atomic.Pointer[Conn]
2322
connClosedCh chan int // notification channel for when connection is closed
2423
connObs ConnObserver
2524
}
@@ -99,13 +98,11 @@ func (p *ConnPool) storeConn(conn *Conn) {
9998
}
10099

101100
func (p *ConnPool) loadConn(shard int) *Conn {
102-
conn, _ := p.conns[shard].Load().(*Conn)
103-
return conn
101+
return p.conns[shard].Load()
104102
}
105103

106104
func (p *ConnPool) clearConn(shard int) bool {
107-
conn, _ := p.conns[shard].Swap((*Conn)(nil)).(*Conn)
108-
return conn != nil
105+
return p.conns[shard].Swap(nil) != nil
109106
}
110107

111108
func (p *ConnPool) Close() {
@@ -115,7 +112,7 @@ func (p *ConnPool) Close() {
115112
// closeAll is called by PoolRefiller.
116113
func (p *ConnPool) closeAll() {
117114
for i := range p.conns {
118-
if conn, ok := p.conns[i].Swap((*Conn)(nil)).(*Conn); ok {
115+
if conn := p.conns[i].Swap(nil); conn != nil {
119116
conn.Close()
120117
}
121118
}
@@ -168,7 +165,7 @@ func (r *PoolRefiller) init(ctx context.Context, host string) error {
168165
host: host,
169166
nrShards: int(ss.NrShards),
170167
msbIgnore: ss.MsbIgnore,
171-
conns: make([]atomic.Value, int(ss.NrShards)),
168+
conns: make([]atomic.Pointer[Conn], int(ss.NrShards)),
172169
connClosedCh: make(chan int, int(ss.NrShards)+1),
173170
connObs: r.cfg.ConnObserver,
174171
}

0 commit comments

Comments
 (0)