Skip to content

Commit 16bf9d5

Browse files
committed
Support websocket frames watcher in interceptor
1 parent e88cac6 commit 16bf9d5

9 files changed

Lines changed: 105 additions & 103 deletions

File tree

README.md

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,20 @@ handler, err := mitmpgo.NewMitmProxyHandler(
8888
### With WebSocket Interceptor
8989

9090
```go
91-
websocketInterceptor := func(ctx context.Context, dir metadata.WSDirection, msgType int, b *buf.Buffer, req *http.Request, invoker mitmpgo.WebsocketDelegatedInvoker) error {
91+
websocketInterceptor := func(ctx context.Context, req *http.Request, rsp *http.Response, fw mitmpgo.WebsocketFramesWatcher) {
9292
// Log WebSocket messages
93-
fmt.Printf("WS [%s] %s: %d bytes\n", dir, req.URL, b.Len())
94-
95-
// Forward the message
96-
return invoker.Invoke(msgType, b)
93+
log.Printf("WS url: %s", req.URL.String())
94+
95+
for frame := range fw.GetFrame() {
96+
dir := frame.Direction()
97+
msgType := frame.MessageType()
98+
dataBuf := frame.DataBuffer()
99+
log.Printf("---> %s %d %s", dir, msgType, dataBuf.String())
100+
if err := frame.Invoke(); err != nil {
101+
log.Printf("frame invoke error: %v", err)
102+
}
103+
frame.Release()
104+
}
97105
}
98106

99107
handler, err := mitmpgo.NewMitmProxyHandler(
@@ -163,6 +171,9 @@ mitmpgo.WithCertCachePool(2048, 30, 15)
163171
mitmpgo.WithDialer(&net.Dialer{
164172
Timeout: 30 * time.Second,
165173
})
174+
175+
// Maximum channel size of WebSocket frames
176+
mitmpgo.WithMaxWebsocketFramesPerForward(4096)
166177
```
167178

168179
### Interceptor Options
@@ -177,9 +188,6 @@ mitmpgo.WithWebsocketInterceptor(websocketInterceptor)
177188
// Chain multiple HTTP interceptors (executed in order)
178189
mitmpgo.WithChainHTTPInterceptor(interceptor1, interceptor2, interceptor3)
179190

180-
// Chain multiple WebSocket interceptors (executed in order)
181-
mitmpgo.WithChainWebsocketInterceptor(wsInterceptor1, wsInterceptor2)
182-
183191
// Set error handler
184192
mitmpgo.WithErrorHandler(func(ec mitmpgo.ErrorContext) {
185193
log.Printf("Error: %v", ec.Error)

buf/buffer.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ func (b *Buffer) String() string {
339339
if b == nil {
340340
return "<nil>"
341341
}
342+
if b.Len() == 0 {
343+
return ""
344+
}
342345
return unsafe.String(&b.Bytes()[0], b.Len())
343346
}
344347

bufconn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import (
55
"io"
66
"net"
77

8-
"github.com/gorilla/websocket"
98
"github.com/josexy/mitmpgo/buf"
9+
"github.com/josexy/websocket"
1010
)
1111

1212
type bufConn struct {

examples/chain-interceptors/main.go

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ import (
77
"log/slog"
88
"net/http"
99
"os"
10-
"time"
1110

1211
"github.com/josexy/mitmpgo"
13-
"github.com/josexy/mitmpgo/buf"
1412
)
1513

1614
func main() {
@@ -29,7 +27,6 @@ func main() {
2927
mitmpgo.WithCACertPath(caCertPath),
3028
mitmpgo.WithCAKeyPath(caKeyPath),
3129
mitmpgo.WithChainHTTPInterceptor(httpInterceptor1, httpInterceptor2, httpInterceptor3),
32-
mitmpgo.WithChainWebsocketInterceptor(websocketInterceptor1, websocketInterceptor2),
3330
)
3431
if err != nil {
3532
panic(err)
@@ -68,26 +65,3 @@ func httpInterceptor3(ctx context.Context, req *http.Request, invoker mitmpgo.HT
6865
slog.Debug("httpInterceptor3 after", slog.String("status", rsp.Status), slog.String("protocol", rsp.Proto))
6966
return rsp, err
7067
}
71-
72-
func websocketInterceptor1(ctx context.Context, dir mitmpgo.WSDirection, msgType int, b *buf.Buffer, req *http.Request, invoker mitmpgo.WebsocketDelegatedInvoker) error {
73-
slog.Debug("websocketInterceptor1 before", slog.String("dir", dir.String()), slog.Int("msgType", msgType), slog.Int("len", b.Len()))
74-
if dir == mitmpgo.Send {
75-
b.WriteString("->" + time.Now().Format(time.DateTime))
76-
}
77-
err := invoker.Invoke(msgType, b)
78-
if err != nil {
79-
return err
80-
}
81-
slog.Debug("websocketInterceptor1 after")
82-
return nil
83-
}
84-
85-
func websocketInterceptor2(ctx context.Context, dir mitmpgo.WSDirection, msgType int, b *buf.Buffer, req *http.Request, invoker mitmpgo.WebsocketDelegatedInvoker) error {
86-
slog.Debug("websocketInterceptor2 before", slog.String("dir", dir.String()), slog.Int("msgType", msgType), slog.Int("len", b.Len()))
87-
err := invoker.Invoke(msgType, b)
88-
if err != nil {
89-
return err
90-
}
91-
slog.Debug("websocketInterceptor2 after")
92-
return err
93-
}

examples/dumper/main.go

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ import (
1313
"log/slog"
1414
"net"
1515
"net/http"
16+
"net/http/httputil"
1617
"os"
1718
"os/signal"
1819
"strings"
1920
"syscall"
2021
"time"
2122

2223
"github.com/josexy/mitmpgo"
23-
"github.com/josexy/mitmpgo/buf"
2424
"github.com/josexy/mitmpgo/metadata"
2525
)
2626

@@ -37,6 +37,9 @@ type bodyDecoder struct {
3737
}
3838

3939
func newBodyDecoder(r io.ReadCloser, encoding string, chunkType int) (io.ReadCloser, error) {
40+
if r == http.NoBody { // no body and no need to replace it
41+
return r, nil
42+
}
4043
if encoding == "" {
4144
return newChunkBodyReader(r, CHUNK_SIZE, chunkType), nil
4245
}
@@ -102,7 +105,7 @@ func (r *chunkBodyReader) Read(p []byte) (n int, err error) {
102105
// fmt.Printf("--> hex dump(chunk size/data size: %d/%d):\n%s\n", r.N, n, hex.Dump(p[:n]))
103106
}
104107
if err == io.EOF {
105-
fmt.Printf("<<-- [%d]full data dump (%d bytes):\n%s\n", r.chunkType, r.buf.Len(), r.buf.Bytes())
108+
fmt.Printf("<<-- [%d]full data dump (%d bytes):\n", r.chunkType, r.buf.Len())
106109
}
107110
return
108111
}
@@ -147,6 +150,7 @@ func main() {
147150
// mitmpgo.WithDisableProxy(),
148151
// mitmpgo.WithDisableHTTP2(),
149152
// mitmpgo.WithSkipVerifySSLFromServer(),
153+
// mitmpgo.WithMaxWebsocketFramesPerForward(4096),
150154
)
151155
if err != nil {
152156
panic(err)
@@ -207,7 +211,6 @@ func httpInterceptor(ctx context.Context, req *http.Request, invoker mitmpgo.HTT
207211
slog.String("host", req.Host),
208212
slog.String("proto", req.Proto),
209213
slog.String("method", req.Method),
210-
slog.Bool("tls", req.TLS != nil),
211214
slog.String("url", req.URL.String()),
212215
slog.Any("headers", map[string][]string(req.Header)),
213216
)
@@ -257,17 +260,28 @@ func httpInterceptor(ctx context.Context, req *http.Request, invoker mitmpgo.HTT
257260
return rsp, err
258261
}
259262

260-
func websocketInterceptor(ctx context.Context, dir mitmpgo.WSDirection, msgType int, b *buf.Buffer, req *http.Request, wdi mitmpgo.WebsocketDelegatedInvoker) error {
263+
func websocketInterceptor(ctx context.Context, req *http.Request, rsp *http.Response, fw mitmpgo.WebsocketFramesWatcher) {
261264
_md, _ := metadata.FromContext(ctx)
262265
md := _md.MD()
263-
slog.Debug("websocket",
266+
slog.Debug("request",
267+
slog.Bool("stream_body", md.StreamBody),
264268
slog.String("source", md.SourceAddr.String()),
265269
slog.String("destination", md.DestinationAddr.String()),
266270
slog.String("hostport", md.RequestHostport),
267-
slog.String("uri", req.URL.String()),
268-
slog.String("direction", dir.String()),
269-
slog.Int("msg_type", msgType),
271+
slog.String("host", req.Host),
272+
slog.String("proto", req.Proto),
273+
slog.String("method", req.Method),
274+
slog.String("url", req.URL.String()),
275+
slog.Int("status_code", rsp.StatusCode),
276+
slog.Any("request_headers", map[string][]string(req.Header)),
277+
slog.Any("response_headers", map[string][]string(rsp.Header)),
270278
)
279+
280+
data, _ := httputil.DumpRequest(req, false)
281+
fmt.Printf("%s\n", string(data))
282+
data, _ = httputil.DumpResponse(rsp, false)
283+
fmt.Printf("%s\n", string(data))
284+
271285
if md.TLSState != nil {
272286
slog.Debug("tls state",
273287
slog.String("server_name", md.TLSState.ServerName),
@@ -292,10 +306,25 @@ func websocketInterceptor(ctx context.Context, dir mitmpgo.WSDirection, msgType
292306
slog.Any("ip", md.ServerCertificate.IPAddresses),
293307
)
294308
}
295-
// if md.Direction == metadata.Receive {
296-
// b.WriteString(time.Now().String())
297-
// }
298-
return wdi.Invoke(msgType, b)
309+
310+
for {
311+
select {
312+
case <-ctx.Done():
313+
return
314+
case frame, ok := <-fw.GetFrame():
315+
if !ok {
316+
return
317+
}
318+
dir := frame.Direction()
319+
msgType := frame.MessageType()
320+
dataBuf := frame.DataBuffer()
321+
fmt.Printf("---> %s %d %s\n", dir, msgType, dataBuf.String())
322+
if err := frame.Invoke(); err != nil {
323+
slog.Error("frame invoke error", slog.String("error", err.Error()))
324+
}
325+
frame.Release()
326+
}
327+
}
299328
}
300329

301330
func getDecodedReader(r io.Reader, encoding string) (io.ReadCloser, error) {

go.mod

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ module github.com/josexy/mitmpgo
33
go 1.25.4
44

55
require (
6-
github.com/gorilla/websocket v1.5.3
7-
golang.org/x/net v0.48.0
6+
github.com/josexy/websocket v0.0.0-20260219083038-11b2ba10886b
7+
golang.org/x/net v0.50.0
88
)
99

10-
require golang.org/x/text v0.32.0 // indirect
10+
require golang.org/x/text v0.34.0 // indirect

go.sum

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
2-
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
3-
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
4-
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
5-
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
6-
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
1+
github.com/josexy/websocket v0.0.0-20260219083038-11b2ba10886b h1:ahmk86ulmaZRnXZFfLWPahG0FCcBw6Bco8NML+4UCNU=
2+
github.com/josexy/websocket v0.0.0-20260219083038-11b2ba10886b/go.mod h1:E1y5c7BPj8laYUGMExE3snEx4mLawJGeVztAKQQeNcg=
3+
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
4+
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
5+
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
6+
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=

interceptor.go

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ func (d WSDirection) String() string {
2828
}
2929
}
3030

31+
type WsFrame interface {
32+
Direction() WSDirection
33+
MessageType() int
34+
DataBuffer() *buf.Buffer
35+
36+
// Forward the websocket message and release the data buffer
37+
Invoke() error
38+
// MUST be called to release the data buffer
39+
Release()
40+
}
41+
3142
type (
3243
HTTPDelegatedInvoker interface {
3344
Invoke(request *http.Request) (*http.Response, error)
@@ -36,14 +47,17 @@ type (
3647
WebsocketDelegatedInvoker interface {
3748
Invoke(msgType int, dataPtr *buf.Buffer) error
3849
}
50+
WebsocketFramesWatcher interface {
51+
GetFrame() <-chan WsFrame
52+
}
3953
)
4054

4155
type (
4256
HTTPDelegatedInvokerFunc func(*http.Request) (*http.Response, error)
4357
WebsocketDelegatedInvokerFunc func(int, *buf.Buffer) error
4458

4559
HTTPInterceptor func(context.Context, *http.Request, HTTPDelegatedInvoker) (*http.Response, error)
46-
WebsocketInterceptor func(context.Context, WSDirection, int, *buf.Buffer, *http.Request, WebsocketDelegatedInvoker) error
60+
WebsocketInterceptor func(context.Context, *http.Request, *http.Response, WebsocketFramesWatcher)
4761
)
4862

4963
func (f HTTPDelegatedInvokerFunc) Invoke(r *http.Request) (*http.Response, error) { return f(r) }
@@ -69,18 +83,3 @@ func getChainHTTPInterceptor(interceptors []HTTPInterceptor, curr int, ctx conte
6983
return interceptors[curr+1](ctx, r, getChainHTTPInterceptor(interceptors, curr+1, ctx, finalInvoker))
7084
})
7185
}
72-
73-
func chainWebsocketInterceptors(interceptors []WebsocketInterceptor) WebsocketInterceptor {
74-
return func(ctx context.Context, d WSDirection, i int, b *buf.Buffer, r *http.Request, wdi WebsocketDelegatedInvoker) error {
75-
return interceptors[0](ctx, d, i, b, r, getChainWebsocketInterceptor(interceptors, 0, ctx, d, r, wdi))
76-
}
77-
}
78-
79-
func getChainWebsocketInterceptor(interceptors []WebsocketInterceptor, curr int, ctx context.Context, dir WSDirection, req *http.Request, finalInvoker WebsocketDelegatedInvoker) WebsocketDelegatedInvoker {
80-
if curr == len(interceptors)-1 {
81-
return finalInvoker
82-
}
83-
return WebsocketDelegatedInvokerFunc(func(i int, b *buf.Buffer) error {
84-
return interceptors[curr+1](ctx, dir, i, b, req, getChainWebsocketInterceptor(interceptors, curr+1, ctx, dir, req, finalInvoker))
85-
})
86-
}

option.go

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ type options struct {
3434
rootCAs []string // Paths to additional root CA certificate files
3535
dialer *net.Dialer // Custom dialer for outbound connections
3636

37+
wsMaxFramesPerForward int // Max frames channel size per single websocket forward
38+
3739
clientCerts map[string]ClientCert // Client certificate configuration
3840

3941
// Certificate cache pool configuration
@@ -50,14 +52,14 @@ type options struct {
5052
httpInt HTTPInterceptor
5153
wsInt WebsocketInterceptor
5254
chainHttpInts []HTTPInterceptor
53-
chainWsInts []WebsocketInterceptor
5455
}
5556

5657
// newOptions creates a new options instance with default values.
5758
// Default dialer timeout is 15 seconds.
5859
func newOptions(opt ...Option) *options {
5960
options := &options{
60-
dialer: &net.Dialer{Timeout: 15 * time.Second},
61+
dialer: &net.Dialer{Timeout: 15 * time.Second},
62+
wsMaxFramesPerForward: 2048,
6163
}
6264
for _, o := range opt {
6365
o.apply(options)
@@ -258,6 +260,22 @@ func WithCertCachePool(capacity, intervalSecond, expireSecond int) Option {
258260
})
259261
}
260262

263+
// WithMaxWebsocketFramesPerForward specifies the maximum channel size of frames that can be buffered
264+
// per single websocket forward.
265+
//
266+
// If not specified, default value(2048) is used.
267+
//
268+
// Example:
269+
//
270+
// handler, err := NewMitmProxyHandler(
271+
// WithMaxWebsocketFramesPerForward(2048),
272+
// )
273+
func WithMaxWebsocketFramesPerForward(maxFrames int) Option {
274+
return OptionFunc(func(o *options) {
275+
o.wsMaxFramesPerForward = maxFrames
276+
})
277+
}
278+
261279
// WithIncludeHosts specifies a whitelist of hosts that should be intercepted.
262280
// Only traffic to these hosts will be intercepted; all other traffic will pass through
263281
// without interception (passthrough mode).
@@ -436,32 +454,3 @@ func WithChainHTTPInterceptor(interceptors ...HTTPInterceptor) Option {
436454
o.chainHttpInts = append(o.chainHttpInts, interceptors...)
437455
})
438456
}
439-
440-
// WithChainWebsocketInterceptor chains multiple WebSocket interceptors together.
441-
// Interceptors are executed in the order they are provided, forming a middleware chain.
442-
// Each interceptor can inspect/modify the WebSocket message, call the next interceptor,
443-
// and handle the message forwarding. And The final interceptor will forwards the message.
444-
//
445-
// Example:
446-
//
447-
// loggingInterceptor := func(ctx context.Context, dir metadata.WSDirection, msgType int, data *buf.Buffer, req *http.Request, invoker WebsocketDelegatedInvoker) error {
448-
// log.Printf("[%s] Message: type=%d, size=%d", dir, msgType, data.Len())
449-
// return invoker.Invoke(msgType, data)
450-
// }
451-
//
452-
// filterInterceptor := func(ctx context.Context, dir metadata.WSDirection, msgType int, data *buf.Buffer, req *http.Request, invoker WebsocketDelegatedInvoker) error {
453-
// // Drop ping messages
454-
// if msgType == websocket.PingMessage {
455-
// return nil
456-
// }
457-
// return invoker.Invoke(msgType, data)
458-
// }
459-
//
460-
// handler, err := NewMitmProxyHandler(
461-
// WithChainWebsocketInterceptor(loggingInterceptor, filterInterceptor),
462-
// )
463-
func WithChainWebsocketInterceptor(interceptors ...WebsocketInterceptor) Option {
464-
return OptionFunc(func(o *options) {
465-
o.chainWsInts = append(o.chainWsInts, interceptors...)
466-
})
467-
}

0 commit comments

Comments
 (0)