Skip to content

Commit 55525fe

Browse files
committed
ConnectAddress() was refactored
1 parent 37030fb commit 55525fe

File tree

3 files changed

+42
-23
lines changed

3 files changed

+42
-23
lines changed

conn.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,11 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
16971697
}
16981698

16991699
for _, row := range rows {
1700-
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
1700+
h, err := newHostInfo(c.host.ConnectAddress(), c.session.cfg.Port)
1701+
if err != nil {
1702+
goto cont
1703+
}
1704+
host, err := c.session.hostInfoFromMap(row, h)
17011705
if err != nil {
17021706
goto cont
17031707
}

control.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
146146

147147
// Check if host is a literal IP address
148148
if ip := net.ParseIP(host); ip != nil {
149-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
149+
h, err := newHostInfo(ip, port)
150+
if err != nil {
151+
return nil, err
152+
}
153+
hosts = append(hosts, h)
150154
return hosts, nil
151155
}
152156

@@ -172,7 +176,12 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
172176
}
173177

174178
for _, ip := range ips {
175-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
179+
h, err := newHostInfo(ip, port)
180+
if err != nil {
181+
return nil, err
182+
}
183+
184+
hosts = append(hosts, h)
176185
}
177186

178187
return hosts, nil

host_source.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ type HostInfo struct {
159159
tokens []string
160160
}
161161

162+
func newHostInfo(addr net.IP, port int) (*HostInfo, error) {
163+
if !validIpAddr(addr) {
164+
return nil, errors.New("invalid host address")
165+
}
166+
host := &HostInfo{}
167+
host.hostname = addr.String()
168+
host.port = port
169+
170+
host.connectAddress = addr
171+
return host, nil
172+
}
173+
162174
func (h *HostInfo) Equal(host *HostInfo) bool {
163175
if h == host {
164176
// prevent rlock reentry
@@ -191,14 +203,12 @@ func (h *HostInfo) connectAddressLocked() (net.IP, string) {
191203
} else if validIpAddr(h.rpcAddress) {
192204
return h.rpcAddress, "rpc_adress"
193205
} else if validIpAddr(h.preferredIP) {
194-
// where does perferred_ip get set?
195206
return h.preferredIP, "preferred_ip"
196207
} else if validIpAddr(h.broadcastAddress) {
197208
return h.broadcastAddress, "broadcast_address"
198-
} else if validIpAddr(h.peer) {
199-
return h.peer, "peer"
200209
}
201-
return net.IPv4zero, "invalid"
210+
return h.peer, "peer"
211+
202212
}
203213

204214
// nodeToNodeAddress returns address broadcasted between node to nodes.
@@ -218,24 +228,13 @@ func (h *HostInfo) nodeToNodeAddress() net.IP {
218228
}
219229

220230
// Returns the address that should be used to connect to the host.
221-
// If you wish to override this, use an AddressTranslator or
222-
// use a HostFilter to SetConnectAddress()
231+
// If you wish to override this, use an AddressTranslator
223232
func (h *HostInfo) ConnectAddress() net.IP {
224233
h.mu.RLock()
225234
defer h.mu.RUnlock()
226235

227-
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
228-
return addr
229-
}
230-
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
231-
}
232-
233-
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
234-
// TODO(zariel): should this not be exported?
235-
h.mu.Lock()
236-
defer h.mu.Unlock()
237-
h.connectAddress = address
238-
return h
236+
addr, _ := h.connectAddressLocked()
237+
return addr
239238
}
240239

241240
func (h *HostInfo) BroadcastAddress() net.IP {
@@ -469,6 +468,10 @@ func checkSystemSchema(control *controlConn) (bool, error) {
469468
return true, nil
470469
}
471470

471+
func (s *Session) newHostInfoFromMap(addr net.IP, port int, row map[string]interface{}) (*HostInfo, error) {
472+
return s.hostInfoFromMap(row, &HostInfo{connectAddress: addr, port: port})
473+
}
474+
472475
// Given a map that represents a row from either system.local or system.peers
473476
// return as much information as we can in *HostInfo
474477
func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) {
@@ -584,6 +587,9 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*
584587
}
585588

586589
ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
590+
if !validIpAddr(ip) {
591+
return nil, fmt.Errorf("invalid host address (before translation: %v:%v, after translation: %v:%v)", host.ConnectAddress(), host.port, ip.String(), port)
592+
}
587593
host.connectAddress = ip
588594
host.port = port
589595

@@ -601,7 +607,7 @@ func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPor
601607
return nil, errors.New("query returned 0 rows")
602608
}
603609

604-
host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort})
610+
host, err := s.newHostInfoFromMap(connectAddress, defaultPort, rows[0])
605611
if err != nil {
606612
return nil, err
607613
}
@@ -652,7 +658,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, er
652658

653659
for _, row := range rows {
654660
// extract all available info about the peer
655-
host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
661+
host, err := r.session.newHostInfoFromMap(nil, r.session.cfg.Port, row)
656662
if err != nil {
657663
return nil, err
658664
} else if !isValidPeer(host) {

0 commit comments

Comments
 (0)