Skip to content

Commit 4370b70

Browse files
Merge pull request #383 from sonicfromnewyoke/fix/ws-race-and-leak-fixes
fix: ws race and leak fixes
2 parents 364496d + 6be3404 commit 4370b70

14 files changed

Lines changed: 547 additions & 192 deletions

rpc/sendAndConfirmTransaction/sendAndConfirmTransaction.go

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,18 @@ func WaitForConfirmation(
120120
timeout = &t
121121
}
122122

123-
for {
124-
select {
125-
case <-ctx.Done():
126-
return false, ctx.Err()
127-
case <-time.After(*timeout):
123+
timeoutCtx, cancel := context.WithTimeout(ctx, *timeout)
124+
defer cancel()
125+
126+
got, err := sub.Recv(timeoutCtx)
127+
if err != nil {
128+
if timeoutCtx.Err() == context.DeadlineExceeded {
128129
return false, ErrTimeout
129-
case resp, ok := <-sub.Response():
130-
if !ok {
131-
return false, fmt.Errorf("subscription closed")
132-
}
133-
if resp.Value.Err != nil {
134-
// The transaction was confirmed, but it failed while executing (one of the instructions failed).
135-
return true, fmt.Errorf("confirmed transaction with execution error: %v", resp.Value.Err)
136-
} else {
137-
// Success! Confirmed! And there was no error while executing the transaction.
138-
return true, nil
139-
}
140-
case err := <-sub.Err():
141-
return false, err
142130
}
131+
return false, err
132+
}
133+
if got.Value.Err != nil {
134+
return true, fmt.Errorf("confirmed transaction with execution error: %v", got.Value.Err)
143135
}
136+
return true, nil
144137
}

rpc/ws/accountSubscribe.go

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,6 @@ func (sw *AccountSubscription) Err() <-chan error {
101101
return sw.sub.err
102102
}
103103

104-
func (sw *AccountSubscription) Response() <-chan *AccountResult {
105-
typedChan := make(chan *AccountResult, 1)
106-
go func(ch chan *AccountResult) {
107-
// TODO: will this subscription yield more than one result?
108-
d, ok := <-sw.sub.stream
109-
if !ok {
110-
return
111-
}
112-
ch <- d.(*AccountResult)
113-
}(typedChan)
114-
return typedChan
115-
}
116-
117104
func (sw *AccountSubscription) Unsubscribe() {
118105
sw.sub.Unsubscribe()
119106
}

rpc/ws/blockSubscribe.go

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,6 @@ func (sw *BlockSubscription) Err() <-chan error {
167167
return sw.sub.err
168168
}
169169

170-
func (sw *BlockSubscription) Response() <-chan *BlockResult {
171-
typedChan := make(chan *BlockResult, 1)
172-
go func(ch chan *BlockResult) {
173-
// TODO: will this subscription yield more than one result?
174-
d, ok := <-sw.sub.stream
175-
if !ok {
176-
return
177-
}
178-
ch <- d.(*BlockResult)
179-
}(typedChan)
180-
return typedChan
181-
}
182-
183170
func (sw *BlockSubscription) Unsubscribe() {
184171
sw.sub.Unsubscribe()
185172
}

rpc/ws/client.go

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ type Client struct {
4242
conn *websocket.Conn
4343
connCtx context.Context
4444
connCtxCancel context.CancelFunc
45+
wg sync.WaitGroup
4546
lock sync.RWMutex
4647
subscriptionByRequestID map[uint64]*Subscription
4748
subscriptionByWSSubID map[uint64]*Subscription
48-
reconnectOnErr bool
4949
shortID bool
5050
}
5151

@@ -105,10 +105,13 @@ func ConnectWithOptions(ctx context.Context, rpcEndpoint string, opt *Options) (
105105
}
106106

107107
c.connCtx, c.connCtxCancel = context.WithCancel(context.Background())
108+
c.wg.Add(2)
108109
go func() {
110+
defer c.wg.Done()
109111
c.conn.SetReadDeadline(time.Now().Add(pongWait))
110112
c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
111113
ticker := time.NewTicker(pingPeriod)
114+
defer ticker.Stop()
112115
for {
113116
select {
114117
case <-c.connCtx.Done():
@@ -132,22 +135,27 @@ func (c *Client) sendPing() {
132135
}
133136
}
134137

138+
// Close cancels the connection context, closes the underlying websocket
139+
// connection, and waits for background goroutines to finish.
135140
func (c *Client) Close() {
136141
c.lock.Lock()
137-
defer c.lock.Unlock()
138142
c.connCtxCancel()
139143
c.conn.Close()
144+
c.lock.Unlock()
145+
c.wg.Wait()
140146
}
141147

142148
func (c *Client) receiveMessages() {
149+
defer c.wg.Done()
150+
defer c.closeAllSubscription(ErrSubscriptionClosed)
151+
143152
for {
144153
select {
145154
case <-c.connCtx.Done():
146155
return
147156
default:
148157
_, message, err := c.conn.ReadMessage()
149158
if err != nil {
150-
c.closeAllSubscription(err)
151159
return
152160
}
153161
c.handleMessage(message)
@@ -218,7 +226,6 @@ func (c *Client) handleNewSubscriptionMessage(requestID, subID uint64) {
218226
zap.Uint64("request_id", requestID),
219227
zap.Int("subscription_count", len(c.subscriptionByWSSubID)),
220228
)
221-
return
222229
}
223230

224231
func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
@@ -239,34 +246,43 @@ func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
239246
// Decode the message using the subscription-provided decoderFunc.
240247
result, err := sub.decoderFunc(message)
241248
if err != nil {
242-
fmt.Println("*****************************")
243249
c.closeSubscription(sub.req.ID, fmt.Errorf("unable to decode client response: %w", err))
244250
return
245251
}
246252

247253
// this cannot be blocking or else
248254
// we will no read any other message
249255
if len(sub.stream) >= cap(sub.stream) {
250-
zlog.Warn("closing ws client subscription... not consuming fast en ought",
256+
zlog.Warn("closing ws client subscription... not consuming fast enough",
251257
zap.Uint64("request_id", sub.req.ID),
252258
)
253259
c.closeSubscription(sub.req.ID, fmt.Errorf("reached channel max capacity %d", len(sub.stream)))
254260
return
255261
}
256262

257-
sub.mutex.Lock()
258-
defer sub.mutex.Unlock()
263+
sub.mu.Lock()
259264
if !sub.closed {
260265
sub.stream <- result
261266
}
267+
sub.mu.Unlock()
262268
}
263269

264270
func (c *Client) closeAllSubscription(err error) {
265271
c.lock.Lock()
266272
defer c.lock.Unlock()
267273

268274
for _, sub := range c.subscriptionByRequestID {
269-
sub.err <- err
275+
sub.mu.Lock()
276+
if !sub.closed {
277+
select {
278+
case sub.err <- err:
279+
default:
280+
}
281+
sub.closed = true
282+
close(sub.stream)
283+
close(sub.err)
284+
}
285+
sub.mu.Unlock()
270286
}
271287

272288
c.subscriptionByRequestID = map[uint64]*Subscription{}
@@ -282,12 +298,21 @@ func (c *Client) closeSubscription(reqID uint64, err error) {
282298
return
283299
}
284300

285-
sub.err <- err
301+
sub.mu.Lock()
302+
if !sub.closed {
303+
select {
304+
case sub.err <- err:
305+
default:
306+
}
307+
sub.closed = true
308+
close(sub.stream)
309+
close(sub.err)
310+
}
311+
sub.mu.Unlock()
286312

287-
err = c.unsubscribe(sub.subID, sub.unsubscribeMethod)
288-
if err != nil {
313+
if e := c.unsubscribe(sub.subID, sub.unsubscribeMethod); e != nil {
289314
zlog.Warn("unable to send rpc unsubscribe call",
290-
zap.Error(err),
315+
zap.Error(e),
291316
)
292317
}
293318

0 commit comments

Comments
 (0)