From 34b5349ebea4361ea0b21df2244c751b9a7ca138 Mon Sep 17 00:00:00 2001 From: "mykyta.oleksiienko" Date: Wed, 2 Apr 2025 11:35:39 +0300 Subject: [PATCH] Make HostFilter interface easier to test Currently external HostFilter implementations cannot be tested against HostFilter interface. To fix that the Accept method should have an interface as the parameter instead of HostInfo so that it is possible to mock it. Patch by: Mykyta Oleksiienko; Revieved by: ***; For CASSGO-71. --- CHANGELOG.md | 1 + cassandra_test.go | 8 +++++- events.go | 3 -- filters.go | 51 +++++++++++++++++++++++++++------ filters_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++ host_source.go | 13 ++------- integration_test.go | 4 +-- 7 files changed, 125 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37ae55e3e..b1248f23c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Query.SetKeyspace(), Query.WithNowInSeconds(), Batch.SetKeyspace(), Batch.WithNowInSeconds() (CASSGO-1) ### Changed +- Make HostFilter interface easier to test (CASSGO-71) - Move lz4 compressor to lz4 package within the gocql module (CASSGO-32) diff --git a/cassandra_test.go b/cassandra_test.go index 54a54f426..5c6aacf61 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -480,6 +480,12 @@ func TestCAS(t *testing.T) { t.Fatalf("insert should have not been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } + // TODO: This test failing with error due to "function dateof" on cassandra side. + // It was officially removed in version 5.0.0. The recommended replacement for dateOf is the toTimestamp function. + // As we are not testing against deprecated cassandra versions, it makes sense to update tests to keep them up to date + // === RUN TestCAS + // cassandra_test.go:487: insert: Unknown function dateof called + // --- FAIL: TestCAS (0.97s) insertBatch := session.Batch(LoggedBatch) insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") @@ -951,7 +957,7 @@ func TestReconnection(t *testing.T) { t.Fatal("Host should be NodeDown but not.") } - time.Sleep(cluster.ReconnectInterval + h.Version().nodeUpDelay() + 1*time.Second) + time.Sleep(cluster.ReconnectInterval + 1*time.Second) if h.State() != NodeUp { t.Fatal("Host should be NodeUp but not. Failed to reconnect.") diff --git a/events.go b/events.go index 93b001acc..a22c81255 100644 --- a/events.go +++ b/events.go @@ -227,9 +227,6 @@ func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) { return } - if d := host.Version().nodeUpDelay(); d > 0 { - time.Sleep(d) - } s.startPoolFill(host) } diff --git a/filters.go b/filters.go index 312bd0d1a..bc1a6026a 100644 --- a/filters.go +++ b/filters.go @@ -24,31 +24,66 @@ package gocql -import "fmt" +import ( + "fmt" + "net" +) // HostFilter interface is used when a host is discovered via server sent events. type HostFilter interface { // Called when a new host is discovered, returning true will cause the host // to be added to the pools. - Accept(host *HostInfo) bool + Accept(host Host) bool +} + +// Host interface is provided to enable testing of custom implementations of the HostFilter interface. +type Host interface { + Peer() net.IP + ConnectAddress() net.IP + BroadcastAddress() net.IP + ListenAddress() net.IP + RPCAddress() net.IP + PreferredIP() net.IP + DataCenter() string + Rack() string + HostID() string + WorkLoad() string + Graph() bool + DSEVersion() string + Partitioner() string + ClusterName() string + Version() CassVersion + Tokens() []string + Port() int + IsUp() bool + String() string +} + +// Since cassVersion is an unexported type, the CassVersion interface is introduced +// to allow better testability and increase test coverage. +type CassVersion interface { + Set(v string) error + UnmarshalCQL(info TypeInfo, data []byte) error + AtLeast(major, minor, patch int) bool + String() string } // HostFilterFunc converts a func(host HostInfo) bool into a HostFilter -type HostFilterFunc func(host *HostInfo) bool +type HostFilterFunc func(host Host) bool -func (fn HostFilterFunc) Accept(host *HostInfo) bool { +func (fn HostFilterFunc) Accept(host Host) bool { return fn(host) } // AcceptAllFilter will accept all hosts func AcceptAllFilter() HostFilter { - return HostFilterFunc(func(host *HostInfo) bool { + return HostFilterFunc(func(host Host) bool { return true }) } func DenyAllFilter() HostFilter { - return HostFilterFunc(func(host *HostInfo) bool { + return HostFilterFunc(func(host Host) bool { return false }) } @@ -56,7 +91,7 @@ func DenyAllFilter() HostFilter { // DataCenterHostFilter filters all hosts such that they are in the same data center // as the supplied data center. func DataCenterHostFilter(dataCenter string) HostFilter { - return HostFilterFunc(func(host *HostInfo) bool { + return HostFilterFunc(func(host Host) bool { return host.DataCenter() == dataCenter }) } @@ -81,7 +116,7 @@ func WhiteListHostFilter(hosts ...string) HostFilter { m[host.ConnectAddress().String()] = true } - return HostFilterFunc(func(host *HostInfo) bool { + return HostFilterFunc(func(host Host) bool { return m[host.ConnectAddress().String()] }) } diff --git a/filters_test.go b/filters_test.go index a1abec207..08f7a48fb 100644 --- a/filters_test.go +++ b/filters_test.go @@ -27,6 +27,8 @@ package gocql import ( "net" "testing" + + "github.com/stretchr/testify/assert" ) func TestFilter_WhiteList(t *testing.T) { @@ -121,3 +123,71 @@ func TestFilter_DataCenter(t *testing.T) { } } } + +// mockHost is a custom implementation of the Host interface for testing. +type mockHost struct { + peer net.IP + connectAddress net.IP + dataCenter string + version CassVersion +} + +func (m *mockHost) Peer() net.IP { return m.peer } +func (m *mockHost) ConnectAddress() net.IP { return m.connectAddress } +func (m *mockHost) BroadcastAddress() net.IP { return nil } +func (m *mockHost) ListenAddress() net.IP { return nil } +func (m *mockHost) RPCAddress() net.IP { return nil } +func (m *mockHost) PreferredIP() net.IP { return nil } +func (m *mockHost) DataCenter() string { return m.dataCenter } +func (m *mockHost) Rack() string { return "" } +func (m *mockHost) HostID() string { return "" } +func (m *mockHost) WorkLoad() string { return "" } +func (m *mockHost) Graph() bool { return false } +func (m *mockHost) DSEVersion() string { return "" } +func (m *mockHost) Partitioner() string { return "" } +func (m *mockHost) ClusterName() string { return "" } +func (m *mockHost) Version() CassVersion { return m.version } +func (m *mockHost) Tokens() []string { return nil } +func (m *mockHost) Port() int { return 9042 } +func (m *mockHost) IsUp() bool { return true } +func (m *mockHost) String() string { return "mockHost" } + +// mockCassVersion is a fake CassVersion implementation for testing. +type mockCassVersion struct { + versionString string +} + +func (m *mockCassVersion) Set(v string) error { m.versionString = v; return nil } +func (m *mockCassVersion) UnmarshalCQL(info TypeInfo, data []byte) error { return nil } +func (m *mockCassVersion) AtLeast(major, minor, patch int) bool { return true } +func (m *mockCassVersion) String() string { return m.versionString } + +// Test custom Host implementation +func TestHostImplementation(t *testing.T) { + mockVersion := &mockCassVersion{versionString: "3.11.4"} + mockHost := &mockHost{ + peer: net.ParseIP("192.168.1.1"), + connectAddress: net.ParseIP("10.0.0.1"), + dataCenter: "datacenter1", + version: mockVersion, + } + + assert.Equal(t, "datacenter1", mockHost.DataCenter(), "DataCenter() should return the correct value") + assert.Equal(t, net.ParseIP("10.0.0.1"), mockHost.ConnectAddress(), "ConnectAddress() should return the correct IP") + assert.Equal(t, "3.11.4", mockHost.Version().String(), "Version() should return the correct Cassandra version") + assert.True(t, mockHost.IsUp(), "IsUp() should return true") + assert.Equal(t, "mockHost", mockHost.String(), "String() should return 'mockHost'") +} + +// Test CassVersion interface implementation +func TestHostCassVersion(t *testing.T) { + mockVersion := &mockCassVersion{versionString: "4.0.0"} + + // Test setting version + err := mockVersion.Set("4.0.1") + assert.NoError(t, err, "Set() should not return an error") + assert.Equal(t, "4.0.1", mockVersion.String(), "String() should return the updated version") + + // Test AtLeast method + assert.True(t, mockVersion.AtLeast(4, 0, 0), "AtLeast() should return true for matching version") +} diff --git a/host_source.go b/host_source.go index adcf1a729..ad0c51b83 100644 --- a/host_source.go +++ b/host_source.go @@ -146,15 +146,6 @@ func (c cassVersion) String() string { return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch) } -func (c cassVersion) nodeUpDelay() time.Duration { - if c.Major >= 2 && c.Minor >= 2 { - // CASSANDRA-8236 - return 0 - } - - return 10 * time.Second -} - type HostInfo struct { // TODO(zariel): reduce locking maybe, not all values will change, but to ensure // that we are thread safe use a mutex to access all fields. @@ -341,10 +332,10 @@ func (h *HostInfo) ClusterName() string { return h.clusterName } -func (h *HostInfo) Version() cassVersion { +func (h *HostInfo) Version() CassVersion { h.mu.RLock() defer h.mu.RUnlock() - return h.version + return &h.version } func (h *HostInfo) State() nodeState { diff --git a/integration_test.go b/integration_test.go index 61ffbf504..24ef3f33e 100644 --- a/integration_test.go +++ b/integration_test.go @@ -114,7 +114,7 @@ func TestHostFilterDiscovery(t *testing.T) { // we'll filter out the second host filtered := clusterHosts[1] cluster.Hosts = clusterHosts[:1] - cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool { + cluster.HostFilter = HostFilterFunc(func(host Host) bool { if host.ConnectAddress().String() == filtered { return false } @@ -138,7 +138,7 @@ func TestHostFilterInitial(t *testing.T) { cluster.PoolConfig.HostSelectionPolicy = rr // we'll filter out the second host filtered := clusterHosts[1] - cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool { + cluster.HostFilter = HostFilterFunc(func(host Host) bool { if host.ConnectAddress().String() == filtered { return false }