Skip to content

Commit f5114b0

Browse files
authored
fix: fix how validateAddrPort(...) handles ipv6 literals (#66)
1 parent b11d1c6 commit f5114b0

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

server.go

+23-20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"crypto/tls"
99
"fmt"
1010
"net"
11+
"net/netip"
1112
"strings"
1213
"sync"
1314
"time"
@@ -75,35 +76,37 @@ func last(s string, b byte) int {
7576
return i
7677
}
7778

78-
// validateAddr will not only validate the addr, but if it's an ipv6 literal without
79-
// proper brackets, it will add them.
80-
func validateAddr(addr string) (string, error) {
79+
// validateAddrPort will not only validate the address+port, but if it's an ipv6
80+
// literal without proper brackets, it will add them.
81+
func validateAddrPort(addrPort string) (string, error) {
8182
const op = "gldap.parseAddr"
8283

83-
lastColon := last(addr, ':')
84+
lastColon := last(addrPort, ':')
8485
if lastColon < 0 {
85-
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
86+
return "", fmt.Errorf("%s: missing port in addr \"%s\": %w", op, addrPort, ErrInvalidParameter)
8687
}
87-
rawHost := addr[0:lastColon]
88-
rawPort := addr[lastColon+1:]
88+
rawHost := addrPort[0:lastColon]
89+
rawPort := addrPort[lastColon+1:]
8990
switch {
9091
case len(rawPort) == 0:
91-
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
92+
return "", fmt.Errorf("%s: missing port in addr \"%s\": %w", op, addrPort, ErrInvalidParameter)
9293
case len(rawHost) == 0:
9394
return fmt.Sprintf(":%s", rawPort), nil
94-
case addr[0] == '[' && addr[len(addr)-1] == ']':
95-
return "", fmt.Errorf("%s: missing port in ipv6 addr : %s : %w", op, addr, ErrInvalidParameter)
95+
case addrPort[0] == '[' && addrPort[len(addrPort)-1] == ']':
96+
return "", fmt.Errorf("%s: missing port in ipv6 addr : \"%s\": %w", op, addrPort, ErrInvalidParameter)
9697
}
9798
// ipv6 literal with proper brackets
9899
if rawHost[0] == '[' {
99100
// Expect the first ']' just before the last ':'.
100101
end := strings.IndexByte(rawHost, ']')
101102
if end < 0 {
102-
return "", fmt.Errorf("%s: missing ']' in ipv6 address %s : %w", op, addr, ErrInvalidParameter)
103+
return "", fmt.Errorf("%s: missing ']' in ipv6 address \"%s\": %w", op, addrPort, ErrInvalidParameter)
103104
}
105+
// Note: netip.ParseAddr requires ipv6 addresses without brackets []
104106
trimmedIp := strings.Trim(rawHost, "[]")
105-
if net.ParseIP(trimmedIp) == nil {
106-
return "", fmt.Errorf("%s: invalid ipv6 address %s : %w", op, rawHost, ErrInvalidParameter)
107+
if _, err := netip.ParseAddr(trimmedIp); err != nil {
108+
// if net.ParseIP(trimmedIp) == nil {
109+
return "", fmt.Errorf("%s: invalid ipv6 address \"%s\": %w", op, rawHost, err)
107110
}
108111
// ipv6 literal has enclosing brackets, and it's a valid ipv6 address, so we're good
109112
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
@@ -123,16 +126,16 @@ func validateAddr(addr string) (string, error) {
123126

124127
lastColon = last(rawHost, ':')
125128
if lastColon >= 0 {
126-
// ipv6 literal without proper brackets
127-
ipv6Literal := fmt.Sprintf("[%s]", rawHost)
128-
if net.ParseIP(ipv6Literal) == nil {
129-
return "", fmt.Errorf("%s: invalid ipv6 address + port %s : %w", op, addr, ErrInvalidParameter)
129+
// ipv6 literal without proper brackets. Note: netip.ParseAddr requires
130+
// ipv6 addresses without brackets []
131+
if _, err := netip.ParseAddr(rawHost); err != nil {
132+
return "", fmt.Errorf("%s: invalid ipv6 address + port \"%s\": %w", op, addrPort, err)
130133
}
131-
return fmt.Sprintf("[%s]:%s", ipv6Literal, rawPort), nil
134+
return fmt.Sprintf("[%s]:%s", rawHost, rawPort), nil
132135
}
133136
// ipv4
134137
if net.ParseIP(rawHost) == nil {
135-
return "", fmt.Errorf("%s: invalid IP address %s : %w", op, rawHost, ErrInvalidParameter)
138+
return "", fmt.Errorf("%s: invalid IP address \"%s\": %w", op, rawHost, ErrInvalidParameter)
136139
}
137140
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
138141
}
@@ -145,7 +148,7 @@ func (s *Server) Run(addr string, opt ...Option) error {
145148
opts := getConfigOpts(opt...)
146149

147150
var err error
148-
addr, err = validateAddr(addr)
151+
addr, err = validateAddrPort(addr)
149152
if err != nil {
150153
return fmt.Errorf("%s: %w", op, err)
151154
}

server_internal_test.go

+8-10
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (*mockListener) Close() error {
127127
return errors.New("mockListener.Close error")
128128
}
129129

130-
func TestValidateAddr(t *testing.T) {
130+
func Test_validateAddrPort(t *testing.T) {
131131
tests := []struct {
132132
name string
133133
addr string
@@ -180,46 +180,44 @@ func TestValidateAddr(t *testing.T) {
180180
{
181181
name: "err-missing-port-ipv6",
182182
addr: "[::1]",
183-
wantErrContains: "missing port in ipv6 addr : [::1]",
183+
wantErrContains: "missing port in ipv6 addr : \"[::1]\"",
184184
wantErrIs: ErrInvalidParameter,
185185
},
186186
{
187187
name: "err-invalid-IPv4-address",
188188
addr: "0.0",
189-
wantErrContains: "missing port in addr 0.0",
189+
wantErrContains: "missing port in addr \"0.0\"",
190190
wantErrIs: ErrInvalidParameter,
191191
},
192192
{
193193
name: "err-invalid-IPv6-address-missing-start-bracket",
194194
addr: "::1]",
195-
wantErrContains: "invalid ipv6 address + port ::1]",
196-
wantErrIs: ErrInvalidParameter,
195+
wantErrContains: "invalid ipv6 address + port \"::1]\": ParseAddr(\":\"): each colon-separated field must have at least one digit (at \":\")",
197196
},
198197
{
199198
name: "err-invalid-IPv6-address-missing-final-bracket",
200199
addr: "[::1",
201-
wantErrContains: "missing ']' in ipv6 address [::1",
200+
wantErrContains: "missing ']' in ipv6 address \"[::1\"",
202201
wantErrIs: ErrInvalidParameter,
203202
},
204203
{
205204
name: "err-invalid-IPv6",
206205
addr: "2001:db8:3333:4444:5555:6666:7777:389",
207-
wantErrContains: "invalid ipv6 address + port 2001:db8:3333:4444:5555:6666:7777:389",
208-
wantErrIs: ErrInvalidParameter,
206+
wantErrContains: "invalid ipv6 address + port \"2001:db8:3333:4444:5555:6666:7777:389\": ParseAddr(\"2001:db8:3333:4444:5555:6666:7777\"): address string too short",
209207
},
210208
{
211209
name: "err-missing-port",
212210
addr: "invalid",
213211
expected: "",
214-
wantErrContains: "missing port in addr invalid",
212+
wantErrContains: "missing port in addr \"invalid\"",
215213
wantErrIs: ErrInvalidParameter,
216214
},
217215
}
218216

219217
for _, tc := range tests {
220218
tc := tc
221219
t.Run(tc.name, func(t *testing.T) {
222-
result, err := validateAddr(tc.addr)
220+
result, err := validateAddrPort(tc.addr)
223221
if tc.wantErrContains != "" {
224222
require.Error(t, err)
225223
assert.Empty(t, result)

0 commit comments

Comments
 (0)