Skip to content

Commit 3feb791

Browse files
committed
chore: remove legacy JWT-based flow state handling
1 parent c675749 commit 3feb791

File tree

4 files changed

+38
-233
lines changed

4 files changed

+38
-233
lines changed

internal/api/context.go

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ const (
3333
oauthVerifierKey = contextKey("oauth_verifier")
3434
ssoProviderKey = contextKey("sso_provider")
3535
externalHostKey = contextKey("external_host")
36-
flowStateKey = contextKey("flow_state_id")
3736
oauthClientStateKey = contextKey("oauth_client_state_id")
3837
flowStateContextKey = contextKey("flow_state")
3938
)
@@ -128,18 +127,6 @@ func withInviteToken(ctx context.Context, token string) context.Context {
128127
return context.WithValue(ctx, inviteTokenKey, token)
129128
}
130129

131-
func withFlowStateID(ctx context.Context, FlowStateID string) context.Context {
132-
return context.WithValue(ctx, flowStateKey, FlowStateID)
133-
}
134-
135-
func getFlowStateID(ctx context.Context) string {
136-
obj := ctx.Value(flowStateKey)
137-
if obj == nil {
138-
return ""
139-
}
140-
return obj.(string)
141-
}
142-
143130
func withOAuthClientStateID(ctx context.Context, oauthClientStateID uuid.UUID) context.Context {
144131
return context.WithValue(ctx, oauthClientStateKey, oauthClientStateID)
145132
}

internal/api/external.go

Lines changed: 4 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010

1111
"github.com/fatih/structs"
1212
"github.com/gofrs/uuid"
13-
jwt "github.com/golang-jwt/jwt/v5"
1413
"github.com/sirupsen/logrus"
1514
"github.com/supabase/auth/internal/api/apierrors"
1615
"github.com/supabase/auth/internal/api/provider"
@@ -23,18 +22,6 @@ import (
2322
"golang.org/x/oauth2"
2423
)
2524

26-
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
27-
type ExternalProviderClaims struct {
28-
AuthMicroserviceClaims
29-
Provider string `json:"provider"`
30-
InviteToken string `json:"invite_token,omitempty"`
31-
Referrer string `json:"referrer,omitempty"`
32-
FlowStateID string `json:"flow_state_id"`
33-
OAuthClientStateID string `json:"oauth_client_state_id,omitempty"`
34-
LinkingTargetID string `json:"linking_target_id,omitempty"`
35-
EmailOptional bool `json:"email_optional,omitempty"`
36-
}
37-
3825
// ExternalProviderRedirect redirects the request to the oauth provider
3926
func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error {
4027
rurl, err := a.GetExternalProviderRedirectURL(w, r, nil)
@@ -203,20 +190,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
203190
providerAccessToken := data.token
204191
providerRefreshToken := data.refreshToken
205192

206-
// Get flow state from context (new UUID format) or load from FlowStateID (legacy JWT format)
207193
flowState := getFlowState(ctx)
208-
if flowState == nil {
209-
// Backward compatibility: load from FlowStateID for legacy JWT state
210-
// To be removed in subsequent release.
211-
if flowStateID := getFlowStateID(ctx); flowStateID != "" {
212-
flowState, err = models.FindFlowStateByID(db, flowStateID)
213-
if models.IsNotFoundError(err) {
214-
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err)
215-
} else if err != nil {
216-
return apierrors.NewInternalServerError("Failed to find flow state").WithInternalError(err)
217-
}
218-
}
219-
}
220194

221195
targetUser := getTargetUser(ctx)
222196
inviteToken := getInviteToken(ctx)
@@ -545,13 +519,12 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag
545519
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
546520
}
547521

548-
// Try to parse state as UUID first (new format)
549-
if stateUUID, err := uuid.FromString(state); err == nil {
550-
return a.loadExternalStateFromUUID(ctx, db, stateUUID)
522+
stateUUID, err := uuid.FromString(state)
523+
if err != nil {
524+
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth state parameter is invalid")
551525
}
552526

553-
// Fall back to JWT parsing for backward compatibility
554-
return a.loadExternalStateFromJWT(ctx, db, state)
527+
return a.loadExternalStateFromUUID(ctx, db, stateUUID)
555528
}
556529

557530
// loadExternalStateFromUUID loads OAuth state from a flow_state record (new UUID format)
@@ -598,75 +571,6 @@ func (a *API) loadExternalStateFromUUID(ctx context.Context, db *storage.Connect
598571
return withSignature(ctx, stateID.String()), nil
599572
}
600573

601-
// loadExternalStateFromJWT loads OAuth state from a JWT (legacy format for backward compatibility)
602-
func (a *API) loadExternalStateFromJWT(ctx context.Context, db *storage.Connection, state string) (context.Context, error) {
603-
config := a.config
604-
claims := ExternalProviderClaims{}
605-
p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods))
606-
_, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) {
607-
if kid, ok := token.Header["kid"]; ok {
608-
if kidStr, ok := kid.(string); ok {
609-
key, err := conf.FindPublicKeyByKid(kidStr, &config.JWT)
610-
if err != nil {
611-
return nil, err
612-
}
613-
614-
if key != nil {
615-
return key, nil
616-
}
617-
618-
// otherwise try to use fallback
619-
}
620-
}
621-
if alg, ok := token.Header["alg"]; ok {
622-
if alg == jwt.SigningMethodHS256.Name {
623-
// preserve backward compatibility for cases where the kid is not set or potentially invalid but the key can be decoded with the secret
624-
return []byte(config.JWT.Secret), nil
625-
}
626-
}
627-
628-
return nil, fmt.Errorf("unrecognized JWT kid %v for algorithm %v", token.Header["kid"], token.Header["alg"])
629-
})
630-
if err != nil {
631-
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
632-
}
633-
if claims.Provider == "" {
634-
return ctx, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)")
635-
}
636-
if claims.InviteToken != "" {
637-
ctx = withInviteToken(ctx, claims.InviteToken)
638-
}
639-
if claims.Referrer != "" {
640-
ctx = withExternalReferrer(ctx, claims.Referrer)
641-
}
642-
if claims.FlowStateID != "" {
643-
ctx = withFlowStateID(ctx, claims.FlowStateID)
644-
}
645-
if claims.OAuthClientStateID != "" {
646-
oauthClientStateID, err := uuid.FromString(claims.OAuthClientStateID)
647-
if err != nil {
648-
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (oauth_client_state_id must be UUID)")
649-
}
650-
ctx = withOAuthClientStateID(ctx, oauthClientStateID)
651-
}
652-
if claims.LinkingTargetID != "" {
653-
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
654-
if err != nil {
655-
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)")
656-
}
657-
u, err := models.FindUserByID(db, linkingTargetUserID)
658-
if err != nil {
659-
if models.IsNotFoundError(err) {
660-
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found")
661-
}
662-
return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err)
663-
}
664-
ctx = withTargetUser(ctx, u)
665-
}
666-
ctx = withExternalProviderType(ctx, claims.Provider, claims.EmailOptional)
667-
return withSignature(ctx, state), nil
668-
}
669-
670574
// Provider returns a Provider interface for the given name.
671575
func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, conf.OAuthProviderConfiguration, error) {
672576
config := a.config

internal/api/external_oauth.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ type OAuthProviderData struct {
2626
code string
2727
}
2828

29-
// loadFlowState parses the `state` query parameter as a JWS payload,
30-
// extracting the provider requested
29+
// loadFlowState parses the `state` query parameter as a UUID,
30+
// loads the flow state from the database, and extracts the provider requested
3131
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
3232
ctx := r.Context()
3333
db := a.db.WithContext(ctx)

internal/api/external_test.go

Lines changed: 32 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@ import (
66
"net/http/httptest"
77
"net/url"
88
"testing"
9-
"time"
109

1110
"github.com/gofrs/uuid"
12-
jwt "github.com/golang-jwt/jwt/v5"
1311
"github.com/stretchr/testify/require"
1412
"github.com/stretchr/testify/suite"
1513
"github.com/supabase/auth/internal/conf"
1614
"github.com/supabase/auth/internal/models"
17-
"github.com/supabase/auth/internal/tokens"
1815
)
1916

2017
type ExternalTestSuite struct {
@@ -352,145 +349,62 @@ func setupGenericOAuthServer(ts *ExternalTestSuite, code string) *httptest.Serve
352349
return server
353350
}
354351

355-
// TestOAuthState_BackwardCompatibleJWT tests that the callback endpoint
356-
// still accepts the legacy JWT state format for backward compatibility during migration.
357-
func (ts *ExternalTestSuite) TestOAuthState_BackwardCompatibleJWT() {
352+
// TestOAuthState_UUIDFormat tests that the callback endpoint processes UUID state correctly.
353+
func (ts *ExternalTestSuite) TestOAuthState_UUIDFormat() {
358354
code := "authcode"
359355
server := setupGenericOAuthServer(ts, code)
360356
defer server.Close()
361357

362-
// Create a legacy JWT state token manually
363-
claims := &ExternalProviderClaims{
364-
AuthMicroserviceClaims: AuthMicroserviceClaims{
365-
RegisteredClaims: jwt.RegisteredClaims{
366-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
367-
IssuedAt: jwt.NewNumericDate(time.Now()),
368-
Issuer: ts.Config.JWT.Issuer,
369-
},
370-
},
371-
Provider: "github",
372-
Referrer: "https://example.com/admin",
373-
EmailOptional: false,
374-
}
358+
// Use the standard authorization flow which generates UUID state
359+
w := performAuthorizationRequest(ts, "github", "")
360+
ts.Require().Equal(http.StatusFound, w.Code)
361+
u, err := url.Parse(w.Header().Get("Location"))
362+
ts.Require().NoError(err)
375363

376-
jwtState, err := tokens.SignJWT(&ts.Config.JWT, claims)
377-
require.NoError(ts.T(), err)
378-
require.NotEmpty(ts.T(), jwtState)
364+
state := u.Query().Get("state")
365+
ts.Require().NotEmpty(state)
366+
367+
stateUUID, err := uuid.FromString(state)
368+
require.NoError(ts.T(), err, "state should be a valid UUID")
369+
require.NotEqual(ts.T(), uuid.Nil, stateUUID)
379370

380371
testURL, err := url.Parse("http://localhost/callback")
381372
require.NoError(ts.T(), err)
382373
v := testURL.Query()
383374
v.Set("code", code)
384-
v.Set("state", jwtState)
375+
v.Set("state", state)
385376
testURL.RawQuery = v.Encode()
386377

387378
req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
388-
w := httptest.NewRecorder()
379+
w = httptest.NewRecorder()
389380
ts.API.handler.ServeHTTP(w, req)
390381

391382
ts.Require().Equal(http.StatusFound, w.Code)
392-
u, err := url.Parse(w.Header().Get("Location"))
393-
ts.Require().NoError(err, "redirect url parse failed")
394-
ts.Require().Equal("/admin", u.Path)
395-
396-
fragment, err := url.ParseQuery(u.Fragment)
383+
resultURL, err := url.Parse(w.Header().Get("Location"))
397384
ts.Require().NoError(err)
398-
ts.NotEmpty(fragment.Get("access_token"), "should have access_token")
399-
ts.NotEmpty(fragment.Get("refresh_token"), "should have refresh_token")
400385

401-
user, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
402-
require.NoError(ts.T(), err)
403-
require.NotNil(ts.T(), user)
386+
fragment, err := url.ParseQuery(resultURL.Fragment)
387+
ts.Require().NoError(err)
388+
ts.NotEmpty(fragment.Get("access_token"), "UUID state should result in access_token")
404389
}
405390

406-
// TestOAuthState_MigrationScenario tests that both UUID and JWT state formats
407-
// can be processed during the migration period.
408-
func (ts *ExternalTestSuite) TestOAuthState_MigrationScenario() {
391+
// TestOAuthState_InvalidFormat tests that non-UUID state parameters are rejected.
392+
func (ts *ExternalTestSuite) TestOAuthState_InvalidFormat() {
409393
code := "authcode"
410394
server := setupGenericOAuthServer(ts, code)
411395
defer server.Close()
412396

413-
ts.Run("NewUUIDFormat", func() {
414-
// Use the standard authorization flow which now generates UUID state
415-
w := performAuthorizationRequest(ts, "github", "")
416-
ts.Require().Equal(http.StatusFound, w.Code)
417-
u, err := url.Parse(w.Header().Get("Location"))
418-
ts.Require().NoError(err)
419-
420-
state := u.Query().Get("state")
421-
ts.Require().NotEmpty(state)
422-
423-
// Verify state is a valid UUID
424-
stateUUID, err := uuid.FromString(state)
425-
require.NoError(ts.T(), err, "state should be a valid UUID")
426-
require.NotEqual(ts.T(), uuid.Nil, stateUUID)
427-
428-
// Complete the callback
429-
testURL, err := url.Parse("http://localhost/callback")
430-
require.NoError(ts.T(), err)
431-
v := testURL.Query()
432-
v.Set("code", code)
433-
v.Set("state", state)
434-
testURL.RawQuery = v.Encode()
435-
436-
req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
437-
w = httptest.NewRecorder()
438-
ts.API.handler.ServeHTTP(w, req)
439-
440-
ts.Require().Equal(http.StatusFound, w.Code)
441-
resultURL, err := url.Parse(w.Header().Get("Location"))
442-
ts.Require().NoError(err)
443-
444-
fragment, err := url.ParseQuery(resultURL.Fragment)
445-
ts.Require().NoError(err)
446-
ts.NotEmpty(fragment.Get("access_token"), "UUID state should result in access_token")
447-
})
448-
449-
// Clean up user for next test
450-
user, _ := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
451-
if user != nil {
452-
require.NoError(ts.T(), ts.API.db.Destroy(user))
453-
}
454-
455-
ts.Run("LegacyJWTFormat", func() {
456-
// Create a legacy JWT state
457-
claims := &ExternalProviderClaims{
458-
AuthMicroserviceClaims: AuthMicroserviceClaims{
459-
RegisteredClaims: jwt.RegisteredClaims{
460-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
461-
IssuedAt: jwt.NewNumericDate(time.Now()),
462-
Issuer: ts.Config.JWT.Issuer,
463-
},
464-
},
465-
Provider: "github",
466-
Referrer: "https://example.com/admin",
467-
}
468-
469-
jwtState, err := tokens.SignJWT(&ts.Config.JWT, claims)
470-
require.NoError(ts.T(), err)
471-
472-
// Verify state is NOT a UUID (it's a JWT)
473-
_, uuidErr := uuid.FromString(jwtState)
474-
require.Error(ts.T(), uuidErr, "JWT state should not be parseable as UUID")
475-
476-
// Complete the callback with JWT state
477-
testURL, err := url.Parse("http://localhost/callback")
478-
require.NoError(ts.T(), err)
479-
v := testURL.Query()
480-
v.Set("code", code)
481-
v.Set("state", jwtState)
482-
testURL.RawQuery = v.Encode()
483-
484-
req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
485-
w := httptest.NewRecorder()
486-
ts.API.handler.ServeHTTP(w, req)
397+
testURL, err := url.Parse("http://localhost/callback")
398+
require.NoError(ts.T(), err)
399+
v := testURL.Query()
400+
v.Set("code", code)
401+
v.Set("state", "not-a-valid-uuid")
402+
testURL.RawQuery = v.Encode()
487403

488-
ts.Require().Equal(http.StatusFound, w.Code)
489-
resultURL, err := url.Parse(w.Header().Get("Location"))
490-
ts.Require().NoError(err)
404+
req := httptest.NewRequest(http.MethodGet, testURL.String(), nil)
405+
w := httptest.NewRecorder()
406+
ts.API.handler.ServeHTTP(w, req)
491407

492-
fragment, err := url.ParseQuery(resultURL.Fragment)
493-
ts.Require().NoError(err)
494-
ts.NotEmpty(fragment.Get("access_token"), "JWT state should also result in access_token")
495-
})
408+
// Should redirect to site URL with error since state is invalid
409+
ts.Require().Equal(http.StatusSeeOther, w.Code)
496410
}

0 commit comments

Comments
 (0)