Skip to content

Commit 988c0d7

Browse files
author
christianbagley-viz
committed
Implementing gorilla#740
1 parent 78cf1bc commit 988c0d7

File tree

5 files changed

+131
-5
lines changed

5 files changed

+131
-5
lines changed

client.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
303303
return nil, nil, err
304304
}
305305
if proxyURL != nil {
306-
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
306+
proxyDialer := &netDialerFunc{fn: netDial}
307+
modifyProxyDialer(ctx, d, proxyURL, proxyDialer)
308+
dialer, err := proxy_FromURL(proxyURL, proxyDialer)
307309
if err != nil {
308310
return nil, nil, err
309311
}

client_server_httpsproxy_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//go:build go1.15
2+
// +build go1.15
3+
4+
package websocket
5+
6+
import (
7+
"crypto/tls"
8+
"net/http"
9+
"net/url"
10+
"testing"
11+
)
12+
13+
func TestHttpsProxy(t *testing.T) {
14+
15+
sTLS := newTLSServer(t)
16+
defer sTLS.Close()
17+
s := newServer(t)
18+
defer s.Close()
19+
20+
surlTLS, _ := url.Parse(sTLS.Server.URL)
21+
22+
cstDialer := cstDialer // make local copy for modification on next line.
23+
cstDialer.Proxy = http.ProxyURL(surlTLS)
24+
25+
connect := false
26+
origHandler := sTLS.Server.Config.Handler
27+
28+
// Capture the request Host header.
29+
sTLS.Server.Config.Handler = http.HandlerFunc(
30+
func(w http.ResponseWriter, r *http.Request) {
31+
if r.Method == "CONNECT" {
32+
connect = true
33+
w.WriteHeader(http.StatusOK)
34+
return
35+
}
36+
37+
if !connect {
38+
t.Log("connect not received")
39+
http.Error(w, "connect not received", http.StatusMethodNotAllowed)
40+
return
41+
}
42+
origHandler.ServeHTTP(w, r)
43+
})
44+
45+
cstDialer.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, sTLS.Server)}
46+
ws, _, err := cstDialer.Dial(s.URL, nil)
47+
if err != nil {
48+
t.Fatalf("Dial: %v", err)
49+
}
50+
defer ws.Close()
51+
sendRecv(t, ws)
52+
}

proxy.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,37 @@ import (
1414
"strings"
1515
)
1616

17-
type netDialerFunc func(network, addr string) (net.Conn, error)
17+
// proxyDialerEx extends the generated proxy_Dialer
18+
type proxyDialerEx interface {
19+
proxy_Dialer
20+
// UsesTLS indicates whether we expect to dial to a TLS proxy
21+
UsesTLS() bool
22+
}
23+
24+
type netDialerFunc struct {
25+
fn func(network, addr string) (net.Conn, error)
26+
usesTLS bool
27+
}
1828

19-
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
20-
return fn(network, addr)
29+
func (ndf *netDialerFunc) Dial(network, addr string) (net.Conn, error) {
30+
return ndf.fn(network, addr)
31+
}
32+
33+
func (ndf *netDialerFunc) UsesTLS() bool {
34+
return ndf.usesTLS
2135
}
2236

2337
func init() {
2438
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
25-
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
39+
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, usesTLS: false}, nil
2640
})
41+
registerDialerHttps()
2742
}
2843

2944
type httpProxyDialer struct {
3045
proxyURL *url.URL
3146
forwardDial func(network, addr string) (net.Conn, error)
47+
usesTLS bool
3248
}
3349

3450
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
@@ -75,3 +91,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
7591
}
7692
return conn, nil
7793
}
94+
95+
func (hpd *httpProxyDialer) UsesTLS() bool {
96+
return hpd.usesTLS
97+
}

proxy_https.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//go:build go1.15
2+
// +build go1.15
3+
4+
package websocket
5+
6+
import (
7+
"context"
8+
"crypto/tls"
9+
"net"
10+
"net/url"
11+
)
12+
13+
func registerDialerHttps() {
14+
proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
15+
fwd := forwardDialer.Dial
16+
if dialerEx, ok := forwardDialer.(proxyDialerEx); !ok || !dialerEx.UsesTLS() {
17+
tlsDialer := &tls.Dialer{
18+
Config: &tls.Config{},
19+
NetDialer: &net.Dialer{},
20+
}
21+
fwd = tlsDialer.Dial
22+
}
23+
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil
24+
})
25+
}
26+
27+
func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) {
28+
if proxyURL.Scheme == "https" {
29+
proxyDialer.usesTLS = true
30+
proxyDialer.fn = func(network, addr string) (net.Conn, error) {
31+
t := tls.Dialer{}
32+
t.Config = d.TLSClientConfig
33+
t.NetDialer = &net.Dialer{}
34+
return t.DialContext(ctx, network, addr)
35+
}
36+
}
37+
}

proxy_https_legacy.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//go:build !go1.15
2+
// +build !go1.15
3+
4+
package websocket
5+
6+
import (
7+
"context"
8+
"net/url"
9+
)
10+
11+
func registerDialerHttps() {
12+
}
13+
14+
func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) {
15+
}

0 commit comments

Comments
 (0)