Skip to content

Commit 7a70e2e

Browse files
committed
{channel,server}_test: add tests for message limits.
Adjust unit test to accomodate for altered internal interfaces. Add unit tests to exercise the new message size limit options. Signed-off-by: Krisztian Litkey <[email protected]>
1 parent 370ca78 commit 7a70e2e

File tree

2 files changed

+247
-32
lines changed

2 files changed

+247
-32
lines changed

channel_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ import (
3131
func TestReadWriteMessage(t *testing.T) {
3232
var (
3333
w, r = net.Pipe()
34-
ch = newChannel(w)
35-
rch = newChannel(r)
34+
ch = newChannel(w, 0)
35+
rch = newChannel(r, 0)
3636
messages = [][]byte{
3737
[]byte("hello"),
3838
[]byte("this is a test"),
@@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) {
9090
func TestMessageOversize(t *testing.T) {
9191
var (
9292
w, _ = net.Pipe()
93-
wch = newChannel(w)
93+
wch = newChannel(w, 0)
9494
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
9595
errs = make(chan error, 1)
9696
)

server_test.go

Lines changed: 244 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package ttrpc
1919
import (
2020
"bytes"
2121
"context"
22+
"crypto/md5"
2223
"errors"
2324
"fmt"
2425
"net"
@@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
6162
}
6263

6364
// testingServer is what would be implemented by the user of this package.
64-
type testingServer struct{}
65+
type testingServer struct {
66+
echoOnce bool
67+
}
6568

6669
func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
67-
tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)}
70+
tp := &internal.TestPayload{}
71+
if s.echoOnce {
72+
tp.Foo = req.Foo
73+
} else {
74+
tp.Foo = strings.Repeat(req.Foo, 2)
75+
}
6876
if dl, ok := ctx.Deadline(); ok {
6977
tp.Deadline = dl.UnixNano()
7078
}
@@ -329,38 +337,238 @@ func TestImmediateServerShutdown(t *testing.T) {
329337
}
330338

331339
func TestOversizeCall(t *testing.T) {
332-
var (
333-
ctx = context.Background()
334-
server = mustServer(t)(NewServer())
335-
addr, listener = newTestListener(t)
336-
errs = make(chan error, 1)
337-
client, cleanup = newTestClient(t, addr)
338-
)
339-
defer cleanup()
340-
defer listener.Close()
341-
go func() {
342-
errs <- server.Serve(ctx, listener)
343-
}()
340+
type testCase struct {
341+
name string
342+
echoOnce bool
343+
clientLimit int
344+
serverLimit int
345+
requestSize int
346+
clientFail bool
347+
sendFail bool
348+
serverFail bool
349+
}
350+
351+
overhead := getWireMessageOverhead(t)
352+
353+
clientOpts := func(tc *testCase) []ClientOpts {
354+
if tc.clientLimit == 0 {
355+
return nil
356+
}
357+
return []ClientOpts{WithClientWireMessageLimit(tc.clientLimit)}
358+
}
359+
serverOpts := func(tc *testCase) []ServerOpt {
360+
if tc.serverLimit == 0 {
361+
return nil
362+
}
363+
return []ServerOpt{WithServerWireMessageLimit(tc.serverLimit)}
364+
}
365+
366+
runTest := func(t *testing.T, tc *testCase) {
367+
var (
368+
ctx = context.Background()
369+
server = mustServer(t)(NewServer(serverOpts(tc)...))
370+
addr, listener = newTestListener(t)
371+
errs = make(chan error, 1)
372+
client, cleanup = newTestClient(t, addr, clientOpts(tc)...)
373+
)
374+
defer cleanup()
375+
defer listener.Close()
376+
go func() {
377+
errs <- server.Serve(ctx, listener)
378+
}()
379+
380+
registerTestingService(server, &testingServer{echoOnce: tc.echoOnce})
381+
382+
req := &internal.TestPayload{
383+
Foo: strings.Repeat("a", tc.requestSize),
384+
}
385+
rsp := &internal.TestPayload{}
386+
387+
err := client.Call(ctx, serviceName, "Test", req, rsp)
388+
if tc.clientFail {
389+
if err == nil {
390+
t.Fatalf("expected error from oversized message")
391+
} else if status, ok := status.FromError(err); !ok {
392+
t.Fatalf("expected status present in error: %v", err)
393+
} else if status.Code() != codes.ResourceExhausted {
394+
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
395+
}
396+
if tc.sendFail {
397+
var msgLenErr *OversizedMessageErr
398+
if !errors.As(err, &msgLenErr) {
399+
t.Fatalf("failed to retrieve client send OversizedMessageErr")
400+
}
401+
rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength()
402+
if rejLen == 0 {
403+
t.Fatalf("zero rejected length in client send oversized message error")
404+
}
405+
if maxLen == 0 {
406+
t.Fatalf("zero maximum length in client send oversized message error")
407+
}
408+
if rejLen <= maxLen {
409+
t.Fatalf("client send oversized message error rejected < max. length (%d < %d)",
410+
rejLen, maxLen)
411+
}
412+
}
413+
} else if tc.serverFail {
414+
if err == nil {
415+
t.Fatalf("expected error from server-side oversized message")
416+
} else {
417+
if status, ok := status.FromError(err); !ok {
418+
t.Fatalf("expected status present in error: %v", err)
419+
} else if status.Code() != codes.ResourceExhausted {
420+
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
421+
}
422+
if msgLenErr, ok := OversizedMessageFromError(err); !ok {
423+
t.Fatalf("failed to retrieve oversized message error")
424+
} else {
425+
rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength()
426+
if rejLen == 0 {
427+
t.Fatalf("zero rejected length in oversized message error")
428+
}
429+
if maxLen == 0 {
430+
t.Fatalf("zero maximum length in oversized message error")
431+
}
432+
if rejLen <= maxLen {
433+
t.Fatalf("oversized message error rejected < max. length (%d < %d)",
434+
rejLen, maxLen)
435+
}
436+
}
437+
}
438+
} else {
439+
if err != nil {
440+
t.Fatalf("expected success, got error %v", err)
441+
}
442+
}
344443

345-
registerTestingService(server, &testingServer{})
444+
if err := server.Shutdown(ctx); err != nil {
445+
t.Fatal(err)
446+
}
447+
if err := <-errs; err != ErrServerClosed {
448+
t.Fatal(err)
449+
}
450+
}
346451

347-
tp := &internal.TestPayload{
348-
Foo: strings.Repeat("a", 1+messageLengthMax),
452+
for _, tc := range []*testCase{
453+
{
454+
name: "default limits, fitting request and response",
455+
echoOnce: true,
456+
clientLimit: 0,
457+
serverLimit: 0,
458+
requestSize: DefaultMessageLengthLimit - overhead,
459+
},
460+
{
461+
name: "default limits, only recv side check",
462+
clientLimit: 0,
463+
serverLimit: 0,
464+
requestSize: DefaultMessageLengthLimit - overhead,
465+
serverFail: true,
466+
},
467+
468+
{
469+
name: "default limits, oversized request",
470+
echoOnce: true,
471+
clientLimit: 0,
472+
serverLimit: 0,
473+
requestSize: DefaultMessageLengthLimit,
474+
clientFail: true,
475+
},
476+
{
477+
name: "default limits, oversized response",
478+
clientLimit: 0,
479+
serverLimit: 0,
480+
requestSize: DefaultMessageLengthLimit / 2,
481+
serverFail: true,
482+
},
483+
{
484+
name: "8K limits, 4K request and response",
485+
echoOnce: true,
486+
clientLimit: 8 * 1024,
487+
serverLimit: 8 * 1024,
488+
requestSize: 4 * 1024,
489+
},
490+
{
491+
name: "4K limits, barely fitting cc. 4K request and response",
492+
echoOnce: true,
493+
clientLimit: 4 * 1024,
494+
serverLimit: 4 * 1024,
495+
requestSize: 4*1024 - overhead,
496+
},
497+
{
498+
name: "4K limits, oversized request on client side",
499+
echoOnce: true,
500+
clientLimit: 4 * 1024,
501+
serverLimit: 4 * 1024,
502+
requestSize: 4 * 1024,
503+
clientFail: true,
504+
sendFail: true,
505+
},
506+
{
507+
name: "4K limits, oversized request on server side",
508+
echoOnce: true,
509+
clientLimit: 4*1024 + overhead,
510+
serverLimit: 4 * 1024,
511+
requestSize: 4 * 1024,
512+
serverFail: true,
513+
},
514+
{
515+
name: "4K limits, oversized response on client side",
516+
clientLimit: 4*1024 + overhead,
517+
serverLimit: 4 * 1024,
518+
requestSize: 8*1024 + overhead,
519+
clientFail: true,
520+
},
521+
{
522+
name: "4K limits, oversized response on server side",
523+
clientLimit: 4*1024 + overhead,
524+
serverLimit: 4 * 1024,
525+
requestSize: 4 * 1024,
526+
serverFail: true,
527+
},
528+
{
529+
name: "too small limits, adjusted to minimum accepted limit",
530+
echoOnce: true,
531+
clientLimit: 4,
532+
serverLimit: 4,
533+
requestSize: 4*1024 - overhead,
534+
},
535+
{
536+
name: "maximum allowed protocol limit",
537+
echoOnce: true,
538+
clientLimit: MaxMessageLengthLimit,
539+
serverLimit: MaxMessageLengthLimit,
540+
requestSize: MaxMessageLengthLimit - overhead,
541+
},
542+
} {
543+
t.Run(tc.name, func(t *testing.T) {
544+
runTest(t, tc)
545+
})
349546
}
350-
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
351-
t.Fatalf("expected error from oversized message")
352-
} else if status, ok := status.FromError(err); !ok {
353-
t.Fatalf("expected status present in error: %v", err)
354-
} else if status.Code() != codes.ResourceExhausted {
355-
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
547+
}
548+
549+
func getWireMessageOverhead(t *testing.T) int {
550+
emptyReq, err := codec{}.Marshal(&Request{
551+
Service: serviceName,
552+
Method: "Test",
553+
})
554+
if err != nil {
555+
t.Fatalf("failed to marshal empty request: %v", err)
356556
}
357557

358-
if err := server.Shutdown(ctx); err != nil {
359-
t.Fatal(err)
558+
emptyRsp, err := codec{}.Marshal(&Response{
559+
Status: status.New(codes.OK, "").Proto(),
560+
})
561+
if err != nil {
562+
t.Fatalf("failed to marshal empty response: %v", err)
360563
}
361-
if err := <-errs; err != ErrServerClosed {
362-
t.Fatal(err)
564+
565+
reqLen := len(emptyReq)
566+
rspLen := len(emptyRsp)
567+
if reqLen > rspLen {
568+
return reqLen + messageHeaderLength
363569
}
570+
571+
return rspLen + messageHeaderLength
364572
}
365573

366574
func TestClientEOF(t *testing.T) {
@@ -581,13 +789,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
581789
}
582790

583791
func newTestListener(t testing.TB) (string, net.Listener) {
584-
var prefix string
792+
var (
793+
name = t.Name()
794+
prefix string
795+
)
585796

586797
// Abstracts sockets are only available on Linux.
587798
if runtime.GOOS == "linux" {
588799
prefix = "\x00"
800+
} else {
801+
if split := strings.SplitN(name, "/", 2); len(split) == 2 {
802+
name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1])))
803+
}
589804
}
590-
addr := prefix + t.Name()
805+
addr := prefix + name
591806
listener, err := net.Listen("unix", addr)
592807
if err != nil {
593808
t.Fatal(err)

0 commit comments

Comments
 (0)