Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20251024170139.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: [headers] Added utilities for header sanitisation
1 change: 1 addition & 0 deletions changes/20251024170358.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:bug: [headers] Fix header search due to normalisation of the name by go
10 changes: 5 additions & 5 deletions utils/http/header_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (

type ClientWithHeaders struct {
client IClient
headers headers.Headers
headers *headers.Headers
}

func newClientWithHeaders(underlyingClient IClient, headerValues ...string) (c *ClientWithHeaders, err error) {
c = &ClientWithHeaders{
headers: make(headers.Headers),
headers: headers.NewHeaders(),
}

if underlyingClient == nil {
Expand Down Expand Up @@ -123,7 +123,7 @@ func (c *ClientWithHeaders) Close() error {

func (c *ClientWithHeaders) AppendHeader(key, value string) {
if c.headers == nil {
c.headers = make(headers.Headers)
c.headers = headers.NewHeaders()
}
c.headers.AppendHeader(key, value)
}
Expand All @@ -132,9 +132,9 @@ func (c *ClientWithHeaders) RemoveHeader(key string) {
if c.headers == nil {
return
}
delete(c.headers, key)
c.headers.RemoveHeader(key)
}

func (c *ClientWithHeaders) ClearHeaders() {
c.headers = make(headers.Headers)
c.headers = headers.NewHeaders()
}
4 changes: 3 additions & 1 deletion utils/http/header_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ func TestClientWithHeadersWithDifferentBodies(t *testing.T) {

clientStruct.AppendHeader("hello", "world")
require.NotEmpty(t, clientStruct.headers)
assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, clientStruct.headers["hello"])
header := clientStruct.headers.GetHeader("hello")
require.NotNil(t, header)
assert.Equal(t, headers.Header{Key: "hello", Value: "world"}, *header)

clientStruct.RemoveHeader("hello")
assert.Empty(t, clientStruct.headers)
Expand Down
132 changes: 116 additions & 16 deletions utils/http/headers/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"net/http"
"strings"

mapset "github.com/deckarep/golang-set/v2"
"github.com/go-http-utils/headers"

"github.com/ARM-software/golang-utils/utils/collection"
"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/encoding/base64"
"github.com/ARM-software/golang-utils/utils/field"
"github.com/ARM-software/golang-utils/utils/http/headers/useragent"
"github.com/ARM-software/golang-utils/utils/http/schemes"
"github.com/ARM-software/golang-utils/utils/reflection"
Expand Down Expand Up @@ -146,6 +148,9 @@ var (
headers.XRatelimitRemaining,
headers.XRatelimitReset,
}
// NormalisedSafeHeaders returns a normalised list of safe headers
NormalisedSafeHeaders = collection.Map[string, string](SafeHeaders, headers.Normalize) //nolint:misspell

)

type Header struct {
Expand All @@ -167,17 +172,34 @@ func (hs Headers) AppendHeader(key, value string) {
}

func (hs Headers) Append(h *Header) {
hs[h.Key] = *h
hs[headers.Normalize(h.Key)] = *h //nolint:misspell
}

func (hs Headers) Get(key string) string {
h, found := hs[key]
h, found := hs.get(key)
if !found {
return ""
}
return h.Value
}

func (hs Headers) GetHeader(key string) (header *Header) {
header, _ = hs.get(key)
return
}

func (hs Headers) get(key string) (header *Header, found bool) {
h, found := hs[key]
if !found {
h, found = hs[headers.Normalize(key)] //nolint:misspell
if !found {
return
}
}
header = &h
return
}

func (hs Headers) Has(h *Header) bool {
if h == nil {
return false
Expand All @@ -186,10 +208,30 @@ func (hs Headers) Has(h *Header) bool {
}

func (hs Headers) HasHeader(key string) bool {
_, found := hs[key]
_, found := hs.get(key)
return found
}

func (hs Headers) FromRequest(r *http.Request) {
if reflection.IsEmpty(r) {
return
}
hs.FromGoHTTPHeaders(&r.Header)
}

func (hs Headers) FromGoHTTPHeaders(headers *http.Header) {
for key, value := range field.Optional[http.Header](headers, http.Header{}) {
hs.AppendHeader(key, value[0])
}
}

func (hs Headers) FromResponse(resp *http.Response) {
if reflection.IsEmpty(resp) {
return
}
hs.FromGoHTTPHeaders(&resp.Header)
}

func (hs Headers) Empty() bool {
return len(hs) == 0
}
Expand All @@ -210,10 +252,77 @@ func (hs Headers) AppendToRequest(r *http.Request) {
}
}

func (hs Headers) RemoveHeader(key string) {
delete(hs, key)
delete(hs, headers.Normalize(key)) //nolint:misspell
}

func (hs Headers) RemoveHeaders(key ...string) {
for i := range key {
hs.RemoveHeader(key[i])
}
}

func (hs Headers) Clone() *Headers {
clone := make(Headers, len(hs))
for k, v := range hs {
clone[k] = v
}
return &clone
}

// DisallowList returns the headers minus any header defined in the disallow list.
func (hs Headers) DisallowList(key ...string) *Headers {
clone := hs.Clone()
clone.RemoveHeaders(key...)
return clone
}

// AllowList return only safe headers and headers defined in the allow list.
func (hs Headers) AllowList(key ...string) *Headers {
clone := hs.Clone()
clone.Sanitise(key...)
return clone
}

// Sanitise sanitises headers so no personal data is retained.
// It is possible to provide an allowed list of extra headers which would also be retained.
func (hs Headers) Sanitise(allowList ...string) {
allowedHeaders := mapset.NewSet[string](NormalisedSafeHeaders...)
allowedHeaders.Append(collection.Map[string, string](allowList, headers.Normalize)...) //nolint:misspell
var headersToRemove []string
for key := range hs {
if !allowedHeaders.Contains(headers.Normalize(key)) { //nolint:misspell
headersToRemove = append(headersToRemove, key)
}
}
hs.RemoveHeaders(headersToRemove...)
}

func NewHeaders() *Headers {
return &Headers{}
}

// FromRequest returns request's headers
func FromRequest(r *http.Request) *Headers {
if reflection.IsEmpty(r) {
return nil
}
h := NewHeaders()
h.FromRequest(r)
return h
}

// FromResponse returns response's headers
func FromResponse(resp *http.Response) *Headers {
if reflection.IsEmpty(resp) {
return nil
}
h := NewHeaders()
h.FromResponse(resp)
return h
}

// ParseAuthorizationHeader fetches the `Authorization` header and parses it.
func ParseAuthorizationHeader(r *http.Request) (string, string, error) {
return ParseAuthorisationValue(FetchWebsocketAuthorisation(r))
Expand Down Expand Up @@ -414,17 +523,8 @@ func CreateLinkHeader(link, relation, contentType string) string {

// SanitiseHeaders sanitises a collection of request headers not to include any with personal data
func SanitiseHeaders(requestHeader *http.Header) *Headers {
if requestHeader == nil {
return nil
}
aHeaders := NewHeaders()
for i := range SafeHeaders {
safeHeader := SafeHeaders[i]
rHeader := requestHeader.Get(safeHeader)
if !reflection.IsEmpty(rHeader) {
aHeaders.AppendHeader(safeHeader, rHeader)
}
}

return aHeaders
hs := NewHeaders()
hs.FromGoHTTPHeaders(requestHeader)
hs.Sanitise()
return hs
}
78 changes: 78 additions & 0 deletions utils/http/headers/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,38 @@ func TestParseAuthorizationHeader(t *testing.T) {
})
}

func TestFromToRequestResponse(t *testing.T) {
request := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
request.Header.Add(headers.Authorization, faker.Password())
request.Header.Add(HeaderWebsocketProtocol, faker.Password())
h := FromRequest(request)
h.AppendHeader(headers.Accept, "1.0.0")
h.AppendHeader(headers.AcceptEncoding, "gzip")
r2 := httptest.NewRequest(http.MethodGet, faker.URL(), nil)
assert.Empty(t, r2.Header)
h.AppendToRequest(r2)
assert.NotEmpty(t, r2.Header)
h2 := FromRequest(r2)
assert.True(t, h2.HasHeader(headers.Authorization))
assert.True(t, h2.HasHeader(headers.AcceptEncoding))
assert.True(t, h2.HasHeader(headers.Accept))
assert.True(t, h2.HasHeader(HeaderWebsocketProtocol))

response := httptest.NewRecorder()
response.Header().Set(HeaderWebsocketProtocol, "base64.binary.k8s.io")
response.Header().Set(headers.Authorization, faker.Password())
h3 := FromResponse(response.Result())
h3.AppendHeader(headers.Accept, "1.0.0")
h3.AppendHeader(headers.AcceptEncoding, "gzip")
response2 := httptest.NewRecorder()
h3.AppendToResponse(response2)
h4 := FromResponse(response2.Result())
assert.True(t, h4.HasHeader(headers.Authorization))
assert.True(t, h4.HasHeader(headers.AcceptEncoding))
assert.True(t, h4.HasHeader(headers.Accept))
assert.True(t, h4.HasHeader(HeaderWebsocketProtocol))
}

func TestAddProductInformationToUserAgent(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, faker.URL(), nil)
require.NoError(t, err)
Expand All @@ -165,6 +197,18 @@ func TestSetLocationHeaders(t *testing.T) {
assert.Equal(t, location, w.Header().Get(headers.ContentLocation))
}

func TestGetHeaders(t *testing.T) {
header := NewHeaders()
test := faker.Word()
header.AppendHeader(HeaderWebsocketProtocol, test)
assert.Equal(t, test, header.Get(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell
assert.True(t, header.HasHeader(HeaderWebsocketProtocol))
assert.True(t, header.HasHeader(headers.Normalize(HeaderWebsocketProtocol))) //nolint:misspell
assert.Empty(t, header.Get(headers.ContentLocation))
assert.False(t, header.HasHeader(headers.ContentLocation))
assert.False(t, header.HasHeader(headers.Normalize(headers.ContentLocation))) //nolint:misspell
}

func TestSanitiseHeaders(t *testing.T) {
header := &http.Header{}
t.Run("empty", func(t *testing.T) {
Expand Down Expand Up @@ -197,5 +241,39 @@ func TestSanitiseHeaders(t *testing.T) {
assert.False(t, actual.HasHeader(
HeaderWebsocketProtocol))
})
t.Run("allow/disallow list", func(t *testing.T) {
h := NewHeaders()
h.AppendHeader(headers.Authorization, faker.Password())
h.AppendHeader(HeaderWebsocketProtocol, faker.Password())
h.AppendHeader(headers.Accept, "1.0.0")
h.AppendHeader(headers.AcceptEncoding, "gzip")
h1 := h.Clone()
h1.Sanitise()
assert.True(t, h1.HasHeader(headers.Accept))
assert.True(t, h1.HasHeader(headers.AcceptEncoding))
assert.False(t, h1.HasHeader(HeaderWebsocketProtocol))
assert.False(t, h1.HasHeader(headers.Authorization))
assert.True(t, h.HasHeader(headers.Accept))
assert.True(t, h.HasHeader(headers.AcceptEncoding))
assert.True(t, h.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h.HasHeader(headers.Authorization))
h11 := h.AllowList(headers.Authorization)
assert.True(t, h11.HasHeader(headers.Accept))
assert.True(t, h11.HasHeader(headers.AcceptEncoding))
assert.False(t, h11.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h11.HasHeader(headers.Authorization))
h2 := h.Clone()
h2.Sanitise(headers.Authorization)
h2.RemoveHeaders(headers.AcceptEncoding, headers.Accept)
assert.False(t, h2.HasHeader(headers.Accept))
assert.False(t, h2.HasHeader(headers.AcceptEncoding))
assert.False(t, h2.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h2.HasHeader(headers.Authorization))
h22 := h.DisallowList(headers.AcceptEncoding, headers.Accept)
assert.False(t, h22.HasHeader(headers.Accept))
assert.False(t, h22.HasHeader(headers.AcceptEncoding))
assert.True(t, h22.HasHeader(HeaderWebsocketProtocol))
assert.True(t, h22.HasHeader(headers.Authorization))
})

}
Loading