diff --git a/validator/option.go b/validator/option.go index 12c1cc61..013b4a92 100644 --- a/validator/option.go +++ b/validator/option.go @@ -26,3 +26,13 @@ func WithCustomClaims(f func() CustomClaims) Option { v.customClaims = f } } + +// WithSkipIssuerURLVerification is an option which sets up the allowed +// clock skew for the token. Note that in order to use this +// the expected claims Time field MUST not be time.IsZero(). +// If this option is not used clock skew is not allowed. +func WithSkipIssuerURLVerification(skip bool) Option { + return func(v *Validator) { + v.skipIssuerURLVerification = skip + } +} diff --git a/validator/validator.go b/validator/validator.go index 2a302493..7e8a83cc 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -28,11 +28,12 @@ const ( // Validator to use with the jose v2 package. type Validator struct { - keyFunc func(context.Context) (interface{}, error) // Required. - signatureAlgorithm SignatureAlgorithm // Required. - expectedClaims jwt.Expected // Internal. - customClaims func() CustomClaims // Optional. - allowedClockSkew time.Duration // Optional. + keyFunc func(context.Context) (interface{}, error) // Required. + signatureAlgorithm SignatureAlgorithm // Required. + expectedClaims jwt.Expected // Internal. + customClaims func() CustomClaims // Optional. + allowedClockSkew time.Duration // Optional. + skipIssuerURLVerification bool // Optional. } // SignatureAlgorithm is a signature algorithm. @@ -83,6 +84,7 @@ func New( Issuer: issuerURL, Audience: audience, }, + skipIssuerURLVerification: false, } for _, opt := range opts { @@ -108,7 +110,7 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte return nil, fmt.Errorf("failed to deserialize token claims: %w", err) } - if err = validateClaimsWithLeeway(registeredClaims, v.expectedClaims, v.allowedClockSkew); err != nil { + if err = validateClaimsWithLeeway(registeredClaims, v.expectedClaims, v.allowedClockSkew, v.skipIssuerURLVerification); err != nil { return nil, fmt.Errorf("expected claims not validated: %w", err) } @@ -134,11 +136,11 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte return validatedClaims, nil } -func validateClaimsWithLeeway(actualClaims jwt.Claims, expected jwt.Expected, leeway time.Duration) error { +func validateClaimsWithLeeway(actualClaims jwt.Claims, expected jwt.Expected, leeway time.Duration, skipIssuerURLVerification bool) error { expectedClaims := expected expectedClaims.Time = time.Now() - if actualClaims.Issuer != expectedClaims.Issuer { + if !skipIssuerURLVerification && actualClaims.Issuer != expectedClaims.Issuer { return jwt.ErrInvalidIssuer } diff --git a/validator/validator_test.go b/validator/validator_test.go index 08feeb14..dca9220e 100644 --- a/validator/validator_test.go +++ b/validator/validator_test.go @@ -29,13 +29,14 @@ func TestValidator_ValidateToken(t *testing.T) { ) testCases := []struct { - name string - token string - keyFunc func(context.Context) (interface{}, error) - algorithm SignatureAlgorithm - customClaims func() CustomClaims - expectedError error - expectedClaims *ValidatedClaims + name string + token string + keyFunc func(context.Context) (interface{}, error) + algorithm SignatureAlgorithm + customClaims func() CustomClaims + expectedError error + expectedClaims *ValidatedClaims + skipIssuerURLVerification bool }{ { name: "it successfully validates a token", @@ -205,6 +206,25 @@ func TestValidator_ValidateToken(t *testing.T) { algorithm: HS256, expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrInvalidIssuer), }, + { + name: "it successfully validates a token when token issuer is invalid but skip issuer url verification is true", + token: "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwczovL2hhY2tlZC1qd3QtbWlkZGxld2FyZS5ldS5hdXRoMC5jb20vIiwiaWF0IjoxNzI1MzUyNzQ0LCJleHAiOjE3NTY4ODg4MDAsImF1ZCI6Imh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyIsInN1YiI6IjEyMzQ1Njc4OTAifQ.-ruuyhRkx4T_1HZUQw3eKNWIhV3utPO_e7FagciLk50", + skipIssuerURLVerification: true, + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: "https://hacked-jwt-middleware.eu.auth0.com/", + Subject: subject, + Audience: []string{audience}, + Expiry: 1756888800, + NotBefore: 0, + IssuedAt: 1725352744, + }, + }, + }, } for _, testCase := range testCases { @@ -222,6 +242,8 @@ func TestValidator_ValidateToken(t *testing.T) { ) require.NoError(t, err) + validator.skipIssuerURLVerification = testCase.skipIssuerURLVerification + tokenClaims, err := validator.ValidateToken(context.Background(), testCase.token) if testCase.expectedError != nil { assert.EqualError(t, err, testCase.expectedError.Error())