@@ -19,6 +19,7 @@ package ttrpc
19
19
import (
20
20
"bytes"
21
21
"context"
22
+ "crypto/md5"
22
23
"errors"
23
24
"fmt"
24
25
"net"
@@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
61
62
}
62
63
63
64
// testingServer is what would be implemented by the user of this package.
64
- type testingServer struct {}
65
+ type testingServer struct {
66
+ echoOnce bool
67
+ }
65
68
66
69
func (s * testingServer ) Test (ctx context.Context , req * internal.TestPayload ) (* internal.TestPayload , error ) {
67
- tp := & internal.TestPayload {Foo : strings .Repeat (req .Foo , 2 )}
70
+ tp := & internal.TestPayload {}
71
+ if s .echoOnce {
72
+ tp .Foo = req .Foo
73
+ } else {
74
+ tp .Foo = strings .Repeat (req .Foo , 2 )
75
+ }
68
76
if dl , ok := ctx .Deadline (); ok {
69
77
tp .Deadline = dl .UnixNano ()
70
78
}
@@ -330,38 +338,238 @@ func TestImmediateServerShutdown(t *testing.T) {
330
338
}
331
339
332
340
func TestOversizeCall (t * testing.T ) {
333
- var (
334
- ctx = context .Background ()
335
- server = mustServer (t )(NewServer ())
336
- addr , listener = newTestListener (t )
337
- errs = make (chan error , 1 )
338
- client , cleanup = newTestClient (t , addr )
339
- )
340
- defer cleanup ()
341
- defer listener .Close ()
342
- go func () {
343
- errs <- server .Serve (ctx , listener )
344
- }()
341
+ type testCase struct {
342
+ name string
343
+ echoOnce bool
344
+ clientLimit int
345
+ serverLimit int
346
+ requestSize int
347
+ clientFail bool
348
+ sendFail bool
349
+ serverFail bool
350
+ }
351
+
352
+ overhead := getWireMessageOverhead (t )
353
+
354
+ clientOpts := func (tc * testCase ) []ClientOpts {
355
+ if tc .clientLimit == 0 {
356
+ return nil
357
+ }
358
+ return []ClientOpts {WithClientWireMessageLimit (tc .clientLimit )}
359
+ }
360
+ serverOpts := func (tc * testCase ) []ServerOpt {
361
+ if tc .serverLimit == 0 {
362
+ return nil
363
+ }
364
+ return []ServerOpt {WithServerWireMessageLimit (tc .serverLimit )}
365
+ }
345
366
346
- registerTestingService (server , & testingServer {})
367
+ runTest := func (t * testing.T , tc * testCase ) {
368
+ var (
369
+ ctx = context .Background ()
370
+ server = mustServer (t )(NewServer (serverOpts (tc )... ))
371
+ addr , listener = newTestListener (t )
372
+ errs = make (chan error , 1 )
373
+ client , cleanup = newTestClient (t , addr , clientOpts (tc )... )
374
+ )
375
+ defer cleanup ()
376
+ defer listener .Close ()
377
+ go func () {
378
+ errs <- server .Serve (ctx , listener )
379
+ }()
380
+
381
+ registerTestingService (server , & testingServer {echoOnce : tc .echoOnce })
347
382
348
- tp := & internal.TestPayload {
349
- Foo : strings .Repeat ("a" , 1 + messageLengthMax ),
383
+ req := & internal.TestPayload {
384
+ Foo : strings .Repeat ("a" , tc .requestSize ),
385
+ }
386
+ rsp := & internal.TestPayload {}
387
+
388
+ err := client .Call (ctx , serviceName , "Test" , req , rsp )
389
+ if tc .clientFail {
390
+ if err == nil {
391
+ t .Fatalf ("expected error from oversized message" )
392
+ } else if status , ok := status .FromError (err ); ! ok {
393
+ t .Fatalf ("expected status present in error: %v" , err )
394
+ } else if status .Code () != codes .ResourceExhausted {
395
+ t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
396
+ }
397
+ if tc .sendFail {
398
+ var msgLenErr * OversizedMessageErr
399
+ if ! errors .As (err , & msgLenErr ) {
400
+ t .Fatalf ("failed to retrieve client send OversizedMessageErr" )
401
+ }
402
+ rejLen , maxLen := msgLenErr .RejectedLength (), msgLenErr .MaximumLength ()
403
+ if rejLen == 0 {
404
+ t .Fatalf ("zero rejected length in client send oversized message error" )
405
+ }
406
+ if maxLen == 0 {
407
+ t .Fatalf ("zero maximum length in client send oversized message error" )
408
+ }
409
+ if rejLen <= maxLen {
410
+ t .Fatalf ("client send oversized message error rejected < max. length (%d < %d)" ,
411
+ rejLen , maxLen )
412
+ }
413
+ }
414
+ } else if tc .serverFail {
415
+ if err == nil {
416
+ t .Fatalf ("expected error from server-side oversized message" )
417
+ } else {
418
+ if status , ok := status .FromError (err ); ! ok {
419
+ t .Fatalf ("expected status present in error: %v" , err )
420
+ } else if status .Code () != codes .ResourceExhausted {
421
+ t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
422
+ }
423
+ if msgLenErr , ok := OversizedMessageFromError (err ); ! ok {
424
+ t .Fatalf ("failed to retrieve oversized message error" )
425
+ } else {
426
+ rejLen , maxLen := msgLenErr .RejectedLength (), msgLenErr .MaximumLength ()
427
+ if rejLen == 0 {
428
+ t .Fatalf ("zero rejected length in oversized message error" )
429
+ }
430
+ if maxLen == 0 {
431
+ t .Fatalf ("zero maximum length in oversized message error" )
432
+ }
433
+ if rejLen <= maxLen {
434
+ t .Fatalf ("oversized message error rejected < max. length (%d < %d)" ,
435
+ rejLen , maxLen )
436
+ }
437
+ }
438
+ }
439
+ } else {
440
+ if err != nil {
441
+ t .Fatalf ("expected success, got error %v" , err )
442
+ }
443
+ }
444
+
445
+ if err := server .Shutdown (ctx ); err != nil {
446
+ t .Fatal (err )
447
+ }
448
+ if err := <- errs ; err != ErrServerClosed {
449
+ t .Fatal (err )
450
+ }
350
451
}
351
- if err := client .Call (ctx , serviceName , "Test" , tp , tp ); err == nil {
352
- t .Fatalf ("expected error from oversized message" )
353
- } else if status , ok := status .FromError (err ); ! ok {
354
- t .Fatalf ("expected status present in error: %v" , err )
355
- } else if status .Code () != codes .ResourceExhausted {
356
- t .Fatalf ("expected code: %v != %v" , status .Code (), codes .ResourceExhausted )
452
+
453
+ for _ , tc := range []* testCase {
454
+ {
455
+ name : "default limits, fitting request and response" ,
456
+ echoOnce : true ,
457
+ clientLimit : 0 ,
458
+ serverLimit : 0 ,
459
+ requestSize : DefaultMessageLengthLimit - overhead ,
460
+ },
461
+ {
462
+ name : "default limits, only recv side check" ,
463
+ clientLimit : 0 ,
464
+ serverLimit : 0 ,
465
+ requestSize : DefaultMessageLengthLimit - overhead ,
466
+ serverFail : true ,
467
+ },
468
+
469
+ {
470
+ name : "default limits, oversized request" ,
471
+ echoOnce : true ,
472
+ clientLimit : 0 ,
473
+ serverLimit : 0 ,
474
+ requestSize : DefaultMessageLengthLimit ,
475
+ clientFail : true ,
476
+ },
477
+ {
478
+ name : "default limits, oversized response" ,
479
+ clientLimit : 0 ,
480
+ serverLimit : 0 ,
481
+ requestSize : DefaultMessageLengthLimit / 2 ,
482
+ serverFail : true ,
483
+ },
484
+ {
485
+ name : "8K limits, 4K request and response" ,
486
+ echoOnce : true ,
487
+ clientLimit : 8 * 1024 ,
488
+ serverLimit : 8 * 1024 ,
489
+ requestSize : 4 * 1024 ,
490
+ },
491
+ {
492
+ name : "4K limits, barely fitting cc. 4K request and response" ,
493
+ echoOnce : true ,
494
+ clientLimit : 4 * 1024 ,
495
+ serverLimit : 4 * 1024 ,
496
+ requestSize : 4 * 1024 - overhead ,
497
+ },
498
+ {
499
+ name : "4K limits, oversized request on client side" ,
500
+ echoOnce : true ,
501
+ clientLimit : 4 * 1024 ,
502
+ serverLimit : 4 * 1024 ,
503
+ requestSize : 4 * 1024 ,
504
+ clientFail : true ,
505
+ sendFail : true ,
506
+ },
507
+ {
508
+ name : "4K limits, oversized request on server side" ,
509
+ echoOnce : true ,
510
+ clientLimit : 4 * 1024 + overhead ,
511
+ serverLimit : 4 * 1024 ,
512
+ requestSize : 4 * 1024 ,
513
+ serverFail : true ,
514
+ },
515
+ {
516
+ name : "4K limits, oversized response on client side" ,
517
+ clientLimit : 4 * 1024 + overhead ,
518
+ serverLimit : 4 * 1024 ,
519
+ requestSize : 8 * 1024 + overhead ,
520
+ clientFail : true ,
521
+ },
522
+ {
523
+ name : "4K limits, oversized response on server side" ,
524
+ clientLimit : 4 * 1024 + overhead ,
525
+ serverLimit : 4 * 1024 ,
526
+ requestSize : 4 * 1024 ,
527
+ serverFail : true ,
528
+ },
529
+ {
530
+ name : "too small limits, adjusted to minimum accepted limit" ,
531
+ echoOnce : true ,
532
+ clientLimit : 4 ,
533
+ serverLimit : 4 ,
534
+ requestSize : 4 * 1024 - overhead ,
535
+ },
536
+ {
537
+ name : "maximum allowed protocol limit" ,
538
+ echoOnce : true ,
539
+ clientLimit : MaxMessageLengthLimit ,
540
+ serverLimit : MaxMessageLengthLimit ,
541
+ requestSize : MaxMessageLengthLimit - overhead ,
542
+ },
543
+ } {
544
+ t .Run (tc .name , func (t * testing.T ) {
545
+ runTest (t , tc )
546
+ })
357
547
}
548
+ }
358
549
359
- if err := server .Shutdown (ctx ); err != nil {
360
- t .Fatal (err )
550
+ func getWireMessageOverhead (t * testing.T ) int {
551
+ emptyReq , err := codec {}.Marshal (& Request {
552
+ Service : serviceName ,
553
+ Method : "Test" ,
554
+ })
555
+ if err != nil {
556
+ t .Fatalf ("failed to marshal empty request: %v" , err )
361
557
}
362
- if err := <- errs ; err != ErrServerClosed {
363
- t .Fatal (err )
558
+
559
+ emptyRsp , err := codec {}.Marshal (& Response {
560
+ Status : status .New (codes .OK , "" ).Proto (),
561
+ })
562
+ if err != nil {
563
+ t .Fatalf ("failed to marshal empty response: %v" , err )
564
+ }
565
+
566
+ reqLen := len (emptyReq )
567
+ rspLen := len (emptyRsp )
568
+ if reqLen > rspLen {
569
+ return reqLen + messageHeaderLength
364
570
}
571
+
572
+ return rspLen + messageHeaderLength
365
573
}
366
574
367
575
func TestClientEOF (t * testing.T ) {
@@ -582,13 +790,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
582
790
}
583
791
584
792
func newTestListener (t testing.TB ) (string , net.Listener ) {
585
- var prefix string
793
+ var (
794
+ name = t .Name ()
795
+ prefix string
796
+ )
586
797
587
798
// Abstracts sockets are only available on Linux.
588
799
if runtime .GOOS == "linux" {
589
800
prefix = "\x00 "
801
+ } else {
802
+ if split := strings .SplitN (name , "/" , 2 ); len (split ) == 2 {
803
+ name = split [0 ] + "-" + fmt .Sprintf ("%x" , md5 .Sum ([]byte (split [1 ])))
804
+ }
590
805
}
591
- addr := prefix + t . Name ()
806
+ addr := prefix + name
592
807
listener , err := net .Listen ("unix" , addr )
593
808
if err != nil {
594
809
t .Fatal (err )
0 commit comments