Skip to content

Commit fc6da2d

Browse files
authored
improve allocation when extracting gRPC metadata (#7065)
Signed-off-by: yeya24 <[email protected]>
1 parent db252aa commit fc6da2d

File tree

7 files changed

+196
-34
lines changed

7 files changed

+196
-34
lines changed

pkg/storegateway/bucket_stores.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -786,16 +786,10 @@ func (u *BucketStores) getTokensToRetrieve(tokens uint64, dataType store.StoreDa
786786
}
787787

788788
func getUserIDFromGRPCContext(ctx context.Context) string {
789-
meta, ok := metadata.FromIncomingContext(ctx)
790-
if !ok {
789+
values := metadata.ValueFromIncomingContext(ctx, tsdb.TenantIDExternalLabel)
790+
if values == nil || len(values) != 1 {
791791
return ""
792792
}
793-
794-
values := meta.Get(tsdb.TenantIDExternalLabel)
795-
if len(values) != 1 {
796-
return ""
797-
}
798-
799793
return values[0]
800794
}
801795

pkg/util/extract_forwarded.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,8 @@ func GetSourceIPsFromOutgoingCtx(ctx context.Context) string {
2424

2525
// GetSourceIPsFromIncomingCtx extracts the source field from the GRPC context
2626
func GetSourceIPsFromIncomingCtx(ctx context.Context) string {
27-
md, ok := metadata.FromIncomingContext(ctx)
28-
if !ok {
29-
return ""
30-
}
31-
ipAddresses, ok := md[ipAddressesKey]
32-
if !ok {
27+
ipAddresses := metadata.ValueFromIncomingContext(ctx, ipAddressesKey)
28+
if ipAddresses == nil || len(ipAddresses) != 1 {
3329
return ""
3430
}
3531
return ipAddresses[0]

pkg/util/grpcclient/signing_handler.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,12 @@ func UnarySigningServerInterceptor(ctx context.Context, req any, _ *grpc.UnarySe
3333
return handler(ctx, req)
3434
}
3535

36-
md, ok := metadata.FromIncomingContext(ctx)
37-
38-
if !ok {
39-
return nil, ErrSignatureNotPresent
40-
}
41-
42-
sig, ok := md[reqSignHeaderName]
43-
44-
if !ok || len(sig) != 1 {
36+
sig := metadata.ValueFromIncomingContext(ctx, reqSignHeaderName)
37+
if sig == nil || len(sig) != 1 {
4538
return nil, ErrSignatureNotPresent
4639
}
4740

4841
valid, err := rs.VerifySign(ctx, sig[0])
49-
5042
if err != nil {
5143
return nil, err
5244
}

pkg/util/grpcutil/util.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,17 @@ func HTTPHeaderPropagationStreamServerInterceptor(srv any, ss grpc.ServerStream,
5151
// extractForwardedRequestMetadataFromMetadata implements HTTPHeaderPropagationServerInterceptor by placing forwarded
5252
// headers into incoming context
5353
func extractForwardedRequestMetadataFromMetadata(ctx context.Context) context.Context {
54-
md, ok := metadata.FromIncomingContext(ctx)
55-
if !ok {
54+
headersSlice := metadata.ValueFromIncomingContext(ctx, requestmeta.PropagationStringForRequestMetadata)
55+
if headersSlice == nil {
56+
// we want to check old key if no data
57+
headersSlice = metadata.ValueFromIncomingContext(ctx, requestmeta.HeaderPropagationStringForRequestLogging)
58+
}
59+
60+
if headersSlice == nil {
5661
return ctx
5762
}
58-
return requestmeta.ContextWithRequestMetadataMapFromMetadata(ctx, md)
63+
64+
return requestmeta.ContextWithRequestMetadataMapFromHeaderSlice(ctx, headersSlice)
5965
}
6066

6167
// HTTPHeaderPropagationClientInterceptor allows for propagation of HTTP Request headers across gRPC calls - works

pkg/util/grpcutil/util_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package grpcutil
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"google.golang.org/grpc/metadata"
9+
10+
"github.com/cortexproject/cortex/pkg/util/requestmeta"
11+
)
12+
13+
// TestExtractForwardedRequestMetadataFromMetadata tests the extractForwardedRequestMetadataFromMetadata function
14+
func TestExtractForwardedRequestMetadataFromMetadata(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
ctx context.Context
18+
expectedResult map[string]string
19+
}{
20+
{
21+
name: "context without metadata",
22+
ctx: context.Background(),
23+
expectedResult: nil,
24+
},
25+
{
26+
name: "context with new metadata key",
27+
ctx: func() context.Context {
28+
md := metadata.New(nil)
29+
md.Append(requestmeta.PropagationStringForRequestMetadata, "header1", "value1", "header2", "value2")
30+
return metadata.NewIncomingContext(context.Background(), md)
31+
}(),
32+
expectedResult: map[string]string{
33+
"header1": "value1",
34+
"header2": "value2",
35+
},
36+
},
37+
{
38+
name: "context with old metadata key",
39+
ctx: func() context.Context {
40+
md := metadata.New(nil)
41+
md.Append(requestmeta.HeaderPropagationStringForRequestLogging, "header1", "value1", "header2", "value2")
42+
return metadata.NewIncomingContext(context.Background(), md)
43+
}(),
44+
expectedResult: map[string]string{
45+
"header1": "value1",
46+
"header2": "value2",
47+
},
48+
},
49+
{
50+
name: "context with both keys, new key takes precedence",
51+
ctx: func() context.Context {
52+
md := metadata.New(nil)
53+
md.Append(requestmeta.PropagationStringForRequestMetadata, "newheader", "newvalue")
54+
md.Append(requestmeta.HeaderPropagationStringForRequestLogging, "oldheader", "oldvalue")
55+
return metadata.NewIncomingContext(context.Background(), md)
56+
}(),
57+
expectedResult: map[string]string{
58+
"newheader": "newvalue",
59+
},
60+
},
61+
{
62+
name: "context with odd number of metadata values",
63+
ctx: func() context.Context {
64+
md := metadata.New(nil)
65+
md.Append(requestmeta.PropagationStringForRequestMetadata, "header1", "value1", "header2")
66+
return metadata.NewIncomingContext(context.Background(), md)
67+
}(),
68+
expectedResult: nil,
69+
},
70+
}
71+
72+
for _, tt := range tests {
73+
t.Run(tt.name, func(t *testing.T) {
74+
result := extractForwardedRequestMetadataFromMetadata(tt.ctx)
75+
metadataMap := requestmeta.MapFromContext(result)
76+
77+
if tt.expectedResult == nil {
78+
assert.Nil(t, metadataMap)
79+
} else {
80+
assert.Equal(t, tt.expectedResult, metadataMap)
81+
}
82+
})
83+
}
84+
}

pkg/util/requestmeta/context.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ func InjectMetadataIntoHTTPRequestHeaders(requestMetadataMap map[string]string,
5555
}
5656
}
5757

58+
func ContextWithRequestMetadataMapFromHeaderSlice(ctx context.Context, headerSlice []string) context.Context {
59+
if len(headerSlice)%2 == 1 {
60+
return ctx
61+
}
62+
63+
requestMetadataMap := make(map[string]string, len(headerSlice)/2)
64+
for i := 0; i < len(headerSlice); i += 2 {
65+
requestMetadataMap[headerSlice[i]] = headerSlice[i+1]
66+
}
67+
68+
return ContextWithRequestMetadataMap(ctx, requestMetadataMap)
69+
}
70+
5871
func ContextWithRequestMetadataMapFromMetadata(ctx context.Context, md metadata.MD) context.Context {
5972
headersSlice, ok := md[PropagationStringForRequestMetadata]
6073

@@ -63,14 +76,9 @@ func ContextWithRequestMetadataMapFromMetadata(ctx context.Context, md metadata.
6376
headersSlice, ok = md[HeaderPropagationStringForRequestLogging]
6477
}
6578

66-
if !ok || len(headersSlice)%2 == 1 {
79+
if !ok {
6780
return ctx
6881
}
6982

70-
requestMetadataMap := make(map[string]string)
71-
for i := 0; i < len(headersSlice); i += 2 {
72-
requestMetadataMap[headersSlice[i]] = headersSlice[i+1]
73-
}
74-
75-
return ContextWithRequestMetadataMap(ctx, requestMetadataMap)
83+
return ContextWithRequestMetadataMapFromHeaderSlice(ctx, headersSlice)
7684
}

pkg/util/requestmeta/context_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,85 @@ func TestInjectMetadataIntoHTTPRequestHeaders(t *testing.T) {
111111
require.Equal(t, "ContentsOfTestHeader2", header2[0])
112112

113113
}
114+
115+
func TestContextWithRequestMetadataMapFromHeaderSlice(t *testing.T) {
116+
tests := []struct {
117+
name string
118+
headerSlice []string
119+
expectedResult map[string]string
120+
}{
121+
{
122+
name: "empty header slice",
123+
headerSlice: []string{},
124+
expectedResult: map[string]string{},
125+
},
126+
{
127+
name: "nil header slice",
128+
headerSlice: nil,
129+
expectedResult: map[string]string{},
130+
},
131+
{
132+
name: "odd number of elements",
133+
headerSlice: []string{"header1", "value1", "header2"},
134+
expectedResult: nil,
135+
},
136+
{
137+
name: "single key-value pair",
138+
headerSlice: []string{"header1", "value1"},
139+
expectedResult: map[string]string{
140+
"header1": "value1",
141+
},
142+
},
143+
{
144+
name: "multiple key-value pairs",
145+
headerSlice: []string{"header1", "value1", "header2", "value2", "header3", "value3"},
146+
expectedResult: map[string]string{
147+
"header1": "value1",
148+
"header2": "value2",
149+
"header3": "value3",
150+
},
151+
},
152+
{
153+
name: "duplicate keys (last value wins)",
154+
headerSlice: []string{"header1", "value1", "header1", "value2"},
155+
expectedResult: map[string]string{
156+
"header1": "value2",
157+
},
158+
},
159+
{
160+
name: "empty values",
161+
headerSlice: []string{"header1", "", "header2", "value2"},
162+
expectedResult: map[string]string{
163+
"header1": "",
164+
"header2": "value2",
165+
},
166+
},
167+
{
168+
name: "special characters in keys and values",
169+
headerSlice: []string{"header-1", "value with spaces", "header_2", "value-with-dashes"},
170+
expectedResult: map[string]string{
171+
"header-1": "value with spaces",
172+
"header_2": "value-with-dashes",
173+
},
174+
},
175+
}
176+
177+
for _, tt := range tests {
178+
t.Run(tt.name, func(t *testing.T) {
179+
ctx := context.Background()
180+
result := ContextWithRequestMetadataMapFromHeaderSlice(ctx, tt.headerSlice)
181+
metadataMap := MapFromContext(result)
182+
183+
if tt.expectedResult == nil {
184+
require.Nil(t, metadataMap)
185+
} else {
186+
require.NotNil(t, metadataMap)
187+
require.Equal(t, len(tt.expectedResult), len(metadataMap))
188+
for key, expectedValue := range tt.expectedResult {
189+
require.Contains(t, metadataMap, key)
190+
require.Equal(t, expectedValue, metadataMap[key])
191+
}
192+
}
193+
})
194+
}
195+
}

0 commit comments

Comments
 (0)