diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dc2b3498e5..c4cd032fd80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ * [ENHANCEMENT] Querier: Support query limits in parquet queryable. #6870 * [ENHANCEMENT] Ring: Add zone label to ring_members metric. #6900 * [ENHANCEMENT] Ingester: Add new metric `cortex_ingester_push_errors_total` to track reasons for ingester request failures. #6901 +* [ENHANCEMENT] API: add request ID injection to context to enable tracking requests across downstream services. #6895 * [BUGFIX] Ingester: Avoid error or early throttling when READONLY ingesters are present in the ring #6517 * [BUGFIX] Ingester: Fix labelset data race condition. #6573 * [BUGFIX] Compactor: Cleaner should not put deletion marker for blocks with no-compact marker. #6576 diff --git a/docs/configuration/config-file-reference.md b/docs/configuration/config-file-reference.md index 2963a87348c..c0e8f66a453 100644 --- a/docs/configuration/config-file-reference.md +++ b/docs/configuration/config-file-reference.md @@ -102,6 +102,10 @@ api: # CLI flag: -api.http-request-headers-to-log [http_request_headers_to_log: | default = []] + # HTTP header that can be used as request id + # CLI flag: -api.request-id-header + [request_id_header: | default = ""] + # Regex for CORS origin. It is fully anchored. Example: # 'https?://(domain1|domain2)\.com' # CLI flag: -server.cors-origin diff --git a/go.mod b/go.mod index ea2dbcc0670..642d2f65d05 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3 github.com/cespare/xxhash/v2 v2.3.0 github.com/google/go-cmp v0.7.0 + github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 github.com/oklog/ulid/v2 v2.1.1 @@ -170,7 +171,6 @@ require ( github.com/google/btree v1.1.3 // indirect github.com/google/pprof v0.0.0-20250607225305-033d6d78b36a // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 // indirect diff --git a/pkg/api/api.go b/pkg/api/api.go index ec02f72e760..1c68c426d8b 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -71,6 +71,10 @@ type Config struct { // Allows and is used to configure the addition of HTTP Header fields to logs HTTPRequestHeadersToLog flagext.StringSlice `yaml:"http_request_headers_to_log"` + // HTTP header that can be used as request id. It will always be included in logs + // If it's not provided, or this header is empty, then random requestId will be generated + RequestIdHeader string `yaml:"request_id_header"` + // This sets the Origin header value corsRegexString string `yaml:"cors_origin"` @@ -87,6 +91,7 @@ var ( func (cfg *Config) RegisterFlags(f *flag.FlagSet) { f.BoolVar(&cfg.ResponseCompression, "api.response-compression-enabled", false, "Use GZIP compression for API responses. Some endpoints serve large YAML or JSON blobs which can benefit from compression.") f.Var(&cfg.HTTPRequestHeadersToLog, "api.http-request-headers-to-log", "Which HTTP Request headers to add to logs") + f.StringVar(&cfg.RequestIdHeader, "api.request-id-header", "", "HTTP header that can be used as request id") f.BoolVar(&cfg.buildInfoEnabled, "api.build-info-enabled", false, "If enabled, build Info API will be served by query frontend or querier.") f.StringVar(&cfg.QuerierDefaultCodec, "api.querier-default-codec", "json", "Choose default codec for querier response serialization. Supports 'json' and 'protobuf'.") cfg.RegisterFlagsWithPrefix("", f) @@ -169,8 +174,9 @@ func New(cfg Config, serverCfg server.Config, s *server.Server, logger log.Logge if cfg.HTTPAuthMiddleware == nil { api.AuthMiddleware = middleware.AuthenticateUser } - if len(cfg.HTTPRequestHeadersToLog) > 0 { - api.HTTPHeaderMiddleware = &HTTPHeaderMiddleware{TargetHeaders: cfg.HTTPRequestHeadersToLog} + api.HTTPHeaderMiddleware = &HTTPHeaderMiddleware{ + TargetHeaders: cfg.HTTPRequestHeadersToLog, + RequestIdHeader: cfg.RequestIdHeader, } return api, nil diff --git a/pkg/api/api_test.go b/pkg/api/api_test.go index c25ca27234b..df2ec239f03 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/api_test.go @@ -89,6 +89,7 @@ func TestNewApiWithHeaderLogging(t *testing.T) { } +// HTTPHeaderMiddleware should be added even if no headers are specified to log because it also handles request ID injection. func TestNewApiWithoutHeaderLogging(t *testing.T) { cfg := Config{ HTTPRequestHeadersToLog: []string{}, @@ -102,7 +103,8 @@ func TestNewApiWithoutHeaderLogging(t *testing.T) { api, err := New(cfg, serverCfg, server, &FakeLogger{}) require.NoError(t, err) - require.Nil(t, api.HTTPHeaderMiddleware) + require.NotNil(t, api.HTTPHeaderMiddleware) + require.Empty(t, api.HTTPHeaderMiddleware.TargetHeaders) } diff --git a/pkg/api/middlewares.go b/pkg/api/middlewares.go index 8ddefaa2c66..dcb9c298169 100644 --- a/pkg/api/middlewares.go +++ b/pkg/api/middlewares.go @@ -1,40 +1,51 @@ package api import ( - "context" "net/http" - util_log "github.com/cortexproject/cortex/pkg/util/log" + "github.com/google/uuid" + + "github.com/cortexproject/cortex/pkg/util/requestmeta" ) // HTTPHeaderMiddleware adds specified HTTPHeaders to the request context type HTTPHeaderMiddleware struct { - TargetHeaders []string + TargetHeaders []string + RequestIdHeader string } -// InjectTargetHeadersIntoHTTPRequest injects specified HTTPHeaders into the request context -func (h HTTPHeaderMiddleware) InjectTargetHeadersIntoHTTPRequest(r *http.Request) context.Context { - headerMap := make(map[string]string) +// injectRequestContext injects request related metadata into the request context +func (h HTTPHeaderMiddleware) injectRequestContext(r *http.Request) *http.Request { + requestContextMap := make(map[string]string) - // Check to make sure that Headers have not already been injected - checkMapInContext := util_log.HeaderMapFromContext(r.Context()) + // Check to make sure that request context have not already been injected + checkMapInContext := requestmeta.MapFromContext(r.Context()) if checkMapInContext != nil { - return r.Context() + return r } for _, target := range h.TargetHeaders { contents := r.Header.Get(target) if contents != "" { - headerMap[target] = contents + requestContextMap[target] = contents } } - return util_log.ContextWithHeaderMap(r.Context(), headerMap) + requestContextMap[requestmeta.LoggingHeadersKey] = requestmeta.LoggingHeaderKeysToString(h.TargetHeaders) + + reqId := r.Header.Get(h.RequestIdHeader) + if reqId == "" { + reqId = uuid.NewString() + } + requestContextMap[requestmeta.RequestIdKey] = reqId + + ctx := requestmeta.ContextWithRequestMetadataMap(r.Context(), requestContextMap) + return r.WithContext(ctx) } // Wrap implements Middleware func (h HTTPHeaderMiddleware) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := h.InjectTargetHeadersIntoHTTPRequest(r) - next.ServeHTTP(w, r.WithContext(ctx)) + r = h.injectRequestContext(r) + next.ServeHTTP(w, r) }) } diff --git a/pkg/api/middlewares_test.go b/pkg/api/middlewares_test.go index dbf8719ad48..691d3b23584 100644 --- a/pkg/api/middlewares_test.go +++ b/pkg/api/middlewares_test.go @@ -7,12 +7,11 @@ import ( "github.com/stretchr/testify/require" - util_log "github.com/cortexproject/cortex/pkg/util/log" + "github.com/cortexproject/cortex/pkg/util/requestmeta" ) -var HTTPTestMiddleware = HTTPHeaderMiddleware{TargetHeaders: []string{"TestHeader1", "TestHeader2", "Test3"}} - func TestHeaderInjection(t *testing.T) { + middleware := HTTPHeaderMiddleware{TargetHeaders: []string{"TestHeader1", "TestHeader2", "Test3"}} ctx := context.Background() h := http.Header{} contentsMap := make(map[string]string) @@ -32,12 +31,12 @@ func TestHeaderInjection(t *testing.T) { } req = req.WithContext(ctx) - ctx = HTTPTestMiddleware.InjectTargetHeadersIntoHTTPRequest(req) + req = middleware.injectRequestContext(req) - headerMap := util_log.HeaderMapFromContext(ctx) + headerMap := requestmeta.MapFromContext(req.Context()) require.NotNil(t, headerMap) - for _, header := range HTTPTestMiddleware.TargetHeaders { + for _, header := range middleware.TargetHeaders { require.Equal(t, contentsMap[header], headerMap[header]) } for header, contents := range contentsMap { @@ -46,6 +45,7 @@ func TestHeaderInjection(t *testing.T) { } func TestExistingHeaderInContextIsNotOverridden(t *testing.T) { + middleware := HTTPHeaderMiddleware{TargetHeaders: []string{"TestHeader1", "TestHeader2", "Test3"}} ctx := context.Background() h := http.Header{} @@ -58,7 +58,7 @@ func TestExistingHeaderInContextIsNotOverridden(t *testing.T) { h.Add("TestHeader2", "Fail2") h.Add("Test3", "Fail3") - ctx = util_log.ContextWithHeaderMap(ctx, contentsMap) + ctx = requestmeta.ContextWithRequestMetadataMap(ctx, contentsMap) req := &http.Request{ Method: "GET", RequestURI: "/HTTPHeaderTest", @@ -67,8 +67,77 @@ func TestExistingHeaderInContextIsNotOverridden(t *testing.T) { } req = req.WithContext(ctx) - ctx = HTTPTestMiddleware.InjectTargetHeadersIntoHTTPRequest(req) + req = middleware.injectRequestContext(req) + + require.Equal(t, contentsMap, requestmeta.MapFromContext(req.Context())) + +} + +func TestRequestIdInjection(t *testing.T) { + middleware := HTTPHeaderMiddleware{ + RequestIdHeader: "X-Request-ID", + } + + req := &http.Request{ + Method: "GET", + RequestURI: "/test", + Body: http.NoBody, + Header: http.Header{}, + } + req = req.WithContext(context.Background()) + req = middleware.injectRequestContext(req) + + requestID := requestmeta.RequestIdFromContext(req.Context()) + require.NotEmpty(t, requestID, "Request ID should be generated if not provided") +} + +func TestRequestIdFromHeaderIsUsed(t *testing.T) { + const providedID = "my-test-id-123" + + middleware := HTTPHeaderMiddleware{ + RequestIdHeader: "X-Request-ID", + } + + h := http.Header{} + h.Add("X-Request-ID", providedID) - require.Equal(t, contentsMap, util_log.HeaderMapFromContext(ctx)) + req := &http.Request{ + Method: "GET", + RequestURI: "/test", + Body: http.NoBody, + Header: h, + } + req = req.WithContext(context.Background()) + req = middleware.injectRequestContext(req) + + requestID := requestmeta.RequestIdFromContext(req.Context()) + require.Equal(t, providedID, requestID, "Request ID from header should be used") +} + +func TestTargetHeaderAndRequestIdHeaderOverlap(t *testing.T) { + const headerKey = "X-Request-ID" + const providedID = "overlap-id-456" + + middleware := HTTPHeaderMiddleware{ + TargetHeaders: []string{headerKey, "Other-Header"}, + RequestIdHeader: headerKey, + } + + h := http.Header{} + h.Add(headerKey, providedID) + h.Add("Other-Header", "some-value") + + req := &http.Request{ + Method: "GET", + RequestURI: "/test", + Body: http.NoBody, + Header: h, + } + req = req.WithContext(context.Background()) + req = middleware.injectRequestContext(req) + ctxMap := requestmeta.MapFromContext(req.Context()) + requestID := requestmeta.RequestIdFromContext(req.Context()) + require.Equal(t, providedID, ctxMap[headerKey], "Header value should be correctly stored") + require.Equal(t, providedID, requestID, "Request ID should come from the overlapping header") } diff --git a/pkg/cortex/cortex.go b/pkg/cortex/cortex.go index b141adc8127..d301a27d80a 100644 --- a/pkg/cortex/cortex.go +++ b/pkg/cortex/cortex.go @@ -392,10 +392,8 @@ func (t *Cortex) setupThanosTracing() { // setupGRPCHeaderForwarding appends a gRPC middleware used to enable the propagation of // HTTP Headers through child gRPC calls func (t *Cortex) setupGRPCHeaderForwarding() { - if len(t.Cfg.API.HTTPRequestHeadersToLog) > 0 { - t.Cfg.Server.GRPCMiddleware = append(t.Cfg.Server.GRPCMiddleware, grpcutil.HTTPHeaderPropagationServerInterceptor) - t.Cfg.Server.GRPCStreamMiddleware = append(t.Cfg.Server.GRPCStreamMiddleware, grpcutil.HTTPHeaderPropagationStreamServerInterceptor) - } + t.Cfg.Server.GRPCMiddleware = append(t.Cfg.Server.GRPCMiddleware, grpcutil.HTTPHeaderPropagationServerInterceptor) + t.Cfg.Server.GRPCStreamMiddleware = append(t.Cfg.Server.GRPCStreamMiddleware, grpcutil.HTTPHeaderPropagationStreamServerInterceptor) } func (t *Cortex) setupRequestSigning() { diff --git a/pkg/cortex/modules.go b/pkg/cortex/modules.go index a13f35e6a9d..64654808965 100644 --- a/pkg/cortex/modules.go +++ b/pkg/cortex/modules.go @@ -402,9 +402,7 @@ func (t *Cortex) initQuerier() (serv services.Service, err error) { // request context. internalQuerierRouter = t.API.AuthMiddleware.Wrap(internalQuerierRouter) - if len(t.Cfg.API.HTTPRequestHeadersToLog) > 0 { - internalQuerierRouter = t.API.HTTPHeaderMiddleware.Wrap(internalQuerierRouter) - } + internalQuerierRouter = t.API.HTTPHeaderMiddleware.Wrap(internalQuerierRouter) } // If neither frontend address or scheduler address is configured, no worker is needed. diff --git a/pkg/distributor/distributor.go b/pkg/distributor/distributor.go index 0fc11c19d19..9c10a675306 100644 --- a/pkg/distributor/distributor.go +++ b/pkg/distributor/distributor.go @@ -40,6 +40,7 @@ import ( "github.com/cortexproject/cortex/pkg/util/limiter" util_log "github.com/cortexproject/cortex/pkg/util/log" util_math "github.com/cortexproject/cortex/pkg/util/math" + "github.com/cortexproject/cortex/pkg/util/requestmeta" "github.com/cortexproject/cortex/pkg/util/services" "github.com/cortexproject/cortex/pkg/util/validation" ) @@ -892,9 +893,9 @@ func (d *Distributor) doBatch(ctx context.Context, req *cortexpb.WriteRequest, s if sp := opentracing.SpanFromContext(ctx); sp != nil { localCtx = opentracing.ContextWithSpan(localCtx, sp) } - // Get any HTTP headers that are supposed to be added to logs and add to localCtx for later use - if headerMap := util_log.HeaderMapFromContext(ctx); headerMap != nil { - localCtx = util_log.ContextWithHeaderMap(localCtx, headerMap) + // Get any HTTP request metadata that are supposed to be added to logs and add to localCtx for later use + if requestContextMap := requestmeta.MapFromContext(ctx); requestContextMap != nil { + localCtx = requestmeta.ContextWithRequestMetadataMap(localCtx, requestContextMap) } // Get clientIP(s) from Context and add it to localCtx source := util.GetSourceIPsFromOutgoingCtx(ctx) diff --git a/pkg/querier/tripperware/roundtrip.go b/pkg/querier/tripperware/roundtrip.go index b9be569d6d9..1afa1ff2ee0 100644 --- a/pkg/querier/tripperware/roundtrip.go +++ b/pkg/querier/tripperware/roundtrip.go @@ -35,7 +35,7 @@ import ( "github.com/cortexproject/cortex/pkg/tenant" "github.com/cortexproject/cortex/pkg/util" "github.com/cortexproject/cortex/pkg/util/limiter" - util_log "github.com/cortexproject/cortex/pkg/util/log" + "github.com/cortexproject/cortex/pkg/util/requestmeta" ) const ( @@ -261,8 +261,8 @@ func (q roundTripper) Do(ctx context.Context, r Request) (Response, error) { return nil, err } - if headerMap := util_log.HeaderMapFromContext(ctx); headerMap != nil { - util_log.InjectHeadersIntoHTTPRequest(headerMap, request) + if requestMetadataMap := requestmeta.MapFromContext(ctx); requestMetadataMap != nil { + requestmeta.InjectMetadataIntoHTTPRequestHeaders(requestMetadataMap, request) } if err := user.InjectOrgIDIntoHTTPRequest(ctx, request); err != nil { diff --git a/pkg/querier/worker/frontend_processor.go b/pkg/querier/worker/frontend_processor.go index 17bd031acfb..88f7f311393 100644 --- a/pkg/querier/worker/frontend_processor.go +++ b/pkg/querier/worker/frontend_processor.go @@ -17,6 +17,7 @@ import ( querier_stats "github.com/cortexproject/cortex/pkg/querier/stats" "github.com/cortexproject/cortex/pkg/util/backoff" util_log "github.com/cortexproject/cortex/pkg/util/log" + "github.com/cortexproject/cortex/pkg/util/requestmeta" ) var ( @@ -129,18 +130,12 @@ func (fp *frontendProcessor) runRequest(ctx context.Context, request *httpgrpc.H for _, h := range request.Headers { headers[h.Key] = h.Values[0] } - headerMap := make(map[string]string, 0) - // Remove non-existent header. - for _, header := range fp.targetHeaders { - if v, ok := headers[textproto.CanonicalMIMEHeaderKey(header)]; ok { - headerMap[header] = v - } - } + ctx = requestmeta.ContextWithRequestMetadataMapFromHeaders(ctx, headers, fp.targetHeaders) + orgID, ok := headers[textproto.CanonicalMIMEHeaderKey(user.OrgIDHeaderName)] if ok { ctx = user.InjectOrgID(ctx, orgID) } - ctx = util_log.ContextWithHeaderMap(ctx, headerMap) logger := util_log.WithContext(ctx, fp.log) if statsEnabled { level.Info(logger).Log("msg", "started running request") diff --git a/pkg/querier/worker/scheduler_processor.go b/pkg/querier/worker/scheduler_processor.go index 0d149210284..10fd96ab230 100644 --- a/pkg/querier/worker/scheduler_processor.go +++ b/pkg/querier/worker/scheduler_processor.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/http" - "net/textproto" "time" "github.com/go-kit/log" @@ -28,6 +27,7 @@ import ( "github.com/cortexproject/cortex/pkg/util/httpgrpcutil" util_log "github.com/cortexproject/cortex/pkg/util/log" cortexmiddleware "github.com/cortexproject/cortex/pkg/util/middleware" + "github.com/cortexproject/cortex/pkg/util/requestmeta" "github.com/cortexproject/cortex/pkg/util/services" ) @@ -141,14 +141,7 @@ func (sp *schedulerProcessor) querierLoop(c schedulerpb.SchedulerForQuerier_Quer for _, h := range request.HttpRequest.Headers { headers[h.Key] = h.Values[0] } - headerMap := make(map[string]string, 0) - // Remove non-existent header. - for _, header := range sp.targetHeaders { - if v, ok := headers[textproto.CanonicalMIMEHeaderKey(header)]; ok { - headerMap[header] = v - } - } - ctx = util_log.ContextWithHeaderMap(ctx, headerMap) + ctx = requestmeta.ContextWithRequestMetadataMapFromHeaders(ctx, headers, sp.targetHeaders) tracer := opentracing.GlobalTracer() // Ignore errors here. If we cannot get parent span, we just don't create new one. diff --git a/pkg/ruler/compat.go b/pkg/ruler/compat.go index 862bcc54706..173f8941edf 100644 --- a/pkg/ruler/compat.go +++ b/pkg/ruler/compat.go @@ -4,10 +4,12 @@ import ( "context" "errors" "fmt" + "time" "github.com/go-kit/log" "github.com/go-kit/log/level" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/prometheus/model/exemplar" "github.com/prometheus/prometheus/model/histogram" @@ -27,6 +29,7 @@ import ( "github.com/cortexproject/cortex/pkg/ring/client" util_log "github.com/cortexproject/cortex/pkg/util/log" promql_util "github.com/cortexproject/cortex/pkg/util/promql" + "github.com/cortexproject/cortex/pkg/util/requestmeta" "github.com/cortexproject/cortex/pkg/util/validation" ) @@ -183,6 +186,9 @@ func EngineQueryFunc(engine promql.QueryEngine, frontendClient *frontendClient, } } + // Add request ID to the context so that it can be used in logs and metrics for split queries. + ctx = requestmeta.ContextWithRequestId(ctx, uuid.NewString()) + if frontendClient != nil { v, err := frontendClient.InstantQuery(ctx, qs, t) if err != nil { diff --git a/pkg/util/grpcutil/grpc_interceptors_test.go b/pkg/util/grpcutil/grpc_interceptors_test.go index 6a0011c9a90..81788d22d7d 100644 --- a/pkg/util/grpcutil/grpc_interceptors_test.go +++ b/pkg/util/grpcutil/grpc_interceptors_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" - util_log "github.com/cortexproject/cortex/pkg/util/log" + "github.com/cortexproject/cortex/pkg/util/requestmeta" ) func TestHTTPHeaderPropagationClientInterceptor(t *testing.T) { @@ -18,14 +18,14 @@ func TestHTTPHeaderPropagationClientInterceptor(t *testing.T) { contentsMap["TestHeader1"] = "RequestID" contentsMap["TestHeader2"] = "ContentsOfTestHeader2" contentsMap["Test3"] = "SomeInformation" - ctx = util_log.ContextWithHeaderMap(ctx, contentsMap) + ctx = requestmeta.ContextWithRequestMetadataMap(ctx, contentsMap) - ctx = injectForwardedHeadersIntoMetadata(ctx) + ctx = injectForwardedRequestMetadata(ctx) md, ok := metadata.FromOutgoingContext(ctx) require.True(t, ok) - headers := md[util_log.HeaderPropagationStringForRequestLogging] + headers := md[requestmeta.PropagationStringForRequestMetadata] assert.Equal(t, 6, len(headers)) assert.Contains(t, headers, "TestHeader1") assert.Contains(t, headers, "TestHeader2") @@ -37,20 +37,20 @@ func TestHTTPHeaderPropagationClientInterceptor(t *testing.T) { func TestExistingValuesInMetadataForHTTPPropagationClientInterceptor(t *testing.T) { ctx := context.Background() - ctx = metadata.AppendToOutgoingContext(ctx, util_log.HeaderPropagationStringForRequestLogging, "testabc123") + ctx = metadata.AppendToOutgoingContext(ctx, requestmeta.PropagationStringForRequestMetadata, "testabc123") contentsMap := make(map[string]string) contentsMap["TestHeader1"] = "RequestID" contentsMap["TestHeader2"] = "ContentsOfTestHeader2" contentsMap["Test3"] = "SomeInformation" - ctx = util_log.ContextWithHeaderMap(ctx, contentsMap) + ctx = requestmeta.ContextWithRequestMetadataMap(ctx, contentsMap) - ctx = injectForwardedHeadersIntoMetadata(ctx) + ctx = injectForwardedRequestMetadata(ctx) md, ok := metadata.FromOutgoingContext(ctx) require.True(t, ok) - contents := md[util_log.HeaderPropagationStringForRequestLogging] + contents := md[requestmeta.PropagationStringForRequestMetadata] assert.Contains(t, contents, "testabc123") assert.Equal(t, 1, len(contents)) } @@ -63,14 +63,14 @@ func TestGRPCHeaderInjectionForHTTPPropagationServerInterceptor(t *testing.T) { testMap["TestHeader2"] = "Results2" ctx = metadata.NewOutgoingContext(ctx, nil) - ctx = util_log.ContextWithHeaderMap(ctx, testMap) - ctx = injectForwardedHeadersIntoMetadata(ctx) + ctx = requestmeta.ContextWithRequestMetadataMap(ctx, testMap) + ctx = injectForwardedRequestMetadata(ctx) md, ok := metadata.FromOutgoingContext(ctx) require.True(t, ok) - ctx = util_log.ContextWithHeaderMapFromMetadata(ctx, md) + ctx = requestmeta.ContextWithRequestMetadataMapFromMetadata(ctx, md) - headersMap := util_log.HeaderMapFromContext(ctx) + headersMap := requestmeta.MapFromContext(ctx) require.NotNil(t, headersMap) assert.Equal(t, 2, len(headersMap)) @@ -82,11 +82,11 @@ func TestGRPCHeaderInjectionForHTTPPropagationServerInterceptor(t *testing.T) { func TestGRPCHeaderDifferentLengthsForHTTPPropagationServerInterceptor(t *testing.T) { ctx := context.Background() - ctx = metadata.AppendToOutgoingContext(ctx, util_log.HeaderPropagationStringForRequestLogging, "Test123") - ctx = metadata.AppendToOutgoingContext(ctx, util_log.HeaderPropagationStringForRequestLogging, "Results") - ctx = metadata.AppendToOutgoingContext(ctx, util_log.HeaderPropagationStringForRequestLogging, "Results2") + ctx = metadata.AppendToOutgoingContext(ctx, requestmeta.PropagationStringForRequestMetadata, "Test123") + ctx = metadata.AppendToOutgoingContext(ctx, requestmeta.PropagationStringForRequestMetadata, "Results") + ctx = metadata.AppendToOutgoingContext(ctx, requestmeta.PropagationStringForRequestMetadata, "Results2") - ctx = extractForwardedHeadersFromMetadata(ctx) + ctx = extractForwardedRequestMetadataFromMetadata(ctx) - assert.Nil(t, util_log.HeaderMapFromContext(ctx)) + assert.Nil(t, requestmeta.MapFromContext(ctx)) } diff --git a/pkg/util/grpcutil/util.go b/pkg/util/grpcutil/util.go index 8da1c6916e7..41ab05a350b 100644 --- a/pkg/util/grpcutil/util.go +++ b/pkg/util/grpcutil/util.go @@ -8,7 +8,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" - util_log "github.com/cortexproject/cortex/pkg/util/log" + "github.com/cortexproject/cortex/pkg/util/requestmeta" ) type wrappedServerStream struct { @@ -34,49 +34,50 @@ func IsGRPCContextCanceled(err error) bool { // HTTPHeaderPropagationServerInterceptor allows for propagation of HTTP Request headers across gRPC calls - works // alongside HTTPHeaderPropagationClientInterceptor func HTTPHeaderPropagationServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - ctx = extractForwardedHeadersFromMetadata(ctx) + ctx = extractForwardedRequestMetadataFromMetadata(ctx) h, err := handler(ctx, req) return h, err } // HTTPHeaderPropagationStreamServerInterceptor does the same as HTTPHeaderPropagationServerInterceptor but for streams func HTTPHeaderPropagationStreamServerInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx := extractForwardedRequestMetadataFromMetadata(ss.Context()) return handler(srv, wrappedServerStream{ - ctx: extractForwardedHeadersFromMetadata(ss.Context()), + ctx: ctx, ServerStream: ss, }) } -// extractForwardedHeadersFromMetadata implements HTTPHeaderPropagationServerInterceptor by placing forwarded +// extractForwardedRequestMetadataFromMetadata implements HTTPHeaderPropagationServerInterceptor by placing forwarded // headers into incoming context -func extractForwardedHeadersFromMetadata(ctx context.Context) context.Context { +func extractForwardedRequestMetadataFromMetadata(ctx context.Context) context.Context { md, ok := metadata.FromIncomingContext(ctx) if !ok { return ctx } - return util_log.ContextWithHeaderMapFromMetadata(ctx, md) + return requestmeta.ContextWithRequestMetadataMapFromMetadata(ctx, md) } // HTTPHeaderPropagationClientInterceptor allows for propagation of HTTP Request headers across gRPC calls - works // alongside HTTPHeaderPropagationServerInterceptor func HTTPHeaderPropagationClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx = injectForwardedHeadersIntoMetadata(ctx) + ctx = injectForwardedRequestMetadata(ctx) return invoker(ctx, method, req, reply, cc, opts...) } // HTTPHeaderPropagationStreamClientInterceptor does the same as HTTPHeaderPropagationClientInterceptor but for streams func HTTPHeaderPropagationStreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - ctx = injectForwardedHeadersIntoMetadata(ctx) + ctx = injectForwardedRequestMetadata(ctx) return streamer(ctx, desc, cc, method, opts...) } -// injectForwardedHeadersIntoMetadata implements HTTPHeaderPropagationClientInterceptor and HTTPHeaderPropagationStreamClientInterceptor +// injectForwardedRequestMetadata implements HTTPHeaderPropagationClientInterceptor and HTTPHeaderPropagationStreamClientInterceptor // by inserting headers that are supposed to be forwarded into metadata of the request -func injectForwardedHeadersIntoMetadata(ctx context.Context) context.Context { - headerMap := util_log.HeaderMapFromContext(ctx) - if headerMap == nil { +func injectForwardedRequestMetadata(ctx context.Context) context.Context { + requestMetadataMap := requestmeta.MapFromContext(ctx) + if requestMetadataMap == nil { return ctx } md, ok := metadata.FromOutgoingContext(ctx) @@ -85,13 +86,13 @@ func injectForwardedHeadersIntoMetadata(ctx context.Context) context.Context { } newCtx := ctx - if _, ok := md[util_log.HeaderPropagationStringForRequestLogging]; !ok { + if _, ok := md[requestmeta.PropagationStringForRequestMetadata]; !ok { var mdContent []string - for header, content := range headerMap { - mdContent = append(mdContent, header, content) + for requestMetadata, content := range requestMetadataMap { + mdContent = append(mdContent, requestMetadata, content) } md = md.Copy() - md[util_log.HeaderPropagationStringForRequestLogging] = mdContent + md[requestmeta.PropagationStringForRequestMetadata] = mdContent newCtx = metadata.NewOutgoingContext(ctx, md) } return newCtx diff --git a/pkg/util/log/log.go b/pkg/util/log/log.go index 1db95b0b074..79b93b3c576 100644 --- a/pkg/util/log/log.go +++ b/pkg/util/log/log.go @@ -1,9 +1,7 @@ package log import ( - "context" "fmt" - "net/http" "os" "github.com/go-kit/log" @@ -12,15 +10,6 @@ import ( "github.com/prometheus/common/promslog" "github.com/weaveworks/common/logging" "github.com/weaveworks/common/server" - "google.golang.org/grpc/metadata" -) - -type contextKey int - -const ( - headerMapContextKey contextKey = 0 - - HeaderPropagationStringForRequestLogging string = "x-http-header-forwarding-logging" ) var ( @@ -126,36 +115,3 @@ func CheckFatal(location string, err error) { os.Exit(1) } } - -func HeaderMapFromContext(ctx context.Context) map[string]string { - headerMap, ok := ctx.Value(headerMapContextKey).(map[string]string) - if !ok { - return nil - } - return headerMap -} - -func ContextWithHeaderMap(ctx context.Context, headerMap map[string]string) context.Context { - return context.WithValue(ctx, headerMapContextKey, headerMap) -} - -// InjectHeadersIntoHTTPRequest injects the logging header map from the context into the request headers. -func InjectHeadersIntoHTTPRequest(headerMap map[string]string, request *http.Request) { - for header, contents := range headerMap { - request.Header.Add(header, contents) - } -} - -func ContextWithHeaderMapFromMetadata(ctx context.Context, md metadata.MD) context.Context { - headersSlice, ok := md[HeaderPropagationStringForRequestLogging] - if !ok || len(headersSlice)%2 == 1 { - return ctx - } - - headerMap := make(map[string]string) - for i := 0; i < len(headersSlice); i += 2 { - headerMap[headersSlice[i]] = headersSlice[i+1] - } - - return ContextWithHeaderMap(ctx, headerMap) -} diff --git a/pkg/util/log/log_test.go b/pkg/util/log/log_test.go index 0401d4ce086..cb4700afac8 100644 --- a/pkg/util/log/log_test.go +++ b/pkg/util/log/log_test.go @@ -1,73 +1,15 @@ package log import ( - "context" "io" - "net/http" "os" "testing" "github.com/go-kit/log/level" "github.com/stretchr/testify/require" "github.com/weaveworks/common/server" - "google.golang.org/grpc/metadata" ) -func TestHeaderMapFromMetadata(t *testing.T) { - md := metadata.New(nil) - md.Append(HeaderPropagationStringForRequestLogging, "TestHeader1", "SomeInformation", "TestHeader2", "ContentsOfTestHeader2") - - ctx := context.Background() - - ctx = ContextWithHeaderMapFromMetadata(ctx, md) - - headerMap := HeaderMapFromContext(ctx) - - require.Contains(t, headerMap, "TestHeader1") - require.Contains(t, headerMap, "TestHeader2") - require.Equal(t, "SomeInformation", headerMap["TestHeader1"]) - require.Equal(t, "ContentsOfTestHeader2", headerMap["TestHeader2"]) -} - -func TestHeaderMapFromMetadataWithImproperLength(t *testing.T) { - md := metadata.New(nil) - md.Append(HeaderPropagationStringForRequestLogging, "TestHeader1", "SomeInformation", "TestHeader2", "ContentsOfTestHeader2", "Test3") - - ctx := context.Background() - - ctx = ContextWithHeaderMapFromMetadata(ctx, md) - - headerMap := HeaderMapFromContext(ctx) - require.Nil(t, headerMap) -} - -func TestInjectHeadersIntoHTTPRequest(t *testing.T) { - contentsMap := make(map[string]string) - contentsMap["TestHeader1"] = "RequestID" - contentsMap["TestHeader2"] = "ContentsOfTestHeader2" - - h := http.Header{} - req := &http.Request{ - Method: "GET", - RequestURI: "/HTTPHeaderTest", - Body: http.NoBody, - Header: h, - } - InjectHeadersIntoHTTPRequest(contentsMap, req) - - header1 := req.Header.Values("TestHeader1") - header2 := req.Header.Values("TestHeader2") - - require.NotNil(t, header1) - require.NotNil(t, header2) - require.Equal(t, 1, len(header1)) - require.Equal(t, 1, len(header2)) - - require.Equal(t, "RequestID", header1[0]) - require.Equal(t, "ContentsOfTestHeader2", header2[0]) - -} - func TestInitLogger(t *testing.T) { stderr := os.Stderr r, w, err := os.Pipe() @@ -85,8 +27,8 @@ func TestInitLogger(t *testing.T) { require.NoError(t, w.Close()) logs, err := io.ReadAll(r) require.NoError(t, err) - require.Contains(t, string(logs), "caller=log_test.go:82 level=debug hello=world") - require.Contains(t, string(logs), "caller=log_test.go:83 level=debug msg=\"hello world\"") + require.Contains(t, string(logs), "caller=log_test.go:24 level=debug hello=world") + require.Contains(t, string(logs), "caller=log_test.go:25 level=debug msg=\"hello world\"") } func BenchmarkDisallowedLogLevels(b *testing.B) { diff --git a/pkg/util/log/wrappers.go b/pkg/util/log/wrappers.go index 1394b7b0b7b..9a706a570e5 100644 --- a/pkg/util/log/wrappers.go +++ b/pkg/util/log/wrappers.go @@ -9,6 +9,7 @@ import ( "go.opentelemetry.io/otel/trace" "github.com/cortexproject/cortex/pkg/tenant" + "github.com/cortexproject/cortex/pkg/util/requestmeta" ) // WithUserID returns a Logger that has information about the current user in @@ -64,7 +65,7 @@ func WithSourceIPs(sourceIPs string, l log.Logger) log.Logger { // HeadersFromContext enables the logging of specified HTTP Headers that have been added to a context func HeadersFromContext(ctx context.Context, l log.Logger) log.Logger { - headerContentsMap := HeaderMapFromContext(ctx) + headerContentsMap := requestmeta.LoggingHeadersAndRequestIdFromContext(ctx) for header, contents := range headerContentsMap { l = log.With(l, header, contents) } diff --git a/pkg/util/requestmeta/context.go b/pkg/util/requestmeta/context.go new file mode 100644 index 00000000000..2efae506d96 --- /dev/null +++ b/pkg/util/requestmeta/context.go @@ -0,0 +1,75 @@ +package requestmeta + +import ( + "context" + "net/http" + "net/textproto" + + "google.golang.org/grpc/metadata" +) + +type contextKey int + +const ( + requestMetadataContextKey contextKey = 0 + PropagationStringForRequestMetadata string = "x-request-metadata-propagation-string" + // HeaderPropagationStringForRequestLogging is used for backwards compatibility + HeaderPropagationStringForRequestLogging string = "x-http-header-forwarding-logging" +) + +func ContextWithRequestMetadataMap(ctx context.Context, requestContextMap map[string]string) context.Context { + return context.WithValue(ctx, requestMetadataContextKey, requestContextMap) +} + +func MapFromContext(ctx context.Context) map[string]string { + requestContextMap, ok := ctx.Value(requestMetadataContextKey).(map[string]string) + if !ok { + return nil + } + return requestContextMap +} + +// ContextWithRequestMetadataMapFromHeaders adds MetadataContext headers to context and Removes non-existent headers. +// targetHeaders is passed for backwards compatibility, otherwise header keys should be in header itself. +func ContextWithRequestMetadataMapFromHeaders(ctx context.Context, headers map[string]string, targetHeaders []string) context.Context { + headerMap := make(map[string]string) + loggingHeaders := headers[textproto.CanonicalMIMEHeaderKey(LoggingHeadersKey)] + headerKeys := targetHeaders + if loggingHeaders != "" { + headerKeys = LoggingHeaderKeysFromString(loggingHeaders) + headerKeys = append(headerKeys, LoggingHeadersKey) + } + headerKeys = append(headerKeys, RequestIdKey) + for _, header := range headerKeys { + if v, ok := headers[textproto.CanonicalMIMEHeaderKey(header)]; ok { + headerMap[header] = v + } + } + return ContextWithRequestMetadataMap(ctx, headerMap) +} + +func InjectMetadataIntoHTTPRequestHeaders(requestMetadataMap map[string]string, request *http.Request) { + for key, contents := range requestMetadataMap { + request.Header.Add(key, contents) + } +} + +func ContextWithRequestMetadataMapFromMetadata(ctx context.Context, md metadata.MD) context.Context { + headersSlice, ok := md[PropagationStringForRequestMetadata] + + // we want to check old key if no data + if !ok { + headersSlice, ok = md[HeaderPropagationStringForRequestLogging] + } + + if !ok || len(headersSlice)%2 == 1 { + return ctx + } + + requestMetadataMap := make(map[string]string) + for i := 0; i < len(headersSlice); i += 2 { + requestMetadataMap[headersSlice[i]] = headersSlice[i+1] + } + + return ContextWithRequestMetadataMap(ctx, requestMetadataMap) +} diff --git a/pkg/util/requestmeta/context_test.go b/pkg/util/requestmeta/context_test.go new file mode 100644 index 00000000000..23a0d3b4dab --- /dev/null +++ b/pkg/util/requestmeta/context_test.go @@ -0,0 +1,113 @@ +package requestmeta + +import ( + "context" + "net/http" + "net/textproto" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" +) + +func TestRequestMetadataMapFromMetadata(t *testing.T) { + md := metadata.New(nil) + md.Append(PropagationStringForRequestMetadata, "TestHeader1", "SomeInformation", "TestHeader2", "ContentsOfTestHeader2") + + ctx := context.Background() + + ctx = ContextWithRequestMetadataMapFromMetadata(ctx, md) + + requestMetadataMap := MapFromContext(ctx) + + require.Contains(t, requestMetadataMap, "TestHeader1") + require.Contains(t, requestMetadataMap, "TestHeader2") + require.Equal(t, "SomeInformation", requestMetadataMap["TestHeader1"]) + require.Equal(t, "ContentsOfTestHeader2", requestMetadataMap["TestHeader2"]) +} + +func TestRequestMetadataMapFromMetadataWithImproperLength(t *testing.T) { + md := metadata.New(nil) + md.Append(PropagationStringForRequestMetadata, "TestHeader1", "SomeInformation", "TestHeader2", "ContentsOfTestHeader2", "Test3") + + ctx := context.Background() + + ctx = ContextWithRequestMetadataMapFromMetadata(ctx, md) + + requestMetadataMap := MapFromContext(ctx) + require.Nil(t, requestMetadataMap) +} + +func TestContextWithRequestMetadataMapFromHeaders_WithLoggingHeaders(t *testing.T) { + headers := map[string]string{ + textproto.CanonicalMIMEHeaderKey("X-Request-ID"): "1234", + textproto.CanonicalMIMEHeaderKey("X-User-ID"): "user5678", + textproto.CanonicalMIMEHeaderKey(LoggingHeadersKey): "X-Request-ID,X-User-ID", + } + + ctx := context.Background() + ctx = ContextWithRequestMetadataMapFromHeaders(ctx, headers, nil) + + requestMetadataMap := MapFromContext(ctx) + + require.Contains(t, requestMetadataMap, "X-Request-ID") + require.Contains(t, requestMetadataMap, "X-User-ID") + require.Equal(t, "1234", requestMetadataMap["X-Request-ID"]) + require.Equal(t, "user5678", requestMetadataMap["X-User-ID"]) +} + +func TestContextWithRequestMetadataMapFromHeaders_BackwardCompatibleTargetHeaders(t *testing.T) { + headers := map[string]string{ + textproto.CanonicalMIMEHeaderKey("X-Legacy-Header"): "legacy-value", + } + + ctx := context.Background() + ctx = ContextWithRequestMetadataMapFromHeaders(ctx, headers, []string{"X-Legacy-Header"}) + + requestMetadataMap := MapFromContext(ctx) + + require.Contains(t, requestMetadataMap, "X-Legacy-Header") + require.Equal(t, "legacy-value", requestMetadataMap["X-Legacy-Header"]) +} + +func TestContextWithRequestMetadataMapFromHeaders_OnlyMatchingKeysUsed(t *testing.T) { + headers := map[string]string{ + textproto.CanonicalMIMEHeaderKey("X-Some-Header"): "value1", + textproto.CanonicalMIMEHeaderKey("Unused-Header"): "value2", + textproto.CanonicalMIMEHeaderKey(LoggingHeadersKey): "X-Some-Header", + } + + ctx := context.Background() + ctx = ContextWithRequestMetadataMapFromHeaders(ctx, headers, nil) + + requestMetadataMap := MapFromContext(ctx) + + require.Equal(t, "value1", requestMetadataMap["X-Some-Header"]) +} + +func TestInjectMetadataIntoHTTPRequestHeaders(t *testing.T) { + contentsMap := make(map[string]string) + contentsMap["TestHeader1"] = "RequestID" + contentsMap["TestHeader2"] = "ContentsOfTestHeader2" + + h := http.Header{} + req := &http.Request{ + Method: "GET", + RequestURI: "/HTTPHeaderTest", + Body: http.NoBody, + Header: h, + } + InjectMetadataIntoHTTPRequestHeaders(contentsMap, req) + + header1 := req.Header.Values("TestHeader1") + header2 := req.Header.Values("TestHeader2") + + require.NotNil(t, header1) + require.NotNil(t, header2) + require.Equal(t, 1, len(header1)) + require.Equal(t, 1, len(header2)) + + require.Equal(t, "RequestID", header1[0]) + require.Equal(t, "ContentsOfTestHeader2", header2[0]) + +} diff --git a/pkg/util/requestmeta/id.go b/pkg/util/requestmeta/id.go new file mode 100644 index 00000000000..01b34e430a1 --- /dev/null +++ b/pkg/util/requestmeta/id.go @@ -0,0 +1,22 @@ +package requestmeta + +import "context" + +const RequestIdKey = "x-cortex-request-id" + +func RequestIdFromContext(ctx context.Context) string { + metadataMap := MapFromContext(ctx) + if metadataMap == nil { + return "" + } + return metadataMap[RequestIdKey] +} + +func ContextWithRequestId(ctx context.Context, reqId string) context.Context { + metadataMap := MapFromContext(ctx) + if metadataMap == nil { + metadataMap = make(map[string]string) + } + metadataMap[RequestIdKey] = reqId + return ContextWithRequestMetadataMap(ctx, metadataMap) +} diff --git a/pkg/util/requestmeta/logging_headers.go b/pkg/util/requestmeta/logging_headers.go new file mode 100644 index 00000000000..08b62e8d7ec --- /dev/null +++ b/pkg/util/requestmeta/logging_headers.go @@ -0,0 +1,52 @@ +package requestmeta + +import ( + "context" + "strings" +) + +const ( + LoggingHeadersKey = "x-request-logging-headers-key" + loggingHeadersDelimiter = "," +) + +func LoggingHeaderKeysToString(targetHeaders []string) string { + return strings.Join(targetHeaders, loggingHeadersDelimiter) +} + +func LoggingHeaderKeysFromString(headerKeysString string) []string { + return strings.Split(headerKeysString, loggingHeadersDelimiter) +} + +func LoggingHeadersFromContext(ctx context.Context) map[string]string { + metadataMap := MapFromContext(ctx) + if metadataMap == nil { + return nil + } + loggingHeadersString := metadataMap[LoggingHeadersKey] + if loggingHeadersString == "" { + // Backward compatibility: if no specific headers are listed, return all metadata + return metadataMap + } + + result := make(map[string]string) + for _, header := range LoggingHeaderKeysFromString(loggingHeadersString) { + if v, ok := metadataMap[header]; ok { + result[header] = v + } + } + return result +} + +func LoggingHeadersAndRequestIdFromContext(ctx context.Context) map[string]string { + metadataMap := MapFromContext(ctx) + if metadataMap == nil { + return nil + } + + loggingHeaders := LoggingHeadersFromContext(ctx) + reqId := RequestIdFromContext(ctx) + loggingHeaders[RequestIdKey] = reqId + + return loggingHeaders +}