diff --git a/CHANGELOG.md b/CHANGELOG.md index b17c5f008..f642ef900 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for Native Protocol 5. Following protocol changes exposed new API Query.SetKeyspace(), Query.WithNowInSeconds(), Batch.SetKeyspace(), Batch.WithNowInSeconds() (CASSGO-1) - Externally-defined type registration (CASSGO-43) +- Add Query and Batch to ObservedQuery and ObservedBatch (CASSGO-73) ### Changed @@ -43,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - inet columns default to net.IP when using MapScan or SliceMap (CASSGO-43) - NativeType removed (CASSGO-43) - `New` and `NewWithError` removed and replaced with `Zero` (CASSGO-43) +- Changes to Query and Batch to make them safely reusable (CASSGO-22) ### Fixed diff --git a/batch_test.go b/batch_test.go index 47adff83f..7f7d00253 100644 --- a/batch_test.go +++ b/batch_test.go @@ -28,9 +28,10 @@ package gocql import ( - "github.com/stretchr/testify/require" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestBatch_Errors(t *testing.T) { diff --git a/cass1batch_test.go b/cass1batch_test.go index 8e4eb99b2..b699a10bd 100644 --- a/cass1batch_test.go +++ b/cass1batch_test.go @@ -77,7 +77,7 @@ func TestShouldPrepareFunction(t *testing.T) { } for _, test := range shouldPrepareTests { - q := &Query{stmt: test.Stmt, routingInfo: &queryRoutingInfo{}} + q := &Query{stmt: test.Stmt} if got := q.shouldPrepare(); got != test.Result { t.Fatalf("%q: got %v, expected %v\n", test.Stmt, got, test.Result) } diff --git a/cassandra_test.go b/cassandra_test.go index 9fa2a0d1a..2613f17d4 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -1302,8 +1302,8 @@ func Test_RetryPolicyIdempotence(t *testing.T) { q.RetryPolicy(&MyRetryPolicy{}) q.Consistency(All) - _ = q.Exec() - require.Equal(t, tc.expectedNumberOfTries, q.Attempts()) + it := q.Iter() + require.Equal(t, tc.expectedNumberOfTries, it.Attempts()) }) } } @@ -1673,7 +1673,7 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) { defer s.Close() insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5) - if err := conn.executeQuery(ctx, insertQry).err; err == nil { + if err := conn.executeQuery(ctx, newInternalQuery(insertQry, nil)).err; err == nil { t.Fatal("expected error, but got nil.") } @@ -1681,7 +1681,7 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) { t.Fatal("create table:", err) } - if err := conn.executeQuery(ctx, insertQry).err; err != nil { + if err := conn.executeQuery(ctx, newInternalQuery(insertQry, nil)).err; err != nil { t.Fatal(err) // unconfigured columnfamily } } @@ -1695,7 +1695,7 @@ func TestPrepare_ReprepareStatement(t *testing.T) { stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement") query := session.Query(stmt, "bar") - if err := conn.executeQuery(ctx, query).Close(); err != nil { + if err := conn.executeQuery(ctx, newInternalQuery(query, nil)).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } @@ -1714,7 +1714,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) { stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch") batch := session.Batch(UnloggedBatch) batch.Query(stmt, "bar") - if err := conn.executeBatch(ctx, batch).Close(); err != nil { + if err := conn.executeBatch(ctx, newInternalBatch(batch, nil)).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } @@ -2059,14 +2059,16 @@ func TestQueryStats(t *testing.T) { session := createSession(t) defer session.Close() qry := session.Query("SELECT * FROM system.peers") - if err := qry.Exec(); err != nil { + iter := qry.Iter() + err := iter.Close() + if err != nil { t.Fatalf("query failed. %v", err) } else { - if qry.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if qry.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", qry.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } } @@ -2099,15 +2101,16 @@ func TestBatchStats(t *testing.T) { b := session.Batch(LoggedBatch) b.Query("INSERT INTO batchStats (id) VALUES (?)", 1) b.Query("INSERT INTO batchStats (id) VALUES (?)", 2) - - if err := b.Exec(); err != nil { + iter := b.Iter() + err := iter.Close() + if err != nil { t.Fatalf("query failed. %v", err) } else { - if b.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if b.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } } @@ -2850,7 +2853,7 @@ func TestRoutingKey(t *testing.T) { t.Errorf("Expected cache size to be 1 but was %d", cacheSize) } - query := session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2) + query := newInternalQuery(session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil) routingKey, err := query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) @@ -2892,7 +2895,7 @@ func TestRoutingKey(t *testing.T) { t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[1].Type()) } - query = session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2) + query = newInternalQuery(session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil) routingKey, err = query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) @@ -3441,15 +3444,16 @@ func TestUnsetColBatch(t *testing.T) { b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "") b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue) - - if err := b.Exec(); err != nil { + iter := b.Iter() + err := iter.Close() + if err != nil { t.Fatalf("query failed. %v", err) } else { - if b.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if b.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } var id, mInt, count int @@ -3702,7 +3706,9 @@ func TestQueryCompressionNotWorthIt(t *testing.T) { // The driver should handle this by updating its prepared statement inside the cache // when it receives RESULT/ROWS with Metadata_changed flag func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { - session := createSession(t) + session := createSession(t, func(config *ClusterConfig) { + config.NumConns = 1 + }) defer session.Close() if session.cfg.ProtoVersion < protoVersion5 { @@ -3726,13 +3732,17 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { t.Fatal(err) } - // We have to specify conn for all queries to ensure that + // We have to specify host for all queries to ensure that // all queries are running on the same node - conn := session.getConn() + hosts := session.GetHosts() + if len(hosts) == 0 { + t.Fatal("no hosts found") + } + hostid := hosts[0].HostID() const selectStmt = "SELECT * FROM gocql_test.metadata_changed" queryBeforeTableAltering := session.Query(selectStmt) - queryBeforeTableAltering.conn = conn + queryBeforeTableAltering.SetHostID(hostid) row := make(map[string]interface{}) err = queryBeforeTableAltering.MapScan(row) if err != nil { @@ -3742,13 +3752,16 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { require.Len(t, row, 1, "Expected to retrieve a single column") require.Equal(t, 1, row["id"]) - stmtCacheKey := session.stmtsLRU.keyFor(conn.host.HostID(), conn.currentKeyspace, queryBeforeTableAltering.stmt) - inflight, _ := session.stmtsLRU.get(stmtCacheKey) + stmtCacheKey := session.stmtsLRU.keyFor(hostid, "gocql_test", queryBeforeTableAltering.stmt) + inflight, ok := session.stmtsLRU.get(stmtCacheKey) + if !ok { + t.Fatalf("failed to find inflight entry for key %v", stmtCacheKey) + } preparedStatementBeforeTableAltering := inflight.preparedStatment // Changing table schema in order to cause C* to return RESULT/ROWS Metadata_changed alteringTableQuery := session.Query("ALTER TABLE gocql_test.metadata_changed ADD new_col int") - alteringTableQuery.conn = conn + alteringTableQuery.SetHostID(hostid) err = alteringTableQuery.Exec() if err != nil { t.Fatal(err) @@ -3800,7 +3813,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { // Expecting C* will return RESULT/ROWS Metadata_changed // and it will be properly handled queryAfterTableAltering := session.Query(selectStmt) - queryAfterTableAltering.conn = conn + queryAfterTableAltering.SetHostID(hostid) iter := queryAfterTableAltering.Iter() handleRows(iter) @@ -3825,7 +3838,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { defer cancel() queryAfterTableAltering2 := session.Query(selectStmt).WithContext(ctx) - queryAfterTableAltering2.conn = conn + queryAfterTableAltering2.SetHostID(hostid) iter = queryAfterTableAltering2.Iter() handleRows(iter) err = iter.Close() @@ -3842,7 +3855,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { // Executing prepared stmt and expecting that C* won't return // Metadata_changed because the table is not being changed. queryAfterTableAltering3 := session.Query(selectStmt).WithContext(ctx) - queryAfterTableAltering3.conn = conn + queryAfterTableAltering3.SetHostID(hostid) iter = queryAfterTableAltering2.Iter() handleRows(iter) @@ -3946,7 +3959,10 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { getRoutingKeyInfo := func(key string) *routingKeyInfo { t.Helper() session.routingKeyInfoCache.mu.Lock() - value, _ := session.routingKeyInfoCache.lru.Get(key) + value, ok := session.routingKeyInfoCache.lru.Get(key) + if !ok { + t.Fatalf("routing key not found in cache for key %v", key) + } session.routingKeyInfoCache.mu.Unlock() inflight := value.(*inflightCachedEntry) @@ -3956,9 +3972,10 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)" // Running batch in default ks gocql_test - b1 := session.NewBatch(LoggedBatch) + b1 := session.Batch(LoggedBatch) b1.Query(insertQuery, 1) - _, err = b1.GetRoutingKey() + internalB := newInternalBatch(b1, nil) + _, err = internalB.GetRoutingKey() require.NoError(t, err) // Ensuring that the cache contains the query with default ks @@ -3966,10 +3983,11 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { require.Equal(t, "gocql_test", routingKeyInfo1.keyspace) // Running batch in gocql_test_routing_key_cache ks - b2 := session.NewBatch(LoggedBatch) + b2 := session.Batch(LoggedBatch) b2.SetKeyspace("gocql_test_routing_key_cache") b2.Query(insertQuery, 2) - _, err = b2.GetRoutingKey() + internalB2 := newInternalBatch(b2, nil) + _, err = internalB2.GetRoutingKey() require.NoError(t, err) // Ensuring that the cache contains the query with gocql_test_routing_key_cache ks @@ -3980,15 +3998,18 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { // Running query in default ks gocql_test q1 := session.Query(selectStmt, 1) - _, err = q1.GetRoutingKey() + iter := q1.Iter() + err = iter.Close() require.NoError(t, err) - require.Equal(t, "gocql_test", q1.routingInfo.keyspace) + require.Equal(t, "gocql_test", iter.Keyspace()) // Running query in gocql_test_routing_key_cache ks q2 := session.Query(selectStmt, 1) - _, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey() + q2.SetKeyspace("gocql_test_routing_key_cache") + iter = q2.Iter() + err = iter.Close() require.NoError(t, err) - require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace) + require.Equal(t, "gocql_test_routing_key_cache", iter.Keyspace()) session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec() } diff --git a/conn.go b/conn.go index c9efe723b..4e1bb4f2e 100644 --- a/conn.go +++ b/conn.go @@ -1497,32 +1497,34 @@ func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error return nil } -func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { +func (c *Conn) executeQuery(ctx context.Context, q *internalQuery) *Iter { + qryOpts := q.qryOpts params := queryParams{ - consistency: qry.cons, + consistency: q.GetConsistency(), } + iter := newIter(q.metrics, q.Keyspace(), q.routingInfo, q.qryOpts.getKeyspace) // frame checks that it is not 0 - params.serialConsistency = qry.serialCons - params.defaultTimestamp = qry.defaultTimestamp - params.defaultTimestampValue = qry.defaultTimestampValue + params.serialConsistency = qryOpts.serialCons + params.defaultTimestamp = qryOpts.defaultTimestamp + params.defaultTimestampValue = qryOpts.defaultTimestampValue - if len(qry.pageState) > 0 { - params.pagingState = qry.pageState + if len(q.pageState) > 0 { + params.pagingState = q.pageState } - if qry.pageSize > 0 { - params.pageSize = qry.pageSize + if qryOpts.pageSize > 0 { + params.pageSize = qryOpts.pageSize } if c.version > protoVersion4 { - params.keyspace = qry.keyspace - params.nowInSeconds = qry.nowInSecondsValue + params.keyspace = qryOpts.keyspace + params.nowInSeconds = qryOpts.nowInSecondsValue } // If a keyspace for the qry is overriden, // then we should use it to create stmt cache key usedKeyspace := c.currentKeyspace - if qry.keyspace != "" { - usedKeyspace = qry.keyspace + if qryOpts.keyspace != "" { + usedKeyspace = qryOpts.keyspace } var ( @@ -1530,17 +1532,18 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { info *preparedStatment ) - if !qry.skipPrepare && qry.shouldPrepare() { + if !qryOpts.skipPrepare && shouldPrepare(qryOpts.stmt) { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace) + info, err = c.prepareStatement(ctx, qryOpts.stmt, qryOpts.trace, usedKeyspace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } - values := qry.values - if qry.binding != nil { - values, err = qry.binding(&QueryInfo{ + values := qryOpts.values + if qryOpts.binding != nil { + values, err = qryOpts.binding(&QueryInfo{ Id: info.id, Args: info.request.columns, Rval: info.response.columns, @@ -1548,12 +1551,14 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { }) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } } if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))} + iter.err = fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values)) + return iter } params.values = make([]queryValues, len(values)) @@ -1562,59 +1567,63 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { value := values[i] typ := info.request.columns[i].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { - return &Iter{err: err} + iter.err = err + return iter } } // if the metadata was not present in the response then we should not skip it - params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) && info != nil && info.response.flags&flagNoMetaData == 0 + params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qryOpts.disableSkipMetadata) && info != nil && info.response.flags&flagNoMetaData == 0 frame = &writeExecuteFrame{ preparedID: info.id, params: params, - customPayload: qry.customPayload, + customPayload: qryOpts.customPayload, resultMetadataID: info.resultMetadataID, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata - qry.routingInfo.mu.Lock() - qry.routingInfo.keyspace = info.request.keyspace + q.routingInfo.mu.Lock() + q.routingInfo.keyspace = info.request.keyspace if info.request.keyspace == "" { - qry.routingInfo.keyspace = usedKeyspace + q.routingInfo.keyspace = usedKeyspace } - qry.routingInfo.table = info.request.table - qry.routingInfo.mu.Unlock() + q.routingInfo.table = info.request.table + q.routingInfo.mu.Unlock() } else { frame = &writeQueryFrame{ - statement: qry.stmt, + statement: qryOpts.stmt, params: params, - customPayload: qry.customPayload, + customPayload: qryOpts.customPayload, } } - framer, err := c.exec(ctx, frame, qry.trace) + framer, err := c.exec(ctx, frame, qryOpts.trace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } resp, err := framer.parseFrame() if err != nil { - return &Iter{err: err} + iter.err = err + return iter } - if len(framer.traceID) > 0 && qry.trace != nil { - qry.trace.Trace(framer.traceID) + if len(framer.traceID) > 0 && qryOpts.trace != nil { + qryOpts.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: - return &Iter{framer: framer} + iter.framer = framer + return iter case *resultRowsFrame: if x.meta.newMetadataID != nil { // If a RESULT/Rows message reports // changed resultset metadata with the Metadata_changed flag, the reported new // resultset metadata must be used in subsequent executions - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qryOpts.stmt) oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey) if ok { newInflight := &inflightPrepare{ @@ -1635,33 +1644,32 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { info = newInflight.preparedStatment } } - - iter := &Iter{ - meta: x.meta, - framer: framer, - numRows: x.numRows, - } + iter.meta = x.meta + iter.framer = framer + iter.numRows = x.numRows if x.meta.noMetaData() { if info != nil { iter.meta = info.response iter.meta.pagingState = copyBytes(x.meta.pagingState) } else { - return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} + iter = newErrIter(errors.New("gocql: did not receive metadata but prepared info is nil"), q.metrics, q.Keyspace(), q.routingInfo, q.qryOpts.getKeyspace) + iter.framer = framer + return iter } } else { iter.meta = x.meta } - if x.meta.morePages() && !qry.disableAutoPage { - newQry := new(Query) - *newQry = *qry + if x.meta.morePages() && !qryOpts.disableAutoPage { + newQry := new(internalQuery) + *newQry = *q newQry.pageState = copyBytes(x.meta.pagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ - qry: newQry, - pos: int((1 - qry.prefetch) * float64(x.numRows)), + q: newQry, + pos: int((1 - qryOpts.prefetch) * float64(x.numRows)), } if iter.next.pos < 1 { @@ -1671,9 +1679,10 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { return iter case *resultKeyspaceFrame: - return &Iter{framer: framer} + iter.framer = framer + return iter case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: - iter := &Iter{framer: framer} + iter.framer = framer if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag c.logger.Println(err) @@ -1683,16 +1692,17 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qryOpts.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) - return c.executeQuery(ctx, qry) + return c.executeQuery(ctx, q) case error: - return &Iter{err: x, framer: framer} + iter.err = x + iter.framer = framer + return iter default: - return &Iter{ - err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), - framer: framer, - } + iter.err = NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x) + iter.framer = framer + return iter } } @@ -1744,38 +1754,40 @@ func (c *Conn) UseKeyspace(keyspace string) error { return nil } -func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { - n := len(batch.Entries) +func (c *Conn) executeBatch(ctx context.Context, b *internalBatch) *Iter { + iter := newIter(b.metrics, b.Keyspace(), b.routingInfo, nil) + n := len(b.batchOpts.entries) req := &writeBatchFrame{ - typ: batch.Type, + typ: b.batchOpts.bType, statements: make([]batchStatment, n), - consistency: batch.Cons, - serialConsistency: batch.serialCons, - defaultTimestamp: batch.defaultTimestamp, - defaultTimestampValue: batch.defaultTimestampValue, - customPayload: batch.CustomPayload, + consistency: b.GetConsistency(), + serialConsistency: b.batchOpts.serialCons, + defaultTimestamp: b.batchOpts.defaultTimestamp, + defaultTimestampValue: b.batchOpts.defaultTimestampValue, + customPayload: b.batchOpts.customPayload, } if c.version > protoVersion4 { - req.keyspace = batch.keyspace - req.nowInSeconds = batch.nowInSeconds + req.keyspace = b.batchOpts.keyspace + req.nowInSeconds = b.batchOpts.nowInSeconds } usedKeyspace := c.currentKeyspace - if batch.keyspace != "" { - usedKeyspace = batch.keyspace + if b.batchOpts.keyspace != "" { + usedKeyspace = b.batchOpts.keyspace } - stmts := make(map[string]string, len(batch.Entries)) + stmts := make(map[string]string, len(b.batchOpts.entries)) for i := 0; i < n; i++ { - entry := &batch.Entries[i] - b := &req.statements[i] + entry := &b.batchOpts.entries[i] + batchStmt := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace) + info, err := c.prepareStatement(ctx, entry.Stmt, b.batchOpts.trace, usedKeyspace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } var values []interface{} @@ -1789,68 +1801,75 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { PKeyColumns: info.request.pkeyColumns, }) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } } if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))} + iter.err = fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values)) + return iter } - b.preparedID = info.id + batchStmt.preparedID = info.id stmts[string(info.id)] = entry.Stmt - b.values = make([]queryValues, info.request.actualColCount) + batchStmt.values = make([]queryValues, info.request.actualColCount) for j := 0; j < info.request.actualColCount; j++ { - v := &b.values[j] + v := &batchStmt.values[j] value := values[j] typ := info.request.columns[j].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { - return &Iter{err: err} + iter.err = err + return iter } } } else { - b.statement = entry.Stmt + batchStmt.statement = entry.Stmt } } - framer, err := c.exec(batch.Context(), req, batch.trace) + framer, err := c.exec(ctx, req, b.batchOpts.trace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } resp, err := framer.parseFrame() if err != nil { - return &Iter{err: err, framer: framer} + iter.err = err + iter.framer = framer + return iter } - if len(framer.traceID) > 0 && batch.trace != nil { - batch.trace.Trace(framer.traceID) + if len(framer.traceID) > 0 && b.batchOpts.trace != nil { + b.batchOpts.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: - return &Iter{} + return iter case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } - return c.executeBatch(ctx, batch) + return c.executeBatch(ctx, b) case *resultRowsFrame: - iter := &Iter{ - meta: x.meta, - framer: framer, - numRows: x.numRows, - } - + iter.meta = x.meta + iter.framer = framer + iter.numRows = x.numRows return iter case error: - return &Iter{err: x, framer: framer} + iter.err = x + iter.framer = framer + return iter default: - return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} + iter.err = NewErrProtocol("Unknown type in response to batch statement: %s", x) + iter.framer = framer + return iter } } @@ -1858,9 +1877,9 @@ func (c *Conn) query(ctx context.Context, statement string, values ...interface{ q := c.session.Query(statement, values...).Consistency(One).Trace(nil) q.skipPrepare = true q.disableSkipMetadata = true + // we want to keep the query on this connection - q.conn = c - return c.executeQuery(ctx, q) + return q.iterInternal(c, ctx) } func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter { diff --git a/conn_test.go b/conn_test.go index 4b7ed732d..9a3586196 100644 --- a/conn_test.go +++ b/conn_test.go @@ -392,12 +392,14 @@ func TestQueryRetry(t *testing.T) { rt := &SimpleRetryPolicy{NumRetries: 1} qry := db.Query("kill").RetryPolicy(rt) - if err := qry.Exec(); err == nil { + iter := qry.Iter() + err = iter.Close() + if err == nil { t.Fatalf("expected error") } requests := atomic.LoadInt64(&srv.nKillReq) - attempts := qry.Attempts() + attempts := iter.Attempts() if requests != int64(attempts) { t.Fatalf("expected requests %v to match query attempts %v", requests, attempts) } @@ -439,13 +441,15 @@ func TestQueryMultinodeWithMetrics(t *testing.T) { rt := &SimpleRetryPolicy{NumRetries: 3} observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false, logger: log} qry := db.Query("kill").RetryPolicy(rt).Observer(observer).Idempotent(true) - if err := qry.Exec(); err == nil { + iter := qry.Iter() + err = iter.Close() + if err == nil { t.Fatalf("expected error") } for i, ip := range addresses { host := &HostInfo{connectAddress: net.ParseIP(ip)} - queryMetric := qry.metrics.hostMetrics(host) + queryMetric := iter.metrics.hostMetrics(host) observedMetrics := observer.GetMetrics(host) requests := int(atomic.LoadInt64(&nodes[i].nKillReq)) @@ -465,7 +469,7 @@ func TestQueryMultinodeWithMetrics(t *testing.T) { } } // the query will only be attempted once, but is being retried - attempts := qry.Attempts() + attempts := iter.Attempts() if attempts != rt.NumRetries { t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, attempts) } diff --git a/control.go b/control.go index 113bfaaa0..dfc7dc021 100644 --- a/control.go +++ b/control.go @@ -502,7 +502,7 @@ func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { return fn(ch) } - return &Iter{err: errNoControl} + return newErrIter(errNoControl, newQueryMetrics(), "", nil, nil) } func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { @@ -514,20 +514,20 @@ func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { // query will return nil if the connection is closed or nil func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil) + qry := newInternalQuery(q, context.TODO()) for { iter = c.withConn(func(conn *Conn) *Iter { - // we want to keep the query on the control connection - q.conn = conn - return conn.executeQuery(context.TODO(), q) + qry.conn = conn + return conn.executeQuery(qry.Context(), qry) }) if gocqlDebug && iter.err != nil { c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err) } - q.AddAttempts(1, c.getConn().host) - if iter.err == nil || !c.retry.Attempt(q) { + iter.metrics.attempt(1, 0, c.getConn().host, false) + if iter.err == nil || !c.retry.Attempt(qry) { break } } @@ -537,7 +537,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter func (c *controlConn) awaitSchemaAgreement() error { return c.withConn(func(conn *Conn) *Iter { - return &Iter{err: conn.awaitSchemaAgreement(context.TODO())} + return newErrIter(conn.awaitSchemaAgreement(context.TODO()), newQueryMetrics(), "", nil, nil) }).err } diff --git a/hostpool/hostpool.go b/hostpool/hostpool.go index e4a648598..f3a3d0f6f 100644 --- a/hostpool/hostpool.go +++ b/hostpool/hostpool.go @@ -100,7 +100,7 @@ func (r *hostPoolHostPolicy) HostDown(host *gocql.HostInfo) { r.RemoveHost(host) } -func (r *hostPoolHostPolicy) Pick(qry gocql.ExecutableQuery) gocql.NextHost { +func (r *hostPoolHostPolicy) Pick(qry gocql.ExecutableStatement) gocql.NextHost { return func() gocql.SelectedHost { r.mu.RLock() defer r.mu.RUnlock() diff --git a/keyspace_table_test.go b/keyspace_table_test.go index f3d51f3f8..7862b4892 100644 --- a/keyspace_table_test.go +++ b/keyspace_table_test.go @@ -33,7 +33,7 @@ import ( "testing" ) -// Keyspace_table checks if Query.Keyspace() is updated based on prepared statement +// Keyspace_table checks if Iter.Keyspace() is updated based on prepared statement func TestKeyspaceTable(t *testing.T) { cluster := createCluster() @@ -45,8 +45,7 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal("createSession:", err) } - cluster.Keyspace = "wrong_keyspace" - + wrongKeyspace := "testwrong" keyspace := "test1" table := "table1" @@ -55,6 +54,11 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal("unable to drop keyspace:", err) } + err = createTable(session, `DROP KEYSPACE IF EXISTS `+wrongKeyspace) + if err != nil { + t.Fatal("unable to drop keyspace:", err) + } + err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s WITH replication = { 'class' : 'SimpleStrategy', @@ -65,6 +69,16 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal("unable to create keyspace:", err) } + err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }`, wrongKeyspace)) + + if err != nil { + t.Fatal("unable to create keyspace:", err) + } + if err := session.control.awaitSchemaAgreement(); err != nil { t.Fatal(err) } @@ -80,11 +94,24 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal(err) } + session.Close() + + cluster = createCluster() + + fallback = RoundRobinHostPolicy() + cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback) + cluster.Keyspace = wrongKeyspace + + session, err = cluster.CreateSession() + if err != nil { + t.Fatal("createSession:", err) + } + ctx := context.Background() // insert a row if err := session.Query(`INSERT INTO test1.table1(pk, ck, v) VALUES (?, ?, ?)`, - 1, 2, 3).WithContext(ctx).Consistency(One).Exec(); err != nil { + 1, 2, 3).Consistency(One).ExecContext(ctx); err != nil { t.Fatal(err) } @@ -93,13 +120,20 @@ func TestKeyspaceTable(t *testing.T) { /* Search for a specific set of records whose 'pk' column matches * the value of inserted row. */ qry := session.Query(`SELECT pk FROM test1.table1 WHERE pk = ? LIMIT 1`, - 1).WithContext(ctx).Consistency(One) - if err := qry.Scan(&pk); err != nil { + 1).Consistency(One) + iter := qry.IterContext(ctx) + ok := iter.Scan(&pk) + err = iter.Close() + if err != nil { t.Fatal(err) } + if !ok { + t.Fatal("expected pk to be scanned") + } - // cluster.Keyspace was set to "wrong_keyspace", but during prepering statement + // cluster.Keyspace was set to "testwrong", but during prepering statement // Keyspace in Query should be changed to "test" and Table should be changed to table1 - assertEqual(t, "qry.Keyspace()", "test1", qry.Keyspace()) - assertEqual(t, "qry.Table()", "table1", qry.Table()) + assertEqual(t, "qry.Keyspace()", "testwrong", qry.Keyspace()) + assertEqual(t, "iter.Keyspace()", "test1", iter.Keyspace()) + assertEqual(t, "iter.Table()", "table1", iter.Table()) } diff --git a/lz4/lz4_test.go b/lz4/lz4_test.go index ea64371c5..5c00749a2 100644 --- a/lz4/lz4_test.go +++ b/lz4/lz4_test.go @@ -28,9 +28,10 @@ package lz4 import ( - "github.com/pierrec/lz4/v4" "testing" + "github.com/pierrec/lz4/v4" + "github.com/stretchr/testify/require" ) diff --git a/policies.go b/policies.go index 219f472df..c9f3399fd 100644 --- a/policies.go +++ b/policies.go @@ -304,7 +304,7 @@ type HostSelectionPolicy interface { // Multiple attempts of a single query execution won't call the returned NextHost function concurrently, // so it's safe to have internal state without additional synchronization as long as every call to Pick returns // a different instance of NextHost. - Pick(ExecutableQuery) NextHost + Pick(statement ExecutableStatement) NextHost } // SelectedHost is an interface returned when picking a host from a host @@ -342,7 +342,7 @@ func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {} func (r *roundRobinHostPolicy) Init(*Session) {} -func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost { +func (r *roundRobinHostPolicy) Pick(qry ExecutableStatement) NextHost { nextStartOffset := atomic.AddUint64(&r.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), r.hosts.get()) } @@ -577,7 +577,7 @@ func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo, logg m.tokenRing = tokenRing } -func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { +func (t *tokenAwareHostPolicy) Pick(qry ExecutableStatement) NextHost { if qry == nil { return t.fallback.Pick(qry) } @@ -771,7 +771,7 @@ func roundRobbin(shift int, hosts ...[]*HostInfo) NextHost { } } -func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { +func (d *dcAwareRR) Pick(q ExecutableStatement) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) } @@ -833,7 +833,7 @@ func (d *rackAwareRR) RemoveHost(host *HostInfo) { func (d *rackAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } func (d *rackAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } -func (d *rackAwareRR) Pick(q ExecutableQuery) NextHost { +func (d *rackAwareRR) Pick(q ExecutableStatement) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get(), d.hosts[2].get()) } diff --git a/policies_test.go b/policies_test.go index e8bda8908..540742a0f 100644 --- a/policies_test.go +++ b/policies_test.go @@ -76,7 +76,7 @@ func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { return nil, errors.New("not initalized") } - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -129,7 +129,7 @@ func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { // now the token ring is configured query.RoutingKey([]byte("20")) - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first token-aware hosts expectHosts(t, "hosts[0]", iter, "1") expectHosts(t, "hosts[1]", iter, "2") @@ -182,11 +182,11 @@ func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) { } policy.SetPartitioner("OrderedPartitioner") - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return "myKeyspace" } query.RoutingKey([]byte("20")) - iter := policy.Pick(query) + iter := policy.Pick(newInternalQuery(query, nil)) next := iter() if next == nil { t.Fatal("got nil host") @@ -240,7 +240,7 @@ func TestCOWList_Add(t *testing.T) { // TestSimpleRetryPolicy makes sure that we only allow 1 + numRetries attempts func TestSimpleRetryPolicy(t *testing.T) { - q := &Query{routingInfo: &queryRoutingInfo{}} + q := newInternalQuery(&Query{}, nil) // this should allow a total of 3 tries. rt := &SimpleRetryPolicy{NumRetries: 2} @@ -298,7 +298,7 @@ func TestExponentialBackoffPolicy(t *testing.T) { func TestDowngradingConsistencyRetryPolicy(t *testing.T) { - q := &Query{cons: LocalQuorum, routingInfo: &queryRoutingInfo{}} + q := newInternalQuery(&Query{initialConsistency: LocalQuorum}, nil) rewt0 := &RequestErrWriteTimeout{ Received: 0, @@ -459,7 +459,7 @@ func TestHostPolicy_TokenAware(t *testing.T) { return nil, errors.New("not initialized") } - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -497,7 +497,7 @@ func TestHostPolicy_TokenAware(t *testing.T) { } query.RoutingKey([]byte("30")) - if actual := policy.Pick(query)(); actual == nil { + if actual := policy.Pick(newInternalQuery(query, nil))(); actual == nil { t.Fatal("expected to get host from fallback got nil") } @@ -541,7 +541,7 @@ func TestHostPolicy_TokenAware(t *testing.T) { // now the token ring is configured query.RoutingKey([]byte("23")) - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first should be host with matching token from the local DC expectHosts(t, "matching token from local DC", iter, "4") // next are in non-deterministic order @@ -561,7 +561,7 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { return nil, errors.New("not initialized") } - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -632,7 +632,7 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { // now the token ring is configured query.RoutingKey([]byte("18")) - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first should be hosts with matching token from the local DC expectHosts(t, "matching token from local DC", iter, "4", "7") // rest should be hosts with matching token from remote DCs @@ -688,7 +688,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { policyWithFallbackInternal.getKeyspaceName = policyInternal.getKeyspaceName policyWithFallbackInternal.getKeyspaceMetadata = policyInternal.getKeyspaceMetadata - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -727,7 +727,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { } query.RoutingKey([]byte("30")) - if actual := policy.Pick(query)(); actual == nil { + if actual := policy.Pick(newInternalQuery(query, nil))(); actual == nil { t.Fatal("expected to get host from fallback got nil") } @@ -775,7 +775,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { // now the token ring is configured // Test the policy with fallback - iter = policyWithFallback.Pick(query) + iter = policyWithFallback.Pick(newInternalQuery(query, nil)) // first should be host with matching token from the local DC & rack expectHosts(t, "matching token from local DC and local rack", iter, "7") @@ -792,7 +792,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { expectNoMoreHosts(t, iter) // Test the policy without fallback - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first should be host with matching token from the local DC & Rack expectHosts(t, "matching token from local DC and local rack", iter, "7") diff --git a/query_executor.go b/query_executor.go index 61a43c8c6..552d0b97c 100644 --- a/query_executor.go +++ b/query_executor.go @@ -31,22 +31,41 @@ import ( "time" ) -type ExecutableQuery interface { - borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine. - releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error. - execute(ctx context.Context, conn *Conn) *Iter - attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) - retryPolicy() RetryPolicy - speculativeExecutionPolicy() SpeculativeExecutionPolicy +// Deprecated: Will be removed in a future major release. Also Query and Batch no longer implement this interface. +// +// Please use Statement (for Query / Batch objects) or ExecutableStatement (in HostSelectionPolicy implementations) instead. +type ExecutableQuery = ExecutableStatement + +// ExecutableStatement is an interface that represents a query or batch statement that +// exposes the correct functions for the HostSelectionPolicy to operate correctly. +type ExecutableStatement interface { GetRoutingKey() ([]byte, error) Keyspace() string Table() string IsIdempotent() bool GetHostID() string + Statement() Statement +} - withContext(context.Context) ExecutableQuery +// Statement is an interface that represents a CQL statement that the driver can execute +// (currently Query and Batch via Session.Query and Session.Batch) +type Statement interface { + Iter() *Iter + IterContext(ctx context.Context) *Iter + Exec() error + ExecContext(ctx context.Context) error +} +type internalRequest interface { + execute(ctx context.Context, conn *Conn) *Iter + attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) + retryPolicy() RetryPolicy + speculativeExecutionPolicy() SpeculativeExecutionPolicy + getQueryMetrics() *queryMetrics + getRoutingInfo() *queryRoutingInfo + getKeyspaceFunc() func() string RetryableQuery + ExecutableStatement } type queryExecutor struct { @@ -54,7 +73,7 @@ type queryExecutor struct { policy HostSelectionPolicy } -func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter { +func (q *queryExecutor) attemptQuery(ctx context.Context, qry internalRequest, conn *Conn) *Iter { start := time.Now() iter := qry.execute(ctx, conn) end := time.Now() @@ -64,7 +83,7 @@ func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, c return iter } -func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, +func (q *queryExecutor) speculate(ctx context.Context, qry internalRequest, sp SpeculativeExecutionPolicy, hostIter NextHost, results chan *Iter) *Iter { ticker := time.NewTicker(sp.Delay()) defer ticker.Stop() @@ -72,10 +91,9 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S for i := 0; i < sp.Attempts(); i++ { select { case <-ticker.C: - qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) case <-ctx.Done(): - return &Iter{err: ctx.Err()} + return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) case iter := <-results: return iter } @@ -84,7 +102,7 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S return nil } -func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { +func (q *queryExecutor) executeQuery(qry internalRequest) (*Iter, error) { var hostIter NextHost // check if the host id is specified for the query, @@ -132,7 +150,6 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { results := make(chan *Iter, 1) // Launch the main execution - qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) // The speculative executions are launched _in addition_ to the main @@ -146,11 +163,11 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { case iter := <-results: return iter, nil case <-ctx.Done(): - return &Iter{err: ctx.Err()}, nil + return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()), nil } } -func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { +func (q *queryExecutor) do(ctx context.Context, qry internalRequest, hostIter NextHost) *Iter { selectedHost := hostIter() rt := qry.retryPolicy() @@ -213,7 +230,7 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne stopRetries = true default: // Undefined? Return nil and error, this will panic in the requester - return &Iter{err: ErrUnknownRetryType} + return newErrIter(ErrUnknownRetryType, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } if stopRetries || attemptsReached { @@ -225,16 +242,447 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne } if lastErr != nil { - return &Iter{err: lastErr} + return newErrIter(lastErr, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } - return &Iter{err: ErrNoConnections} + return newErrIter(ErrNoConnections, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } -func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) { +func (q *queryExecutor) run(ctx context.Context, qry internalRequest, hostIter NextHost, results chan<- *Iter) { select { case results <- q.do(ctx, qry, hostIter): case <-ctx.Done(): } - qry.releaseAfterExecution() +} + +type queryOptions struct { + stmt string + + // Paging + pageSize int + disableAutoPage bool + + // Monitoring + trace Tracer + observer QueryObserver + + // Parameters + values []interface{} + binding func(q *QueryInfo) ([]interface{}, error) + + // Timestamp + defaultTimestamp bool + defaultTimestampValue int64 + + // Consistency + serialCons SerialConsistency + + // Protocol flag + disableSkipMetadata bool + + customPayload map[string][]byte + prefetch float64 + rt RetryPolicy + spec SpeculativeExecutionPolicy + context context.Context + idempotent bool + keyspace string + skipPrepare bool + routingKey []byte + nowInSecondsValue *int + hostID string + + // getKeyspace is field so that it can be overriden in tests + getKeyspace func() string +} + +func newQueryOptions(q *Query, ctx context.Context) *queryOptions { + var newRoutingKey []byte + if q.routingKey != nil { + routingKey := q.routingKey + newRoutingKey = make([]byte, len(routingKey)) + copy(newRoutingKey, routingKey) + } + if ctx == nil { + ctx = q.Context() + } + return &queryOptions{ + stmt: q.stmt, + values: q.values, + pageSize: q.pageSize, + prefetch: q.prefetch, + trace: q.trace, + observer: q.observer, + rt: q.rt, + spec: q.spec, + binding: q.binding, + serialCons: q.serialCons, + defaultTimestamp: q.defaultTimestamp, + defaultTimestampValue: q.defaultTimestampValue, + disableSkipMetadata: q.disableSkipMetadata, + context: ctx, + idempotent: q.idempotent, + customPayload: q.customPayload, + disableAutoPage: q.disableAutoPage, + skipPrepare: q.skipPrepare, + routingKey: newRoutingKey, + getKeyspace: q.getKeyspace, + nowInSecondsValue: q.nowInSecondsValue, + keyspace: q.keyspace, + hostID: q.hostID, + } +} + +type internalQuery struct { + originalQuery *Query + qryOpts *queryOptions + pageState []byte + metrics *queryMetrics + conn *Conn + consistency uint32 + session *Session + routingInfo *queryRoutingInfo +} + +func newInternalQuery(q *Query, ctx context.Context) *internalQuery { + var newPageState []byte + if q.initialPageState != nil { + pageState := q.initialPageState + newPageState = make([]byte, len(pageState)) + copy(newPageState, pageState) + } + return &internalQuery{ + originalQuery: q, + qryOpts: newQueryOptions(q, ctx), + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + consistency: uint32(q.initialConsistency), + pageState: newPageState, + conn: nil, + session: q.session, + routingInfo: &queryRoutingInfo{}, + } +} + +// Attempts returns the number of times the query was executed. +func (q *internalQuery) Attempts() int { + return q.metrics.attempts() +} + +func (q *internalQuery) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { + latency := end.Sub(start) + attempt, metricsForHost := q.metrics.attempt(1, latency, host, q.qryOpts.observer != nil) + + if q.qryOpts.observer != nil { + q.qryOpts.observer.ObserveQuery(q.qryOpts.context, ObservedQuery{ + Keyspace: keyspace, + Statement: q.qryOpts.stmt, + Values: q.qryOpts.values, + Start: start, + End: end, + Rows: iter.numRows, + Host: host, + Metrics: metricsForHost, + Err: iter.err, + Attempt: attempt, + Query: q.originalQuery, + }) + } +} + +func (q *internalQuery) execute(ctx context.Context, conn *Conn) *Iter { + return conn.executeQuery(ctx, q) +} + +func (q *internalQuery) retryPolicy() RetryPolicy { + return q.qryOpts.rt +} + +func (q *internalQuery) speculativeExecutionPolicy() SpeculativeExecutionPolicy { + return q.qryOpts.spec +} + +func (q *internalQuery) GetRoutingKey() ([]byte, error) { + if q.qryOpts.routingKey != nil { + return q.qryOpts.routingKey, nil + } + + if q.qryOpts.binding != nil && len(q.qryOpts.values) == 0 { + // If this query was created using session.Bind we wont have the query + // values yet, so we have to pass down to the next policy. + // TODO: Remove this and handle this case + return nil, nil + } + + // try to determine the routing key + routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.qryOpts.stmt, q.qryOpts.keyspace) + if err != nil { + return nil, err + } + + if routingKeyInfo != nil { + q.routingInfo.mu.Lock() + q.routingInfo.keyspace = routingKeyInfo.keyspace + q.routingInfo.table = routingKeyInfo.table + q.routingInfo.mu.Unlock() + } + return createRoutingKey(routingKeyInfo, q.qryOpts.values) +} + +func (q *internalQuery) Keyspace() string { + if q.qryOpts.getKeyspace != nil { + return q.qryOpts.getKeyspace() + } + + qrKs := q.routingInfo.getKeyspace() + if qrKs != "" { + return qrKs + } + if q.qryOpts.keyspace != "" { + return q.qryOpts.keyspace + } + + if q.session == nil { + return "" + } + // TODO(chbannis): this should be parsed from the query or we should let + // this be set by users. + return q.session.cfg.Keyspace +} + +func (q *internalQuery) Table() string { + return q.routingInfo.getTable() +} + +func (q *internalQuery) IsIdempotent() bool { + return q.qryOpts.idempotent +} + +func (q *internalQuery) getQueryMetrics() *queryMetrics { + return q.metrics +} + +func (q *internalQuery) SetConsistency(c Consistency) { + atomic.StoreUint32(&q.consistency, uint32(c)) +} + +func (q *internalQuery) GetConsistency() Consistency { + return Consistency(atomic.LoadUint32(&q.consistency)) +} + +func (q *internalQuery) Context() context.Context { + return q.qryOpts.context +} + +func (q *internalQuery) Statement() Statement { + return q.originalQuery +} + +func (q *internalQuery) GetHostID() string { + return q.qryOpts.hostID +} + +func (q *internalQuery) getRoutingInfo() *queryRoutingInfo { + return q.routingInfo +} + +func (q *internalQuery) getKeyspaceFunc() func() string { + return q.qryOpts.getKeyspace +} + +type batchOptions struct { + trace Tracer + observer BatchObserver + + bType BatchType + entries []BatchEntry + + defaultTimestamp bool + defaultTimestampValue int64 + + serialCons SerialConsistency + + customPayload map[string][]byte + rt RetryPolicy + spec SpeculativeExecutionPolicy + context context.Context + keyspace string + idempotent bool + routingKey []byte + nowInSeconds *int +} + +func newBatchOptions(b *Batch, ctx context.Context) *batchOptions { + // make a new array so if user keeps appending entries on the Batch object it doesn't affect this execution + newEntries := make([]BatchEntry, len(b.Entries)) + for i, e := range b.Entries { + newEntries[i] = e + } + var newRoutingKey []byte + if b.routingKey != nil { + routingKey := b.routingKey + newRoutingKey = make([]byte, len(routingKey)) + copy(newRoutingKey, routingKey) + } + if ctx == nil { + ctx = b.Context() + } + return &batchOptions{ + bType: b.Type, + entries: newEntries, + customPayload: b.CustomPayload, + rt: b.rt, + spec: b.spec, + trace: b.trace, + observer: b.observer, + serialCons: b.serialCons, + defaultTimestamp: b.defaultTimestamp, + defaultTimestampValue: b.defaultTimestampValue, + context: ctx, + keyspace: b.Keyspace(), + idempotent: b.IsIdempotent(), + routingKey: newRoutingKey, + nowInSeconds: b.nowInSeconds, + } +} + +type internalBatch struct { + originalBatch *Batch + batchOpts *batchOptions + metrics *queryMetrics + consistency uint32 + routingInfo *queryRoutingInfo + session *Session +} + +func newInternalBatch(batch *Batch, ctx context.Context) *internalBatch { + return &internalBatch{ + originalBatch: batch, + batchOpts: newBatchOptions(batch, ctx), + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + routingInfo: &queryRoutingInfo{}, + session: batch.session, + consistency: uint32(batch.GetConsistency()), + } +} + +// Attempts returns the number of attempts made to execute the batch. +func (b *internalBatch) Attempts() int { + return b.metrics.attempts() +} + +func (b *internalBatch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { + latency := end.Sub(start) + attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.batchOpts.observer != nil) + + if b.batchOpts.observer == nil { + return + } + + statements := make([]string, len(b.batchOpts.entries)) + values := make([][]interface{}, len(b.batchOpts.entries)) + + for i, entry := range b.batchOpts.entries { + statements[i] = entry.Stmt + values[i] = entry.Args + } + + b.batchOpts.observer.ObserveBatch(b.batchOpts.context, ObservedBatch{ + Keyspace: keyspace, + Statements: statements, + Values: values, + Start: start, + End: end, + // Rows not used in batch observations // TODO - might be able to support it when using BatchCAS + Host: host, + Metrics: metricsForHost, + Err: iter.err, + Attempt: attempt, + Batch: b.originalBatch, + }) +} + +func (b *internalBatch) retryPolicy() RetryPolicy { + return b.batchOpts.rt +} + +func (b *internalBatch) speculativeExecutionPolicy() SpeculativeExecutionPolicy { + return b.batchOpts.spec +} + +func (b *internalBatch) GetRoutingKey() ([]byte, error) { + if b.batchOpts.routingKey != nil { + return b.batchOpts.routingKey, nil + } + + if len(b.batchOpts.entries) == 0 { + return nil, nil + } + + entry := b.batchOpts.entries[0] + if entry.binding != nil { + // bindings do not have the values let's skip it like Query does. + return nil, nil + } + // try to determine the routing key + routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.batchOpts.keyspace) + if err != nil { + return nil, err + } + + if routingKeyInfo != nil { + b.routingInfo.mu.Lock() + b.routingInfo.keyspace = routingKeyInfo.keyspace + b.routingInfo.table = routingKeyInfo.table + b.routingInfo.mu.Unlock() + } + + return createRoutingKey(routingKeyInfo, entry.Args) +} + +func (b *internalBatch) Keyspace() string { + return b.batchOpts.keyspace +} + +func (b *internalBatch) Table() string { + return b.routingInfo.getTable() +} + +func (b *internalBatch) IsIdempotent() bool { + return b.batchOpts.idempotent +} + +func (b *internalBatch) getQueryMetrics() *queryMetrics { + return b.metrics +} + +func (b *internalBatch) SetConsistency(c Consistency) { + atomic.StoreUint32(&b.consistency, uint32(c)) +} + +func (b *internalBatch) GetConsistency() Consistency { + return Consistency(atomic.LoadUint32(&b.consistency)) +} + +func (b *internalBatch) Context() context.Context { + return b.batchOpts.context +} + +func (b *internalBatch) Statement() Statement { + return b.originalBatch +} + +func (b *internalBatch) GetHostID() string { + return "" +} + +func (b *internalBatch) getRoutingInfo() *queryRoutingInfo { + return b.routingInfo +} + +func (b *internalBatch) getKeyspaceFunc() func() string { + return nil +} + +func (b *internalBatch) execute(ctx context.Context, conn *Conn) *Iter { + return conn.executeBatch(ctx, b) } diff --git a/session.go b/session.go index bdda06406..dcd0b6207 100644 --- a/session.go +++ b/session.go @@ -104,12 +104,6 @@ type Session struct { logger StdLogger } -var queryPool = &sync.Pool{ - New: func() interface{} { - return &Query{routingInfo: &queryRoutingInfo{}, refCount: 1} - }, -} - func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInfo, error) { var hosts []*HostInfo for _, hostaddr := range addrs { @@ -386,7 +380,7 @@ func (s *Session) AwaitSchemaAgreement(ctx context.Context) error { return errNoControl } return s.control.withConn(func(conn *Conn) *Iter { - return &Iter{err: conn.awaitSchemaAgreement(ctx)} + return newErrIter(conn.awaitSchemaAgreement(ctx), newQueryMetrics(), "", nil, nil) }).err } @@ -426,7 +420,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) { // value before the query is executed. Query is automatically prepared // if it has not previously been executed. func (s *Session) Query(stmt string, values ...interface{}) *Query { - qry := queryPool.Get().(*Query) + qry := &Query{} qry.session = s qry.stmt = stmt qry.values = values @@ -449,7 +443,7 @@ type QueryInfo struct { // During execution, the meta data of the prepared query will be routed to the // binding callback, which is responsible for producing the query argument values. func (s *Session) Bind(stmt string, b func(q *QueryInfo) ([]interface{}, error)) *Query { - qry := queryPool.Get().(*Query) + qry := &Query{} qry.session = s qry.stmt = stmt qry.binding = b @@ -512,15 +506,15 @@ func (s *Session) initialized() bool { return initialized } -func (s *Session) executeQuery(qry *Query) (it *Iter) { +func (s *Session) executeQuery(qry *internalQuery) (it *Iter) { // fail fast if s.Closed() { - return &Iter{err: ErrSessionClosed} + return newErrIter(ErrSessionClosed, qry.metrics, qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } iter, err := s.executor.executeQuery(qry) if err != nil { - return &Iter{err: err} + return newErrIter(err, qry.metrics, qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } if iter == nil { panic("nil iter") @@ -713,33 +707,47 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace stri return routingKeyInfo, nil } -func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { - return conn.executeBatch(ctx, b) -} - // Exec executes a batch operation and returns nil if successful // otherwise an error is returned describing the failure. func (b *Batch) Exec() error { - iter := b.session.executeBatch(b) + iter := b.session.executeBatch(b, nil) + return iter.Close() +} + +// ExecContext executes a batch operation with the provided context and returns nil if successful +// otherwise an error is returned describing the failure. +func (b *Batch) ExecContext(ctx context.Context) error { + iter := b.session.executeBatch(b, ctx) return iter.Close() } -func (s *Session) executeBatch(batch *Batch) *Iter { +// Iter executes a batch operation and returns an Iter object +// that can be used to access properties related to the execution like Iter.Attempts and Iter.Latency +func (b *Batch) Iter() *Iter { return b.IterContext(nil) } + +// IterContext executes a batch operation with the provided context and returns an Iter object +// that can be used to access properties related to the execution like Iter.Attempts and Iter.Latency +func (b *Batch) IterContext(ctx context.Context) *Iter { + return b.session.executeBatch(b, ctx) +} + +func (s *Session) executeBatch(batch *Batch, ctx context.Context) *Iter { + b := newInternalBatch(batch, ctx) // fail fast if s.Closed() { - return &Iter{err: ErrSessionClosed} + return newErrIter(ErrSessionClosed, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc()) } // Prevent the execution of the batch if greater than the limit // Currently batches have a limit of 65536 queries. // https://datastax-oss.atlassian.net/browse/JAVA-229 if batch.Size() > BatchSizeMaximum { - return &Iter{err: ErrTooManyStmts} + return newErrIter(ErrTooManyStmts, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc()) } - iter, err := s.executor.executeQuery(batch) + iter, err := s.executor.executeQuery(b) if err != nil { - return &Iter{err: err} + return newErrIter(err, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc()) } return iter @@ -749,7 +757,7 @@ func (s *Session) executeBatch(batch *Batch) *Iter { // ExecuteBatch executes a batch operation and returns nil if successful // otherwise an error is returned describing the failure. func (s *Session) ExecuteBatch(batch *Batch) error { - iter := s.executeBatch(batch) + iter := s.executeBatch(batch, nil) return iter.Close() } @@ -769,7 +777,7 @@ func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bo // Further scans on the interator must also remember to include // the applied boolean as the first argument to *Iter.Scan func (b *Batch) ExecCAS(dest ...interface{}) (applied bool, iter *Iter, err error) { - iter = b.session.executeBatch(b) + iter = b.session.executeBatch(b, nil) if err := iter.checkErrAndNotFound(); err != nil { iter.Close() return false, nil, err @@ -797,7 +805,7 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) // however it accepts a map rather than a list of arguments for the initial // scan. func (b *Batch) MapExecCAS(dest map[string]interface{}) (applied bool, iter *Iter, err error) { - iter = b.session.executeBatch(b) + iter = b.session.executeBatch(b, nil) if err := iter.checkErrAndNotFound(); err != nil { iter.Close() return false, nil, err @@ -835,6 +843,10 @@ type queryMetrics struct { totalAttempts int } +func newQueryMetrics() *queryMetrics { + return &queryMetrics{m: make(map[string]*hostMetrics)} +} + // preFilledQueryMetrics initializes new queryMetrics based on per-host supplied data. func preFilledQueryMetrics(m map[string]*hostMetrics) *queryMetrics { qm := &queryMetrics{m: m} @@ -921,15 +933,14 @@ func (qm *queryMetrics) attempt(addAttempts int, addLatency time.Duration, type Query struct { stmt string values []interface{} - cons Consistency + initialConsistency Consistency pageSize int routingKey []byte - pageState []byte + initialPageState []byte prefetch float64 trace Tracer observer QueryObserver session *Session - conn *Conn rt RetryPolicy spec SpeculativeExecutionPolicy binding func(q *QueryInfo) ([]interface{}, error) @@ -940,8 +951,6 @@ type Query struct { context context.Context idempotent bool customPayload map[string][]byte - metrics *queryMetrics - refCount uint32 disableAutoPage bool @@ -952,9 +961,6 @@ type Query struct { // tables in AWS MCS see skipPrepare bool - // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. - routingInfo *queryRoutingInfo - // hostID specifies the host on which the query should be executed. // If it is empty, then the host is picked by HostSelectionPolicy hostID string @@ -972,10 +978,22 @@ type queryRoutingInfo struct { table string } +func (qr *queryRoutingInfo) getKeyspace() string { + qr.mu.RLock() + defer qr.mu.RUnlock() + return qr.keyspace +} + +func (qr *queryRoutingInfo) getTable() string { + qr.mu.RLock() + defer qr.mu.RUnlock() + return qr.table +} + func (q *Query) defaultsFromSession() { s := q.session - q.cons = s.cons + q.initialConsistency = s.cons q.pageSize = s.pageSize q.trace = s.trace q.observer = s.queryObserver @@ -984,7 +1002,6 @@ func (q *Query) defaultsFromSession() { q.serialCons = s.cfg.SerialConsistency q.defaultTimestamp = s.cfg.DefaultTimestamp q.idempotent = s.cfg.DefaultIdempotence - q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} q.spec = &NonSpeculativeExecution{} } @@ -1002,47 +1019,30 @@ func (q Query) Values() []interface{} { // String implements the stringer interface. func (q Query) String() string { - return fmt.Sprintf("[query statement=%q values=%+v consistency=%s]", q.stmt, q.values, q.cons) -} - -// Attempts returns the number of times the query was executed. -func (q *Query) Attempts() int { - return q.metrics.attempts() -} - -func (q *Query) AddAttempts(i int, host *HostInfo) { - q.metrics.attempt(i, 0, host, false) -} - -// Latency returns the average amount of nanoseconds per attempt of the query. -func (q *Query) Latency() int64 { - return q.metrics.latency() -} - -func (q *Query) AddLatency(l int64, host *HostInfo) { - q.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) + return fmt.Sprintf("[query statement=%q values=%+v consistency=%s]", q.stmt, q.values, q.initialConsistency) } // Consistency sets the consistency level for this query. If no consistency // level have been set, the default consistency level of the cluster // is used. func (q *Query) Consistency(c Consistency) *Query { - q.cons = c + q.initialConsistency = c return q } // GetConsistency returns the currently configured consistency level for // the query. func (q *Query) GetConsistency() Consistency { - return q.cons + return q.initialConsistency } -// Same as Consistency but without a return value +// Deprecated: use Query.Consistency instead func (q *Query) SetConsistency(c Consistency) { - q.cons = c + q.initialConsistency = c } -// CustomPayload sets the custom payload level for this query. +// CustomPayload sets the custom payload level for this query. The map is not copied internally +// so it shouldn't be modified after the query is scheduled for execution. func (q *Query) CustomPayload(customPayload map[string][]byte) *Query { q.customPayload = customPayload return q @@ -1108,11 +1108,8 @@ func (q *Query) RoutingKey(routingKey []byte) *Query { return q } -func (q *Query) withContext(ctx context.Context) ExecutableQuery { - // I really wish go had covariant types - return q.WithContext(ctx) -} - +// Deprecated: Use Query.ExecContext or Query.IterContext instead. This will be removed in a future major version. +// // WithContext returns a shallow copy of q with its context // set to ctx. // @@ -1125,47 +1122,11 @@ func (q *Query) WithContext(ctx context.Context) *Query { return &q2 } -// Deprecate: does nothing, cancel the context passed to WithContext -func (q *Query) Cancel() { - // TODO: delete -} - -func (q *Query) execute(ctx context.Context, conn *Conn) *Iter { - return conn.executeQuery(ctx, q) -} - -func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { - latency := end.Sub(start) - attempt, metricsForHost := q.metrics.attempt(1, latency, host, q.observer != nil) - - if q.observer != nil { - q.observer.ObserveQuery(q.Context(), ObservedQuery{ - Keyspace: keyspace, - Statement: q.stmt, - Values: q.values, - Start: start, - End: end, - Rows: iter.numRows, - Host: host, - Metrics: metricsForHost, - Err: iter.err, - Attempt: attempt, - }) - } -} - -func (q *Query) retryPolicy() RetryPolicy { - return q.rt -} - // Keyspace returns the keyspace the query will be executed against. func (q *Query) Keyspace() string { if q.getKeyspace != nil { return q.getKeyspace() } - if q.routingInfo.keyspace != "" { - return q.routingInfo.keyspace - } if q.keyspace != "" { return q.keyspace } @@ -1178,45 +1139,12 @@ func (q *Query) Keyspace() string { return q.session.cfg.Keyspace } -// Table returns name of the table the query will be executed against. -func (q *Query) Table() string { - return q.routingInfo.table -} - -// GetRoutingKey gets the routing key to use for routing this query. If -// a routing key has not been explicitly set, then the routing key will -// be constructed if possible using the keyspace's schema and the query -// info for this query statement. If the routing key cannot be determined -// then nil will be returned with no error. On any error condition, -// an error description will be returned. -func (q *Query) GetRoutingKey() ([]byte, error) { - if q.routingKey != nil { - return q.routingKey, nil - } else if q.binding != nil && len(q.values) == 0 { - // If this query was created using session.Bind we wont have the query - // values yet, so we have to pass down to the next policy. - // TODO: Remove this and handle this case - return nil, nil - } - - // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace) - if err != nil { - return nil, err - } - - if routingKeyInfo != nil { - q.routingInfo.mu.Lock() - q.routingInfo.keyspace = routingKeyInfo.keyspace - q.routingInfo.table = routingKeyInfo.table - q.routingInfo.mu.Unlock() - } - return createRoutingKey(routingKeyInfo, q.values) -} - func (q *Query) shouldPrepare() bool { + return shouldPrepare(q.stmt) +} - stmt := strings.TrimLeftFunc(strings.TrimRightFunc(q.stmt, func(r rune) bool { +func shouldPrepare(s string) bool { + stmt := strings.TrimLeftFunc(strings.TrimRightFunc(s, func(r rune) bool { return unicode.IsSpace(r) || r == ';' }), unicode.IsSpace) @@ -1256,11 +1184,6 @@ func (q *Query) SetSpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Qu return q } -// speculativeExecutionPolicy fetches the policy -func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy { - return q.spec -} - // IsIdempotent returns whether the query is marked as idempotent. // Non-idempotent query won't be retried. // See "Retries and speculative execution" in package docs for more details. @@ -1281,7 +1204,6 @@ func (q *Query) Idempotent(value bool) *Query { // to an existing query instance. func (q *Query) Bind(v ...interface{}) *Query { q.values = v - q.pageState = nil return q } @@ -1302,7 +1224,7 @@ func (q *Query) SerialConsistency(cons Consistency) *Query { // point in time. Setting this will disable to query paging for this query, and // must be used for all subsequent pages. func (q *Query) PageState(state []byte) *Query { - q.pageState = state + q.initialPageState = state q.disableAutoPage = true return q } @@ -1325,6 +1247,11 @@ func (q *Query) Exec() error { return q.Iter().Close() } +// ExecContext executes the query with the provided context without returning any rows. +func (q *Query) ExecContext(ctx context.Context) error { + return q.IterContext(ctx).Close() +} + func isUseStatement(stmt string) bool { if len(stmt) < 3 { return false @@ -1336,15 +1263,25 @@ func isUseStatement(stmt string) bool { // Iter executes the query and returns an iterator capable of iterating // over all results. func (q *Query) Iter() *Iter { + return q.IterContext(nil) +} + +// IterContext executes the query with the provided context and returns an iterator capable of iterating +// over all results. +func (q *Query) IterContext(ctx context.Context) *Iter { if isUseStatement(q.stmt) { - return &Iter{err: ErrUseStmt} + return newErrIter(ErrUseStmt, newQueryMetrics(), q.Keyspace(), nil, q.getKeyspace) } - // if the query was specifically run on a connection then re-use that - // connection when fetching the next results - if q.conn != nil { - return q.conn.executeQuery(q.Context(), q) - } - return q.session.executeQuery(q) + + internalQry := newInternalQuery(q, ctx) + return q.session.executeQuery(internalQry) +} + +func (q *Query) iterInternal(c *Conn, ctx context.Context) *Iter { + internalQry := newInternalQuery(q, ctx) + internalQry.conn = c + + return c.executeQuery(internalQry.Context(), internalQry) } // MapScan executes the query, copies the columns of the first selected @@ -1421,43 +1358,6 @@ func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error return applied, iter.Close() } -// Release releases a query back into a pool of queries. Released Queries -// cannot be reused. -// -// Example: -// -// qry := session.Query("SELECT * FROM my_table") -// qry.Exec() -// qry.Release() -func (q *Query) Release() { - q.decRefCount() -} - -// reset zeroes out all fields of a query so that it can be safely pooled. -func (q *Query) reset() { - *q = Query{routingInfo: &queryRoutingInfo{}, refCount: 1} -} - -func (q *Query) incRefCount() { - atomic.AddUint32(&q.refCount, 1) -} - -func (q *Query) decRefCount() { - if res := atomic.AddUint32(&q.refCount, ^uint32(0)); res == 0 { - // do release - q.reset() - queryPool.Put(q) - } -} - -func (q *Query) borrowForExecution() { - q.incRefCount() -} - -func (q *Query) releaseAfterExecution() { - q.decRefCount() -} - // SetHostID allows to define the host the query should be executed against. If the // host was filtered or otherwise unavailable, then the query will error. If an empty // string is sent, the default behavior, using the configured HostSelectionPolicy will @@ -1490,9 +1390,13 @@ func (q *Query) WithNowInSeconds(now int) *Query { return q } -// Iter represents an iterator that can be used to iterate over all rows that -// were returned by a query. The iterator might send additional queries to the +// Iter represents the result that was returned by the execution of a statement. +// +// If the statement is a query then this can be seen as an iterator that can be used to iterate over all rows that +// were returned by the query. The iterator might send additional queries to the // database during the iteration if paging was enabled. +// +// It also contains metadata about the request that can be accessed by Iter.Keyspace(), Iter.Table(), Iter.Attempts(), Iter.Latency(). type Iter struct { err error pos int @@ -1500,12 +1404,27 @@ type Iter struct { numRows int next *nextIter host *HostInfo + metrics *queryMetrics + + getKeyspace func() string + keyspace string + routingInfo *queryRoutingInfo framer *framer closed int32 } -// Host returns the host which the query was sent to. +func newErrIter(err error, metrics *queryMetrics, keyspace string, routingInfo *queryRoutingInfo, getKeyspace func() string) *Iter { + iter := newIter(metrics, keyspace, routingInfo, getKeyspace) + iter.err = err + return iter +} + +func newIter(metrics *queryMetrics, keyspace string, routingInfo *queryRoutingInfo, getKeyspace func() string) *Iter { + return &Iter{metrics: metrics, keyspace: keyspace, routingInfo: routingInfo, getKeyspace: getKeyspace} +} + +// Host returns the host which the statement was sent to. func (iter *Iter) Host() *HostInfo { return iter.host } @@ -1515,6 +1434,39 @@ func (iter *Iter) Columns() []ColumnInfo { return iter.meta.columns } +// Attempts returns the number of times the statement was executed. +func (iter *Iter) Attempts() int { + return iter.metrics.attempts() +} + +// Latency returns the average amount of nanoseconds per attempt of the statement. +func (iter *Iter) Latency() int64 { + return iter.metrics.latency() +} + +// Keyspace returns the keyspace the statement was executed against if the driver could determine it. +func (iter *Iter) Keyspace() string { + if iter.getKeyspace != nil { + return iter.getKeyspace() + } + + if iter.routingInfo != nil { + if ks := iter.routingInfo.getKeyspace(); ks != "" { + return ks + } + } + + return iter.keyspace +} + +// Table returns name of the table the statement was executed against if the driver could determine it. +func (iter *Iter) Table() string { + if iter.routingInfo != nil { + return iter.routingInfo.getTable() + } + return "" +} + type Scanner interface { // Next advances the row pointer to point at the next row, the row is valid until // the next call of Next. It returns true if there is a row which is available to be @@ -1765,7 +1717,7 @@ func (iter *Iter) NumRows() int { // nextIter holds state for fetching a single page in an iterator. // single page might be attempted multiple times due to retries. type nextIter struct { - qry *Query + q *internalQuery pos int oncea sync.Once once sync.Once @@ -1782,10 +1734,10 @@ func (n *nextIter) fetch() *Iter { n.once.Do(func() { // if the query was specifically run on a connection then re-use that // connection when fetching the next results - if n.qry.conn != nil { - n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry) + if n.q.conn != nil { + n.next = n.q.conn.executeQuery(n.q.qryOpts.context, n.q) } else { - n.next = n.qry.session.executeQuery(n.qry) + n.next = n.q.session.executeQuery(n.q) } }) return n.next @@ -1806,17 +1758,14 @@ type Batch struct { defaultTimestamp bool defaultTimestampValue int64 context context.Context - cancelBatch func() keyspace string - metrics *queryMetrics nowInSeconds *int - - // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. - routingInfo *queryRoutingInfo } // Deprecated: use Session.Batch instead // NewBatch creates a new batch operation using defaults defined in the cluster +// +// Deprecated: use Session.Batch instead func (s *Session) NewBatch(typ BatchType) *Batch { return s.Batch(typ) } @@ -1833,9 +1782,7 @@ func (s *Session) Batch(typ BatchType) *Batch { Cons: s.cons, defaultTimestamp: s.cfg.DefaultTimestamp, keyspace: s.cfg.Keyspace, - metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, spec: &NonSpeculativeExecution{}, - routingInfo: &queryRoutingInfo{}, } return batch @@ -1859,27 +1806,12 @@ func (b *Batch) Keyspace() string { return b.keyspace } -// Batch has no reasonable eqivalent of Query.Table(). -func (b *Batch) Table() string { - return b.routingInfo.table -} - -// Attempts returns the number of attempts made to execute the batch. -func (b *Batch) Attempts() int { - return b.metrics.attempts() -} - -func (b *Batch) AddAttempts(i int, host *HostInfo) { - b.metrics.attempt(i, 0, host, false) -} - -// Latency returns the average number of nanoseconds to execute a single attempt of the batch. -func (b *Batch) Latency() int64 { - return b.metrics.latency() -} - -func (b *Batch) AddLatency(l int64, host *HostInfo) { - b.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) +// Consistency sets the consistency level for this batch. If no consistency +// level have been set, the default consistency level of the cluster +// is used. +func (b *Batch) Consistency(cons Consistency) *Batch { + b.Cons = cons + return b } // GetConsistency returns the currently configured consistency level for the batch @@ -1888,8 +1820,7 @@ func (b *Batch) GetConsistency() Consistency { return b.Cons } -// SetConsistency sets the currently configured consistency level for the batch -// operation. +// Deprecated: Use Batch.Consistency func (b *Batch) SetConsistency(c Consistency) { b.Cons = c } @@ -1932,20 +1863,14 @@ func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) ([]interface{}, error) b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind}) } -func (b *Batch) retryPolicy() RetryPolicy { - return b.rt -} - // RetryPolicy sets the retry policy to use when executing the batch operation func (b *Batch) RetryPolicy(r RetryPolicy) *Batch { b.rt = r return b } -func (b *Batch) withContext(ctx context.Context) ExecutableQuery { - return b.WithContext(ctx) -} - +// Deprecated: Use Batch.ExecContext or Batch.IterContext instead. This will be removed in a future major version. +// // WithContext returns a shallow copy of b with its context // set to ctx. // @@ -1958,11 +1883,6 @@ func (b *Batch) WithContext(ctx context.Context) *Batch { return &b2 } -// Deprecate: does nothing, cancel the context passed to WithContext -func (*Batch) Cancel() { - // TODO: delete -} - // Size returns the number of batch statements to be executed by the batch operation. func (b *Batch) Size() int { return len(b.Entries) @@ -2006,59 +1926,6 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch { return b } -func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { - latency := end.Sub(start) - attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil) - - if b.observer == nil { - return - } - - statements := make([]string, len(b.Entries)) - values := make([][]interface{}, len(b.Entries)) - - for i, entry := range b.Entries { - statements[i] = entry.Stmt - values[i] = entry.Args - } - - b.observer.ObserveBatch(b.Context(), ObservedBatch{ - Keyspace: keyspace, - Statements: statements, - Values: values, - Start: start, - End: end, - // Rows not used in batch observations // TODO - might be able to support it when using BatchCAS - Host: host, - Metrics: metricsForHost, - Err: iter.err, - Attempt: attempt, - }) -} - -func (b *Batch) GetRoutingKey() ([]byte, error) { - if b.routingKey != nil { - return b.routingKey, nil - } - - if len(b.Entries) == 0 { - return nil, nil - } - - entry := b.Entries[0] - if entry.binding != nil { - // bindings do not have the values let's skip it like Query does. - return nil, nil - } - // try to determine the routing key - routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace) - if err != nil { - return nil, err - } - - return createRoutingKey(routingKeyInfo, entry.Args) -} - func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]byte, error) { if routingKeyInfo == nil { return nil, nil @@ -2096,21 +1963,6 @@ func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]b return routingKey, nil } -func (b *Batch) borrowForExecution() { - // empty, because Batch has no equivalent of Query.Release() - // that would race with speculative executions. -} - -func (b *Batch) releaseAfterExecution() { - // empty, because Batch has no equivalent of Query.Release() - // that would race with speculative executions. -} - -// GetHostID satisfies ExecutableQuery interface but does noop. -func (b *Batch) GetHostID() string { - return "" -} - // SetKeyspace will enable keyspace flag on the query. // It allows to specify the keyspace that the query should be executed in // @@ -2295,6 +2147,9 @@ type ObservedQuery struct { // Attempt is the index of attempt at executing this query. // The first attempt is number zero and any retries have non-zero attempt number. Attempt int + + // Query object associated with this request. Should be used as read only. + Query *Query } // QueryObserver is the interface implemented by query observers / stat collectors. @@ -2332,6 +2187,9 @@ type ObservedBatch struct { // Attempt is the index of attempt at executing this query. // The first attempt is number zero and any retries have non-zero attempt number. Attempt int + + // Batch object associated with this request. Should be used as read only. + Batch *Batch } // BatchObserver is the interface implemented by batch observers / stat collectors. diff --git a/session_test.go b/session_test.go index 48f7fe7dc..d414d6f41 100644 --- a/session_test.go +++ b/session_test.go @@ -30,7 +30,6 @@ package gocql import ( "context" "fmt" - "net" "testing" ) @@ -69,7 +68,7 @@ func TestSessionAPI(t *testing.T) { t.Fatalf("expected qry.stmt to be 'test', got '%v'", boundQry.stmt) } - itr := s.executeQuery(qry) + itr := s.executeQuery(newInternalQuery(qry, nil)) if itr.err != ErrNoConnections { t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err) } @@ -102,28 +101,7 @@ func (f funcQueryObserver) ObserveQuery(ctx context.Context, o ObservedQuery) { } func TestQueryBasicAPI(t *testing.T) { - qry := &Query{routingInfo: &queryRoutingInfo{}} - - // Initiate host - ip := "127.0.0.1" - - qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 0, TotalLatency: 0}}) - if qry.Latency() != 0 { - t.Fatalf("expected Query.Latency() to return 0, got %v", qry.Latency()) - } - - qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 2, TotalLatency: 4}}) - if qry.Attempts() != 2 { - t.Fatalf("expected Query.Attempts() to return 2, got %v", qry.Attempts()) - } - if qry.Latency() != 2 { - t.Fatalf("expected Query.Latency() to return 2, got %v", qry.Latency()) - } - - qry.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) - if qry.Attempts() != 4 { - t.Fatalf("expected Query.Attempts() to return 4, got %v", qry.Attempts()) - } + qry := &Query{} qry.Consistency(All) if qry.GetConsistency() != All { @@ -166,7 +144,7 @@ func TestQueryBasicAPI(t *testing.T) { func TestQueryShouldPrepare(t *testing.T) { toPrepare := []string{"select * ", "INSERT INTO", "update table", "delete from", "begin batch"} cantPrepare := []string{"create table", "USE table", "LIST keyspaces", "alter table", "drop table", "grant user", "revoke user"} - q := &Query{routingInfo: &queryRoutingInfo{}} + q := &Query{} for i := 0; i < len(toPrepare); i++ { q.stmt = toPrepare[i] @@ -210,29 +188,6 @@ func TestBatchBasicAPI(t *testing.T) { t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type) } - ip := "127.0.0.1" - - // Test attempts - b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1}}) - if b.Attempts() != 1 { - t.Fatalf("expected batch.Attempts() to return %v, got %v", 1, b.Attempts()) - } - - b.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) - if b.Attempts() != 3 { - t.Fatalf("expected batch.Attempts() to return %v, got %v", 3, b.Attempts()) - } - - // Test latency - if b.Latency() != 0 { - t.Fatalf("expected batch.Latency() to be 0, got %v", b.Latency()) - } - - b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1, TotalLatency: 4}}) - if b.Latency() != 4 { - t.Fatalf("expected batch.Latency() to return %v, got %v", 4, b.Latency()) - } - // Test Consistency b.Cons = One if b.GetConsistency() != One { @@ -277,6 +232,43 @@ func TestBatchBasicAPI(t *testing.T) { } +func TestQueryIterBasicApi(t *testing.T) { + session := createSession(t) + defer session.Close() + + qry := session.Query("INSERT INTO gocql_test.invalid_table(value) VALUES(1)") + iter1 := qry.Iter() + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to return 1, got %v", iter1.Attempts()) + } + iter2 := qry.Iter() + if iter2.Attempts() != 1 { + t.Fatalf("expected iter2 Iter.Attempts() to return 1, got %v", iter2.Attempts()) + } + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to still return 1, got %v", iter1.Attempts()) + } +} + +func TestBatchIterBasicApi(t *testing.T) { + session := createSession(t) + defer session.Close() + + b := session.Batch(LoggedBatch) + b.Query("INSERT INTO gocql_test.invalid_table(value) VALUES(1)") + iter1 := b.Iter() + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to return 1, got %v", iter1.Attempts()) + } + iter2 := b.Iter() + if iter2.Attempts() != 1 { + t.Fatalf("expected iter2 Iter.Attempts() to return 1, got %v", iter2.Attempts()) + } + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to still return 1, got %v", iter1.Attempts()) + } +} + func TestConsistencyNames(t *testing.T) { names := map[fmt.Stringer]string{ Any: "ANY", diff --git a/stress_test.go b/stress_test.go index be41707aa..bf4f24874 100644 --- a/stress_test.go +++ b/stress_test.go @@ -29,7 +29,6 @@ package gocql import ( "sync/atomic" - "testing" ) @@ -82,7 +81,7 @@ func BenchmarkConnRoutingKey(b *testing.B) { query := session.Query("insert into routing_key_stress (id) values (?)") for pb.Next() { - if _, err := query.Bind(i * seed).GetRoutingKey(); err != nil { + if _, err := newInternalQuery(query.Bind(i*seed), nil).GetRoutingKey(); err != nil { b.Error(err) return }