Skip to content

Commit 6e53443

Browse files
committed
Add tests for TLS ticket resumption
Add test that checks if TLS ticket was picked up by TLS layer and reused after reconnection
1 parent 4a608ce commit 6e53443

File tree

2 files changed

+216
-0
lines changed

2 files changed

+216
-0
lines changed

integration_only.go

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//go:build integration
2+
// +build integration
3+
4+
package gocql
5+
6+
// This file contains code to enable easy access to driver internals
7+
// To be used only for integration test
8+
9+
import "fmt"
10+
11+
func (pool *hostConnPool) MissingConnections() (int, error) {
12+
pool.mu.Lock()
13+
defer pool.mu.Unlock()
14+
15+
if pool.closed {
16+
return 0, fmt.Errorf("pool is closed")
17+
}
18+
_, missing := pool.connPicker.Size()
19+
return missing, nil
20+
}
21+
22+
func (p *policyConnPool) MissingConnections() (int, error) {
23+
p.mu.Lock()
24+
defer p.mu.Unlock()
25+
26+
total := 0
27+
28+
// close the pools
29+
for _, pool := range p.hostConnPools {
30+
missing, err := pool.MissingConnections()
31+
if err != nil {
32+
return 0, err
33+
}
34+
total += missing
35+
}
36+
return total, nil
37+
}
38+
39+
func (s *Session) MissingConnections() (int, error) {
40+
if s.pool == nil {
41+
return 0, fmt.Errorf("pool is nil")
42+
}
43+
return s.pool.MissingConnections()
44+
}
45+
46+
type ConnPickerIntegration interface {
47+
Pick(Token, ExecutableQuery) *Conn
48+
Put(*Conn)
49+
Remove(conn *Conn)
50+
InFlight() int
51+
Size() (int, int)
52+
Close()
53+
CloseAllConnections()
54+
55+
// NextShard returns the shardID to connect to.
56+
// nrShard specifies how many shards the host has.
57+
// If nrShards is zero, the caller shouldn't use shard-aware port.
58+
NextShard() (shardID, nrShards int)
59+
}
60+
61+
func (p *scyllaConnPicker) CloseAllConnections() {
62+
p.nrConns = 0
63+
closeConns(p.conns...)
64+
for id := range p.conns {
65+
p.conns[id] = nil
66+
}
67+
}
68+
69+
func (p *defaultConnPicker) CloseAllConnections() {
70+
closeConns(p.conns...)
71+
p.conns = p.conns[:0]
72+
}
73+
74+
func (p *nopConnPicker) CloseAllConnections() {
75+
}
76+
77+
func (pool *hostConnPool) CloseAllConnections() {
78+
if !pool.closed {
79+
return
80+
}
81+
pool.mu.Lock()
82+
println("Closing all connections in a pool")
83+
pool.connPicker.(ConnPickerIntegration).CloseAllConnections()
84+
println("Filling the pool")
85+
pool.mu.Unlock()
86+
pool.fill()
87+
}
88+
89+
func (p *policyConnPool) CloseAllConnections() {
90+
p.mu.Lock()
91+
defer p.mu.Unlock()
92+
93+
// close the pools
94+
for _, pool := range p.hostConnPools {
95+
pool.CloseAllConnections()
96+
}
97+
}
98+
99+
func (s *Session) CloseAllConnections() {
100+
if s.pool != nil {
101+
s.pool.CloseAllConnections()
102+
}
103+
}

session_test.go

+113
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ package gocql
2929

3030
import (
3131
"context"
32+
"crypto/tls"
3233
"fmt"
3334
"net"
35+
"sync"
3436
"testing"
37+
"time"
3538
)
3639

3740
func TestSessionAPI(t *testing.T) {
@@ -424,3 +427,113 @@ func TestRetryType_IgnoreRethrow(t *testing.T) {
424427
resetObserved()
425428
}
426429
}
430+
431+
type sessionCache struct {
432+
orig tls.ClientSessionCache
433+
values map[string][][]byte
434+
caches map[string][]int64
435+
valuesLock sync.Mutex
436+
}
437+
438+
func (c *sessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
439+
return c.orig.Get(sessionKey)
440+
}
441+
442+
func (c *sessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
443+
ticket, _, err := cs.ResumptionState()
444+
if err != nil {
445+
panic(err)
446+
}
447+
if len(ticket) == 0 {
448+
panic("ticket should not be empty")
449+
}
450+
c.valuesLock.Lock()
451+
c.values[sessionKey] = append(c.values[sessionKey], ticket)
452+
c.valuesLock.Unlock()
453+
c.orig.Put(sessionKey, cs)
454+
}
455+
456+
func (c *sessionCache) NumberOfTickets() int {
457+
c.valuesLock.Lock()
458+
defer c.valuesLock.Unlock()
459+
total := 0
460+
for _, tickets := range c.values {
461+
total += len(tickets)
462+
}
463+
return total
464+
}
465+
466+
func newSessionCache() *sessionCache {
467+
return &sessionCache{
468+
orig: tls.NewLRUClientSessionCache(1024),
469+
values: make(map[string][][]byte),
470+
caches: make(map[string][]int64),
471+
valuesLock: sync.Mutex{},
472+
}
473+
}
474+
475+
func withSessionCache(cache tls.ClientSessionCache) func(config *ClusterConfig) {
476+
return func(config *ClusterConfig) {
477+
config.SslOpts = &SslOptions{
478+
EnableHostVerification: false,
479+
Config: &tls.Config{
480+
ClientSessionCache: cache,
481+
InsecureSkipVerify: true,
482+
},
483+
}
484+
}
485+
}
486+
487+
func TestTLSTicketResumption(t *testing.T) {
488+
t.Skip("TLS ticket resumption is only supported by 2025.2 and later")
489+
490+
c := newSessionCache()
491+
session := createSession(t, withSessionCache(c))
492+
defer session.Close()
493+
494+
waitAllConnectionsOpened := func() error {
495+
println("wait all connections opened")
496+
defer println("end of wait all connections closed")
497+
endtime := time.Now().UTC().Add(time.Second * 10)
498+
for {
499+
if time.Now().UTC().After(endtime) {
500+
return fmt.Errorf("timed out waiting for all connections opened")
501+
}
502+
missing, err := session.MissingConnections()
503+
if err != nil {
504+
return fmt.Errorf("failed to get missing connections count: %w", err)
505+
}
506+
if missing == 0 {
507+
return nil
508+
}
509+
time.Sleep(time.Millisecond * 100)
510+
}
511+
}
512+
513+
if err := waitAllConnectionsOpened(); err != nil {
514+
t.Fatal(err)
515+
}
516+
tickets := c.NumberOfTickets()
517+
if tickets == 0 {
518+
t.Fatal("no tickets learned, which means that server does not support TLS tickets")
519+
}
520+
521+
session.CloseAllConnections()
522+
if err := waitAllConnectionsOpened(); err != nil {
523+
t.Fatal(err)
524+
}
525+
newTickets1 := c.NumberOfTickets()
526+
527+
session.CloseAllConnections()
528+
if err := waitAllConnectionsOpened(); err != nil {
529+
t.Fatal(err)
530+
}
531+
newTickets2 := c.NumberOfTickets()
532+
533+
if newTickets1 != tickets {
534+
t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets1, tickets)
535+
}
536+
if newTickets2 != tickets {
537+
t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets2, tickets)
538+
}
539+
}

0 commit comments

Comments
 (0)