Skip to content

Add TLS ticket resumption test #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions integration_only.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//go:build integration
// +build integration

package gocql

// This file contains code to enable easy access to driver internals
// To be used only for integration test

import "fmt"

func (pool *hostConnPool) MissingConnections() (int, error) {
pool.mu.Lock()
defer pool.mu.Unlock()

if pool.closed {
return 0, fmt.Errorf("pool is closed")
}
_, missing := pool.connPicker.Size()
return missing, nil
}

func (p *policyConnPool) MissingConnections() (int, error) {
p.mu.Lock()
defer p.mu.Unlock()

total := 0

// close the pools
for _, pool := range p.hostConnPools {
missing, err := pool.MissingConnections()
if err != nil {
return 0, err
}
total += missing
}
return total, nil
}

func (s *Session) MissingConnections() (int, error) {
if s.pool == nil {
return 0, fmt.Errorf("pool is nil")
}
return s.pool.MissingConnections()
}

type ConnPickerIntegration interface {
Pick(Token, ExecutableQuery) *Conn
Put(*Conn)
Remove(conn *Conn)
InFlight() int
Size() (int, int)
Close()
CloseAllConnections()

// NextShard returns the shardID to connect to.
// nrShard specifies how many shards the host has.
// If nrShards is zero, the caller shouldn't use shard-aware port.
NextShard() (shardID, nrShards int)
}

func (p *scyllaConnPicker) CloseAllConnections() {
p.nrConns = 0
closeConns(p.conns...)
for id := range p.conns {
p.conns[id] = nil
}
}

func (p *defaultConnPicker) CloseAllConnections() {
closeConns(p.conns...)
p.conns = p.conns[:0]
}

func (p *nopConnPicker) CloseAllConnections() {
}

func (pool *hostConnPool) CloseAllConnections() {
if !pool.closed {
return
}
pool.mu.Lock()
println("Closing all connections in a pool")
pool.connPicker.(ConnPickerIntegration).CloseAllConnections()
println("Filling the pool")
pool.mu.Unlock()
pool.fill()
}

func (p *policyConnPool) CloseAllConnections() {
p.mu.Lock()
defer p.mu.Unlock()

// close the pools
for _, pool := range p.hostConnPools {
pool.CloseAllConnections()
}
}

func (s *Session) CloseAllConnections() {
if s.pool != nil {
s.pool.CloseAllConnections()
}
}
113 changes: 113 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ package gocql

import (
"context"
"crypto/tls"
"fmt"
"net"
"sync"
"testing"
"time"
)

func TestSessionAPI(t *testing.T) {
Expand Down Expand Up @@ -424,3 +427,113 @@ func TestRetryType_IgnoreRethrow(t *testing.T) {
resetObserved()
}
}

type sessionCache struct {
orig tls.ClientSessionCache
values map[string][][]byte
caches map[string][]int64
valuesLock sync.Mutex
}

func (c *sessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
return c.orig.Get(sessionKey)
}

func (c *sessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
ticket, _, err := cs.ResumptionState()
if err != nil {
panic(err)
}
if len(ticket) == 0 {
panic("ticket should not be empty")
}
c.valuesLock.Lock()
c.values[sessionKey] = append(c.values[sessionKey], ticket)
c.valuesLock.Unlock()
c.orig.Put(sessionKey, cs)
}

func (c *sessionCache) NumberOfTickets() int {
c.valuesLock.Lock()
defer c.valuesLock.Unlock()
total := 0
for _, tickets := range c.values {
total += len(tickets)
}
return total
}

func newSessionCache() *sessionCache {
return &sessionCache{
orig: tls.NewLRUClientSessionCache(1024),
values: make(map[string][][]byte),
caches: make(map[string][]int64),
valuesLock: sync.Mutex{},
}
}

func withSessionCache(cache tls.ClientSessionCache) func(config *ClusterConfig) {
return func(config *ClusterConfig) {
config.SslOpts = &SslOptions{
EnableHostVerification: false,
Config: &tls.Config{
ClientSessionCache: cache,
InsecureSkipVerify: true,
},
}
}
}

func TestTLSTicketResumption(t *testing.T) {
t.Skip("TLS ticket resumption is only supported by 2025.2 and later")

c := newSessionCache()
session := createSession(t, withSessionCache(c))
defer session.Close()

waitAllConnectionsOpened := func() error {
println("wait all connections opened")
defer println("end of wait all connections closed")
endtime := time.Now().UTC().Add(time.Second * 10)
for {
if time.Now().UTC().After(endtime) {
return fmt.Errorf("timed out waiting for all connections opened")
}
missing, err := session.MissingConnections()
if err != nil {
return fmt.Errorf("failed to get missing connections count: %w", err)
}
if missing == 0 {
return nil
}
time.Sleep(time.Millisecond * 100)
}
}

if err := waitAllConnectionsOpened(); err != nil {
t.Fatal(err)
}
tickets := c.NumberOfTickets()
if tickets == 0 {
t.Fatal("no tickets learned, which means that server does not support TLS tickets")
}

session.CloseAllConnections()
if err := waitAllConnectionsOpened(); err != nil {
t.Fatal(err)
}
newTickets1 := c.NumberOfTickets()

session.CloseAllConnections()
if err := waitAllConnectionsOpened(); err != nil {
t.Fatal(err)
}
newTickets2 := c.NumberOfTickets()

if newTickets1 != tickets {
t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets1, tickets)
}
if newTickets2 != tickets {
t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets2, tickets)
}
}
Loading