@@ -29,6 +29,7 @@ package rpc
2929import (
3030 "context"
3131 "encoding/json"
32+ "os"
3233 "reflect"
3334 "strconv"
3435 "strings"
@@ -40,6 +41,42 @@ import (
4041 "golang.org/x/time/rate"
4142)
4243
44+ const (
45+ errcodeTimeout = - 32002
46+ errcodeResponseTooLarge = - 32003
47+ )
48+
49+ const (
50+ errMsgTimeout = "request timed out"
51+ errMsgResponseTooLarge = "response too large"
52+ errMsgBatchTooLarge = "batch too large"
53+ )
54+
55+ var (
56+ batchRequestLimit = 0 // limit on total number of requests in a batch
57+ batchResponseMaxSize = 0 // limit on the size of a batch response
58+ )
59+
60+ func init () {
61+ // Read batchRequestLimit and batchResponseMaxSize from environment variables
62+ // RPC_BATCH_REQUEST_LIMIT and RPC_BATCH_RESPONSE_MAX_SIZE.
63+ // If their values are invalid integers, panic.
64+ if batchRequestLimitStr := os .Getenv ("RPC_BATCH_REQUEST_LIMIT" ); batchRequestLimitStr != "" {
65+ var err error
66+ batchRequestLimit , err = strconv .Atoi (batchRequestLimitStr )
67+ if err != nil || batchRequestLimit < 0 {
68+ panic ("RPC_BATCH_REQUEST_LIMIT must be a non-negative integer" )
69+ }
70+ }
71+ if batchResponseMaxSizeStr := os .Getenv ("RPC_BATCH_RESPONSE_MAX_SIZE" ); batchResponseMaxSizeStr != "" {
72+ var err error
73+ batchResponseMaxSize , err = strconv .Atoi (batchResponseMaxSizeStr )
74+ if err != nil || batchResponseMaxSize < 0 {
75+ panic ("RPC_BATCH_RESPONSE_MAX_SIZE must be a non-negative integer" )
76+ }
77+ }
78+ }
79+
4380// handler handles JSON-RPC messages. There is one handler per connection. Note that
4481// handler is not safe for concurrent use. Message handling never blocks indefinitely
4582// because RPCs are processed on background goroutines launched by handler.
@@ -108,6 +145,75 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
108145 return h
109146}
110147
148+ // batchCallBuffer manages in progress call messages and their responses during a batch
149+ // call. Calls need to be synchronized between the processing and timeout-triggering
150+ // goroutines.
151+ type batchCallBuffer struct {
152+ mutex sync.Mutex
153+ calls []* jsonrpcMessage
154+ resp []* jsonrpcMessage
155+ wrote bool
156+ }
157+
158+ // nextCall returns the next unprocessed message.
159+ func (b * batchCallBuffer ) nextCall () * jsonrpcMessage {
160+ b .mutex .Lock ()
161+ defer b .mutex .Unlock ()
162+
163+ if len (b .calls ) == 0 {
164+ return nil
165+ }
166+ // The popping happens in `pushAnswer`. The in progress call is kept
167+ // so we can return an error for it in case of timeout.
168+ msg := b .calls [0 ]
169+ return msg
170+ }
171+
172+ // pushResponse adds the response to last call returned by nextCall.
173+ func (b * batchCallBuffer ) pushResponse (answer * jsonrpcMessage ) {
174+ b .mutex .Lock ()
175+ defer b .mutex .Unlock ()
176+
177+ if answer != nil {
178+ b .resp = append (b .resp , answer )
179+ }
180+ b .calls = b .calls [1 :]
181+ }
182+
183+ // write sends the responses.
184+ func (b * batchCallBuffer ) write (ctx context.Context , conn jsonWriter ) {
185+ b .mutex .Lock ()
186+ defer b .mutex .Unlock ()
187+
188+ b .doWrite (ctx , conn )
189+ }
190+
191+ // respondWithError sends the responses added so far. For the remaining unanswered call
192+ // messages, it responds with the given error.
193+ func (b * batchCallBuffer ) respondWithError (ctx context.Context , conn jsonWriter , err error ) {
194+ b .mutex .Lock ()
195+ defer b .mutex .Unlock ()
196+
197+ for _ , msg := range b .calls {
198+ if ! msg .isNotification () {
199+ b .resp = append (b .resp , msg .errorResponse (err ))
200+ }
201+ }
202+ b .doWrite (ctx , conn )
203+ }
204+
205+ // doWrite actually writes the response.
206+ // This assumes b.mutex is held.
207+ func (b * batchCallBuffer ) doWrite (ctx context.Context , conn jsonWriter ) {
208+ if b .wrote {
209+ return
210+ }
211+ b .wrote = true // can only write once
212+ if len (b .resp ) > 0 {
213+ conn .writeJSONSkipDeadline (ctx , b .resp , true )
214+ }
215+ }
216+
111217// addLimiter adds a rate limiter to the handler that will allow at most
112218// [refillRate] cpu to be used per second. At most [maxStored] cpu time will be
113219// stored for this limiter.
@@ -129,6 +235,13 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
129235 })
130236 return
131237 }
238+ // Apply limit on total number of requests.
239+ if batchRequestLimit != 0 && len (msgs ) > batchRequestLimit {
240+ h .startCallProc (func (cp * callProc ) {
241+ h .respondWithBatchTooLarge (cp , msgs )
242+ })
243+ return
244+ }
132245
133246 // Handle non-call messages first:
134247 calls := make ([]* jsonrpcMessage , 0 , len (msgs ))
@@ -140,24 +253,76 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
140253 if len (calls ) == 0 {
141254 return
142255 }
256+
143257 // Process calls on a goroutine because they may block indefinitely:
144258 h .startCallProc (func (cp * callProc ) {
145- answers := make ([]* jsonrpcMessage , 0 , len (msgs ))
146- for _ , msg := range calls {
147- if answer := h .handleCallMsg (cp , msg ); answer != nil {
148- answers = append (answers , answer )
259+ var (
260+ timer * time.Timer
261+ cancel context.CancelFunc
262+ callBuffer = & batchCallBuffer {calls : calls , resp : make ([]* jsonrpcMessage , 0 , len (calls ))}
263+ )
264+
265+ cp .ctx , cancel = context .WithCancel (cp .ctx )
266+ defer cancel ()
267+
268+ // Cancel the request context after timeout and send an error response. Since the
269+ // currently-running method might not return immediately on timeout, we must wait
270+ // for the timeout concurrently with processing the request.
271+ if timeout , ok := ContextRequestTimeout (cp .ctx ); ok {
272+ timer = time .AfterFunc (timeout , func () {
273+ cancel ()
274+ err := & internalServerError {errcodeTimeout , errMsgTimeout }
275+ callBuffer .respondWithError (cp .ctx , h .conn , err )
276+ })
277+ }
278+
279+ responseBytes := 0
280+ for {
281+ // No need to handle rest of calls if timed out.
282+ if cp .ctx .Err () != nil {
283+ break
284+ }
285+ msg := callBuffer .nextCall ()
286+ if msg == nil {
287+ break
288+ }
289+ resp := h .handleCallMsg (cp , msg )
290+ callBuffer .pushResponse (resp )
291+ if resp != nil && batchResponseMaxSize != 0 {
292+ responseBytes += len (resp .Result )
293+ if responseBytes > batchResponseMaxSize {
294+ err := & internalServerError {errcodeResponseTooLarge , errMsgResponseTooLarge }
295+ callBuffer .respondWithError (cp .ctx , h .conn , err )
296+ break
297+ }
149298 }
150299 }
151- h .addSubscriptions (cp .notifiers )
152- if len (answers ) > 0 {
153- h .conn .writeJSONSkipDeadline (cp .ctx , answers , h .deadlineContext > 0 )
300+ if timer != nil {
301+ timer .Stop ()
154302 }
303+
304+ h .addSubscriptions (cp .notifiers )
305+ callBuffer .write (cp .ctx , h .conn )
155306 for _ , n := range cp .notifiers {
156307 n .activate ()
157308 }
158309 })
159310}
160311
312+ func (h * handler ) respondWithBatchTooLarge (cp * callProc , batch []* jsonrpcMessage ) {
313+ resp := errorMessage (& invalidRequestError {errMsgBatchTooLarge })
314+ // Find the first call and add its "id" field to the error.
315+ // This is the best we can do, given that the protocol doesn't have a way
316+ // of reporting an error for the entire batch.
317+ for _ , msg := range batch {
318+ if msg .isCall () {
319+ resp .ID = msg .ID
320+ break
321+ }
322+ }
323+ h .conn .writeJSONSkipDeadline (cp .ctx , []* jsonrpcMessage {resp }, h .deadlineContext > 0 )
324+ }
325+
161326// handleMsg handles a single message.
162327func (h * handler ) handleMsg (msg * jsonrpcMessage ) {
163328 if ok := h .handleImmediate (msg ); ok {
0 commit comments