Skip to content

Commit 0444ee3

Browse files
committed
Support of sending queries to the specific node
1 parent 974fa12 commit 0444ee3

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

cassandra_test.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3253,7 +3253,6 @@ func TestUnsetColBatch(t *testing.T) {
32533253
}
32543254
var id, mInt, count int
32553255
var mText string
3256-
32573256
if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil {
32583257
t.Fatalf("Failed to select with err: %v", err)
32593258
} else if count != 2 {
@@ -3288,3 +3287,33 @@ func TestQuery_NamedValues(t *testing.T) {
32883287
t.Fatal(err)
32893288
}
32903289
}
3290+
3291+
func TestQuery_SetHost(t *testing.T) {
3292+
session := createSession(t)
3293+
defer session.Close()
3294+
3295+
hosts, err := session.GetHosts()
3296+
if err != nil {
3297+
t.Fatal(err)
3298+
}
3299+
3300+
for _, expectedHost := range hosts {
3301+
const iterations = 5
3302+
for i := 0; i < iterations; i++ {
3303+
var actualHostID string
3304+
err := session.Query("SELECT host_id FROM system.local").
3305+
SetHost(expectedHost).
3306+
Scan(&actualHostID)
3307+
if err != nil {
3308+
t.Fatal(err)
3309+
}
3310+
3311+
if expectedHost.HostID() != actualHostID {
3312+
t.Fatalf("Expected query to be executed on host %s, but it was executed on %s",
3313+
expectedHost.HostID(),
3314+
actualHostID,
3315+
)
3316+
}
3317+
}
3318+
}
3319+
}

query_executor.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,28 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
8383
}
8484

8585
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
86-
hostIter := q.policy.Pick(qry)
86+
type hostGetter interface {
87+
getHost() *HostInfo
88+
}
89+
90+
var hostIter NextHost
91+
// checking if the qry implements hostGetter interface
92+
if hostGetter, ok := qry.(hostGetter); ok {
93+
// checking if the host is specified for the query,
94+
// if it is, the query should be executed at the specified host
95+
if host := hostGetter.getHost(); host != nil {
96+
hostIter = func() SelectedHost {
97+
return (*selectedHost)(host)
98+
}
99+
}
100+
}
101+
102+
// if host is not specified for the query,
103+
// or it doesn't implement hostGetter interface,
104+
// then a host will be picked by HostSelectionPolicy
105+
if hostIter == nil {
106+
hostIter = q.policy.Pick(qry)
107+
}
87108

88109
// check if the query is not marked as idempotent, if
89110
// it is, we force the policy to NonSpeculative

session.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,10 @@ type Query struct {
936936

937937
// routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex.
938938
routingInfo *queryRoutingInfo
939+
940+
// host specifies the host on which the query should be executed.
941+
// If it is nil, then the host is picked by HostSelectionPolicy
942+
host *HostInfo
939943
}
940944

941945
type queryRoutingInfo struct {
@@ -1423,6 +1427,17 @@ func (q *Query) releaseAfterExecution() {
14231427
q.decRefCount()
14241428
}
14251429

1430+
// SetHosts allows to define on which host the query should be executed.
1431+
// If host == nil, then the HostSelectionPolicy will be used to pick a host
1432+
func (q *Query) SetHost(host *HostInfo) *Query {
1433+
q.host = host
1434+
return q
1435+
}
1436+
1437+
func (q *Query) getHost() *HostInfo {
1438+
return q.host
1439+
}
1440+
14261441
// Iter represents an iterator that can be used to iterate over all rows that
14271442
// were returned by a query. The iterator might send additional queries to the
14281443
// database during the iteration if paging was enabled.
@@ -2174,6 +2189,15 @@ func (t *traceWriter) Trace(traceId []byte) {
21742189
}
21752190
}
21762191

2192+
// GetHosts returns a list of hosts found via queries to system.local and system.peers
2193+
func (s *Session) GetHosts() ([]*HostInfo, error) {
2194+
hosts, _, err := s.hostSource.GetHosts()
2195+
if err != nil {
2196+
return nil, err
2197+
}
2198+
return hosts, nil
2199+
}
2200+
21772201
type ObservedQuery struct {
21782202
Keyspace string
21792203
Statement string

0 commit comments

Comments
 (0)