Skip to content

Commit b2c2b50

Browse files
committed
rpcclient: support canceling in-flight http requests
1 parent 582b999 commit b2c2b50

File tree

2 files changed

+113
-12
lines changed

2 files changed

+113
-12
lines changed

rpcclient/infrastructure.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ out:
766766
// handleSendPostMessage handles performing the passed HTTP request, reading the
767767
// result, unmarshalling it, and delivering the unmarshalled result to the
768768
// provided response channel.
769-
func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
769+
func (c *Client) handleSendPostMessage(ctx context.Context, jReq *jsonRequest) {
770770
var (
771771
lastErr error
772772
backoff time.Duration
@@ -782,12 +782,17 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
782782
}
783783

784784
tries := 10
785+
retryloop:
785786
for i := 0; i < tries; i++ {
786787
var httpReq *http.Request
787788

788789
bodyReader := bytes.NewReader(jReq.marshalledJSON)
789-
httpReq, err = http.NewRequest("POST", httpURL, bodyReader)
790+
httpReq, err = http.NewRequestWithContext(ctx, "POST", httpURL, bodyReader)
790791
if err != nil {
792+
// We must observe the contract that shutdown returns ErrClientShutdown.
793+
if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) {
794+
err = ErrClientShutdown
795+
}
791796
jReq.responseChan <- &Response{result: nil, err: err}
792797
return
793798
}
@@ -812,6 +817,11 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
812817
break
813818
}
814819

820+
// We must observe the contract that shutdown returns ErrClientShutdown.
821+
if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) {
822+
err = ErrClientShutdown
823+
}
824+
815825
// Save the last error for the case where we backoff further,
816826
// retry and get an invalid response but no error. If this
817827
// happens the saved last error will be used to enrich the error
@@ -830,8 +840,13 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
830840
select {
831841
case <-time.After(backoff):
832842

833-
case <-c.shutdown:
834-
return
843+
case <-ctx.Done():
844+
err = ctx.Err()
845+
// maintain our contract: shutdown errors are ErrClientShutdown
846+
if errors.Is(context.Cause(ctx), ErrClientShutdown) {
847+
err = ErrClientShutdown
848+
}
849+
break retryloop
835850
}
836851
}
837852
if err != nil {
@@ -891,30 +906,28 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
891906
// in HTTP POST mode. It uses a buffered channel to serialize output messages
892907
// while allowing the sender to continue running asynchronously. It must be run
893908
// as a goroutine.
894-
func (c *Client) sendPostHandler() {
909+
func (c *Client) sendPostHandler(ctx context.Context) {
895910
out:
896911
for {
897912
// Send any messages ready for send until the shutdown channel
898913
// is closed.
899914
select {
900915
case jReq := <-c.sendPostChan:
901-
c.handleSendPostMessage(jReq)
916+
c.handleSendPostMessage(ctx, jReq)
902917

903-
case <-c.shutdown:
918+
case <-ctx.Done():
904919
break out
905920
}
906921
}
907922

923+
err := context.Cause(ctx)
908924
// Drain any wait channels before exiting so nothing is left waiting
909925
// around to send.
910926
cleanup:
911927
for {
912928
select {
913929
case jReq := <-c.sendPostChan:
914-
jReq.responseChan <- &Response{
915-
result: nil,
916-
err: ErrClientShutdown,
917-
}
930+
jReq.responseChan <- &Response{result: nil, err: err}
918931

919932
default:
920933
break cleanup
@@ -1178,8 +1191,13 @@ func (c *Client) start() {
11781191
// Start the I/O processing handlers depending on whether the client is
11791192
// in HTTP POST mode or the default websocket mode.
11801193
if c.config.HTTPPostMode {
1194+
ctx, cancel := context.WithCancelCause(context.Background())
11811195
c.wg.Add(1)
1182-
go c.sendPostHandler()
1196+
go c.sendPostHandler(ctx)
1197+
go func() {
1198+
<-c.shutdown
1199+
cancel(ErrClientShutdown)
1200+
}()
11831201
} else {
11841202
c.wg.Add(3)
11851203
go func() {

rpcclient/infrastructure_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
package rpcclient
22

33
import (
4+
"io"
5+
"net"
46
"testing"
7+
"time"
58

9+
"github.com/stretchr/testify/assert"
610
"github.com/stretchr/testify/require"
711
)
812

@@ -108,3 +112,82 @@ func TestParseAddressString(t *testing.T) {
108112
})
109113
}
110114
}
115+
116+
// TestHTTPPostShutdownInterruptsPendingRequest ensures that a client operating
117+
// in HTTP POST mode can interrupt an in-flight request during shutdown.
118+
func TestHTTPPostShutdownInterruptsPendingRequest(t *testing.T) {
119+
t.Parallel()
120+
121+
listener, err := net.Listen("tcp", "127.0.0.1:0")
122+
require.NoError(t, err)
123+
124+
requestAccepted := make(chan struct{})
125+
serverDone := make(chan struct{})
126+
127+
go func() {
128+
defer close(serverDone)
129+
130+
conn, err := listener.Accept()
131+
if err != nil {
132+
return
133+
}
134+
defer func() {
135+
err := conn.Close()
136+
assert.NoError(t, err)
137+
}()
138+
139+
close(requestAccepted)
140+
141+
_, _ = io.Copy(io.Discard, conn)
142+
}()
143+
144+
t.Cleanup(func() {
145+
err := listener.Close()
146+
require.NoError(t, err)
147+
<-serverDone
148+
})
149+
150+
connCfg := &ConnConfig{
151+
Host: listener.Addr().String(),
152+
User: "user",
153+
Pass: "pass",
154+
DisableTLS: true,
155+
HTTPPostMode: true,
156+
}
157+
158+
client, err := New(connCfg, nil)
159+
require.NoError(t, err)
160+
t.Cleanup(client.Shutdown)
161+
162+
future := client.GetBlockCountAsync()
163+
164+
select {
165+
case <-requestAccepted:
166+
case <-time.After(2 * time.Second):
167+
t.Fatalf("server did not accept client connection")
168+
}
169+
170+
select {
171+
case <-future:
172+
t.Fatalf("expected request to remain pending until shutdown")
173+
case <-time.After(100 * time.Millisecond):
174+
}
175+
176+
client.Shutdown()
177+
178+
waitDone := make(chan struct{})
179+
go func() {
180+
client.WaitForShutdown()
181+
close(waitDone)
182+
}()
183+
184+
select {
185+
case <-waitDone:
186+
case <-time.After(5 * time.Second):
187+
t.Fatalf("client shutdown did not complete")
188+
}
189+
190+
result, err := future.Receive()
191+
require.Zero(t, result)
192+
require.ErrorContains(t, err, ErrClientShutdown.Error())
193+
}

0 commit comments

Comments
 (0)