Skip to content

Commit 06342ad

Browse files
committed
Optimize TLS and common request code logic
1 parent 918721d commit 06342ad

1 file changed

Lines changed: 42 additions & 61 deletions

File tree

mitm.go

Lines changed: 42 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,12 @@ const preboundConnKey contextKey = "prebound-net-conn"
3636
type reqContextKey struct{}
3737

3838
type ReqContext struct {
39-
// used for https proxy and sock5 (connect) proxy
40-
Connect bool
4139
Hostport string
4240
Request *http.Request
4341
}
4442

45-
func AppendToRequestContext(ctx context.Context, connect bool, hostport string, request *http.Request) context.Context {
43+
func AppendToRequestContext(ctx context.Context, hostport string, request *http.Request) context.Context {
4644
reqCtx := ReqContext{
47-
Connect: connect,
4845
Hostport: hostport,
4946
Request: request,
5047
}
@@ -263,17 +260,15 @@ func (r *mitmProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
263260
if err != nil {
264261
return
265262
}
266-
var connect bool
267263
if req.Method == http.MethodConnect {
268-
connect = true
269264
request = nil
270265
conn.Write(HttpResponseConnectionEstablished)
271266
} else if req.URL != nil && len(req.URL.Scheme) == 0 {
272267
// directly access proxy server and url scheme is empty
273268
err = ErrInvalidProxyRequest
274269
return
275270
}
276-
err = r.Serve(AppendToRequestContext(req.Context(), connect, hostport, request), conn)
271+
err = r.Serve(AppendToRequestContext(req.Context(), hostport, request), conn)
277272
}
278273

279274
func (r *mitmProxyHandler) ServeSOCKS5(ctx context.Context, conn net.Conn) error {
@@ -294,7 +289,7 @@ func (r *mitmProxyHandler) ServeSOCKS5(ctx context.Context, conn net.Conn) error
294289
if hostport, err = r.handleSocks5Request(ctx, conn); err != nil {
295290
return err
296291
}
297-
err = r.Serve(AppendToRequestContext(ctx, true, hostport, nil), conn)
292+
err = r.Serve(AppendToRequestContext(ctx, hostport, nil), conn)
298293
return err
299294
}
300295

@@ -317,25 +312,7 @@ func (r *mitmProxyHandler) Serve(ctx context.Context, conn net.Conn) error {
317312
md.Set(metadata.ConnectionSourceAddrPort, getAddrPortFromConn(conn))
318313
ctx = metadata.AppendToContext(ctx, md)
319314

320-
if reqCtx.Connect {
321-
// handle https/ws/wss request
322-
return r.handleConnectRequest(ctx, conn)
323-
} else {
324-
// handle common http request(include h2c upgrade)
325-
dstConn, err := r.proxyDialer.DialTCPContext(ctx, reqCtx.Hostport)
326-
if err != nil {
327-
return err
328-
}
329-
defer dstConn.Close()
330-
if !r.disableHTTP2 {
331-
rw := newFakeHttpResponseWriter(conn)
332-
earlyDone, err := r.handleH2CRequest(ctx, rw, reqCtx.Request, dstConn)
333-
if err != nil || earlyDone {
334-
return err
335-
}
336-
}
337-
return r.relayConnForHTTP(ctx, conn, dstConn)
338-
}
315+
return r.handleTunnelRequest(ctx, conn, reqCtx.Request != nil)
339316
}
340317

341318
func (r *mitmProxyHandler) shouldPassthroughRequest(hostport string) bool {
@@ -364,7 +341,7 @@ func (r *mitmProxyHandler) passthroughTunnel(ctx context.Context, srcConn net.Co
364341
return err
365342
}
366343
// only write the request for none-CONNECT request
367-
if !reqCtx.Connect && reqCtx.Request != nil {
344+
if reqCtx.Request != nil {
368345
// we should copy the request to dst connection firstly
369346
// TODO: if upload large file, this will cause performance problem
370347
if err = reqCtx.Request.Write(dstConn); err != nil {
@@ -492,24 +469,26 @@ func isTLSRequest(data []byte) bool {
492469
return data[0] == 0x16 && data[1] == 0x3 && (data[2] >= 0x1 && data[2] <= 0x3) && data[5] == 0x1
493470
}
494471

495-
func (r *mitmProxyHandler) handleConnectRequest(ctx context.Context, conn net.Conn) (err error) {
496-
bufConn := newBufConn(conn)
497-
data, err := bufConn.Peek(6)
498-
if err != nil {
499-
return err
500-
}
501-
if len(data) < 6 {
502-
return ErrShortTLSPacket
472+
func (r *mitmProxyHandler) handleTunnelRequest(ctx context.Context, conn net.Conn, consumedRequest bool) (err error) {
473+
var data []byte
474+
475+
if !consumedRequest {
476+
bufConn := newBufConn(conn)
477+
data, err = bufConn.Peek(6)
478+
if err != nil {
479+
return err
480+
}
481+
conn = bufConn
503482
}
504483

505484
var srcConn, dstConn net.Conn
506485
// Check if the common http/websocket request with tls
507-
if isTLSRequest(data) {
486+
if len(data) >= 6 && isTLSRequest(data) {
508487
clientHelloInfoCh := make(chan *tls.ClientHelloInfo, 1)
509488
tlsConnCh := make(chan net.Conn, 1)
510489
tlsConfigCh := make(chan *tls.Config, 1)
511490
errCh := make(chan error, 1)
512-
tlsConn := tls.Server(bufConn, &tls.Config{
491+
tlsConn := tls.Server(conn, &tls.Config{
513492
GetConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
514493
clientHelloInfoCh <- chi
515494
select {
@@ -571,7 +550,7 @@ func (r *mitmProxyHandler) handleConnectRequest(ctx context.Context, conn net.Co
571550
if err != nil {
572551
return
573552
}
574-
srcConn = bufConn
553+
srcConn = conn
575554
}
576555
defer dstConn.Close()
577556

@@ -619,45 +598,47 @@ func (r *mitmProxyHandler) handleH2CRequest(ctx context.Context, rw http.Respons
619598
}
620599

621600
func (r *mitmProxyHandler) distinguishHTTPRequest(ctx context.Context, srcConn, dstConn net.Conn) (newCtx context.Context, earlyDone bool, upgrade bool, retErr error) {
622-
newCtx = ctx
623601
reqCtx, _ := FromRequestContext(ctx)
624602

625603
// Read the http request for https/wss via tls tunnel
626-
if reqCtx.Connect {
627-
fakerw := newFakeHttpResponseWriter(srcConn)
604+
fakerw := newFakeHttpResponseWriter(srcConn)
605+
request := reqCtx.Request
606+
607+
// Need to read the request
608+
if request == nil {
628609
_, rw, err := fakerw.Hijack()
629610
if err != nil {
630611
retErr = err
631612
return
632613
}
633-
request, err := http.ReadRequest(rw.Reader)
614+
request, err = http.ReadRequest(rw.Reader)
634615
if err != nil {
635616
retErr = err
636617
return
637618
}
619+
}
638620

639-
if !r.disableHTTP2 {
640-
// If it's a SOCKS proxy, then the request might be h2c.
641-
earlyDone, retErr = r.handleH2CRequest(ctx, fakerw, request, dstConn)
642-
if retErr != nil || earlyDone {
643-
return
644-
}
621+
if !r.disableHTTP2 {
622+
// If it's a SOCKS proxy, then the request might be h2c.
623+
earlyDone, retErr = r.handleH2CRequest(ctx, fakerw, request, dstConn)
624+
if retErr != nil || earlyDone {
625+
return
645626
}
627+
}
646628

647-
// The request url scheme can be either http or https and we don't care for HTTP1 transport
648-
// Because the inner Dial and DialTLS functions were overwritten and replaced with custom net.Conn
649-
request.URL.Scheme = "http"
650-
request.URL.Host = request.Host
629+
// The request url scheme can be either http or https and we don't care for HTTP1 transport
630+
// Because the inner Dial and DialTLS functions were overwritten and replaced with custom net.Conn
631+
request.URL.Scheme = "http"
632+
request.URL.Host = request.Host
651633

652-
if upgrade = isWSUpgrade(request.Header); upgrade {
653-
request.URL.Scheme = "ws"
654-
}
655-
656-
// patch the new request to the request context
657-
reqCtx.Request = request
658-
newCtx = AppendToRequestContext(ctx, reqCtx.Connect, reqCtx.Hostport, reqCtx.Request)
634+
if upgrade = isWSUpgrade(request.Header); upgrade {
635+
request.URL.Scheme = "ws"
659636
}
660637

638+
// patch the new request to the request context
639+
reqCtx.Request = request
640+
newCtx = AppendToRequestContext(ctx, reqCtx.Hostport, reqCtx.Request)
641+
661642
return
662643
}
663644

@@ -771,7 +752,7 @@ func (r *mitmProxyHandler) serveHTTP2Handler(ctx context.Context, dstConn net.Co
771752
if req.URL.Host == "" {
772753
req.URL.Host = req.Host
773754
}
774-
ctx = AppendToRequestContext(ctx, reqCtx.Connect, reqCtx.Hostport, req)
755+
ctx = AppendToRequestContext(ctx, reqCtx.Hostport, req)
775756
response, err := r.roundTripWithContext(ctx, dstConn)
776757
if err != nil {
777758
r.handleError(ErrorContext{

0 commit comments

Comments
 (0)