diff --git a/integration_only.go b/integration_only.go new file mode 100644 index 000000000..fc82fdeda --- /dev/null +++ b/integration_only.go @@ -0,0 +1,103 @@ +//go:build integration +// +build integration + +package gocql + +// This file contains code to enable easy access to driver internals +// To be used only for integration test + +import "fmt" + +func (pool *hostConnPool) MissingConnections() (int, error) { + pool.mu.Lock() + defer pool.mu.Unlock() + + if pool.closed { + return 0, fmt.Errorf("pool is closed") + } + _, missing := pool.connPicker.Size() + return missing, nil +} + +func (p *policyConnPool) MissingConnections() (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + + total := 0 + + // close the pools + for _, pool := range p.hostConnPools { + missing, err := pool.MissingConnections() + if err != nil { + return 0, err + } + total += missing + } + return total, nil +} + +func (s *Session) MissingConnections() (int, error) { + if s.pool == nil { + return 0, fmt.Errorf("pool is nil") + } + return s.pool.MissingConnections() +} + +type ConnPickerIntegration interface { + Pick(Token, ExecutableQuery) *Conn + Put(*Conn) + Remove(conn *Conn) + InFlight() int + Size() (int, int) + Close() + CloseAllConnections() + + // NextShard returns the shardID to connect to. + // nrShard specifies how many shards the host has. + // If nrShards is zero, the caller shouldn't use shard-aware port. + NextShard() (shardID, nrShards int) +} + +func (p *scyllaConnPicker) CloseAllConnections() { + p.nrConns = 0 + closeConns(p.conns...) + for id := range p.conns { + p.conns[id] = nil + } +} + +func (p *defaultConnPicker) CloseAllConnections() { + closeConns(p.conns...) + p.conns = p.conns[:0] +} + +func (p *nopConnPicker) CloseAllConnections() { +} + +func (pool *hostConnPool) CloseAllConnections() { + if !pool.closed { + return + } + pool.mu.Lock() + println("Closing all connections in a pool") + pool.connPicker.(ConnPickerIntegration).CloseAllConnections() + println("Filling the pool") + pool.mu.Unlock() + pool.fill() +} + +func (p *policyConnPool) CloseAllConnections() { + p.mu.Lock() + defer p.mu.Unlock() + + // close the pools + for _, pool := range p.hostConnPools { + pool.CloseAllConnections() + } +} + +func (s *Session) CloseAllConnections() { + if s.pool != nil { + s.pool.CloseAllConnections() + } +} diff --git a/session_test.go b/session_test.go index 15d496082..38ae7b2b3 100644 --- a/session_test.go +++ b/session_test.go @@ -29,9 +29,12 @@ package gocql import ( "context" + "crypto/tls" "fmt" "net" + "sync" "testing" + "time" ) func TestSessionAPI(t *testing.T) { @@ -424,3 +427,113 @@ func TestRetryType_IgnoreRethrow(t *testing.T) { resetObserved() } } + +type sessionCache struct { + orig tls.ClientSessionCache + values map[string][][]byte + caches map[string][]int64 + valuesLock sync.Mutex +} + +func (c *sessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) { + return c.orig.Get(sessionKey) +} + +func (c *sessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { + ticket, _, err := cs.ResumptionState() + if err != nil { + panic(err) + } + if len(ticket) == 0 { + panic("ticket should not be empty") + } + c.valuesLock.Lock() + c.values[sessionKey] = append(c.values[sessionKey], ticket) + c.valuesLock.Unlock() + c.orig.Put(sessionKey, cs) +} + +func (c *sessionCache) NumberOfTickets() int { + c.valuesLock.Lock() + defer c.valuesLock.Unlock() + total := 0 + for _, tickets := range c.values { + total += len(tickets) + } + return total +} + +func newSessionCache() *sessionCache { + return &sessionCache{ + orig: tls.NewLRUClientSessionCache(1024), + values: make(map[string][][]byte), + caches: make(map[string][]int64), + valuesLock: sync.Mutex{}, + } +} + +func withSessionCache(cache tls.ClientSessionCache) func(config *ClusterConfig) { + return func(config *ClusterConfig) { + config.SslOpts = &SslOptions{ + EnableHostVerification: false, + Config: &tls.Config{ + ClientSessionCache: cache, + InsecureSkipVerify: true, + }, + } + } +} + +func TestTLSTicketResumption(t *testing.T) { + t.Skip("TLS ticket resumption is only supported by 2025.2 and later") + + c := newSessionCache() + session := createSession(t, withSessionCache(c)) + defer session.Close() + + waitAllConnectionsOpened := func() error { + println("wait all connections opened") + defer println("end of wait all connections closed") + endtime := time.Now().UTC().Add(time.Second * 10) + for { + if time.Now().UTC().After(endtime) { + return fmt.Errorf("timed out waiting for all connections opened") + } + missing, err := session.MissingConnections() + if err != nil { + return fmt.Errorf("failed to get missing connections count: %w", err) + } + if missing == 0 { + return nil + } + time.Sleep(time.Millisecond * 100) + } + } + + if err := waitAllConnectionsOpened(); err != nil { + t.Fatal(err) + } + tickets := c.NumberOfTickets() + if tickets == 0 { + t.Fatal("no tickets learned, which means that server does not support TLS tickets") + } + + session.CloseAllConnections() + if err := waitAllConnectionsOpened(); err != nil { + t.Fatal(err) + } + newTickets1 := c.NumberOfTickets() + + session.CloseAllConnections() + if err := waitAllConnectionsOpened(); err != nil { + t.Fatal(err) + } + newTickets2 := c.NumberOfTickets() + + if newTickets1 != tickets { + t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets1, tickets) + } + if newTickets2 != tickets { + t.Fatalf("new tickets learned, it looks like tls tickets where not reused: new %d, was %d", newTickets2, tickets) + } +}