diff --git a/CHANGELOG.md b/CHANGELOG.md index 37ae55e3e..ed8c5664f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Refactoring hostpool package test and Expose HostInfo creation (CASSGO-59) +- Unable to discover cluster nodes with an empty rack name (CASSGO-6) + ### Fixed - Cassandra version unmarshal fix (CASSGO-49) diff --git a/cluster_test.go b/cluster_test.go index adc21fd05..04451c055 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -25,6 +25,7 @@ package gocql import ( + "errors" "net" "reflect" "testing" @@ -80,3 +81,59 @@ func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) { assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr)) assertEqual(t, "translated port", 5432, newPort) } + +func TestEmptyRack(t *testing.T) { + s := &Session{} + host := &HostInfo{} + + row := make(map[string]interface{}) + + row["preferred_ip"] = "172.3.0.2" + row["rpc_address"] = "172.3.0.2" + row["host_id"] = UUIDFromTime(time.Now()) + row["data_center"] = "dc1" + row["tokens"] = []string{"t1", "t2"} + row["rack"] = "rack1" + + validHost, err := s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + row["rack"] = "" + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + strPtr := new(string) + *strPtr = "rack" + row["rack"] = strPtr + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + strPtr = new(string) + strPtr = nil + row["rack"] = strPtr + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if isValidPeer(validHost) { + t.Fatal(errors.New("expected invalid host")) + } +} diff --git a/helpers.go b/helpers.go index 823c10689..f29e5c1f7 100644 --- a/helpers.go +++ b/helpers.go @@ -332,23 +332,29 @@ func (iter *Iter) RowData() (RowData, error) { values := make([]interface{}, 0, len(iter.Columns())) for _, column := range iter.Columns() { - if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { - val, err := column.TypeInfo.NewWithError() - if err != nil { - iter.err = err - return RowData{}, err - } + if column.Name == "rack" && column.Keyspace == "system" && (column.Table == "peers_v2" || column.Table == "peers") { + var strPtr = new(string) columns = append(columns, column.Name) - values = append(values, val) + values = append(values, &strPtr) } else { - for i, elem := range c.Elems { - columns = append(columns, TupleColumnName(column.Name, i)) - val, err := elem.NewWithError() + if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { + val, err := column.TypeInfo.NewWithError() if err != nil { iter.err = err return RowData{}, err } + columns = append(columns, column.Name) values = append(values, val) + } else { + for i, elem := range c.Elems { + columns = append(columns, TupleColumnName(column.Name, i)) + val, err := elem.NewWithError() + if err != nil { + iter.err = err + return RowData{}, err + } + values = append(values, val) + } } } } diff --git a/host_source.go b/host_source.go index adcf1a729..c013eab95 100644 --- a/host_source.go +++ b/host_source.go @@ -179,6 +179,7 @@ type HostInfo struct { state nodeState schemaVersion string tokens []string + isRackNil bool } // NewHostInfo creates HostInfo with provided connectAddress and port. @@ -511,9 +512,18 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* return nil, fmt.Errorf(assertErrorMsg, "data_center") } case "rack": - host.rack, ok = value.(string) + rack, ok := value.(*string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "rack") + host.rack, ok = value.(string) + if !ok { + return nil, fmt.Errorf(assertErrorMsg, "rack") + } + } else { + if rack != nil { + host.rack = *rack + } else { + host.isRackNil = true + } } case "host_id": hostId, ok := value.(UUID) @@ -703,7 +713,7 @@ func isValidPeer(host *HostInfo) bool { return !(len(host.RPCAddress()) == 0 || host.hostId == "" || host.dataCenter == "" || - host.rack == "" || + host.isRackNil || len(host.tokens) == 0) }