diff --git a/backend/client/cmd/main.go b/backend/client/cmd/main.go index 4e45291..3a4ea96 100644 --- a/backend/client/cmd/main.go +++ b/backend/client/cmd/main.go @@ -11,13 +11,11 @@ import ( "github.com/joho/godotenv" "github.com/kelseyhightower/envconfig" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" at "mosaic-client.com/gen/audio_transcription" cb "mosaic-client.com/gen/conversation_briefing" fd "mosaic-client.com/gen/face_detection" "mosaic-client.com/internal/handler" - "mosaic-client.com/internal/middleware" + "mosaic-client.com/internal/observability" ) type Config struct { @@ -31,37 +29,15 @@ func main() { log.Fatalf("failed to load config values: %v", err) } - var handler slog.Handler - if cfg.ProdMode { - handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}) - } else { - handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) - } - logger := slog.New(handler).With("service", "client") + logger := observability.StructuredLogger(cfg.ProdMode) logger.Info("Starting Mosaic backend server...") - audioConn, err := grpc.NewClient("localhost:50051", grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - logger.Error("Unable to start audio gRPC client") - os.Exit(1) - } + gwClient := handler.NewClient("https://api.verturus.com", handler.DefaultRetryConfig) - faceConn, err := grpc.NewClient("localhost:40040", grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - logger.Error("Unable to start audio gRPC client") - os.Exit(1) - } - - briefingConn, err := grpc.NewClient("localhost:30030", grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - logger.Error("Unable to start briefing generation gRPC client") - os.Exit(1) - } - - atClient := at.NewAudioTranscriptionServiceClient(audioConn) - fdClient := fd.NewFaceDetectionServiceClient(faceConn) - cbClient := cb.NewConversationBriefingServiceClient(briefingConn) + atClient := at.NewAudioTranscriptionServiceClient(gwClient) + fdClient := fd.NewFaceDetectionServiceClient(gwClient) + cbClient := cb.NewConversationBriefingServiceClient(gwClient) go websocketServer(cfg, logger, atClient, fdClient, cbClient) @@ -74,14 +50,9 @@ func main() { <-sigChan logger.Info("Shutting down gracefully...") - err = audioConn.Close() - if err != nil { - logger.Warn("audio grpc connection not closed properly", "err", err) - } - - err = faceConn.Close() + err = gwClient.Close() if err != nil { - logger.Warn("face detection grpc connection not closed properly", "err", err) + logger.Warn("grpc-web connection not closed properly", "err", err) } logger.Debug("Closed gRPC connection") } @@ -107,7 +78,7 @@ func websocketServer( server := http.Server{ Addr: fmt.Sprintf(":%s", cfg.ServerPort), - Handler: middleware.Logging(router), + Handler: observability.HTTPLogger(router), } logger.Debug("[client] Server running on", "port", cfg.ServerPort) diff --git a/backend/client/cmd/makefile b/backend/client/cmd/makefile index e62aca0..dd28aca 100644 --- a/backend/client/cmd/makefile +++ b/backend/client/cmd/makefile @@ -1,15 +1,12 @@ test_all: unit integration -PKGS := ../internal/handler/... ../internal/middleware/... ../internal/service/... +PKGS := ../internal/handler/... ../internal/observability/... ../internal/service/... coverage: go test -tags=unit,integration ${PKGS} -v -coverprofile=coverage.out -covermode=atomic format: - go fmt . - go fmt ../internal/middleware/... . - go fmt ../internal/handler/... . - go fmt ../internal/service/... . + go fmt ${PKGS} integration: go test -tags integration ${PKGS} diff --git a/backend/client/internal/handler/grpc_web_client.go b/backend/client/internal/handler/grpc_web_client.go new file mode 100644 index 0000000..d3ef3a0 --- /dev/null +++ b/backend/client/internal/handler/grpc_web_client.go @@ -0,0 +1,201 @@ +package handler + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +type RetryConfig struct { + MaxAttempts int + InitialBackoff time.Duration + MaxBackoff time.Duration + BackoffMultiplier float64 + RetryableCodes []codes.Code +} + +var DefaultRetryConfig = RetryConfig{ + MaxAttempts: 4, + InitialBackoff: time.Second, + MaxBackoff: 10 * time.Second, + BackoffMultiplier: 2.0, + RetryableCodes: []codes.Code{codes.Unavailable}, +} + +var NoRetry = RetryConfig{MaxAttempts: 1} + +type Client struct { + baseURL string + httpClient *http.Client + retry RetryConfig +} + +func NewClient(baseURL string, retry RetryConfig) *Client { + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + httpClient: &http.Client{}, + retry: retry, + } +} + +func (c *Client) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + backoff := c.retry.InitialBackoff + var lastErr error + + for attempt := 0; attempt < c.retry.MaxAttempts; attempt++ { + if attempt > 0 { + select { + case <-ctx.Done(): + return status.FromContextError(ctx.Err()).Err() + case <-time.After(backoff): + } + backoff = time.Duration(float64(backoff) * c.retry.BackoffMultiplier) + if backoff > c.retry.MaxBackoff { + backoff = c.retry.MaxBackoff + } + } + + lastErr = c.invoke(ctx, method, args, reply) + if lastErr == nil { + return nil + } + if !c.isRetryable(lastErr) { + return lastErr + } + } + return lastErr +} + +func (c *Client) isRetryable(err error) bool { + s, ok := status.FromError(err) + if !ok { + return false + } + for _, code := range c.retry.RetryableCodes { + if s.Code() == code { + return true + } + } + return false +} + +func (c *Client) invoke(ctx context.Context, method string, args, reply interface{}) error { + reqMsg, ok := args.(proto.Message) + if !ok { + return status.Error(codes.Internal, "args must be proto.Message") + } + reqBytes, err := proto.Marshal(reqMsg) + if err != nil { + return status.Errorf(codes.Internal, "marshal request: %v", err) + } + + frame := make([]byte, 5+len(reqBytes)) + frame[0] = 0 + binary.BigEndian.PutUint32(frame[1:5], uint32(len(reqBytes))) + copy(frame[5:], reqBytes) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+method, bytes.NewReader(frame)) + if err != nil { + return status.Errorf(codes.Internal, "create request: %v", err) + } + req.Header.Set("Content-Type", "application/grpc-web+proto") + req.Header.Set("X-Grpc-Web", "1") + + resp, err := c.httpClient.Do(req) + if err != nil { + return status.Errorf(codes.Unavailable, "http request: %v", err) + } + defer func() { + err := resp.Body.Close() + if err != nil { + fmt.Printf("error closing resp body: %v", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return status.Errorf(codes.Internal, "http status: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return status.Errorf(codes.Internal, "read body: %v", err) + } + + var grpcStatus codes.Code + var grpcMessage string + dataRead := false + pos := 0 + + for pos < len(body) { + if pos+5 > len(body) { + return status.Error(codes.Internal, "truncated frame header") + } + frameType := body[pos] + frameLen := int(binary.BigEndian.Uint32(body[pos+1 : pos+5])) + pos += 5 + + if pos+frameLen > len(body) { + return status.Error(codes.Internal, "truncated frame body") + } + frameData := body[pos : pos+frameLen] + pos += frameLen + + switch frameType { + case 0x00: + replyMsg, ok := reply.(proto.Message) + if !ok { + return status.Error(codes.Internal, "reply must be proto.Message") + } + if err := proto.Unmarshal(frameData, replyMsg); err != nil { + return status.Errorf(codes.Internal, "unmarshal response: %v", err) + } + dataRead = true + case 0x80: + for _, line := range strings.Split(string(frameData), "\r\n") { + k, v, ok := strings.Cut(line, ":") + if !ok { + continue + } + switch strings.TrimSpace(k) { + case "grpc-status": + var code int + _, err := fmt.Sscanf(strings.TrimSpace(v), "%d", &code) + if err != nil { + return status.Error(codes.Internal, "error scanning") + } + grpcStatus = codes.Code(code) + case "grpc-message": + grpcMessage, _ = url.PathUnescape(strings.TrimSpace(v)) + } + } + } + } + + if !dataRead && grpcStatus == codes.OK { + return status.Error(codes.Internal, "no data frame in response") + } + if grpcStatus != codes.OK { + return status.Error(grpcStatus, grpcMessage) + } + return nil +} + +func (c *Client) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, status.Error(codes.Unimplemented, "streaming not supported") +} + +func (c *Client) Close() error { + c.httpClient.CloseIdleConnections() + return nil +} diff --git a/backend/client/internal/handler/grpc_web_client_integration_test.go b/backend/client/internal/handler/grpc_web_client_integration_test.go new file mode 100644 index 0000000..3ba7a69 --- /dev/null +++ b/backend/client/internal/handler/grpc_web_client_integration_test.go @@ -0,0 +1,139 @@ +//go:build integration + +package handler_test + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + fd "mosaic-client.com/gen/face_detection" + "mosaic-client.com/internal/handler" + "mosaic-client.com/internal/test" +) + +// fastRetry is a retry config with short backoffs so tests run quickly. +var fastRetry = handler.RetryConfig{ + MaxAttempts: 3, + InitialBackoff: 20 * time.Millisecond, + MaxBackoff: 100 * time.Millisecond, + BackoffMultiplier: 2.0, + RetryableCodes: []codes.Code{codes.Unavailable}, +} + +func TestRetry(t *testing.T) { + t.Run("attempt count and final error", func(t *testing.T) { + tests := []struct { + name string + config handler.RetryConfig + responseCode codes.Code + wantAttempts int32 + wantCode codes.Code + }{ + { + name: "Unavailable exhausts MaxAttempts", + config: fastRetry, + responseCode: codes.Unavailable, + wantAttempts: int32(fastRetry.MaxAttempts), + wantCode: codes.Unavailable, + }, + { + name: "non-retryable NotFound returns on first attempt", + config: fastRetry, + responseCode: codes.NotFound, + wantAttempts: 1, + wantCode: codes.NotFound, + }, + { + name: "NoRetry config makes exactly one attempt", + config: handler.NoRetry, + responseCode: codes.Unavailable, + wantAttempts: 1, + wantCode: codes.Unavailable, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var attempt atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt.Add(1) + w.Header().Set("Content-Type", "application/grpc-web+proto") + w.WriteHeader(http.StatusOK) + w.Write(test.GrpcTrailerFrame(tc.responseCode, "error")) //nolint:errcheck + })) + defer srv.Close() + + client := handler.NewClient(srv.URL, tc.config) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + + test.RequireGRPCCode(t, err, tc.wantCode) + assert.Equal(t, tc.wantAttempts, attempt.Load()) + }) + } + }) + + t.Run("succeeds once server recovers after two Unavailable responses", func(t *testing.T) { + var attempt atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempt.Add(1) + w.Header().Set("Content-Type", "application/grpc-web+proto") + w.WriteHeader(http.StatusOK) + if n < 3 { + w.Write(test.GrpcTrailerFrame(codes.Unavailable, "not ready")) //nolint:errcheck + } else { + w.Write(test.GrpcDataFrame(t, &fd.ProcessVisitorFacesResponse{FaceDetected: true})) //nolint:errcheck + } + })) + defer srv.Close() + + client := handler.NewClient(srv.URL, fastRetry) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + + require.NoError(t, err) + assert.True(t, reply.FaceDetected) + assert.Equal(t, int32(3), attempt.Load()) + }) + + t.Run("context cancelled during backoff returns before next attempt", func(t *testing.T) { + slowRetry := handler.RetryConfig{ + MaxAttempts: 3, + InitialBackoff: 300 * time.Millisecond, + MaxBackoff: 1 * time.Second, + BackoffMultiplier: 2.0, + RetryableCodes: []codes.Code{codes.Unavailable}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/grpc-web+proto") + w.WriteHeader(http.StatusOK) + w.Write(test.GrpcTrailerFrame(codes.Unavailable, "not ready")) //nolint:errcheck + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(80 * time.Millisecond) // cancel mid-way through the 300ms backoff + cancel() + }() + + client := handler.NewClient(srv.URL, slowRetry) + var reply fd.ProcessVisitorFacesResponse + + start := time.Now() + err := client.Invoke(ctx, "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + + require.Error(t, err) + assert.Less(t, time.Since(start), 250*time.Millisecond, "context cancel must short-circuit the backoff") + }) +} diff --git a/backend/client/internal/handler/grpc_web_client_unit_test.go b/backend/client/internal/handler/grpc_web_client_unit_test.go new file mode 100644 index 0000000..a026c75 --- /dev/null +++ b/backend/client/internal/handler/grpc_web_client_unit_test.go @@ -0,0 +1,205 @@ +//go:build unit + +package handler_test + +import ( + "bytes" + "context" + "encoding/binary" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + fd "mosaic-client.com/gen/face_detection" + "mosaic-client.com/internal/handler" + "mosaic-client.com/internal/test" +) + +func TestInvoke(t *testing.T) { + t.Run("response error paths", func(t *testing.T) { + truncatedBodyFrame := func() []byte { + var buf bytes.Buffer + buf.WriteByte(0x00) + length := make([]byte, 4) + binary.BigEndian.PutUint32(length, 100) + buf.Write(length) + buf.Write([]byte{0x01, 0x02, 0x03}) // only 3 bytes of the declared 100 + return buf.Bytes() + } + + tests := []struct { + name string + body []byte + reply interface{} // nil → uses *fd.ProcessVisitorFacesResponse + wantCode codes.Code + wantMsg string + }{ + { + name: "truncated frame header", + body: []byte{0x00, 0x00, 0x01}, // 3 bytes, need 5 + wantCode: codes.Internal, + wantMsg: "truncated frame header", + }, + { + name: "truncated frame body", + body: truncatedBodyFrame(), + wantCode: codes.Internal, + wantMsg: "truncated frame body", + }, + { + name: "empty body - no data frame", + body: []byte{}, + wantCode: codes.Internal, + wantMsg: "no data frame in response", + }, + { + name: "OK trailer only - no data frame", + body: test.GrpcTrailerFrame(codes.OK, ""), + wantCode: codes.Internal, + wantMsg: "no data frame in response", + }, + { + name: "non-OK gRPC trailer", + body: test.GrpcTrailerFrame(codes.NotFound, "record not found"), + wantCode: codes.NotFound, + wantMsg: "record not found", + }, + { + name: "data frame + non-OK trailer - trailer wins", + body: append(test.GrpcDataFrame(t, &fd.ProcessVisitorFacesResponse{FaceDetected: true}), test.GrpcTrailerFrame(codes.Unavailable, "overloaded")...), + wantCode: codes.Unavailable, + wantMsg: "overloaded", + }, + { + name: "non-proto reply type", + body: test.GrpcDataFrame(t, &fd.ProcessVisitorFacesResponse{}), + reply: new(string), + wantCode: codes.Internal, + wantMsg: "reply must be proto.Message", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + srv := test.ServeGRPCWeb(tc.body) + defer srv.Close() + + client := handler.NewClient(srv.URL, handler.NoRetry) + reply := tc.reply + if reply == nil { + reply = &fd.ProcessVisitorFacesResponse{} + } + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, reply) + + test.RequireGRPCCode(t, err, tc.wantCode) + s, _ := status.FromError(err) + assert.Equal(t, tc.wantMsg, s.Message()) + }) + } + }) + + t.Run("data frame only - reply is populated", func(t *testing.T) { + srv := test.ServeGRPCWeb(test.GrpcDataFrame(t, &fd.ProcessVisitorFacesResponse{FaceDetected: true})) + defer srv.Close() + + client := handler.NewClient(srv.URL, handler.NoRetry) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + + require.NoError(t, err) + assert.True(t, reply.FaceDetected) + }) + + t.Run("data frame + OK trailer - still succeeds", func(t *testing.T) { + body := append(test.GrpcDataFrame(t, &fd.ProcessVisitorFacesResponse{FaceDetected: true}), test.GrpcTrailerFrame(codes.OK, "")...) + srv := test.ServeGRPCWeb(body) + defer srv.Close() + + client := handler.NewClient(srv.URL, handler.NoRetry) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + + require.NoError(t, err) + assert.True(t, reply.FaceDetected) + }) + + t.Run("trailing slash on baseURL is stripped", func(t *testing.T) { + srv := test.ServeGRPCWeb(test.GrpcDataFrame(t, &fd.ProcessVisitorFacesResponse{})) + defer srv.Close() + + client := handler.NewClient(srv.URL+"/", handler.NoRetry) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + require.NoError(t, err) + }) + + t.Run("non-proto args - returns Internal without hitting server", func(t *testing.T) { + client := handler.NewClient("http://localhost:9999", handler.NoRetry) + + err := client.Invoke(context.Background(), "/any/Method", "not-a-proto", &fd.ProcessVisitorFacesResponse{}) + + test.RequireGRPCCode(t, err, codes.Internal) + assert.ErrorContains(t, err, "args must be proto.Message") + }) + + t.Run("HTTP 500 - returns Internal", func(t *testing.T) { + srv := test.ServeGRPCWebStatus(http.StatusInternalServerError) + defer srv.Close() + + client := handler.NewClient(srv.URL, handler.NoRetry) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + + test.RequireGRPCCode(t, err, codes.Internal) + assert.ErrorContains(t, err, "http status: 500") + }) + + t.Run("server down - returns Unavailable", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + srv.Close() + + client := handler.NewClient(srv.URL, handler.NoRetry) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(context.Background(), "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + + test.RequireGRPCCode(t, err, codes.Unavailable) + }) + + t.Run("cancelled context - returns error", func(t *testing.T) { + srv := test.ServeGRPCWeb(nil) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + client := handler.NewClient(srv.URL, handler.NoRetry) + var reply fd.ProcessVisitorFacesResponse + + err := client.Invoke(ctx, "/any/Method", &fd.ProcessVisitorFacesRequest{}, &reply) + require.Error(t, err) + }) +} + +func TestNewStreamReturn(t *testing.T) { + client := handler.NewClient("http://localhost:9999", handler.NoRetry) + + stream, err := client.NewStream(context.Background(), nil, "/any/Method") + + assert.Nil(t, stream) + test.RequireGRPCCode(t, err, codes.Unimplemented) +} + +func TestCloseReturnsNil(t *testing.T) { + client := handler.NewClient("http://localhost:9999", handler.NoRetry) + assert.NoError(t, client.Close()) +} diff --git a/backend/client/internal/middleware/logging.go b/backend/client/internal/observability/logging.go similarity index 64% rename from backend/client/internal/middleware/logging.go rename to backend/client/internal/observability/logging.go index fb4b912..5810dab 100644 --- a/backend/client/internal/middleware/logging.go +++ b/backend/client/internal/observability/logging.go @@ -1,10 +1,12 @@ -package middleware +package observability import ( "bufio" "log" + "log/slog" "net" "net/http" + "os" "time" ) @@ -29,8 +31,8 @@ func (w *WrappedWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return hijacker.Hijack() } -// logging middleware to track status codes, the url path, and response latency -func Logging(next http.Handler) http.Handler { +// http logger to track status codes, the url path, and response latency +func HTTPLogger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -43,3 +45,14 @@ func Logging(next http.Handler) http.Handler { log.Println(wrapped.StatusCode, r.Method, r.URL.Path, time.Since(start)) }) } + +// returns the structured logger with appropriate log level based on prodMode +func StructuredLogger(prodMode bool) *slog.Logger { + level := slog.LevelDebug + if prodMode { + level = slog.LevelInfo + } + h := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level}) + + return slog.New(h).With("service", "client") +} diff --git a/backend/client/internal/middleware/logging_unit_test.go b/backend/client/internal/observability/logging_unit_test.go similarity index 77% rename from backend/client/internal/middleware/logging_unit_test.go rename to backend/client/internal/observability/logging_unit_test.go index 58c6b5c..63287f1 100644 --- a/backend/client/internal/middleware/logging_unit_test.go +++ b/backend/client/internal/observability/logging_unit_test.go @@ -1,24 +1,26 @@ //go:build unit -package middleware_test +package observability_test import ( "bytes" + "context" "log" + "log/slog" "net/http" "net/http/httptest" "strings" "testing" "github.com/stretchr/testify/assert" - "mosaic-client.com/internal/middleware" + "mosaic-client.com/internal/observability" ) // unit tests for writeHeader function func TestWriteHeader(t *testing.T) { t.Run("Captures status code properly", func(t *testing.T) { recorder := httptest.NewRecorder() - wrapped := &middleware.WrappedWriter{ + wrapped := &observability.WrappedWriter{ ResponseWriter: recorder, StatusCode: http.StatusOK, } @@ -30,7 +32,7 @@ func TestWriteHeader(t *testing.T) { t.Run("Forwards to responsewriter", func(t *testing.T) { recorder := httptest.NewRecorder() - wrapped := &middleware.WrappedWriter{ + wrapped := &observability.WrappedWriter{ ResponseWriter: recorder, StatusCode: http.StatusOK, } @@ -42,7 +44,7 @@ func TestWriteHeader(t *testing.T) { t.Run("Starts at 200 status code", func(t *testing.T) { recorder := httptest.NewRecorder() - wrapped := &middleware.WrappedWriter{ + wrapped := &observability.WrappedWriter{ ResponseWriter: recorder, StatusCode: http.StatusOK, } @@ -60,7 +62,7 @@ func TestLogging(t *testing.T) { w.WriteHeader(http.StatusOK) }) - logging := middleware.Logging(mockHandler) + logging := observability.HTTPLogger(mockHandler) req := httptest.NewRequest(http.MethodGet, "/test", nil) recorder := httptest.NewRecorder() @@ -79,7 +81,7 @@ func TestLogging(t *testing.T) { w.WriteHeader(http.StatusOK) }) - logging := middleware.Logging(mockHandler) + logging := observability.HTTPLogger(mockHandler) req := httptest.NewRequest(http.MethodPost, "/api/products", nil) recorder := httptest.NewRecorder() @@ -99,7 +101,7 @@ func TestLogging(t *testing.T) { w.WriteHeader(http.StatusNotFound) }) - logging := middleware.Logging(mockHandler) + logging := observability.HTTPLogger(mockHandler) req := httptest.NewRequest(http.MethodGet, "/test-path", nil) recorder := httptest.NewRecorder() @@ -113,5 +115,20 @@ func TestLogging(t *testing.T) { assert.Contains(t, logOutput, "/test-path", "Log should contain request path") assert.True(t, strings.Contains(logOutput, "ns") || strings.Contains(logOutput, "µs") || strings.Contains(logOutput, "ms") || strings.Contains(logOutput, "s"), "Log should contain timing information") }) +} + +func TestStructuredLogger(t *testing.T) { + t.Run("prod mode set to false should enable debug level", func(t *testing.T) { + logger := observability.StructuredLogger(false) + + assert.True(t, logger.Enabled(context.Background(), slog.LevelDebug)) + }) + + t.Run("prod mode set to true should disable debug level", func(t *testing.T) { + logger := observability.StructuredLogger(true) + + assert.False(t, logger.Enabled(context.Background(), slog.LevelDebug)) + assert.True(t, logger.Enabled(context.Background(), slog.LevelInfo)) + }) } diff --git a/backend/client/internal/test/grpc_web_helpers.go b/backend/client/internal/test/grpc_web_helpers.go new file mode 100644 index 0000000..da40873 --- /dev/null +++ b/backend/client/internal/test/grpc_web_helpers.go @@ -0,0 +1,66 @@ +package test + +import ( + "encoding/binary" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// grpcDataFrame encodes a proto message as a gRPC-web data frame (type 0x00). +func GrpcDataFrame(t *testing.T, msg proto.Message) []byte { + t.Helper() + b, err := proto.Marshal(msg) + require.NoError(t, err) + frame := make([]byte, 5+len(b)) + frame[0] = 0x00 + binary.BigEndian.PutUint32(frame[1:5], uint32(len(b))) + copy(frame[5:], b) + return frame +} + +// grpcTrailerFrame builds a gRPC-web trailer frame (type 0x80). +func GrpcTrailerFrame(code codes.Code, msg string) []byte { + trailer := fmt.Sprintf("grpc-status: %d", int(code)) + if msg != "" { + trailer += fmt.Sprintf("\r\ngrpc-message: %s", msg) + } + b := []byte(trailer) + frame := make([]byte, 5+len(b)) + frame[0] = 0x80 + binary.BigEndian.PutUint32(frame[1:5], uint32(len(b))) + copy(frame[5:], b) + return frame +} + +// serveGRPCWeb spins up a test server that always returns the given raw body with HTTP 200. +func ServeGRPCWeb(body []byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/grpc-web+proto") + w.WriteHeader(http.StatusOK) + w.Write(body) //nolint:errcheck + })) +} + +// serveGRPCWebStatus spins up a test server that returns the given HTTP status with no body. +func ServeGRPCWebStatus(httpStatus int) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(httpStatus) + })) +} + +// requireGRPCCode asserts that err is a gRPC status error with the expected code. +func RequireGRPCCode(t *testing.T, err error, expected codes.Code) { + t.Helper() + require.Error(t, err) + s, ok := status.FromError(err) + require.True(t, ok, "error must be a gRPC status error") + assert.Equal(t, expected, s.Code()) +}