Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1410,12 +1410,12 @@ type inflightPrepare struct {
}

func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, requestTimeout time.Duration) (*preparedStatment, error) {
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
cacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(cacheKey, func(cache *lru.Cache[stmtCacheKey]) *inflightPrepare {
flight := &inflightPrepare{
done: make(chan struct{}),
}
lru.Add(stmtCacheKey, flight)
cache.Add(cacheKey, flight)
return flight
})

Expand All @@ -1436,14 +1436,14 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer,
framer, err := c.exec(c.ctx, prep, tracer, requestTimeout)
if err != nil {
flight.err = err
c.session.stmtsLRU.remove(stmtCacheKey)
c.session.stmtsLRU.remove(cacheKey)
return
}

frame, err := framer.parseFrame()
if err != nil {
flight.err = err
c.session.stmtsLRU.remove(stmtCacheKey)
c.session.stmtsLRU.remove(cacheKey)
return
}

Expand Down Expand Up @@ -1471,7 +1471,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer,
}

if flight.err != nil {
c.session.stmtsLRU.remove(stmtCacheKey)
c.session.stmtsLRU.remove(cacheKey)
}
}()
}
Expand Down Expand Up @@ -1670,8 +1670,8 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) (iter *Iter) {
// is not consistent with regards to its schema.
return iter
case *RequestErrUnprepared:
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt)
c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId)
cacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt)
c.session.stmtsLRU.evictPreparedID(cacheKey, x.StatementId)
return c.executeQuery(ctx, qry)
case error:
return &Iter{err: x, framer: framer}
Expand Down
46 changes: 23 additions & 23 deletions internal/lru/lru.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,70 +44,70 @@ package lru

import "container/list"

// Cache is an LRU cache. It is not safe for concurrent access.
// Cache is a generic LRU cache. It is not safe for concurrent access.
//
// This cache has been forked from github.com/golang/groupcache/lru, but
// specialized with string keys to avoid the allocations caused by wrapping them
// in interface{}.
type Cache struct {
// This cache has been forked from github.com/golang/groupcache/lru and
// generalized with a comparable type parameter to avoid the allocations
// caused by wrapping keys in interface{}.
type Cache[K comparable] struct {
// OnEvicted optionally specifies a callback function to be
// executed when an entry is purged from the cache.
OnEvicted func(key string, value interface{})
OnEvicted func(key K, value interface{})
ll *list.List
cache map[string]*list.Element
cache map[K]*list.Element
// MaxEntries is the maximum number of cache entries before
// an item is evicted. Zero means no limit.
MaxEntries int
}

type entry struct {
type entry[K comparable] struct {
value interface{}
key string
key K
}

// New creates a new Cache.
// If maxEntries is zero, the cache has no limit and it's assumed
// that eviction is done by the caller.
func New(maxEntries int) *Cache {
return &Cache{
func New[K comparable](maxEntries int) *Cache[K] {
return &Cache[K]{
MaxEntries: maxEntries,
ll: list.New(),
cache: make(map[string]*list.Element),
cache: make(map[K]*list.Element),
}
}

// Add adds a value to the cache.
func (c *Cache) Add(key string, value interface{}) {
func (c *Cache[K]) Add(key K, value interface{}) {
if c.cache == nil {
c.cache = make(map[string]*list.Element)
c.cache = make(map[K]*list.Element)
c.ll = list.New()
}
if ee, ok := c.cache[key]; ok {
c.ll.MoveToFront(ee)
ee.Value.(*entry).value = value
ee.Value.(*entry[K]).value = value
return
}
ele := c.ll.PushFront(&entry{key: key, value: value})
ele := c.ll.PushFront(&entry[K]{key: key, value: value})
c.cache[key] = ele
if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries {
c.RemoveOldest()
}
}

// Get looks up a key's value from the cache.
func (c *Cache) Get(key string) (value interface{}, ok bool) {
func (c *Cache[K]) Get(key K) (value interface{}, ok bool) {
if c.cache == nil {
return
}
if ele, hit := c.cache[key]; hit {
c.ll.MoveToFront(ele)
return ele.Value.(*entry).value, true
return ele.Value.(*entry[K]).value, true
}
return
}

// Remove removes the provided key from the cache.
func (c *Cache) Remove(key string) bool {
func (c *Cache[K]) Remove(key K) bool {
if c.cache == nil {
return false
}
Expand All @@ -121,7 +121,7 @@ func (c *Cache) Remove(key string) bool {
}

// RemoveOldest removes the oldest item from the cache.
func (c *Cache) RemoveOldest() {
func (c *Cache[K]) RemoveOldest() {
if c.cache == nil {
return
}
Expand All @@ -131,17 +131,17 @@ func (c *Cache) RemoveOldest() {
}
}

func (c *Cache) removeElement(e *list.Element) {
func (c *Cache[K]) removeElement(e *list.Element) {
c.ll.Remove(e)
kv := e.Value.(*entry)
kv := e.Value.(*entry[K])
delete(c.cache, kv.key)
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}

// Len returns the number of items in the cache.
func (c *Cache) Len() int {
func (c *Cache[K]) Len() int {
if c.cache == nil {
return 0
}
Expand Down
99 changes: 97 additions & 2 deletions internal/lru/lru_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.
package lru

import (
"fmt"
"testing"
)

Expand All @@ -64,7 +65,7 @@ func TestGet(t *testing.T) {
t.Parallel()

for _, tt := range getTests {
lru := New(0)
lru := New[string](0)
lru.Add(tt.keyToAdd, 1234)
val, ok := lru.Get(tt.keyToGet)
if ok != tt.expectedOk {
Expand All @@ -78,7 +79,7 @@ func TestGet(t *testing.T) {
func TestRemove(t *testing.T) {
t.Parallel()

lru := New(0)
lru := New[string](0)
lru.Add("mystring", 1234)
if val, ok := lru.Get("mystring"); !ok {
t.Fatal("TestRemove returned no match")
Expand All @@ -91,3 +92,97 @@ func TestRemove(t *testing.T) {
t.Fatal("TestRemove returned a removed entry")
}
}

// TestStructKey verifies that struct keys work correctly with the generic cache.
func TestStructKey(t *testing.T) {
t.Parallel()

type compositeKey struct {
A string
B string
}

c := New[compositeKey](0)
k1 := compositeKey{A: "ab", B: "cd"}
k2 := compositeKey{A: "a", B: "bcd"}

c.Add(k1, "value1")
c.Add(k2, "value2")

if val, ok := c.Get(k1); !ok || val != "value1" {
t.Fatalf("expected value1 for k1, got %v (ok=%v)", val, ok)
}
if val, ok := c.Get(k2); !ok || val != "value2" {
t.Fatalf("expected value2 for k2, got %v (ok=%v)", val, ok)
}

// Verify that keys with same concatenation but different field boundaries
// are distinct (this was a bug with string concatenation keys).
if c.Len() != 2 {
t.Fatalf("expected 2 entries, got %d", c.Len())
}
}

type stmtKey struct {
hostID string
keyspace string
statement string
}

// BenchmarkStructKeyLookup benchmarks the hot path: looking up a struct key
// in a populated cache.
func BenchmarkStructKeyLookup(b *testing.B) {
c := New[stmtKey](1000)
key := stmtKey{
hostID: "550e8400-e29b-41d4-a716-446655440000",
keyspace: "my_keyspace",
statement: "SELECT id, name, email FROM users WHERE id = ?",
}
c.Add(key, "prepared-id")

b.ReportAllocs()
for i := 0; i < b.N; i++ {
c.Get(key)
}
}

// BenchmarkStringKeyLookup benchmarks the old approach: looking up a
// concatenated string key in a populated cache.
func BenchmarkStringKeyLookup(b *testing.B) {
c := New[string](1000)
key := "550e8400-e29b-41d4-a716-446655440000" + "my_keyspace" + "SELECT id, name, email FROM users WHERE id = ?"
c.Add(key, "prepared-id")

b.ReportAllocs()
for i := 0; i < b.N; i++ {
c.Get(key)
}
}

// BenchmarkStructKeyInsert benchmarks inserting entries with struct keys,
// including eviction when the cache is full.
func BenchmarkStructKeyInsert(b *testing.B) {
c := New[stmtKey](1000)

b.ReportAllocs()
for i := 0; i < b.N; i++ {
k := stmtKey{
hostID: "550e8400-e29b-41d4-a716-446655440000",
keyspace: "my_keyspace",
statement: fmt.Sprintf("SELECT id FROM users WHERE id = %d", i),
}
c.Add(k, "prepared-id")
}
}

// BenchmarkStringKeyInsert benchmarks inserting entries with concatenated
// string keys, including the per-query allocation cost of key construction.
func BenchmarkStringKeyInsert(b *testing.B) {
c := New[string](1000)

b.ReportAllocs()
for i := 0; i < b.N; i++ {
k := fmt.Sprintf("%s%s%s", "550e8400-e29b-41d4-a716-446655440000", "my_keyspace", fmt.Sprintf("SELECT id FROM users WHERE id = %d", i))
c.Add(k, "prepared-id")
}
}
34 changes: 25 additions & 9 deletions prepared_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,20 @@ import (

const defaultMaxPreparedStmts = 1000

// stmtCacheKey is a composite key for the prepared statement cache.
// Using a struct avoids the string concatenation allocation that occurred
// on every query and fixes the theoretical key collision bug where
// different (hostID, keyspace, statement) tuples could produce the same
// concatenated string.
type stmtCacheKey struct {
hostID string
keyspace string
statement string
}

// preparedLRU is the prepared statement cache
type preparedLRU struct {
lru *lru.Cache
lru *lru.Cache[stmtCacheKey]
mu sync.Mutex
}

Expand All @@ -48,19 +59,19 @@ func (p *preparedLRU) clear() {
}
}

func (p *preparedLRU) add(key string, val *inflightPrepare) {
func (p *preparedLRU) add(key stmtCacheKey, val *inflightPrepare) {
p.mu.Lock()
defer p.mu.Unlock()
p.lru.Add(key, val)
}

func (p *preparedLRU) remove(key string) bool {
func (p *preparedLRU) remove(key stmtCacheKey) bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.lru.Remove(key)
}

func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
func (p *preparedLRU) execIfMissing(key stmtCacheKey, fn func(cache *lru.Cache[stmtCacheKey]) *inflightPrepare) (*inflightPrepare, bool) {
p.mu.Lock()
defer p.mu.Unlock()

Expand All @@ -72,12 +83,18 @@ func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *infligh
return fn(p.lru), false
}

func (p *preparedLRU) keyFor(hostID, keyspace, statement string) string {
// TODO: we should just use a struct for the key in the map
return hostID + keyspace + statement
// keyFor constructs a zero-allocation composite cache key from the given
// components. The returned struct references the original strings without
// copying, so no heap allocation occurs.
func (p *preparedLRU) keyFor(hostID, keyspace, statement string) stmtCacheKey {
return stmtCacheKey{
hostID: hostID,
keyspace: keyspace,
statement: statement,
}
}

func (p *preparedLRU) evictPreparedID(key string, id []byte) {
func (p *preparedLRU) evictPreparedID(key stmtCacheKey, id []byte) {
p.mu.Lock()
defer p.mu.Unlock()

Expand All @@ -98,5 +115,4 @@ func (p *preparedLRU) evictPreparedID(key string, id []byte) {
}
default:
}

}
6 changes: 3 additions & 3 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func newSessionCommon(cfg ClusterConfig) (*Session, error) {
prefetch: 0.25,
cfg: cfg,
pageSize: cfg.PageSize,
stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)},
stmtsLRU: &preparedLRU{lru: lru.New[stmtCacheKey](cfg.MaxPreparedStmts)},
connectObserver: cfg.ConnectObserver,
ctx: ctx,
cancel: cancel,
Expand Down Expand Up @@ -166,7 +166,7 @@ func newSessionCommon(cfg ClusterConfig) (*Session, error) {
s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent, s.logger)
s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent, s.logger)

s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
s.routingKeyInfoCache.lru = lru.New[string](cfg.MaxRoutingKeyInfo)

s.hostSource = &ringDescriber{cfg: &s.cfg, logger: s.logger}
s.ringRefresher = debounce.NewRefreshDebouncer(debounce.RingRefreshDebounceTime, func() error {
Expand Down Expand Up @@ -2505,7 +2505,7 @@ func (c ColumnInfo) String() string {

// routing key indexes LRU cache
type routingKeyInfoLRU struct {
lru *lru.Cache
lru *lru.Cache[string]
mu sync.Mutex
}

Expand Down
Loading