From 41576c9487b9c64937ef54f6f92653e7dbd9cb8d Mon Sep 17 00:00:00 2001 From: Jen-Hung Yu Date: Tue, 2 Sep 2025 03:00:30 +0800 Subject: [PATCH] feat: replace custom LRU cache with otter cache library MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace internal LRU cache implementation for prepared statements and routing key info with the more efficient otter cache library. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- conn.go | 114 +++++++++--------------- go.mod | 14 +-- go.sum | 15 ++-- prepared_cache.go | 73 ++-------------- session.go | 219 ++++++++++++++++------------------------------ 5 files changed, 146 insertions(+), 289 deletions(-) diff --git a/conn.go b/conn.go index 3daca6250..80c768b53 100644 --- a/conn.go +++ b/conn.go @@ -39,8 +39,8 @@ import ( "sync/atomic" "time" - "github.com/gocql/gocql/internal/lru" "github.com/gocql/gocql/internal/streams" + "github.com/maypok86/otter/v2" ) var ( @@ -1236,78 +1236,50 @@ type inflightPrepare struct { } func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) - flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { - flight := &inflightPrepare{ - done: make(chan struct{}), + loader := otter.LoaderFunc[string, *preparedStatment](func(ctx context.Context, key string) (*preparedStatment, error) { + prep := &writePrepareFrame{ + statement: stmt, + } + if c.version > protoVersion4 { + prep.keyspace = c.currentKeyspace } - lru.Add(stmtCacheKey, flight) - return flight - }) - - if !ok { - go func() { - defer close(flight.done) - - prep := &writePrepareFrame{ - statement: stmt, - } - if c.version > protoVersion4 { - prep.keyspace = c.currentKeyspace - } - - // we won the race to do the load, if our context is canceled we shouldnt - // stop the load as other callers are waiting for it but this caller should get - // their context cancelled error. - framer, err := c.exec(c.ctx, prep, tracer) - if err != nil { - flight.err = err - c.session.stmtsLRU.remove(stmtCacheKey) - return - } - frame, err := framer.parseFrame() - if err != nil { - flight.err = err - c.session.stmtsLRU.remove(stmtCacheKey) - return - } + framer, err := c.exec(c.ctx, prep, tracer) + if err != nil { + return nil, err + } - // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated - // everytime we need to parse a frame. - if len(framer.traceID) > 0 && tracer != nil { - tracer.Trace(framer.traceID) - } + frame, err := framer.parseFrame() + if err != nil { + return nil, err + } - switch x := frame.(type) { - case *resultPreparedFrame: - flight.preparedStatment = &preparedStatment{ - // defensively copy as we will recycle the underlying buffer after we - // return. - id: copyBytes(x.preparedID), - // the type info's should _not_ have a reference to the framers read buffer, - // therefore we can just copy them directly. - request: x.reqMeta, - response: x.respMeta, - } - case error: - flight.err = x - default: - flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x) - } + // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated + // everytime we need to parse a frame. + if len(framer.traceID) > 0 && tracer != nil { + tracer.Trace(framer.traceID) + } - if flight.err != nil { - c.session.stmtsLRU.remove(stmtCacheKey) - } - }() - } + switch x := frame.(type) { + case *resultPreparedFrame: + return &preparedStatment{ + // defensively copy as we will recycle the underlying buffer after we + // return. + id: copyBytes(x.preparedID), + // the type info's should _not_ have a reference to the framers read buffer, + // therefore we can just copy them directly. + request: x.reqMeta, + response: x.respMeta, + }, nil + case error: + return nil, x + default: + return nil, NewErrProtocol("Unknown type in response to prepare frame: %s", x) + } + }) + stmtCacheKey := keyForPreparedStatement(c.host.HostID(), c.currentKeyspace, stmt) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-flight.done: - return flight.preparedStatment, flight.err - } + return c.session.stmtsLRU.Get(ctx, stmtCacheKey, loader) } func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { @@ -1477,8 +1449,8 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) - c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) + stmtCacheKey := keyForPreparedStatement(c.host.HostID(), c.currentKeyspace, qry.stmt) + c.session.stmtsLRU.Invalidate(stmtCacheKey) return c.executeQuery(ctx, qry) case error: return &Iter{err: x, framer: framer} @@ -1623,8 +1595,8 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { - key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) - c.session.stmtsLRU.evictPreparedID(key, x.StatementId) + key := keyForPreparedStatement(c.host.HostID(), c.currentKeyspace, stmt) + c.session.stmtsLRU.Invalidate(key) } return c.executeBatch(ctx, batch) case *resultRowsFrame: diff --git a/go.mod b/go.mod index 0aea881ec..54b45a845 100644 --- a/go.mod +++ b/go.mod @@ -18,13 +18,17 @@ module github.com/gocql/gocql require ( - github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect - github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/golang/snappy v0.0.3 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed - github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.3.0 // indirect + github.com/maypok86/otter/v2 v2.2.1 gopkg.in/inf.v0 v0.9.1 ) -go 1.13 +require ( + github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect + github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/kr/pretty v0.1.0 // indirect + golang.org/x/sys v0.34.0 // indirect +) + +go 1.24.6 diff --git a/go.sum b/go.sum index 2e3892bcb..1f0f5e42d 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= @@ -13,10 +13,15 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/maypok86/otter/v2 v2.2.1 h1:hnGssisMFkdisYcvQ8L019zpYQcdtPse+g0ps2i7cfI= +github.com/maypok86/otter/v2 v2.2.1/go.mod h1:1NKY9bY+kB5jwCXBJfE59u+zAwOt6C7ni1FTlFFMqVs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/prepared_cache.go b/prepared_cache.go index 3fd256d33..a3806b3f8 100644 --- a/prepared_cache.go +++ b/prepared_cache.go @@ -25,78 +25,19 @@ package gocql import ( - "bytes" - "sync" - - "github.com/gocql/gocql/internal/lru" + "github.com/maypok86/otter/v2" ) const defaultMaxPreparedStmts = 1000 -// preparedLRU is the prepared statement cache -type preparedLRU struct { - mu sync.Mutex - lru *lru.Cache -} - -func (p *preparedLRU) clear() { - p.mu.Lock() - defer p.mu.Unlock() - - for p.lru.Len() > 0 { - p.lru.RemoveOldest() - } -} - -func (p *preparedLRU) add(key string, val *inflightPrepare) { - p.mu.Lock() - defer p.mu.Unlock() - p.lru.Add(key, val) -} - -func (p *preparedLRU) remove(key string) bool { - p.mu.Lock() - defer p.mu.Unlock() - return p.lru.Remove(key) -} - -func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) { - p.mu.Lock() - defer p.mu.Unlock() - - val, ok := p.lru.Get(key) - if ok { - return val.(*inflightPrepare), true - } - - return fn(p.lru), false +func NewPreparedLRU(size int) *otter.Cache[string, *preparedStatment] { + return otter.Must(&otter.Options[string, *preparedStatment]{ + InitialCapacity: size, + MaximumSize: size, + }) } -func (p *preparedLRU) keyFor(hostID, keyspace, statement string) string { +func keyForPreparedStatement(hostID, keyspace, statement string) string { // TODO: we should just use a struct for the key in the map return hostID + keyspace + statement } - -func (p *preparedLRU) evictPreparedID(key string, id []byte) { - p.mu.Lock() - defer p.mu.Unlock() - - val, ok := p.lru.Get(key) - if !ok { - return - } - - ifp, ok := val.(*inflightPrepare) - if !ok { - return - } - - select { - case <-ifp.done: - if bytes.Equal(id, ifp.preparedStatment.id) { - p.lru.Remove(key) - } - default: - } - -} diff --git a/session.go b/session.go index a600b95f3..ef162004c 100644 --- a/session.go +++ b/session.go @@ -38,7 +38,7 @@ import ( "time" "unicode" - "github.com/gocql/gocql/internal/lru" + "github.com/maypok86/otter/v2" ) // Session is the interface used by users to interact with the database. @@ -54,7 +54,7 @@ type Session struct { cons Consistency pageSize int prefetch float64 - routingKeyInfoCache routingKeyInfoLRU + routingKeyInfoCache *otter.Cache[string, *routingKeyInfo] schemaDescriber *schemaDescriber trace Tracer queryObserver QueryObserver @@ -64,7 +64,7 @@ type Session struct { streamObserver StreamObserver hostSource *ringDescriber ringRefresher *refreshDebouncer - stmtsLRU *preparedLRU + stmtsLRU *otter.Cache[string, *preparedStatment] connCfg *ConnConfig @@ -152,7 +152,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) { prefetch: 0.25, cfg: cfg, pageSize: cfg.PageSize, - stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)}, + stmtsLRU: NewPreparedLRU(cfg.MaxPreparedStmts), connectObserver: cfg.ConnectObserver, ctx: ctx, cancel: cancel, @@ -164,7 +164,7 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent, s.logger) s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent, s.logger) - s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) + s.routingKeyInfoCache = NewRoutingKeyInfoLRU(cfg.MaxRoutingKeyInfo) s.hostSource = &ringDescriber{session: s} s.ringRefresher = newRefreshDebouncer(ringRefreshDebounceTime, func() error { return refreshRing(s.hostSource) }) @@ -593,138 +593,95 @@ func (s *Session) getConn() *Conn { // returns routing key indexes and type info func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) { - s.routingKeyInfoCache.mu.Lock() - - entry, cached := s.routingKeyInfoCache.lru.Get(stmt) - if cached { - // done accessing the cache - s.routingKeyInfoCache.mu.Unlock() - // the entry is an inflight struct similar to that used by - // Conn to prepare statements - inflight := entry.(*inflightCachedEntry) - - // wait for any inflight work - inflight.wg.Wait() - - if inflight.err != nil { - return nil, inflight.err + loader := otter.LoaderFunc[string, *routingKeyInfo](func(ctx context.Context, key string) (*routingKeyInfo, error) { + conn := s.getConn() + if conn == nil { + return nil, errors.New("gocql: unable to fetch prepared info: no connection available") } - key, _ := inflight.value.(*routingKeyInfo) - - return key, nil - } + // get the query info for the statement + info, err := conn.prepareStatement(ctx, key, nil) + if err != nil { + return nil, err + } - // create a new inflight entry while the data is created - inflight := new(inflightCachedEntry) - inflight.wg.Add(1) - defer inflight.wg.Done() - s.routingKeyInfoCache.lru.Add(stmt, inflight) - s.routingKeyInfoCache.mu.Unlock() + // TODO: it would be nice to mark hosts here but as we are not using the policies + // to fetch hosts we cant - var ( - info *preparedStatment - partitionKey []*ColumnMetadata - ) + if info.request.colCount == 0 { + // no arguments, no routing key, and no error + return nil, nil + } - conn := s.getConn() - if conn == nil { - // TODO: better error? - inflight.err = errors.New("gocql: unable to fetch prepared info: no connection available") - return nil, inflight.err - } + table := info.request.table + keyspace := info.request.keyspace - // get the query info for the statement - info, inflight.err = conn.prepareStatement(ctx, stmt, nil) - if inflight.err != nil { - // don't cache this error - s.routingKeyInfoCache.Remove(stmt) - return nil, inflight.err - } + if len(info.request.pkeyColumns) > 0 { + // proto v4 dont need to calculate primary key columns + types := make([]TypeInfo, len(info.request.pkeyColumns)) + for i, col := range info.request.pkeyColumns { + types[i] = info.request.columns[col].TypeInfo + } - // TODO: it would be nice to mark hosts here but as we are not using the policies - // to fetch hosts we cant + routingKeyInfo := &routingKeyInfo{ + indexes: info.request.pkeyColumns, + types: types, + keyspace: keyspace, + table: table, + } - if info.request.colCount == 0 { - // no arguments, no routing key, and no error - return nil, nil - } + return routingKeyInfo, nil + } - table := info.request.table - keyspace := info.request.keyspace + var keyspaceMetadata *KeyspaceMetadata + keyspaceMetadata, err = s.KeyspaceMetadata(info.request.columns[0].Keyspace) + if err != nil { + return nil, err + } - if len(info.request.pkeyColumns) > 0 { - // proto v4 dont need to calculate primary key columns - types := make([]TypeInfo, len(info.request.pkeyColumns)) - for i, col := range info.request.pkeyColumns { - types[i] = info.request.columns[col].TypeInfo + tableMetadata, found := keyspaceMetadata.Tables[table] + if !found { + // unlikely that the statement could be prepared and the metadata for + // the table couldn't be found, but this may indicate either a bug + // in the metadata code, or that the table was just dropped. + return nil, ErrNoMetadata } + partitionKey := tableMetadata.PartitionKey + + size := len(partitionKey) routingKeyInfo := &routingKeyInfo{ - indexes: info.request.pkeyColumns, - types: types, + indexes: make([]int, size), + types: make([]TypeInfo, size), keyspace: keyspace, table: table, } - inflight.value = routingKeyInfo - return routingKeyInfo, nil - } - - var keyspaceMetadata *KeyspaceMetadata - keyspaceMetadata, inflight.err = s.KeyspaceMetadata(info.request.columns[0].Keyspace) - if inflight.err != nil { - // don't cache this error - s.routingKeyInfoCache.Remove(stmt) - return nil, inflight.err - } - - tableMetadata, found := keyspaceMetadata.Tables[table] - if !found { - // unlikely that the statement could be prepared and the metadata for - // the table couldn't be found, but this may indicate either a bug - // in the metadata code, or that the table was just dropped. - inflight.err = ErrNoMetadata - // don't cache this error - s.routingKeyInfoCache.Remove(stmt) - return nil, inflight.err - } - - partitionKey = tableMetadata.PartitionKey - - size := len(partitionKey) - routingKeyInfo := &routingKeyInfo{ - indexes: make([]int, size), - types: make([]TypeInfo, size), - keyspace: keyspace, - table: table, - } - - for keyIndex, keyColumn := range partitionKey { - // set an indicator for checking if the mapping is missing - routingKeyInfo.indexes[keyIndex] = -1 - - // find the column in the query info - for argIndex, boundColumn := range info.request.columns { - if keyColumn.Name == boundColumn.Name { - // there may be many such bound columns, pick the first - routingKeyInfo.indexes[keyIndex] = argIndex - routingKeyInfo.types[keyIndex] = boundColumn.TypeInfo - break + for keyIndex, keyColumn := range partitionKey { + // set an indicator for checking if the mapping is missing + routingKeyInfo.indexes[keyIndex] = -1 + + // find the column in the query info + for argIndex, boundColumn := range info.request.columns { + if keyColumn.Name == boundColumn.Name { + // there may be many such bound columns, pick the first + routingKeyInfo.indexes[keyIndex] = argIndex + routingKeyInfo.types[keyIndex] = boundColumn.TypeInfo + break + } } - } - if routingKeyInfo.indexes[keyIndex] == -1 { - // missing a routing key column mapping - // no routing key, and no error - return nil, nil + if routingKeyInfo.indexes[keyIndex] == -1 { + // missing a routing key column mapping + // no routing key, and no error + return nil, nil + } } - } - // cache this result - inflight.value = routingKeyInfo + return routingKeyInfo, nil + }) - return routingKeyInfo, nil + return s.routingKeyInfoCache.Get(ctx, stmt, loader) } func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { @@ -2068,10 +2025,11 @@ func (c ColumnInfo) String() string { return fmt.Sprintf("[column keyspace=%s table=%s name=%s type=%v]", c.Keyspace, c.Table, c.Name, c.TypeInfo) } -// routing key indexes LRU cache -type routingKeyInfoLRU struct { - lru *lru.Cache - mu sync.Mutex +func NewRoutingKeyInfoLRU(size int) *otter.Cache[string, *routingKeyInfo] { + return otter.Must(&otter.Options[string, *routingKeyInfo]{ + InitialCapacity: size, + MaximumSize: size, + }) } type routingKeyInfo struct { @@ -2085,29 +2043,6 @@ func (r *routingKeyInfo) String() string { return fmt.Sprintf("routing key index=%v types=%v", r.indexes, r.types) } -func (r *routingKeyInfoLRU) Remove(key string) { - r.mu.Lock() - r.lru.Remove(key) - r.mu.Unlock() -} - -// Max adjusts the maximum size of the cache and cleans up the oldest records if -// the new max is lower than the previous value. Not concurrency safe. -func (r *routingKeyInfoLRU) Max(max int) { - r.mu.Lock() - for r.lru.Len() > max { - r.lru.RemoveOldest() - } - r.lru.MaxEntries = max - r.mu.Unlock() -} - -type inflightCachedEntry struct { - wg sync.WaitGroup - err error - value interface{} -} - // Tracer is the interface implemented by query tracers. Tracers have the // ability to obtain a detailed event log of all events that happened during // the execution of a query from Cassandra. Gathering this information might