Skip to content

Commit bba783a

Browse files
committed
enhance: Add base context for h2 stream connections
1 parent 56cefbe commit bba783a

3 files changed

Lines changed: 40 additions & 12 deletions

File tree

examples/dumper/main.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,15 @@ func main() {
132132
)
133133
}
134134

135+
ctx, cancel := context.WithCancel(context.Background())
136+
135137
handler, err := mitmpgo.NewMitmProxyHandler(
136138
mitmpgo.WithCACertPath(caCertPath),
137139
mitmpgo.WithCAKeyPath(caKeyPath),
138140
mitmpgo.WithHTTPInterceptor(httpInterceptor),
139141
mitmpgo.WithWebsocketInterceptor(websocketInterceptor),
140142
mitmpgo.WithErrorHandler(errHandler),
143+
mitmpgo.WithStreamBaseContext(ctx),
141144
// mitmpgo.WithClientCert("127.0.0.1", mitmpgo.ClientCert{
142145
// CertPath: "certs/client.crt",
143146
// KeyPath: "certs/client.key",
@@ -158,6 +161,7 @@ func main() {
158161

159162
listenAddr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
160163
var closeFn func()
164+
161165
switch mitmMode {
162166
case "socks5":
163167
ln, err := net.Listen("tcp", listenAddr)
@@ -172,14 +176,15 @@ func main() {
172176
return
173177
}
174178
go func() {
175-
handler.ServeSOCKS5(context.Background(), conn)
179+
handler.ServeSOCKS5(ctx, conn)
176180
}()
177181
}
178182
}()
179183
default:
180184
server := &http.Server{
181-
Addr: listenAddr,
182-
Handler: handler,
185+
Addr: listenAddr,
186+
Handler: handler,
187+
BaseContext: func(l net.Listener) context.Context { return ctx },
183188
}
184189
closeFn = func() { server.Close() }
185190
go func() {
@@ -197,7 +202,8 @@ func main() {
197202
handler.Cleanup()
198203
slog.Info("exit")
199204
closeFn()
200-
time.Sleep(time.Millisecond * 100)
205+
cancel()
206+
time.Sleep(time.Millisecond * 500)
201207
}
202208

203209
func httpInterceptor(ctx context.Context, req *http.Request, invoker mitmpgo.HTTPDelegatedInvoker) (*http.Response, error) {

mitm.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ func (r *mitmProxyHandler) handleTunnelRequest(ctx context.Context, consumedRequ
660660
// and finally we only need to get the [http.Request] and process the [http.ResponseWriter].
661661
// Early process http2
662662
if state.NegotiatedProtocol == http2.NextProtoTLS {
663-
newCtx, cancel := context.WithCancel(context.Background())
663+
newCtx, cancel := context.WithCancel(r.streamBaseCtx)
664664
go func() {
665665
connCtx.local.waitClose()
666666
cancel()
@@ -673,7 +673,7 @@ func (r *mitmProxyHandler) handleTunnelRequest(ctx context.Context, consumedRequ
673673
}
674674
}
675675

676-
ctx, earlyDone, isWsUpgrade, err := r.distinguishHTTPRequest(ctx, srcConn, dstConn, tlsRequest)
676+
ctx, earlyDone, isWsUpgrade, err := r.distinguishHTTPRequest(ctx, srcConn, tlsRequest)
677677
if err != nil || earlyDone {
678678
return
679679
}
@@ -691,7 +691,7 @@ func (r *mitmProxyHandler) handleH2CRequest(ctx context.Context, rw http.Respons
691691
return false, err
692692
}
693693
connCtx := ctx.Value(connContextKey).(*biConnContext)
694-
newCtx, cancel := context.WithCancel(context.Background())
694+
newCtx, cancel := context.WithCancel(r.streamBaseCtx)
695695
go func() {
696696
connCtx.local.waitClose()
697697
cancel()
@@ -711,7 +711,7 @@ func (r *mitmProxyHandler) handleH2CRequest(ctx context.Context, rw http.Respons
711711
return false, err
712712
}
713713
connCtx := ctx.Value(connContextKey).(*biConnContext)
714-
newCtx, cancel := context.WithCancel(context.Background())
714+
newCtx, cancel := context.WithCancel(r.streamBaseCtx)
715715
go func() {
716716
connCtx.local.waitClose()
717717
cancel()
@@ -727,7 +727,7 @@ func (r *mitmProxyHandler) handleH2CRequest(ctx context.Context, rw http.Respons
727727
return false, nil
728728
}
729729

730-
func (r *mitmProxyHandler) distinguishHTTPRequest(ctx context.Context, srcConn, dstConn net.Conn, tlsRequest bool) (newCtx context.Context, earlyDone bool, upgrade bool, retErr error) {
730+
func (r *mitmProxyHandler) distinguishHTTPRequest(ctx context.Context, srcConn net.Conn, tlsRequest bool) (newCtx context.Context, earlyDone bool, upgrade bool, retErr error) {
731731
reqCtx, _ := FromRequestContext(ctx)
732732

733733
// Read the http request for https/wss via tls tunnel
@@ -859,13 +859,18 @@ func (r *mitmProxyHandler) relayConnForWS(ctx context.Context, srcConn, dstConn
859859
}
860860

861861
errCh := make(chan error, 2)
862-
relayWSMessage := func(dir WSDirection, src, dst *websocket.Conn) {
862+
relayWSMessage := func(ctx context.Context, dir WSDirection, src, dst *websocket.Conn) {
863863
defer func() {
864864
if fw != nil {
865865
fw.close()
866866
}
867867
}()
868868
for {
869+
select {
870+
case <-ctx.Done():
871+
return
872+
default:
873+
}
869874
msgType, buffer, err := readBufferFromWSConn(src)
870875
if err != nil {
871876
errCh <- err
@@ -885,8 +890,8 @@ func (r *mitmProxyHandler) relayConnForWS(ctx context.Context, srcConn, dstConn
885890
}
886891
}
887892
}
888-
go relayWSMessage(Send, wsSrcConn, wsDstConn)
889-
go relayWSMessage(Receive, wsDstConn, wsSrcConn)
893+
go relayWSMessage(ctx, Send, wsSrcConn, wsDstConn)
894+
go relayWSMessage(ctx, Receive, wsDstConn, wsSrcConn)
890895
err = <-errCh
891896
cancel(err)
892897
return

option.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package mitmpgo
22

33
import (
4+
"context"
45
"crypto/x509"
56
"net"
67
"time"
@@ -23,6 +24,8 @@ func (f OptionFunc) apply(o *options) { f(o) }
2324

2425
// options holds all configuration parameters for the MITM proxy handler.
2526
type options struct {
27+
streamBaseCtx context.Context // Stream Base Context for h2 connection
28+
2629
proxy string // Upstream proxy URL (e.g., "http://127.0.0.1:8080")
2730
caCertPath string // Path to the CA certificate file for TLS interception
2831
caKeyPath string // Path to the CA private key file for TLS interception
@@ -60,13 +63,27 @@ func newOptions(opt ...Option) *options {
6063
options := &options{
6164
dialer: &net.Dialer{Timeout: 15 * time.Second},
6265
wsMaxFramesPerForward: 2048,
66+
streamBaseCtx: context.Background(),
6367
}
6468
for _, o := range opt {
6569
o.apply(options)
6670
}
6771
return options
6872
}
6973

74+
// WithStreamBaseContext configures h2 connection stream base context.
75+
//
76+
// Example:
77+
//
78+
// handler, err := NewMitmProxyHandler(
79+
// WithStreamBaseContext(context.Background()),
80+
// )
81+
func WithStreamBaseContext(baseCtx context.Context) Option {
82+
return OptionFunc(func(o *options) {
83+
o.streamBaseCtx = baseCtx
84+
})
85+
}
86+
7087
// WithProxy configures an upstream proxy server for outbound connections.
7188
//
7289
// The proxy parameter should be a URL in one of these formats:

0 commit comments

Comments
 (0)