diff --git a/CHANGELOG.md b/CHANGELOG.md index 15f0102cf..65f1163f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Support of sending queries to the specific node with Query.SetHostID() (CASSGO-4) + ### Changed - Don't restrict server authenticator unless PasswordAuthentictor.AllowedAuthenticators is provided (CASSGO-19) diff --git a/cassandra_test.go b/cassandra_test.go index 9fa9abad5..d20df8ffe 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -3312,7 +3312,6 @@ func TestUnsetColBatch(t *testing.T) { } var id, mInt, count int var mText string - if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil { t.Fatalf("Failed to select with err: %v", err) } else if count != 2 { @@ -3347,3 +3346,52 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +// This test ensures that queries are sent to the specified host only +func TestQuery_SetHostID(t *testing.T) { + session := createSession(t) + defer session.Close() + + hosts := session.GetHosts() + + const iterations = 5 + for _, expectedHost := range hosts { + for i := 0; i < iterations; i++ { + var actualHostID string + err := session.Query("SELECT host_id FROM system.local"). + SetHostID(expectedHost.HostID()). + Scan(&actualHostID) + if err != nil { + t.Fatal(err) + } + + if expectedHost.HostID() != actualHostID { + t.Fatalf("Expected query to be executed on host %s, but it was executed on %s", + expectedHost.HostID(), + actualHostID, + ) + } + } + } + + // ensuring properly handled invalid host id + err := session.Query("SELECT host_id FROM system.local"). + SetHostID("[invalid]"). + Exec() + if !errors.Is(err, ErrNoConnections) { + t.Fatalf("Expected error to be: %v, but got %v", ErrNoConnections, err) + } + + // ensuring that the driver properly handles the case + // when specified host for the query is down + host := hosts[0] + pool, _ := session.pool.getPoolByHostID(host.HostID()) + // simulating specified host is down + pool.host.setState(NodeDown) + err = session.Query("SELECT host_id FROM system.local"). + SetHostID(host.HostID()). + Exec() + if !errors.Is(err, ErrNoConnections) { + t.Fatalf("Expected error to be: %v, but got %v", ErrNoConnections, err) + } +} diff --git a/connectionpool.go b/connectionpool.go index 2ccd3c8a7..9b8295e70 100644 --- a/connectionpool.go +++ b/connectionpool.go @@ -243,6 +243,13 @@ func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) { return } +func (p *policyConnPool) getPoolByHostID(hostID string) (pool *hostConnPool, ok bool) { + p.mu.RLock() + pool, ok = p.hostConnPools[hostID] + p.mu.RUnlock() + return +} + func (p *policyConnPool) Close() { p.mu.Lock() defer p.mu.Unlock() diff --git a/query_executor.go b/query_executor.go index d6be02e53..9eaf19dbb 100644 --- a/query_executor.go +++ b/query_executor.go @@ -41,6 +41,7 @@ type ExecutableQuery interface { Keyspace() string Table() string IsIdempotent() bool + GetHostID() string withContext(context.Context) ExecutableQuery @@ -83,12 +84,32 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S } func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { - hostIter := q.policy.Pick(qry) + var hostIter NextHost + + // check if the host id is specified for the query, + // if it is, the query should be executed at the corresponding host. + if hostID := qry.GetHostID(); hostID != "" { + hostIter = func() SelectedHost { + pool, ok := q.pool.getPoolByHostID(hostID) + // if the specified host is down + // we return nil to avoid endless query execution in queryExecutor.do() + if !ok || !pool.host.IsUp() { + return nil + } + return (*selectedHost)(pool.host) + } + } + + // if host is not specified for the query, + // then a host will be picked by HostSelectionPolicy + if hostIter == nil { + hostIter = q.policy.Pick(qry) + } // check if the query is not marked as idempotent, if // it is, we force the policy to NonSpeculative sp := qry.speculativeExecutionPolicy() - if !qry.IsIdempotent() || sp.Attempts() == 0 { + if qry.GetHostID() != "" || !qry.IsIdempotent() || sp.Attempts() == 0 { return q.do(qry.Context(), qry, hostIter), nil } diff --git a/session.go b/session.go index c47e753d9..8965f0f54 100644 --- a/session.go +++ b/session.go @@ -456,6 +456,7 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query { qry.session = s qry.stmt = stmt qry.values = values + qry.hostID = "" qry.defaultsFromSession() return qry } @@ -949,6 +950,10 @@ type Query struct { // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo + + // hostID specifies the host on which the query should be executed. + // If it is empty, then the host is picked by HostSelectionPolicy + hostID string } type queryRoutingInfo struct { @@ -1442,6 +1447,20 @@ func (q *Query) releaseAfterExecution() { q.decRefCount() } +// SetHostID allows to define the host the query should be executed against. If the +// host was filtered or otherwise unavailable, then the query will error. If an empty +// string is sent, the default behavior, using the configured HostSelectionPolicy will +// be used. A hostID can be obtained from HostInfo.HostID() after calling GetHosts(). +func (q *Query) SetHostID(hostID string) *Query { + q.hostID = hostID + return q +} + +// GetHostID returns id of the host on which query should be executed. +func (q *Query) GetHostID() string { + return q.hostID +} + // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. @@ -2057,6 +2076,11 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } +// GetHostID satisfies ExecutableQuery interface but does noop. +func (b *Batch) GetHostID() string { + return "" +} + type BatchType byte const ( @@ -2189,6 +2213,11 @@ func (t *traceWriter) Trace(traceId []byte) { } } +// GetHosts return a list of hosts in the ring the driver knows of. +func (s *Session) GetHosts() []*HostInfo { + return s.ring.allHosts() +} + type ObservedQuery struct { Keyspace string Statement string