Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Support of sending queries to the specific node with Query.SetHostID() (CASSGO-4)

### Changed

- Don't restrict server authenticator unless PasswordAuthentictor.AllowedAuthenticators is provided (CASSGO-19)
Expand Down
50 changes: 49 additions & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3312,7 +3312,6 @@ func TestUnsetColBatch(t *testing.T) {
}
var id, mInt, count int
var mText string

if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil {
t.Fatalf("Failed to select with err: %v", err)
} else if count != 2 {
Expand Down Expand Up @@ -3347,3 +3346,52 @@ func TestQuery_NamedValues(t *testing.T) {
t.Fatal(err)
}
}

// This test ensures that queries are sent to the specified host only
func TestQuery_SetHostID(t *testing.T) {
session := createSession(t)
defer session.Close()

hosts := session.GetHosts()

const iterations = 5
for _, expectedHost := range hosts {
for i := 0; i < iterations; i++ {
var actualHostID string
err := session.Query("SELECT host_id FROM system.local").
SetHostID(expectedHost.HostID()).
Scan(&actualHostID)
if err != nil {
t.Fatal(err)
}

if expectedHost.HostID() != actualHostID {
t.Fatalf("Expected query to be executed on host %s, but it was executed on %s",
expectedHost.HostID(),
actualHostID,
)
}
}
}

// ensuring properly handled invalid host id
err := session.Query("SELECT host_id FROM system.local").
SetHostID("[invalid]").
Exec()
if !errors.Is(err, ErrNoConnections) {
t.Fatalf("Expected error to be: %v, but got %v", ErrNoConnections, err)
}

// ensuring that the driver properly handles the case
// when specified host for the query is down
host := hosts[0]
pool, _ := session.pool.getPoolByHostID(host.HostID())
// simulating specified host is down
pool.host.setState(NodeDown)
err = session.Query("SELECT host_id FROM system.local").
SetHostID(host.HostID()).
Exec()
if !errors.Is(err, ErrNoConnections) {
t.Fatalf("Expected error to be: %v, but got %v", ErrNoConnections, err)
}
}
7 changes: 7 additions & 0 deletions connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
return
}

func (p *policyConnPool) getPoolByHostID(hostID string) (pool *hostConnPool, ok bool) {
p.mu.RLock()
pool, ok = p.hostConnPools[hostID]
p.mu.RUnlock()
return
}

func (p *policyConnPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
Expand Down
25 changes: 23 additions & 2 deletions query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type ExecutableQuery interface {
Keyspace() string
Table() string
IsIdempotent() bool
GetHostID() string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option here rather than changing the ExecutableQuery is to do:

type hostIDQuery interface {
  GetHostID() string
}

then later in executeQuery you'd do:

var hostID string
if hostIDQry, ok := qry.(hostIDQuery); ok {
  hostID = hostIDQry.GetHostID()
}

Then we wouldn't need to break this interface and worry about modifying batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was the first approach I've proposed. We discussed this with @joao-r-reis. If you have an opinion on this, let's discuss it here.

TL;DR
We have two solutions how to avoid type casts and make the implementation more obvious:

  1. Extend ExecutableQuery interface, because it can't be implemented outside gocql due to private methods like execute, so we cannot consider it a breaking change.
  2. Extract private methods from ExecutableQuery to an internal interface which also should implement ExecutableQuery and use it internally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't think we need to treat ExecutableQuery as a "public" interface to be implemented by users, it's a bit messy right now and I actually want to spend some time investigating how we could clean up this part of the API as part of #1848 ideally in time for the 2.0 release

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I was just suggesting so we didn't have to have a weird case for Batch. I think its fine to add and we can aim to improve it in a future MR.


withContext(context.Context) ExecutableQuery

Expand Down Expand Up @@ -83,12 +84,32 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
}

func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
hostIter := q.policy.Pick(qry)
var hostIter NextHost

// check if the host id is specified for the query,
// if it is, the query should be executed at the corresponding host.
if hostID := qry.GetHostID(); hostID != "" {
hostIter = func() SelectedHost {
pool, ok := q.pool.getPoolByHostID(hostID)
// if the specified host is down
// we return nil to avoid endless query execution in queryExecutor.do()
if !ok || !pool.host.IsUp() {
return nil
}
return (*selectedHost)(pool.host)
}
}

// if host is not specified for the query,
// then a host will be picked by HostSelectionPolicy
if hostIter == nil {
hostIter = q.policy.Pick(qry)
}

// check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative
sp := qry.speculativeExecutionPolicy()
if !qry.IsIdempotent() || sp.Attempts() == 0 {
if qry.GetHostID() != "" || !qry.IsIdempotent() || sp.Attempts() == 0 {
return q.do(qry.Context(), qry, hostIter), nil
}

Expand Down
29 changes: 29 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ func (s *Session) Query(stmt string, values ...interface{}) *Query {
qry.session = s
qry.stmt = stmt
qry.values = values
qry.hostID = ""
qry.defaultsFromSession()
return qry
}
Expand Down Expand Up @@ -949,6 +950,10 @@ type Query struct {

// routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex.
routingInfo *queryRoutingInfo

// hostID specifies the host on which the query should be executed.
// If it is empty, then the host is picked by HostSelectionPolicy
hostID string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about queryTargetHostId or targetHostID?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's an internal field so we can call it whatever. again, i think that it's obvious enough as-is without "target" but i'm fine internally if there's a need to clarify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to leave hostID to be consistent with method names

}

type queryRoutingInfo struct {
Expand Down Expand Up @@ -1442,6 +1447,20 @@ func (q *Query) releaseAfterExecution() {
q.decRefCount()
}

// SetHostID allows to define the host the query should be executed against. If the
// host was filtered or otherwise unavailable, then the query will error. If an empty
// string is sent, the default behavior, using the configured HostSelectionPolicy will
// be used. A hostID can be obtained from HostInfo.HostID() after calling GetHosts().
func (q *Query) SetHostID(hostID string) *Query {
q.hostID = hostID
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check if it's a non empty string because if it's empty then the query will be executed as if SetHostID() was never called which is odd behavior. Either that or we change q.hostID to be *string so we can check if it was actually set or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. If we add a check then we should panic if it is empty which is probably not the best, but using *string seems to be fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a pointer seems a bit awkward to use and is very unusual given all of the other functions in this package. What happens if you call RoutingKey(nil) or PageState(nil), etc. Setting it to empty could be a way to unset an existing HostID if one was set for some reason prior. When queries are immutable it could be useful to get a copy of a query without a specific HostID that previously had one set.

We should just document that sending an empty string will restore the default behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent some time thinking about this more and I'm +1 on what @jameshartig says. We at least should have a way to restore Query.hostID to default.

If SetHostID takes *string it is available, but this makes UX of this API way worse IMO. I don't like the idea of creating vars to hold some strings and pass a pointer to them. However, if API is SetHostID(string) there is no way to restore its behavior to default. Once SetHostID() is used, it can't be changed. Ofc we can expose API Query.Default() to restore this behavior, but this becomes something very odd to me.

All of this is based on the case when we want to copy the existing Query and somehow modify it, but I'm unsure how this matches query immutability. By this do we understand that once the Query obj is created it can't be modified? Either it means that the driver doesn't modify Query internally during its execution? As far as I know, gocql writes some metrics to the query obj.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I suggested the pointer I wasn't actually saying we should make the parameter a pointer, I was saying we could make the private field a pointer so we know when the user actually set the value or not. The parameter would still be string.

Copy link
Contributor

@joao-r-reis joao-r-reis Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but yeah I think we can just document it instead of checking and failing. To return an error we would have to change the method signature which would be awkward (or panic'ing but that's just a bad idea anyway)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If HostFilter is causing issues, consider mentioning it in the documentation. For example: 'SetHostID will not work on filtered nodes (i.e., nodes excluded by ClusterConfig.HostFilter)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments for this method are getting lengthly. How about:

// SetHostID allows to define the host the query should be executed against. If the
// host was filtered or otherwise unavailable, then the query will error. If an empty
// string is sent, the default behavior, using the configured HostSelectionPolicy will
// be used. A hostID can be obtained from HostInfo.HostID() after calling GetHosts().

I removed the WithContext part because we plan to fix that in a follow-up issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. It is much better!

I removed the WithContext

Oh I completely forgot about this, thanks lots!

return q
}

// GetHostID returns id of the host on which query should be executed.
func (q *Query) GetHostID() string {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd name this method GetTargetHostId to make it more specific.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you like it, then SetTargetHostID might be changed as well 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, they both have better semantics... If we want to be consistence with the same API in other drivers it should be SetHost(), but we don't really have to make the same...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if Target is specifically necessary since it seems kind of obvious what the point of the host is. I do agree that staying closer to the java API is probably ideal, which SetHostID tries to do.

return q.hostID
}

// Iter represents an iterator that can be used to iterate over all rows that
// were returned by a query. The iterator might send additional queries to the
// database during the iteration if paging was enabled.
Expand Down Expand Up @@ -2057,6 +2076,11 @@ func (b *Batch) releaseAfterExecution() {
// that would race with speculative executions.
}

// GetHostID satisfies ExecutableQuery interface but does noop.
func (b *Batch) GetHostID() string {
return ""
}

type BatchType byte

const (
Expand Down Expand Up @@ -2189,6 +2213,11 @@ func (t *traceWriter) Trace(traceId []byte) {
}
}

// GetHosts return a list of hosts in the ring the driver knows of.
func (s *Session) GetHosts() []*HostInfo {
return s.ring.allHosts()
}

type ObservedQuery struct {
Keyspace string
Statement string
Expand Down
Loading