@@ -17,6 +17,7 @@ package tcpproxy
17
17
import (
18
18
"bufio"
19
19
"bytes"
20
+ "context"
20
21
"crypto/rand"
21
22
"crypto/rsa"
22
23
"crypto/tls"
@@ -287,6 +288,49 @@ func TestProxySNI(t *testing.T) {
287
288
}
288
289
}
289
290
291
+ func TestAddSNIRouteFunc (t * testing.T ) {
292
+ front := newLocalListener (t )
293
+ defer front .Close ()
294
+
295
+ backFoo := newLocalListener (t )
296
+ defer backFoo .Close ()
297
+ backBar := newLocalListener (t )
298
+ defer backBar .Close ()
299
+
300
+ p := testProxy (t , front )
301
+ p .AddSNIRouteFunc (testFrontAddr , func (ctx context.Context , sniName string ) (_ Target , ok bool ) {
302
+ if sniName == "bar.com" {
303
+ return To (backBar .Addr ().String ()), true
304
+ }
305
+ t .Fatalf ("failed to match %q" , sniName )
306
+ return nil , false
307
+ })
308
+ if err := p .Start (); err != nil {
309
+ t .Fatal (err )
310
+ }
311
+
312
+ toFront , err := net .Dial ("tcp" , front .Addr ().String ())
313
+ if err != nil {
314
+ t .Fatal (err )
315
+ }
316
+ defer toFront .Close ()
317
+
318
+ msg := clientHelloRecord (t , "bar.com" )
319
+ io .WriteString (toFront , msg )
320
+
321
+ fromProxy , err := backBar .Accept ()
322
+ if err != nil {
323
+ t .Fatal (err )
324
+ }
325
+
326
+ buf := make ([]byte , len (msg ))
327
+ if _ , err := io .ReadFull (fromProxy , buf ); err != nil {
328
+ t .Fatal (err )
329
+ }
330
+ if string (buf ) != msg {
331
+ t .Fatalf ("got %q; want %q" , buf , msg )
332
+ }
333
+ }
290
334
func TestProxyPROXYOut (t * testing.T ) {
291
335
front := newLocalListener (t )
292
336
defer front .Close ()
@@ -362,7 +406,7 @@ func (t *tlsServer) Close() {
362
406
// cert creates a well-formed, but completely insecure self-signed
363
407
// cert for domain.
364
408
func cert (t * testing.T , domain string ) tls.Certificate {
365
- private , err := rsa .GenerateKey (rand .Reader , 512 )
409
+ private , err := rsa .GenerateKey (rand .Reader , 1024 )
366
410
if err != nil {
367
411
t .Fatal (err )
368
412
}
0 commit comments