diff --git a/host_source.go b/host_source.go index adcf1a729..ce764959e 100644 --- a/host_source.go +++ b/host_source.go @@ -201,7 +201,7 @@ func (h *HostInfo) Equal(host *HostInfo) bool { return true } - return h.ConnectAddress().Equal(host.ConnectAddress()) + return h.HostID() == host.HostID() && h.ConnectAddressAndPort() == host.ConnectAddressAndPort() } func (h *HostInfo) Peer() net.IP { diff --git a/policies.go b/policies.go index ed0b02f3e..523ab2fbd 100644 --- a/policies.go +++ b/policies.go @@ -32,7 +32,6 @@ import ( "fmt" "math" "math/rand" - "net" "sync" "sync/atomic" "time" @@ -82,7 +81,7 @@ func (c *cowHostList) add(host *HostInfo) bool { return true } -func (c *cowHostList) remove(ip net.IP) bool { +func (c *cowHostList) remove(host *HostInfo) bool { c.mu.Lock() l := c.get() size := len(l) @@ -94,7 +93,7 @@ func (c *cowHostList) remove(ip net.IP) bool { found := false newL := make([]*HostInfo, 0, size) for i := 0; i < len(l); i++ { - if !l[i].ConnectAddress().Equal(ip) { + if !l[i].Equal(host) { newL = append(newL, l[i]) } else { found = true @@ -351,7 +350,7 @@ func (r *roundRobinHostPolicy) AddHost(host *HostInfo) { } func (r *roundRobinHostPolicy) RemoveHost(host *HostInfo) { - r.hosts.remove(host.ConnectAddress()) + r.hosts.remove(host) } func (r *roundRobinHostPolicy) HostUp(host *HostInfo) { @@ -517,7 +516,7 @@ func (t *tokenAwareHostPolicy) AddHosts(hosts []*HostInfo) { func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) { t.mu.Lock() - if t.hosts.remove(host.ConnectAddress()) { + if t.hosts.remove(host) { meta := t.getMetadataForUpdate() meta.resetTokenRing(t.partitioner, t.hosts.get(), t.logger) t.updateReplicas(meta, t.getKeyspaceName()) @@ -720,9 +719,9 @@ func (d *dcAwareRR) AddHost(host *HostInfo) { func (d *dcAwareRR) RemoveHost(host *HostInfo) { if d.IsLocal(host) { - d.localHosts.remove(host.ConnectAddress()) + d.localHosts.remove(host) } else { - d.remoteHosts.remove(host.ConnectAddress()) + d.remoteHosts.remove(host) } } @@ -826,7 +825,7 @@ func (d *rackAwareRR) AddHost(host *HostInfo) { func (d *rackAwareRR) RemoveHost(host *HostInfo) { dist := d.HostTier(host) - d.hosts[dist].remove(host.ConnectAddress()) + d.hosts[dist].remove(host) } func (d *rackAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } diff --git a/policies_test.go b/policies_test.go index e8bda8908..8a94dee55 100644 --- a/policies_test.go +++ b/policies_test.go @@ -65,6 +65,32 @@ func TestRoundRobbin(t *testing.T) { } } +func TestRoundRobbinSameConnectAddress(t *testing.T) { + policy := RoundRobinHostPolicy() + + hosts := [...]*HostInfo{ + {hostId: "0", connectAddress: net.IPv4(0, 0, 0, 1), port: 9042}, + {hostId: "1", connectAddress: net.IPv4(0, 0, 0, 1), port: 9043}, + } + + for _, host := range hosts { + policy.AddHost(host) + } + + got := make(map[string]bool) + it := policy.Pick(nil) + for h := it(); h != nil; h = it() { + id := h.Info().hostId + if got[id] { + t.Fatalf("got duplicate host: %v", id) + } + got[id] = true + } + if len(got) != len(hosts) { + t.Fatalf("expected %d hosts got %d", len(hosts), len(got)) + } +} + // Tests of the token-aware host selection policy implementation with a // round-robin host selection policy fallback. func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) {