1
- // +build !android
2
-
3
1
/* SPDX-License-Identifier: MIT
4
2
*
5
3
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
@@ -18,55 +16,59 @@ import (
18
16
"golang.org/x/sys/unix"
19
17
)
20
18
21
- type IPv4Source struct {
19
+ type ipv4Source struct {
22
20
Src [4 ]byte
23
21
Ifindex int32
24
22
}
25
23
26
- type IPv6Source struct {
24
+ type ipv6Source struct {
27
25
src [16 ]byte
28
- //ifindex belongs in dst.ZoneId
26
+ // ifindex belongs in dst.ZoneId
29
27
}
30
28
31
- type NativeEndpoint struct {
29
+ type LinuxSocketEndpoint struct {
32
30
sync.Mutex
33
31
dst [unsafe .Sizeof (unix.SockaddrInet6 {})]byte
34
- src [unsafe .Sizeof (IPv6Source {})]byte
32
+ src [unsafe .Sizeof (ipv6Source {})]byte
35
33
isV6 bool
36
34
}
37
35
38
- func (endpoint * NativeEndpoint ) Src4 () * IPv4Source { return endpoint .src4 () }
39
- func (endpoint * NativeEndpoint ) Dst4 () * unix.SockaddrInet4 { return endpoint .dst4 () }
40
- func (endpoint * NativeEndpoint ) IsV6 () bool { return endpoint .isV6 }
36
+ func (endpoint * LinuxSocketEndpoint ) Src4 () * ipv4Source { return endpoint .src4 () }
37
+ func (endpoint * LinuxSocketEndpoint ) Dst4 () * unix.SockaddrInet4 { return endpoint .dst4 () }
38
+ func (endpoint * LinuxSocketEndpoint ) IsV6 () bool { return endpoint .isV6 }
41
39
42
- func (endpoint * NativeEndpoint ) src4 () * IPv4Source {
43
- return (* IPv4Source )(unsafe .Pointer (& endpoint .src [0 ]))
40
+ func (endpoint * LinuxSocketEndpoint ) src4 () * ipv4Source {
41
+ return (* ipv4Source )(unsafe .Pointer (& endpoint .src [0 ]))
44
42
}
45
43
46
- func (endpoint * NativeEndpoint ) src6 () * IPv6Source {
47
- return (* IPv6Source )(unsafe .Pointer (& endpoint .src [0 ]))
44
+ func (endpoint * LinuxSocketEndpoint ) src6 () * ipv6Source {
45
+ return (* ipv6Source )(unsafe .Pointer (& endpoint .src [0 ]))
48
46
}
49
47
50
- func (endpoint * NativeEndpoint ) dst4 () * unix.SockaddrInet4 {
48
+ func (endpoint * LinuxSocketEndpoint ) dst4 () * unix.SockaddrInet4 {
51
49
return (* unix .SockaddrInet4 )(unsafe .Pointer (& endpoint .dst [0 ]))
52
50
}
53
51
54
- func (endpoint * NativeEndpoint ) dst6 () * unix.SockaddrInet6 {
52
+ func (endpoint * LinuxSocketEndpoint ) dst6 () * unix.SockaddrInet6 {
55
53
return (* unix .SockaddrInet6 )(unsafe .Pointer (& endpoint .dst [0 ]))
56
54
}
57
55
58
- type nativeBind struct {
56
+ // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
57
+ type LinuxSocketBind struct {
59
58
sock4 int
60
59
sock6 int
61
60
lastMark uint32
62
61
closing sync.RWMutex
63
62
}
64
63
65
- var _ Endpoint = (* NativeEndpoint )(nil )
66
- var _ Bind = (* nativeBind )(nil )
64
+ func NewLinuxSocketBind () Bind { return & LinuxSocketBind {sock4 : - 1 , sock6 : - 1 } }
65
+ func NewDefaultBind () Bind { return NewLinuxSocketBind () }
66
+
67
+ var _ Endpoint = (* LinuxSocketEndpoint )(nil )
68
+ var _ Bind = (* LinuxSocketBind )(nil )
67
69
68
- func CreateEndpoint (s string ) (Endpoint , error ) {
69
- var end NativeEndpoint
70
+ func ( * LinuxSocketBind ) ParseEndpoint (s string ) (Endpoint , error ) {
71
+ var end LinuxSocketEndpoint
70
72
addr , err := parseEndpoint (s )
71
73
if err != nil {
72
74
return nil , err
@@ -97,14 +99,18 @@ func CreateEndpoint(s string) (Endpoint, error) {
97
99
return & end , nil
98
100
}
99
101
100
- return nil , errors .New ("Invalid IP address" )
102
+ return nil , errors .New ("invalid IP address" )
101
103
}
102
104
103
- func createBind ( port uint16 ) (Bind , uint16 , error ) {
105
+ func ( bind * LinuxSocketBind ) Open ( port uint16 ) (uint16 , error ) {
104
106
var err error
105
- var bind nativeBind
106
107
var newPort uint16
107
108
var tries int
109
+
110
+ if bind .sock4 != - 1 || bind .sock6 != - 1 {
111
+ return 0 , ErrBindAlreadyOpen
112
+ }
113
+
108
114
originalPort := port
109
115
110
116
again:
@@ -113,7 +119,7 @@ again:
113
119
bind .sock6 , newPort , err = create6 (port )
114
120
if err != nil {
115
121
if err != syscall .EAFNOSUPPORT {
116
- return nil , 0 , err
122
+ return 0 , err
117
123
}
118
124
} else {
119
125
port = newPort
@@ -129,24 +135,19 @@ again:
129
135
}
130
136
if err != syscall .EAFNOSUPPORT {
131
137
unix .Close (bind .sock6 )
132
- return nil , 0 , err
138
+ return 0 , err
133
139
}
134
140
} else {
135
141
port = newPort
136
142
}
137
143
138
144
if bind .sock4 == - 1 && bind .sock6 == - 1 {
139
- return nil , 0 , errors . New ( "ipv4 and ipv6 not supported" )
145
+ return 0 , syscall . EAFNOSUPPORT
140
146
}
141
-
142
- return & bind , port , nil
143
- }
144
-
145
- func (bind * nativeBind ) LastMark () uint32 {
146
- return bind .lastMark
147
+ return port , nil
147
148
}
148
149
149
- func (bind * nativeBind ) SetMark (value uint32 ) error {
150
+ func (bind * LinuxSocketBind ) SetMark (value uint32 ) error {
150
151
bind .closing .RLock ()
151
152
defer bind .closing .RUnlock ()
152
153
@@ -180,7 +181,7 @@ func (bind *nativeBind) SetMark(value uint32) error {
180
181
return nil
181
182
}
182
183
183
- func (bind * nativeBind ) Close () error {
184
+ func (bind * LinuxSocketBind ) Close () error {
184
185
var err1 , err2 error
185
186
bind .closing .RLock ()
186
187
if bind .sock6 != - 1 {
@@ -207,11 +208,11 @@ func (bind *nativeBind) Close() error {
207
208
return err2
208
209
}
209
210
210
- func (bind * nativeBind ) ReceiveIPv6 (buff []byte ) (int , Endpoint , error ) {
211
+ func (bind * LinuxSocketBind ) ReceiveIPv6 (buff []byte ) (int , Endpoint , error ) {
211
212
bind .closing .RLock ()
212
213
defer bind .closing .RUnlock ()
213
214
214
- var end NativeEndpoint
215
+ var end LinuxSocketEndpoint
215
216
if bind .sock6 == - 1 {
216
217
return 0 , nil , net .ErrClosed
217
218
}
@@ -223,11 +224,11 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
223
224
return n , & end , err
224
225
}
225
226
226
- func (bind * nativeBind ) ReceiveIPv4 (buff []byte ) (int , Endpoint , error ) {
227
+ func (bind * LinuxSocketBind ) ReceiveIPv4 (buff []byte ) (int , Endpoint , error ) {
227
228
bind .closing .RLock ()
228
229
defer bind .closing .RUnlock ()
229
230
230
- var end NativeEndpoint
231
+ var end LinuxSocketEndpoint
231
232
if bind .sock4 == - 1 {
232
233
return 0 , nil , net .ErrClosed
233
234
}
@@ -239,11 +240,14 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
239
240
return n , & end , err
240
241
}
241
242
242
- func (bind * nativeBind ) Send (buff []byte , end Endpoint ) error {
243
+ func (bind * LinuxSocketBind ) Send (buff []byte , end Endpoint ) error {
243
244
bind .closing .RLock ()
244
245
defer bind .closing .RUnlock ()
245
246
246
- nend := end .(* NativeEndpoint )
247
+ nend , ok := end .(* LinuxSocketEndpoint )
248
+ if ! ok {
249
+ return ErrWrongEndpointType
250
+ }
247
251
if ! nend .isV6 {
248
252
if bind .sock4 == - 1 {
249
253
return net .ErrClosed
@@ -257,7 +261,7 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
257
261
}
258
262
}
259
263
260
- func (end * NativeEndpoint ) SrcIP () net.IP {
264
+ func (end * LinuxSocketEndpoint ) SrcIP () net.IP {
261
265
if ! end .isV6 {
262
266
return net .IPv4 (
263
267
end .src4 ().Src [0 ],
@@ -270,7 +274,7 @@ func (end *NativeEndpoint) SrcIP() net.IP {
270
274
}
271
275
}
272
276
273
- func (end * NativeEndpoint ) DstIP () net.IP {
277
+ func (end * LinuxSocketEndpoint ) DstIP () net.IP {
274
278
if ! end .isV6 {
275
279
return net .IPv4 (
276
280
end .dst4 ().Addr [0 ],
@@ -283,19 +287,19 @@ func (end *NativeEndpoint) DstIP() net.IP {
283
287
}
284
288
}
285
289
286
- func (end * NativeEndpoint ) DstToBytes () []byte {
290
+ func (end * LinuxSocketEndpoint ) DstToBytes () []byte {
287
291
if ! end .isV6 {
288
292
return (* [unsafe .Offsetof (end .dst4 ().Addr ) + unsafe .Sizeof (end .dst4 ().Addr )]byte )(unsafe .Pointer (end .dst4 ()))[:]
289
293
} else {
290
294
return (* [unsafe .Offsetof (end .dst6 ().Addr ) + unsafe .Sizeof (end .dst6 ().Addr )]byte )(unsafe .Pointer (end .dst6 ()))[:]
291
295
}
292
296
}
293
297
294
- func (end * NativeEndpoint ) SrcToString () string {
298
+ func (end * LinuxSocketEndpoint ) SrcToString () string {
295
299
return end .SrcIP ().String ()
296
300
}
297
301
298
- func (end * NativeEndpoint ) DstToString () string {
302
+ func (end * LinuxSocketEndpoint ) DstToString () string {
299
303
var udpAddr net.UDPAddr
300
304
udpAddr .IP = end .DstIP ()
301
305
if ! end .isV6 {
@@ -306,13 +310,13 @@ func (end *NativeEndpoint) DstToString() string {
306
310
return udpAddr .String ()
307
311
}
308
312
309
- func (end * NativeEndpoint ) ClearDst () {
313
+ func (end * LinuxSocketEndpoint ) ClearDst () {
310
314
for i := range end .dst {
311
315
end .dst [i ] = 0
312
316
}
313
317
}
314
318
315
- func (end * NativeEndpoint ) ClearSrc () {
319
+ func (end * LinuxSocketEndpoint ) ClearSrc () {
316
320
for i := range end .src {
317
321
end .src [i ] = 0
318
322
}
@@ -427,7 +431,7 @@ func create6(port uint16) (int, uint16, error) {
427
431
return fd , uint16 (addr .Port ), err
428
432
}
429
433
430
- func send4 (sock int , end * NativeEndpoint , buff []byte ) error {
434
+ func send4 (sock int , end * LinuxSocketEndpoint , buff []byte ) error {
431
435
432
436
// construct message header
433
437
@@ -467,7 +471,7 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
467
471
return err
468
472
}
469
473
470
- func send6 (sock int , end * NativeEndpoint , buff []byte ) error {
474
+ func send6 (sock int , end * LinuxSocketEndpoint , buff []byte ) error {
471
475
472
476
// construct message header
473
477
@@ -511,7 +515,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
511
515
return err
512
516
}
513
517
514
- func receive4 (sock int , buff []byte , end * NativeEndpoint ) (int , error ) {
518
+ func receive4 (sock int , buff []byte , end * LinuxSocketEndpoint ) (int , error ) {
515
519
516
520
// construct message header
517
521
@@ -543,7 +547,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
543
547
return size , nil
544
548
}
545
549
546
- func receive6 (sock int , buff []byte , end * NativeEndpoint ) (int , error ) {
550
+ func receive6 (sock int , buff []byte , end * LinuxSocketEndpoint ) (int , error ) {
547
551
548
552
// construct message header
549
553
0 commit comments