Skip to content

Commit ffe9660

Browse files
committed
ConnectAddress() was refactored
1 parent 37030fb commit ffe9660

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

conn.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,11 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
16971697
}
16981698

16991699
for _, row := range rows {
1700-
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
1700+
h, err := newHostInfo(c.host.connectAddress, c.session.cfg.Port)
1701+
if err != nil {
1702+
goto cont
1703+
}
1704+
host, err := c.session.hostInfoFromMap(row, h)
17011705
if err != nil {
17021706
goto cont
17031707
}

control.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
146146

147147
// Check if host is a literal IP address
148148
if ip := net.ParseIP(host); ip != nil {
149-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
149+
h, err := newHostInfo(ip, port)
150+
if err != nil {
151+
return nil, err
152+
}
153+
hosts = append(hosts, h)
150154
return hosts, nil
151155
}
152156

@@ -172,7 +176,12 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
172176
}
173177

174178
for _, ip := range ips {
175-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
179+
h, err := newHostInfo(ip, port)
180+
if err != nil {
181+
return nil, err
182+
}
183+
184+
hosts = append(hosts, h)
176185
}
177186

178187
return hosts, nil

host_source.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ type HostInfo struct {
159159
tokens []string
160160
}
161161

162+
func newHostInfo(addr net.IP, port int) (*HostInfo, error) {
163+
host := &HostInfo{}
164+
host.hostname = addr.String()
165+
host.port = port
166+
if !validIpAddr(addr) {
167+
return nil, errors.New("invalid host address")
168+
}
169+
170+
host.connectAddress = addr
171+
return host, nil
172+
}
173+
162174
func (h *HostInfo) Equal(host *HostInfo) bool {
163175
if h == host {
164176
// prevent rlock reentry
@@ -224,18 +236,20 @@ func (h *HostInfo) ConnectAddress() net.IP {
224236
h.mu.RLock()
225237
defer h.mu.RUnlock()
226238

227-
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
228-
return addr
229-
}
230-
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
239+
addr, _ := h.connectAddressLocked()
240+
return addr
231241
}
232242

233-
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
234-
// TODO(zariel): should this not be exported?
243+
func (h *HostInfo) SetConnectAddress(address net.IP) (*HostInfo, error) {
244+
if !validIpAddr(address) {
245+
return nil, errors.New("invalid address")
246+
}
247+
235248
h.mu.Lock()
236249
defer h.mu.Unlock()
250+
237251
h.connectAddress = address
238-
return h
252+
return h, nil
239253
}
240254

241255
func (h *HostInfo) BroadcastAddress() net.IP {
@@ -584,6 +598,9 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*
584598
}
585599

586600
ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
601+
if !validIpAddr(ip) {
602+
return nil, errors.New("invalid connect address")
603+
}
587604
host.connectAddress = ip
588605
host.port = port
589606

0 commit comments

Comments
 (0)