diff --git a/v2/common/websocket/client.go b/v2/common/websocket/client.go index f2fadf0d3..182261da0 100644 --- a/v2/common/websocket/client.go +++ b/v2/common/websocket/client.go @@ -229,7 +229,7 @@ func (c *client) wait(timeout time.Duration) { // handleReconnect waits for reconnect signal and starts reconnect func (c *client) handleReconnect() { - for _ = range c.reconnectSignal { + for range c.reconnectSignal { c.debug("reconnect: received signal") b := &backoff.Backoff{ @@ -337,6 +337,8 @@ func NewConnection( return nil, err } + ctx, cancel := context.WithCancel(context.Background()) + wsConn := &connection{ conn: underlyingWsConn, connectionMu: sync.Mutex{}, @@ -344,6 +346,8 @@ func NewConnection( initUnderlyingWsConnFn: initUnderlyingWsConnFn, keepaliveTimeout: keepaliveTimeout, isKeepAliveNeeded: isKeepAliveNeeded, + ctx: ctx, + cancel: cancel, } if isKeepAliveNeeded { @@ -362,6 +366,8 @@ type connection struct { initUnderlyingWsConnFn func() (*websocket.Conn, error) keepaliveTimeout time.Duration isKeepAliveNeeded bool + ctx context.Context + cancel context.CancelFunc } type Connection interface { @@ -379,7 +385,11 @@ func (c *connection) WriteMessage(messageType int, data []byte) error { // ReadMessage wrapper for conn.ReadMessage func (c *connection) ReadMessage() (int, []byte, error) { - return c.conn.ReadMessage() + msgType, msg, err := c.conn.ReadMessage() + if err != nil { + c.cancel() + } + return msgType, msg, err } // RestoreConnection recreates ws connection with the same underlying connection callback and keepalive timeout @@ -389,8 +399,6 @@ func (c *connection) RestoreConnection() (Connection, error) { // keepAlive handles ping-pong for connection func (c *connection) keepAlive(timeout time.Duration) { - ticker := time.NewTicker(timeout) - c.updateLastResponse() c.conn.SetPongHandler(func(msg string) error { @@ -399,17 +407,23 @@ func (c *connection) keepAlive(timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() - for { - err := c.ping() - if err != nil { - return - } - <-ticker.C - if c.isLastResponseOutdated(timeout) { - c.close() + for { + select { + case <-c.ctx.Done(): return + case <-ticker.C: + err := c.ping() + if err != nil { + return + } + + if c.isLastResponseOutdated(timeout) { + c.close() + return + } } } }() @@ -442,10 +456,5 @@ func (c *connection) ping() error { defer c.connectionMu.Unlock() deadline := time.Now().Add(KeepAlivePingDeadline) - err := c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline) - if err != nil { - return err - } - - return nil + return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline) } diff --git a/v2/delivery/websocket.go b/v2/delivery/websocket.go index 6d0f1f0b8..0fd84642c 100644 --- a/v2/delivery/websocket.go +++ b/v2/delivery/websocket.go @@ -55,7 +55,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don // closed by the client. defer close(doneC) if WebsocketKeepalive { - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a // separate goroutine because ReadMessage is a blocking @@ -83,9 +83,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -105,12 +103,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }() diff --git a/v2/futures/websocket.go b/v2/futures/websocket.go index 7d54b7404..82d02dc48 100644 --- a/v2/futures/websocket.go +++ b/v2/futures/websocket.go @@ -55,7 +55,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don // closed by the client. defer close(doneC) if WebsocketKeepalive { - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a // separate goroutine because ReadMessage is a blocking @@ -83,9 +83,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -105,12 +103,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }() diff --git a/v2/options/websocket.go b/v2/options/websocket.go index 6d3d4d54c..a6453abd4 100644 --- a/v2/options/websocket.go +++ b/v2/options/websocket.go @@ -56,7 +56,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don // closed by the client. defer close(doneC) if WebsocketKeepalive { - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a // separate goroutine because ReadMessage is a blocking @@ -84,9 +84,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -106,12 +104,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }() diff --git a/v2/websocket.go b/v2/websocket.go index 864443787..5ff2b6c81 100644 --- a/v2/websocket.go +++ b/v2/websocket.go @@ -58,7 +58,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don if WebsocketKeepalive { // This function overwrites the default ping frame handler // sent by the websocket API server - keepAlive(c, WebsocketTimeout) + keepAlive(doneC, c, WebsocketTimeout) } // Wait for the stopC channel to be closed. We do that in a @@ -87,9 +87,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don return } -func keepAlive(c *websocket.Conn, timeout time.Duration) { - ticker := time.NewTicker(timeout) - +func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) { lastResponse := time.Now() c.SetPingHandler(func(pingData string) error { @@ -109,12 +107,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) { }) go func() { + ticker := time.NewTicker(timeout) defer ticker.Stop() + for { - <-ticker.C - if time.Since(lastResponse) > timeout { - c.Close() + select { + case <-done: return + case <-ticker.C: + if time.Since(lastResponse) > timeout { + c.Close() + return + } } } }()