Skip to content

Commit 219296a

Browse files
committed
device: test up/down using virtual conn
This prevents port clashing bugs. Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent 7321491 commit 219296a

File tree

3 files changed

+155
-24
lines changed

3 files changed

+155
-24
lines changed

Diff for: conn/bindtest/bindtest.go

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/* SPDX-License-Identifier: MIT
2+
*
3+
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
4+
*/
5+
6+
package bindtest
7+
8+
import (
9+
"fmt"
10+
"math/rand"
11+
"net"
12+
"os"
13+
"strconv"
14+
15+
"golang.zx2c4.com/wireguard/conn"
16+
)
17+
18+
type ChannelBind struct {
19+
rx4, tx4 *chan []byte
20+
rx6, tx6 *chan []byte
21+
closeSignal chan bool
22+
source4, source6 ChannelEndpoint
23+
target4, target6 ChannelEndpoint
24+
}
25+
26+
type ChannelEndpoint uint16
27+
28+
var _ conn.Bind = (*ChannelBind)(nil)
29+
var _ conn.Endpoint = (*ChannelEndpoint)(nil)
30+
31+
func NewChannelBinds() [2]conn.Bind {
32+
arx4 := make(chan []byte, 8192)
33+
brx4 := make(chan []byte, 8192)
34+
arx6 := make(chan []byte, 8192)
35+
brx6 := make(chan []byte, 8192)
36+
var binds [2]ChannelBind
37+
binds[0].rx4 = &arx4
38+
binds[0].tx4 = &brx4
39+
binds[1].rx4 = &brx4
40+
binds[1].tx4 = &arx4
41+
binds[0].rx6 = &arx6
42+
binds[0].tx6 = &brx6
43+
binds[1].rx6 = &brx6
44+
binds[1].tx6 = &arx6
45+
binds[0].target4 = ChannelEndpoint(1)
46+
binds[1].target4 = ChannelEndpoint(2)
47+
binds[0].target6 = ChannelEndpoint(3)
48+
binds[1].target6 = ChannelEndpoint(4)
49+
binds[0].source4 = binds[1].target4
50+
binds[0].source6 = binds[1].target6
51+
binds[1].source4 = binds[0].target4
52+
binds[1].source6 = binds[0].target6
53+
return [2]conn.Bind{&binds[0], &binds[1]}
54+
}
55+
56+
func (c ChannelEndpoint) ClearSrc() {}
57+
58+
func (c ChannelEndpoint) SrcToString() string { return "" }
59+
60+
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
61+
62+
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
63+
64+
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
65+
66+
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
67+
68+
func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
69+
c.closeSignal = make(chan bool)
70+
if rand.Uint32()&1 == 0 {
71+
return uint16(c.source4), nil
72+
} else {
73+
return uint16(c.source6), nil
74+
}
75+
}
76+
77+
func (c *ChannelBind) Close() error {
78+
if c.closeSignal != nil {
79+
select {
80+
case <-c.closeSignal:
81+
default:
82+
close(c.closeSignal)
83+
}
84+
}
85+
return nil
86+
}
87+
88+
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
89+
90+
func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
91+
select {
92+
case <-c.closeSignal:
93+
return 0, nil, net.ErrClosed
94+
case rx := <-*c.rx6:
95+
return copy(b, rx), c.target6, nil
96+
}
97+
}
98+
99+
func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
100+
select {
101+
case <-c.closeSignal:
102+
return 0, nil, net.ErrClosed
103+
case rx := <-*c.rx4:
104+
return copy(b, rx), c.target4, nil
105+
}
106+
}
107+
108+
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
109+
select {
110+
case <-c.closeSignal:
111+
return net.ErrClosed
112+
default:
113+
bc := make([]byte, len(b))
114+
copy(bc, b)
115+
if ep.(ChannelEndpoint) == c.target4 {
116+
*c.tx4 <- bc
117+
} else if ep.(ChannelEndpoint) == c.target6 {
118+
*c.tx6 <- bc
119+
} else {
120+
return os.ErrInvalid
121+
}
122+
}
123+
return nil
124+
}
125+
126+
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
127+
_, port, err := net.SplitHostPort(s)
128+
if err != nil {
129+
return nil, err
130+
}
131+
i, err := strconv.ParseUint(port, 10, 16)
132+
if err != nil {
133+
return nil, err
134+
}
135+
return ChannelEndpoint(i), nil
136+
}

Diff for: device/device_test.go

+19-23
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ package device
88
import (
99
"bytes"
1010
"encoding/hex"
11-
"errors"
1211
"fmt"
1312
"io"
1413
"math/rand"
@@ -17,11 +16,11 @@ import (
1716
"runtime/pprof"
1817
"sync"
1918
"sync/atomic"
20-
"syscall"
2119
"testing"
2220
"time"
2321

2422
"golang.zx2c4.com/wireguard/conn"
23+
"golang.zx2c4.com/wireguard/conn/bindtest"
2524
"golang.zx2c4.com/wireguard/tun/tuntest"
2625
)
2726

@@ -148,8 +147,14 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
148147
}
149148

150149
// genTestPair creates a testPair.
151-
func genTestPair(tb testing.TB) (pair testPair) {
150+
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
152151
cfg, endpointCfg := genConfigs(tb)
152+
var binds [2]conn.Bind
153+
if realSocket {
154+
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
155+
} else {
156+
binds = bindtest.NewChannelBinds()
157+
}
153158
// Bring up a ChannelTun for each config.
154159
for i := range pair {
155160
p := &pair[i]
@@ -159,7 +164,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
159164
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
160165
level = LogLevelError
161166
}
162-
p.dev = NewDevice(p.tun.TUN(), conn.NewDefaultBind(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
167+
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
163168
if err := p.dev.IpcSet(cfg[i]); err != nil {
164169
tb.Errorf("failed to configure device %d: %v", i, err)
165170
p.dev.Close()
@@ -187,7 +192,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
187192

188193
func TestTwoDevicePing(t *testing.T) {
189194
goroutineLeakCheck(t)
190-
pair := genTestPair(t)
195+
pair := genTestPair(t, true)
191196
t.Run("ping 1.0.0.1", func(t *testing.T) {
192197
pair.Send(t, Ping, nil)
193198
})
@@ -198,11 +203,11 @@ func TestTwoDevicePing(t *testing.T) {
198203

199204
func TestUpDown(t *testing.T) {
200205
goroutineLeakCheck(t)
201-
const itrials = 20
202-
const otrials = 1
206+
const itrials = 50
207+
const otrials = 10
203208

204209
for n := 0; n < otrials; n++ {
205-
pair := genTestPair(t)
210+
pair := genTestPair(t, false)
206211
for i := range pair {
207212
for k := range pair[i].dev.peers.keyMap {
208213
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@@ -214,17 +219,8 @@ func TestUpDown(t *testing.T) {
214219
go func(d *Device) {
215220
defer wg.Done()
216221
for i := 0; i < itrials; i++ {
217-
start := time.Now()
218-
for {
219-
if err := d.Up(); err != nil {
220-
if errors.Is(err, syscall.EADDRINUSE) && time.Now().Sub(start) < time.Second*4 {
221-
// Some other test process is racing with us, so try again.
222-
time.Sleep(time.Millisecond * 10)
223-
continue
224-
}
225-
t.Errorf("failed up bring up device: %v", err)
226-
}
227-
break
222+
if err := d.Up(); err != nil {
223+
t.Errorf("failed up bring up device: %v", err)
228224
}
229225
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
230226
if err := d.Down(); err != nil {
@@ -245,7 +241,7 @@ func TestUpDown(t *testing.T) {
245241
// TestConcurrencySafety does other things concurrently with tunnel use.
246242
// It is intended to be used with the race detector to catch data races.
247243
func TestConcurrencySafety(t *testing.T) {
248-
pair := genTestPair(t)
244+
pair := genTestPair(t, true)
249245
done := make(chan struct{})
250246

251247
const warmupIters = 10
@@ -315,7 +311,7 @@ func TestConcurrencySafety(t *testing.T) {
315311
}
316312

317313
func BenchmarkLatency(b *testing.B) {
318-
pair := genTestPair(b)
314+
pair := genTestPair(b, true)
319315

320316
// Establish a connection.
321317
pair.Send(b, Ping, nil)
@@ -329,7 +325,7 @@ func BenchmarkLatency(b *testing.B) {
329325
}
330326

331327
func BenchmarkThroughput(b *testing.B) {
332-
pair := genTestPair(b)
328+
pair := genTestPair(b, true)
333329

334330
// Establish a connection.
335331
pair.Send(b, Ping, nil)
@@ -373,7 +369,7 @@ func BenchmarkThroughput(b *testing.B) {
373369
}
374370

375371
func BenchmarkUAPIGet(b *testing.B) {
376-
pair := genTestPair(b)
372+
pair := genTestPair(b, true)
377373
pair.Send(b, Ping, nil)
378374
pair.Send(b, Pong, nil)
379375
b.ReportAllocs()

Diff for: tun/tuntest/tuntest.go

-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
7979
return pkt
8080
}
8181

82-
// TODO(crawshaw): find a reusable home for this. package devicetest?
8382
type ChannelTUN struct {
8483
Inbound chan []byte // incoming packets, closed on TUN close
8584
Outbound chan []byte // outbound packets, blocks forever on TUN close

0 commit comments

Comments
 (0)