diff --git a/changes/20251024170139.feature b/changes/20251024170139.feature new file mode 100644 index 0000000000..6ed67c7b4e --- /dev/null +++ b/changes/20251024170139.feature @@ -0,0 +1 @@ +:sparkles: [headers] Added utilities for header sanitisation diff --git a/changes/20251024170358.bugfix b/changes/20251024170358.bugfix new file mode 100644 index 0000000000..58e5565235 --- /dev/null +++ b/changes/20251024170358.bugfix @@ -0,0 +1 @@ +:bug: [headers] Fix header search due to normalisation of the name by go diff --git a/utils/http/header_client.go b/utils/http/header_client.go index b7f9debf5a..b1e2fb0d4a 100644 --- a/utils/http/header_client.go +++ b/utils/http/header_client.go @@ -15,12 +15,12 @@ import ( type ClientWithHeaders struct { client IClient - headers headers.Headers + headers *headers.Headers } func newClientWithHeaders(underlyingClient IClient, headerValues ...string) (c *ClientWithHeaders, err error) { c = &ClientWithHeaders{ - headers: make(headers.Headers), + headers: headers.NewHeaders(), } if underlyingClient == nil { @@ -123,7 +123,7 @@ func (c *ClientWithHeaders) Close() error { func (c *ClientWithHeaders) AppendHeader(key, value string) { if c.headers == nil { - c.headers = make(headers.Headers) + c.headers = headers.NewHeaders() } c.headers.AppendHeader(key, value) } @@ -132,9 +132,9 @@ func (c *ClientWithHeaders) RemoveHeader(key string) { if c.headers == nil { return } - delete(c.headers, key) + c.headers.RemoveHeader(key) } func (c *ClientWithHeaders) ClearHeaders() { - c.headers = make(headers.Headers) + c.headers = headers.NewHeaders() } diff --git a/utils/http/header_client_test.go b/utils/http/header_client_test.go index b3357ec24a..44bce2ca05 100644 --- a/utils/http/header_client_test.go +++ b/utils/http/header_client_test.go @@ -223,7 +223,9 @@ func TestClientWithHeadersWithDifferentBodies(t *testing.T) { clientStruct.AppendHeader("hello", "world") require.NotEmpty(t, clientStruct.headers) - assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, clientStruct.headers["hello"]) + header := clientStruct.headers.GetHeader("hello") + require.NotNil(t, header) + assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, *header) clientStruct.RemoveHeader("hello") assert.Empty(t, clientStruct.headers) diff --git a/utils/http/headers/headers.go b/utils/http/headers/headers.go index 19d20acb61..1c0781616c 100644 --- a/utils/http/headers/headers.go +++ b/utils/http/headers/headers.go @@ -6,11 +6,13 @@ import ( "net/http" "strings" + mapset "github.com/deckarep/golang-set/v2" "github.com/go-http-utils/headers" "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/encoding/base64" + "github.com/ARM-software/golang-utils/utils/field" "github.com/ARM-software/golang-utils/utils/http/headers/useragent" "github.com/ARM-software/golang-utils/utils/http/schemes" "github.com/ARM-software/golang-utils/utils/reflection" @@ -146,6 +148,9 @@ var ( headers.XRatelimitRemaining, headers.XRatelimitReset, } + // NormalisedSafeHeaders returns a normalised list of safe headers + NormalisedSafeHeaders = collection.Map[string, string](SafeHeaders, headers.Normalize) //nolint:misspell + ) type Header struct { @@ -167,17 +172,34 @@ func (hs Headers) AppendHeader(key, value string) { } func (hs Headers) Append(h *Header) { - hs[h.Key] = *h + hs[headers.Normalize(h.Key)] = *h //nolint:misspell } func (hs Headers) Get(key string) string { - h, found := hs[key] + h, found := hs.get(key) if !found { return "" } return h.Value } +func (hs Headers) GetHeader(key string) (header *Header) { + header, _ = hs.get(key) + return +} + +func (hs Headers) get(key string) (header *Header, found bool) { + h, found := hs[key] + if !found { + h, found = hs[headers.Normalize(key)] //nolint:misspell + if !found { + return + } + } + header = &h + return +} + func (hs Headers) Has(h *Header) bool { if h == nil { return false @@ -186,10 +208,30 @@ func (hs Headers) Has(h *Header) bool { } func (hs Headers) HasHeader(key string) bool { - _, found := hs[key] + _, found := hs.get(key) return found } +func (hs Headers) FromRequest(r *http.Request) { + if reflection.IsEmpty(r) { + return + } + hs.FromGoHTTPHeaders(&r.Header) +} + +func (hs Headers) FromGoHTTPHeaders(headers *http.Header) { + for key, value := range field.Optional[http.Header](headers, http.Header{}) { + hs.AppendHeader(key, value[0]) + } +} + +func (hs Headers) FromResponse(resp *http.Response) { + if reflection.IsEmpty(resp) { + return + } + hs.FromGoHTTPHeaders(&resp.Header) +} + func (hs Headers) Empty() bool { return len(hs) == 0 } @@ -210,10 +252,77 @@ func (hs Headers) AppendToRequest(r *http.Request) { } } +func (hs Headers) RemoveHeader(key string) { + delete(hs, key) + delete(hs, headers.Normalize(key)) //nolint:misspell +} + +func (hs Headers) RemoveHeaders(key ...string) { + for i := range key { + hs.RemoveHeader(key[i]) + } +} + +func (hs Headers) Clone() *Headers { + clone := make(Headers, len(hs)) + for k, v := range hs { + clone[k] = v + } + return &clone +} + +// DisallowList returns the headers minus any header defined in the disallow list. +func (hs Headers) DisallowList(key ...string) *Headers { + clone := hs.Clone() + clone.RemoveHeaders(key...) + return clone +} + +// AllowList return only safe headers and headers defined in the allow list. +func (hs Headers) AllowList(key ...string) *Headers { + clone := hs.Clone() + clone.Sanitise(key...) + return clone +} + +// Sanitise sanitises headers so no personal data is retained. +// It is possible to provide an allowed list of extra headers which would also be retained. +func (hs Headers) Sanitise(allowList ...string) { + allowedHeaders := mapset.NewSet[string](NormalisedSafeHeaders...) + allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...) //nolint:misspell + var headersToRemove []string + for key := range hs { + if !allowedHeaders.Contains(headers.Normalize(key)) { //nolint:misspell + headersToRemove = append(headersToRemove, key) + } + } + hs.RemoveHeaders(headersToRemove...) +} + func NewHeaders() *Headers { return &Headers{} } +// FromRequest returns request's headers +func FromRequest(r *http.Request) *Headers { + if reflection.IsEmpty(r) { + return nil + } + h := NewHeaders() + h.FromRequest(r) + return h +} + +// FromResponse returns response's headers +func FromResponse(resp *http.Response) *Headers { + if reflection.IsEmpty(resp) { + return nil + } + h := NewHeaders() + h.FromResponse(resp) + return h +} + // ParseAuthorizationHeader fetches the `Authorization` header and parses it. func ParseAuthorizationHeader(r *http.Request) (string, string, error) { return ParseAuthorisationValue(FetchWebsocketAuthorisation(r)) @@ -414,17 +523,8 @@ func CreateLinkHeader(link, relation, contentType string) string { // SanitiseHeaders sanitises a collection of request headers not to include any with personal data func SanitiseHeaders(requestHeader *http.Header) *Headers { - if requestHeader == nil { - return nil - } - aHeaders := NewHeaders() - for i := range SafeHeaders { - safeHeader := SafeHeaders[i] - rHeader := requestHeader.Get(safeHeader) - if !reflection.IsEmpty(rHeader) { - aHeaders.AppendHeader(safeHeader, rHeader) - } - } - - return aHeaders + hs := NewHeaders() + hs.FromGoHTTPHeaders(requestHeader) + hs.Sanitise() + return hs } diff --git a/utils/http/headers/headers_test.go b/utils/http/headers/headers_test.go index e3149d1efa..0d6b2d17f5 100644 --- a/utils/http/headers/headers_test.go +++ b/utils/http/headers/headers_test.go @@ -144,6 +144,38 @@ func TestParseAuthorizationHeader(t *testing.T) { }) } +func TestFromToRequestResponse(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, faker.URL(), nil) + request.Header.Add(headers.Authorization, faker.Password()) + request.Header.Add(HeaderWebsocketProtocol, faker.Password()) + h := FromRequest(request) + h.AppendHeader(headers.Accept, "1.0.0") + h.AppendHeader(headers.AcceptEncoding, "gzip") + r2 := httptest.NewRequest(http.MethodGet, faker.URL(), nil) + assert.Empty(t, r2.Header) + h.AppendToRequest(r2) + assert.NotEmpty(t, r2.Header) + h2 := FromRequest(r2) + assert.True(t, h2.HasHeader(headers.Authorization)) + assert.True(t, h2.HasHeader(headers.AcceptEncoding)) + assert.True(t, h2.HasHeader(headers.Accept)) + assert.True(t, h2.HasHeader(HeaderWebsocketProtocol)) + + response := httptest.NewRecorder() + response.Header().Set(HeaderWebsocketProtocol, "base64.binary.k8s.io") + response.Header().Set(headers.Authorization, faker.Password()) + h3 := FromResponse(response.Result()) + h3.AppendHeader(headers.Accept, "1.0.0") + h3.AppendHeader(headers.AcceptEncoding, "gzip") + response2 := httptest.NewRecorder() + h3.AppendToResponse(response2) + h4 := FromResponse(response2.Result()) + assert.True(t, h4.HasHeader(headers.Authorization)) + assert.True(t, h4.HasHeader(headers.AcceptEncoding)) + assert.True(t, h4.HasHeader(headers.Accept)) + assert.True(t, h4.HasHeader(HeaderWebsocketProtocol)) +} + func TestAddProductInformationToUserAgent(t *testing.T) { r, err := http.NewRequest(http.MethodGet, faker.URL(), nil) require.NoError(t, err) @@ -165,6 +197,18 @@ func TestSetLocationHeaders(t *testing.T) { assert.Equal(t, location, w.Header().Get(headers.ContentLocation)) } +func TestGetHeaders(t *testing.T) { + header := NewHeaders() + test := faker.Word() + header.AppendHeader(HeaderWebsocketProtocol, test) + assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell + assert.True(t, header.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell + assert.Empty(t, header.Get(headers.ContentLocation)) + assert.False(t, header.HasHeader(headers.ContentLocation)) + assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation))) //nolint:misspell +} + func TestSanitiseHeaders(t *testing.T) { header := &http.Header{} t.Run("empty", func(t *testing.T) { @@ -197,5 +241,39 @@ func TestSanitiseHeaders(t *testing.T) { assert.False(t, actual.HasHeader( HeaderWebsocketProtocol)) }) + t.Run("allow/disallow list", func(t *testing.T) { + h := NewHeaders() + h.AppendHeader(headers.Authorization, faker.Password()) + h.AppendHeader(HeaderWebsocketProtocol, faker.Password()) + h.AppendHeader(headers.Accept, "1.0.0") + h.AppendHeader(headers.AcceptEncoding, "gzip") + h1 := h.Clone() + h1.Sanitise() + assert.True(t, h1.HasHeader(headers.Accept)) + assert.True(t, h1.HasHeader(headers.AcceptEncoding)) + assert.False(t, h1.HasHeader(HeaderWebsocketProtocol)) + assert.False(t, h1.HasHeader(headers.Authorization)) + assert.True(t, h.HasHeader(headers.Accept)) + assert.True(t, h.HasHeader(headers.AcceptEncoding)) + assert.True(t, h.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h.HasHeader(headers.Authorization)) + h11 := h.AllowList(headers.Authorization) + assert.True(t, h11.HasHeader(headers.Accept)) + assert.True(t, h11.HasHeader(headers.AcceptEncoding)) + assert.False(t, h11.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h11.HasHeader(headers.Authorization)) + h2 := h.Clone() + h2.Sanitise(headers.Authorization) + h2.RemoveHeaders(headers.AcceptEncoding, headers.Accept) + assert.False(t, h2.HasHeader(headers.Accept)) + assert.False(t, h2.HasHeader(headers.AcceptEncoding)) + assert.False(t, h2.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h2.HasHeader(headers.Authorization)) + h22 := h.DisallowList(headers.AcceptEncoding, headers.Accept) + assert.False(t, h22.HasHeader(headers.Accept)) + assert.False(t, h22.HasHeader(headers.AcceptEncoding)) + assert.True(t, h22.HasHeader(HeaderWebsocketProtocol)) + assert.True(t, h22.HasHeader(headers.Authorization)) + }) }