Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Query.SetKeyspace(), Query.WithNowInSeconds(), Batch.SetKeyspace(), Batch.WithNowInSeconds() (CASSGO-1)

### Changed
- Make HostFilter interface easier to test (CASSGO-71)

- Move lz4 compressor to lz4 package within the gocql module (CASSGO-32)

Expand Down
8 changes: 7 additions & 1 deletion cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,12 @@ func TestCAS(t *testing.T) {
t.Fatalf("insert should have not been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

// TODO: This test failing with error due to "function dateof" on cassandra side.
// It was officially removed in version 5.0.0. The recommended replacement for dateOf is the toTimestamp function.
// As we are not testing against deprecated cassandra versions, it makes sense to update tests to keep them up to date
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this was done in #1828 so once that is merged and this PR is rebased it should be fixed

// === RUN TestCAS
// cassandra_test.go:487: insert: Unknown function dateof called
// --- FAIL: TestCAS (0.97s)
insertBatch := session.Batch(LoggedBatch)
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
Expand Down Expand Up @@ -951,7 +957,7 @@ func TestReconnection(t *testing.T) {
t.Fatal("Host should be NodeDown but not.")
}

time.Sleep(cluster.ReconnectInterval + h.Version().nodeUpDelay() + 1*time.Second)
time.Sleep(cluster.ReconnectInterval + 1*time.Second)

if h.State() != NodeUp {
t.Fatal("Host should be NodeUp but not. Failed to reconnect.")
Expand Down
3 changes: 0 additions & 3 deletions events.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
return
}

if d := host.Version().nodeUpDelay(); d > 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this?

time.Sleep(d)
}
s.startPoolFill(host)
}

Expand Down
51 changes: 43 additions & 8 deletions filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,74 @@

package gocql

import "fmt"
import (
"fmt"
"net"
)

// HostFilter interface is used when a host is discovered via server sent events.
type HostFilter interface {
// Called when a new host is discovered, returning true will cause the host
// to be added to the pools.
Accept(host *HostInfo) bool
Accept(host Host) bool
}

// Host interface is provided to enable testing of custom implementations of the HostFilter interface.
type Host interface {
Peer() net.IP
ConnectAddress() net.IP
BroadcastAddress() net.IP
ListenAddress() net.IP
RPCAddress() net.IP
PreferredIP() net.IP
DataCenter() string
Rack() string
HostID() string
WorkLoad() string
Graph() bool
DSEVersion() string
Partitioner() string
ClusterName() string
Version() CassVersion
Tokens() []string
Port() int
IsUp() bool
String() string
}

// Since cassVersion is an unexported type, the CassVersion interface is introduced
// to allow better testability and increase test coverage.
type CassVersion interface {
Set(v string) error
UnmarshalCQL(info TypeInfo, data []byte) error
AtLeast(major, minor, patch int) bool
String() string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs some changes, just making an interface out of the methods of the private type feels wrong here, we shouldn't expose Set or Unmarshal method and we should definitely expose a way to get the actual version fields.

Ideally CassVersion would be:

type CassVersion interface {
	Major() int
        Minor() int
        Patch() int
        Qualifier() string
	String() string
}

The methods on cassVersion can be converted into global private functions.

func (c *cassVersion) Set(v string) error 
func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error 
func (c *cassVersion) unmarshal(data []byte) error {
func (c cassVersion) Before(major, minor, patch int) bool 
func (c cassVersion) AtLeast(major, minor, patch int) bool
func parseCassVersion(v string) (CassVersion, error) {
        c := &cassVersion{}
	if v == "" {
		return c, nil
	}

	return c.unmarshalCQL(nil, []byte(v))
}

func (c *cassVersion) unmarshalCQL(info TypeInfo, data []byte) error {
	return c.unmarshal(data)
}

// can be kept as is
func (c *cassVersion) unmarshal(data []byte) error

// these two should just be global, private and have `CassVersion` parameters instead
func before(v1 CassVersion, v2 CassVersion) bool 
func atLeast(v1 CassVersion, v2 CassVersion) bool

}

// HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it also needs to be updated to func(host Host) bool

type HostFilterFunc func(host *HostInfo) bool
type HostFilterFunc func(host Host) bool

func (fn HostFilterFunc) Accept(host *HostInfo) bool {
func (fn HostFilterFunc) Accept(host Host) bool {
return fn(host)
}

// AcceptAllFilter will accept all hosts
func AcceptAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return HostFilterFunc(func(host Host) bool {
return true
})
}

func DenyAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return HostFilterFunc(func(host Host) bool {
return false
})
}

// DataCenterHostFilter filters all hosts such that they are in the same data center
// as the supplied data center.
func DataCenterHostFilter(dataCenter string) HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return HostFilterFunc(func(host Host) bool {
return host.DataCenter() == dataCenter
})
}
Expand All @@ -81,7 +116,7 @@ func WhiteListHostFilter(hosts ...string) HostFilter {
m[host.ConnectAddress().String()] = true
}

return HostFilterFunc(func(host *HostInfo) bool {
return HostFilterFunc(func(host Host) bool {
return m[host.ConnectAddress().String()]
})
}
70 changes: 70 additions & 0 deletions filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ package gocql
import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestFilter_WhiteList(t *testing.T) {
Expand Down Expand Up @@ -121,3 +123,71 @@ func TestFilter_DataCenter(t *testing.T) {
}
}
}

// mockHost is a custom implementation of the Host interface for testing.
type mockHost struct {
peer net.IP
connectAddress net.IP
dataCenter string
version CassVersion
}

func (m *mockHost) Peer() net.IP { return m.peer }
func (m *mockHost) ConnectAddress() net.IP { return m.connectAddress }
func (m *mockHost) BroadcastAddress() net.IP { return nil }
func (m *mockHost) ListenAddress() net.IP { return nil }
func (m *mockHost) RPCAddress() net.IP { return nil }
func (m *mockHost) PreferredIP() net.IP { return nil }
func (m *mockHost) DataCenter() string { return m.dataCenter }
func (m *mockHost) Rack() string { return "" }
func (m *mockHost) HostID() string { return "" }
func (m *mockHost) WorkLoad() string { return "" }
func (m *mockHost) Graph() bool { return false }
func (m *mockHost) DSEVersion() string { return "" }
func (m *mockHost) Partitioner() string { return "" }
func (m *mockHost) ClusterName() string { return "" }
func (m *mockHost) Version() CassVersion { return m.version }
func (m *mockHost) Tokens() []string { return nil }
func (m *mockHost) Port() int { return 9042 }
func (m *mockHost) IsUp() bool { return true }
func (m *mockHost) String() string { return "mockHost" }

// mockCassVersion is a fake CassVersion implementation for testing.
type mockCassVersion struct {
versionString string
}

func (m *mockCassVersion) Set(v string) error { m.versionString = v; return nil }
func (m *mockCassVersion) UnmarshalCQL(info TypeInfo, data []byte) error { return nil }
func (m *mockCassVersion) AtLeast(major, minor, patch int) bool { return true }
func (m *mockCassVersion) String() string { return m.versionString }

// Test custom Host implementation
func TestHostImplementation(t *testing.T) {
mockVersion := &mockCassVersion{versionString: "3.11.4"}
mockHost := &mockHost{
peer: net.ParseIP("192.168.1.1"),
connectAddress: net.ParseIP("10.0.0.1"),
dataCenter: "datacenter1",
version: mockVersion,
}

assert.Equal(t, "datacenter1", mockHost.DataCenter(), "DataCenter() should return the correct value")
assert.Equal(t, net.ParseIP("10.0.0.1"), mockHost.ConnectAddress(), "ConnectAddress() should return the correct IP")
assert.Equal(t, "3.11.4", mockHost.Version().String(), "Version() should return the correct Cassandra version")
assert.True(t, mockHost.IsUp(), "IsUp() should return true")
assert.Equal(t, "mockHost", mockHost.String(), "String() should return 'mockHost'")
}

// Test CassVersion interface implementation
func TestHostCassVersion(t *testing.T) {
mockVersion := &mockCassVersion{versionString: "4.0.0"}

// Test setting version
err := mockVersion.Set("4.0.1")
assert.NoError(t, err, "Set() should not return an error")
assert.Equal(t, "4.0.1", mockVersion.String(), "String() should return the updated version")

// Test AtLeast method
assert.True(t, mockVersion.AtLeast(4, 0, 0), "AtLeast() should return true for matching version")
}
13 changes: 2 additions & 11 deletions host_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,6 @@ func (c cassVersion) String() string {
return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
}

func (c cassVersion) nodeUpDelay() time.Duration {
if c.Major >= 2 && c.Minor >= 2 {
// CASSANDRA-8236
return 0
}

return 10 * time.Second
}

type HostInfo struct {
// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
// that we are thread safe use a mutex to access all fields.
Expand Down Expand Up @@ -341,10 +332,10 @@ func (h *HostInfo) ClusterName() string {
return h.clusterName
}

func (h *HostInfo) Version() cassVersion {
func (h *HostInfo) Version() CassVersion {
h.mu.RLock()
defer h.mu.RUnlock()
return h.version
return &h.version
}

func (h *HostInfo) State() nodeState {
Expand Down
4 changes: 2 additions & 2 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestHostFilterDiscovery(t *testing.T) {
// we'll filter out the second host
filtered := clusterHosts[1]
cluster.Hosts = clusterHosts[:1]
cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool {
cluster.HostFilter = HostFilterFunc(func(host Host) bool {
if host.ConnectAddress().String() == filtered {
return false
}
Expand All @@ -138,7 +138,7 @@ func TestHostFilterInitial(t *testing.T) {
cluster.PoolConfig.HostSelectionPolicy = rr
// we'll filter out the second host
filtered := clusterHosts[1]
cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool {
cluster.HostFilter = HostFilterFunc(func(host Host) bool {
if host.ConnectAddress().String() == filtered {
return false
}
Expand Down
Loading