1
1
package gocql
2
2
3
3
import (
4
+ "context"
4
5
"fmt"
5
6
"math"
6
7
"runtime"
7
8
"sync"
8
9
"testing"
10
+ "time"
9
11
10
12
"github.com/gocql/gocql/internal/streams"
11
13
)
@@ -167,6 +169,9 @@ func TestScyllaRandomConnPIcker(t *testing.T) {
167
169
})
168
170
169
171
t .Run ("async access of max iterations" , func (t * testing.T ) {
172
+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
173
+ defer cancel ()
174
+
170
175
s := & scyllaConnPicker {
171
176
nrShards : 4 ,
172
177
msbIgnore : 12 ,
@@ -175,26 +180,34 @@ func TestScyllaRandomConnPIcker(t *testing.T) {
175
180
}
176
181
177
182
var wg sync.WaitGroup
183
+ connCh := make (chan * Conn , 9 )
178
184
for i := 0 ; i < 3 ; i ++ {
179
185
wg .Add (1 )
180
- go pickLoop (t , s , 3 , & wg )
186
+ go func () {
187
+ defer wg .Done ()
188
+ for i := 0 ; i < 3 ; i ++ {
189
+ select {
190
+ case connCh <- s .Pick (token (nil )):
191
+ case <- ctx .Done ():
192
+ }
193
+ }
194
+ }()
181
195
}
182
196
wg .Wait ()
197
+ close (connCh )
183
198
184
199
if s .pos != 8 {
185
200
t .Fatalf ("expected position to be 8 | actual %d" , s .pos )
186
201
}
187
- })
188
- }
189
-
190
- func pickLoop (t * testing.T , s * scyllaConnPicker , c int , wg * sync.WaitGroup ) {
191
- t .Helper ()
192
- for i := 0 ; i < c ; i ++ {
193
- if s .Pick (token (nil )) == nil {
194
- t .Fatal ("expected connection" )
202
+ if len (connCh ) != 9 {
203
+ t .Fatalf ("expected 9 connection picks, got %d" , len (connCh ))
195
204
}
196
- }
197
- wg .Done ()
205
+ for conn := range connCh {
206
+ if conn == nil {
207
+ t .Fatal ("expected connection, got nil" )
208
+ }
209
+ }
210
+ })
198
211
}
199
212
200
213
func TestScyllaLWTExtParsing (t * testing.T ) {
0 commit comments