Skip to content

Commit 62389b9

Browse files
committed
Add tests for TLS ticket resumption
Add test that checks if TLS ticket was picked up by TLS layer and reused after reconnection
1 parent 4a608ce commit 62389b9

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

integration_only.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package gocql
2+
3+
import "fmt"
4+
5+
func (pool *hostConnPool) MissingConnections() (int, error) {
6+
pool.mu.Lock()
7+
defer pool.mu.Unlock()
8+
9+
if pool.closed {
10+
return 0, fmt.Errorf("pool is closed")
11+
}
12+
_, missing := pool.connPicker.Size()
13+
return missing, nil
14+
}
15+
16+
func (p *policyConnPool) MissingConnections() (int, error) {
17+
p.mu.Lock()
18+
defer p.mu.Unlock()
19+
20+
total := 0
21+
22+
// close the pools
23+
for _, pool := range p.hostConnPools {
24+
missing, err := pool.MissingConnections()
25+
if err != nil {
26+
return 0, err
27+
}
28+
total += missing
29+
}
30+
return total, nil
31+
}
32+
33+
func (s *Session) MissingConnections() (int, error) {
34+
if s.pool == nil {
35+
return 0, fmt.Errorf("pool is nil")
36+
}
37+
return s.pool.MissingConnections()
38+
}
39+
40+
type ConnPickerIntegration interface {
41+
Pick(Token, ExecutableQuery) *Conn
42+
Put(*Conn)
43+
Remove(conn *Conn)
44+
InFlight() int
45+
Size() (int, int)
46+
Close()
47+
CloseAllConnections()
48+
49+
// NextShard returns the shardID to connect to.
50+
// nrShard specifies how many shards the host has.
51+
// If nrShards is zero, the caller shouldn't use shard-aware port.
52+
NextShard() (shardID, nrShards int)
53+
}
54+
55+
func (p *scyllaConnPicker) CloseAllConnections() {
56+
p.nrConns = 0
57+
closeConns(p.conns...)
58+
for id := range p.conns {
59+
p.conns[id] = nil
60+
}
61+
}
62+
63+
func (p *defaultConnPicker) CloseAllConnections() {
64+
closeConns(p.conns...)
65+
p.conns = p.conns[:0]
66+
}
67+
68+
func (p *nopConnPicker) CloseAllConnections() {
69+
}
70+
71+
func (pool *hostConnPool) CloseAllConnections() {
72+
if !pool.closed {
73+
return
74+
}
75+
pool.mu.Lock()
76+
println("Closing all connections in a pool")
77+
pool.connPicker.(ConnPickerIntegration).CloseAllConnections()
78+
println("Filling the pool")
79+
pool.mu.Unlock()
80+
pool.fill()
81+
}
82+
83+
func (p *policyConnPool) CloseAllConnections() {
84+
p.mu.Lock()
85+
defer p.mu.Unlock()
86+
87+
// close the pools
88+
for _, pool := range p.hostConnPools {
89+
pool.CloseAllConnections()
90+
}
91+
}
92+
93+
func (s *Session) CloseAllConnections() {
94+
if s.pool != nil {
95+
s.pool.CloseAllConnections()
96+
}
97+
}

session_test.go

+113
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ package gocql
2929

3030
import (
3131
"context"
32+
"crypto/tls"
3233
"fmt"
3334
"net"
35+
"sync"
3436
"testing"
37+
"time"
3538
)
3639

3740
func TestSessionAPI(t *testing.T) {
@@ -424,3 +427,113 @@ func TestRetryType_IgnoreRethrow(t *testing.T) {
424427
resetObserved()
425428
}
426429
}
430+
431+
type sessionCache struct {
432+
orig tls.ClientSessionCache
433+
values map[string][][]byte
434+
caches map[string][]int64
435+
valuesLock sync.Mutex
436+
}
437+
438+
func (c *sessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
439+
return c.orig.Get(sessionKey)
440+
}
441+
442+
func (c *sessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
443+
ticket, _, err := cs.ResumptionState()
444+
if err != nil {
445+
panic(err)
446+
}
447+
if len(ticket) == 0 {
448+
panic("ticket should not be empty")
449+
}
450+
c.valuesLock.Lock()
451+
c.values[sessionKey] = append(c.values[sessionKey], ticket)
452+
c.valuesLock.Unlock()
453+
c.orig.Put(sessionKey, cs)
454+
}
455+
456+
func (c *sessionCache) NumberOfTickets() int {
457+
c.valuesLock.Lock()
458+
defer c.valuesLock.Unlock()
459+
total := 0
460+
for _, tickets := range c.values {
461+
total += len(tickets)
462+
}
463+
return total
464+
}
465+
466+
func newSessionCache() *sessionCache {
467+
return &sessionCache{
468+
orig: tls.NewLRUClientSessionCache(1024),
469+
values: make(map[string][][]byte),
470+
caches: make(map[string][]int64),
471+
valuesLock: sync.Mutex{},
472+
}
473+
}
474+
475+
func withSessionCache(cache tls.ClientSessionCache) func(config *ClusterConfig) {
476+
return func(config *ClusterConfig) {
477+
config.SslOpts = &SslOptions{
478+
EnableHostVerification: false,
479+
Config: &tls.Config{
480+
ClientSessionCache: cache,
481+
InsecureSkipVerify: true,
482+
},
483+
}
484+
}
485+
}
486+
487+
func TestTLSTicketResumption(t *testing.T) {
488+
t.Skip("TLS ticket resumption is only supported by 2025.2 and later")
489+
490+
c := newSessionCache()
491+
session := createSession(t, withSessionCache(c))
492+
defer session.Close()
493+
494+
waitAllConnectionsOpened := func() error {
495+
println("wait all connections opened")
496+
defer println("end of wait all connections closed")
497+
endtime := time.Now().UTC().Add(time.Second * 10)
498+
for {
499+
if time.Now().UTC().After(endtime) {
500+
return fmt.Errorf("timed out waiting for all connections opened")
501+
}
502+
missing, err := session.MissingConnections()
503+
if err != nil {
504+
return fmt.Errorf("failed to get missing connections count: %w", err)
505+
}
506+
if missing == 0 {
507+
return nil
508+
}
509+
time.Sleep(time.Millisecond * 100)
510+
}
511+
}
512+
513+
if err := waitAllConnectionsOpened(); err != nil {
514+
t.Fatal(err)
515+
}
516+
tickets := c.NumberOfTickets()
517+
if tickets == 0 {
518+
t.Fatal("no tickets learned, which means that server does not support TLS tickets")
519+
}
520+
521+
session.CloseAllConnections()
522+
if err := waitAllConnectionsOpened(); err != nil {
523+
t.Fatal(err)
524+
}
525+
newTickets1 := c.NumberOfTickets()
526+
527+
session.CloseAllConnections()
528+
if err := waitAllConnectionsOpened(); err != nil {
529+
t.Fatal(err)
530+
}
531+
newTickets2 := c.NumberOfTickets()
532+
533+
if newTickets1 != tickets {
534+
t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets1, tickets)
535+
}
536+
if newTickets2 != tickets {
537+
t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets2, tickets)
538+
}
539+
}

0 commit comments

Comments
 (0)