Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ func TestAuthenticateTrustedJWT(t *testing.T) {
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
}

claimEmail := map[string]interface{}{"email": wantUserEmail}
claimEmail := map[string]any{"email": wantUserEmail}
builder := jwt.Signed(signer).Claims(claims).Claims(claimEmail)
token, err := builder.Serialize()
require.NoError(t, err, "Error serializing token using compact serialization format")
Expand Down Expand Up @@ -1162,7 +1162,7 @@ func TestAuthenticateTrustedJWT(t *testing.T) {
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
}

claimEmail := map[string]interface{}{"email": "emily@"}
claimEmail := map[string]any{"email": "emily@"}
builder := jwt.Signed(signer).Claims(claims).Claims(claimEmail)
token, err := builder.Serialize()
require.NoError(t, err, "Error serializing token using compact serialization format")
Expand Down Expand Up @@ -1197,7 +1197,7 @@ func TestAuthenticateTrustedJWT(t *testing.T) {
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
}

claimEmail := map[string]interface{}{"email": "layla@"}
claimEmail := map[string]any{"email": "layla@"}
builder := jwt.Signed(signer).Claims(claims).Claims(claimEmail)
token, err := builder.Serialize()
require.NoError(t, err, "Error serializing token using compact serialization format")
Expand Down
2 changes: 1 addition & 1 deletion auth/auth_time_sensitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestAuthenticationSpeed(t *testing.T) {
assert.True(t, user.Authenticate("goIsKewl"))

start := time.Now()
for i := 0; i < 1000; i++ {
for range 1000 {
assert.True(t, user.Authenticate("goIsKewl"))
}
durationPerAuth := time.Since(start) / 1000
Expand Down
7 changes: 3 additions & 4 deletions auth/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package auth

import (
"net/http"
"slices"
"strings"
)

Expand Down Expand Up @@ -46,10 +47,8 @@ func MatchedOrigin(allowOrigins []string, rqOrigins []string) string {
}
}
}
for _, av := range allowOrigins {
if av == "*" {
return "*"
}
if slices.Contains(allowOrigins, "*") {
return "*"
}
return ""
}
10 changes: 3 additions & 7 deletions auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"errors"
"fmt"
"net/url"
"slices"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/couchbase/sync_gateway/base"
Expand Down Expand Up @@ -86,12 +87,7 @@ func (j JWTConfigCommon) ValidFor(ctx context.Context, issuer string, audiences
if *j.ClientID == "" {
return true
}
for _, aud := range audiences {
if aud == *j.ClientID {
return true
}
}
return false
return slices.Contains(audiences, *j.ClientID)
}

var ErrNoMatchingProvider = errors.New("no matching OIDC/JWT provider")
Expand Down Expand Up @@ -209,7 +205,7 @@ func (l *LocalJWTAuthProvider) verifyToken(ctx context.Context, token string, _
}
base.DebugfCtx(ctx, base.KeyAuth, "Local JWT ID Token successfully parsed and verified (iss: %v; sub: %v)", base.UD(idToken.Issuer), base.UD(idToken.Subject))

var claims map[string]interface{}
var claims map[string]any
if err := idToken.Claims(&claims); err != nil {
base.WarnfCtx(ctx, "Failed to unmarshal ID token claims: %v", err)
}
Expand Down
20 changes: 10 additions & 10 deletions auth/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,21 @@ func TestJWTVerifyToken(t *testing.T) {

t.Run("valid RSA", test(baseProvider, CreateTestJWT(t, jose.RS256, testRSAKeypair, JWTHeaders{
"kid": testRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
}), ""))

t.Run("valid EC", test(baseProvider, CreateTestJWT(t, jose.ES256, testECKeypair, JWTHeaders{
"kid": testECJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
}), ""))

t.Run("valid + expiry check", test(providerWithExpiryCheck, CreateTestJWT(t, jose.RS256, testRSAKeypair, JWTHeaders{
"kid": testRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
"exp": time.Now().Add(time.Hour).Unix(),
Expand All @@ -139,7 +139,7 @@ func TestJWTVerifyToken(t *testing.T) {

t.Run("valid but expired", test(providerWithExpiryCheck, CreateTestJWT(t, jose.RS256, testRSAKeypair, JWTHeaders{
"kid": testRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
// 2000-01-01T00:00:00Z
Expand All @@ -148,7 +148,7 @@ func TestJWTVerifyToken(t *testing.T) {

t.Run("valid but issued in the future", test(providerWithExpiryCheck, CreateTestJWT(t, jose.RS256, testRSAKeypair, JWTHeaders{
"kid": testRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
// 3000-01-01T00:00:00Z
Expand All @@ -158,7 +158,7 @@ func TestJWTVerifyToken(t *testing.T) {

invalidSignature := CreateTestJWT(t, jose.RS256, testRSAKeypair, JWTHeaders{
"kid": testRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
})
Expand All @@ -167,21 +167,21 @@ func TestJWTVerifyToken(t *testing.T) {

t.Run("valid JWT signed with an unknown key", test(baseProvider, CreateTestJWT(t, jose.RS256, testExtraKeypair, JWTHeaders{
"kid": testExtraJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
}), anyError))

t.Run("valid JWT signed with a mismatching KID", test(baseProvider, CreateTestJWT(t, jose.RS256, testExtraKeypair, JWTHeaders{
"kid": testRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
}), anyError))

t.Run("valid RSA signed with key with use=enc", test(baseProvider, CreateTestJWT(t, jose.RS256, testRSAKeypair, JWTHeaders{
"kid": testEncRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": testIssuer,
"aud": []string{testClientID},
}), anyError))
Expand All @@ -195,7 +195,7 @@ func TestJWTVerifyToken(t *testing.T) {

t.Run("valid RSA with invalid issuer", test(baseProvider, CreateTestJWT(t, jose.RS256, testRSAKeypair, JWTHeaders{
"kid": testRSAJWK.KeyID,
}, map[string]interface{}{
}, map[string]any{
"iss": "nonsense",
"aud": []string{testClientID},
}), "id token issued by a different provider"))
Expand Down
4 changes: 2 additions & 2 deletions auth/jwt_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ import (

// These are not in jwt_test.go to allow use in tests from other packages.

type JWTHeaders map[jose.HeaderKey]interface{}
type JWTHeaders map[jose.HeaderKey]any

// CreateTestJWT creates and signs a valid JWT with the given headers and claims.
// The key must be valid for use with gopkg.in/square/go-jose.v2 (https://pkg.go.dev/gopkg.in/square/go-jose.v2#readme-supported-key-types),
// and the alg must match the key.
func CreateTestJWT(t *testing.T, alg jose.SignatureAlgorithm, key interface{}, headers JWTHeaders, claims map[string]interface{}) string {
func CreateTestJWT(t *testing.T, alg jose.SignatureAlgorithm, key any, headers JWTHeaders, claims map[string]any) string {
t.Helper()

signerOpts := new(jose.SignerOptions)
Expand Down
13 changes: 6 additions & 7 deletions auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"net/http"
"net/url"
"reflect"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -220,11 +221,9 @@ func (opm OIDCProviderMap) GetProviderForIssuer(ctx context.Context, issuer stri
clientID := base.ValDefault(provider.ClientID, "")
if provider.Issuer == issuer && clientID != "" {
// Iterate over the audiences looking for a match
for _, aud := range audiences {
if clientID == aud {
base.DebugfCtx(ctx, base.KeyAuth, "Provider matches, returning")
return provider
}
if slices.Contains(audiences, clientID) {
base.DebugfCtx(ctx, base.KeyAuth, "Provider matches, returning")
return provider
}
}
}
Expand Down Expand Up @@ -436,7 +435,7 @@ func getJWTUsername(provider JWTConfigCommon, identity *Identity) (username stri
}

// formatUsername returns the string representation of the given username value.
func formatUsername(value interface{}) (string, error) {
func formatUsername(value any) (string, error) {
switch valueType := value.(type) {
case string:
return valueType, nil
Expand Down Expand Up @@ -725,7 +724,7 @@ func (op *OIDCProvider) stopDiscoverySync() {
// amount of time in seconds that the fetched responses are allowed to be used again (from the
// time when a request is made). The second value (ok) is true if max-age exists, and false if not.
func cacheControlMaxAge(header http.Header) (maxAge time.Duration, ok bool, err error) {
for _, field := range strings.Split(header.Get("Cache-Control"), ",") {
for field := range strings.SplitSeq(header.Get("Cache-Control"), ",") {
parts := strings.SplitN(strings.TrimSpace(field), "=", 2)
k := strings.ToLower(strings.TrimSpace(parts[0]))
if k != "max-age" {
Expand Down
32 changes: 16 additions & 16 deletions auth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ func TestIsStandardDiscovery(t *testing.T) {
func TestFormatUsername(t *testing.T) {
tests := []struct {
name string
username interface{}
username any
usernameExpected string
errorExpected error
}{{
Expand Down Expand Up @@ -1014,7 +1014,7 @@ func TestJWTRolesChannels(t *testing.T) {
)
type simulatedLogin struct {
explicitRoles, explicitChannels []string
claims map[string]interface{}
claims map[string]any
expectedRoles, expectedChannels []string
}
type testCase struct {
Expand All @@ -1037,7 +1037,7 @@ func TestJWTRolesChannels(t *testing.T) {
rolesClaimName: "roles",
logins: []simulatedLogin{
{
claims: map[string]interface{}{"roles": []string{"foo"}},
claims: map[string]any{"roles": []string{"foo"}},
expectedRoles: []string{"foo"},
expectedChannels: []string{"!"},
},
Expand All @@ -1048,7 +1048,7 @@ func TestJWTRolesChannels(t *testing.T) {
rolesClaimName: "roles",
logins: []simulatedLogin{
{
claims: map[string]interface{}{},
claims: map[string]any{},
expectedRoles: []string{},
expectedChannels: []string{"!"},
},
Expand All @@ -1060,7 +1060,7 @@ func TestJWTRolesChannels(t *testing.T) {
logins: []simulatedLogin{
{
explicitRoles: []string{"bar"},
claims: map[string]interface{}{"roles": []string{"foo"}},
claims: map[string]any{"roles": []string{"foo"}},
expectedRoles: []string{"foo", "bar"},
expectedChannels: []string{"!"},
},
Expand All @@ -1071,12 +1071,12 @@ func TestJWTRolesChannels(t *testing.T) {
rolesClaimName: "roles",
logins: []simulatedLogin{
{
claims: map[string]interface{}{"roles": []string{"foo"}},
claims: map[string]any{"roles": []string{"foo"}},
expectedRoles: []string{"foo"},
expectedChannels: []string{"!"},
},
{
claims: map[string]interface{}{"roles": []string{}},
claims: map[string]any{"roles": []string{}},
expectedRoles: []string{},
expectedChannels: []string{"!"},
},
Expand All @@ -1088,7 +1088,7 @@ func TestJWTRolesChannels(t *testing.T) {
logins: []simulatedLogin{
{
explicitRoles: []string{"foo"},
claims: map[string]interface{}{"roles": []string{"foo"}},
claims: map[string]any{"roles": []string{"foo"}},
expectedRoles: []string{"foo"},
expectedChannels: []string{"!"},
},
Expand All @@ -1100,13 +1100,13 @@ func TestJWTRolesChannels(t *testing.T) {
logins: []simulatedLogin{
{
explicitRoles: []string{"foo"},
claims: map[string]interface{}{"roles": []string{"foo"}},
claims: map[string]any{"roles": []string{"foo"}},
expectedRoles: []string{"foo"},
expectedChannels: []string{"!"},
},
{
explicitRoles: []string{"foo"},
claims: map[string]interface{}{"roles": []string{}},
claims: map[string]any{"roles": []string{}},
expectedRoles: []string{"foo"},
expectedChannels: []string{"!"},
},
Expand All @@ -1117,12 +1117,12 @@ func TestJWTRolesChannels(t *testing.T) {
rolesClaimName: "roles",
logins: []simulatedLogin{
{
claims: map[string]interface{}{"roles": []string{"foo"}},
claims: map[string]any{"roles": []string{"foo"}},
expectedRoles: []string{"foo"},
expectedChannels: []string{"!"},
},
{
claims: map[string]interface{}{"roles": []string{"bar"}},
claims: map[string]any{"roles": []string{"bar"}},
expectedRoles: []string{"bar"},
expectedChannels: []string{"!"},
},
Expand All @@ -1133,7 +1133,7 @@ func TestJWTRolesChannels(t *testing.T) {
channelsClaimName: "channels",
logins: []simulatedLogin{
{
claims: map[string]interface{}{"channels": []string{"foo"}},
claims: map[string]any{"channels": []string{"foo"}},
expectedRoles: []string{},
expectedChannels: []string{"!", "foo"},
},
Expand All @@ -1145,7 +1145,7 @@ func TestJWTRolesChannels(t *testing.T) {
logins: []simulatedLogin{
{
explicitChannels: []string{"bar"},
claims: map[string]interface{}{"channels": []string{"foo"}},
claims: map[string]any{"channels": []string{"foo"}},
expectedRoles: []string{},
expectedChannels: []string{"!", "foo", "bar"},
},
Expand All @@ -1157,7 +1157,7 @@ func TestJWTRolesChannels(t *testing.T) {
channelsClaimName: "channels",
logins: []simulatedLogin{
{
claims: map[string]interface{}{"roles": []string{"rFoo"}, "channels": []string{"cBar"}},
claims: map[string]any{"roles": []string{"rFoo"}, "channels": []string{"cBar"}},
expectedRoles: []string{"rFoo"},
expectedChannels: []string{"!", "cBar"},
},
Expand All @@ -1169,7 +1169,7 @@ func TestJWTRolesChannels(t *testing.T) {
channelsClaimName: "channels",
logins: []simulatedLogin{
{
claims: map[string]interface{}{"roles": []string{"rFoo"}, "channels": []string{"cBar"}},
claims: map[string]any{"roles": []string{"rFoo"}, "channels": []string{"cBar"}},
explicitRoles: []string{"rBaz"},
explicitChannels: []string{"cQux"},
expectedRoles: []string{"rFoo", "rBaz"},
Expand Down
Loading