@@ -169,38 +169,90 @@ func testProxy(t *testing.T, front net.Listener) *Proxy {
169
169
}
170
170
}
171
171
172
- func TestProxyAlwaysMatch (t * testing.T ) {
173
- front := newLocalListener (t )
174
- defer front .Close ()
175
- back := newLocalListener (t )
176
- defer back .Close ()
172
+ func testRouteToBackendWithExpected (t * testing.T , toFront net.Conn , back net.Listener , msg string , expected string ) {
173
+ io .WriteString (toFront , msg )
174
+ fromProxy , err := back .Accept ()
175
+ if err != nil {
176
+ t .Fatal (err )
177
+ }
177
178
178
- p := testProxy (t , front )
179
- p .AddRoute (testFrontAddr , To (back .Addr ().String ()))
180
- if err := p .Start (); err != nil {
179
+ buf := make ([]byte , len (expected ))
180
+ if _ , err := io .ReadFull (fromProxy , buf ); err != nil {
181
181
t .Fatal (err )
182
182
}
183
+ if string (buf ) != expected {
184
+ t .Fatalf ("got %q; want %q" , buf , expected )
185
+ }
186
+ }
183
187
188
+ func testRouteToBackend (t * testing.T , front net.Listener , back net.Listener , msg string ) {
184
189
toFront , err := net .Dial ("tcp" , front .Addr ().String ())
185
190
if err != nil {
186
191
t .Fatal (err )
187
192
}
188
193
defer toFront .Close ()
189
194
190
- fromProxy , err := back .Accept ()
195
+ testRouteToBackendWithExpected (t , toFront , back , msg , msg )
196
+ }
197
+
198
+ // test the backend is not receiving traffic
199
+ func testNotRouteToBackend (t * testing.T , front net.Listener , back net.Listener , msg string ) <- chan bool {
200
+ done := make (chan bool )
201
+ toFront , err := net .Dial ("tcp" , front .Addr ().String ())
191
202
if err != nil {
192
203
t .Fatal (err )
193
204
}
194
- const msg = "message"
195
- io .WriteString (toFront , msg )
205
+ defer toFront .Close ()
196
206
197
- buf := make ([]byte , len (msg ))
198
- if _ , err := io .ReadFull (fromProxy , buf ); err != nil {
207
+ timeC := time .NewTimer (10 * time .Millisecond ).C
208
+ acceptC := make (chan struct {})
209
+ go func () {
210
+ io .WriteString (toFront , msg )
211
+ fromProxy , err := back .Accept ()
212
+ acceptC <- struct {}{}
213
+ {
214
+ if err == nil {
215
+ buf := make ([]byte , len (msg ))
216
+ if _ , err := io .ReadFull (fromProxy , buf ); err != nil {
217
+ t .Fatal (err )
218
+ }
219
+ t .Fatalf ("Expect backend to not receive message, but found %s" , string (buf ))
220
+ }
221
+ err , ok := err .(net.Error )
222
+ if ! ok || ! err .Timeout () {
223
+ t .Fatalf ("Expect backend to timeout, but found err: %v" , err )
224
+ }
225
+ }
226
+ }()
227
+ go func () {
228
+ select {
229
+ case <- timeC :
230
+ {
231
+ done <- true
232
+ }
233
+ case <- acceptC :
234
+ {
235
+ t .Fatal ("Expect backend to not receive message" )
236
+ done <- true
237
+ }
238
+ }
239
+ }()
240
+ return done
241
+ }
242
+
243
+ func TestProxyAlwaysMatch (t * testing.T ) {
244
+ front := newLocalListener (t )
245
+ defer front .Close ()
246
+ back := newLocalListener (t )
247
+ defer back .Close ()
248
+
249
+ p := testProxy (t , front )
250
+ p .AddRoute (testFrontAddr , To (back .Addr ().String ()))
251
+ if err := p .Start (); err != nil {
199
252
t .Fatal (err )
200
253
}
201
- if string (buf ) != msg {
202
- t .Fatalf ("got %q; want %q" , buf , msg )
203
- }
254
+
255
+ testRouteToBackend (t , front , back , "message" )
204
256
}
205
257
206
258
func TestProxyHTTP (t * testing.T ) {
@@ -219,27 +271,9 @@ func TestProxyHTTP(t *testing.T) {
219
271
t .Fatal (err )
220
272
}
221
273
222
- toFront , err := net .Dial ("tcp" , front .Addr ().String ())
223
- if err != nil {
224
- t .Fatal (err )
225
- }
226
- defer toFront .Close ()
227
-
228
- const msg = "GET / HTTP/1.1\r \n Host: bar.com\r \n \r \n "
229
- io .WriteString (toFront , msg )
230
-
231
- fromProxy , err := backBar .Accept ()
232
- if err != nil {
233
- t .Fatal (err )
234
- }
235
-
236
- buf := make ([]byte , len (msg ))
237
- if _ , err := io .ReadFull (fromProxy , buf ); err != nil {
238
- t .Fatal (err )
239
- }
240
- if string (buf ) != msg {
241
- t .Fatalf ("got %q; want %q" , buf , msg )
242
- }
274
+ testRouteToBackend (t , front , backBar , "GET / HTTP/1.1\r \n Host: bar.com\r \n \r \n " )
275
+ <- testNotRouteToBackend (t , front , backBar , "GET / HTTP/1.1\r \n Host: boo.com\r \n \r \n " )
276
+ testRouteToBackend (t , front , backFoo , "GET / HTTP/1.1\r \n Host: foo.com\r \n \r \n " )
243
277
}
244
278
245
279
func TestProxySNI (t * testing.T ) {
@@ -258,27 +292,32 @@ func TestProxySNI(t *testing.T) {
258
292
t .Fatal (err )
259
293
}
260
294
261
- toFront , err := net .Dial ("tcp" , front .Addr ().String ())
262
- if err != nil {
263
- t .Fatal (err )
264
- }
265
- defer toFront .Close ()
295
+ testRouteToBackend (t , front , backBar , clientHelloRecord (t , "bar.com" ))
296
+ <- testNotRouteToBackend (t , front , backBar , clientHelloRecord (t , "foo.com" ))
297
+ testRouteToBackend (t , front , backFoo , clientHelloRecord (t , "foo.com" ))
298
+ }
266
299
267
- msg := clientHelloRecord (t , "bar.com" )
268
- io .WriteString (toFront , msg )
300
+ func TestProxyRemoveRoute (t * testing.T ) {
301
+ front := newLocalListener (t )
302
+ defer front .Close ()
303
+ p := testProxy (t , front )
269
304
270
- fromProxy , err := backBar .Accept ()
271
- if err != nil {
272
- t .Fatal (err )
273
- }
305
+ // NOTE: Needs to register testFrontAddr before server starts
306
+ p .AddSNIRoute (testFrontAddr , "unused.com" , noopTarget {})
274
307
275
- buf := make ([]byte , len (msg ))
276
- if _ , err := io .ReadFull (fromProxy , buf ); err != nil {
308
+ if err := p .Start (); err != nil {
277
309
t .Fatal (err )
278
310
}
279
- if string (buf ) != msg {
280
- t .Fatalf ("got %q; want %q" , buf , msg )
281
- }
311
+
312
+ backBar := newLocalListener (t )
313
+ defer backBar .Close ()
314
+ routeID := p .AddSNIRoute (testFrontAddr , "bar.com" , To (backBar .Addr ().String ()))
315
+
316
+ msg := clientHelloRecord (t , "bar.com" )
317
+ testRouteToBackend (t , front , backBar , msg )
318
+
319
+ p .RemoveRoute (testFrontAddr , routeID )
320
+ <- testNotRouteToBackend (t , front , backBar , msg )
282
321
}
283
322
284
323
func TestProxyPROXYOut (t * testing.T ) {
@@ -301,23 +340,8 @@ func TestProxyPROXYOut(t *testing.T) {
301
340
t .Fatal (err )
302
341
}
303
342
304
- io .WriteString (toFront , "foo" )
305
- toFront .Close ()
306
-
307
- fromProxy , err := back .Accept ()
308
- if err != nil {
309
- t .Fatal (err )
310
- }
311
-
312
- bs , err := ioutil .ReadAll (fromProxy )
313
- if err != nil {
314
- t .Fatal (err )
315
- }
316
-
317
343
want := fmt .Sprintf ("PROXY TCP4 %s %d %s %d\r \n foo" , toFront .LocalAddr ().(* net.TCPAddr ).IP , toFront .LocalAddr ().(* net.TCPAddr ).Port , toFront .RemoteAddr ().(* net.TCPAddr ).IP , toFront .RemoteAddr ().(* net.TCPAddr ).Port )
318
- if string (bs ) != want {
319
- t .Fatalf ("got %q; want %q" , bs , want )
320
- }
344
+ testRouteToBackendWithExpected (t , toFront , back , "foo" , want )
321
345
}
322
346
323
347
type tlsServer struct {
0 commit comments