@@ -19,6 +19,7 @@ package ttrpc
1919import (
2020 "bytes"
2121 "context"
22+ "crypto/md5"
2223 "errors"
2324 "fmt"
2425 "net"
@@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
6162}
6263
6364// testingServer is what would be implemented by the user of this package.
64- type testingServer struct {}
65+ type testingServer struct {
66+ echoOnce bool
67+ }
6568
6669func (s * testingServer ) Test (ctx context.Context , req * internal.TestPayload ) (* internal.TestPayload , error ) {
67- tp := & internal.TestPayload {Foo : strings .Repeat (req .Foo , 2 )}
70+ tp := & internal.TestPayload {}
71+ if s .echoOnce {
72+ tp .Foo = req .Foo
73+ } else {
74+ tp .Foo = strings .Repeat (req .Foo , 2 )
75+ }
6876 if dl , ok := ctx .Deadline (); ok {
6977 tp .Deadline = dl .UnixNano ()
7078 }
@@ -329,38 +337,238 @@ func TestImmediateServerShutdown(t *testing.T) {
329337}
330338
331339func TestOversizeCall (t * testing.T ) {
332- var (
333- ctx = context .Background ()
334- server = mustServer (t )(NewServer ())
335- addr , listener = newTestListener (t )
336- errs = make (chan error , 1 )
337- client , cleanup = newTestClient (t , addr )
338- )
339- defer cleanup ()
340- defer listener .Close ()
341- go func () {
342- errs <- server .Serve (ctx , listener )
343- }()
340+ type testCase struct {
341+ name string
342+ echoOnce bool
343+ clientLimit int
344+ serverLimit int
345+ requestSize int
346+ clientFail bool
347+ sendFail bool
348+ serverFail bool
349+ }
350+
351+ overhead := getWireMessageOverhead (t )
352+
353+ clientOpts := func (tc * testCase ) []ClientOpts {
354+ if tc .clientLimit == 0 {
355+ return nil
356+ }
357+ return []ClientOpts {WithClientWireMessageLimit (tc .clientLimit )}
358+ }
359+ serverOpts := func (tc * testCase ) []ServerOpt {
360+ if tc .serverLimit == 0 {
361+ return nil
362+ }
363+ return []ServerOpt {WithServerWireMessageLimit (tc .serverLimit )}
364+ }
365+
366+ runTest := func (t * testing.T , tc * testCase ) {
367+ var (
368+ ctx = context .Background ()
369+ server = mustServer (t )(NewServer (serverOpts (tc )... ))
370+ addr , listener = newTestListener (t )
371+ errs = make (chan error , 1 )
372+ client , cleanup = newTestClient (t , addr , clientOpts (tc )... )
373+ )
374+ defer cleanup ()
375+ defer listener .Close ()
376+ go func () {
377+ errs <- server .Serve (ctx , listener )
378+ }()
379+
380+ registerTestingService (server , & testingServer {echoOnce : tc .echoOnce })
381+
382+ req := & internal.TestPayload {
383+ Foo : strings .Repeat ("a" , tc .requestSize ),
384+ }
385+ rsp := & internal.TestPayload {}
386+
387+ err := client .Call (ctx , serviceName , "Test" , req , rsp )
388+ if tc .clientFail {
389+ if err == nil {
390+ t .Fatalf ("expected error from oversized message" )
391+ } else if status , ok := status .FromError (err ); ! ok {
392+ t .Fatalf ("expected status present in error: %v" , err )
393+ } else if status .Code () != codes .ResourceExhausted {
394+ t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
395+ }
396+ if tc .sendFail {
397+ var msgLenErr * OversizedMessageErr
398+ if ! errors .As (err , & msgLenErr ) {
399+ t .Fatalf ("failed to retrieve client send OversizedMessageErr" )
400+ }
401+ rejLen , maxLen := msgLenErr .RejectedLength (), msgLenErr .MaximumLength ()
402+ if rejLen == 0 {
403+ t .Fatalf ("zero rejected length in client send oversized message error" )
404+ }
405+ if maxLen == 0 {
406+ t .Fatalf ("zero maximum length in client send oversized message error" )
407+ }
408+ if rejLen <= maxLen {
409+ t .Fatalf ("client send oversized message error rejected < max. length (%d < %d)" ,
410+ rejLen , maxLen )
411+ }
412+ }
413+ } else if tc .serverFail {
414+ if err == nil {
415+ t .Fatalf ("expected error from server-side oversized message" )
416+ } else {
417+ if status , ok := status .FromError (err ); ! ok {
418+ t .Fatalf ("expected status present in error: %v" , err )
419+ } else if status .Code () != codes .ResourceExhausted {
420+ t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
421+ }
422+ if msgLenErr , ok := OversizedMessageFromError (err ); ! ok {
423+ t .Fatalf ("failed to retrieve oversized message error" )
424+ } else {
425+ rejLen , maxLen := msgLenErr .RejectedLength (), msgLenErr .MaximumLength ()
426+ if rejLen == 0 {
427+ t .Fatalf ("zero rejected length in oversized message error" )
428+ }
429+ if maxLen == 0 {
430+ t .Fatalf ("zero maximum length in oversized message error" )
431+ }
432+ if rejLen <= maxLen {
433+ t .Fatalf ("oversized message error rejected < max. length (%d < %d)" ,
434+ rejLen , maxLen )
435+ }
436+ }
437+ }
438+ } else {
439+ if err != nil {
440+ t .Fatalf ("expected success, got error %v" , err )
441+ }
442+ }
344443
345- registerTestingService (server , & testingServer {})
444+ if err := server .Shutdown (ctx ); err != nil {
445+ t .Fatal (err )
446+ }
447+ if err := <- errs ; err != ErrServerClosed {
448+ t .Fatal (err )
449+ }
450+ }
346451
347- tp := & internal.TestPayload {
348- Foo : strings .Repeat ("a" , 1 + messageLengthMax ),
452+ for _ , tc := range []* testCase {
453+ {
454+ name : "default limits, fitting request and response" ,
455+ echoOnce : true ,
456+ clientLimit : 0 ,
457+ serverLimit : 0 ,
458+ requestSize : DefaultMessageLengthLimit - overhead ,
459+ },
460+ {
461+ name : "default limits, only recv side check" ,
462+ clientLimit : 0 ,
463+ serverLimit : 0 ,
464+ requestSize : DefaultMessageLengthLimit - overhead ,
465+ serverFail : true ,
466+ },
467+
468+ {
469+ name : "default limits, oversized request" ,
470+ echoOnce : true ,
471+ clientLimit : 0 ,
472+ serverLimit : 0 ,
473+ requestSize : DefaultMessageLengthLimit ,
474+ clientFail : true ,
475+ },
476+ {
477+ name : "default limits, oversized response" ,
478+ clientLimit : 0 ,
479+ serverLimit : 0 ,
480+ requestSize : DefaultMessageLengthLimit / 2 ,
481+ serverFail : true ,
482+ },
483+ {
484+ name : "8K limits, 4K request and response" ,
485+ echoOnce : true ,
486+ clientLimit : 8 * 1024 ,
487+ serverLimit : 8 * 1024 ,
488+ requestSize : 4 * 1024 ,
489+ },
490+ {
491+ name : "4K limits, barely fitting cc. 4K request and response" ,
492+ echoOnce : true ,
493+ clientLimit : 4 * 1024 ,
494+ serverLimit : 4 * 1024 ,
495+ requestSize : 4 * 1024 - overhead ,
496+ },
497+ {
498+ name : "4K limits, oversized request on client side" ,
499+ echoOnce : true ,
500+ clientLimit : 4 * 1024 ,
501+ serverLimit : 4 * 1024 ,
502+ requestSize : 4 * 1024 ,
503+ clientFail : true ,
504+ sendFail : true ,
505+ },
506+ {
507+ name : "4K limits, oversized request on server side" ,
508+ echoOnce : true ,
509+ clientLimit : 4 * 1024 + overhead ,
510+ serverLimit : 4 * 1024 ,
511+ requestSize : 4 * 1024 ,
512+ serverFail : true ,
513+ },
514+ {
515+ name : "4K limits, oversized response on client side" ,
516+ clientLimit : 4 * 1024 + overhead ,
517+ serverLimit : 4 * 1024 ,
518+ requestSize : 8 * 1024 + overhead ,
519+ clientFail : true ,
520+ },
521+ {
522+ name : "4K limits, oversized response on server side" ,
523+ clientLimit : 4 * 1024 + overhead ,
524+ serverLimit : 4 * 1024 ,
525+ requestSize : 4 * 1024 ,
526+ serverFail : true ,
527+ },
528+ {
529+ name : "too small limits, adjusted to minimum accepted limit" ,
530+ echoOnce : true ,
531+ clientLimit : 4 ,
532+ serverLimit : 4 ,
533+ requestSize : 4 * 1024 - overhead ,
534+ },
535+ {
536+ name : "maximum allowed protocol limit" ,
537+ echoOnce : true ,
538+ clientLimit : MaxMessageLengthLimit ,
539+ serverLimit : MaxMessageLengthLimit ,
540+ requestSize : MaxMessageLengthLimit - overhead ,
541+ },
542+ } {
543+ t .Run (tc .name , func (t * testing.T ) {
544+ runTest (t , tc )
545+ })
349546 }
350- if err := client .Call (ctx , serviceName , "Test" , tp , tp ); err == nil {
351- t .Fatalf ("expected error from oversized message" )
352- } else if status , ok := status .FromError (err ); ! ok {
353- t .Fatalf ("expected status present in error: %v" , err )
354- } else if status .Code () != codes .ResourceExhausted {
355- t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
547+ }
548+
549+ func getWireMessageOverhead (t * testing.T ) int {
550+ emptyReq , err := codec {}.Marshal (& Request {
551+ Service : serviceName ,
552+ Method : "Test" ,
553+ })
554+ if err != nil {
555+ t .Fatalf ("failed to marshal empty request: %v" , err )
356556 }
357557
358- if err := server .Shutdown (ctx ); err != nil {
359- t .Fatal (err )
558+ emptyRsp , err := codec {}.Marshal (& Response {
559+ Status : status .New (codes .OK , "" ).Proto (),
560+ })
561+ if err != nil {
562+ t .Fatalf ("failed to marshal empty response: %v" , err )
360563 }
361- if err := <- errs ; err != ErrServerClosed {
362- t .Fatal (err )
564+
565+ reqLen := len (emptyReq )
566+ rspLen := len (emptyRsp )
567+ if reqLen > rspLen {
568+ return reqLen + messageHeaderLength
363569 }
570+
571+ return rspLen + messageHeaderLength
364572}
365573
366574func TestClientEOF (t * testing.T ) {
@@ -581,13 +789,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
581789}
582790
583791func newTestListener (t testing.TB ) (string , net.Listener ) {
584- var prefix string
792+ var (
793+ name = t .Name ()
794+ prefix string
795+ )
585796
586797 // Abstracts sockets are only available on Linux.
587798 if runtime .GOOS == "linux" {
588799 prefix = "\x00 "
800+ } else {
801+ if split := strings .SplitN (name , "/" , 2 ); len (split ) == 2 {
802+ name = split [0 ] + "-" + fmt .Sprintf ("%x" , md5 .Sum ([]byte (split [1 ])))
803+ }
589804 }
590- addr := prefix + t . Name ()
805+ addr := prefix + name
591806 listener , err := net .Listen ("unix" , addr )
592807 if err != nil {
593808 t .Fatal (err )
0 commit comments