From 68e805a42c1365c546db47fb9bb6233bb45b6447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Reis?= Date: Tue, 10 Jun 2025 19:36:58 +0100 Subject: [PATCH] Changes to Query and Batch API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this change queries were mutated while being executed (the query metrics and the consistency for example). Instead copy query properties to an internal query object and move query metrics to Iter. This allows users to reuse Query and Batch objects. Query object pooling was also removed. Some query and batch properties were not accessible via ObservedBatch and ObservedQuery. Added the original Batch and Query objects to ObservedBatch and ObservedQuery to fix this. Patch by João Reis; reviewed by James Hartig and Stanislav Bychkov for CASSGO-22 and CASSGO-73 --- CHANGELOG.md | 2 + batch_test.go | 3 +- cass1batch_test.go | 2 +- cassandra_test.go | 103 +++++---- conn.go | 219 +++++++++--------- conn_test.go | 14 +- control.go | 14 +- hostpool/hostpool.go | 2 +- keyspace_table_test.go | 52 ++++- lz4/lz4_test.go | 3 +- policies.go | 10 +- policies_test.go | 30 +-- query_executor.go | 490 +++++++++++++++++++++++++++++++++++++++-- session.go | 470 ++++++++++++++------------------------- session_test.go | 88 ++++---- stress_test.go | 3 +- 16 files changed, 942 insertions(+), 563 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b17c5f008..f642ef900 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for Native Protocol 5. Following protocol changes exposed new API Query.SetKeyspace(), Query.WithNowInSeconds(), Batch.SetKeyspace(), Batch.WithNowInSeconds() (CASSGO-1) - Externally-defined type registration (CASSGO-43) +- Add Query and Batch to ObservedQuery and ObservedBatch (CASSGO-73) ### Changed @@ -43,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - inet columns default to net.IP when using MapScan or SliceMap (CASSGO-43) - NativeType removed (CASSGO-43) - `New` and `NewWithError` removed and replaced with `Zero` (CASSGO-43) +- Changes to Query and Batch to make them safely reusable (CASSGO-22) ### Fixed diff --git a/batch_test.go b/batch_test.go index 47adff83f..7f7d00253 100644 --- a/batch_test.go +++ b/batch_test.go @@ -28,9 +28,10 @@ package gocql import ( - "github.com/stretchr/testify/require" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestBatch_Errors(t *testing.T) { diff --git a/cass1batch_test.go b/cass1batch_test.go index 8e4eb99b2..b699a10bd 100644 --- a/cass1batch_test.go +++ b/cass1batch_test.go @@ -77,7 +77,7 @@ func TestShouldPrepareFunction(t *testing.T) { } for _, test := range shouldPrepareTests { - q := &Query{stmt: test.Stmt, routingInfo: &queryRoutingInfo{}} + q := &Query{stmt: test.Stmt} if got := q.shouldPrepare(); got != test.Result { t.Fatalf("%q: got %v, expected %v\n", test.Stmt, got, test.Result) } diff --git a/cassandra_test.go b/cassandra_test.go index 9fa2a0d1a..2613f17d4 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -1302,8 +1302,8 @@ func Test_RetryPolicyIdempotence(t *testing.T) { q.RetryPolicy(&MyRetryPolicy{}) q.Consistency(All) - _ = q.Exec() - require.Equal(t, tc.expectedNumberOfTries, q.Attempts()) + it := q.Iter() + require.Equal(t, tc.expectedNumberOfTries, it.Attempts()) }) } } @@ -1673,7 +1673,7 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) { defer s.Close() insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5) - if err := conn.executeQuery(ctx, insertQry).err; err == nil { + if err := conn.executeQuery(ctx, newInternalQuery(insertQry, nil)).err; err == nil { t.Fatal("expected error, but got nil.") } @@ -1681,7 +1681,7 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) { t.Fatal("create table:", err) } - if err := conn.executeQuery(ctx, insertQry).err; err != nil { + if err := conn.executeQuery(ctx, newInternalQuery(insertQry, nil)).err; err != nil { t.Fatal(err) // unconfigured columnfamily } } @@ -1695,7 +1695,7 @@ func TestPrepare_ReprepareStatement(t *testing.T) { stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement") query := session.Query(stmt, "bar") - if err := conn.executeQuery(ctx, query).Close(); err != nil { + if err := conn.executeQuery(ctx, newInternalQuery(query, nil)).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } @@ -1714,7 +1714,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) { stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch") batch := session.Batch(UnloggedBatch) batch.Query(stmt, "bar") - if err := conn.executeBatch(ctx, batch).Close(); err != nil { + if err := conn.executeBatch(ctx, newInternalBatch(batch, nil)).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } @@ -2059,14 +2059,16 @@ func TestQueryStats(t *testing.T) { session := createSession(t) defer session.Close() qry := session.Query("SELECT * FROM system.peers") - if err := qry.Exec(); err != nil { + iter := qry.Iter() + err := iter.Close() + if err != nil { t.Fatalf("query failed. %v", err) } else { - if qry.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if qry.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", qry.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } } @@ -2099,15 +2101,16 @@ func TestBatchStats(t *testing.T) { b := session.Batch(LoggedBatch) b.Query("INSERT INTO batchStats (id) VALUES (?)", 1) b.Query("INSERT INTO batchStats (id) VALUES (?)", 2) - - if err := b.Exec(); err != nil { + iter := b.Iter() + err := iter.Close() + if err != nil { t.Fatalf("query failed. %v", err) } else { - if b.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if b.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } } @@ -2850,7 +2853,7 @@ func TestRoutingKey(t *testing.T) { t.Errorf("Expected cache size to be 1 but was %d", cacheSize) } - query := session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2) + query := newInternalQuery(session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil) routingKey, err := query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) @@ -2892,7 +2895,7 @@ func TestRoutingKey(t *testing.T) { t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[1].Type()) } - query = session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2) + query = newInternalQuery(session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil) routingKey, err = query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) @@ -3441,15 +3444,16 @@ func TestUnsetColBatch(t *testing.T) { b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "") b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue) - - if err := b.Exec(); err != nil { + iter := b.Iter() + err := iter.Close() + if err != nil { t.Fatalf("query failed. %v", err) } else { - if b.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if b.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } var id, mInt, count int @@ -3702,7 +3706,9 @@ func TestQueryCompressionNotWorthIt(t *testing.T) { // The driver should handle this by updating its prepared statement inside the cache // when it receives RESULT/ROWS with Metadata_changed flag func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { - session := createSession(t) + session := createSession(t, func(config *ClusterConfig) { + config.NumConns = 1 + }) defer session.Close() if session.cfg.ProtoVersion < protoVersion5 { @@ -3726,13 +3732,17 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { t.Fatal(err) } - // We have to specify conn for all queries to ensure that + // We have to specify host for all queries to ensure that // all queries are running on the same node - conn := session.getConn() + hosts := session.GetHosts() + if len(hosts) == 0 { + t.Fatal("no hosts found") + } + hostid := hosts[0].HostID() const selectStmt = "SELECT * FROM gocql_test.metadata_changed" queryBeforeTableAltering := session.Query(selectStmt) - queryBeforeTableAltering.conn = conn + queryBeforeTableAltering.SetHostID(hostid) row := make(map[string]interface{}) err = queryBeforeTableAltering.MapScan(row) if err != nil { @@ -3742,13 +3752,16 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { require.Len(t, row, 1, "Expected to retrieve a single column") require.Equal(t, 1, row["id"]) - stmtCacheKey := session.stmtsLRU.keyFor(conn.host.HostID(), conn.currentKeyspace, queryBeforeTableAltering.stmt) - inflight, _ := session.stmtsLRU.get(stmtCacheKey) + stmtCacheKey := session.stmtsLRU.keyFor(hostid, "gocql_test", queryBeforeTableAltering.stmt) + inflight, ok := session.stmtsLRU.get(stmtCacheKey) + if !ok { + t.Fatalf("failed to find inflight entry for key %v", stmtCacheKey) + } preparedStatementBeforeTableAltering := inflight.preparedStatment // Changing table schema in order to cause C* to return RESULT/ROWS Metadata_changed alteringTableQuery := session.Query("ALTER TABLE gocql_test.metadata_changed ADD new_col int") - alteringTableQuery.conn = conn + alteringTableQuery.SetHostID(hostid) err = alteringTableQuery.Exec() if err != nil { t.Fatal(err) @@ -3800,7 +3813,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { // Expecting C* will return RESULT/ROWS Metadata_changed // and it will be properly handled queryAfterTableAltering := session.Query(selectStmt) - queryAfterTableAltering.conn = conn + queryAfterTableAltering.SetHostID(hostid) iter := queryAfterTableAltering.Iter() handleRows(iter) @@ -3825,7 +3838,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { defer cancel() queryAfterTableAltering2 := session.Query(selectStmt).WithContext(ctx) - queryAfterTableAltering2.conn = conn + queryAfterTableAltering2.SetHostID(hostid) iter = queryAfterTableAltering2.Iter() handleRows(iter) err = iter.Close() @@ -3842,7 +3855,7 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) { // Executing prepared stmt and expecting that C* won't return // Metadata_changed because the table is not being changed. queryAfterTableAltering3 := session.Query(selectStmt).WithContext(ctx) - queryAfterTableAltering3.conn = conn + queryAfterTableAltering3.SetHostID(hostid) iter = queryAfterTableAltering2.Iter() handleRows(iter) @@ -3946,7 +3959,10 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { getRoutingKeyInfo := func(key string) *routingKeyInfo { t.Helper() session.routingKeyInfoCache.mu.Lock() - value, _ := session.routingKeyInfoCache.lru.Get(key) + value, ok := session.routingKeyInfoCache.lru.Get(key) + if !ok { + t.Fatalf("routing key not found in cache for key %v", key) + } session.routingKeyInfoCache.mu.Unlock() inflight := value.(*inflightCachedEntry) @@ -3956,9 +3972,10 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)" // Running batch in default ks gocql_test - b1 := session.NewBatch(LoggedBatch) + b1 := session.Batch(LoggedBatch) b1.Query(insertQuery, 1) - _, err = b1.GetRoutingKey() + internalB := newInternalBatch(b1, nil) + _, err = internalB.GetRoutingKey() require.NoError(t, err) // Ensuring that the cache contains the query with default ks @@ -3966,10 +3983,11 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { require.Equal(t, "gocql_test", routingKeyInfo1.keyspace) // Running batch in gocql_test_routing_key_cache ks - b2 := session.NewBatch(LoggedBatch) + b2 := session.Batch(LoggedBatch) b2.SetKeyspace("gocql_test_routing_key_cache") b2.Query(insertQuery, 2) - _, err = b2.GetRoutingKey() + internalB2 := newInternalBatch(b2, nil) + _, err = internalB2.GetRoutingKey() require.NoError(t, err) // Ensuring that the cache contains the query with gocql_test_routing_key_cache ks @@ -3980,15 +3998,18 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { // Running query in default ks gocql_test q1 := session.Query(selectStmt, 1) - _, err = q1.GetRoutingKey() + iter := q1.Iter() + err = iter.Close() require.NoError(t, err) - require.Equal(t, "gocql_test", q1.routingInfo.keyspace) + require.Equal(t, "gocql_test", iter.Keyspace()) // Running query in gocql_test_routing_key_cache ks q2 := session.Query(selectStmt, 1) - _, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey() + q2.SetKeyspace("gocql_test_routing_key_cache") + iter = q2.Iter() + err = iter.Close() require.NoError(t, err) - require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace) + require.Equal(t, "gocql_test_routing_key_cache", iter.Keyspace()) session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec() } diff --git a/conn.go b/conn.go index c9efe723b..4e1bb4f2e 100644 --- a/conn.go +++ b/conn.go @@ -1497,32 +1497,34 @@ func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error return nil } -func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { +func (c *Conn) executeQuery(ctx context.Context, q *internalQuery) *Iter { + qryOpts := q.qryOpts params := queryParams{ - consistency: qry.cons, + consistency: q.GetConsistency(), } + iter := newIter(q.metrics, q.Keyspace(), q.routingInfo, q.qryOpts.getKeyspace) // frame checks that it is not 0 - params.serialConsistency = qry.serialCons - params.defaultTimestamp = qry.defaultTimestamp - params.defaultTimestampValue = qry.defaultTimestampValue + params.serialConsistency = qryOpts.serialCons + params.defaultTimestamp = qryOpts.defaultTimestamp + params.defaultTimestampValue = qryOpts.defaultTimestampValue - if len(qry.pageState) > 0 { - params.pagingState = qry.pageState + if len(q.pageState) > 0 { + params.pagingState = q.pageState } - if qry.pageSize > 0 { - params.pageSize = qry.pageSize + if qryOpts.pageSize > 0 { + params.pageSize = qryOpts.pageSize } if c.version > protoVersion4 { - params.keyspace = qry.keyspace - params.nowInSeconds = qry.nowInSecondsValue + params.keyspace = qryOpts.keyspace + params.nowInSeconds = qryOpts.nowInSecondsValue } // If a keyspace for the qry is overriden, // then we should use it to create stmt cache key usedKeyspace := c.currentKeyspace - if qry.keyspace != "" { - usedKeyspace = qry.keyspace + if qryOpts.keyspace != "" { + usedKeyspace = qryOpts.keyspace } var ( @@ -1530,17 +1532,18 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { info *preparedStatment ) - if !qry.skipPrepare && qry.shouldPrepare() { + if !qryOpts.skipPrepare && shouldPrepare(qryOpts.stmt) { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace) + info, err = c.prepareStatement(ctx, qryOpts.stmt, qryOpts.trace, usedKeyspace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } - values := qry.values - if qry.binding != nil { - values, err = qry.binding(&QueryInfo{ + values := qryOpts.values + if qryOpts.binding != nil { + values, err = qryOpts.binding(&QueryInfo{ Id: info.id, Args: info.request.columns, Rval: info.response.columns, @@ -1548,12 +1551,14 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { }) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } } if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))} + iter.err = fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values)) + return iter } params.values = make([]queryValues, len(values)) @@ -1562,59 +1567,63 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { value := values[i] typ := info.request.columns[i].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { - return &Iter{err: err} + iter.err = err + return iter } } // if the metadata was not present in the response then we should not skip it - params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) && info != nil && info.response.flags&flagNoMetaData == 0 + params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qryOpts.disableSkipMetadata) && info != nil && info.response.flags&flagNoMetaData == 0 frame = &writeExecuteFrame{ preparedID: info.id, params: params, - customPayload: qry.customPayload, + customPayload: qryOpts.customPayload, resultMetadataID: info.resultMetadataID, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata - qry.routingInfo.mu.Lock() - qry.routingInfo.keyspace = info.request.keyspace + q.routingInfo.mu.Lock() + q.routingInfo.keyspace = info.request.keyspace if info.request.keyspace == "" { - qry.routingInfo.keyspace = usedKeyspace + q.routingInfo.keyspace = usedKeyspace } - qry.routingInfo.table = info.request.table - qry.routingInfo.mu.Unlock() + q.routingInfo.table = info.request.table + q.routingInfo.mu.Unlock() } else { frame = &writeQueryFrame{ - statement: qry.stmt, + statement: qryOpts.stmt, params: params, - customPayload: qry.customPayload, + customPayload: qryOpts.customPayload, } } - framer, err := c.exec(ctx, frame, qry.trace) + framer, err := c.exec(ctx, frame, qryOpts.trace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } resp, err := framer.parseFrame() if err != nil { - return &Iter{err: err} + iter.err = err + return iter } - if len(framer.traceID) > 0 && qry.trace != nil { - qry.trace.Trace(framer.traceID) + if len(framer.traceID) > 0 && qryOpts.trace != nil { + qryOpts.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: - return &Iter{framer: framer} + iter.framer = framer + return iter case *resultRowsFrame: if x.meta.newMetadataID != nil { // If a RESULT/Rows message reports // changed resultset metadata with the Metadata_changed flag, the reported new // resultset metadata must be used in subsequent executions - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qryOpts.stmt) oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey) if ok { newInflight := &inflightPrepare{ @@ -1635,33 +1644,32 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { info = newInflight.preparedStatment } } - - iter := &Iter{ - meta: x.meta, - framer: framer, - numRows: x.numRows, - } + iter.meta = x.meta + iter.framer = framer + iter.numRows = x.numRows if x.meta.noMetaData() { if info != nil { iter.meta = info.response iter.meta.pagingState = copyBytes(x.meta.pagingState) } else { - return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} + iter = newErrIter(errors.New("gocql: did not receive metadata but prepared info is nil"), q.metrics, q.Keyspace(), q.routingInfo, q.qryOpts.getKeyspace) + iter.framer = framer + return iter } } else { iter.meta = x.meta } - if x.meta.morePages() && !qry.disableAutoPage { - newQry := new(Query) - *newQry = *qry + if x.meta.morePages() && !qryOpts.disableAutoPage { + newQry := new(internalQuery) + *newQry = *q newQry.pageState = copyBytes(x.meta.pagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ - qry: newQry, - pos: int((1 - qry.prefetch) * float64(x.numRows)), + q: newQry, + pos: int((1 - qryOpts.prefetch) * float64(x.numRows)), } if iter.next.pos < 1 { @@ -1671,9 +1679,10 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { return iter case *resultKeyspaceFrame: - return &Iter{framer: framer} + iter.framer = framer + return iter case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: - iter := &Iter{framer: framer} + iter.framer = framer if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag c.logger.Println(err) @@ -1683,16 +1692,17 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: - stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt) + stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qryOpts.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) - return c.executeQuery(ctx, qry) + return c.executeQuery(ctx, q) case error: - return &Iter{err: x, framer: framer} + iter.err = x + iter.framer = framer + return iter default: - return &Iter{ - err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), - framer: framer, - } + iter.err = NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x) + iter.framer = framer + return iter } } @@ -1744,38 +1754,40 @@ func (c *Conn) UseKeyspace(keyspace string) error { return nil } -func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { - n := len(batch.Entries) +func (c *Conn) executeBatch(ctx context.Context, b *internalBatch) *Iter { + iter := newIter(b.metrics, b.Keyspace(), b.routingInfo, nil) + n := len(b.batchOpts.entries) req := &writeBatchFrame{ - typ: batch.Type, + typ: b.batchOpts.bType, statements: make([]batchStatment, n), - consistency: batch.Cons, - serialConsistency: batch.serialCons, - defaultTimestamp: batch.defaultTimestamp, - defaultTimestampValue: batch.defaultTimestampValue, - customPayload: batch.CustomPayload, + consistency: b.GetConsistency(), + serialConsistency: b.batchOpts.serialCons, + defaultTimestamp: b.batchOpts.defaultTimestamp, + defaultTimestampValue: b.batchOpts.defaultTimestampValue, + customPayload: b.batchOpts.customPayload, } if c.version > protoVersion4 { - req.keyspace = batch.keyspace - req.nowInSeconds = batch.nowInSeconds + req.keyspace = b.batchOpts.keyspace + req.nowInSeconds = b.batchOpts.nowInSeconds } usedKeyspace := c.currentKeyspace - if batch.keyspace != "" { - usedKeyspace = batch.keyspace + if b.batchOpts.keyspace != "" { + usedKeyspace = b.batchOpts.keyspace } - stmts := make(map[string]string, len(batch.Entries)) + stmts := make(map[string]string, len(b.batchOpts.entries)) for i := 0; i < n; i++ { - entry := &batch.Entries[i] - b := &req.statements[i] + entry := &b.batchOpts.entries[i] + batchStmt := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace) + info, err := c.prepareStatement(ctx, entry.Stmt, b.batchOpts.trace, usedKeyspace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } var values []interface{} @@ -1789,68 +1801,75 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { PKeyColumns: info.request.pkeyColumns, }) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } } if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))} + iter.err = fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values)) + return iter } - b.preparedID = info.id + batchStmt.preparedID = info.id stmts[string(info.id)] = entry.Stmt - b.values = make([]queryValues, info.request.actualColCount) + batchStmt.values = make([]queryValues, info.request.actualColCount) for j := 0; j < info.request.actualColCount; j++ { - v := &b.values[j] + v := &batchStmt.values[j] value := values[j] typ := info.request.columns[j].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { - return &Iter{err: err} + iter.err = err + return iter } } } else { - b.statement = entry.Stmt + batchStmt.statement = entry.Stmt } } - framer, err := c.exec(batch.Context(), req, batch.trace) + framer, err := c.exec(ctx, req, b.batchOpts.trace) if err != nil { - return &Iter{err: err} + iter.err = err + return iter } resp, err := framer.parseFrame() if err != nil { - return &Iter{err: err, framer: framer} + iter.err = err + iter.framer = framer + return iter } - if len(framer.traceID) > 0 && batch.trace != nil { - batch.trace.Trace(framer.traceID) + if len(framer.traceID) > 0 && b.batchOpts.trace != nil { + b.batchOpts.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: - return &Iter{} + return iter case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } - return c.executeBatch(ctx, batch) + return c.executeBatch(ctx, b) case *resultRowsFrame: - iter := &Iter{ - meta: x.meta, - framer: framer, - numRows: x.numRows, - } - + iter.meta = x.meta + iter.framer = framer + iter.numRows = x.numRows return iter case error: - return &Iter{err: x, framer: framer} + iter.err = x + iter.framer = framer + return iter default: - return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} + iter.err = NewErrProtocol("Unknown type in response to batch statement: %s", x) + iter.framer = framer + return iter } } @@ -1858,9 +1877,9 @@ func (c *Conn) query(ctx context.Context, statement string, values ...interface{ q := c.session.Query(statement, values...).Consistency(One).Trace(nil) q.skipPrepare = true q.disableSkipMetadata = true + // we want to keep the query on this connection - q.conn = c - return c.executeQuery(ctx, q) + return q.iterInternal(c, ctx) } func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter { diff --git a/conn_test.go b/conn_test.go index 4b7ed732d..9a3586196 100644 --- a/conn_test.go +++ b/conn_test.go @@ -392,12 +392,14 @@ func TestQueryRetry(t *testing.T) { rt := &SimpleRetryPolicy{NumRetries: 1} qry := db.Query("kill").RetryPolicy(rt) - if err := qry.Exec(); err == nil { + iter := qry.Iter() + err = iter.Close() + if err == nil { t.Fatalf("expected error") } requests := atomic.LoadInt64(&srv.nKillReq) - attempts := qry.Attempts() + attempts := iter.Attempts() if requests != int64(attempts) { t.Fatalf("expected requests %v to match query attempts %v", requests, attempts) } @@ -439,13 +441,15 @@ func TestQueryMultinodeWithMetrics(t *testing.T) { rt := &SimpleRetryPolicy{NumRetries: 3} observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false, logger: log} qry := db.Query("kill").RetryPolicy(rt).Observer(observer).Idempotent(true) - if err := qry.Exec(); err == nil { + iter := qry.Iter() + err = iter.Close() + if err == nil { t.Fatalf("expected error") } for i, ip := range addresses { host := &HostInfo{connectAddress: net.ParseIP(ip)} - queryMetric := qry.metrics.hostMetrics(host) + queryMetric := iter.metrics.hostMetrics(host) observedMetrics := observer.GetMetrics(host) requests := int(atomic.LoadInt64(&nodes[i].nKillReq)) @@ -465,7 +469,7 @@ func TestQueryMultinodeWithMetrics(t *testing.T) { } } // the query will only be attempted once, but is being retried - attempts := qry.Attempts() + attempts := iter.Attempts() if attempts != rt.NumRetries { t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, attempts) } diff --git a/control.go b/control.go index 113bfaaa0..dfc7dc021 100644 --- a/control.go +++ b/control.go @@ -502,7 +502,7 @@ func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { return fn(ch) } - return &Iter{err: errNoControl} + return newErrIter(errNoControl, newQueryMetrics(), "", nil, nil) } func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { @@ -514,20 +514,20 @@ func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { // query will return nil if the connection is closed or nil func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil) + qry := newInternalQuery(q, context.TODO()) for { iter = c.withConn(func(conn *Conn) *Iter { - // we want to keep the query on the control connection - q.conn = conn - return conn.executeQuery(context.TODO(), q) + qry.conn = conn + return conn.executeQuery(qry.Context(), qry) }) if gocqlDebug && iter.err != nil { c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err) } - q.AddAttempts(1, c.getConn().host) - if iter.err == nil || !c.retry.Attempt(q) { + iter.metrics.attempt(1, 0, c.getConn().host, false) + if iter.err == nil || !c.retry.Attempt(qry) { break } } @@ -537,7 +537,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter func (c *controlConn) awaitSchemaAgreement() error { return c.withConn(func(conn *Conn) *Iter { - return &Iter{err: conn.awaitSchemaAgreement(context.TODO())} + return newErrIter(conn.awaitSchemaAgreement(context.TODO()), newQueryMetrics(), "", nil, nil) }).err } diff --git a/hostpool/hostpool.go b/hostpool/hostpool.go index e4a648598..f3a3d0f6f 100644 --- a/hostpool/hostpool.go +++ b/hostpool/hostpool.go @@ -100,7 +100,7 @@ func (r *hostPoolHostPolicy) HostDown(host *gocql.HostInfo) { r.RemoveHost(host) } -func (r *hostPoolHostPolicy) Pick(qry gocql.ExecutableQuery) gocql.NextHost { +func (r *hostPoolHostPolicy) Pick(qry gocql.ExecutableStatement) gocql.NextHost { return func() gocql.SelectedHost { r.mu.RLock() defer r.mu.RUnlock() diff --git a/keyspace_table_test.go b/keyspace_table_test.go index f3d51f3f8..7862b4892 100644 --- a/keyspace_table_test.go +++ b/keyspace_table_test.go @@ -33,7 +33,7 @@ import ( "testing" ) -// Keyspace_table checks if Query.Keyspace() is updated based on prepared statement +// Keyspace_table checks if Iter.Keyspace() is updated based on prepared statement func TestKeyspaceTable(t *testing.T) { cluster := createCluster() @@ -45,8 +45,7 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal("createSession:", err) } - cluster.Keyspace = "wrong_keyspace" - + wrongKeyspace := "testwrong" keyspace := "test1" table := "table1" @@ -55,6 +54,11 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal("unable to drop keyspace:", err) } + err = createTable(session, `DROP KEYSPACE IF EXISTS `+wrongKeyspace) + if err != nil { + t.Fatal("unable to drop keyspace:", err) + } + err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s WITH replication = { 'class' : 'SimpleStrategy', @@ -65,6 +69,16 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal("unable to create keyspace:", err) } + err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }`, wrongKeyspace)) + + if err != nil { + t.Fatal("unable to create keyspace:", err) + } + if err := session.control.awaitSchemaAgreement(); err != nil { t.Fatal(err) } @@ -80,11 +94,24 @@ func TestKeyspaceTable(t *testing.T) { t.Fatal(err) } + session.Close() + + cluster = createCluster() + + fallback = RoundRobinHostPolicy() + cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback) + cluster.Keyspace = wrongKeyspace + + session, err = cluster.CreateSession() + if err != nil { + t.Fatal("createSession:", err) + } + ctx := context.Background() // insert a row if err := session.Query(`INSERT INTO test1.table1(pk, ck, v) VALUES (?, ?, ?)`, - 1, 2, 3).WithContext(ctx).Consistency(One).Exec(); err != nil { + 1, 2, 3).Consistency(One).ExecContext(ctx); err != nil { t.Fatal(err) } @@ -93,13 +120,20 @@ func TestKeyspaceTable(t *testing.T) { /* Search for a specific set of records whose 'pk' column matches * the value of inserted row. */ qry := session.Query(`SELECT pk FROM test1.table1 WHERE pk = ? LIMIT 1`, - 1).WithContext(ctx).Consistency(One) - if err := qry.Scan(&pk); err != nil { + 1).Consistency(One) + iter := qry.IterContext(ctx) + ok := iter.Scan(&pk) + err = iter.Close() + if err != nil { t.Fatal(err) } + if !ok { + t.Fatal("expected pk to be scanned") + } - // cluster.Keyspace was set to "wrong_keyspace", but during prepering statement + // cluster.Keyspace was set to "testwrong", but during prepering statement // Keyspace in Query should be changed to "test" and Table should be changed to table1 - assertEqual(t, "qry.Keyspace()", "test1", qry.Keyspace()) - assertEqual(t, "qry.Table()", "table1", qry.Table()) + assertEqual(t, "qry.Keyspace()", "testwrong", qry.Keyspace()) + assertEqual(t, "iter.Keyspace()", "test1", iter.Keyspace()) + assertEqual(t, "iter.Table()", "table1", iter.Table()) } diff --git a/lz4/lz4_test.go b/lz4/lz4_test.go index ea64371c5..5c00749a2 100644 --- a/lz4/lz4_test.go +++ b/lz4/lz4_test.go @@ -28,9 +28,10 @@ package lz4 import ( - "github.com/pierrec/lz4/v4" "testing" + "github.com/pierrec/lz4/v4" + "github.com/stretchr/testify/require" ) diff --git a/policies.go b/policies.go index 219f472df..c9f3399fd 100644 --- a/policies.go +++ b/policies.go @@ -304,7 +304,7 @@ type HostSelectionPolicy interface { // Multiple attempts of a single query execution won't call the returned NextHost function concurrently, // so it's safe to have internal state without additional synchronization as long as every call to Pick returns // a different instance of NextHost. - Pick(ExecutableQuery) NextHost + Pick(statement ExecutableStatement) NextHost } // SelectedHost is an interface returned when picking a host from a host @@ -342,7 +342,7 @@ func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {} func (r *roundRobinHostPolicy) Init(*Session) {} -func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost { +func (r *roundRobinHostPolicy) Pick(qry ExecutableStatement) NextHost { nextStartOffset := atomic.AddUint64(&r.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), r.hosts.get()) } @@ -577,7 +577,7 @@ func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo, logg m.tokenRing = tokenRing } -func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { +func (t *tokenAwareHostPolicy) Pick(qry ExecutableStatement) NextHost { if qry == nil { return t.fallback.Pick(qry) } @@ -771,7 +771,7 @@ func roundRobbin(shift int, hosts ...[]*HostInfo) NextHost { } } -func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { +func (d *dcAwareRR) Pick(q ExecutableStatement) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) } @@ -833,7 +833,7 @@ func (d *rackAwareRR) RemoveHost(host *HostInfo) { func (d *rackAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } func (d *rackAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } -func (d *rackAwareRR) Pick(q ExecutableQuery) NextHost { +func (d *rackAwareRR) Pick(q ExecutableStatement) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get(), d.hosts[2].get()) } diff --git a/policies_test.go b/policies_test.go index e8bda8908..540742a0f 100644 --- a/policies_test.go +++ b/policies_test.go @@ -76,7 +76,7 @@ func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { return nil, errors.New("not initalized") } - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -129,7 +129,7 @@ func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { // now the token ring is configured query.RoutingKey([]byte("20")) - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first token-aware hosts expectHosts(t, "hosts[0]", iter, "1") expectHosts(t, "hosts[1]", iter, "2") @@ -182,11 +182,11 @@ func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) { } policy.SetPartitioner("OrderedPartitioner") - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return "myKeyspace" } query.RoutingKey([]byte("20")) - iter := policy.Pick(query) + iter := policy.Pick(newInternalQuery(query, nil)) next := iter() if next == nil { t.Fatal("got nil host") @@ -240,7 +240,7 @@ func TestCOWList_Add(t *testing.T) { // TestSimpleRetryPolicy makes sure that we only allow 1 + numRetries attempts func TestSimpleRetryPolicy(t *testing.T) { - q := &Query{routingInfo: &queryRoutingInfo{}} + q := newInternalQuery(&Query{}, nil) // this should allow a total of 3 tries. rt := &SimpleRetryPolicy{NumRetries: 2} @@ -298,7 +298,7 @@ func TestExponentialBackoffPolicy(t *testing.T) { func TestDowngradingConsistencyRetryPolicy(t *testing.T) { - q := &Query{cons: LocalQuorum, routingInfo: &queryRoutingInfo{}} + q := newInternalQuery(&Query{initialConsistency: LocalQuorum}, nil) rewt0 := &RequestErrWriteTimeout{ Received: 0, @@ -459,7 +459,7 @@ func TestHostPolicy_TokenAware(t *testing.T) { return nil, errors.New("not initialized") } - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -497,7 +497,7 @@ func TestHostPolicy_TokenAware(t *testing.T) { } query.RoutingKey([]byte("30")) - if actual := policy.Pick(query)(); actual == nil { + if actual := policy.Pick(newInternalQuery(query, nil))(); actual == nil { t.Fatal("expected to get host from fallback got nil") } @@ -541,7 +541,7 @@ func TestHostPolicy_TokenAware(t *testing.T) { // now the token ring is configured query.RoutingKey([]byte("23")) - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first should be host with matching token from the local DC expectHosts(t, "matching token from local DC", iter, "4") // next are in non-deterministic order @@ -561,7 +561,7 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { return nil, errors.New("not initialized") } - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -632,7 +632,7 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { // now the token ring is configured query.RoutingKey([]byte("18")) - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first should be hosts with matching token from the local DC expectHosts(t, "matching token from local DC", iter, "4", "7") // rest should be hosts with matching token from remote DCs @@ -688,7 +688,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { policyWithFallbackInternal.getKeyspaceName = policyInternal.getKeyspaceName policyWithFallbackInternal.getKeyspaceMetadata = policyInternal.getKeyspaceMetadata - query := &Query{routingInfo: &queryRoutingInfo{}} + query := &Query{} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) @@ -727,7 +727,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { } query.RoutingKey([]byte("30")) - if actual := policy.Pick(query)(); actual == nil { + if actual := policy.Pick(newInternalQuery(query, nil))(); actual == nil { t.Fatal("expected to get host from fallback got nil") } @@ -775,7 +775,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { // now the token ring is configured // Test the policy with fallback - iter = policyWithFallback.Pick(query) + iter = policyWithFallback.Pick(newInternalQuery(query, nil)) // first should be host with matching token from the local DC & rack expectHosts(t, "matching token from local DC and local rack", iter, "7") @@ -792,7 +792,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) { expectNoMoreHosts(t, iter) // Test the policy without fallback - iter = policy.Pick(query) + iter = policy.Pick(newInternalQuery(query, nil)) // first should be host with matching token from the local DC & Rack expectHosts(t, "matching token from local DC and local rack", iter, "7") diff --git a/query_executor.go b/query_executor.go index 61a43c8c6..552d0b97c 100644 --- a/query_executor.go +++ b/query_executor.go @@ -31,22 +31,41 @@ import ( "time" ) -type ExecutableQuery interface { - borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine. - releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error. - execute(ctx context.Context, conn *Conn) *Iter - attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) - retryPolicy() RetryPolicy - speculativeExecutionPolicy() SpeculativeExecutionPolicy +// Deprecated: Will be removed in a future major release. Also Query and Batch no longer implement this interface. +// +// Please use Statement (for Query / Batch objects) or ExecutableStatement (in HostSelectionPolicy implementations) instead. +type ExecutableQuery = ExecutableStatement + +// ExecutableStatement is an interface that represents a query or batch statement that +// exposes the correct functions for the HostSelectionPolicy to operate correctly. +type ExecutableStatement interface { GetRoutingKey() ([]byte, error) Keyspace() string Table() string IsIdempotent() bool GetHostID() string + Statement() Statement +} - withContext(context.Context) ExecutableQuery +// Statement is an interface that represents a CQL statement that the driver can execute +// (currently Query and Batch via Session.Query and Session.Batch) +type Statement interface { + Iter() *Iter + IterContext(ctx context.Context) *Iter + Exec() error + ExecContext(ctx context.Context) error +} +type internalRequest interface { + execute(ctx context.Context, conn *Conn) *Iter + attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) + retryPolicy() RetryPolicy + speculativeExecutionPolicy() SpeculativeExecutionPolicy + getQueryMetrics() *queryMetrics + getRoutingInfo() *queryRoutingInfo + getKeyspaceFunc() func() string RetryableQuery + ExecutableStatement } type queryExecutor struct { @@ -54,7 +73,7 @@ type queryExecutor struct { policy HostSelectionPolicy } -func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter { +func (q *queryExecutor) attemptQuery(ctx context.Context, qry internalRequest, conn *Conn) *Iter { start := time.Now() iter := qry.execute(ctx, conn) end := time.Now() @@ -64,7 +83,7 @@ func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, c return iter } -func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, +func (q *queryExecutor) speculate(ctx context.Context, qry internalRequest, sp SpeculativeExecutionPolicy, hostIter NextHost, results chan *Iter) *Iter { ticker := time.NewTicker(sp.Delay()) defer ticker.Stop() @@ -72,10 +91,9 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S for i := 0; i < sp.Attempts(); i++ { select { case <-ticker.C: - qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) case <-ctx.Done(): - return &Iter{err: ctx.Err()} + return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) case iter := <-results: return iter } @@ -84,7 +102,7 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S return nil } -func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { +func (q *queryExecutor) executeQuery(qry internalRequest) (*Iter, error) { var hostIter NextHost // check if the host id is specified for the query, @@ -132,7 +150,6 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { results := make(chan *Iter, 1) // Launch the main execution - qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) // The speculative executions are launched _in addition_ to the main @@ -146,11 +163,11 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { case iter := <-results: return iter, nil case <-ctx.Done(): - return &Iter{err: ctx.Err()}, nil + return newErrIter(ctx.Err(), qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()), nil } } -func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { +func (q *queryExecutor) do(ctx context.Context, qry internalRequest, hostIter NextHost) *Iter { selectedHost := hostIter() rt := qry.retryPolicy() @@ -213,7 +230,7 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne stopRetries = true default: // Undefined? Return nil and error, this will panic in the requester - return &Iter{err: ErrUnknownRetryType} + return newErrIter(ErrUnknownRetryType, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } if stopRetries || attemptsReached { @@ -225,16 +242,447 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne } if lastErr != nil { - return &Iter{err: lastErr} + return newErrIter(lastErr, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } - return &Iter{err: ErrNoConnections} + return newErrIter(ErrNoConnections, qry.getQueryMetrics(), qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } -func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) { +func (q *queryExecutor) run(ctx context.Context, qry internalRequest, hostIter NextHost, results chan<- *Iter) { select { case results <- q.do(ctx, qry, hostIter): case <-ctx.Done(): } - qry.releaseAfterExecution() +} + +type queryOptions struct { + stmt string + + // Paging + pageSize int + disableAutoPage bool + + // Monitoring + trace Tracer + observer QueryObserver + + // Parameters + values []interface{} + binding func(q *QueryInfo) ([]interface{}, error) + + // Timestamp + defaultTimestamp bool + defaultTimestampValue int64 + + // Consistency + serialCons SerialConsistency + + // Protocol flag + disableSkipMetadata bool + + customPayload map[string][]byte + prefetch float64 + rt RetryPolicy + spec SpeculativeExecutionPolicy + context context.Context + idempotent bool + keyspace string + skipPrepare bool + routingKey []byte + nowInSecondsValue *int + hostID string + + // getKeyspace is field so that it can be overriden in tests + getKeyspace func() string +} + +func newQueryOptions(q *Query, ctx context.Context) *queryOptions { + var newRoutingKey []byte + if q.routingKey != nil { + routingKey := q.routingKey + newRoutingKey = make([]byte, len(routingKey)) + copy(newRoutingKey, routingKey) + } + if ctx == nil { + ctx = q.Context() + } + return &queryOptions{ + stmt: q.stmt, + values: q.values, + pageSize: q.pageSize, + prefetch: q.prefetch, + trace: q.trace, + observer: q.observer, + rt: q.rt, + spec: q.spec, + binding: q.binding, + serialCons: q.serialCons, + defaultTimestamp: q.defaultTimestamp, + defaultTimestampValue: q.defaultTimestampValue, + disableSkipMetadata: q.disableSkipMetadata, + context: ctx, + idempotent: q.idempotent, + customPayload: q.customPayload, + disableAutoPage: q.disableAutoPage, + skipPrepare: q.skipPrepare, + routingKey: newRoutingKey, + getKeyspace: q.getKeyspace, + nowInSecondsValue: q.nowInSecondsValue, + keyspace: q.keyspace, + hostID: q.hostID, + } +} + +type internalQuery struct { + originalQuery *Query + qryOpts *queryOptions + pageState []byte + metrics *queryMetrics + conn *Conn + consistency uint32 + session *Session + routingInfo *queryRoutingInfo +} + +func newInternalQuery(q *Query, ctx context.Context) *internalQuery { + var newPageState []byte + if q.initialPageState != nil { + pageState := q.initialPageState + newPageState = make([]byte, len(pageState)) + copy(newPageState, pageState) + } + return &internalQuery{ + originalQuery: q, + qryOpts: newQueryOptions(q, ctx), + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + consistency: uint32(q.initialConsistency), + pageState: newPageState, + conn: nil, + session: q.session, + routingInfo: &queryRoutingInfo{}, + } +} + +// Attempts returns the number of times the query was executed. +func (q *internalQuery) Attempts() int { + return q.metrics.attempts() +} + +func (q *internalQuery) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { + latency := end.Sub(start) + attempt, metricsForHost := q.metrics.attempt(1, latency, host, q.qryOpts.observer != nil) + + if q.qryOpts.observer != nil { + q.qryOpts.observer.ObserveQuery(q.qryOpts.context, ObservedQuery{ + Keyspace: keyspace, + Statement: q.qryOpts.stmt, + Values: q.qryOpts.values, + Start: start, + End: end, + Rows: iter.numRows, + Host: host, + Metrics: metricsForHost, + Err: iter.err, + Attempt: attempt, + Query: q.originalQuery, + }) + } +} + +func (q *internalQuery) execute(ctx context.Context, conn *Conn) *Iter { + return conn.executeQuery(ctx, q) +} + +func (q *internalQuery) retryPolicy() RetryPolicy { + return q.qryOpts.rt +} + +func (q *internalQuery) speculativeExecutionPolicy() SpeculativeExecutionPolicy { + return q.qryOpts.spec +} + +func (q *internalQuery) GetRoutingKey() ([]byte, error) { + if q.qryOpts.routingKey != nil { + return q.qryOpts.routingKey, nil + } + + if q.qryOpts.binding != nil && len(q.qryOpts.values) == 0 { + // If this query was created using session.Bind we wont have the query + // values yet, so we have to pass down to the next policy. + // TODO: Remove this and handle this case + return nil, nil + } + + // try to determine the routing key + routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.qryOpts.stmt, q.qryOpts.keyspace) + if err != nil { + return nil, err + } + + if routingKeyInfo != nil { + q.routingInfo.mu.Lock() + q.routingInfo.keyspace = routingKeyInfo.keyspace + q.routingInfo.table = routingKeyInfo.table + q.routingInfo.mu.Unlock() + } + return createRoutingKey(routingKeyInfo, q.qryOpts.values) +} + +func (q *internalQuery) Keyspace() string { + if q.qryOpts.getKeyspace != nil { + return q.qryOpts.getKeyspace() + } + + qrKs := q.routingInfo.getKeyspace() + if qrKs != "" { + return qrKs + } + if q.qryOpts.keyspace != "" { + return q.qryOpts.keyspace + } + + if q.session == nil { + return "" + } + // TODO(chbannis): this should be parsed from the query or we should let + // this be set by users. + return q.session.cfg.Keyspace +} + +func (q *internalQuery) Table() string { + return q.routingInfo.getTable() +} + +func (q *internalQuery) IsIdempotent() bool { + return q.qryOpts.idempotent +} + +func (q *internalQuery) getQueryMetrics() *queryMetrics { + return q.metrics +} + +func (q *internalQuery) SetConsistency(c Consistency) { + atomic.StoreUint32(&q.consistency, uint32(c)) +} + +func (q *internalQuery) GetConsistency() Consistency { + return Consistency(atomic.LoadUint32(&q.consistency)) +} + +func (q *internalQuery) Context() context.Context { + return q.qryOpts.context +} + +func (q *internalQuery) Statement() Statement { + return q.originalQuery +} + +func (q *internalQuery) GetHostID() string { + return q.qryOpts.hostID +} + +func (q *internalQuery) getRoutingInfo() *queryRoutingInfo { + return q.routingInfo +} + +func (q *internalQuery) getKeyspaceFunc() func() string { + return q.qryOpts.getKeyspace +} + +type batchOptions struct { + trace Tracer + observer BatchObserver + + bType BatchType + entries []BatchEntry + + defaultTimestamp bool + defaultTimestampValue int64 + + serialCons SerialConsistency + + customPayload map[string][]byte + rt RetryPolicy + spec SpeculativeExecutionPolicy + context context.Context + keyspace string + idempotent bool + routingKey []byte + nowInSeconds *int +} + +func newBatchOptions(b *Batch, ctx context.Context) *batchOptions { + // make a new array so if user keeps appending entries on the Batch object it doesn't affect this execution + newEntries := make([]BatchEntry, len(b.Entries)) + for i, e := range b.Entries { + newEntries[i] = e + } + var newRoutingKey []byte + if b.routingKey != nil { + routingKey := b.routingKey + newRoutingKey = make([]byte, len(routingKey)) + copy(newRoutingKey, routingKey) + } + if ctx == nil { + ctx = b.Context() + } + return &batchOptions{ + bType: b.Type, + entries: newEntries, + customPayload: b.CustomPayload, + rt: b.rt, + spec: b.spec, + trace: b.trace, + observer: b.observer, + serialCons: b.serialCons, + defaultTimestamp: b.defaultTimestamp, + defaultTimestampValue: b.defaultTimestampValue, + context: ctx, + keyspace: b.Keyspace(), + idempotent: b.IsIdempotent(), + routingKey: newRoutingKey, + nowInSeconds: b.nowInSeconds, + } +} + +type internalBatch struct { + originalBatch *Batch + batchOpts *batchOptions + metrics *queryMetrics + consistency uint32 + routingInfo *queryRoutingInfo + session *Session +} + +func newInternalBatch(batch *Batch, ctx context.Context) *internalBatch { + return &internalBatch{ + originalBatch: batch, + batchOpts: newBatchOptions(batch, ctx), + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + routingInfo: &queryRoutingInfo{}, + session: batch.session, + consistency: uint32(batch.GetConsistency()), + } +} + +// Attempts returns the number of attempts made to execute the batch. +func (b *internalBatch) Attempts() int { + return b.metrics.attempts() +} + +func (b *internalBatch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { + latency := end.Sub(start) + attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.batchOpts.observer != nil) + + if b.batchOpts.observer == nil { + return + } + + statements := make([]string, len(b.batchOpts.entries)) + values := make([][]interface{}, len(b.batchOpts.entries)) + + for i, entry := range b.batchOpts.entries { + statements[i] = entry.Stmt + values[i] = entry.Args + } + + b.batchOpts.observer.ObserveBatch(b.batchOpts.context, ObservedBatch{ + Keyspace: keyspace, + Statements: statements, + Values: values, + Start: start, + End: end, + // Rows not used in batch observations // TODO - might be able to support it when using BatchCAS + Host: host, + Metrics: metricsForHost, + Err: iter.err, + Attempt: attempt, + Batch: b.originalBatch, + }) +} + +func (b *internalBatch) retryPolicy() RetryPolicy { + return b.batchOpts.rt +} + +func (b *internalBatch) speculativeExecutionPolicy() SpeculativeExecutionPolicy { + return b.batchOpts.spec +} + +func (b *internalBatch) GetRoutingKey() ([]byte, error) { + if b.batchOpts.routingKey != nil { + return b.batchOpts.routingKey, nil + } + + if len(b.batchOpts.entries) == 0 { + return nil, nil + } + + entry := b.batchOpts.entries[0] + if entry.binding != nil { + // bindings do not have the values let's skip it like Query does. + return nil, nil + } + // try to determine the routing key + routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.batchOpts.keyspace) + if err != nil { + return nil, err + } + + if routingKeyInfo != nil { + b.routingInfo.mu.Lock() + b.routingInfo.keyspace = routingKeyInfo.keyspace + b.routingInfo.table = routingKeyInfo.table + b.routingInfo.mu.Unlock() + } + + return createRoutingKey(routingKeyInfo, entry.Args) +} + +func (b *internalBatch) Keyspace() string { + return b.batchOpts.keyspace +} + +func (b *internalBatch) Table() string { + return b.routingInfo.getTable() +} + +func (b *internalBatch) IsIdempotent() bool { + return b.batchOpts.idempotent +} + +func (b *internalBatch) getQueryMetrics() *queryMetrics { + return b.metrics +} + +func (b *internalBatch) SetConsistency(c Consistency) { + atomic.StoreUint32(&b.consistency, uint32(c)) +} + +func (b *internalBatch) GetConsistency() Consistency { + return Consistency(atomic.LoadUint32(&b.consistency)) +} + +func (b *internalBatch) Context() context.Context { + return b.batchOpts.context +} + +func (b *internalBatch) Statement() Statement { + return b.originalBatch +} + +func (b *internalBatch) GetHostID() string { + return "" +} + +func (b *internalBatch) getRoutingInfo() *queryRoutingInfo { + return b.routingInfo +} + +func (b *internalBatch) getKeyspaceFunc() func() string { + return nil +} + +func (b *internalBatch) execute(ctx context.Context, conn *Conn) *Iter { + return conn.executeBatch(ctx, b) } diff --git a/session.go b/session.go index bdda06406..dcd0b6207 100644 --- a/session.go +++ b/session.go @@ -104,12 +104,6 @@ type Session struct { logger StdLogger } -var queryPool = &sync.Pool{ - New: func() interface{} { - return &Query{routingInfo: &queryRoutingInfo{}, refCount: 1} - }, -} - func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInfo, error) { var hosts []*HostInfo for _, hostaddr := range addrs { @@ -386,7 +380,7 @@ func (s *Session) AwaitSchemaAgreement(ctx context.Context) error { return errNoControl } return s.control.withConn(func(conn *Conn) *Iter { - return &Iter{err: conn.awaitSchemaAgreement(ctx)} + return newErrIter(conn.awaitSchemaAgreement(ctx), newQueryMetrics(), "", nil, nil) }).err } @@ -426,7 +420,7 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) { // value before the query is executed. Query is automatically prepared // if it has not previously been executed. func (s *Session) Query(stmt string, values ...interface{}) *Query { - qry := queryPool.Get().(*Query) + qry := &Query{} qry.session = s qry.stmt = stmt qry.values = values @@ -449,7 +443,7 @@ type QueryInfo struct { // During execution, the meta data of the prepared query will be routed to the // binding callback, which is responsible for producing the query argument values. func (s *Session) Bind(stmt string, b func(q *QueryInfo) ([]interface{}, error)) *Query { - qry := queryPool.Get().(*Query) + qry := &Query{} qry.session = s qry.stmt = stmt qry.binding = b @@ -512,15 +506,15 @@ func (s *Session) initialized() bool { return initialized } -func (s *Session) executeQuery(qry *Query) (it *Iter) { +func (s *Session) executeQuery(qry *internalQuery) (it *Iter) { // fail fast if s.Closed() { - return &Iter{err: ErrSessionClosed} + return newErrIter(ErrSessionClosed, qry.metrics, qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } iter, err := s.executor.executeQuery(qry) if err != nil { - return &Iter{err: err} + return newErrIter(err, qry.metrics, qry.Keyspace(), qry.getRoutingInfo(), qry.getKeyspaceFunc()) } if iter == nil { panic("nil iter") @@ -713,33 +707,47 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace stri return routingKeyInfo, nil } -func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { - return conn.executeBatch(ctx, b) -} - // Exec executes a batch operation and returns nil if successful // otherwise an error is returned describing the failure. func (b *Batch) Exec() error { - iter := b.session.executeBatch(b) + iter := b.session.executeBatch(b, nil) + return iter.Close() +} + +// ExecContext executes a batch operation with the provided context and returns nil if successful +// otherwise an error is returned describing the failure. +func (b *Batch) ExecContext(ctx context.Context) error { + iter := b.session.executeBatch(b, ctx) return iter.Close() } -func (s *Session) executeBatch(batch *Batch) *Iter { +// Iter executes a batch operation and returns an Iter object +// that can be used to access properties related to the execution like Iter.Attempts and Iter.Latency +func (b *Batch) Iter() *Iter { return b.IterContext(nil) } + +// IterContext executes a batch operation with the provided context and returns an Iter object +// that can be used to access properties related to the execution like Iter.Attempts and Iter.Latency +func (b *Batch) IterContext(ctx context.Context) *Iter { + return b.session.executeBatch(b, ctx) +} + +func (s *Session) executeBatch(batch *Batch, ctx context.Context) *Iter { + b := newInternalBatch(batch, ctx) // fail fast if s.Closed() { - return &Iter{err: ErrSessionClosed} + return newErrIter(ErrSessionClosed, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc()) } // Prevent the execution of the batch if greater than the limit // Currently batches have a limit of 65536 queries. // https://datastax-oss.atlassian.net/browse/JAVA-229 if batch.Size() > BatchSizeMaximum { - return &Iter{err: ErrTooManyStmts} + return newErrIter(ErrTooManyStmts, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc()) } - iter, err := s.executor.executeQuery(batch) + iter, err := s.executor.executeQuery(b) if err != nil { - return &Iter{err: err} + return newErrIter(err, b.metrics, b.Keyspace(), b.getRoutingInfo(), b.getKeyspaceFunc()) } return iter @@ -749,7 +757,7 @@ func (s *Session) executeBatch(batch *Batch) *Iter { // ExecuteBatch executes a batch operation and returns nil if successful // otherwise an error is returned describing the failure. func (s *Session) ExecuteBatch(batch *Batch) error { - iter := s.executeBatch(batch) + iter := s.executeBatch(batch, nil) return iter.Close() } @@ -769,7 +777,7 @@ func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bo // Further scans on the interator must also remember to include // the applied boolean as the first argument to *Iter.Scan func (b *Batch) ExecCAS(dest ...interface{}) (applied bool, iter *Iter, err error) { - iter = b.session.executeBatch(b) + iter = b.session.executeBatch(b, nil) if err := iter.checkErrAndNotFound(); err != nil { iter.Close() return false, nil, err @@ -797,7 +805,7 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) // however it accepts a map rather than a list of arguments for the initial // scan. func (b *Batch) MapExecCAS(dest map[string]interface{}) (applied bool, iter *Iter, err error) { - iter = b.session.executeBatch(b) + iter = b.session.executeBatch(b, nil) if err := iter.checkErrAndNotFound(); err != nil { iter.Close() return false, nil, err @@ -835,6 +843,10 @@ type queryMetrics struct { totalAttempts int } +func newQueryMetrics() *queryMetrics { + return &queryMetrics{m: make(map[string]*hostMetrics)} +} + // preFilledQueryMetrics initializes new queryMetrics based on per-host supplied data. func preFilledQueryMetrics(m map[string]*hostMetrics) *queryMetrics { qm := &queryMetrics{m: m} @@ -921,15 +933,14 @@ func (qm *queryMetrics) attempt(addAttempts int, addLatency time.Duration, type Query struct { stmt string values []interface{} - cons Consistency + initialConsistency Consistency pageSize int routingKey []byte - pageState []byte + initialPageState []byte prefetch float64 trace Tracer observer QueryObserver session *Session - conn *Conn rt RetryPolicy spec SpeculativeExecutionPolicy binding func(q *QueryInfo) ([]interface{}, error) @@ -940,8 +951,6 @@ type Query struct { context context.Context idempotent bool customPayload map[string][]byte - metrics *queryMetrics - refCount uint32 disableAutoPage bool @@ -952,9 +961,6 @@ type Query struct { // tables in AWS MCS see skipPrepare bool - // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. - routingInfo *queryRoutingInfo - // hostID specifies the host on which the query should be executed. // If it is empty, then the host is picked by HostSelectionPolicy hostID string @@ -972,10 +978,22 @@ type queryRoutingInfo struct { table string } +func (qr *queryRoutingInfo) getKeyspace() string { + qr.mu.RLock() + defer qr.mu.RUnlock() + return qr.keyspace +} + +func (qr *queryRoutingInfo) getTable() string { + qr.mu.RLock() + defer qr.mu.RUnlock() + return qr.table +} + func (q *Query) defaultsFromSession() { s := q.session - q.cons = s.cons + q.initialConsistency = s.cons q.pageSize = s.pageSize q.trace = s.trace q.observer = s.queryObserver @@ -984,7 +1002,6 @@ func (q *Query) defaultsFromSession() { q.serialCons = s.cfg.SerialConsistency q.defaultTimestamp = s.cfg.DefaultTimestamp q.idempotent = s.cfg.DefaultIdempotence - q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} q.spec = &NonSpeculativeExecution{} } @@ -1002,47 +1019,30 @@ func (q Query) Values() []interface{} { // String implements the stringer interface. func (q Query) String() string { - return fmt.Sprintf("[query statement=%q values=%+v consistency=%s]", q.stmt, q.values, q.cons) -} - -// Attempts returns the number of times the query was executed. -func (q *Query) Attempts() int { - return q.metrics.attempts() -} - -func (q *Query) AddAttempts(i int, host *HostInfo) { - q.metrics.attempt(i, 0, host, false) -} - -// Latency returns the average amount of nanoseconds per attempt of the query. -func (q *Query) Latency() int64 { - return q.metrics.latency() -} - -func (q *Query) AddLatency(l int64, host *HostInfo) { - q.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) + return fmt.Sprintf("[query statement=%q values=%+v consistency=%s]", q.stmt, q.values, q.initialConsistency) } // Consistency sets the consistency level for this query. If no consistency // level have been set, the default consistency level of the cluster // is used. func (q *Query) Consistency(c Consistency) *Query { - q.cons = c + q.initialConsistency = c return q } // GetConsistency returns the currently configured consistency level for // the query. func (q *Query) GetConsistency() Consistency { - return q.cons + return q.initialConsistency } -// Same as Consistency but without a return value +// Deprecated: use Query.Consistency instead func (q *Query) SetConsistency(c Consistency) { - q.cons = c + q.initialConsistency = c } -// CustomPayload sets the custom payload level for this query. +// CustomPayload sets the custom payload level for this query. The map is not copied internally +// so it shouldn't be modified after the query is scheduled for execution. func (q *Query) CustomPayload(customPayload map[string][]byte) *Query { q.customPayload = customPayload return q @@ -1108,11 +1108,8 @@ func (q *Query) RoutingKey(routingKey []byte) *Query { return q } -func (q *Query) withContext(ctx context.Context) ExecutableQuery { - // I really wish go had covariant types - return q.WithContext(ctx) -} - +// Deprecated: Use Query.ExecContext or Query.IterContext instead. This will be removed in a future major version. +// // WithContext returns a shallow copy of q with its context // set to ctx. // @@ -1125,47 +1122,11 @@ func (q *Query) WithContext(ctx context.Context) *Query { return &q2 } -// Deprecate: does nothing, cancel the context passed to WithContext -func (q *Query) Cancel() { - // TODO: delete -} - -func (q *Query) execute(ctx context.Context, conn *Conn) *Iter { - return conn.executeQuery(ctx, q) -} - -func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { - latency := end.Sub(start) - attempt, metricsForHost := q.metrics.attempt(1, latency, host, q.observer != nil) - - if q.observer != nil { - q.observer.ObserveQuery(q.Context(), ObservedQuery{ - Keyspace: keyspace, - Statement: q.stmt, - Values: q.values, - Start: start, - End: end, - Rows: iter.numRows, - Host: host, - Metrics: metricsForHost, - Err: iter.err, - Attempt: attempt, - }) - } -} - -func (q *Query) retryPolicy() RetryPolicy { - return q.rt -} - // Keyspace returns the keyspace the query will be executed against. func (q *Query) Keyspace() string { if q.getKeyspace != nil { return q.getKeyspace() } - if q.routingInfo.keyspace != "" { - return q.routingInfo.keyspace - } if q.keyspace != "" { return q.keyspace } @@ -1178,45 +1139,12 @@ func (q *Query) Keyspace() string { return q.session.cfg.Keyspace } -// Table returns name of the table the query will be executed against. -func (q *Query) Table() string { - return q.routingInfo.table -} - -// GetRoutingKey gets the routing key to use for routing this query. If -// a routing key has not been explicitly set, then the routing key will -// be constructed if possible using the keyspace's schema and the query -// info for this query statement. If the routing key cannot be determined -// then nil will be returned with no error. On any error condition, -// an error description will be returned. -func (q *Query) GetRoutingKey() ([]byte, error) { - if q.routingKey != nil { - return q.routingKey, nil - } else if q.binding != nil && len(q.values) == 0 { - // If this query was created using session.Bind we wont have the query - // values yet, so we have to pass down to the next policy. - // TODO: Remove this and handle this case - return nil, nil - } - - // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace) - if err != nil { - return nil, err - } - - if routingKeyInfo != nil { - q.routingInfo.mu.Lock() - q.routingInfo.keyspace = routingKeyInfo.keyspace - q.routingInfo.table = routingKeyInfo.table - q.routingInfo.mu.Unlock() - } - return createRoutingKey(routingKeyInfo, q.values) -} - func (q *Query) shouldPrepare() bool { + return shouldPrepare(q.stmt) +} - stmt := strings.TrimLeftFunc(strings.TrimRightFunc(q.stmt, func(r rune) bool { +func shouldPrepare(s string) bool { + stmt := strings.TrimLeftFunc(strings.TrimRightFunc(s, func(r rune) bool { return unicode.IsSpace(r) || r == ';' }), unicode.IsSpace) @@ -1256,11 +1184,6 @@ func (q *Query) SetSpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Qu return q } -// speculativeExecutionPolicy fetches the policy -func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy { - return q.spec -} - // IsIdempotent returns whether the query is marked as idempotent. // Non-idempotent query won't be retried. // See "Retries and speculative execution" in package docs for more details. @@ -1281,7 +1204,6 @@ func (q *Query) Idempotent(value bool) *Query { // to an existing query instance. func (q *Query) Bind(v ...interface{}) *Query { q.values = v - q.pageState = nil return q } @@ -1302,7 +1224,7 @@ func (q *Query) SerialConsistency(cons Consistency) *Query { // point in time. Setting this will disable to query paging for this query, and // must be used for all subsequent pages. func (q *Query) PageState(state []byte) *Query { - q.pageState = state + q.initialPageState = state q.disableAutoPage = true return q } @@ -1325,6 +1247,11 @@ func (q *Query) Exec() error { return q.Iter().Close() } +// ExecContext executes the query with the provided context without returning any rows. +func (q *Query) ExecContext(ctx context.Context) error { + return q.IterContext(ctx).Close() +} + func isUseStatement(stmt string) bool { if len(stmt) < 3 { return false @@ -1336,15 +1263,25 @@ func isUseStatement(stmt string) bool { // Iter executes the query and returns an iterator capable of iterating // over all results. func (q *Query) Iter() *Iter { + return q.IterContext(nil) +} + +// IterContext executes the query with the provided context and returns an iterator capable of iterating +// over all results. +func (q *Query) IterContext(ctx context.Context) *Iter { if isUseStatement(q.stmt) { - return &Iter{err: ErrUseStmt} + return newErrIter(ErrUseStmt, newQueryMetrics(), q.Keyspace(), nil, q.getKeyspace) } - // if the query was specifically run on a connection then re-use that - // connection when fetching the next results - if q.conn != nil { - return q.conn.executeQuery(q.Context(), q) - } - return q.session.executeQuery(q) + + internalQry := newInternalQuery(q, ctx) + return q.session.executeQuery(internalQry) +} + +func (q *Query) iterInternal(c *Conn, ctx context.Context) *Iter { + internalQry := newInternalQuery(q, ctx) + internalQry.conn = c + + return c.executeQuery(internalQry.Context(), internalQry) } // MapScan executes the query, copies the columns of the first selected @@ -1421,43 +1358,6 @@ func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error return applied, iter.Close() } -// Release releases a query back into a pool of queries. Released Queries -// cannot be reused. -// -// Example: -// -// qry := session.Query("SELECT * FROM my_table") -// qry.Exec() -// qry.Release() -func (q *Query) Release() { - q.decRefCount() -} - -// reset zeroes out all fields of a query so that it can be safely pooled. -func (q *Query) reset() { - *q = Query{routingInfo: &queryRoutingInfo{}, refCount: 1} -} - -func (q *Query) incRefCount() { - atomic.AddUint32(&q.refCount, 1) -} - -func (q *Query) decRefCount() { - if res := atomic.AddUint32(&q.refCount, ^uint32(0)); res == 0 { - // do release - q.reset() - queryPool.Put(q) - } -} - -func (q *Query) borrowForExecution() { - q.incRefCount() -} - -func (q *Query) releaseAfterExecution() { - q.decRefCount() -} - // SetHostID allows to define the host the query should be executed against. If the // host was filtered or otherwise unavailable, then the query will error. If an empty // string is sent, the default behavior, using the configured HostSelectionPolicy will @@ -1490,9 +1390,13 @@ func (q *Query) WithNowInSeconds(now int) *Query { return q } -// Iter represents an iterator that can be used to iterate over all rows that -// were returned by a query. The iterator might send additional queries to the +// Iter represents the result that was returned by the execution of a statement. +// +// If the statement is a query then this can be seen as an iterator that can be used to iterate over all rows that +// were returned by the query. The iterator might send additional queries to the // database during the iteration if paging was enabled. +// +// It also contains metadata about the request that can be accessed by Iter.Keyspace(), Iter.Table(), Iter.Attempts(), Iter.Latency(). type Iter struct { err error pos int @@ -1500,12 +1404,27 @@ type Iter struct { numRows int next *nextIter host *HostInfo + metrics *queryMetrics + + getKeyspace func() string + keyspace string + routingInfo *queryRoutingInfo framer *framer closed int32 } -// Host returns the host which the query was sent to. +func newErrIter(err error, metrics *queryMetrics, keyspace string, routingInfo *queryRoutingInfo, getKeyspace func() string) *Iter { + iter := newIter(metrics, keyspace, routingInfo, getKeyspace) + iter.err = err + return iter +} + +func newIter(metrics *queryMetrics, keyspace string, routingInfo *queryRoutingInfo, getKeyspace func() string) *Iter { + return &Iter{metrics: metrics, keyspace: keyspace, routingInfo: routingInfo, getKeyspace: getKeyspace} +} + +// Host returns the host which the statement was sent to. func (iter *Iter) Host() *HostInfo { return iter.host } @@ -1515,6 +1434,39 @@ func (iter *Iter) Columns() []ColumnInfo { return iter.meta.columns } +// Attempts returns the number of times the statement was executed. +func (iter *Iter) Attempts() int { + return iter.metrics.attempts() +} + +// Latency returns the average amount of nanoseconds per attempt of the statement. +func (iter *Iter) Latency() int64 { + return iter.metrics.latency() +} + +// Keyspace returns the keyspace the statement was executed against if the driver could determine it. +func (iter *Iter) Keyspace() string { + if iter.getKeyspace != nil { + return iter.getKeyspace() + } + + if iter.routingInfo != nil { + if ks := iter.routingInfo.getKeyspace(); ks != "" { + return ks + } + } + + return iter.keyspace +} + +// Table returns name of the table the statement was executed against if the driver could determine it. +func (iter *Iter) Table() string { + if iter.routingInfo != nil { + return iter.routingInfo.getTable() + } + return "" +} + type Scanner interface { // Next advances the row pointer to point at the next row, the row is valid until // the next call of Next. It returns true if there is a row which is available to be @@ -1765,7 +1717,7 @@ func (iter *Iter) NumRows() int { // nextIter holds state for fetching a single page in an iterator. // single page might be attempted multiple times due to retries. type nextIter struct { - qry *Query + q *internalQuery pos int oncea sync.Once once sync.Once @@ -1782,10 +1734,10 @@ func (n *nextIter) fetch() *Iter { n.once.Do(func() { // if the query was specifically run on a connection then re-use that // connection when fetching the next results - if n.qry.conn != nil { - n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry) + if n.q.conn != nil { + n.next = n.q.conn.executeQuery(n.q.qryOpts.context, n.q) } else { - n.next = n.qry.session.executeQuery(n.qry) + n.next = n.q.session.executeQuery(n.q) } }) return n.next @@ -1806,17 +1758,14 @@ type Batch struct { defaultTimestamp bool defaultTimestampValue int64 context context.Context - cancelBatch func() keyspace string - metrics *queryMetrics nowInSeconds *int - - // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. - routingInfo *queryRoutingInfo } // Deprecated: use Session.Batch instead // NewBatch creates a new batch operation using defaults defined in the cluster +// +// Deprecated: use Session.Batch instead func (s *Session) NewBatch(typ BatchType) *Batch { return s.Batch(typ) } @@ -1833,9 +1782,7 @@ func (s *Session) Batch(typ BatchType) *Batch { Cons: s.cons, defaultTimestamp: s.cfg.DefaultTimestamp, keyspace: s.cfg.Keyspace, - metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, spec: &NonSpeculativeExecution{}, - routingInfo: &queryRoutingInfo{}, } return batch @@ -1859,27 +1806,12 @@ func (b *Batch) Keyspace() string { return b.keyspace } -// Batch has no reasonable eqivalent of Query.Table(). -func (b *Batch) Table() string { - return b.routingInfo.table -} - -// Attempts returns the number of attempts made to execute the batch. -func (b *Batch) Attempts() int { - return b.metrics.attempts() -} - -func (b *Batch) AddAttempts(i int, host *HostInfo) { - b.metrics.attempt(i, 0, host, false) -} - -// Latency returns the average number of nanoseconds to execute a single attempt of the batch. -func (b *Batch) Latency() int64 { - return b.metrics.latency() -} - -func (b *Batch) AddLatency(l int64, host *HostInfo) { - b.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) +// Consistency sets the consistency level for this batch. If no consistency +// level have been set, the default consistency level of the cluster +// is used. +func (b *Batch) Consistency(cons Consistency) *Batch { + b.Cons = cons + return b } // GetConsistency returns the currently configured consistency level for the batch @@ -1888,8 +1820,7 @@ func (b *Batch) GetConsistency() Consistency { return b.Cons } -// SetConsistency sets the currently configured consistency level for the batch -// operation. +// Deprecated: Use Batch.Consistency func (b *Batch) SetConsistency(c Consistency) { b.Cons = c } @@ -1932,20 +1863,14 @@ func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) ([]interface{}, error) b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind}) } -func (b *Batch) retryPolicy() RetryPolicy { - return b.rt -} - // RetryPolicy sets the retry policy to use when executing the batch operation func (b *Batch) RetryPolicy(r RetryPolicy) *Batch { b.rt = r return b } -func (b *Batch) withContext(ctx context.Context) ExecutableQuery { - return b.WithContext(ctx) -} - +// Deprecated: Use Batch.ExecContext or Batch.IterContext instead. This will be removed in a future major version. +// // WithContext returns a shallow copy of b with its context // set to ctx. // @@ -1958,11 +1883,6 @@ func (b *Batch) WithContext(ctx context.Context) *Batch { return &b2 } -// Deprecate: does nothing, cancel the context passed to WithContext -func (*Batch) Cancel() { - // TODO: delete -} - // Size returns the number of batch statements to be executed by the batch operation. func (b *Batch) Size() int { return len(b.Entries) @@ -2006,59 +1926,6 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch { return b } -func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { - latency := end.Sub(start) - attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil) - - if b.observer == nil { - return - } - - statements := make([]string, len(b.Entries)) - values := make([][]interface{}, len(b.Entries)) - - for i, entry := range b.Entries { - statements[i] = entry.Stmt - values[i] = entry.Args - } - - b.observer.ObserveBatch(b.Context(), ObservedBatch{ - Keyspace: keyspace, - Statements: statements, - Values: values, - Start: start, - End: end, - // Rows not used in batch observations // TODO - might be able to support it when using BatchCAS - Host: host, - Metrics: metricsForHost, - Err: iter.err, - Attempt: attempt, - }) -} - -func (b *Batch) GetRoutingKey() ([]byte, error) { - if b.routingKey != nil { - return b.routingKey, nil - } - - if len(b.Entries) == 0 { - return nil, nil - } - - entry := b.Entries[0] - if entry.binding != nil { - // bindings do not have the values let's skip it like Query does. - return nil, nil - } - // try to determine the routing key - routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace) - if err != nil { - return nil, err - } - - return createRoutingKey(routingKeyInfo, entry.Args) -} - func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]byte, error) { if routingKeyInfo == nil { return nil, nil @@ -2096,21 +1963,6 @@ func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]b return routingKey, nil } -func (b *Batch) borrowForExecution() { - // empty, because Batch has no equivalent of Query.Release() - // that would race with speculative executions. -} - -func (b *Batch) releaseAfterExecution() { - // empty, because Batch has no equivalent of Query.Release() - // that would race with speculative executions. -} - -// GetHostID satisfies ExecutableQuery interface but does noop. -func (b *Batch) GetHostID() string { - return "" -} - // SetKeyspace will enable keyspace flag on the query. // It allows to specify the keyspace that the query should be executed in // @@ -2295,6 +2147,9 @@ type ObservedQuery struct { // Attempt is the index of attempt at executing this query. // The first attempt is number zero and any retries have non-zero attempt number. Attempt int + + // Query object associated with this request. Should be used as read only. + Query *Query } // QueryObserver is the interface implemented by query observers / stat collectors. @@ -2332,6 +2187,9 @@ type ObservedBatch struct { // Attempt is the index of attempt at executing this query. // The first attempt is number zero and any retries have non-zero attempt number. Attempt int + + // Batch object associated with this request. Should be used as read only. + Batch *Batch } // BatchObserver is the interface implemented by batch observers / stat collectors. diff --git a/session_test.go b/session_test.go index 48f7fe7dc..d414d6f41 100644 --- a/session_test.go +++ b/session_test.go @@ -30,7 +30,6 @@ package gocql import ( "context" "fmt" - "net" "testing" ) @@ -69,7 +68,7 @@ func TestSessionAPI(t *testing.T) { t.Fatalf("expected qry.stmt to be 'test', got '%v'", boundQry.stmt) } - itr := s.executeQuery(qry) + itr := s.executeQuery(newInternalQuery(qry, nil)) if itr.err != ErrNoConnections { t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err) } @@ -102,28 +101,7 @@ func (f funcQueryObserver) ObserveQuery(ctx context.Context, o ObservedQuery) { } func TestQueryBasicAPI(t *testing.T) { - qry := &Query{routingInfo: &queryRoutingInfo{}} - - // Initiate host - ip := "127.0.0.1" - - qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 0, TotalLatency: 0}}) - if qry.Latency() != 0 { - t.Fatalf("expected Query.Latency() to return 0, got %v", qry.Latency()) - } - - qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 2, TotalLatency: 4}}) - if qry.Attempts() != 2 { - t.Fatalf("expected Query.Attempts() to return 2, got %v", qry.Attempts()) - } - if qry.Latency() != 2 { - t.Fatalf("expected Query.Latency() to return 2, got %v", qry.Latency()) - } - - qry.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) - if qry.Attempts() != 4 { - t.Fatalf("expected Query.Attempts() to return 4, got %v", qry.Attempts()) - } + qry := &Query{} qry.Consistency(All) if qry.GetConsistency() != All { @@ -166,7 +144,7 @@ func TestQueryBasicAPI(t *testing.T) { func TestQueryShouldPrepare(t *testing.T) { toPrepare := []string{"select * ", "INSERT INTO", "update table", "delete from", "begin batch"} cantPrepare := []string{"create table", "USE table", "LIST keyspaces", "alter table", "drop table", "grant user", "revoke user"} - q := &Query{routingInfo: &queryRoutingInfo{}} + q := &Query{} for i := 0; i < len(toPrepare); i++ { q.stmt = toPrepare[i] @@ -210,29 +188,6 @@ func TestBatchBasicAPI(t *testing.T) { t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type) } - ip := "127.0.0.1" - - // Test attempts - b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1}}) - if b.Attempts() != 1 { - t.Fatalf("expected batch.Attempts() to return %v, got %v", 1, b.Attempts()) - } - - b.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) - if b.Attempts() != 3 { - t.Fatalf("expected batch.Attempts() to return %v, got %v", 3, b.Attempts()) - } - - // Test latency - if b.Latency() != 0 { - t.Fatalf("expected batch.Latency() to be 0, got %v", b.Latency()) - } - - b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1, TotalLatency: 4}}) - if b.Latency() != 4 { - t.Fatalf("expected batch.Latency() to return %v, got %v", 4, b.Latency()) - } - // Test Consistency b.Cons = One if b.GetConsistency() != One { @@ -277,6 +232,43 @@ func TestBatchBasicAPI(t *testing.T) { } +func TestQueryIterBasicApi(t *testing.T) { + session := createSession(t) + defer session.Close() + + qry := session.Query("INSERT INTO gocql_test.invalid_table(value) VALUES(1)") + iter1 := qry.Iter() + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to return 1, got %v", iter1.Attempts()) + } + iter2 := qry.Iter() + if iter2.Attempts() != 1 { + t.Fatalf("expected iter2 Iter.Attempts() to return 1, got %v", iter2.Attempts()) + } + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to still return 1, got %v", iter1.Attempts()) + } +} + +func TestBatchIterBasicApi(t *testing.T) { + session := createSession(t) + defer session.Close() + + b := session.Batch(LoggedBatch) + b.Query("INSERT INTO gocql_test.invalid_table(value) VALUES(1)") + iter1 := b.Iter() + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to return 1, got %v", iter1.Attempts()) + } + iter2 := b.Iter() + if iter2.Attempts() != 1 { + t.Fatalf("expected iter2 Iter.Attempts() to return 1, got %v", iter2.Attempts()) + } + if iter1.Attempts() != 1 { + t.Fatalf("expected iter1 Iter.Attempts() to still return 1, got %v", iter1.Attempts()) + } +} + func TestConsistencyNames(t *testing.T) { names := map[fmt.Stringer]string{ Any: "ANY", diff --git a/stress_test.go b/stress_test.go index be41707aa..bf4f24874 100644 --- a/stress_test.go +++ b/stress_test.go @@ -29,7 +29,6 @@ package gocql import ( "sync/atomic" - "testing" ) @@ -82,7 +81,7 @@ func BenchmarkConnRoutingKey(b *testing.B) { query := session.Query("insert into routing_key_stress (id) values (?)") for pb.Next() { - if _, err := query.Bind(i * seed).GetRoutingKey(); err != nil { + if _, err := newInternalQuery(query.Bind(i*seed), nil).GetRoutingKey(); err != nil { b.Error(err) return }