Skip to content

Commit 0298a00

Browse files
committed
1. Updated the way how the driver constructs stmt cache keys. The current code base uses initial keyspace provided by the user to construct the keys. Since proto v5 we also should account for keyspace bounding for a specific query, so the driver should use the bounded keyspace instead of the initial to construct the key.
2. Changed the way how routing key cache keys are constructed to account the keyspace overriding as well.
1 parent 0592a90 commit 0298a00

File tree

3 files changed

+180
-23
lines changed

3 files changed

+180
-23
lines changed

cassandra_test.go

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,7 +1483,7 @@ func TestQueryInfo(t *testing.T) {
14831483
defer session.Close()
14841484

14851485
conn := getRandomConn(t, session)
1486-
info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
1486+
info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil, conn.currentKeyspace)
14871487

14881488
if err != nil {
14891489
t.Fatalf("Failed to execute query for preparing statement: %v", err)
@@ -2602,7 +2602,7 @@ func TestRoutingKey(t *testing.T) {
26022602
t.Fatalf("failed to create table with error '%v'", err)
26032603
}
26042604

2605-
routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
2605+
routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
26062606
if err != nil {
26072607
t.Fatalf("failed to get routing key info due to error: %v", err)
26082608
}
@@ -2626,7 +2626,7 @@ func TestRoutingKey(t *testing.T) {
26262626
}
26272627

26282628
// verify the cache is working
2629-
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
2629+
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
26302630
if err != nil {
26312631
t.Fatalf("failed to get routing key info due to error: %v", err)
26322632
}
@@ -2660,7 +2660,7 @@ func TestRoutingKey(t *testing.T) {
26602660
t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey)
26612661
}
26622662

2663-
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
2663+
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "")
26642664
if err != nil {
26652665
t.Fatalf("failed to get routing key info due to error: %v", err)
26662666
}
@@ -3606,3 +3606,135 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
36063606
require.Equal(t, preparedStatementAfterTableAltering2.resultMetadataID, preparedStatementAfterTableAltering3.resultMetadataID)
36073607
require.Equal(t, preparedStatementAfterTableAltering2.response, preparedStatementAfterTableAltering3.response)
36083608
}
3609+
3610+
func TestStmtCacheUsesOverriddenKeyspace(t *testing.T) {
3611+
session := createSession(t)
3612+
defer session.Close()
3613+
3614+
const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
3615+
WITH replication = {
3616+
'class' : 'SimpleStrategy',
3617+
'replication_factor' : 1
3618+
}`
3619+
3620+
err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_stmt_cache"))
3621+
if err != nil {
3622+
t.Fatal(err)
3623+
}
3624+
3625+
err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
3626+
if err != nil {
3627+
t.Fatal(err)
3628+
}
3629+
3630+
err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
3631+
if err != nil {
3632+
t.Fatal(err)
3633+
}
3634+
3635+
const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id) VALUES (?)"
3636+
3637+
// Inserting data via Batch to ensure that batches
3638+
// properly accounts for keyspace overriding
3639+
b1 := session.NewBatch(LoggedBatch)
3640+
b1.Query(insertQuery, 1)
3641+
err = session.ExecuteBatch(b1)
3642+
require.NoError(t, err)
3643+
3644+
b2 := session.NewBatch(LoggedBatch)
3645+
b2.SetKeyspace("gocql_test_stmt_cache")
3646+
b2.Query(insertQuery, 2)
3647+
err = session.ExecuteBatch(b2)
3648+
require.NoError(t, err)
3649+
3650+
var scannedID int
3651+
3652+
const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks"
3653+
3654+
// By default in our test suite session uses gocql_test ks
3655+
err = session.Query(selectStmt).Scan(&scannedID)
3656+
require.NoError(t, err)
3657+
require.Equal(t, 1, scannedID)
3658+
3659+
scannedID = 0
3660+
err = session.Query(selectStmt).SetKeyspace("gocql_test_stmt_cache").Scan(&scannedID)
3661+
require.NoError(t, err)
3662+
require.Equal(t, 2, scannedID)
3663+
3664+
session.Query("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache").Exec()
3665+
}
3666+
3667+
func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) {
3668+
session := createSession(t)
3669+
defer session.Close()
3670+
3671+
const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
3672+
WITH replication = {
3673+
'class' : 'SimpleStrategy',
3674+
'replication_factor' : 1
3675+
}`
3676+
3677+
err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_routing_key_cache"))
3678+
if err != nil {
3679+
t.Fatal(err)
3680+
}
3681+
3682+
err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
3683+
if err != nil {
3684+
t.Fatal(err)
3685+
}
3686+
3687+
err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
3688+
if err != nil {
3689+
t.Fatal(err)
3690+
}
3691+
3692+
getRoutingKeyInfo := func(key string) *routingKeyInfo {
3693+
t.Helper()
3694+
session.routingKeyInfoCache.mu.Lock()
3695+
value, _ := session.routingKeyInfoCache.lru.Get(key)
3696+
session.routingKeyInfoCache.mu.Unlock()
3697+
3698+
inflight := value.(*inflightCachedEntry)
3699+
return inflight.value.(*routingKeyInfo)
3700+
}
3701+
3702+
const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)"
3703+
3704+
// Running batch in default ks gocql_test
3705+
b1 := session.NewBatch(LoggedBatch)
3706+
b1.Query(insertQuery, 1)
3707+
_, err = b1.GetRoutingKey()
3708+
require.NoError(t, err)
3709+
3710+
// Ensuring that the cache contains the query with default ks
3711+
routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt)
3712+
require.Equal(t, "gocql_test", routingKeyInfo1.keyspace)
3713+
3714+
// Running batch in gocql_test_routing_key_cache ks
3715+
b2 := session.NewBatch(LoggedBatch)
3716+
b2.SetKeyspace("gocql_test_routing_key_cache")
3717+
b2.Query(insertQuery, 2)
3718+
_, err = b2.GetRoutingKey()
3719+
require.NoError(t, err)
3720+
3721+
// Ensuring that the cache contains the query with gocql_test_routing_key_cache ks
3722+
routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" + b2.Entries[0].Stmt)
3723+
require.Equal(t, "gocql_test_routing_key_cache", routingKeyInfo2.keyspace)
3724+
3725+
const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?"
3726+
3727+
// Running query in default ks gocql_test
3728+
q1 := session.Query(selectStmt, 1)
3729+
_, err = q1.GetRoutingKey()
3730+
require.NoError(t, err)
3731+
require.Equal(t, "gocql_test", q1.routingInfo.keyspace)
3732+
3733+
// Running query in gocql_test_routing_key_cache ks
3734+
q2 := session.Query(selectStmt, 1)
3735+
_, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey()
3736+
require.NoError(t, err)
3737+
require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace)
3738+
3739+
session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec()
3740+
}

conn.go

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,8 +1410,8 @@ type inflightPrepare struct {
14101410
preparedStatment *preparedStatment
14111411
}
14121412

1413-
func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) {
1414-
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt)
1413+
func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) {
1414+
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace, stmt)
14151415
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
14161416
flight := &inflightPrepare{
14171417
done: make(chan struct{}),
@@ -1486,10 +1486,6 @@ func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tra
14861486
}
14871487
}
14881488

1489-
func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
1490-
return c.prepareStatementForKeyspace(ctx, stmt, tracer, c.currentKeyspace)
1491-
}
1492-
14931489
func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
14941490
if named, ok := value.(*namedValue); ok {
14951491
dst.name = named.name
@@ -1531,6 +1527,13 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
15311527
params.nowInSeconds = qry.nowInSecondsValue
15321528
}
15331529

1530+
// If a keyspace for the qry is overriden,
1531+
// then we should use it to create stmt cache key
1532+
usedKeyspace := c.currentKeyspace
1533+
if qry.keyspace != "" {
1534+
usedKeyspace = qry.keyspace
1535+
}
1536+
15341537
var (
15351538
frame frameBuilder
15361539
info *preparedStatment
@@ -1539,7 +1542,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
15391542
if !qry.skipPrepare && qry.shouldPrepare() {
15401543
// Prepare all DML queries. Other queries can not be prepared.
15411544
var err error
1542-
info, err = c.prepareStatementForKeyspace(ctx, qry.stmt, qry.trace, qry.keyspace)
1545+
info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace)
15431546
if err != nil {
15441547
return &Iter{err: err}
15451548
}
@@ -1584,6 +1587,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
15841587
// Set "keyspace" and "table" property in the query if it is present in preparedMetadata
15851588
qry.routingInfo.mu.Lock()
15861589
qry.routingInfo.keyspace = info.request.keyspace
1590+
if info.request.keyspace == "" {
1591+
qry.routingInfo.keyspace = usedKeyspace
1592+
}
15871593
qry.routingInfo.table = info.request.table
15881594
qry.routingInfo.mu.Unlock()
15891595
} else {
@@ -1616,7 +1622,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
16161622
// If a RESULT/Rows message reports
16171623
// changed resultset metadata with the Metadata_changed flag, the reported new
16181624
// resultset metadata must be used in subsequent executions
1619-
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt)
1625+
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt)
16201626
oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey)
16211627
if ok {
16221628
newInflight := &inflightPrepare{
@@ -1685,7 +1691,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
16851691
// is not consistent with regards to its schema.
16861692
return iter
16871693
case *RequestErrUnprepared:
1688-
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt)
1694+
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt)
16891695
c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId)
16901696
return c.executeQuery(ctx, qry)
16911697
case error:
@@ -1767,14 +1773,19 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
17671773
req.nowInSeconds = batch.nowInSeconds
17681774
}
17691775

1776+
usedKeyspace := c.currentKeyspace
1777+
if batch.keyspace != "" {
1778+
usedKeyspace = batch.keyspace
1779+
}
1780+
17701781
stmts := make(map[string]string, len(batch.Entries))
17711782

17721783
for i := 0; i < n; i++ {
17731784
entry := &batch.Entries[i]
17741785
b := &req.statements[i]
17751786

17761787
if len(entry.Args) > 0 || entry.binding != nil {
1777-
info, err := c.prepareStatementForKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace)
1788+
info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace)
17781789
if err != nil {
17791790
return &Iter{err: err}
17801791
}
@@ -1836,7 +1847,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
18361847
case *RequestErrUnprepared:
18371848
stmt, found := stmts[string(x.StatementId)]
18381849
if found {
1839-
key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt)
1850+
key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt)
18401851
c.session.stmtsLRU.evictPreparedID(key, x.StatementId)
18411852
}
18421853
return c.executeBatch(ctx, batch)

session.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -591,11 +591,20 @@ func (s *Session) getConn() *Conn {
591591
return nil
592592
}
593593

594-
// returns routing key indexes and type info
595-
func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) {
594+
// Returns routing key indexes and type info.
595+
// If keyspace == "" it uses the keyspace which is specified in Cluster.Keyspace
596+
func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace string) (*routingKeyInfo, error) {
597+
if keyspace == "" {
598+
keyspace = s.cfg.Keyspace
599+
}
600+
601+
routingKeyInfoCacheKey := keyspace + stmt
602+
596603
s.routingKeyInfoCache.mu.Lock()
597604

598-
entry, cached := s.routingKeyInfoCache.lru.Get(stmt)
605+
// Using here keyspace + stmt as a cache key because
606+
// the query keyspace could be overridden via SetKeyspace
607+
entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey)
599608
if cached {
600609
// done accessing the cache
601610
s.routingKeyInfoCache.mu.Unlock()
@@ -619,7 +628,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
619628
inflight := new(inflightCachedEntry)
620629
inflight.wg.Add(1)
621630
defer inflight.wg.Done()
622-
s.routingKeyInfoCache.lru.Add(stmt, inflight)
631+
s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight)
623632
s.routingKeyInfoCache.mu.Unlock()
624633

625634
var (
@@ -635,7 +644,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
635644
}
636645

637646
// get the query info for the statement
638-
info, inflight.err = conn.prepareStatement(ctx, stmt, nil)
647+
info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace)
639648
if inflight.err != nil {
640649
// don't cache this error
641650
s.routingKeyInfoCache.Remove(stmt)
@@ -651,7 +660,9 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
651660
}
652661

653662
table := info.request.table
654-
keyspace := info.request.keyspace
663+
if info.request.keyspace != "" {
664+
keyspace = info.request.keyspace
665+
}
655666

656667
if len(info.request.pkeyColumns) > 0 {
657668
// proto v4 dont need to calculate primary key columns
@@ -1146,6 +1157,9 @@ func (q *Query) Keyspace() string {
11461157
if q.routingInfo.keyspace != "" {
11471158
return q.routingInfo.keyspace
11481159
}
1160+
if q.keyspace != "" {
1161+
return q.keyspace
1162+
}
11491163

11501164
if q.session == nil {
11511165
return ""
@@ -1177,7 +1191,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
11771191
}
11781192

11791193
// try to determine the routing key
1180-
routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt)
1194+
routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace)
11811195
if err != nil {
11821196
return nil, err
11831197
}
@@ -2009,7 +2023,7 @@ func (b *Batch) GetRoutingKey() ([]byte, error) {
20092023
return nil, nil
20102024
}
20112025
// try to determine the routing key
2012-
routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt)
2026+
routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace)
20132027
if err != nil {
20142028
return nil, err
20152029
}

0 commit comments

Comments
 (0)