From fcc2f965e4d1ea5eafdd04b51ca389c627bbfb81 Mon Sep 17 00:00:00 2001 From: Pierre Mdawar Date: Sat, 12 Apr 2025 13:29:37 +0300 Subject: [PATCH 1/4] fix: websocket keepalive goroutine leak --- v2/delivery/websocket.go | 18 +++++++++++------- v2/futures/websocket.go | 18 +++++++++++------- v2/options/websocket.go | 18 +++++++++++------- v2/websocket.go | 18 +++++++++++------- 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/v2/delivery/websocket.go b/v2/delivery/websocket.go index 6d0f1f0b8..0fd84642c 100644 --- a/v2/delivery/websocket.go +++ b/v2/delivery/websocket.go @@ -55,7 +55,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don // closed by the client. defer close(doneC) if WebsocketKeepalive { - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a // separate goroutine because ReadMessage is a blocking @@ -83,9 +83,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -105,12 +103,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }() diff --git a/v2/futures/websocket.go b/v2/futures/websocket.go index 7d54b7404..82d02dc48 100644 --- a/v2/futures/websocket.go +++ b/v2/futures/websocket.go @@ -55,7 +55,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don // closed by the client. defer close(doneC) if WebsocketKeepalive { - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a // separate goroutine because ReadMessage is a blocking @@ -83,9 +83,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -105,12 +103,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }() diff --git a/v2/options/websocket.go b/v2/options/websocket.go index 6d3d4d54c..a6453abd4 100644 --- a/v2/options/websocket.go +++ b/v2/options/websocket.go @@ -56,7 +56,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don // closed by the client. defer close(doneC) if WebsocketKeepalive { - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a // separate goroutine because ReadMessage is a blocking @@ -84,9 +84,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -106,12 +104,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }() diff --git a/v2/websocket.go b/v2/websocket.go index 864443787..5ff2b6c81 100644 --- a/v2/websocket.go +++ b/v2/websocket.go @@ -58,7 +58,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don if WebsocketKeepalive { // This function overwrites the default ping frame handler // sent by the websocket API server - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a @@ -87,9 +87,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -109,12 +107,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }() From 051e4600794062e2455637503d40130ab0da4bc6 Mon Sep 17 00:00:00 2001 From: Pierre Mdawar Date: Sat, 12 Apr 2025 17:16:50 +0300 Subject: [PATCH 2/4] refactor: remove unnecessary assignment and move ticker variable --- v2/common/websocket/client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/v2/common/websocket/client.go b/v2/common/websocket/client.go index f2fadf0d3..fbf7dd00c 100644 --- a/v2/common/websocket/client.go +++ b/v2/common/websocket/client.go @@ -229,7 +229,7 @@ func (c *client) wait(timeout time.Duration) { // handleReconnect waits for reconnect signal and starts reconnect func (c *client) handleReconnect() { - for _ = range c.reconnectSignal { + for range c.reconnectSignal { c.debug("reconnect: received signal") b := &backoff.Backoff{ @@ -389,8 +389,6 @@ func (c *connection) RestoreConnection() (Connection, error) { // keepAlive handles ping-pong for connection func (c *connection) keepAlive(timeout time.Duration) { - ticker := time.NewTicker(timeout) - c.updateLastResponse() c.conn.SetPongHandler(func(msg string) error { @@ -399,7 +397,9 @@ func (c *connection) keepAlive(timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { err := c.ping() if err != nil { From f054ff6dd2ef7b94dbae8bd263004d67976903ef Mon Sep 17 00:00:00 2001 From: Pierre Mdawar Date: Sat, 12 Apr 2025 21:17:29 +0300 Subject: [PATCH 3/4] fix: stop the keepalive goroutine on websocket client read errors --- v2/common/websocket/client.go | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/v2/common/websocket/client.go b/v2/common/websocket/client.go index fbf7dd00c..f1daf4761 100644 --- a/v2/common/websocket/client.go +++ b/v2/common/websocket/client.go @@ -344,6 +344,7 @@ func NewConnection( initUnderlyingWsConnFn: initUnderlyingWsConnFn, keepaliveTimeout: keepaliveTimeout, isKeepAliveNeeded: isKeepAliveNeeded, + done: make(chan struct{}), } if isKeepAliveNeeded { @@ -362,6 +363,7 @@ type connection struct { initUnderlyingWsConnFn func() (*websocket.Conn, error) keepaliveTimeout time.Duration isKeepAliveNeeded bool + done chan struct{} } type Connection interface { @@ -379,7 +381,11 @@ func (c *connection) WriteMessage(messageType int, data []byte) error { // ReadMessage wrapper for conn.ReadMessage func (c *connection) ReadMessage() (int, []byte, error) { - return c.conn.ReadMessage() + msgType, msg, err := c.conn.ReadMessage() + if err != nil { + close(c.done) + } + return msgType, msg, err } // RestoreConnection recreates ws connection with the same underlying connection callback and keepalive timeout @@ -401,15 +407,19 @@ func (c *connection) keepAlive(timeout time.Duration) { defer ticker.Stop() for { - err := c.ping() - if err != nil { - return - } - - <-ticker.C - if c.isLastResponseOutdated(timeout) { - c.close() + select { + case <-c.done: return + case <-ticker.C: + err := c.ping() + if err != nil { + return + } + + if c.isLastResponseOutdated(timeout) { + c.close() + return + } } } }() @@ -442,10 +452,5 @@ func (c *connection) ping() error { defer c.connectionMu.Unlock() deadline := time.Now().Add(KeepAlivePingDeadline) - err := c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline) - if err != nil { - return err - } - - return nil + return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline) } From b100c39d37db907ecbc9252dcebbfbcfccf8c0e0 Mon Sep 17 00:00:00 2001 From: Pierre Mdawar Date: Tue, 15 Apr 2025 06:57:05 +0300 Subject: [PATCH 4/4] fix: use a context to stop the keepalive goroutine --- v2/common/websocket/client.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/v2/common/websocket/client.go b/v2/common/websocket/client.go index f1daf4761..182261da0 100644 --- a/v2/common/websocket/client.go +++ b/v2/common/websocket/client.go @@ -337,6 +337,8 @@ func NewConnection( return nil, err } + ctx, cancel := context.WithCancel(context.Background()) + wsConn := &connection{ conn: underlyingWsConn, connectionMu: sync.Mutex{}, @@ -344,7 +346,8 @@ func NewConnection( initUnderlyingWsConnFn: initUnderlyingWsConnFn, keepaliveTimeout: keepaliveTimeout, isKeepAliveNeeded: isKeepAliveNeeded, - done: make(chan struct{}), + ctx: ctx, + cancel: cancel, } if isKeepAliveNeeded { @@ -363,7 +366,8 @@ type connection struct { initUnderlyingWsConnFn func() (*websocket.Conn, error) keepaliveTimeout time.Duration isKeepAliveNeeded bool - done chan struct{} + ctx context.Context + cancel context.CancelFunc } type Connection interface { @@ -383,7 +387,7 @@ func (c *connection) WriteMessage(messageType int, data []byte) error { func (c *connection) ReadMessage() (int, []byte, error) { msgType, msg, err := c.conn.ReadMessage() if err != nil { - close(c.done) + c.cancel() } return msgType, msg, err } @@ -408,7 +412,7 @@ func (c *connection) keepAlive(timeout time.Duration) { for { select { - case <-c.done: + case <-c.ctx.Done(): return case <-ticker.C: err := c.ping()