Skip to content

Commit 0a83bce

Browse files
committed
add Proxy.AddHTTPHostRouteFunc to match routes by hosts dynamically.
This is similar to Proxy.AddSNIRouteFunc. Signed-off-by: Nitin Jain <[email protected]>
1 parent 91f8614 commit 0a83bce

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

Diff for: http.go

+19-1
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,34 @@ func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) {
3838
//
3939
// The ipPort is any valid net.Listen TCP address.
4040
func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) {
41-
p.addRoute(ipPort, httpHostMatch{match, dest})
41+
p.addRoute(ipPort, httpHostMatch{matcher: match, target: dest})
42+
}
43+
44+
// HTTPHostTargetFunc is the func callback used by Proxy.AddHTTPHostRouteFunc.
45+
type HTTPHostTargetFunc func(ctx context.Context, httpHost string) (t Target, ok bool)
46+
47+
// AddHTTPHostRouteFunc adds a route to ipPort that matches an HTTP request and calls
48+
// fn to map it to a target.
49+
func (p *Proxy) AddHTTPHostRouteFunc(ipPort string, fn HTTPHostTargetFunc) {
50+
p.addRoute(ipPort, httpHostMatch{targetFunc: fn})
4251
}
4352

4453
type httpHostMatch struct {
4554
matcher Matcher
4655
target Target
56+
57+
// Alternatively, if targetFunc is non-nil, it's used instead:
58+
targetFunc HTTPHostTargetFunc
4759
}
4860

4961
func (m httpHostMatch) match(br *bufio.Reader) (Target, string) {
5062
hh := httpHostHeader(br)
63+
if m.targetFunc != nil {
64+
if t, ok := m.targetFunc(context.TODO(), hh); ok {
65+
return t, hh
66+
}
67+
return nil, ""
68+
}
5169
if m.matcher(context.TODO(), hh) {
5270
return m.target, hh
5371
}

Diff for: tcpproxy_test.go

+45-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func TestMatchHTTPHost(t *testing.T) {
7272
}
7373
t.Run(name, func(t *testing.T) {
7474
br := bufio.NewReader(tt.r)
75-
r := httpHostMatch{equals(tt.host), noopTarget{}}
75+
r := httpHostMatch{matcher: equals(tt.host), target: noopTarget{}}
7676
m, name := r.match(br)
7777
got := m != nil
7878
if got != tt.want {
@@ -247,6 +247,50 @@ func TestProxyHTTP(t *testing.T) {
247247
}
248248
}
249249

250+
func TestProxyHTTPFunc(t *testing.T) {
251+
front := newLocalListener(t)
252+
defer front.Close()
253+
254+
backFoo := newLocalListener(t)
255+
defer backFoo.Close()
256+
backBar := newLocalListener(t)
257+
defer backBar.Close()
258+
259+
p := testProxy(t, front)
260+
p.AddHTTPHostRouteFunc(testFrontAddr, func(ctx context.Context, httpHost string) (_ Target, ok bool) {
261+
if httpHost == "bar.com" {
262+
return To(backBar.Addr().String()), true
263+
}
264+
t.Fatalf("failed to match %q", httpHost)
265+
return nil, false
266+
})
267+
if err := p.Start(); err != nil {
268+
t.Fatal(err)
269+
}
270+
271+
toFront, err := net.Dial("tcp", front.Addr().String())
272+
if err != nil {
273+
t.Fatal(err)
274+
}
275+
defer toFront.Close()
276+
277+
const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n"
278+
io.WriteString(toFront, msg)
279+
280+
fromProxy, err := backBar.Accept()
281+
if err != nil {
282+
t.Fatal(err)
283+
}
284+
285+
buf := make([]byte, len(msg))
286+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
287+
t.Fatal(err)
288+
}
289+
if string(buf) != msg {
290+
t.Fatalf("got %q; want %q", buf, msg)
291+
}
292+
}
293+
250294
func TestProxySNI(t *testing.T) {
251295
front := newLocalListener(t)
252296
defer front.Close()

0 commit comments

Comments
 (0)