@@ -660,7 +660,7 @@ func (r *mitmProxyHandler) handleTunnelRequest(ctx context.Context, consumedRequ
660660 // and finally we only need to get the [http.Request] and process the [http.ResponseWriter].
661661 // Early process http2
662662 if state .NegotiatedProtocol == http2 .NextProtoTLS {
663- newCtx , cancel := context .WithCancel (context . Background () )
663+ newCtx , cancel := context .WithCancel (r . streamBaseCtx )
664664 go func () {
665665 connCtx .local .waitClose ()
666666 cancel ()
@@ -673,7 +673,7 @@ func (r *mitmProxyHandler) handleTunnelRequest(ctx context.Context, consumedRequ
673673 }
674674 }
675675
676- ctx , earlyDone , isWsUpgrade , err := r .distinguishHTTPRequest (ctx , srcConn , dstConn , tlsRequest )
676+ ctx , earlyDone , isWsUpgrade , err := r .distinguishHTTPRequest (ctx , srcConn , tlsRequest )
677677 if err != nil || earlyDone {
678678 return
679679 }
@@ -691,7 +691,7 @@ func (r *mitmProxyHandler) handleH2CRequest(ctx context.Context, rw http.Respons
691691 return false , err
692692 }
693693 connCtx := ctx .Value (connContextKey ).(* biConnContext )
694- newCtx , cancel := context .WithCancel (context . Background () )
694+ newCtx , cancel := context .WithCancel (r . streamBaseCtx )
695695 go func () {
696696 connCtx .local .waitClose ()
697697 cancel ()
@@ -711,7 +711,7 @@ func (r *mitmProxyHandler) handleH2CRequest(ctx context.Context, rw http.Respons
711711 return false , err
712712 }
713713 connCtx := ctx .Value (connContextKey ).(* biConnContext )
714- newCtx , cancel := context .WithCancel (context . Background () )
714+ newCtx , cancel := context .WithCancel (r . streamBaseCtx )
715715 go func () {
716716 connCtx .local .waitClose ()
717717 cancel ()
@@ -727,7 +727,7 @@ func (r *mitmProxyHandler) handleH2CRequest(ctx context.Context, rw http.Respons
727727 return false , nil
728728}
729729
730- func (r * mitmProxyHandler ) distinguishHTTPRequest (ctx context.Context , srcConn , dstConn net.Conn , tlsRequest bool ) (newCtx context.Context , earlyDone bool , upgrade bool , retErr error ) {
730+ func (r * mitmProxyHandler ) distinguishHTTPRequest (ctx context.Context , srcConn net.Conn , tlsRequest bool ) (newCtx context.Context , earlyDone bool , upgrade bool , retErr error ) {
731731 reqCtx , _ := FromRequestContext (ctx )
732732
733733 // Read the http request for https/wss via tls tunnel
@@ -859,13 +859,18 @@ func (r *mitmProxyHandler) relayConnForWS(ctx context.Context, srcConn, dstConn
859859 }
860860
861861 errCh := make (chan error , 2 )
862- relayWSMessage := func (dir WSDirection , src , dst * websocket.Conn ) {
862+ relayWSMessage := func (ctx context. Context , dir WSDirection , src , dst * websocket.Conn ) {
863863 defer func () {
864864 if fw != nil {
865865 fw .close ()
866866 }
867867 }()
868868 for {
869+ select {
870+ case <- ctx .Done ():
871+ return
872+ default :
873+ }
869874 msgType , buffer , err := readBufferFromWSConn (src )
870875 if err != nil {
871876 errCh <- err
@@ -885,8 +890,8 @@ func (r *mitmProxyHandler) relayConnForWS(ctx context.Context, srcConn, dstConn
885890 }
886891 }
887892 }
888- go relayWSMessage (Send , wsSrcConn , wsDstConn )
889- go relayWSMessage (Receive , wsDstConn , wsSrcConn )
893+ go relayWSMessage (ctx , Send , wsSrcConn , wsDstConn )
894+ go relayWSMessage (ctx , Receive , wsDstConn , wsSrcConn )
890895 err = <- errCh
891896 cancel (err )
892897 return
0 commit comments