@@ -42,10 +42,10 @@ type Client struct {
4242 conn * websocket.Conn
4343 connCtx context.Context
4444 connCtxCancel context.CancelFunc
45+ wg sync.WaitGroup
4546 lock sync.RWMutex
4647 subscriptionByRequestID map [uint64 ]* Subscription
4748 subscriptionByWSSubID map [uint64 ]* Subscription
48- reconnectOnErr bool
4949 shortID bool
5050}
5151
@@ -105,10 +105,13 @@ func ConnectWithOptions(ctx context.Context, rpcEndpoint string, opt *Options) (
105105 }
106106
107107 c .connCtx , c .connCtxCancel = context .WithCancel (context .Background ())
108+ c .wg .Add (2 )
108109 go func () {
110+ defer c .wg .Done ()
109111 c .conn .SetReadDeadline (time .Now ().Add (pongWait ))
110112 c .conn .SetPongHandler (func (string ) error { c .conn .SetReadDeadline (time .Now ().Add (pongWait )); return nil })
111113 ticker := time .NewTicker (pingPeriod )
114+ defer ticker .Stop ()
112115 for {
113116 select {
114117 case <- c .connCtx .Done ():
@@ -132,22 +135,27 @@ func (c *Client) sendPing() {
132135 }
133136}
134137
138+ // Close cancels the connection context, closes the underlying websocket
139+ // connection, and waits for background goroutines to finish.
135140func (c * Client ) Close () {
136141 c .lock .Lock ()
137- defer c .lock .Unlock ()
138142 c .connCtxCancel ()
139143 c .conn .Close ()
144+ c .lock .Unlock ()
145+ c .wg .Wait ()
140146}
141147
142148func (c * Client ) receiveMessages () {
149+ defer c .wg .Done ()
150+ defer c .closeAllSubscription (ErrSubscriptionClosed )
151+
143152 for {
144153 select {
145154 case <- c .connCtx .Done ():
146155 return
147156 default :
148157 _ , message , err := c .conn .ReadMessage ()
149158 if err != nil {
150- c .closeAllSubscription (err )
151159 return
152160 }
153161 c .handleMessage (message )
@@ -218,7 +226,6 @@ func (c *Client) handleNewSubscriptionMessage(requestID, subID uint64) {
218226 zap .Uint64 ("request_id" , requestID ),
219227 zap .Int ("subscription_count" , len (c .subscriptionByWSSubID )),
220228 )
221- return
222229}
223230
224231func (c * Client ) handleSubscriptionMessage (subID uint64 , message []byte ) {
@@ -239,34 +246,43 @@ func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
239246 // Decode the message using the subscription-provided decoderFunc.
240247 result , err := sub .decoderFunc (message )
241248 if err != nil {
242- fmt .Println ("*****************************" )
243249 c .closeSubscription (sub .req .ID , fmt .Errorf ("unable to decode client response: %w" , err ))
244250 return
245251 }
246252
247253 // this cannot be blocking or else
248254 // we will no read any other message
249255 if len (sub .stream ) >= cap (sub .stream ) {
250- zlog .Warn ("closing ws client subscription... not consuming fast en ought " ,
256+ zlog .Warn ("closing ws client subscription... not consuming fast enough " ,
251257 zap .Uint64 ("request_id" , sub .req .ID ),
252258 )
253259 c .closeSubscription (sub .req .ID , fmt .Errorf ("reached channel max capacity %d" , len (sub .stream )))
254260 return
255261 }
256262
257- sub .mutex .Lock ()
258- defer sub .mutex .Unlock ()
263+ sub .mu .Lock ()
259264 if ! sub .closed {
260265 sub .stream <- result
261266 }
267+ sub .mu .Unlock ()
262268}
263269
264270func (c * Client ) closeAllSubscription (err error ) {
265271 c .lock .Lock ()
266272 defer c .lock .Unlock ()
267273
268274 for _ , sub := range c .subscriptionByRequestID {
269- sub .err <- err
275+ sub .mu .Lock ()
276+ if ! sub .closed {
277+ select {
278+ case sub .err <- err :
279+ default :
280+ }
281+ sub .closed = true
282+ close (sub .stream )
283+ close (sub .err )
284+ }
285+ sub .mu .Unlock ()
270286 }
271287
272288 c .subscriptionByRequestID = map [uint64 ]* Subscription {}
@@ -282,12 +298,21 @@ func (c *Client) closeSubscription(reqID uint64, err error) {
282298 return
283299 }
284300
285- sub .err <- err
301+ sub .mu .Lock ()
302+ if ! sub .closed {
303+ select {
304+ case sub .err <- err :
305+ default :
306+ }
307+ sub .closed = true
308+ close (sub .stream )
309+ close (sub .err )
310+ }
311+ sub .mu .Unlock ()
286312
287- err = c .unsubscribe (sub .subID , sub .unsubscribeMethod )
288- if err != nil {
313+ if e := c .unsubscribe (sub .subID , sub .unsubscribeMethod ); e != nil {
289314 zlog .Warn ("unable to send rpc unsubscribe call" ,
290- zap .Error (err ),
315+ zap .Error (e ),
291316 )
292317 }
293318
0 commit comments