From 8b51c534279c8760a1a7d60b730666e0f0ab2b38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Reis?= Date: Fri, 13 Jun 2025 17:30:05 +0100 Subject: [PATCH] Add a way to create HostInfo objects for testing purposes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It is not possible to test implementations of public gocql interfaces such as HostFilter because it requires creating HostInfo objects (HostInfo has private fields). With this change, it is now possible to create HostInfo objects from system.local / system.peers rows using NewTestHostInfoFromRow(). Patch by João Reis; reviewed by James Hartig for CASSGO-71 --- CHANGELOG.md | 2 + conn.go | 6 +-- control.go | 4 +- host_source.go | 82 ++++++++++++++++++++++++++++++--------- hostpool/hostpool_test.go | 40 ++++++++++++------- session.go | 4 +- 6 files changed, 97 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a1c888bb..6262875bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Cleanup of deprecated elements (CASSGO-12) - Remove global NewBatch function (CASSGO-15) - Remove deprecated global logger (CASSGO-24) +- HostInfo.SetHostID is no longer exported (CASSGO-71) ### Added @@ -23,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Query.SetKeyspace(), Query.WithNowInSeconds(), Batch.SetKeyspace(), Batch.WithNowInSeconds() (CASSGO-1) - Externally-defined type registration (CASSGO-43) - Add Query and Batch to ObservedQuery and ObservedBatch (CASSGO-73) +- Add way to create HostInfo objects for testing purposes (CASSGO-71) ### Changed diff --git a/conn.go b/conn.go index 41f7db41a..f80d7ca27 100644 --- a/conn.go +++ b/conn.go @@ -1938,11 +1938,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { for _, row := range rows { var host *HostInfo - host, err = NewHostInfo(c.host.ConnectAddress(), c.session.cfg.Port) - if err != nil { - goto cont - } - host, err = c.session.hostInfoFromMap(row, host) + host, err = c.session.newHostInfoFromMap(c.host.ConnectAddress(), c.session.cfg.Port, row) if err != nil { goto cont } diff --git a/control.go b/control.go index c2cb4cb13..fa30cd059 100644 --- a/control.go +++ b/control.go @@ -148,7 +148,7 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) { // Check if host is a literal IP address if ip := net.ParseIP(host); ip != nil { - h, err := NewHostInfo(ip, port) + h, err := NewHostInfoFromAddrPort(ip, port) if err != nil { return nil, err } @@ -178,7 +178,7 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) { } for _, ip := range ips { - h, err := NewHostInfo(ip, port) + h, err := NewHostInfoFromAddrPort(ip, port) if err != nil { return nil, err } diff --git a/host_source.go b/host_source.go index 7e88cf9da..fbc9134d6 100644 --- a/host_source.go +++ b/host_source.go @@ -155,6 +155,8 @@ func (c cassVersion) nodeUpDelay() time.Duration { return 10 * time.Second } +// HostInfo represents a server Host/Node. You can create a HostInfo object with either NewHostInfoFromAddrPort or +// NewTestHostInfoFromRow. type HostInfo struct { // TODO(zariel): reduce locking maybe, not all values will change, but to ensure // that we are thread safe use a mutex to access all fields. @@ -181,9 +183,12 @@ type HostInfo struct { tokens []string } -// NewHostInfo creates HostInfo with provided connectAddress and port. +// NewHostInfoFromAddrPort creates HostInfo with provided connectAddress and port. // It returns an error if addr is invalid. -func NewHostInfo(addr net.IP, port int) (*HostInfo, error) { +// +// If you're looking for a way to create a HostInfo object with more than just an address and port for +// testing purposes then you can use NewTestHostInfoFromRow +func NewHostInfoFromAddrPort(addr net.IP, port int) (*HostInfo, error) { if !validIpAddr(addr) { return nil, errors.New("invalid host address") } @@ -251,7 +256,7 @@ func (h *HostInfo) nodeToNodeAddress() net.IP { return net.IPv4zero } -// Returns the address that should be used to connect to the host. +// ConnectAddress Returns the address that should be used to connect to the host. // If you wish to override this, use an AddressTranslator func (h *HostInfo) ConnectAddress() net.IP { h.mu.RLock() @@ -305,7 +310,7 @@ func (h *HostInfo) HostID() string { return h.hostId } -func (h *HostInfo) SetHostID(hostID string) { +func (h *HostInfo) setHostID(hostID string) { h.mu.Lock() defer h.mu.Unlock() h.hostId = hostID @@ -492,17 +497,42 @@ func checkSystemSchema(control *controlConn) (bool, error) { return true, nil } -func (s *Session) newHostInfoFromMap(addr net.IP, port int, row map[string]interface{}) (*HostInfo, error) { - return s.hostInfoFromMap(row, &HostInfo{connectAddress: addr, port: port}) -} - // Given a map that represents a row from either system.local or system.peers // return as much information as we can in *HostInfo -func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) { +func (s *Session) newHostInfoFromMap(addr net.IP, port int, row map[string]interface{}) (*HostInfo, error) { + return newHostInfoFromRow(s, addr, port, row) +} + +// NewTestHostInfoFromRow creates a new HostInfo object from a system.peers or system.local row. The port +// defaults to 9042. +// +// You can create a HostInfo object for testing purposes using this function: +// +// Example usage: +// +// row := map[string]interface{}{ +// "broadcast_address": net.ParseIP("10.0.0.1"), +// "listen_address": net.ParseIP("10.0.0.1"), +// "rpc_address": net.ParseIP("10.0.0.1"), +// "peer": net.ParseIP("10.0.0.1"), // system.peers only +// "data_center": "dc1", +// "rack": "rack1", +// "host_id": MustRandomUUID(), // can also use ParseUUID("550e8400-e29b-41d4-a716-446655440000") +// "release_version": "4.0.0", +// "native_port": 9042, +// } +// host, err := NewTestHostInfoFromRow(row) +func NewTestHostInfoFromRow(row map[string]interface{}) (*HostInfo, error) { + return newHostInfoFromRow(nil, nil, 9042, row) +} + +func newHostInfoFromRow(s *Session, defaultAddr net.IP, defaultPort int, row map[string]interface{}) (*HostInfo, error) { const assertErrorMsg = "Assertion failed for %s, type was %T" var ok bool - // Default to our connected port if the cluster doesn't have port information + host := &HostInfo{connectAddress: defaultAddr, port: defaultPort} + + // Process all fields from the row for key, value := range row { switch key { case "data_center": @@ -606,18 +636,34 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* } host.schemaVersion = schemaVersion.String() } - // TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete - // Not sure what the port field will be called until the JIRA issue is complete } - ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port, s.logger) - if !validIpAddr(ip) { - return nil, fmt.Errorf("invalid host address (before translation: %v:%v, after translation: %v:%v)", host.ConnectAddress(), host.port, ip.String(), port) + // Determine the connect address from available addresses + if validIpAddr(host.rpcAddress) { + host.connectAddress = host.rpcAddress + } else if validIpAddr(host.preferredIP) { + host.connectAddress = host.preferredIP + } else if validIpAddr(host.broadcastAddress) { + host.connectAddress = host.broadcastAddress + } else if validIpAddr(host.peer) { + host.connectAddress = host.peer } - host.connectAddress = ip - host.port = port - return host, nil + if s != nil && s.cfg.AddressTranslator != nil { + ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port, s.logger) + if !validIpAddr(ip) { + return nil, fmt.Errorf("invalid host address (before translation: %v:%v, after translation: %v:%v)", host.ConnectAddress(), host.port, ip.String(), port) + } + host.connectAddress = ip + host.port = port + } + + if validIpAddr(host.connectAddress) { + host.hostname = host.connectAddress.String() + return host, nil + } else { + return nil, errors.New("invalid host address") + } } func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int) (*HostInfo, error) { diff --git a/hostpool/hostpool_test.go b/hostpool/hostpool_test.go index 75e3cc563..eb374d9ab 100644 --- a/hostpool/hostpool_test.go +++ b/hostpool/hostpool_test.go @@ -35,21 +35,33 @@ func TestHostPolicy_HostPool(t *testing.T) { policy := HostPoolHostPolicy(hostpool.New(nil)) //hosts := []*gocql.HostInfo{ - // {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 0)}, - // {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 1)}, + // {hostId: "f1935733-af5f-4995-bd1e-94a7a3e67bfd", connectAddress: net.ParseIP("10.0.0.0")}, + // {hostId: "93ca4489-b322-4fda-b5a5-12d4436271df", connectAddress: net.ParseIP("10.0.0.1")}, //} + firstHostId, err1 := gocql.ParseUUID("f1935733-af5f-4995-bd1e-94a7a3e67bfd") + secondHostId, err2 := gocql.ParseUUID("93ca4489-b322-4fda-b5a5-12d4436271df") - firstHost, err := gocql.NewHostInfo(net.IPv4(10, 0, 0, 0), 9042) + if err1 != nil || err2 != nil { + t.Fatal(err1, err2) + } + + firstHost, err := gocql.NewTestHostInfoFromRow( + map[string]interface{}{ + "peer": net.ParseIP("10.0.0.0"), + "native_port": 9042, + "host_id": firstHostId}) if err != nil { t.Errorf("Error creating first host: %v", err) } - firstHost.SetHostID("0") - secHost, err := gocql.NewHostInfo(net.IPv4(10, 0, 0, 1), 9042) + secHost, err := gocql.NewTestHostInfoFromRow( + map[string]interface{}{ + "peer": net.ParseIP("10.0.0.1"), + "native_port": 9042, + "host_id": secondHostId}) if err != nil { t.Errorf("Error creating second host: %v", err) } - secHost.SetHostID("1") hosts := []*gocql.HostInfo{firstHost, secHost} // Using set host to control the ordering of the hosts as calling "AddHost" iterates the map // which will result in an unpredictable ordering @@ -59,26 +71,26 @@ func TestHostPolicy_HostPool(t *testing.T) { // interleaved iteration should always increment the host iter := policy.Pick(nil) actualA := iter() - if actualA.Info().HostID() != "0" { - t.Errorf("Expected hosts[0] but was hosts[%s]", actualA.Info().HostID()) + if actualA.Info().HostID() != firstHostId.String() { + t.Errorf("Expected first host id but was %s", actualA.Info().HostID()) } actualA.Mark(nil) actualB := iter() - if actualB.Info().HostID() != "1" { - t.Errorf("Expected hosts[1] but was hosts[%s]", actualB.Info().HostID()) + if actualB.Info().HostID() != secondHostId.String() { + t.Errorf("Expected second host id but was %s", actualB.Info().HostID()) } actualB.Mark(fmt.Errorf("error")) actualC := iter() - if actualC.Info().HostID() != "0" { - t.Errorf("Expected hosts[0] but was hosts[%s]", actualC.Info().HostID()) + if actualC.Info().HostID() != firstHostId.String() { + t.Errorf("Expected first host id but was %s", actualC.Info().HostID()) } actualC.Mark(nil) actualD := iter() - if actualD.Info().HostID() != "0" { - t.Errorf("Expected hosts[0] but was hosts[%s]", actualD.Info().HostID()) + if actualD.Info().HostID() != firstHostId.String() { + t.Errorf("Expected first host id but was %s", actualD.Info().HostID()) } actualD.Mark(nil) } diff --git a/session.go b/session.go index 45206cd04..606c68f4b 100644 --- a/session.go +++ b/session.go @@ -268,7 +268,7 @@ func (s *Session) init() error { // by internal logic. // Associate random UUIDs here with all hosts missing this information. if len(host.HostID()) == 0 { - host.SetHostID(MustRandomUUID().String()) + host.setHostID(MustRandomUUID().String()) } } @@ -2143,7 +2143,7 @@ type ObservedQuery struct { // Rows is not used in batch queries and remains at the default value Rows int - // Host is the informations about the host that performed the query + // Host is the information about the host that performed the query Host *HostInfo // The metrics per this host