@@ -19,6 +19,7 @@ package ttrpc
1919import (
2020 "bytes"
2121 "context"
22+ "crypto/md5"
2223 "errors"
2324 "fmt"
2425 "net"
@@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
6162}
6263
6364// testingServer is what would be implemented by the user of this package.
64- type testingServer struct {}
65+ type testingServer struct {
66+ echoOnce bool
67+ }
6568
6669func (s * testingServer ) Test (ctx context.Context , req * internal.TestPayload ) (* internal.TestPayload , error ) {
67- tp := & internal.TestPayload {Foo : strings .Repeat (req .Foo , 2 )}
70+ tp := & internal.TestPayload {}
71+ if s .echoOnce {
72+ tp .Foo = req .Foo
73+ } else {
74+ tp .Foo = strings .Repeat (req .Foo , 2 )
75+ }
6876 if dl , ok := ctx .Deadline (); ok {
6977 tp .Deadline = dl .UnixNano ()
7078 }
@@ -299,38 +307,155 @@ func TestServerClose(t *testing.T) {
299307}
300308
301309func TestOversizeCall (t * testing.T ) {
302- var (
303- ctx = context .Background ()
304- server = mustServer (t )(NewServer ())
305- addr , listener = newTestListener (t )
306- errs = make (chan error , 1 )
307- client , cleanup = newTestClient (t , addr )
308- )
309- defer cleanup ()
310- defer listener .Close ()
311- go func () {
312- errs <- server .Serve (ctx , listener )
313- }()
310+ type testCase struct {
311+ name string
312+ echoOnce bool
313+ clientLimit int
314+ serverLimit int
315+ requestSize int
316+ clientFail bool
317+ serverFail bool
318+ }
319+
320+ overhead := getWireMessageOverhead (t )
321+
322+ runTest := func (t * testing.T , tc * testCase ) {
323+ var (
324+ ctx = context .Background ()
325+ server = mustServer (t )(NewServer (WithServerWireMessageLimit (tc .serverLimit )))
326+ addr , listener = newTestListener (t )
327+ errs = make (chan error , 1 )
328+ client , cleanup = newTestClient (t , addr , WithClientWireMessageLimit (tc .clientLimit ))
329+ )
330+ defer cleanup ()
331+ defer listener .Close ()
332+ go func () {
333+ errs <- server .Serve (ctx , listener )
334+ }()
335+
336+ registerTestingService (server , & testingServer {echoOnce : tc .echoOnce })
337+
338+ req := & internal.TestPayload {
339+ Foo : strings .Repeat ("a" , tc .requestSize ),
340+ }
341+ rsp := & internal.TestPayload {}
342+
343+ err := client .Call (ctx , serviceName , "Test" , req , rsp )
344+ if tc .clientFail {
345+ if err == nil {
346+ t .Fatalf ("expected error from oversized message" )
347+ } else if status , ok := status .FromError (err ); ! ok {
348+ t .Fatalf ("expected status present in error: %v" , err )
349+ } else if status .Code () != codes .ResourceExhausted {
350+ t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
351+ }
352+ } else if tc .serverFail {
353+ if err == nil {
354+ t .Fatalf ("expected error from server-side oversized message" )
355+ }
356+ } else {
357+ if err != nil {
358+ t .Fatalf ("expected success, got error %v" , err )
359+ }
360+ }
314361
315- registerTestingService (server , & testingServer {})
362+ if err := server .Shutdown (ctx ); err != nil {
363+ t .Fatal (err )
364+ }
365+ if err := <- errs ; err != ErrServerClosed {
366+ t .Fatal (err )
367+ }
368+ }
316369
317- tp := & internal.TestPayload {
318- Foo : strings .Repeat ("a" , 1 + messageLengthMax ),
370+ for _ , tc := range []* testCase {
371+ {
372+ name : "default limits, fitting request and response" ,
373+ echoOnce : true ,
374+ clientLimit : 0 ,
375+ serverLimit : 0 ,
376+ requestSize : DefaultMessageLengthLimit - overhead ,
377+ },
378+ {
379+ name : "default limits, oversized request" ,
380+ echoOnce : true ,
381+ clientLimit : 0 ,
382+ serverLimit : 0 ,
383+ requestSize : DefaultMessageLengthLimit ,
384+ clientFail : true ,
385+ },
386+ {
387+ name : "default limits, oversized response" ,
388+ clientLimit : 0 ,
389+ serverLimit : 0 ,
390+ requestSize : DefaultMessageLengthLimit / 2 ,
391+ serverFail : true ,
392+ },
393+ {
394+ name : "8K limits, fitting 4K request and response" ,
395+ echoOnce : true ,
396+ clientLimit : 8 * 1024 ,
397+ serverLimit : 8 * 1024 ,
398+ requestSize : 4 * 1024 ,
399+ },
400+ {
401+ name : "8K limits, fitting cc. 4K request and response" ,
402+ echoOnce : true ,
403+ clientLimit : 4 * 1024 ,
404+ serverLimit : 4 * 1024 ,
405+ requestSize : 4 * 1024 - overhead ,
406+ },
407+ {
408+ name : "4K limits, non-fitting 4K response" ,
409+ echoOnce : true ,
410+ clientLimit : 4 * 1024 + overhead ,
411+ serverLimit : 4 * 1024 ,
412+ requestSize : 4 * 1024 ,
413+ serverFail : true ,
414+ },
415+ {
416+ name : "too small limits, adjusted to minimum accepted limit" ,
417+ echoOnce : true ,
418+ clientLimit : 4 ,
419+ serverLimit : 4 ,
420+ requestSize : 4 * 1024 - overhead ,
421+ },
422+ {
423+ name : "maximum allowed protocol limit" ,
424+ echoOnce : true ,
425+ clientLimit : MaxMessageLengthLimit ,
426+ serverLimit : MaxMessageLengthLimit ,
427+ requestSize : MaxMessageLengthLimit - overhead ,
428+ },
429+ } {
430+ t .Run (tc .name , func (t * testing.T ) {
431+ runTest (t , tc )
432+ })
319433 }
320- if err := client .Call (ctx , serviceName , "Test" , tp , tp ); err == nil {
321- t .Fatalf ("expected error from oversized message" )
322- } else if status , ok := status .FromError (err ); ! ok {
323- t .Fatalf ("expected status present in error: %v" , err )
324- } else if status .Code () != codes .ResourceExhausted {
325- t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
434+ }
435+
436+ func getWireMessageOverhead (t * testing.T ) int {
437+ emptyReq , err := codec {}.Marshal (& Request {
438+ Service : serviceName ,
439+ Method : "Test" ,
440+ })
441+ if err != nil {
442+ t .Fatalf ("failed to marshal empty request: %v" , err )
326443 }
327444
328- if err := server .Shutdown (ctx ); err != nil {
329- t .Fatal (err )
445+ emptyRsp , err := codec {}.Marshal (& Response {
446+ Status : status .New (codes .OK , "" ).Proto (),
447+ })
448+ if err != nil {
449+ t .Fatalf ("failed to marshal empty response: %v" , err )
330450 }
331- if err := <- errs ; err != ErrServerClosed {
332- t .Fatal (err )
451+
452+ reqLen := len (emptyReq )
453+ rspLen := len (emptyRsp )
454+ if reqLen > rspLen {
455+ return reqLen + messageHeaderLength
333456 }
457+
458+ return rspLen + messageHeaderLength
334459}
335460
336461func TestClientEOF (t * testing.T ) {
@@ -551,13 +676,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
551676}
552677
553678func newTestListener (t testing.TB ) (string , net.Listener ) {
554- var prefix string
679+ var (
680+ name = t .Name ()
681+ prefix string
682+ )
555683
556684 // Abstracts sockets are only available on Linux.
557685 if runtime .GOOS == "linux" {
558686 prefix = "\x00 "
687+ } else {
688+ if split := strings .SplitN (name , "/" , 2 ); len (split ) == 2 {
689+ name = split [0 ] + "-" + fmt .Sprintf ("%x" , md5 .Sum ([]byte (split [1 ])))
690+ }
559691 }
560- addr := prefix + t . Name ()
692+ addr := prefix + name
561693 listener , err := net .Listen ("unix" , addr )
562694 if err != nil {
563695 t .Fatal (err )
0 commit comments