From 291756bb4787d11d3b067006a6fc2308730e61dc Mon Sep 17 00:00:00 2001 From: James Hartig Date: Thu, 23 Oct 2025 18:01:15 +0000 Subject: [PATCH] CASSGO-92: add public method to retrieve StatementMetadata and LogField methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit StatementMetadata can be called on a session to get the bind, result, and pk information for a given query. Previously this wasn't publicly exposed but is necessary for some implementations of HostSelectionPolicy like token-aware. This might also be useful for CI tooling or runtime analysis of queries and the types of columns. NewLogField* are methods to to return a LogField with name and a specific type. Finally, session init was cleaned up to prevent a HostSelectionPolicy from causing a panic if it tried to make a query during init. The interface was documented that queries should not be attempted. Patch by James Hartig for CASSGO-92; reviewed by João Reis for CASSGO-92 --- CHANGELOG.md | 11 ++ cassandra_test.go | 139 +++++++++++----------- cluster.go | 4 +- conn.go | 16 +-- conn_test.go | 14 +-- connectionpool.go | 18 +-- control.go | 62 +++++----- events.go | 18 +-- frame.go | 12 +- host_source.go | 10 +- logger.go | 20 ++-- policies.go | 12 +- query_executor.go | 20 ++-- session.go | 296 +++++++++++++++++++++++----------------------- topology.go | 6 +- 15 files changed, 342 insertions(+), 316 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fa50aa00..66a6067e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.1.0] + +### Added + +- Session.StatementMetadata (CASSGO-92) +- NewLogFieldIP, NewLogFieldError, NewLogFieldStringer, NewLogFieldString, NewLogFieldInt, NewLogFieldBool (CASSGO-92) + +### Fixed + +- Prevent panic with queries during session init (CASSGO-92) + ## [2.0.0] ### Removed diff --git a/cassandra_test.go b/cassandra_test.go index bc260e440..937a7c054 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -2794,66 +2794,62 @@ func TestKeyspaceMetadata(t *testing.T) { } // Integration test of the routing key calculation -func TestRoutingKey(t *testing.T) { +func TestRoutingStatementMetadata(t *testing.T) { session := createSession(t) defer session.Close() - if err := createTable(session, "CREATE TABLE gocql_test.test_single_routing_key (first_id int, second_id int, PRIMARY KEY (first_id, second_id))"); err != nil { + if err := createTable(session, "CREATE TABLE gocql_test.test_single_routing_key (first_id int, second_id varchar, PRIMARY KEY (first_id, second_id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } - if err := createTable(session, "CREATE TABLE gocql_test.test_composite_routing_key (first_id int, second_id int, PRIMARY KEY ((first_id, second_id)))"); err != nil { + if err := createTable(session, "CREATE TABLE gocql_test.test_composite_routing_key (first_id int, second_id varchar, PRIMARY KEY ((first_id, second_id)))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } - routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") + meta, err := session.routingStatementMetadata(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { - t.Fatalf("failed to get routing key info due to error: %v", err) + t.Fatalf("failed to get routing statement metadata due to error: %v", err) } - if routingKeyInfo == nil { - t.Fatal("Expected routing key info, but was nil") + if meta == nil { + t.Fatal("Expected routing statement metadata, but was nil") } - if len(routingKeyInfo.indexes) != 1 { - t.Fatalf("Expected routing key indexes length to be 1 but was %d", len(routingKeyInfo.indexes)) + if len(meta.PKBindColumnIndexes) != 1 { + t.Fatalf("Expected routing statement metadata PKBindColumnIndexes length to be 1 but was %d", len(meta.PKBindColumnIndexes)) } - if routingKeyInfo.indexes[0] != 1 { - t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0]) + if meta.PKBindColumnIndexes[0] != 1 { + t.Errorf("Expected routing statement metadata PKBindColumnIndexes[0] to be 1 but was %d", meta.PKBindColumnIndexes[0]) } - if len(routingKeyInfo.types) != 1 { - t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types)) + if len(meta.BindColumns) != 2 { + t.Fatalf("Expected routing statement metadata BindColumns length to be 2 but was %d", len(meta.BindColumns)) } - if routingKeyInfo.types[0] == nil { - t.Fatal("Expected routing key types[0] to be non-nil") + if meta.BindColumns[0].TypeInfo.Type() != TypeVarchar { + t.Fatalf("Expected routing statement metadata BindColumns[0].TypeInfo.Type to be %v but was %v", TypeVarchar, meta.BindColumns[0].TypeInfo.Type()) } - if routingKeyInfo.types[0].Type() != TypeInt { - t.Fatalf("Expected routing key types[0].Type to be %v but was %v", TypeInt, routingKeyInfo.types[0].Type()) + if meta.BindColumns[1].TypeInfo.Type() != TypeInt { + t.Fatalf("Expected routing statement metadata BindColumns[1].TypeInfo.Type to be %v but was %v", TypeInt, meta.BindColumns[1].TypeInfo.Type()) } - - // verify the cache is working - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "") - if err != nil { - t.Fatalf("failed to get routing key info due to error: %v", err) + if len(meta.ResultColumns) != 2 { + t.Fatalf("Expected routing statement metadata ResultColumns length to be 2 but was %d", len(meta.ResultColumns)) } - if len(routingKeyInfo.indexes) != 1 { - t.Fatalf("Expected routing key indexes length to be 1 but was %d", len(routingKeyInfo.indexes)) + if meta.ResultColumns[0].Name != "first_id" { + t.Fatalf("Expected routing statement metadata ResultColumns[0].Name to be %v but was %v", "first_id", meta.ResultColumns[0].Name) } - if routingKeyInfo.indexes[0] != 1 { - t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0]) + if meta.ResultColumns[0].TypeInfo.Type() != TypeInt { + t.Fatalf("Expected routing statement metadata ResultColumns[0].TypeInfo.Type to be %v but was %v", TypeInt, meta.ResultColumns[0].TypeInfo.Type()) } - if len(routingKeyInfo.types) != 1 { - t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types)) + if meta.ResultColumns[1].Name != "second_id" { + t.Fatalf("Expected routing statement metadata ResultColumns[1].Name to be %v but was %v", "second_id", meta.ResultColumns[1].Name) } - if routingKeyInfo.types[0] == nil { - t.Fatal("Expected routing key types[0] to be non-nil") + if meta.ResultColumns[1].TypeInfo.Type() != TypeVarchar { + t.Fatalf("Expected routing statement metadata ResultColumns[1].TypeInfo.Type to be %v but was %v", TypeVarchar, meta.ResultColumns[1].TypeInfo.Type()) } - if routingKeyInfo.types[0].Type() != TypeInt { - t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[0].Type()) - } - cacheSize := session.routingKeyInfoCache.lru.Len() + + // verify the cache is working + cacheSize := session.routingMetadataCache.lru.Len() if cacheSize != 1 { t.Errorf("Expected cache size to be 1 but was %d", cacheSize) } - query := newInternalQuery(session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil) + query := newInternalQuery(session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "1", 2), nil) routingKey, err := query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) @@ -2863,50 +2859,59 @@ func TestRoutingKey(t *testing.T) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } - routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "") + meta, err = session.routingStatementMetadata(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "") if err != nil { - t.Fatalf("failed to get routing key info due to error: %v", err) + t.Fatalf("failed to get routing statement metadata due to error: %v", err) + } + if meta == nil { + t.Fatal("Expected routing statement metadata, but was nil") + } + if len(meta.PKBindColumnIndexes) != 2 { + t.Fatalf("Expected routing statement metadata PKBindColumnIndexes length to be 2 but was %d", len(meta.PKBindColumnIndexes)) + } + if meta.PKBindColumnIndexes[0] != 1 { + t.Errorf("Expected routing statement metadata PKBindColumnIndexes[0] to be 1 but was %d", meta.PKBindColumnIndexes[0]) } - if routingKeyInfo == nil { - t.Fatal("Expected routing key info, but was nil") + if meta.PKBindColumnIndexes[1] != 0 { + t.Errorf("Expected routing statement metadata PKBindColumnIndexes[1] to be 0 but was %d", meta.PKBindColumnIndexes[1]) } - if len(routingKeyInfo.indexes) != 2 { - t.Fatalf("Expected routing key indexes length to be 2 but was %d", len(routingKeyInfo.indexes)) + if len(meta.BindColumns) != 2 { + t.Fatalf("Expected routing statement metadata BindColumns length to be 2 but was %d", len(meta.BindColumns)) } - if routingKeyInfo.indexes[0] != 1 { - t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0]) + if meta.BindColumns[0].TypeInfo.Type() != TypeVarchar { + t.Fatalf("Expected routing statement metadata BindColumns[0].TypeInfo.Type to be %v but was %v", TypeVarchar, meta.BindColumns[0].TypeInfo.Type()) } - if routingKeyInfo.indexes[1] != 0 { - t.Errorf("Expected routing key index[1] to be 0 but was %d", routingKeyInfo.indexes[1]) + if meta.BindColumns[1].TypeInfo.Type() != TypeInt { + t.Fatalf("Expected routing statement metadata BindColumns[1].TypeInfo.Type to be %v but was %v", TypeInt, meta.BindColumns[1].TypeInfo.Type()) } - if len(routingKeyInfo.types) != 2 { - t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types)) + if len(meta.ResultColumns) != 2 { + t.Fatalf("Expected routing statement metadata ResultColumns length to be 2 but was %d", len(meta.ResultColumns)) } - if routingKeyInfo.types[0] == nil { - t.Fatal("Expected routing key types[0] to be non-nil") + if meta.ResultColumns[0].Name != "first_id" { + t.Fatalf("Expected routing statement metadata ResultColumns[0].Name to be %v but was %v", "first_id", meta.ResultColumns[0].Name) } - if routingKeyInfo.types[0].Type() != TypeInt { - t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[0].Type()) + if meta.ResultColumns[0].TypeInfo.Type() != TypeInt { + t.Fatalf("Expected routing statement metadata ResultColumns[0].TypeInfo.Type to be %v but was %v", TypeInt, meta.ResultColumns[0].TypeInfo.Type()) } - if routingKeyInfo.types[1] == nil { - t.Fatal("Expected routing key types[1] to be non-nil") + if meta.ResultColumns[1].Name != "second_id" { + t.Fatalf("Expected routing statement metadata ResultColumns[1].Name to be %v but was %v", "second_id", meta.ResultColumns[1].Name) } - if routingKeyInfo.types[1].Type() != TypeInt { - t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[1].Type()) + if meta.ResultColumns[1].TypeInfo.Type() != TypeVarchar { + t.Fatalf("Expected routing statement metadata ResultColumns[1].TypeInfo.Type to be %v but was %v", TypeVarchar, meta.ResultColumns[1].TypeInfo.Type()) } - query = newInternalQuery(session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil) + query = newInternalQuery(session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "1", 2), nil) routingKey, err = query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) } - expectedRoutingKey = []byte{0, 4, 0, 0, 0, 2, 0, 0, 4, 0, 0, 0, 1, 0} + expectedRoutingKey = []byte{0, 4, 0, 0, 0, 2, 0, 0, 1, 49, 0} if !reflect.DeepEqual(expectedRoutingKey, routingKey) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } // verify the cache is working - cacheSize = session.routingKeyInfoCache.lru.Len() + cacheSize = session.routingMetadataCache.lru.Len() if cacheSize != 2 { t.Errorf("Expected cache size to be 2 but was %d", cacheSize) } @@ -3956,17 +3961,17 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { t.Fatal(err) } - getRoutingKeyInfo := func(key string) *routingKeyInfo { + getStatementMetadata := func(key string) *StatementMetadata { t.Helper() - session.routingKeyInfoCache.mu.Lock() - value, ok := session.routingKeyInfoCache.lru.Get(key) + session.routingMetadataCache.mu.Lock() + value, ok := session.routingMetadataCache.lru.Get(key) if !ok { t.Fatalf("routing key not found in cache for key %v", key) } - session.routingKeyInfoCache.mu.Unlock() + session.routingMetadataCache.mu.Unlock() inflight := value.(*inflightCachedEntry) - return inflight.value.(*routingKeyInfo) + return inflight.value.(*StatementMetadata) } const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)" @@ -3979,8 +3984,8 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { require.NoError(t, err) // Ensuring that the cache contains the query with default ks - routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt) - require.Equal(t, "gocql_test", routingKeyInfo1.keyspace) + meta1 := getStatementMetadata("gocql_test" + b1.Entries[0].Stmt) + require.Equal(t, "gocql_test", meta1.Keyspace) // Running batch in gocql_test_routing_key_cache ks b2 := session.Batch(LoggedBatch) @@ -3991,8 +3996,8 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) { require.NoError(t, err) // Ensuring that the cache contains the query with gocql_test_routing_key_cache ks - routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" + b2.Entries[0].Stmt) - require.Equal(t, "gocql_test_routing_key_cache", routingKeyInfo2.keyspace) + meta2 := getStatementMetadata("gocql_test_routing_key_cache" + b2.Entries[0].Stmt) + require.Equal(t, "gocql_test_routing_key_cache", meta2.Keyspace) const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?" diff --git a/cluster.go b/cluster.go index aec8be64e..ebd4a1868 100644 --- a/cluster.go +++ b/cluster.go @@ -352,8 +352,8 @@ func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int, logger Str } newAddr, newPort := cfg.AddressTranslator.Translate(addr, port) logger.Debug("Translating address.", - newLogFieldIp("old_addr", addr), newLogFieldInt("old_port", port), - newLogFieldIp("new_addr", newAddr), newLogFieldInt("new_port", newPort)) + NewLogFieldIP("old_addr", addr), NewLogFieldInt("old_port", port), + NewLogFieldIP("new_addr", newAddr), NewLogFieldInt("new_port", newPort)) return newAddr, newPort } diff --git a/conn.go b/conn.go index ddf130818..40044565d 100644 --- a/conn.go +++ b/conn.go @@ -709,7 +709,7 @@ func (c *Conn) processFrame(ctx context.Context, r io.Reader) error { delete(c.calls, head.stream) c.mu.Unlock() if call == nil || !ok { - c.logger.Warning("Received response for stream which has no handler.", newLogFieldString("header", head.String())) + c.logger.Warning("Received response for stream which has no handler.", NewLogFieldString("header", head.String())) return c.discardFrame(r, head) } else if head.stream != call.streamID { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) @@ -1316,7 +1316,7 @@ func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer responseFrame, err := resp.framer.parseFrame() if err != nil { c.logger.Warning("Framer error while attempting to parse potential protocol error.", - newLogFieldError("err", err)) + NewLogFieldError("err", err)) return nil, errProtocol } //goland:noinspection GoTypeAssertionOnErrors @@ -1333,17 +1333,17 @@ func (c *Conn) execInternal(ctx context.Context, req frameBuilder, tracer Tracer case <-timeoutCh: close(call.timeout) c.logger.Debug("Request timed out on connection.", - newLogFieldString("host_id", c.host.HostID()), newLogFieldIp("addr", c.host.ConnectAddress())) + NewLogFieldString("host_id", c.host.HostID()), NewLogFieldIP("addr", c.host.ConnectAddress())) return nil, ErrTimeoutNoResponse case <-ctxDone: c.logger.Debug("Request failed because context elapsed out on connection.", - newLogFieldString("host_id", c.host.HostID()), newLogFieldIp("addr", c.host.ConnectAddress()), - newLogFieldError("ctx_err", ctx.Err())) + NewLogFieldString("host_id", c.host.HostID()), NewLogFieldIP("addr", c.host.ConnectAddress()), + NewLogFieldError("ctx_err", ctx.Err())) close(call.timeout) return nil, ctx.Err() case <-c.ctx.Done(): c.logger.Debug("Request failed because connection closed.", - newLogFieldString("host_id", c.host.HostID()), newLogFieldIp("addr", c.host.ConnectAddress())) + NewLogFieldString("host_id", c.host.HostID()), NewLogFieldIP("addr", c.host.ConnectAddress())) close(call.timeout) return nil, ErrConnectionClosed } @@ -1698,7 +1698,7 @@ func (c *Conn) executeQuery(ctx context.Context, q *internalQuery) *Iter { iter.framer = framer if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag - c.logger.Warning("Error while awaiting for schema agreement after a schema change event.", newLogFieldError("err", err)) + c.logger.Warning("Error while awaiting for schema agreement after a schema change event.", NewLogFieldError("err", err)) } // dont return an error from this, might be a good idea to give a warning // though. The impact of this returning an error would be that the cluster @@ -1956,7 +1956,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { goto cont } if !isValidPeer(host) || host.schemaVersion == "" { - c.logger.Warning("Invalid peer or peer with empty schema_version.", newLogFieldIp("peer", host.ConnectAddress())) + c.logger.Warning("Invalid peer or peer with empty schema_version.", NewLogFieldIP("peer", host.ConnectAddress())) continue } diff --git a/conn_test.go b/conn_test.go index 3646dcc85..60e4a2a8a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -356,13 +356,13 @@ func (o *testQueryObserver) ObserveQuery(ctx context.Context, q ObservedQuery) { host := q.Host.ConnectAddress().String() o.metrics[host] = q.Metrics o.logger.Debug("Observed query.", - newLogFieldString("stmt", q.Statement), - newLogFieldInt("rows", q.Rows), - newLogFieldString("duration", q.End.Sub(q.Start).String()), - newLogFieldString("host", host), - newLogFieldInt("attempts", q.Metrics.Attempts), - newLogFieldString("latency", strconv.FormatInt(q.Metrics.TotalLatency, 10)), - newLogFieldError("err", q.Err)) + NewLogFieldString("stmt", q.Statement), + NewLogFieldInt("rows", q.Rows), + NewLogFieldString("duration", q.End.Sub(q.Start).String()), + NewLogFieldString("host", host), + NewLogFieldInt("attempts", q.Metrics.Attempts), + NewLogFieldString("latency", strconv.FormatInt(q.Metrics.TotalLatency, 10)), + NewLogFieldError("err", q.Err)) } func (o *testQueryObserver) GetMetrics(host *HostInfo) *hostMetrics { diff --git a/connectionpool.go b/connectionpool.go index f316b5695..56ca53702 100644 --- a/connectionpool.go +++ b/connectionpool.go @@ -494,11 +494,11 @@ func (pool *hostConnPool) logConnectErr(err error) { // connection refused // these are typical during a node outage so avoid log spam. pool.logger.Debug("Pool unable to establish a connection to host.", - newLogFieldIp("host_addr", pool.host.ConnectAddress()), newLogFieldString("host_id", pool.host.HostID()), newLogFieldError("err", err)) + NewLogFieldIP("host_addr", pool.host.ConnectAddress()), NewLogFieldString("host_id", pool.host.HostID()), NewLogFieldError("err", err)) } else if err != nil { // unexpected error pool.logger.Debug("Pool failed to connect to host due to error.", - newLogFieldIp("host_addr", pool.host.ConnectAddress()), newLogFieldString("host_id", pool.host.HostID()), newLogFieldError("err", err)) + NewLogFieldIP("host_addr", pool.host.ConnectAddress()), NewLogFieldString("host_id", pool.host.HostID()), NewLogFieldError("err", err)) } } @@ -506,7 +506,7 @@ func (pool *hostConnPool) logConnectErr(err error) { func (pool *hostConnPool) fillingStopped(err error) { if err != nil { pool.logger.Warning("Connection pool filling failed.", - newLogFieldIp("host_addr", pool.host.ConnectAddress()), newLogFieldString("host_id", pool.host.HostID()), newLogFieldError("err", err)) + NewLogFieldIP("host_addr", pool.host.ConnectAddress()), NewLogFieldString("host_id", pool.host.HostID()), NewLogFieldError("err", err)) // wait for some time to avoid back-to-back filling // this provides some time between failed attempts // to fill the pool for the host to recover @@ -523,7 +523,7 @@ func (pool *hostConnPool) fillingStopped(err error) { // if we errored and the size is now zero, make sure the host is marked as down // see https://github.com/apache/cassandra-gocql-driver/issues/1614 pool.logger.Debug("Logging number of connections of pool after filling stopped.", - newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", host.HostID()), newLogFieldInt("count", count)) + NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()), NewLogFieldInt("count", count)) if err != nil && count == 0 { if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) { pool.session.handleNodeDown(host.ConnectAddress(), port) @@ -580,10 +580,10 @@ func (pool *hostConnPool) connect() (err error) { } } pool.logger.Warning("Pool failed to connect to host. Reconnecting according to the reconnection policy.", - newLogFieldIp("host", pool.host.ConnectAddress()), - newLogFieldString("host_id", pool.host.HostID()), - newLogFieldError("err", err), - newLogFieldString("reconnectionPolicy", fmt.Sprintf("%T", reconnectionPolicy))) + NewLogFieldIP("host", pool.host.ConnectAddress()), + NewLogFieldString("host_id", pool.host.HostID()), + NewLogFieldError("err", err), + NewLogFieldString("reconnectionPolicy", fmt.Sprintf("%T", reconnectionPolicy))) time.Sleep(reconnectionPolicy.GetInterval(i)) } @@ -631,7 +631,7 @@ func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) { } pool.logger.Info("Pool connection error.", - newLogFieldString("addr", conn.addr), newLogFieldError("err", err)) + NewLogFieldString("addr", conn.addr), NewLogFieldError("err", err)) // find the connection index for i, candidate := range pool.conns { diff --git a/control.go b/control.go index de374c589..e59acb402 100644 --- a/control.go +++ b/control.go @@ -105,7 +105,7 @@ func (c *controlConn) heartBeat() { resp, err := c.writeFrame(&writeOptionsFrame{}) if err != nil { - c.session.logger.Debug("Control connection failed to send heartbeat.", newLogFieldError("err", err)) + c.session.logger.Debug("Control connection failed to send heartbeat.", NewLogFieldError("err", err)) goto reconn } @@ -115,10 +115,10 @@ func (c *controlConn) heartBeat() { sleepTime = 5 * time.Second continue case error: - c.session.logger.Debug("Control connection heartbeat failed.", newLogFieldError("err", actualResp)) + c.session.logger.Debug("Control connection heartbeat failed.", NewLogFieldError("err", actualResp)) goto reconn default: - c.session.logger.Error("Unknown frame in response to options.", newLogFieldString("frame_type", fmt.Sprintf("%T", resp))) + c.session.logger.Error("Unknown frame in response to options.", NewLogFieldString("frame_type", fmt.Sprintf("%T", resp))) } reconn: @@ -270,18 +270,18 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { if err == nil { c.session.logger.Debug("Discovered protocol version using host.", - newLogFieldInt("protocol_version", connCfg.ProtoVersion), newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", host.HostID())) + NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) return connCfg.ProtoVersion, nil } if proto := parseProtocolFromError(err); proto > 0 { c.session.logger.Debug("Discovered protocol version using host after parsing protocol error.", - newLogFieldInt("protocol_version", proto), newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", host.HostID())) + NewLogFieldInt("protocol_version", proto), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) return proto, nil } c.session.logger.Debug("Failed to discover protocol version using host.", - newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", host.HostID()), newLogFieldError("err", err)) + NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err)) } return 0, err @@ -305,10 +305,10 @@ func (c *controlConn) connect(hosts []*HostInfo, sessionInit bool) error { conn, err = c.session.dial(c.session.ctx, host, &cfg, c) if err != nil { c.session.logger.Info("Control connection failed to establish a connection to host.", - newLogFieldIp("host_addr", host.ConnectAddress()), - newLogFieldInt("port", host.Port()), - newLogFieldString("host_id", host.HostID()), - newLogFieldError("err", err)) + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldInt("port", host.Port()), + NewLogFieldString("host_id", host.HostID()), + NewLogFieldError("err", err)) continue } err = c.setupConn(conn, sessionInit) @@ -316,10 +316,10 @@ func (c *controlConn) connect(hosts []*HostInfo, sessionInit bool) error { break } c.session.logger.Info("Control connection setup failed after connecting to host.", - newLogFieldIp("host_addr", host.ConnectAddress()), - newLogFieldInt("port", host.Port()), - newLogFieldString("host_id", host.HostID()), - newLogFieldError("err", err)) + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldInt("port", host.Port()), + NewLogFieldString("host_id", host.HostID()), + NewLogFieldError("err", err)) conn.Close() conn = nil } @@ -368,7 +368,7 @@ func (c *controlConn) setupConn(conn *Conn, sessionInit bool) error { msg = "Added control host (session initialization)." } logHelper(c.session.logger, logLevel, msg, - newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", host.HostID())) + NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) } if err := c.registerEvents(conn); err != nil { @@ -383,7 +383,7 @@ func (c *controlConn) setupConn(conn *Conn, sessionInit bool) error { c.conn.Store(ch) c.session.logger.Info("Control connection connected to host.", - newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", host.HostID())) + NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) if c.session.initialized() { // We connected to control conn, so add the connect the host in pool as well. @@ -445,14 +445,14 @@ func (c *controlConn) reconnect() { if err != nil { c.session.logger.Error("Unable to reconnect control connection.", - newLogFieldError("err", err)) + NewLogFieldError("err", err)) return } err = c.session.refreshRing() if err != nil { c.session.logger.Warning("Unable to refresh ring.", - newLogFieldError("err", err)) + NewLogFieldError("err", err)) } } @@ -482,7 +482,7 @@ func (c *controlConn) attemptReconnect() (*Conn, error) { return conn, err } - c.session.logger.Error("Unable to connect to any ring node, control connection falling back to initial contact points.", newLogFieldError("err", err)) + c.session.logger.Error("Unable to connect to any ring node, control connection falling back to initial contact points.", NewLogFieldError("err", err)) // Fallback to initial contact points, as it may be the case that all known initialHosts // changed their IPs while keeping the same hostname(s). initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger) @@ -500,10 +500,10 @@ func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, er conn, err = c.session.connect(c.session.ctx, host, c) if err != nil { c.session.logger.Info("During reconnection, control connection failed to establish a connection to host.", - newLogFieldIp("host_addr", host.ConnectAddress()), - newLogFieldInt("port", host.Port()), - newLogFieldString("host_id", host.HostID()), - newLogFieldError("err", err)) + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldInt("port", host.Port()), + NewLogFieldString("host_id", host.HostID()), + NewLogFieldError("err", err)) continue } err = c.setupConn(conn, false) @@ -511,10 +511,10 @@ func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, er break } c.session.logger.Info("During reconnection, control connection setup failed after connecting to host.", - newLogFieldIp("host_addr", host.ConnectAddress()), - newLogFieldInt("port", host.Port()), - newLogFieldString("host_id", host.HostID()), - newLogFieldError("err", err)) + NewLogFieldIP("host_addr", host.ConnectAddress()), + NewLogFieldInt("port", host.Port()), + NewLogFieldString("host_id", host.HostID()), + NewLogFieldError("err", err)) conn.Close() conn = nil } @@ -535,9 +535,9 @@ func (c *controlConn) HandleError(conn *Conn, err error, closed bool) { } c.session.logger.Warning("Control connection error.", - newLogFieldIp("host_addr", conn.host.ConnectAddress()), - newLogFieldString("host_id", conn.host.HostID()), - newLogFieldError("err", err)) + NewLogFieldIP("host_addr", conn.host.ConnectAddress()), + NewLogFieldString("host_id", conn.host.HostID()), + NewLogFieldError("err", err)) c.reconnect() } @@ -602,7 +602,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter if iter.err != nil { c.session.logger.Warning("Error executing control connection statement.", - newLogFieldString("statement", statement), newLogFieldError("err", iter.err)) + NewLogFieldString("statement", statement), NewLogFieldError("err", iter.err)) } qry.metrics.attempt(0) diff --git a/events.go b/events.go index 8f4bd1db8..d511d9aee 100644 --- a/events.go +++ b/events.go @@ -104,7 +104,7 @@ func (e *eventDebouncer) debounce(frame frame) { e.events = append(e.events, frame) } else { e.logger.Warning("Event buffer full, dropping event frame.", - newLogFieldString("event_name", e.name), newLogFieldStringer("frame", frame)) + NewLogFieldString("event_name", e.name), NewLogFieldStringer("frame", frame)) } e.mu.Unlock() @@ -113,11 +113,11 @@ func (e *eventDebouncer) debounce(frame frame) { func (s *Session) handleEvent(framer *framer) { frame, err := framer.parseFrame() if err != nil { - s.logger.Error("Unable to parse event frame.", newLogFieldError("err", err)) + s.logger.Error("Unable to parse event frame.", NewLogFieldError("err", err)) return } - s.logger.Debug("Handling event frame.", newLogFieldStringer("frame", frame)) + s.logger.Debug("Handling event frame.", NewLogFieldStringer("frame", frame)) switch f := frame.(type) { case *schemaChangeKeyspace, *schemaChangeFunction, @@ -128,7 +128,7 @@ func (s *Session) handleEvent(framer *framer) { s.nodeEvents.debounce(frame) default: s.logger.Error("Invalid event frame.", - newLogFieldString("frame_type", fmt.Sprintf("%T", f)), newLogFieldStringer("frame", f)) + NewLogFieldString("frame_type", fmt.Sprintf("%T", f)), NewLogFieldStringer("frame", f)) } } @@ -181,7 +181,7 @@ func (s *Session) handleNodeEvent(frames []frame) { switch f := frame.(type) { case *topologyChangeEventFrame: s.logger.Info("Received topology change event.", - newLogFieldString("frame", strings.Join([]string{f.change, "->", f.host.String(), ":", strconv.Itoa(f.port)}, ""))) + NewLogFieldString("frame", strings.Join([]string{f.change, "->", f.host.String(), ":", strconv.Itoa(f.port)}, ""))) topologyEventReceived = true case *statusChangeEventFrame: event, ok := sEvents[f.host.String()] @@ -199,7 +199,7 @@ func (s *Session) handleNodeEvent(frames []frame) { for _, f := range sEvents { s.logger.Info("Dispatching status change event.", - newLogFieldString("frame", strings.Join([]string{f.change, "->", f.host.String(), ":", strconv.Itoa(f.port)}, ""))) + NewLogFieldString("frame", strings.Join([]string{f.change, "->", f.host.String(), ":", strconv.Itoa(f.port)}, ""))) // ignore events we received if they were disabled // see https://github.com/apache/cassandra-gocql-driver/issues/1591 @@ -218,7 +218,7 @@ func (s *Session) handleNodeEvent(frames []frame) { func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) { s.logger.Info("Node is UP.", - newLogFieldStringer("event_ip", eventIp), newLogFieldInt("event_port", eventPort)) + NewLogFieldStringer("event_ip", eventIp), NewLogFieldInt("event_port", eventPort)) host, ok := s.ring.getHostByIP(eventIp.String()) if !ok { @@ -244,7 +244,7 @@ func (s *Session) startPoolFill(host *HostInfo) { func (s *Session) handleNodeConnected(host *HostInfo) { s.logger.Debug("Pool connected to node.", - newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldInt("port", host.Port()), newLogFieldString("host_id", host.HostID())) + NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldInt("port", host.Port()), NewLogFieldString("host_id", host.HostID())) host.setState(NodeUp) @@ -255,7 +255,7 @@ func (s *Session) handleNodeConnected(host *HostInfo) { func (s *Session) handleNodeDown(ip net.IP, port int) { s.logger.Warning("Node is DOWN.", - newLogFieldIp("host_addr", ip), newLogFieldInt("port", port)) + NewLogFieldIP("host_addr", ip), NewLogFieldInt("port", port)) host, ok := s.ring.getHostByIP(ip.String()) if ok { diff --git a/frame.go b/frame.go index 0ee6a2e18..403da877b 100644 --- a/frame.go +++ b/frame.go @@ -1046,12 +1046,11 @@ func (f *framer) readTypeInfo() (TypeInfo, error) { type preparedMetadata struct { resultMetadata - // proto v4+ - pkeyColumns []int - - keyspace string - - table string + // pkeyColumns is only present in protocol v4+ + pkeyColumns []int + supportsPKeyColumns bool + keyspace string + table string } func (r preparedMetadata) String() string { @@ -1090,6 +1089,7 @@ func (f *framer) parsePreparedMetadata() (preparedMetadata, error) { pkeys[i] = int(c) } meta.pkeyColumns = pkeys + meta.supportsPKeyColumns = true } if meta.flags&flagHasMorePages == flagHasMorePages { diff --git a/host_source.go b/host_source.go index ab6537409..622d22948 100644 --- a/host_source.go +++ b/host_source.go @@ -784,7 +784,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, er return nil, fmt.Errorf("unable to fetch peer host info: %s", iterErr) } // skip over peers that we couldn't parse - r.session.logger.Warning("Failed to parse peer this host will be ignored.", newLogFieldError("err", err)) + r.session.logger.Warning("Failed to parse peer this host will be ignored.", NewLogFieldError("err", err)) continue } // if nil then none left @@ -794,7 +794,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, er if !isValidPeer(host) { // If it's not a valid peer r.session.logger.Warning("Found invalid peer "+ - "likely due to a gossip or snitch issue, this host will be ignored.", newLogFieldStringer("host", host)) + "likely due to a gossip or snitch issue, this host will be ignored.", NewLogFieldStringer("host", host)) continue } @@ -866,7 +866,7 @@ func refreshRing(r *ringDescriber) error { } if host, ok := r.session.ring.addHostIfMissing(h); !ok { - r.session.logger.Info("Adding host.", newLogFieldIp("host_addr", h.ConnectAddress()), newLogFieldString("host_id", h.HostID())) + r.session.logger.Info("Adding host.", NewLogFieldIP("host_addr", h.ConnectAddress()), NewLogFieldString("host_id", h.HostID())) r.session.startPoolFill(h) } else { // host (by hostID) already exists; determine if IP has changed @@ -885,7 +885,7 @@ func refreshRing(r *ringDescriber) error { if _, alreadyExists := r.session.ring.addHostIfMissing(h); alreadyExists { return fmt.Errorf("add new host=%s after removal: %w", h, ErrHostAlreadyExists) } - r.session.logger.Info("Adding host with new IP after removing old host.", newLogFieldIp("host_addr", h.ConnectAddress()), newLogFieldString("host_id", h.HostID())) + r.session.logger.Info("Adding host with new IP after removing old host.", NewLogFieldIP("host_addr", h.ConnectAddress()), NewLogFieldString("host_id", h.HostID())) // add new HostInfo (same hostID, new IP) r.session.startPoolFill(h) } @@ -899,7 +899,7 @@ func refreshRing(r *ringDescriber) error { r.session.metadata.setPartitioner(partitioner) r.session.policy.SetPartitioner(partitioner) - r.session.logger.Info("Refreshed ring.", newLogFieldString("ring", ringString(r.session.ring.allHosts()))) + r.session.logger.Info("Refreshed ring.", NewLogFieldString("ring", ringString(r.session.ring.allHosts()))) return nil } diff --git a/logger.go b/logger.go index df865c849..ad8c2e5f0 100644 --- a/logger.go +++ b/logger.go @@ -48,7 +48,7 @@ func logHelper(logger StructuredLogger, level LogLevel, msg string, fields ...Lo case LogLevelError: logger.Error(msg, fields...) default: - logger.Error("Unknown log level", newLogFieldInt("level", int(level)), newLogFieldString("msg", msg)) + logger.Error("Unknown log level", NewLogFieldInt("level", int(level)), NewLogFieldString("msg", msg)) } } @@ -229,7 +229,8 @@ func newLogField(name string, value LogFieldValue) LogField { } } -func newLogFieldIp(name string, value net.IP) LogField { +// NewLogFieldIP creates a new LogField with the given name and net.IP. +func NewLogFieldIP(name string, value net.IP) LogField { var str string if value == nil { str = "" @@ -239,7 +240,8 @@ func newLogFieldIp(name string, value net.IP) LogField { return newLogField(name, logFieldValueString(str)) } -func newLogFieldError(name string, value error) LogField { +// NewLogFieldError creates a new LogField with the given name and error. +func NewLogFieldError(name string, value error) LogField { var str string if value != nil { str = value.Error() @@ -247,7 +249,8 @@ func newLogFieldError(name string, value error) LogField { return newLogField(name, logFieldValueString(str)) } -func newLogFieldStringer(name string, value fmt.Stringer) LogField { +// NewLogFieldStringer creates a new LogField with the given name and fmt.Stringer. +func NewLogFieldStringer(name string, value fmt.Stringer) LogField { var str string if value != nil { str = value.String() @@ -255,15 +258,18 @@ func newLogFieldStringer(name string, value fmt.Stringer) LogField { return newLogField(name, logFieldValueString(str)) } -func newLogFieldString(name string, value string) LogField { +// NewLogFieldString creates a new LogField with the given name and string. +func NewLogFieldString(name string, value string) LogField { return newLogField(name, logFieldValueString(value)) } -func newLogFieldInt(name string, value int) LogField { +// NewLogFieldInt creates a new LogField with the given name and int. +func NewLogFieldInt(name string, value int) LogField { return newLogField(name, logFieldValueInt64(int64(value))) } -func newLogFieldBool(name string, value bool) LogField { +// NewLogFieldBool creates a new LogField with the given name and bool. +func NewLogFieldBool(name string, value bool) LogField { return newLogField(name, logFieldValueBool(value)) } diff --git a/policies.go b/policies.go index 8f9680874..5fc61d781 100644 --- a/policies.go +++ b/policies.go @@ -304,9 +304,19 @@ type HostTierer interface { type HostSelectionPolicy interface { HostStateNotifier SetPartitioner + + // KeyspaceChanged is called when the driver receives a keyspace change event. KeyspaceChanged(KeyspaceUpdateEvent) + + // Init is called automatically during session creation so the policy can store + // a reference to the attached session. Notably the session is not usable yet + // when it's passed to this method. Init(*Session) + + // IsLocal should return true if the given Host is considered "local" by some + // criteria. "Local" hosts are preferred over non-local hosts. IsLocal(host *HostInfo) bool + // Pick returns an iteration function over selected hosts. // Multiple attempts of a single query execution won't call the returned NextHost function concurrently, // so it's safe to have internal state without additional synchronization as long as every call to Pick returns @@ -576,7 +586,7 @@ func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo, logg // create a new token ring tokenRing, err := newTokenRing(partitioner, hosts) if err != nil { - logger.Warning("Unable to update the token ring due to error.", newLogFieldError("err", err)) + logger.Warning("Unable to update the token ring due to error.", NewLogFieldError("err", err)) return } diff --git a/query_executor.go b/query_executor.go index 35422ffeb..2d7a62335 100644 --- a/query_executor.go +++ b/query_executor.go @@ -423,18 +423,18 @@ func (q *internalQuery) GetRoutingKey() ([]byte, error) { } // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.qryOpts.stmt, q.qryOpts.keyspace) + meta, err := q.session.routingStatementMetadata(q.Context(), q.qryOpts.stmt, q.qryOpts.keyspace) if err != nil { return nil, err } - if routingKeyInfo != nil { + if meta != nil { q.routingInfo.mu.Lock() - q.routingInfo.keyspace = routingKeyInfo.keyspace - q.routingInfo.table = routingKeyInfo.table + q.routingInfo.keyspace = meta.Keyspace + q.routingInfo.table = meta.Table q.routingInfo.mu.Unlock() } - return createRoutingKey(routingKeyInfo, q.qryOpts.values) + return createRoutingKey(meta, q.qryOpts.values) } func (q *internalQuery) Keyspace() string { @@ -643,19 +643,19 @@ func (b *internalBatch) GetRoutingKey() ([]byte, error) { return nil, nil } // try to determine the routing key - routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.batchOpts.keyspace) + meta, err := b.session.routingStatementMetadata(b.Context(), entry.Stmt, b.batchOpts.keyspace) if err != nil { return nil, err } - if routingKeyInfo != nil { + if meta != nil { b.routingInfo.mu.Lock() - b.routingInfo.keyspace = routingKeyInfo.keyspace - b.routingInfo.table = routingKeyInfo.table + b.routingInfo.keyspace = meta.Keyspace + b.routingInfo.table = meta.Table b.routingInfo.mu.Unlock() } - return createRoutingKey(routingKeyInfo, entry.Args) + return createRoutingKey(meta, entry.Args) } func (b *internalBatch) Keyspace() string { diff --git a/session.go b/session.go index 111c7d01c..264ed8922 100644 --- a/session.go +++ b/session.go @@ -51,21 +51,21 @@ import ( // and automatically sets a default consistency level on all operations // that do not have a consistency level set. type Session struct { - cons Consistency - pageSize int - prefetch float64 - routingKeyInfoCache routingKeyInfoLRU - schemaDescriber *schemaDescriber - trace Tracer - queryObserver QueryObserver - batchObserver BatchObserver - connectObserver ConnectObserver - frameObserver FrameHeaderObserver - streamObserver StreamObserver - hostSource *ringDescriber - ringRefresher *refreshDebouncer - stmtsLRU *preparedLRU - types *RegisteredTypes + cons Consistency + pageSize int + prefetch float64 + routingMetadataCache routingKeyInfoLRU + schemaDescriber *schemaDescriber + trace Tracer + queryObserver QueryObserver + batchObserver BatchObserver + connectObserver ConnectObserver + frameObserver FrameHeaderObserver + streamObserver StreamObserver + hostSource *ringDescriber + ringRefresher *refreshDebouncer + stmtsLRU *preparedLRU + types *RegisteredTypes connCfg *ConnConfig @@ -111,7 +111,7 @@ func addrsToHosts(addrs []string, defaultPort int, logger StructuredLogger) ([]* if err != nil { // Try other hosts if unable to resolve DNS name if _, ok := err.(*net.DNSError); ok { - logger.Error("DNS error.", newLogFieldError("err", err)) + logger.Error("DNS error.", NewLogFieldError("err", err)) continue } return nil, err @@ -167,24 +167,11 @@ func NewSession(cfg ClusterConfig) (*Session, error) { s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent, s.logger) s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent, s.logger) - s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) + s.routingMetadataCache.lru = lru.New(cfg.MaxRoutingKeyInfo) s.hostSource = &ringDescriber{session: s} s.ringRefresher = newRefreshDebouncer(ringRefreshDebounceTime, func() error { return refreshRing(s.hostSource) }) - if cfg.PoolConfig.HostSelectionPolicy == nil { - cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() - } - s.pool = cfg.PoolConfig.buildPool(s) - - s.policy = cfg.PoolConfig.HostSelectionPolicy - s.policy.Init(s) - - s.executor = &queryExecutor{ - pool: s.pool, - policy: cfg.PoolConfig.HostSelectionPolicy, - } - s.queryObserver = cfg.QueryObserver s.batchObserver = cfg.BatchObserver s.connectObserver = cfg.ConnectObserver @@ -199,6 +186,20 @@ func NewSession(cfg ClusterConfig) (*Session, error) { } s.connCfg = connCfg + if cfg.PoolConfig.HostSelectionPolicy == nil { + cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() + } + s.pool = cfg.PoolConfig.buildPool(s) + s.policy = cfg.PoolConfig.HostSelectionPolicy + + // set the executor here in case the policy needs to execute queries in Init + s.executor = &queryExecutor{ + pool: s.pool, + policy: cfg.PoolConfig.HostSelectionPolicy, + } + + s.policy.Init(s) + if err := s.init(); err != nil { s.Close() if err == ErrNoConnectionsStarted { @@ -234,7 +235,7 @@ func (s *Session) init() error { // TODO(zariel): we really only need this in 1 place s.cfg.ProtoVersion = proto s.connCfg.ProtoVersion = proto - s.logger.Info("Discovered protocol version.", newLogFieldInt("protocol_version", proto)) + s.logger.Info("Discovered protocol version.", NewLogFieldInt("protocol_version", proto)) } if err := s.control.connect(hosts, true); err != nil { @@ -256,7 +257,7 @@ func (s *Session) init() error { } hosts = filteredHosts - s.logger.Info("Refreshed ring.", newLogFieldString("ring", ringString(hosts))) + s.logger.Info("Refreshed ring.", NewLogFieldString("ring", ringString(hosts))) } else { s.logger.Info("Not performing a ring refresh because DisableInitialHostLookup is true.") } @@ -295,7 +296,7 @@ func (s *Session) init() error { } if !exists { s.logger.Info("Adding host (session initialization).", - newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", host.HostID())) + NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID())) } atomic.AddInt64(&left, 1) @@ -404,16 +405,16 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) { hosts := s.ring.allHosts() // Print session.ring for debug. - s.logger.Debug("Logging current ring state.", newLogFieldString("ring", ringString(hosts))) + s.logger.Debug("Logging current ring state.", NewLogFieldString("ring", ringString(hosts))) for _, h := range hosts { if h.IsUp() { continue } s.logger.Debug("Reconnecting to downed host.", - newLogFieldIp("host_addr", h.ConnectAddress()), - newLogFieldInt("host_port", h.Port()), - newLogFieldString("host_id", h.HostID())) + NewLogFieldIP("host_addr", h.ConnectAddress()), + NewLogFieldInt("host_port", h.Port()), + NewLogFieldString("host_id", h.HostID())) // we let the pool call handleNodeConnected to change the host state s.pool.addHost(h) } @@ -578,7 +579,7 @@ func (s *Session) executeQuery(qry *internalQuery) (it *Iter) { } func (s *Session) removeHost(h *HostInfo) { - s.logger.Warning("Removing host.", newLogFieldIp("host_addr", h.ConnectAddress()), newLogFieldString("host_id", h.HostID())) + s.logger.Warning("Removing host.", NewLogFieldIP("host_addr", h.ConnectAddress()), NewLogFieldString("host_id", h.HostID())) s.policy.RemoveHost(h) hostID := h.HostID() s.pool.removeHost(hostID) @@ -599,6 +600,7 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) { func (s *Session) getConn() *Conn { hosts := s.ring.allHosts() + for _, host := range hosts { if !host.IsUp() { continue @@ -615,23 +617,22 @@ func (s *Session) getConn() *Conn { return nil } -// Returns routing key indexes and type info. +// Returns statement metadata for the purposes of generating a routing key. // If keyspace == "" it uses the keyspace which is specified in Cluster.Keyspace -func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace string) (*routingKeyInfo, error) { +func (s *Session) routingStatementMetadata(ctx context.Context, stmt string, keyspace string) (*StatementMetadata, error) { if keyspace == "" { keyspace = s.cfg.Keyspace } - routingKeyInfoCacheKey := keyspace + stmt - - s.routingKeyInfoCache.mu.Lock() + key := keyspace + stmt + s.routingMetadataCache.mu.Lock() // Using here keyspace + stmt as a cache key because // the query keyspace could be overridden via SetKeyspace - entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey) + entry, cached := s.routingMetadataCache.lru.Get(key) if cached { // done accessing the cache - s.routingKeyInfoCache.mu.Unlock() + s.routingMetadataCache.mu.Unlock() // the entry is an inflight struct similar to that used by // Conn to prepare statements inflight := entry.(*inflightCachedEntry) @@ -643,7 +644,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace stri return nil, inflight.err } - key, _ := inflight.value.(*routingKeyInfo) + key, _ := inflight.value.(*StatementMetadata) return key, nil } @@ -652,114 +653,113 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace stri inflight := new(inflightCachedEntry) inflight.wg.Add(1) defer inflight.wg.Done() - s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight) - s.routingKeyInfoCache.mu.Unlock() - - var ( - info *preparedStatment - partitionKey []*ColumnMetadata - ) + s.routingMetadataCache.lru.Add(key, inflight) + s.routingMetadataCache.mu.Unlock() - conn := s.getConn() - if conn == nil { - // TODO: better error? - inflight.err = errors.New("gocql: unable to fetch prepared info: no connection available") - return nil, inflight.err - } - - // get the query info for the statement - info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace) + var meta StatementMetadata + meta, inflight.err = s.StatementMetadata(ctx, stmt, keyspace) if inflight.err != nil { // don't cache this error - s.routingKeyInfoCache.Remove(stmt) + s.routingMetadataCache.Remove(key) return nil, inflight.err } - // TODO: it would be nice to mark hosts here but as we are not using the policies - // to fetch hosts we cant + inflight.value = &meta - if info.request.colCount == 0 { - // no arguments, no routing key, and no error - return nil, nil - } + return &meta, nil +} - table := info.request.table - if info.request.keyspace != "" { - keyspace = info.request.keyspace - } +// StatementMetadata represents various metadata about a statement. +type StatementMetadata struct { + // Keyspace is the keyspace of the table for the statement. + Keyspace string - if len(info.request.pkeyColumns) > 0 { - // proto v4 dont need to calculate primary key columns - types := make([]TypeInfo, len(info.request.pkeyColumns)) - for i, col := range info.request.pkeyColumns { - types[i] = info.request.columns[col].TypeInfo - } + // Table is the table of the statement. + Table string - routingKeyInfo := &routingKeyInfo{ - indexes: info.request.pkeyColumns, - types: types, - keyspace: keyspace, - table: table, - } + // BindColumns are columns bound to the statement. + BindColumns []ColumnInfo - inflight.value = routingKeyInfo - return routingKeyInfo, nil - } + // PKBindColumnIndexes are the indexes of the BindColumns that correspond to + // partition key columns. If this is empty then one or more columns in the + // partition key were not bound to the statement. + PKBindColumnIndexes []int - var keyspaceMetadata *KeyspaceMetadata - keyspaceMetadata, inflight.err = s.KeyspaceMetadata(info.request.columns[0].Keyspace) - if inflight.err != nil { - // don't cache this error - s.routingKeyInfoCache.Remove(stmt) - return nil, inflight.err + // ResultColumns are the columns that are returned by the statement. + ResultColumns []ColumnInfo +} + +// StatementMetadata returns metadata for a statement. If keyspace is empty, +// the session's keyspace is used. +func (s *Session) StatementMetadata(ctx context.Context, stmt, keyspace string) (StatementMetadata, error) { + if keyspace == "" { + keyspace = s.cfg.Keyspace } - tableMetadata, found := keyspaceMetadata.Tables[table] - if !found { - // unlikely that the statement could be prepared and the metadata for - // the table couldn't be found, but this may indicate either a bug - // in the metadata code, or that the table was just dropped. - inflight.err = ErrNoMetadata - // don't cache this error - s.routingKeyInfoCache.Remove(stmt) - return nil, inflight.err + conn := s.getConn() + if conn == nil { + return StatementMetadata{}, ErrNoConnections } - partitionKey = tableMetadata.PartitionKey + // get the query info for the statement + info, err := conn.prepareStatement(ctx, stmt, nil, keyspace) + if err != nil { + // TODO: it would be nice to mark hosts here but as we are not using the policies + // to fetch hosts we cant and we can't use the policies because they might + // require token awareness which requires this method + return StatementMetadata{}, err + } - size := len(partitionKey) - routingKeyInfo := &routingKeyInfo{ - indexes: make([]int, size), - types: make([]TypeInfo, size), - keyspace: keyspace, - table: table, + if info.request.keyspace != "" { + keyspace = info.request.keyspace } - for keyIndex, keyColumn := range partitionKey { - // set an indicator for checking if the mapping is missing - routingKeyInfo.indexes[keyIndex] = -1 + meta := StatementMetadata{ + Keyspace: keyspace, + Table: info.request.table, + BindColumns: info.request.columns, + PKBindColumnIndexes: info.request.pkeyColumns, + ResultColumns: info.response.columns, + } - // find the column in the query info - for argIndex, boundColumn := range info.request.columns { - if keyColumn.Name == boundColumn.Name { - // there may be many such bound columns, pick the first - routingKeyInfo.indexes[keyIndex] = argIndex - routingKeyInfo.types[keyIndex] = boundColumn.TypeInfo - break - } + // if it is protocol < v4 then we need to calculate the routing key info + if !info.request.supportsPKeyColumns && len(info.request.columns) > 0 { + keyspaceMetadata, err := s.KeyspaceMetadata(meta.Keyspace) + if err != nil { + // don't cache this error + return StatementMetadata{}, err } - if routingKeyInfo.indexes[keyIndex] == -1 { - // missing a routing key column mapping - // no routing key, and no error - return nil, nil + tableMetadata, found := keyspaceMetadata.Tables[meta.Table] + if !found { + // unlikely that the statement could be prepared and the metadata for + // the table couldn't be found, but this may indicate either a bug + // in the metadata code, or that the table was just dropped. + return StatementMetadata{}, ErrNoMetadata } - } - // cache this result - inflight.value = routingKeyInfo + meta.PKBindColumnIndexes = make([]int, len(tableMetadata.PartitionKey)) + for keyIndex, keyColumn := range tableMetadata.PartitionKey { + // set an indicator for checking if the mapping is missing + meta.PKBindColumnIndexes[keyIndex] = -1 + + // find the column in the query info + for colIndex, boundColumn := range info.request.columns { + if keyColumn.Name == boundColumn.Name { + // there may be many such bound columns, pick the first + meta.PKBindColumnIndexes[keyIndex] = colIndex + break + } + } - return routingKeyInfo, nil + if meta.PKBindColumnIndexes[keyIndex] == -1 { + // the partition key column is not bound to the statement + meta.PKBindColumnIndexes = nil + break + } + } + } + return meta, nil } // Exec executes a batch operation and returns nil if successful @@ -2102,16 +2102,20 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch { return b } -func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]byte, error) { - if routingKeyInfo == nil { +func createRoutingKey(meta *StatementMetadata, values []interface{}) ([]byte, error) { + if meta == nil || len(meta.PKBindColumnIndexes) == 0 { return nil, nil } - if len(routingKeyInfo.indexes) == 1 { + if len(values) != len(meta.BindColumns) { + return nil, errors.New("gocql: number of values does not match the number of bind columns") + } + + if len(meta.PKBindColumnIndexes) == 1 { // single column routing key routingKey, err := Marshal( - routingKeyInfo.types[0], - values[routingKeyInfo.indexes[0]], + meta.BindColumns[meta.PKBindColumnIndexes[0]].TypeInfo, + values[meta.PKBindColumnIndexes[0]], ) if err != nil { return nil, err @@ -2121,22 +2125,23 @@ func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]b // composite routing key buf := bytes.NewBuffer(make([]byte, 0, 256)) - for i := range routingKeyInfo.indexes { + lenBuf := make([]byte, 2) + for i := range meta.PKBindColumnIndexes { encoded, err := Marshal( - routingKeyInfo.types[i], - values[routingKeyInfo.indexes[i]], + meta.BindColumns[meta.PKBindColumnIndexes[i]].TypeInfo, + values[meta.PKBindColumnIndexes[i]], ) if err != nil { return nil, err } - lenBuf := []byte{0x00, 0x00} + // first write the length of the encoded value as a 16-bit big endian integer binary.BigEndian.PutUint16(lenBuf, uint16(len(encoded))) buf.Write(lenBuf) + // then write the encoded value and a null byte to separate the values buf.Write(encoded) buf.WriteByte(0x00) } - routingKey := buf.Bytes() - return routingKey, nil + return buf.Bytes(), nil } // SetKeyspace will enable keyspace flag on the query. @@ -2195,17 +2200,6 @@ type routingKeyInfoLRU struct { mu sync.Mutex } -type routingKeyInfo struct { - indexes []int - types []TypeInfo - keyspace string - table string -} - -func (r *routingKeyInfo) String() string { - return fmt.Sprintf("routing key index=%v types=%v", r.indexes, r.types) -} - func (r *routingKeyInfoLRU) Remove(key string) { r.mu.Lock() r.lru.Remove(key) diff --git a/topology.go b/topology.go index 5a9b50fa3..e5741ff23 100644 --- a/topology.go +++ b/topology.go @@ -96,7 +96,7 @@ func getStrategy(ks *KeyspaceMetadata, logger StructuredLogger) placementStrateg rf, err := getReplicationFactorFromOpts(ks.StrategyOptions["replication_factor"]) if err != nil { logger.Warning("Failed to parse replication factor of keyspace configured with SimpleStrategy.", - newLogFieldString("keyspace", ks.Name), newLogFieldError("err", err)) + NewLogFieldString("keyspace", ks.Name), NewLogFieldError("err", err)) return nil } return &simpleStrategy{rf: rf} @@ -110,7 +110,7 @@ func getStrategy(ks *KeyspaceMetadata, logger StructuredLogger) placementStrateg rf, err := getReplicationFactorFromOpts(rf) if err != nil { logger.Warning("Failed to parse replication factors of keyspace configured with NetworkTopologyStrategy.", - newLogFieldString("keyspace", ks.Name), newLogFieldString("dc", dc), newLogFieldError("err", err)) + NewLogFieldString("keyspace", ks.Name), NewLogFieldString("dc", dc), NewLogFieldError("err", err)) // skip DC if the rf is invalid/unsupported, so that we can at least work with other working DCs. continue } @@ -122,7 +122,7 @@ func getStrategy(ks *KeyspaceMetadata, logger StructuredLogger) placementStrateg return nil default: logger.Warning("Failed to parse replication factor of keyspace due to unknown strategy class.", - newLogFieldString("keyspace", ks.Name), newLogFieldString("strategy_class", ks.StrategyClass)) + NewLogFieldString("keyspace", ks.Name), NewLogFieldString("strategy_class", ks.StrategyClass)) return nil } }