Skip to content

Commit 9e19b2d

Browse files
committed
{channel,server}_test: add tests for message limits.
Adjust unit test to accomodate for altered internal interfaces. Add unit tests to exercise the new message size limit options. Signed-off-by: Krisztian Litkey <[email protected]>
1 parent e1f03b3 commit 9e19b2d

File tree

2 files changed

+247
-32
lines changed

2 files changed

+247
-32
lines changed

channel_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ import (
3131
func TestReadWriteMessage(t *testing.T) {
3232
var (
3333
w, r = net.Pipe()
34-
ch = newChannel(w)
35-
rch = newChannel(r)
34+
ch = newChannel(w, 0)
35+
rch = newChannel(r, 0)
3636
messages = [][]byte{
3737
[]byte("hello"),
3838
[]byte("this is a test"),
@@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) {
9090
func TestMessageOversize(t *testing.T) {
9191
var (
9292
w, _ = net.Pipe()
93-
wch = newChannel(w)
93+
wch = newChannel(w, 0)
9494
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
9595
errs = make(chan error, 1)
9696
)

server_test.go

Lines changed: 244 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package ttrpc
1919
import (
2020
"bytes"
2121
"context"
22+
"crypto/md5"
2223
"errors"
2324
"fmt"
2425
"net"
@@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
6162
}
6263

6364
// testingServer is what would be implemented by the user of this package.
64-
type testingServer struct{}
65+
type testingServer struct {
66+
echoOnce bool
67+
}
6568

6669
func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
67-
tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)}
70+
tp := &internal.TestPayload{}
71+
if s.echoOnce {
72+
tp.Foo = req.Foo
73+
} else {
74+
tp.Foo = strings.Repeat(req.Foo, 2)
75+
}
6876
if dl, ok := ctx.Deadline(); ok {
6977
tp.Deadline = dl.UnixNano()
7078
}
@@ -330,38 +338,238 @@ func TestImmediateServerShutdown(t *testing.T) {
330338
}
331339

332340
func TestOversizeCall(t *testing.T) {
333-
var (
334-
ctx = context.Background()
335-
server = mustServer(t)(NewServer())
336-
addr, listener = newTestListener(t)
337-
errs = make(chan error, 1)
338-
client, cleanup = newTestClient(t, addr)
339-
)
340-
defer cleanup()
341-
defer listener.Close()
342-
go func() {
343-
errs <- server.Serve(ctx, listener)
344-
}()
341+
type testCase struct {
342+
name string
343+
echoOnce bool
344+
clientLimit int
345+
serverLimit int
346+
requestSize int
347+
clientFail bool
348+
sendFail bool
349+
serverFail bool
350+
}
351+
352+
overhead := getWireMessageOverhead(t)
353+
354+
clientOpts := func(tc *testCase) []ClientOpts {
355+
if tc.clientLimit == 0 {
356+
return nil
357+
}
358+
return []ClientOpts{WithClientWireMessageLimit(tc.clientLimit)}
359+
}
360+
serverOpts := func(tc *testCase) []ServerOpt {
361+
if tc.serverLimit == 0 {
362+
return nil
363+
}
364+
return []ServerOpt{WithServerWireMessageLimit(tc.serverLimit)}
365+
}
345366

346-
registerTestingService(server, &testingServer{})
367+
runTest := func(t *testing.T, tc *testCase) {
368+
var (
369+
ctx = context.Background()
370+
server = mustServer(t)(NewServer(serverOpts(tc)...))
371+
addr, listener = newTestListener(t)
372+
errs = make(chan error, 1)
373+
client, cleanup = newTestClient(t, addr, clientOpts(tc)...)
374+
)
375+
defer cleanup()
376+
defer listener.Close()
377+
go func() {
378+
errs <- server.Serve(ctx, listener)
379+
}()
380+
381+
registerTestingService(server, &testingServer{echoOnce: tc.echoOnce})
347382

348-
tp := &internal.TestPayload{
349-
Foo: strings.Repeat("a", 1+messageLengthMax),
383+
req := &internal.TestPayload{
384+
Foo: strings.Repeat("a", tc.requestSize),
385+
}
386+
rsp := &internal.TestPayload{}
387+
388+
err := client.Call(ctx, serviceName, "Test", req, rsp)
389+
if tc.clientFail {
390+
if err == nil {
391+
t.Fatalf("expected error from oversized message")
392+
} else if status, ok := status.FromError(err); !ok {
393+
t.Fatalf("expected status present in error: %v", err)
394+
} else if status.Code() != codes.ResourceExhausted {
395+
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
396+
}
397+
if tc.sendFail {
398+
var msgLenErr *OversizedMessageErr
399+
if !errors.As(err, &msgLenErr) {
400+
t.Fatalf("failed to retrieve client send OversizedMessageErr")
401+
}
402+
rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength()
403+
if rejLen == 0 {
404+
t.Fatalf("zero rejected length in client send oversized message error")
405+
}
406+
if maxLen == 0 {
407+
t.Fatalf("zero maximum length in client send oversized message error")
408+
}
409+
if rejLen <= maxLen {
410+
t.Fatalf("client send oversized message error rejected < max. length (%d < %d)",
411+
rejLen, maxLen)
412+
}
413+
}
414+
} else if tc.serverFail {
415+
if err == nil {
416+
t.Fatalf("expected error from server-side oversized message")
417+
} else {
418+
if status, ok := status.FromError(err); !ok {
419+
t.Fatalf("expected status present in error: %v", err)
420+
} else if status.Code() != codes.ResourceExhausted {
421+
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
422+
}
423+
if msgLenErr, ok := OversizedMessageFromError(err); !ok {
424+
t.Fatalf("failed to retrieve oversized message error")
425+
} else {
426+
rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength()
427+
if rejLen == 0 {
428+
t.Fatalf("zero rejected length in oversized message error")
429+
}
430+
if maxLen == 0 {
431+
t.Fatalf("zero maximum length in oversized message error")
432+
}
433+
if rejLen <= maxLen {
434+
t.Fatalf("oversized message error rejected < max. length (%d < %d)",
435+
rejLen, maxLen)
436+
}
437+
}
438+
}
439+
} else {
440+
if err != nil {
441+
t.Fatalf("expected success, got error %v", err)
442+
}
443+
}
444+
445+
if err := server.Shutdown(ctx); err != nil {
446+
t.Fatal(err)
447+
}
448+
if err := <-errs; err != ErrServerClosed {
449+
t.Fatal(err)
450+
}
350451
}
351-
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
352-
t.Fatalf("expected error from oversized message")
353-
} else if status, ok := status.FromError(err); !ok {
354-
t.Fatalf("expected status present in error: %v", err)
355-
} else if status.Code() != codes.ResourceExhausted {
356-
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
452+
453+
for _, tc := range []*testCase{
454+
{
455+
name: "default limits, fitting request and response",
456+
echoOnce: true,
457+
clientLimit: 0,
458+
serverLimit: 0,
459+
requestSize: DefaultMessageLengthLimit - overhead,
460+
},
461+
{
462+
name: "default limits, only recv side check",
463+
clientLimit: 0,
464+
serverLimit: 0,
465+
requestSize: DefaultMessageLengthLimit - overhead,
466+
serverFail: true,
467+
},
468+
469+
{
470+
name: "default limits, oversized request",
471+
echoOnce: true,
472+
clientLimit: 0,
473+
serverLimit: 0,
474+
requestSize: DefaultMessageLengthLimit,
475+
clientFail: true,
476+
},
477+
{
478+
name: "default limits, oversized response",
479+
clientLimit: 0,
480+
serverLimit: 0,
481+
requestSize: DefaultMessageLengthLimit / 2,
482+
serverFail: true,
483+
},
484+
{
485+
name: "8K limits, 4K request and response",
486+
echoOnce: true,
487+
clientLimit: 8 * 1024,
488+
serverLimit: 8 * 1024,
489+
requestSize: 4 * 1024,
490+
},
491+
{
492+
name: "4K limits, barely fitting cc. 4K request and response",
493+
echoOnce: true,
494+
clientLimit: 4 * 1024,
495+
serverLimit: 4 * 1024,
496+
requestSize: 4*1024 - overhead,
497+
},
498+
{
499+
name: "4K limits, oversized request on client side",
500+
echoOnce: true,
501+
clientLimit: 4 * 1024,
502+
serverLimit: 4 * 1024,
503+
requestSize: 4 * 1024,
504+
clientFail: true,
505+
sendFail: true,
506+
},
507+
{
508+
name: "4K limits, oversized request on server side",
509+
echoOnce: true,
510+
clientLimit: 4*1024 + overhead,
511+
serverLimit: 4 * 1024,
512+
requestSize: 4 * 1024,
513+
serverFail: true,
514+
},
515+
{
516+
name: "4K limits, oversized response on client side",
517+
clientLimit: 4*1024 + overhead,
518+
serverLimit: 4 * 1024,
519+
requestSize: 8*1024 + overhead,
520+
clientFail: true,
521+
},
522+
{
523+
name: "4K limits, oversized response on server side",
524+
clientLimit: 4*1024 + overhead,
525+
serverLimit: 4 * 1024,
526+
requestSize: 4 * 1024,
527+
serverFail: true,
528+
},
529+
{
530+
name: "too small limits, adjusted to minimum accepted limit",
531+
echoOnce: true,
532+
clientLimit: 4,
533+
serverLimit: 4,
534+
requestSize: 4*1024 - overhead,
535+
},
536+
{
537+
name: "maximum allowed protocol limit",
538+
echoOnce: true,
539+
clientLimit: MaxMessageLengthLimit,
540+
serverLimit: MaxMessageLengthLimit,
541+
requestSize: MaxMessageLengthLimit - overhead,
542+
},
543+
} {
544+
t.Run(tc.name, func(t *testing.T) {
545+
runTest(t, tc)
546+
})
357547
}
548+
}
358549

359-
if err := server.Shutdown(ctx); err != nil {
360-
t.Fatal(err)
550+
func getWireMessageOverhead(t *testing.T) int {
551+
emptyReq, err := codec{}.Marshal(&Request{
552+
Service: serviceName,
553+
Method: "Test",
554+
})
555+
if err != nil {
556+
t.Fatalf("failed to marshal empty request: %v", err)
361557
}
362-
if err := <-errs; err != ErrServerClosed {
363-
t.Fatal(err)
558+
559+
emptyRsp, err := codec{}.Marshal(&Response{
560+
Status: status.New(codes.OK, "").Proto(),
561+
})
562+
if err != nil {
563+
t.Fatalf("failed to marshal empty response: %v", err)
564+
}
565+
566+
reqLen := len(emptyReq)
567+
rspLen := len(emptyRsp)
568+
if reqLen > rspLen {
569+
return reqLen + messageHeaderLength
364570
}
571+
572+
return rspLen + messageHeaderLength
365573
}
366574

367575
func TestClientEOF(t *testing.T) {
@@ -582,13 +790,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
582790
}
583791

584792
func newTestListener(t testing.TB) (string, net.Listener) {
585-
var prefix string
793+
var (
794+
name = t.Name()
795+
prefix string
796+
)
586797

587798
// Abstracts sockets are only available on Linux.
588799
if runtime.GOOS == "linux" {
589800
prefix = "\x00"
801+
} else {
802+
if split := strings.SplitN(name, "/", 2); len(split) == 2 {
803+
name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1])))
804+
}
590805
}
591-
addr := prefix + t.Name()
806+
addr := prefix + name
592807
listener, err := net.Listen("unix", addr)
593808
if err != nil {
594809
t.Fatal(err)

0 commit comments

Comments
 (0)