diff --git a/conn.go b/conn.go index 18e180c67..8cb48684c 100644 --- a/conn.go +++ b/conn.go @@ -1410,12 +1410,12 @@ type inflightPrepare struct { } func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, requestTimeout time.Duration) (*preparedStatment, error) { - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) - flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { + cacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) + flight, ok := c.session.stmtsLRU.execIfMissing(cacheKey, func(cache *lru.Cache[stmtCacheKey]) *inflightPrepare { flight := &inflightPrepare{ done: make(chan struct{}), } - lru.Add(stmtCacheKey, flight) + cache.Add(cacheKey, flight) return flight }) @@ -1436,14 +1436,14 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, framer, err := c.exec(c.ctx, prep, tracer, requestTimeout) if err != nil { flight.err = err - c.session.stmtsLRU.remove(stmtCacheKey) + c.session.stmtsLRU.remove(cacheKey) return } frame, err := framer.parseFrame() if err != nil { flight.err = err - c.session.stmtsLRU.remove(stmtCacheKey) + c.session.stmtsLRU.remove(cacheKey) return } @@ -1471,7 +1471,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, } if flight.err != nil { - c.session.stmtsLRU.remove(stmtCacheKey) + c.session.stmtsLRU.remove(cacheKey) } }() } @@ -1670,8 +1670,8 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) (iter *Iter) { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) - c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) + cacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) + c.session.stmtsLRU.evictPreparedID(cacheKey, x.StatementId) return c.executeQuery(ctx, qry) case error: return &Iter{err: x, framer: framer} diff --git a/internal/lru/lru.go b/internal/lru/lru.go index 899d778c6..6ddd9c721 100644 --- a/internal/lru/lru.go +++ b/internal/lru/lru.go @@ -44,50 +44,50 @@ package lru import "container/list" -// Cache is an LRU cache. It is not safe for concurrent access. +// Cache is a generic LRU cache. It is not safe for concurrent access. // -// This cache has been forked from github.com/golang/groupcache/lru, but -// specialized with string keys to avoid the allocations caused by wrapping them -// in interface{}. -type Cache struct { +// This cache has been forked from github.com/golang/groupcache/lru and +// generalized with a comparable type parameter to avoid the allocations +// caused by wrapping keys in interface{}. +type Cache[K comparable] struct { // OnEvicted optionally specifies a callback function to be // executed when an entry is purged from the cache. - OnEvicted func(key string, value interface{}) + OnEvicted func(key K, value interface{}) ll *list.List - cache map[string]*list.Element + cache map[K]*list.Element // MaxEntries is the maximum number of cache entries before // an item is evicted. Zero means no limit. MaxEntries int } -type entry struct { +type entry[K comparable] struct { value interface{} - key string + key K } // New creates a new Cache. // If maxEntries is zero, the cache has no limit and it's assumed // that eviction is done by the caller. -func New(maxEntries int) *Cache { - return &Cache{ +func New[K comparable](maxEntries int) *Cache[K] { + return &Cache[K]{ MaxEntries: maxEntries, ll: list.New(), - cache: make(map[string]*list.Element), + cache: make(map[K]*list.Element), } } // Add adds a value to the cache. -func (c *Cache) Add(key string, value interface{}) { +func (c *Cache[K]) Add(key K, value interface{}) { if c.cache == nil { - c.cache = make(map[string]*list.Element) + c.cache = make(map[K]*list.Element) c.ll = list.New() } if ee, ok := c.cache[key]; ok { c.ll.MoveToFront(ee) - ee.Value.(*entry).value = value + ee.Value.(*entry[K]).value = value return } - ele := c.ll.PushFront(&entry{key: key, value: value}) + ele := c.ll.PushFront(&entry[K]{key: key, value: value}) c.cache[key] = ele if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries { c.RemoveOldest() @@ -95,19 +95,19 @@ func (c *Cache) Add(key string, value interface{}) { } // Get looks up a key's value from the cache. -func (c *Cache) Get(key string) (value interface{}, ok bool) { +func (c *Cache[K]) Get(key K) (value interface{}, ok bool) { if c.cache == nil { return } if ele, hit := c.cache[key]; hit { c.ll.MoveToFront(ele) - return ele.Value.(*entry).value, true + return ele.Value.(*entry[K]).value, true } return } // Remove removes the provided key from the cache. -func (c *Cache) Remove(key string) bool { +func (c *Cache[K]) Remove(key K) bool { if c.cache == nil { return false } @@ -121,7 +121,7 @@ func (c *Cache) Remove(key string) bool { } // RemoveOldest removes the oldest item from the cache. -func (c *Cache) RemoveOldest() { +func (c *Cache[K]) RemoveOldest() { if c.cache == nil { return } @@ -131,9 +131,9 @@ func (c *Cache) RemoveOldest() { } } -func (c *Cache) removeElement(e *list.Element) { +func (c *Cache[K]) removeElement(e *list.Element) { c.ll.Remove(e) - kv := e.Value.(*entry) + kv := e.Value.(*entry[K]) delete(c.cache, kv.key) if c.OnEvicted != nil { c.OnEvicted(kv.key, kv.value) @@ -141,7 +141,7 @@ func (c *Cache) removeElement(e *list.Element) { } // Len returns the number of items in the cache. -func (c *Cache) Len() int { +func (c *Cache[K]) Len() int { if c.cache == nil { return 0 } diff --git a/internal/lru/lru_test.go b/internal/lru/lru_test.go index 8fb64f408..10a1daff3 100644 --- a/internal/lru/lru_test.go +++ b/internal/lru/lru_test.go @@ -45,6 +45,7 @@ limitations under the License. package lru import ( + "fmt" "testing" ) @@ -64,7 +65,7 @@ func TestGet(t *testing.T) { t.Parallel() for _, tt := range getTests { - lru := New(0) + lru := New[string](0) lru.Add(tt.keyToAdd, 1234) val, ok := lru.Get(tt.keyToGet) if ok != tt.expectedOk { @@ -78,7 +79,7 @@ func TestGet(t *testing.T) { func TestRemove(t *testing.T) { t.Parallel() - lru := New(0) + lru := New[string](0) lru.Add("mystring", 1234) if val, ok := lru.Get("mystring"); !ok { t.Fatal("TestRemove returned no match") @@ -91,3 +92,97 @@ func TestRemove(t *testing.T) { t.Fatal("TestRemove returned a removed entry") } } + +// TestStructKey verifies that struct keys work correctly with the generic cache. +func TestStructKey(t *testing.T) { + t.Parallel() + + type compositeKey struct { + A string + B string + } + + c := New[compositeKey](0) + k1 := compositeKey{A: "ab", B: "cd"} + k2 := compositeKey{A: "a", B: "bcd"} + + c.Add(k1, "value1") + c.Add(k2, "value2") + + if val, ok := c.Get(k1); !ok || val != "value1" { + t.Fatalf("expected value1 for k1, got %v (ok=%v)", val, ok) + } + if val, ok := c.Get(k2); !ok || val != "value2" { + t.Fatalf("expected value2 for k2, got %v (ok=%v)", val, ok) + } + + // Verify that keys with same concatenation but different field boundaries + // are distinct (this was a bug with string concatenation keys). + if c.Len() != 2 { + t.Fatalf("expected 2 entries, got %d", c.Len()) + } +} + +type stmtKey struct { + hostID string + keyspace string + statement string +} + +// BenchmarkStructKeyLookup benchmarks the hot path: looking up a struct key +// in a populated cache. +func BenchmarkStructKeyLookup(b *testing.B) { + c := New[stmtKey](1000) + key := stmtKey{ + hostID: "550e8400-e29b-41d4-a716-446655440000", + keyspace: "my_keyspace", + statement: "SELECT id, name, email FROM users WHERE id = ?", + } + c.Add(key, "prepared-id") + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + c.Get(key) + } +} + +// BenchmarkStringKeyLookup benchmarks the old approach: looking up a +// concatenated string key in a populated cache. +func BenchmarkStringKeyLookup(b *testing.B) { + c := New[string](1000) + key := "550e8400-e29b-41d4-a716-446655440000" + "my_keyspace" + "SELECT id, name, email FROM users WHERE id = ?" + c.Add(key, "prepared-id") + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + c.Get(key) + } +} + +// BenchmarkStructKeyInsert benchmarks inserting entries with struct keys, +// including eviction when the cache is full. +func BenchmarkStructKeyInsert(b *testing.B) { + c := New[stmtKey](1000) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + k := stmtKey{ + hostID: "550e8400-e29b-41d4-a716-446655440000", + keyspace: "my_keyspace", + statement: fmt.Sprintf("SELECT id FROM users WHERE id = %d", i), + } + c.Add(k, "prepared-id") + } +} + +// BenchmarkStringKeyInsert benchmarks inserting entries with concatenated +// string keys, including the per-query allocation cost of key construction. +func BenchmarkStringKeyInsert(b *testing.B) { + c := New[string](1000) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + k := fmt.Sprintf("%s%s%s", "550e8400-e29b-41d4-a716-446655440000", "my_keyspace", fmt.Sprintf("SELECT id FROM users WHERE id = %d", i)) + c.Add(k, "prepared-id") + } +} diff --git a/prepared_cache.go b/prepared_cache.go index d63000dd8..95d3de013 100644 --- a/prepared_cache.go +++ b/prepared_cache.go @@ -33,9 +33,20 @@ import ( const defaultMaxPreparedStmts = 1000 +// stmtCacheKey is a composite key for the prepared statement cache. +// Using a struct avoids the string concatenation allocation that occurred +// on every query and fixes the theoretical key collision bug where +// different (hostID, keyspace, statement) tuples could produce the same +// concatenated string. +type stmtCacheKey struct { + hostID string + keyspace string + statement string +} + // preparedLRU is the prepared statement cache type preparedLRU struct { - lru *lru.Cache + lru *lru.Cache[stmtCacheKey] mu sync.Mutex } @@ -48,19 +59,19 @@ func (p *preparedLRU) clear() { } } -func (p *preparedLRU) add(key string, val *inflightPrepare) { +func (p *preparedLRU) add(key stmtCacheKey, val *inflightPrepare) { p.mu.Lock() defer p.mu.Unlock() p.lru.Add(key, val) } -func (p *preparedLRU) remove(key string) bool { +func (p *preparedLRU) remove(key stmtCacheKey) bool { p.mu.Lock() defer p.mu.Unlock() return p.lru.Remove(key) } -func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) { +func (p *preparedLRU) execIfMissing(key stmtCacheKey, fn func(cache *lru.Cache[stmtCacheKey]) *inflightPrepare) (*inflightPrepare, bool) { p.mu.Lock() defer p.mu.Unlock() @@ -72,12 +83,18 @@ func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *infligh return fn(p.lru), false } -func (p *preparedLRU) keyFor(hostID, keyspace, statement string) string { - // TODO: we should just use a struct for the key in the map - return hostID + keyspace + statement +// keyFor constructs a zero-allocation composite cache key from the given +// components. The returned struct references the original strings without +// copying, so no heap allocation occurs. +func (p *preparedLRU) keyFor(hostID, keyspace, statement string) stmtCacheKey { + return stmtCacheKey{ + hostID: hostID, + keyspace: keyspace, + statement: statement, + } } -func (p *preparedLRU) evictPreparedID(key string, id []byte) { +func (p *preparedLRU) evictPreparedID(key stmtCacheKey, id []byte) { p.mu.Lock() defer p.mu.Unlock() @@ -98,5 +115,4 @@ func (p *preparedLRU) evictPreparedID(key string, id []byte) { } default: } - } diff --git a/session.go b/session.go index 878f744ec..b26195c8d 100644 --- a/session.go +++ b/session.go @@ -134,7 +134,7 @@ func newSessionCommon(cfg ClusterConfig) (*Session, error) { prefetch: 0.25, cfg: cfg, pageSize: cfg.PageSize, - stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)}, + stmtsLRU: &preparedLRU{lru: lru.New[stmtCacheKey](cfg.MaxPreparedStmts)}, connectObserver: cfg.ConnectObserver, ctx: ctx, cancel: cancel, @@ -166,7 +166,7 @@ func newSessionCommon(cfg ClusterConfig) (*Session, error) { s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent, s.logger) s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent, s.logger) - s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) + s.routingKeyInfoCache.lru = lru.New[string](cfg.MaxRoutingKeyInfo) s.hostSource = &ringDescriber{cfg: &s.cfg, logger: s.logger} s.ringRefresher = debounce.NewRefreshDebouncer(debounce.RingRefreshDebounceTime, func() error { @@ -2505,7 +2505,7 @@ func (c ColumnInfo) String() string { // routing key indexes LRU cache type routingKeyInfoLRU struct { - lru *lru.Cache + lru *lru.Cache[string] mu sync.Mutex }