@@ -3,8 +3,8 @@ package transport
33import (
44 "context"
55 "fmt"
6- "log"
76 "math"
7+ "math/bits"
88 "net"
99 "time"
1010
@@ -126,10 +126,11 @@ func (p *ConnPool) closeAll() {
126126}
127127
128128type PoolRefiller struct {
129- addr string
130- pool ConnPool
131- cfg ConnConfig
132- active int
129+ addr string
130+ pool ConnPool
131+ cfg ConnConfig
132+ active int
133+ fillRandomly bool
133134}
134135
135136func (r * PoolRefiller ) init (ctx context.Context , host string ) error {
@@ -158,13 +159,17 @@ func (r *PoolRefiller) init(ctx context.Context, host string) error {
158159 if v , ok := s .Options [ScyllaShardAwarePortSSL ]; ok {
159160 r .addr = net .JoinHostPort (host , v [0 ])
160161 } else {
161- return fmt .Errorf ("missing encrypted shard aware port information %v" , s .Options )
162+ r .cfg .Logger .Warnf ("missing encrypted shard aware port information %v; falling back to random shard discovery" , s .Options )
163+ r .addr = net .JoinHostPort (host , r .cfg .DefaultPort )
164+ r .fillRandomly = true
162165 }
163166 } else {
164167 if v , ok := s .Options [ScyllaShardAwarePort ]; ok {
165168 r .addr = net .JoinHostPort (host , v [0 ])
166169 } else {
167- return fmt .Errorf ("missing shard aware port information %v" , s .Options )
170+ r .cfg .Logger .Warnf ("missing shard aware port information %v; falling back to random shard discovery" , s .Options )
171+ r .addr = net .JoinHostPort (host , r .cfg .DefaultPort )
172+ r .fillRandomly = true
168173 }
169174 }
170175
@@ -222,17 +227,39 @@ func (r *PoolRefiller) loop(ctx context.Context) {
222227 }
223228}
224229
230+ // storeShard assumes conn is non-nil.
231+ func (r * PoolRefiller ) storeShard (conn * Conn , span span ) bool {
232+ if r .pool .loadConn (conn .Shard ()) != nil {
233+ if r .pool .connObs != nil {
234+ r .pool .connObs .OnConnect (ConnectEvent {
235+ ConnEvent : conn .Event (),
236+ span : span ,
237+ Err : fmt .Errorf ("shard already in pool" ),
238+ })
239+ }
240+ conn .Close ()
241+ return false
242+ }
243+
244+ if r .pool .connObs != nil {
245+ r .pool .connObs .OnConnect (ConnectEvent {ConnEvent : conn .Event (), span : span })
246+ }
247+ conn .setOnClose (r .onConnClose )
248+ r .pool .storeConn (conn )
249+ r .active ++
250+ return true
251+ }
252+
225253func (r * PoolRefiller ) fill (ctx context.Context ) {
226- if ! r . needsFilling () {
227- return
254+ if r . fillRandomly {
255+ r . fillRandom ( ctx )
228256 }
229257
230258 si := ShardInfo {
231259 NrShards : uint16 (r .pool .nrShards ),
232260 MsbIgnore : r .pool .msbIgnore ,
233261 }
234-
235- for i := 0 ; i < r .pool .nrShards ; i ++ {
262+ for i := 0 ; r .needsFilling () && i < r .pool .nrShards ; i ++ {
236263 if r .pool .loadConn (i ) != nil {
237264 continue
238265 }
@@ -245,23 +272,14 @@ func (r *PoolRefiller) fill(ctx context.Context) {
245272 if r .pool .connObs != nil {
246273 r .pool .connObs .OnConnect (ConnectEvent {ConnEvent : ConnEvent {Addr : r .addr , Shard : si .Shard }, span : span , Err : err })
247274 }
248- if conn != nil {
249- conn .Close ()
250- }
251275 continue
252276 }
253- if r .pool .connObs != nil {
254- r .pool .connObs .OnConnect (ConnectEvent {ConnEvent : conn .Event (), span : span })
255- }
256277
278+ r .storeShard (conn , span )
257279 if conn .Shard () != i {
258- log .Fatalf ("opened conn to wrong shard: expected %d got %d" , i , conn .Shard ())
259- }
260- conn .setOnClose (r .onConnClose )
261- r .pool .storeConn (conn )
262- r .active ++
263-
264- if ! r .needsFilling () {
280+ r .cfg .Logger .Warnf ("opened conn to wrong shard: expected %d got %d; falling back to random discovery" , i , conn .Shard ())
281+ r .fillRandomly = true
282+ r .fillRandom (ctx )
265283 return
266284 }
267285 }
@@ -270,3 +288,21 @@ func (r *PoolRefiller) fill(ctx context.Context) {
270288func (r * PoolRefiller ) needsFilling () bool {
271289 return r .active < r .pool .nrShards
272290}
291+
292+ func (r * PoolRefiller ) fillRandom (ctx context.Context ) {
293+ // https://en.wikipedia.org/wiki/Coupon_collector%27s_problem
294+ maxTries := r .pool .nrShards * bits .Len (uint (r .pool .nrShards ))
295+ for try := 0 ; r .needsFilling () && try < maxTries ; try ++ {
296+ span := startSpan ()
297+ conn , err := OpenConn (ctx , r .addr , nil , r .cfg )
298+ span .stop ()
299+ if err != nil {
300+ if r .pool .connObs != nil {
301+ r .pool .connObs .OnConnect (ConnectEvent {ConnEvent : ConnEvent {Addr : r .addr , Shard : UnknownShard }, span : span , Err : err })
302+ }
303+ continue
304+ }
305+
306+ r .storeShard (conn , span )
307+ }
308+ }
0 commit comments