Skip to content

Commit f0a8c8e

Browse files
authored
[headers] Added utilities for header sanitisation (#734)
<!-- Copyright (C) 2020-2022 Arm Limited or its affiliates and Contributors. All rights reserved. SPDX-License-Identifier: Apache-2.0 --> ### Description - fix some bugs - improve the sanitisation ### Test Coverage <!-- Please put an `x` in the correct box e.g. `[x]` to indicate the testing coverage of this change. --> - [x] This change is covered by existing or additional automated tests. - [ ] Manual testing has been performed (and evidence provided) as automated testing was not feasible. - [ ] Additional tests are not required for this change (e.g. documentation update).
1 parent cf4da6d commit f0a8c8e

File tree

6 files changed

+204
-22
lines changed

6 files changed

+204
-22
lines changed

changes/20251024170139.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: [headers] Added utilities for header sanitisation

changes/20251024170358.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:bug: [headers] Fix header search due to normalisation of the name by go

utils/http/header_client.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ import (
1515

1616
type ClientWithHeaders struct {
1717
client IClient
18-
headers headers.Headers
18+
headers *headers.Headers
1919
}
2020

2121
func newClientWithHeaders(underlyingClient IClient, headerValues ...string) (c *ClientWithHeaders, err error) {
2222
c = &ClientWithHeaders{
23-
headers: make(headers.Headers),
23+
headers: headers.NewHeaders(),
2424
}
2525

2626
if underlyingClient == nil {
@@ -123,7 +123,7 @@ func (c *ClientWithHeaders) Close() error {
123123

124124
func (c *ClientWithHeaders) AppendHeader(key, value string) {
125125
if c.headers == nil {
126-
c.headers = make(headers.Headers)
126+
c.headers = headers.NewHeaders()
127127
}
128128
c.headers.AppendHeader(key, value)
129129
}
@@ -132,9 +132,9 @@ func (c *ClientWithHeaders) RemoveHeader(key string) {
132132
if c.headers == nil {
133133
return
134134
}
135-
delete(c.headers, key)
135+
c.headers.RemoveHeader(key)
136136
}
137137

138138
func (c *ClientWithHeaders) ClearHeaders() {
139-
c.headers = make(headers.Headers)
139+
c.headers = headers.NewHeaders()
140140
}

utils/http/header_client_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ func TestClientWithHeadersWithDifferentBodies(t *testing.T) {
223223

224224
clientStruct.AppendHeader("hello", "world")
225225
require.NotEmpty(t, clientStruct.headers)
226-
assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, clientStruct.headers["hello"])
226+
header := clientStruct.headers.GetHeader("hello")
227+
require.NotNil(t, header)
228+
assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, *header)
227229

228230
clientStruct.RemoveHeader("hello")
229231
assert.Empty(t, clientStruct.headers)

utils/http/headers/headers.go

Lines changed: 116 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ import (
66
"net/http"
77
"strings"
88

9+
mapset "github.com/deckarep/golang-set/v2"
910
"github.com/go-http-utils/headers"
1011

1112
"github.com/ARM-software/golang-utils/utils/collection"
1213
"github.com/ARM-software/golang-utils/utils/commonerrors"
1314
"github.com/ARM-software/golang-utils/utils/encoding/base64"
15+
"github.com/ARM-software/golang-utils/utils/field"
1416
"github.com/ARM-software/golang-utils/utils/http/headers/useragent"
1517
"github.com/ARM-software/golang-utils/utils/http/schemes"
1618
"github.com/ARM-software/golang-utils/utils/reflection"
@@ -146,6 +148,9 @@ var (
146148
headers.XRatelimitRemaining,
147149
headers.XRatelimitReset,
148150
}
151+
// NormalisedSafeHeaders returns a normalised list of safe headers
152+
NormalisedSafeHeaders = collection.Map[string, string](SafeHeaders, headers.Normalize) //nolint:misspell
153+
149154
)
150155

151156
type Header struct {
@@ -167,17 +172,34 @@ func (hs Headers) AppendHeader(key, value string) {
167172
}
168173

169174
func (hs Headers) Append(h *Header) {
170-
hs[h.Key] = *h
175+
hs[headers.Normalize(h.Key)] = *h //nolint:misspell
171176
}
172177

173178
func (hs Headers) Get(key string) string {
174-
h, found := hs[key]
179+
h, found := hs.get(key)
175180
if !found {
176181
return ""
177182
}
178183
return h.Value
179184
}
180185

186+
func (hs Headers) GetHeader(key string) (header *Header) {
187+
header, _ = hs.get(key)
188+
return
189+
}
190+
191+
func (hs Headers) get(key string) (header *Header, found bool) {
192+
h, found := hs[key]
193+
if !found {
194+
h, found = hs[headers.Normalize(key)] //nolint:misspell
195+
if !found {
196+
return
197+
}
198+
}
199+
header = &h
200+
return
201+
}
202+
181203
func (hs Headers) Has(h *Header) bool {
182204
if h == nil {
183205
return false
@@ -186,10 +208,30 @@ func (hs Headers) Has(h *Header) bool {
186208
}
187209

188210
func (hs Headers) HasHeader(key string) bool {
189-
_, found := hs[key]
211+
_, found := hs.get(key)
190212
return found
191213
}
192214

215+
func (hs Headers) FromRequest(r *http.Request) {
216+
if reflection.IsEmpty(r) {
217+
return
218+
}
219+
hs.FromGoHTTPHeaders(&r.Header)
220+
}
221+
222+
func (hs Headers) FromGoHTTPHeaders(headers *http.Header) {
223+
for key, value := range field.Optional[http.Header](headers, http.Header{}) {
224+
hs.AppendHeader(key, value[0])
225+
}
226+
}
227+
228+
func (hs Headers) FromResponse(resp *http.Response) {
229+
if reflection.IsEmpty(resp) {
230+
return
231+
}
232+
hs.FromGoHTTPHeaders(&resp.Header)
233+
}
234+
193235
func (hs Headers) Empty() bool {
194236
return len(hs) == 0
195237
}
@@ -210,10 +252,77 @@ func (hs Headers) AppendToRequest(r *http.Request) {
210252
}
211253
}
212254

255+
func (hs Headers) RemoveHeader(key string) {
256+
delete(hs, key)
257+
delete(hs, headers.Normalize(key)) //nolint:misspell
258+
}
259+
260+
func (hs Headers) RemoveHeaders(key ...string) {
261+
for i := range key {
262+
hs.RemoveHeader(key[i])
263+
}
264+
}
265+
266+
func (hs Headers) Clone() *Headers {
267+
clone := make(Headers, len(hs))
268+
for k, v := range hs {
269+
clone[k] = v
270+
}
271+
return &clone
272+
}
273+
274+
// DisallowList returns the headers minus any header defined in the disallow list.
275+
func (hs Headers) DisallowList(key ...string) *Headers {
276+
clone := hs.Clone()
277+
clone.RemoveHeaders(key...)
278+
return clone
279+
}
280+
281+
// AllowList return only safe headers and headers defined in the allow list.
282+
func (hs Headers) AllowList(key ...string) *Headers {
283+
clone := hs.Clone()
284+
clone.Sanitise(key...)
285+
return clone
286+
}
287+
288+
// Sanitise sanitises headers so no personal data is retained.
289+
// It is possible to provide an allowed list of extra headers which would also be retained.
290+
func (hs Headers) Sanitise(allowList ...string) {
291+
allowedHeaders := mapset.NewSet[string](NormalisedSafeHeaders...)
292+
allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...) //nolint:misspell
293+
var headersToRemove []string
294+
for key := range hs {
295+
if !allowedHeaders.Contains(headers.Normalize(key)) { //nolint:misspell
296+
headersToRemove = append(headersToRemove, key)
297+
}
298+
}
299+
hs.RemoveHeaders(headersToRemove...)
300+
}
301+
213302
func NewHeaders() *Headers {
214303
return &Headers{}
215304
}
216305

306+
// FromRequest returns request's headers
307+
func FromRequest(r *http.Request) *Headers {
308+
if reflection.IsEmpty(r) {
309+
return nil
310+
}
311+
h := NewHeaders()
312+
h.FromRequest(r)
313+
return h
314+
}
315+
316+
// FromResponse returns response's headers
317+
func FromResponse(resp *http.Response) *Headers {
318+
if reflection.IsEmpty(resp) {
319+
return nil
320+
}
321+
h := NewHeaders()
322+
h.FromResponse(resp)
323+
return h
324+
}
325+
217326
// ParseAuthorizationHeader fetches the `Authorization` header and parses it.
218327
func ParseAuthorizationHeader(r *http.Request) (string, string, error) {
219328
return ParseAuthorisationValue(FetchWebsocketAuthorisation(r))
@@ -414,17 +523,8 @@ func CreateLinkHeader(link, relation, contentType string) string {
414523

415524
// SanitiseHeaders sanitises a collection of request headers not to include any with personal data
416525
func SanitiseHeaders(requestHeader *http.Header) *Headers {
417-
if requestHeader == nil {
418-
return nil
419-
}
420-
aHeaders := NewHeaders()
421-
for i := range SafeHeaders {
422-
safeHeader := SafeHeaders[i]
423-
rHeader := requestHeader.Get(safeHeader)
424-
if !reflection.IsEmpty(rHeader) {
425-
aHeaders.AppendHeader(safeHeader, rHeader)
426-
}
427-
}
428-
429-
return aHeaders
526+
hs := NewHeaders()
527+
hs.FromGoHTTPHeaders(requestHeader)
528+
hs.Sanitise()
529+
return hs
430530
}

utils/http/headers/headers_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,38 @@ func TestParseAuthorizationHeader(t *testing.T) {
144144
})
145145
}
146146

147+
func TestFromToRequestResponse(t *testing.T) {
148+
request := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
149+
request.Header.Add(headers.Authorization, faker.Password())
150+
request.Header.Add(HeaderWebsocketProtocol, faker.Password())
151+
h := FromRequest(request)
152+
h.AppendHeader(headers.Accept, "1.0.0")
153+
h.AppendHeader(headers.AcceptEncoding, "gzip")
154+
r2 := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
155+
assert.Empty(t, r2.Header)
156+
h.AppendToRequest(r2)
157+
assert.NotEmpty(t, r2.Header)
158+
h2 := FromRequest(r2)
159+
assert.True(t, h2.HasHeader(headers.Authorization))
160+
assert.True(t, h2.HasHeader(headers.AcceptEncoding))
161+
assert.True(t, h2.HasHeader(headers.Accept))
162+
assert.True(t, h2.HasHeader(HeaderWebsocketProtocol))
163+
164+
response := httptest.NewRecorder()
165+
response.Header().Set(HeaderWebsocketProtocol, "base64.binary.k8s.io")
166+
response.Header().Set(headers.Authorization, faker.Password())
167+
h3 := FromResponse(response.Result())
168+
h3.AppendHeader(headers.Accept, "1.0.0")
169+
h3.AppendHeader(headers.AcceptEncoding, "gzip")
170+
response2 := httptest.NewRecorder()
171+
h3.AppendToResponse(response2)
172+
h4 := FromResponse(response2.Result())
173+
assert.True(t, h4.HasHeader(headers.Authorization))
174+
assert.True(t, h4.HasHeader(headers.AcceptEncoding))
175+
assert.True(t, h4.HasHeader(headers.Accept))
176+
assert.True(t, h4.HasHeader(HeaderWebsocketProtocol))
177+
}
178+
147179
func TestAddProductInformationToUserAgent(t *testing.T) {
148180
r, err := http.NewRequest(http.MethodGet, faker.URL(), nil)
149181
require.NoError(t, err)
@@ -165,6 +197,18 @@ func TestSetLocationHeaders(t *testing.T) {
165197
assert.Equal(t, location, w.Header().Get(headers.ContentLocation))
166198
}
167199

200+
func TestGetHeaders(t *testing.T) {
201+
header := NewHeaders()
202+
test := faker.Word()
203+
header.AppendHeader(HeaderWebsocketProtocol, test)
204+
assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell
205+
assert.True(t, header.HasHeader(HeaderWebsocketProtocol))
206+
assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell
207+
assert.Empty(t, header.Get(headers.ContentLocation))
208+
assert.False(t, header.HasHeader(headers.ContentLocation))
209+
assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation))) //nolint:misspell
210+
}
211+
168212
func TestSanitiseHeaders(t *testing.T) {
169213
header := &http.Header{}
170214
t.Run("empty", func(t *testing.T) {
@@ -197,5 +241,39 @@ func TestSanitiseHeaders(t *testing.T) {
197241
assert.False(t, actual.HasHeader(
198242
HeaderWebsocketProtocol))
199243
})
244+
t.Run("allow/disallow list", func(t *testing.T) {
245+
h := NewHeaders()
246+
h.AppendHeader(headers.Authorization, faker.Password())
247+
h.AppendHeader(HeaderWebsocketProtocol, faker.Password())
248+
h.AppendHeader(headers.Accept, "1.0.0")
249+
h.AppendHeader(headers.AcceptEncoding, "gzip")
250+
h1 := h.Clone()
251+
h1.Sanitise()
252+
assert.True(t, h1.HasHeader(headers.Accept))
253+
assert.True(t, h1.HasHeader(headers.AcceptEncoding))
254+
assert.False(t, h1.HasHeader(HeaderWebsocketProtocol))
255+
assert.False(t, h1.HasHeader(headers.Authorization))
256+
assert.True(t, h.HasHeader(headers.Accept))
257+
assert.True(t, h.HasHeader(headers.AcceptEncoding))
258+
assert.True(t, h.HasHeader(HeaderWebsocketProtocol))
259+
assert.True(t, h.HasHeader(headers.Authorization))
260+
h11 := h.AllowList(headers.Authorization)
261+
assert.True(t, h11.HasHeader(headers.Accept))
262+
assert.True(t, h11.HasHeader(headers.AcceptEncoding))
263+
assert.False(t, h11.HasHeader(HeaderWebsocketProtocol))
264+
assert.True(t, h11.HasHeader(headers.Authorization))
265+
h2 := h.Clone()
266+
h2.Sanitise(headers.Authorization)
267+
h2.RemoveHeaders(headers.AcceptEncoding, headers.Accept)
268+
assert.False(t, h2.HasHeader(headers.Accept))
269+
assert.False(t, h2.HasHeader(headers.AcceptEncoding))
270+
assert.False(t, h2.HasHeader(HeaderWebsocketProtocol))
271+
assert.True(t, h2.HasHeader(headers.Authorization))
272+
h22 := h.DisallowList(headers.AcceptEncoding, headers.Accept)
273+
assert.False(t, h22.HasHeader(headers.Accept))
274+
assert.False(t, h22.HasHeader(headers.AcceptEncoding))
275+
assert.True(t, h22.HasHeader(HeaderWebsocketProtocol))
276+
assert.True(t, h22.HasHeader(headers.Authorization))
277+
})
200278

201279
}

0 commit comments

Comments
 (0)