@@ -1483,7 +1483,7 @@ func TestQueryInfo(t *testing.T) {
14831483 defer session .Close ()
14841484
14851485 conn := getRandomConn (t , session )
1486- info , err := conn .prepareStatement (context .Background (), "SELECT release_version, host_id FROM system.local WHERE key = ?" , nil )
1486+ info , err := conn .prepareStatement (context .Background (), "SELECT release_version, host_id FROM system.local WHERE key = ?" , nil , conn . currentKeyspace )
14871487
14881488 if err != nil {
14891489 t .Fatalf ("Failed to execute query for preparing statement: %v" , err )
@@ -2602,7 +2602,7 @@ func TestRoutingKey(t *testing.T) {
26022602 t .Fatalf ("failed to create table with error '%v'" , err )
26032603 }
26042604
2605- routingKeyInfo , err := session .routingKeyInfo (context .Background (), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?" )
2605+ routingKeyInfo , err := session .routingKeyInfo (context .Background (), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?" , "" )
26062606 if err != nil {
26072607 t .Fatalf ("failed to get routing key info due to error: %v" , err )
26082608 }
@@ -2626,7 +2626,7 @@ func TestRoutingKey(t *testing.T) {
26262626 }
26272627
26282628 // verify the cache is working
2629- routingKeyInfo , err = session .routingKeyInfo (context .Background (), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?" )
2629+ routingKeyInfo , err = session .routingKeyInfo (context .Background (), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?" , "" )
26302630 if err != nil {
26312631 t .Fatalf ("failed to get routing key info due to error: %v" , err )
26322632 }
@@ -2660,7 +2660,7 @@ func TestRoutingKey(t *testing.T) {
26602660 t .Errorf ("Expected routing key %v but was %v" , expectedRoutingKey , routingKey )
26612661 }
26622662
2663- routingKeyInfo , err = session .routingKeyInfo (context .Background (), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?" )
2663+ routingKeyInfo , err = session .routingKeyInfo (context .Background (), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?" , "" )
26642664 if err != nil {
26652665 t .Fatalf ("failed to get routing key info due to error: %v" , err )
26662666 }
@@ -3606,3 +3606,135 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
36063606 require .Equal (t , preparedStatementAfterTableAltering2 .resultMetadataID , preparedStatementAfterTableAltering3 .resultMetadataID )
36073607 require .Equal (t , preparedStatementAfterTableAltering2 .response , preparedStatementAfterTableAltering3 .response )
36083608}
3609+
3610+ func TestStmtCacheUsesOverriddenKeyspace (t * testing.T ) {
3611+ session := createSession (t )
3612+ defer session .Close ()
3613+
3614+ const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
3615+ WITH replication = {
3616+ 'class' : 'SimpleStrategy',
3617+ 'replication_factor' : 1
3618+ }`
3619+
3620+ err := createTable (session , fmt .Sprintf (createKeyspaceStmt , "gocql_test_stmt_cache" ))
3621+ if err != nil {
3622+ t .Fatal (err )
3623+ }
3624+
3625+ err = createTable (session , "CREATE TABLE IF NOT EXISTS gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))" )
3626+ if err != nil {
3627+ t .Fatal (err )
3628+ }
3629+
3630+ err = createTable (session , "CREATE TABLE IF NOT EXISTS gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))" )
3631+ if err != nil {
3632+ t .Fatal (err )
3633+ }
3634+
3635+ const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id) VALUES (?)"
3636+
3637+ // Inserting data via Batch to ensure that batches
3638+ // properly accounts for keyspace overriding
3639+ b1 := session .NewBatch (LoggedBatch )
3640+ b1 .Query (insertQuery , 1 )
3641+ err = session .ExecuteBatch (b1 )
3642+ require .NoError (t , err )
3643+
3644+ b2 := session .NewBatch (LoggedBatch )
3645+ b2 .SetKeyspace ("gocql_test_stmt_cache" )
3646+ b2 .Query (insertQuery , 2 )
3647+ err = session .ExecuteBatch (b2 )
3648+ require .NoError (t , err )
3649+
3650+ var scannedID int
3651+
3652+ const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks"
3653+
3654+ // By default in our test suite session uses gocql_test ks
3655+ err = session .Query (selectStmt ).Scan (& scannedID )
3656+ require .NoError (t , err )
3657+ require .Equal (t , 1 , scannedID )
3658+
3659+ scannedID = 0
3660+ err = session .Query (selectStmt ).SetKeyspace ("gocql_test_stmt_cache" ).Scan (& scannedID )
3661+ require .NoError (t , err )
3662+ require .Equal (t , 2 , scannedID )
3663+
3664+ session .Query ("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache" ).Exec ()
3665+ }
3666+
3667+ func TestRoutingKeyCacheUsesOverriddenKeyspace (t * testing.T ) {
3668+ session := createSession (t )
3669+ defer session .Close ()
3670+
3671+ const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
3672+ WITH replication = {
3673+ 'class' : 'SimpleStrategy',
3674+ 'replication_factor' : 1
3675+ }`
3676+
3677+ err := createTable (session , fmt .Sprintf (createKeyspaceStmt , "gocql_test_routing_key_cache" ))
3678+ if err != nil {
3679+ t .Fatal (err )
3680+ }
3681+
3682+ err = createTable (session , "CREATE TABLE IF NOT EXISTS gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))" )
3683+ if err != nil {
3684+ t .Fatal (err )
3685+ }
3686+
3687+ err = createTable (session , "CREATE TABLE IF NOT EXISTS gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))" )
3688+ if err != nil {
3689+ t .Fatal (err )
3690+ }
3691+
3692+ getRoutingKeyInfo := func (key string ) * routingKeyInfo {
3693+ t .Helper ()
3694+ session .routingKeyInfoCache .mu .Lock ()
3695+ value , _ := session .routingKeyInfoCache .lru .Get (key )
3696+ session .routingKeyInfoCache .mu .Unlock ()
3697+
3698+ inflight := value .(* inflightCachedEntry )
3699+ return inflight .value .(* routingKeyInfo )
3700+ }
3701+
3702+ const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)"
3703+
3704+ // Running batch in default ks gocql_test
3705+ b1 := session .NewBatch (LoggedBatch )
3706+ b1 .Query (insertQuery , 1 )
3707+ _ , err = b1 .GetRoutingKey ()
3708+ require .NoError (t , err )
3709+
3710+ // Ensuring that the cache contains the query with default ks
3711+ routingKeyInfo1 := getRoutingKeyInfo ("gocql_test" + b1 .Entries [0 ].Stmt )
3712+ require .Equal (t , "gocql_test" , routingKeyInfo1 .keyspace )
3713+
3714+ // Running batch in gocql_test_routing_key_cache ks
3715+ b2 := session .NewBatch (LoggedBatch )
3716+ b2 .SetKeyspace ("gocql_test_routing_key_cache" )
3717+ b2 .Query (insertQuery , 2 )
3718+ _ , err = b2 .GetRoutingKey ()
3719+ require .NoError (t , err )
3720+
3721+ // Ensuring that the cache contains the query with gocql_test_routing_key_cache ks
3722+ routingKeyInfo2 := getRoutingKeyInfo ("gocql_test_routing_key_cache" + b2 .Entries [0 ].Stmt )
3723+ require .Equal (t , "gocql_test_routing_key_cache" , routingKeyInfo2 .keyspace )
3724+
3725+ const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?"
3726+
3727+ // Running query in default ks gocql_test
3728+ q1 := session .Query (selectStmt , 1 )
3729+ _ , err = q1 .GetRoutingKey ()
3730+ require .NoError (t , err )
3731+ require .Equal (t , "gocql_test" , q1 .routingInfo .keyspace )
3732+
3733+ // Running query in gocql_test_routing_key_cache ks
3734+ q2 := session .Query (selectStmt , 1 )
3735+ _ , err = q2 .SetKeyspace ("gocql_test_routing_key_cache" ).GetRoutingKey ()
3736+ require .NoError (t , err )
3737+ require .Equal (t , "gocql_test_routing_key_cache" , q2 .routingInfo .keyspace )
3738+
3739+ session .Query ("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache" ).Exec ()
3740+ }
0 commit comments