Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 94 additions & 141 deletions binarylog/binarylog_end2end_test.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you are here, would you mind changing the clientconn method on test and instead do the following:

  • In test.startServer call stubserver.Start() instead of StartServer()
    • This would also end up creating a ClientConn to the newly created stub server
  • In test.clientconn, return stubserver.CC
  • Or optionally, get rid of this test.clientconn method and directly use test.ss.CC if possible.
  • Either way is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍
I got rid of test.clientConn (and test.tearDown): d212ea6

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/binarylog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
iblog "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/grpctest"
Expand Down Expand Up @@ -130,131 +129,24 @@ func payloadToID(p *testpb.Payload) int32 {
return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24
}

type testServer struct {
testgrpc.UnimplementedTestServiceServer
te *test
}

func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if err := grpc.SendHeader(ctx, md); err != nil {
return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
}
if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
}
}

if id := payloadToID(in.Payload); id == errorID {
return nil, fmt.Errorf("got error id: %v", id)
}

return &testpb.SimpleResponse{Payload: in.Payload}, nil
}

func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
}
stream.SetTrailer(testTrailerMetadata)
}
for {
in, err := stream.Recv()
if err == io.EOF {
// read done.
return nil
}
if err != nil {
return err
}

if id := payloadToID(in.Payload); id == errorID {
return fmt.Errorf("got error id: %v", id)
}

if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
return err
}
}
}

func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
}
stream.SetTrailer(testTrailerMetadata)
}
for {
in, err := stream.Recv()
if err == io.EOF {
// read done.
return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
}
if err != nil {
return err
}

if id := payloadToID(in.Payload); id == errorID {
return fmt.Errorf("got error id: %v", id)
}
}
}

func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
}
stream.SetTrailer(testTrailerMetadata)
}

if id := payloadToID(in.Payload); id == errorID {
return fmt.Errorf("got error id: %v", id)
}

for i := 0; i < 5; i++ {
if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
return err
}
}
return nil
}

// test is an end-to-end test. It should be created with the newTest
// func, modified as needed, and then started with its startServer method.
// It should be cleaned up with the tearDown method.
type test struct {
t *testing.T

testService testgrpc.TestServiceServer // nil means none
// srv and srvAddr are set once startServer is called.
srv *grpc.Server
// ss and srvAddr are set once startServer is called.
ss *stubserver.StubServer
srvAddr string // Server IP without port.
srvIP net.IP
srvPort int

cc *grpc.ClientConn // nil until requested via clientConn

// Fields for client address. Set by the service handler.
clientAddrMu sync.Mutex
clientIP net.IP
clientPort int
}

func (te *test) tearDown() {
if te.cc != nil {
te.cc.Close()
te.cc = nil
}
te.srv.Stop()
}

// newTest returns a new test using the provided testing.T and
// environment. It is returned with default values. Tests should
// modify it before calling its startServer and clientConn methods.
Expand Down Expand Up @@ -284,8 +176,7 @@ func (lw *listenerWrapper) Accept() (net.Conn, error) {

// startServer starts a gRPC server listening. Callers should defer a
// call to te.tearDown to clean up.
func (te *test) startServer(ts testgrpc.TestServiceServer) {
te.testService = ts
func (te *test) startServer() {
lis, err := net.Listen("tcp", "localhost:0")

lis = &listenerWrapper{
Expand All @@ -296,33 +187,96 @@ func (te *test) startServer(ts testgrpc.TestServiceServer) {
if err != nil {
te.t.Fatalf("Failed to listen: %v", err)
}
var opts []grpc.ServerOption
s := grpc.NewServer(opts...)
te.srv = s
if te.testService != nil {
testgrpc.RegisterTestServiceServer(s, te.testService)
}

go s.Serve(lis)
te.ss = &stubserver.StubServer{
Listener: lis,
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if err := grpc.SendHeader(ctx, md); err != nil {
return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
}
if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
}
}
if id := payloadToID(in.Payload); id == errorID {
return nil, fmt.Errorf("got error id: %v", id)
}
return &testpb.SimpleResponse{Payload: in.Payload}, nil
},
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
}
stream.SetTrailer(testTrailerMetadata)
}
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
if id := payloadToID(in.Payload); id == errorID {
return fmt.Errorf("got error id: %v", id)
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
return err
}
}
},
StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
}
stream.SetTrailer(testTrailerMetadata)
}
for {
in, err := stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
}
if err != nil {
return err
}
if id := payloadToID(in.Payload); id == errorID {
return fmt.Errorf("got error id: %v", id)
}
}
},
StreamingOutputCallF: func(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
}
stream.SetTrailer(testTrailerMetadata)
}
if id := payloadToID(in.Payload); id == errorID {
return fmt.Errorf("got error id: %v", id)
}
for i := 0; i < 5; i++ {
if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
return err
}
}
return nil
},
}
if err := te.ss.Start(nil); err != nil {
te.t.Fatalf("Failed to start server: %v", err)
}
te.srvAddr = lis.Addr().String()
te.srvIP = lis.Addr().(*net.TCPAddr).IP
te.srvPort = lis.Addr().(*net.TCPAddr).Port
}

func (te *test) clientConn() *grpc.ClientConn {
if te.cc != nil {
return te.cc
}
opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()}

var err error
te.cc, err = grpc.NewClient(te.srvAddr, opts...)
if err != nil {
te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
}
return te.cc
}

type rpcType int

const (
Expand All @@ -345,7 +299,7 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple
req *testpb.SimpleRequest
err error
)
tc := testgrpc.NewTestServiceClient(te.clientConn())
tc := testgrpc.NewTestServiceClient(te.ss.CC)
if c.success {
req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)}
} else {
Expand All @@ -365,7 +319,7 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []prot
resps []proto.Message
err error
)
tc := testgrpc.NewTestServiceClient(te.clientConn())
tc := testgrpc.NewTestServiceClient(te.ss.CC)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ctx = metadata.NewOutgoingContext(ctx, testMetadata)
Expand Down Expand Up @@ -414,7 +368,7 @@ func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, proto.Message
resp *testpb.StreamingInputCallResponse
err error
)
tc := testgrpc.NewTestServiceClient(te.clientConn())
tc := testgrpc.NewTestServiceClient(te.ss.CC)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ctx = metadata.NewOutgoingContext(ctx, testMetadata)
Expand Down Expand Up @@ -447,7 +401,7 @@ func (te *test) doServerStreamCall(c *rpcConfig) (proto.Message, []proto.Message
err error
)

tc := testgrpc.NewTestServiceClient(te.clientConn())
tc := testgrpc.NewTestServiceClient(te.ss.CC)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ctx = metadata.NewOutgoingContext(ctx, testMetadata)
Expand Down Expand Up @@ -796,8 +750,8 @@ func (ed *expectedData) toServerLogEntries() []*binlogpb.GrpcLogEntry {

func runRPCs(t *testing.T, cc *rpcConfig) *expectedData {
te := newTest(t)
te.startServer(&testServer{te: te})
defer te.tearDown()
te.startServer()
defer te.ss.Stop()

expect := &expectedData{
te: te,
Expand Down Expand Up @@ -830,8 +784,7 @@ func runRPCs(t *testing.T, cc *rpcConfig) *expectedData {
if cc.success != (expect.err == nil) {
t.Fatalf("cc.success: %v, got error: %v", cc.success, expect.err)
}
te.cc.Close()
te.srv.GracefulStop() // Wait for the server to stop.
te.ss.S.GracefulStop() // Wait for the server to stop.

return expect
}
Expand Down
Loading