Skip to content

Commit 3c7ec55

Browse files
bradfitzcrawshaw
authored andcommitted
wgcfg: clean up IP type/method signatures
Signed-off-by: Brad Fitzpatrick <[email protected]>
1 parent f49cb12 commit 3c7ec55

File tree

3 files changed

+59
-69
lines changed

3 files changed

+59
-69
lines changed

Diff for: wgcfg/ip.go

+30-28
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,47 @@ type IP struct {
1616

1717
func (ip IP) String() string { return net.IP(ip.Addr[:]).String() }
1818

19-
func (ip *IP) IP() net.IP { return net.IP(ip.Addr[:]) }
20-
func (ip *IP) Is6() bool { return !ip.Is4() }
21-
func (ip *IP) Is4() bool {
19+
// IP converts ip into a standard library net.IP.
20+
func (ip IP) IP() net.IP { return net.IP(ip.Addr[:]) }
21+
22+
// Is6 reports whether ip is an IPv6 address.
23+
func (ip IP) Is6() bool { return !ip.Is4() }
24+
25+
// Is4 reports whether ip is an IPv4 address.
26+
func (ip IP) Is4() bool {
2227
return ip.Addr[0] == 0 && ip.Addr[1] == 0 &&
2328
ip.Addr[2] == 0 && ip.Addr[3] == 0 &&
2429
ip.Addr[4] == 0 && ip.Addr[5] == 0 &&
2530
ip.Addr[6] == 0 && ip.Addr[7] == 0 &&
2631
ip.Addr[8] == 0 && ip.Addr[9] == 0 &&
2732
ip.Addr[10] == 0xff && ip.Addr[11] == 0xff
2833
}
29-
func (ip *IP) To4() []byte {
34+
35+
// To4 returns either a 4 byte slice for an IPv4 address, or nil if
36+
// it's not IPv4.
37+
func (ip IP) To4() []byte {
3038
if ip.Is4() {
3139
return ip.Addr[12:16]
3240
} else {
3341
return nil
3442
}
3543
}
36-
func (ip *IP) Equal(x *IP) bool {
37-
if ip == nil || x == nil {
38-
return false
39-
}
40-
// TODO: this isn't hard, write a more efficient implementation.
41-
return ip.IP().Equal(x.IP())
44+
45+
// Equal reports whether ip == x.
46+
func (ip IP) Equal(x IP) bool {
47+
return ip == x
4248
}
4349

4450
func (ip IP) MarshalText() ([]byte, error) {
4551
return []byte(ip.String()), nil
4652
}
4753

4854
func (ip *IP) UnmarshalText(text []byte) error {
49-
parsedIP := ParseIP(string(text))
50-
if parsedIP == nil {
51-
return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", string(text))
55+
parsedIP, ok := ParseIP(string(text))
56+
if !ok {
57+
return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", text)
5258
}
53-
*ip = *parsedIP
59+
*ip = parsedIP
5460
return nil
5561
}
5662

@@ -66,15 +72,14 @@ func IPv4(b0, b1, b2, b3 byte) (ip IP) {
6672
// ParseIP parses the string representation of an address into an IP.
6773
//
6874
// It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0".
69-
// If the string is not a valid IP address, ParseIP returns nil.
70-
func ParseIP(s string) *IP {
75+
// The ok result reports whether s was a valid IP and ip is valid.
76+
func ParseIP(s string) (ip IP, ok bool) {
7177
netIP := net.ParseIP(s)
7278
if netIP == nil {
73-
return nil
79+
return IP{}, false
7480
}
75-
ip := new(IP)
7681
copy(ip.Addr[:], netIP.To16())
77-
return ip
82+
return ip, true
7883
}
7984

8085
// CIDR is a compact IP address and subnet mask.
@@ -85,12 +90,12 @@ type CIDR struct {
8590

8691
// ParseCIDR parses CIDR notation into a CIDR type.
8792
// Typical CIDR strings look like "192.168.1.0/24".
88-
func ParseCIDR(s string) (cidr *CIDR, err error) {
93+
func ParseCIDR(s string) (CIDR, error) {
8994
netIP, netAddr, err := net.ParseCIDR(s)
9095
if err != nil {
91-
return nil, err
96+
return CIDR{}, err
9297
}
93-
cidr = new(CIDR)
98+
var cidr CIDR
9499
copy(cidr.IP.Addr[:], netIP.To16())
95100
ones, _ := netAddr.Mask.Size()
96101
cidr.Mask = uint8(ones)
@@ -100,18 +105,15 @@ func ParseCIDR(s string) (cidr *CIDR, err error) {
100105

101106
func (r CIDR) String() string { return r.IPNet().String() }
102107

103-
func (r *CIDR) IPNet() *net.IPNet {
108+
func (r CIDR) IPNet() *net.IPNet {
104109
bits := 128
105110
if r.IP.Is4() {
106111
bits = 32
107112
}
108113
return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)}
109114
}
110115

111-
func (r *CIDR) Contains(ip *IP) bool {
112-
if r == nil || ip == nil {
113-
return false
114-
}
116+
func (r CIDR) Contains(ip IP) bool {
115117
c := int8(r.Mask)
116118
i := 0
117119
if r.IP.Is4() {
@@ -145,6 +147,6 @@ func (r *CIDR) UnmarshalText(text []byte) error {
145147
if err != nil {
146148
return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err)
147149
}
148-
*r = *cidr
150+
*r = cidr
149151
return nil
150152
}

Diff for: wgcfg/ip_test.go

+23-35
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,24 @@ import (
1111
"github.com/tailscale/wireguard-go/wgcfg"
1212
)
1313

14+
func parseIP(t testing.TB, ipStr string) wgcfg.IP {
15+
t.Helper()
16+
ip, ok := wgcfg.ParseIP(ipStr)
17+
if !ok {
18+
t.Fatalf("failed to parse IP: %q", ipStr)
19+
}
20+
return ip
21+
}
22+
1423
func TestCIDRContains(t *testing.T) {
1524
t.Run("home router test", func(t *testing.T) {
1625
r, err := wgcfg.ParseCIDR("192.168.0.0/24")
1726
if err != nil {
1827
t.Fatal(err)
1928
}
20-
ip := wgcfg.ParseIP("192.168.0.1")
21-
if ip == nil {
22-
t.Fatalf("address failed to parse")
23-
}
29+
ip := parseIP(t, "192.168.0.1")
2430
if !r.Contains(ip) {
25-
t.Fatalf("'%s' should contain '%s'", r, ip)
31+
t.Fatalf("%q should contain %q", r, ip)
2632
}
2733
})
2834

@@ -31,12 +37,9 @@ func TestCIDRContains(t *testing.T) {
3137
if err != nil {
3238
t.Fatal(err)
3339
}
34-
ip := wgcfg.ParseIP("192.168.0.4")
35-
if ip == nil {
36-
t.Fatalf("address failed to parse")
37-
}
40+
ip := parseIP(t, "192.168.0.4")
3841
if r.Contains(ip) {
39-
t.Fatalf("'%s' should not contain '%s'", r, ip)
42+
t.Fatalf("%q should not contain %q", r, ip)
4043
}
4144
})
4245

@@ -45,12 +48,9 @@ func TestCIDRContains(t *testing.T) {
4548
if err != nil {
4649
t.Fatal(err)
4750
}
48-
ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334")
49-
if ip == nil {
50-
t.Fatalf("address failed to parse")
51-
}
51+
ip := parseIP(t, "2001:db8:85a3:0:0:8a2e:370:7334")
5252
if r.Contains(ip) {
53-
t.Fatalf("'%s' should not contain '%s'", r, ip)
53+
t.Fatalf("%q should not contain %q", r, ip)
5454
}
5555
})
5656

@@ -59,12 +59,9 @@ func TestCIDRContains(t *testing.T) {
5959
if err != nil {
6060
t.Fatal(err)
6161
}
62-
ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
63-
if ip == nil {
64-
t.Fatalf("ParseIP returned nil pointer")
65-
}
62+
ip := parseIP(t, "2001:db8:1234:0000:0000:0000:0000:0001")
6663
if !r.Contains(ip) {
67-
t.Fatalf("'%s' should not contain '%s'", r, ip)
64+
t.Fatalf("%q should not contain %q", r, ip)
6865
}
6966
})
7067

@@ -73,12 +70,9 @@ func TestCIDRContains(t *testing.T) {
7370
if err != nil {
7471
t.Fatal(err)
7572
}
76-
ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4")
77-
if ip == nil {
78-
t.Fatalf("ParseIP returned nil pointer")
79-
}
73+
ip := parseIP(t, "2001:db8:1234:0:190b:0:1982:4")
8074
if r.Contains(ip) {
81-
t.Fatalf("'%s' should not contain '%s'", r, ip)
75+
t.Fatalf("%q should not contain %q", r, ip)
8276
}
8377
})
8478
}
@@ -89,12 +83,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) {
8983
if err != nil {
9084
b.Fatal(err)
9185
}
92-
ip := wgcfg.ParseIP("1.2.3.4")
93-
if ip == nil {
94-
b.Fatalf("ParseIP returned nil pointer")
95-
}
96-
86+
ip := parseIP(b, "1.2.3.4")
9787
b.ResetTimer()
88+
9889
for i := 0; i < b.N; i++ {
9990
r.Contains(ip)
10091
}
@@ -105,12 +96,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) {
10596
if err != nil {
10697
b.Fatal(err)
10798
}
108-
ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
109-
if ip == nil {
110-
b.Fatalf("ParseIP returned nil pointer")
111-
}
112-
99+
ip := parseIP(b, "2001:db8:1234:0000:0000:0000:0000:0001")
113100
b.ResetTimer()
101+
114102
for i := 0; i < b.N; i++ {
115103
r.Contains(ip)
116104
}

Diff for: wgcfg/parser.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -219,19 +219,19 @@ func FromWgQuick(s string, name string) (*Config, error) {
219219
if err != nil {
220220
return nil, err
221221
}
222-
conf.Addresses = append(conf.Addresses, *a)
222+
conf.Addresses = append(conf.Addresses, a)
223223
}
224224
case "dns":
225225
addresses, err := splitList(val)
226226
if err != nil {
227227
return nil, err
228228
}
229229
for _, address := range addresses {
230-
a := ParseIP(address)
231-
if a == nil {
230+
a, ok := ParseIP(address)
231+
if !ok {
232232
return nil, &ParseError{"Invalid IP address", address}
233233
}
234-
conf.DNS = append(conf.DNS, *a)
234+
conf.DNS = append(conf.DNS, a)
235235
}
236236
default:
237237
return nil, &ParseError{"Invalid key for [Interface] section", key}
@@ -260,7 +260,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
260260
if err != nil {
261261
return nil, err
262262
}
263-
peer.AllowedIPs = append(peer.AllowedIPs, *a)
263+
peer.AllowedIPs = append(peer.AllowedIPs, a)
264264
}
265265
case "persistentkeepalive":
266266
p, err := parsePersistentKeepalive(val)
@@ -373,7 +373,7 @@ func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) {
373373
if err != nil {
374374
return nil, err
375375
}
376-
peer.AllowedIPs = append(peer.AllowedIPs, *a)
376+
peer.AllowedIPs = append(peer.AllowedIPs, a)
377377
case "persistent_keepalive_interval":
378378
p, err := parsePersistentKeepalive(val)
379379
if err != nil {

0 commit comments

Comments
 (0)