Skip to content

Commit 62762b1

Browse files
committed
conn: make binds replacable
Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent c69481f commit 62762b1

16 files changed

+161
-151
lines changed

Diff for: conn/conn_linux.go renamed to conn/bind_linux.go

+56-52
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
// +build !android
2-
31
/* SPDX-License-Identifier: MIT
42
*
53
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
@@ -18,55 +16,59 @@ import (
1816
"golang.org/x/sys/unix"
1917
)
2018

21-
type IPv4Source struct {
19+
type ipv4Source struct {
2220
Src [4]byte
2321
Ifindex int32
2422
}
2523

26-
type IPv6Source struct {
24+
type ipv6Source struct {
2725
src [16]byte
28-
//ifindex belongs in dst.ZoneId
26+
// ifindex belongs in dst.ZoneId
2927
}
3028

31-
type NativeEndpoint struct {
29+
type LinuxSocketEndpoint struct {
3230
sync.Mutex
3331
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
34-
src [unsafe.Sizeof(IPv6Source{})]byte
32+
src [unsafe.Sizeof(ipv6Source{})]byte
3533
isV6 bool
3634
}
3735

38-
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
39-
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
40-
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
36+
func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
37+
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
38+
func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
4139

42-
func (endpoint *NativeEndpoint) src4() *IPv4Source {
43-
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
40+
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
41+
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
4442
}
4543

46-
func (endpoint *NativeEndpoint) src6() *IPv6Source {
47-
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
44+
func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
45+
return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
4846
}
4947

50-
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
48+
func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
5149
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
5250
}
5351

54-
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
52+
func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
5553
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
5654
}
5755

58-
type nativeBind struct {
56+
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
57+
type LinuxSocketBind struct {
5958
sock4 int
6059
sock6 int
6160
lastMark uint32
6261
closing sync.RWMutex
6362
}
6463

65-
var _ Endpoint = (*NativeEndpoint)(nil)
66-
var _ Bind = (*nativeBind)(nil)
64+
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
65+
func NewDefaultBind() Bind { return NewLinuxSocketBind() }
66+
67+
var _ Endpoint = (*LinuxSocketEndpoint)(nil)
68+
var _ Bind = (*LinuxSocketBind)(nil)
6769

68-
func CreateEndpoint(s string) (Endpoint, error) {
69-
var end NativeEndpoint
70+
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
71+
var end LinuxSocketEndpoint
7072
addr, err := parseEndpoint(s)
7173
if err != nil {
7274
return nil, err
@@ -97,14 +99,18 @@ func CreateEndpoint(s string) (Endpoint, error) {
9799
return &end, nil
98100
}
99101

100-
return nil, errors.New("Invalid IP address")
102+
return nil, errors.New("invalid IP address")
101103
}
102104

103-
func createBind(port uint16) (Bind, uint16, error) {
105+
func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) {
104106
var err error
105-
var bind nativeBind
106107
var newPort uint16
107108
var tries int
109+
110+
if bind.sock4 != -1 || bind.sock6 != -1 {
111+
return 0, ErrBindAlreadyOpen
112+
}
113+
108114
originalPort := port
109115

110116
again:
@@ -113,7 +119,7 @@ again:
113119
bind.sock6, newPort, err = create6(port)
114120
if err != nil {
115121
if err != syscall.EAFNOSUPPORT {
116-
return nil, 0, err
122+
return 0, err
117123
}
118124
} else {
119125
port = newPort
@@ -129,24 +135,19 @@ again:
129135
}
130136
if err != syscall.EAFNOSUPPORT {
131137
unix.Close(bind.sock6)
132-
return nil, 0, err
138+
return 0, err
133139
}
134140
} else {
135141
port = newPort
136142
}
137143

138144
if bind.sock4 == -1 && bind.sock6 == -1 {
139-
return nil, 0, errors.New("ipv4 and ipv6 not supported")
145+
return 0, syscall.EAFNOSUPPORT
140146
}
141-
142-
return &bind, port, nil
143-
}
144-
145-
func (bind *nativeBind) LastMark() uint32 {
146-
return bind.lastMark
147+
return port, nil
147148
}
148149

149-
func (bind *nativeBind) SetMark(value uint32) error {
150+
func (bind *LinuxSocketBind) SetMark(value uint32) error {
150151
bind.closing.RLock()
151152
defer bind.closing.RUnlock()
152153

@@ -180,7 +181,7 @@ func (bind *nativeBind) SetMark(value uint32) error {
180181
return nil
181182
}
182183

183-
func (bind *nativeBind) Close() error {
184+
func (bind *LinuxSocketBind) Close() error {
184185
var err1, err2 error
185186
bind.closing.RLock()
186187
if bind.sock6 != -1 {
@@ -207,11 +208,11 @@ func (bind *nativeBind) Close() error {
207208
return err2
208209
}
209210

210-
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
211+
func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
211212
bind.closing.RLock()
212213
defer bind.closing.RUnlock()
213214

214-
var end NativeEndpoint
215+
var end LinuxSocketEndpoint
215216
if bind.sock6 == -1 {
216217
return 0, nil, net.ErrClosed
217218
}
@@ -223,11 +224,11 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
223224
return n, &end, err
224225
}
225226

226-
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
227+
func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
227228
bind.closing.RLock()
228229
defer bind.closing.RUnlock()
229230

230-
var end NativeEndpoint
231+
var end LinuxSocketEndpoint
231232
if bind.sock4 == -1 {
232233
return 0, nil, net.ErrClosed
233234
}
@@ -239,11 +240,14 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
239240
return n, &end, err
240241
}
241242

242-
func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
243+
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
243244
bind.closing.RLock()
244245
defer bind.closing.RUnlock()
245246

246-
nend := end.(*NativeEndpoint)
247+
nend, ok := end.(*LinuxSocketEndpoint)
248+
if !ok {
249+
return ErrWrongEndpointType
250+
}
247251
if !nend.isV6 {
248252
if bind.sock4 == -1 {
249253
return net.ErrClosed
@@ -257,7 +261,7 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
257261
}
258262
}
259263

260-
func (end *NativeEndpoint) SrcIP() net.IP {
264+
func (end *LinuxSocketEndpoint) SrcIP() net.IP {
261265
if !end.isV6 {
262266
return net.IPv4(
263267
end.src4().Src[0],
@@ -270,7 +274,7 @@ func (end *NativeEndpoint) SrcIP() net.IP {
270274
}
271275
}
272276

273-
func (end *NativeEndpoint) DstIP() net.IP {
277+
func (end *LinuxSocketEndpoint) DstIP() net.IP {
274278
if !end.isV6 {
275279
return net.IPv4(
276280
end.dst4().Addr[0],
@@ -283,19 +287,19 @@ func (end *NativeEndpoint) DstIP() net.IP {
283287
}
284288
}
285289

286-
func (end *NativeEndpoint) DstToBytes() []byte {
290+
func (end *LinuxSocketEndpoint) DstToBytes() []byte {
287291
if !end.isV6 {
288292
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
289293
} else {
290294
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
291295
}
292296
}
293297

294-
func (end *NativeEndpoint) SrcToString() string {
298+
func (end *LinuxSocketEndpoint) SrcToString() string {
295299
return end.SrcIP().String()
296300
}
297301

298-
func (end *NativeEndpoint) DstToString() string {
302+
func (end *LinuxSocketEndpoint) DstToString() string {
299303
var udpAddr net.UDPAddr
300304
udpAddr.IP = end.DstIP()
301305
if !end.isV6 {
@@ -306,13 +310,13 @@ func (end *NativeEndpoint) DstToString() string {
306310
return udpAddr.String()
307311
}
308312

309-
func (end *NativeEndpoint) ClearDst() {
313+
func (end *LinuxSocketEndpoint) ClearDst() {
310314
for i := range end.dst {
311315
end.dst[i] = 0
312316
}
313317
}
314318

315-
func (end *NativeEndpoint) ClearSrc() {
319+
func (end *LinuxSocketEndpoint) ClearSrc() {
316320
for i := range end.src {
317321
end.src[i] = 0
318322
}
@@ -427,7 +431,7 @@ func create6(port uint16) (int, uint16, error) {
427431
return fd, uint16(addr.Port), err
428432
}
429433

430-
func send4(sock int, end *NativeEndpoint, buff []byte) error {
434+
func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
431435

432436
// construct message header
433437

@@ -467,7 +471,7 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
467471
return err
468472
}
469473

470-
func send6(sock int, end *NativeEndpoint, buff []byte) error {
474+
func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
471475

472476
// construct message header
473477

@@ -511,7 +515,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
511515
return err
512516
}
513517

514-
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
518+
func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
515519

516520
// construct message header
517521

@@ -543,7 +547,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
543547
return size, nil
544548
}
545549

546-
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
550+
func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
547551

548552
// construct message header
549553

0 commit comments

Comments
 (0)