Skip to content

Commit 1be8ffa

Browse files
committed
Refactored the test files with helpers to test backend
``` func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool ```
1 parent 2b928d9 commit 1be8ffa

File tree

1 file changed

+93
-69
lines changed

1 file changed

+93
-69
lines changed

Diff for: tcpproxy_test.go

+93-69
Original file line numberDiff line numberDiff line change
@@ -169,38 +169,90 @@ func testProxy(t *testing.T, front net.Listener) *Proxy {
169169
}
170170
}
171171

172-
func TestProxyAlwaysMatch(t *testing.T) {
173-
front := newLocalListener(t)
174-
defer front.Close()
175-
back := newLocalListener(t)
176-
defer back.Close()
172+
func testRouteToBackendWithExpected(t *testing.T, toFront net.Conn, back net.Listener, msg string, expected string) {
173+
io.WriteString(toFront, msg)
174+
fromProxy, err := back.Accept()
175+
if err != nil {
176+
t.Fatal(err)
177+
}
177178

178-
p := testProxy(t, front)
179-
p.AddRoute(testFrontAddr, To(back.Addr().String()))
180-
if err := p.Start(); err != nil {
179+
buf := make([]byte, len(expected))
180+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
181181
t.Fatal(err)
182182
}
183+
if string(buf) != expected {
184+
t.Fatalf("got %q; want %q", buf, expected)
185+
}
186+
}
183187

188+
func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) {
184189
toFront, err := net.Dial("tcp", front.Addr().String())
185190
if err != nil {
186191
t.Fatal(err)
187192
}
188193
defer toFront.Close()
189194

190-
fromProxy, err := back.Accept()
195+
testRouteToBackendWithExpected(t, toFront, back, msg, msg)
196+
}
197+
198+
// test the backend is not receiving traffic
199+
func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool {
200+
done := make(chan bool)
201+
toFront, err := net.Dial("tcp", front.Addr().String())
191202
if err != nil {
192203
t.Fatal(err)
193204
}
194-
const msg = "message"
195-
io.WriteString(toFront, msg)
205+
defer toFront.Close()
196206

197-
buf := make([]byte, len(msg))
198-
if _, err := io.ReadFull(fromProxy, buf); err != nil {
207+
timeC := time.NewTimer(10 * time.Millisecond).C
208+
acceptC := make(chan struct{})
209+
go func() {
210+
io.WriteString(toFront, msg)
211+
fromProxy, err := back.Accept()
212+
acceptC <- struct{}{}
213+
{
214+
if err == nil {
215+
buf := make([]byte, len(msg))
216+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
217+
t.Fatal(err)
218+
}
219+
t.Fatalf("Expect backend to not receive message, but found %s", string(buf))
220+
}
221+
err, ok := err.(net.Error)
222+
if !ok || !err.Timeout() {
223+
t.Fatalf("Expect backend to timeout, but found err: %v", err)
224+
}
225+
}
226+
}()
227+
go func() {
228+
select {
229+
case <-timeC:
230+
{
231+
done <- true
232+
}
233+
case <-acceptC:
234+
{
235+
t.Fatal("Expect backend to not receive message")
236+
done <- true
237+
}
238+
}
239+
}()
240+
return done
241+
}
242+
243+
func TestProxyAlwaysMatch(t *testing.T) {
244+
front := newLocalListener(t)
245+
defer front.Close()
246+
back := newLocalListener(t)
247+
defer back.Close()
248+
249+
p := testProxy(t, front)
250+
p.AddRoute(testFrontAddr, To(back.Addr().String()))
251+
if err := p.Start(); err != nil {
199252
t.Fatal(err)
200253
}
201-
if string(buf) != msg {
202-
t.Fatalf("got %q; want %q", buf, msg)
203-
}
254+
255+
testRouteToBackend(t, front, back, "message")
204256
}
205257

206258
func TestProxyHTTP(t *testing.T) {
@@ -219,27 +271,9 @@ func TestProxyHTTP(t *testing.T) {
219271
t.Fatal(err)
220272
}
221273

222-
toFront, err := net.Dial("tcp", front.Addr().String())
223-
if err != nil {
224-
t.Fatal(err)
225-
}
226-
defer toFront.Close()
227-
228-
const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n"
229-
io.WriteString(toFront, msg)
230-
231-
fromProxy, err := backBar.Accept()
232-
if err != nil {
233-
t.Fatal(err)
234-
}
235-
236-
buf := make([]byte, len(msg))
237-
if _, err := io.ReadFull(fromProxy, buf); err != nil {
238-
t.Fatal(err)
239-
}
240-
if string(buf) != msg {
241-
t.Fatalf("got %q; want %q", buf, msg)
242-
}
274+
testRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n")
275+
<-testNotRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: boo.com\r\n\r\n")
276+
testRouteToBackend(t, front, backFoo, "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n")
243277
}
244278

245279
func TestProxySNI(t *testing.T) {
@@ -258,27 +292,32 @@ func TestProxySNI(t *testing.T) {
258292
t.Fatal(err)
259293
}
260294

261-
toFront, err := net.Dial("tcp", front.Addr().String())
262-
if err != nil {
263-
t.Fatal(err)
264-
}
265-
defer toFront.Close()
295+
testRouteToBackend(t, front, backBar, clientHelloRecord(t, "bar.com"))
296+
<-testNotRouteToBackend(t, front, backBar, clientHelloRecord(t, "foo.com"))
297+
testRouteToBackend(t, front, backFoo, clientHelloRecord(t, "foo.com"))
298+
}
266299

267-
msg := clientHelloRecord(t, "bar.com")
268-
io.WriteString(toFront, msg)
300+
func TestProxyRemoveRoute(t *testing.T) {
301+
front := newLocalListener(t)
302+
defer front.Close()
303+
p := testProxy(t, front)
269304

270-
fromProxy, err := backBar.Accept()
271-
if err != nil {
272-
t.Fatal(err)
273-
}
305+
// NOTE: Needs to register testFrontAddr before server starts
306+
p.AddSNIRoute(testFrontAddr, "unused.com", noopTarget{})
274307

275-
buf := make([]byte, len(msg))
276-
if _, err := io.ReadFull(fromProxy, buf); err != nil {
308+
if err := p.Start(); err != nil {
277309
t.Fatal(err)
278310
}
279-
if string(buf) != msg {
280-
t.Fatalf("got %q; want %q", buf, msg)
281-
}
311+
312+
backBar := newLocalListener(t)
313+
defer backBar.Close()
314+
routeID := p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
315+
316+
msg := clientHelloRecord(t, "bar.com")
317+
testRouteToBackend(t, front, backBar, msg)
318+
319+
p.RemoveRoute(testFrontAddr, routeID)
320+
<-testNotRouteToBackend(t, front, backBar, msg)
282321
}
283322

284323
func TestProxyPROXYOut(t *testing.T) {
@@ -301,23 +340,8 @@ func TestProxyPROXYOut(t *testing.T) {
301340
t.Fatal(err)
302341
}
303342

304-
io.WriteString(toFront, "foo")
305-
toFront.Close()
306-
307-
fromProxy, err := back.Accept()
308-
if err != nil {
309-
t.Fatal(err)
310-
}
311-
312-
bs, err := ioutil.ReadAll(fromProxy)
313-
if err != nil {
314-
t.Fatal(err)
315-
}
316-
317343
want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port)
318-
if string(bs) != want {
319-
t.Fatalf("got %q; want %q", bs, want)
320-
}
344+
testRouteToBackendWithExpected(t, toFront, back, "foo", want)
321345
}
322346

323347
type tlsServer struct {

0 commit comments

Comments
 (0)