Skip to content

Commit 12406bf

Browse files
committed
ConnectAddress() was refactored
1 parent 37030fb commit 12406bf

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

conn.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,11 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
16971697
}
16981698

16991699
for _, row := range rows {
1700-
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
1700+
h, err := newHostInfo(c.host.ConnectAddress(), c.session.cfg.Port)
1701+
if err != nil {
1702+
goto cont
1703+
}
1704+
host, err := c.session.hostInfoFromMap(row, h)
17011705
if err != nil {
17021706
goto cont
17031707
}

control.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
146146

147147
// Check if host is a literal IP address
148148
if ip := net.ParseIP(host); ip != nil {
149-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
149+
h, err := newHostInfo(ip, port)
150+
if err != nil {
151+
return nil, err
152+
}
153+
hosts = append(hosts, h)
150154
return hosts, nil
151155
}
152156

@@ -172,7 +176,12 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
172176
}
173177

174178
for _, ip := range ips {
175-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
179+
h, err := newHostInfo(ip, port)
180+
if err != nil {
181+
return nil, err
182+
}
183+
184+
hosts = append(hosts, h)
176185
}
177186

178187
return hosts, nil

host_source.go

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ type HostInfo struct {
159159
tokens []string
160160
}
161161

162+
func newHostInfo(addr net.IP, port int) (*HostInfo, error) {
163+
if !validIpAddr(addr) {
164+
return nil, errors.New("invalid host address")
165+
}
166+
host := &HostInfo{}
167+
host.hostname = addr.String()
168+
host.port = port
169+
170+
host.connectAddress = addr
171+
return host, nil
172+
}
173+
162174
func (h *HostInfo) Equal(host *HostInfo) bool {
163175
if h == host {
164176
// prevent rlock reentry
@@ -191,14 +203,12 @@ func (h *HostInfo) connectAddressLocked() (net.IP, string) {
191203
} else if validIpAddr(h.rpcAddress) {
192204
return h.rpcAddress, "rpc_adress"
193205
} else if validIpAddr(h.preferredIP) {
194-
// where does perferred_ip get set?
195206
return h.preferredIP, "preferred_ip"
196207
} else if validIpAddr(h.broadcastAddress) {
197208
return h.broadcastAddress, "broadcast_address"
198-
} else if validIpAddr(h.peer) {
199-
return h.peer, "peer"
200209
}
201-
return net.IPv4zero, "invalid"
210+
return h.peer, "peer"
211+
202212
}
203213

204214
// nodeToNodeAddress returns address broadcasted between node to nodes.
@@ -224,18 +234,20 @@ func (h *HostInfo) ConnectAddress() net.IP {
224234
h.mu.RLock()
225235
defer h.mu.RUnlock()
226236

227-
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
228-
return addr
229-
}
230-
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
237+
addr, _ := h.connectAddressLocked()
238+
return addr
231239
}
232240

233-
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
234-
// TODO(zariel): should this not be exported?
241+
func (h *HostInfo) SetConnectAddress(address net.IP) (*HostInfo, error) {
242+
if !validIpAddr(address) {
243+
return nil, errors.New("invalid address")
244+
}
245+
235246
h.mu.Lock()
236247
defer h.mu.Unlock()
248+
237249
h.connectAddress = address
238-
return h
250+
return h, nil
239251
}
240252

241253
func (h *HostInfo) BroadcastAddress() net.IP {
@@ -469,6 +481,10 @@ func checkSystemSchema(control *controlConn) (bool, error) {
469481
return true, nil
470482
}
471483

484+
func (s *Session) newHostInfoFromMap(addr net.IP, port int, row map[string]interface{}) (*HostInfo, error) {
485+
return s.hostInfoFromMap(row, &HostInfo{connectAddress: addr, port: port})
486+
}
487+
472488
// Given a map that represents a row from either system.local or system.peers
473489
// return as much information as we can in *HostInfo
474490
func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) {
@@ -584,6 +600,9 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*
584600
}
585601

586602
ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
603+
if !validIpAddr(ip) {
604+
return nil, fmt.Errorf("invalid host address (before translation: %v:%v, after translation: %v:%v)", host.ConnectAddress(), host.port, ip.String(), port)
605+
}
587606
host.connectAddress = ip
588607
host.port = port
589608

@@ -601,7 +620,7 @@ func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPor
601620
return nil, errors.New("query returned 0 rows")
602621
}
603622

604-
host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort})
623+
host, err := s.newHostInfoFromMap(connectAddress, defaultPort, rows[0])
605624
if err != nil {
606625
return nil, err
607626
}
@@ -652,7 +671,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, er
652671

653672
for _, row := range rows {
654673
// extract all available info about the peer
655-
host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
674+
host, err := r.session.newHostInfoFromMap(nil, r.session.cfg.Port, row)
656675
if err != nil {
657676
return nil, err
658677
} else if !isValidPeer(host) {

0 commit comments

Comments
 (0)