Skip to content

Commit 74ca1dc

Browse files
committed
add Proxy.AddSNIRouteFunc to do lookups by SNI dynamically
1 parent 4e04b92 commit 74ca1dc

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

Diff for: sni.go

+22-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,16 @@ func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) {
5757
cfg.acmeTargets = append(cfg.acmeTargets, dest)
5858
}
5959

60-
p.addRoute(ipPort, sniMatch{matcher, dest})
60+
p.addRoute(ipPort, sniMatch{matcher: matcher, target: dest})
61+
}
62+
63+
// SNITargetFunc is the func callback used by Proxy.AddSNIRouteFunc.
64+
type SNITargetFunc func(ctx context.Context, sniName string) (t Target, ok bool)
65+
66+
// AddSNIRouteFunc adds a route to ipPort that matches an SNI request and calls
67+
// fn to map its nap to a target.
68+
func (p *Proxy) AddSNIRouteFunc(ipPort string, fn SNITargetFunc) {
69+
p.addRoute(ipPort, sniMatch{targetFunc: fn})
6170
}
6271

6372
// AddStopACMESearch prevents ACME probing of subsequent SNI routes.
@@ -71,10 +80,22 @@ func (p *Proxy) AddStopACMESearch(ipPort string) {
7180
type sniMatch struct {
7281
matcher Matcher
7382
target Target
83+
84+
// Alternatively, if targetFunc is non-nil, it's used instead:
85+
targetFunc SNITargetFunc
7486
}
7587

7688
func (m sniMatch) match(br *bufio.Reader) (Target, string) {
7789
sni := clientHelloServerName(br)
90+
if sni == "" {
91+
return nil, ""
92+
}
93+
if m.targetFunc != nil {
94+
if t, ok := m.targetFunc(context.TODO(), sni); ok {
95+
return t, sni
96+
}
97+
return nil, ""
98+
}
7899
if m.matcher(context.TODO(), sni) {
79100
return m.target, sni
80101
}

Diff for: tcpproxy_test.go

+45-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package tcpproxy
1717
import (
1818
"bufio"
1919
"bytes"
20+
"context"
2021
"crypto/rand"
2122
"crypto/rsa"
2223
"crypto/tls"
@@ -287,6 +288,49 @@ func TestProxySNI(t *testing.T) {
287288
}
288289
}
289290

291+
func TestAddSNIRouteFunc(t *testing.T) {
292+
front := newLocalListener(t)
293+
defer front.Close()
294+
295+
backFoo := newLocalListener(t)
296+
defer backFoo.Close()
297+
backBar := newLocalListener(t)
298+
defer backBar.Close()
299+
300+
p := testProxy(t, front)
301+
p.AddSNIRouteFunc(testFrontAddr, func(ctx context.Context, sniName string) (_ Target, ok bool) {
302+
if sniName == "bar.com" {
303+
return To(backBar.Addr().String()), true
304+
}
305+
t.Fatalf("failed to match %q", sniName)
306+
return nil, false
307+
})
308+
if err := p.Start(); err != nil {
309+
t.Fatal(err)
310+
}
311+
312+
toFront, err := net.Dial("tcp", front.Addr().String())
313+
if err != nil {
314+
t.Fatal(err)
315+
}
316+
defer toFront.Close()
317+
318+
msg := clientHelloRecord(t, "bar.com")
319+
io.WriteString(toFront, msg)
320+
321+
fromProxy, err := backBar.Accept()
322+
if err != nil {
323+
t.Fatal(err)
324+
}
325+
326+
buf := make([]byte, len(msg))
327+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
328+
t.Fatal(err)
329+
}
330+
if string(buf) != msg {
331+
t.Fatalf("got %q; want %q", buf, msg)
332+
}
333+
}
290334
func TestProxyPROXYOut(t *testing.T) {
291335
front := newLocalListener(t)
292336
defer front.Close()
@@ -362,7 +406,7 @@ func (t *tlsServer) Close() {
362406
// cert creates a well-formed, but completely insecure self-signed
363407
// cert for domain.
364408
func cert(t *testing.T, domain string) tls.Certificate {
365-
private, err := rsa.GenerateKey(rand.Reader, 512)
409+
private, err := rsa.GenerateKey(rand.Reader, 1024)
366410
if err != nil {
367411
t.Fatal(err)
368412
}

0 commit comments

Comments
 (0)