Skip to content

Commit c78df76

Browse files
refactor: simplify logic
1 parent 6103588 commit c78df76

File tree

4 files changed

+154
-85
lines changed

4 files changed

+154
-85
lines changed

access_request.go

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ type AccessRequest struct {
77
GrantTypes Arguments `json:"grantTypes" gorethink:"grantTypes"`
88
HandledGrantType Arguments `json:"handledGrantType" gorethink:"handledGrantType"`
99

10-
RefreshTokenRequestedScope Arguments
11-
RefreshTokenGrantedScope Arguments
12-
1310
Request
1411
}
1512

@@ -27,26 +24,21 @@ func (a *AccessRequest) GetGrantTypes() Arguments {
2724
return a.GrantTypes
2825
}
2926

30-
func (a *AccessRequest) GetRefreshTokenRequestedScopes() (scopes Arguments) {
31-
if a.RefreshTokenRequestedScope == nil {
32-
return a.RequestedScope
33-
}
34-
35-
return a.RefreshTokenRequestedScope
27+
func (a *AccessRequest) SetGrantedScopes(scopes Arguments) {
28+
a.GrantedScope = scopes
3629
}
3730

38-
func (a *AccessRequest) SetRefreshTokenRequestedScopes(scopes Arguments) {
39-
a.RefreshTokenRequestedScope = scopes
40-
}
31+
func (a *AccessRequest) SanitizeRestoreRefreshTokenOriginalRequester(requester Requester) Requester {
32+
r := a.Sanitize(nil).(*Request)
4133

42-
func (a *AccessRequest) GetRefreshTokenGrantedScopes() (scopes Arguments) {
43-
if a.RefreshTokenGrantedScope == nil {
44-
return a.GrantedScope
34+
ar := &AccessRequest{
35+
Request: *r,
4536
}
4637

47-
return a.RefreshTokenGrantedScope
48-
}
38+
ar.SetID(requester.GetID())
39+
40+
ar.SetRequestedScopes(requester.GetRequestedScopes())
41+
ar.SetGrantedScopes(requester.GetGrantedScopes())
4942

50-
func (a *AccessRequest) SetRefreshTokenGrantedScopes(scopes Arguments) {
51-
a.RefreshTokenGrantedScope = scopes
43+
return ar
5244
}

handler/oauth2/flow_refresh.go

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,6 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex
9797
request.SetRequestedScopes(fosite.RemoveEmpty(strings.Split(scope, " ")))
9898
}
9999

100-
// If a new refresh token is issued, the refresh token scope MUST be identical to that of the refresh token included
101-
// by the client in the request.
102-
if rtRequest, ok := request.(fosite.RefreshTokenAccessRequester); ok {
103-
rtRequest.SetRefreshTokenRequestedScopes(originalRequest.GetRequestedScopes())
104-
rtRequest.SetRefreshTokenGrantedScopes(originalRequest.GetGrantedScopes())
105-
}
106-
107100
request.SetRequestedAudience(originalRequest.GetRequestedAudience())
108101

109102
strategy := c.Config.GetScopeStrategy(ctx)
@@ -167,30 +160,28 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
167160
err = c.handleRefreshTokenEndpointStorageError(ctx, err)
168161
}()
169162

170-
ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
163+
originalRequest, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
171164
if err != nil {
172165
return err
173-
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
166+
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, originalRequest.GetID()); err != nil {
174167
return err
175168
}
176169

177-
if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil {
170+
if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, originalRequest.GetID(), signature); err != nil {
178171
return err
179172
}
180173

181174
storeReq := requester.Sanitize([]string{})
182-
storeReq.SetID(ts.GetID())
175+
storeReq.SetID(originalRequest.GetID())
183176

184177
if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
185178
return err
186179
}
187180

188181
if rtRequest, ok := requester.(fosite.RefreshTokenAccessRequester); ok {
189-
rtStoreReq := requester.Sanitize([]string{}).(*fosite.Request)
190-
rtStoreReq.SetID(ts.GetID())
182+
rtStoreReq := rtRequest.SanitizeRestoreRefreshTokenOriginalRequester(originalRequest)
191183

192-
rtStoreReq.RequestedScope = rtRequest.GetRefreshTokenRequestedScopes()
193-
rtStoreReq.GrantedScope = rtRequest.GetRefreshTokenGrantedScopes()
184+
rtStoreReq.SetSession(requester.GetSession().Clone())
194185

195186
if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, rtStoreReq); err != nil {
196187
return err

integration/refresh_token_grant_test.go

Lines changed: 134 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package integration_test
55

66
import (
7+
"context"
78
"encoding/json"
89
"net/http"
910
"net/http/httptest"
@@ -12,15 +13,14 @@ import (
1213
"testing"
1314
"time"
1415

15-
"github.com/ory/fosite/internal/gen"
16-
1716
"github.com/stretchr/testify/assert"
1817
"github.com/stretchr/testify/require"
1918
"golang.org/x/oauth2"
2019

2120
"github.com/ory/fosite"
2221
"github.com/ory/fosite/compose"
2322
"github.com/ory/fosite/handler/openid"
23+
"github.com/ory/fosite/internal/gen"
2424
"github.com/ory/fosite/token/jwt"
2525
)
2626

@@ -266,7 +266,9 @@ func TestRefreshTokenFlow(t *testing.T) {
266266
}
267267
}
268268

269-
func TestRefreshTokenFlowScopeNarrowing(t *testing.T) {
269+
func TestRefreshTokenFlowScopeParameter(t *testing.T) {
270+
ctx := context.Background()
271+
270272
session := &defaultSession{
271273
DefaultSession: &openid.DefaultSession{
272274
Claims: &jwt.IDTokenClaims{
@@ -278,17 +280,16 @@ func TestRefreshTokenFlowScopeNarrowing(t *testing.T) {
278280
},
279281
}
280282
fc := new(fosite.Config)
281-
fc.RefreshTokenLifespan = -1
282283
fc.GlobalSecret = []byte("some-secret-thats-random-some-secret-thats-random-")
283284
f := compose.ComposeAllEnabled(fc, fositeStore, gen.MustRSAKey())
284285
ts := mockServer(t, f, session)
285286
defer ts.Close()
286287

287288
fc.ScopeStrategy = fosite.ExactScopeStrategy
288289

289-
oauthClient := newOAuth2Client(ts)
290-
oauthClient.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"}
291-
oauthClient.ClientID = "grant-all-requested-scopes-client"
290+
client := newOAuth2Client(ts)
291+
client.Scopes = []string{"openid", "offline", "offline_access", "foo", "bar"}
292+
client.ClientID = "grant-all-requested-scopes-client"
292293

293294
state := "1234567890"
294295

@@ -304,53 +305,146 @@ func TestRefreshTokenFlowScopeNarrowing(t *testing.T) {
304305

305306
fositeStore.Clients["grant-all-requested-scopes-client"] = testRefreshingClient
306307

307-
resp, err := http.Get(oauthClient.AuthCodeURL(state))
308+
s := compose.NewOAuth2HMACStrategy(fc)
309+
310+
originalScopes := fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar"}
311+
312+
testCases := []struct {
313+
name string
314+
scopes fosite.Arguments
315+
expected fosite.Arguments
316+
err string
317+
}{
318+
{
319+
"ShouldGrantOriginalScopesWhenOmitted",
320+
nil,
321+
originalScopes,
322+
"",
323+
},
324+
{
325+
"ShouldNarrowScopesWhenIncluded",
326+
fosite.Arguments{"openid", "offline_access", "foo"},
327+
fosite.Arguments{"openid", "offline_access", "foo"},
328+
"",
329+
},
330+
{
331+
"ShouldGrantOriginalScopesWhenOmittedAfterNarrowing",
332+
nil,
333+
originalScopes,
334+
"",
335+
},
336+
{
337+
"ShouldGrantOriginalScopesExplicitlyRequested",
338+
originalScopes,
339+
originalScopes,
340+
"",
341+
},
342+
{
343+
"ShouldErrorWhenBroadeningScopesAllowedByClientButNotOriginallyGranted",
344+
fosite.Arguments{"openid", "offline", "offline_access", "foo", "bar", "baz"},
345+
nil,
346+
"The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.",
347+
},
348+
}
349+
350+
type step struct {
351+
OAuth2 *oauth2.Token
352+
SessionAT, SessionRT fosite.Requester
353+
}
354+
355+
entries := make([]step, len(testCases)+1)
356+
357+
resp, err := http.Get(client.AuthCodeURL(state))
308358
require.NoError(t, err)
309359
require.Equal(t, http.StatusOK, resp.StatusCode)
310360

311-
token, err := oauthClient.Exchange(oauth2.NoContext, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", oauthClient.ClientID))
361+
entries[0].OAuth2, err = client.Exchange(ctx, resp.Request.URL.Query().Get("code"), oauth2.SetAuthURLParam("client_id", client.ClientID))
362+
312363
require.NoError(t, err)
313-
require.NotEmpty(t, token.AccessToken)
314-
require.NotEmpty(t, token.RefreshToken)
364+
require.NotEmpty(t, entries[0].OAuth2.AccessToken)
365+
require.NotEmpty(t, entries[0].OAuth2.RefreshToken)
315366

316-
assert.Equal(t, "openid offline offline_access foo bar", token.Extra("scope"))
367+
assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope"))
317368

318-
token1Refresh, err := doRefresh(oauthClient, token, nil)
369+
entries[0].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[0].OAuth2.AccessToken), nil)
319370
require.NoError(t, err)
320-
require.NotEmpty(t, token1Refresh.AccessToken)
321-
require.NotEmpty(t, token1Refresh.RefreshToken)
322-
323-
assert.Equal(t, "openid offline offline_access foo bar", token1Refresh.Extra("scope"))
324371

325-
token2Refresh, err := doRefresh(oauthClient, token1Refresh, []string{"openid", "offline_access", "foo"})
372+
entries[0].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[0].OAuth2.RefreshToken), nil)
326373
require.NoError(t, err)
327-
require.NotEmpty(t, token2Refresh.AccessToken)
328-
require.NotEmpty(t, token2Refresh.RefreshToken)
329374

330-
assert.Equal(t, "openid offline_access foo", token2Refresh.Extra("scope"))
375+
assert.ElementsMatch(t, entries[0].SessionAT.GetRequestedScopes(), originalScopes)
376+
assert.ElementsMatch(t, entries[0].SessionRT.GetRequestedScopes(), originalScopes)
377+
assert.ElementsMatch(t, entries[0].SessionAT.GetGrantedScopes(), originalScopes)
378+
assert.ElementsMatch(t, entries[0].SessionRT.GetGrantedScopes(), originalScopes)
379+
assert.Equal(t, strings.Join(originalScopes, " "), entries[0].OAuth2.Extra("scope"))
331380

332-
token3Refresh, err := doRefresh(oauthClient, token2Refresh, []string{"openid", "offline", "offline_access", "foo", "bar"})
333-
require.NoError(t, err)
334-
require.NotEmpty(t, token3Refresh.AccessToken)
335-
require.NotEmpty(t, token3Refresh.RefreshToken)
381+
for i, tc := range testCases {
382+
t.Run(tc.name, func(t *testing.T) {
383+
time.Sleep(time.Second)
336384

337-
assert.Equal(t, "openid offline offline_access foo bar", token3Refresh.Extra("scope"))
385+
idx := i + 1
338386

339-
token4Refresh, err := doRefresh(oauthClient, token3Refresh, []string{"openid", "offline", "offline_access", "foo", "bar", "baz"})
340-
require.Error(t, err)
341-
require.Nil(t, token4Refresh)
342-
require.Contains(t, err.Error(), "The requested scope is invalid, unknown, or malformed. The requested scope 'baz' was not originally granted by the resource owner.")
343-
}
387+
opts := []oauth2.AuthCodeOption{
388+
oauth2.SetAuthURLParam("refresh_token", entries[i].OAuth2.RefreshToken),
389+
oauth2.SetAuthURLParam("grant_type", "refresh_token"),
390+
}
344391

345-
func doRefresh(client *oauth2.Config, t *oauth2.Token, scopes []string) (token *oauth2.Token, err error) {
346-
opts := []oauth2.AuthCodeOption{
347-
oauth2.SetAuthURLParam("refresh_token", t.RefreshToken),
348-
oauth2.SetAuthURLParam("grant_type", "refresh_token"),
349-
}
392+
if len(tc.scopes) != 0 {
393+
opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(tc.scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID))
394+
}
350395

351-
if len(scopes) != 0 {
352-
opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(scopes, " ")), oauth2.SetAuthURLParam("client_id", client.ClientID))
353-
}
396+
entries[idx].OAuth2, err = client.Exchange(ctx, "", opts...)
397+
if len(tc.err) != 0 {
398+
require.Error(t, err)
399+
require.Nil(t, entries[idx].OAuth2)
400+
require.Contains(t, err.Error(), tc.err)
401+
402+
return
403+
}
354404

355-
return client.Exchange(oauth2.NoContext, "", opts...)
405+
require.NoError(t, err)
406+
require.NotEmpty(t, entries[idx].OAuth2.AccessToken)
407+
require.NotEmpty(t, entries[idx].OAuth2.RefreshToken)
408+
409+
entries[idx].SessionAT, err = fositeStore.GetAccessTokenSession(ctx, s.AccessTokenSignature(ctx, entries[idx].OAuth2.AccessToken), nil)
410+
require.NoError(t, err)
411+
412+
entries[idx].SessionRT, err = fositeStore.GetRefreshTokenSession(ctx, s.RefreshTokenSignature(ctx, entries[idx].OAuth2.RefreshToken), nil)
413+
require.NoError(t, err)
414+
415+
if len(tc.scopes) != 0 {
416+
assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), tc.scopes)
417+
assert.Equal(t, strings.Join(tc.expected, " "), entries[idx].OAuth2.Extra("scope"))
418+
} else {
419+
assert.ElementsMatch(t, entries[idx].SessionAT.GetRequestedScopes(), originalScopes)
420+
assert.Equal(t, strings.Join(originalScopes, " "), entries[idx].OAuth2.Extra("scope"))
421+
}
422+
assert.ElementsMatch(t, entries[idx].SessionAT.GetGrantedScopes(), tc.expected)
423+
assert.ElementsMatch(t, entries[idx].SessionRT.GetRequestedScopes(), originalScopes)
424+
assert.ElementsMatch(t, entries[idx].SessionRT.GetGrantedScopes(), originalScopes)
425+
426+
var (
427+
j int
428+
entry step
429+
)
430+
431+
assert.Equal(t, entries[idx].SessionAT.GetID(), entries[idx].SessionRT.GetID())
432+
433+
for j, entry = range entries {
434+
if j == idx {
435+
break
436+
}
437+
438+
assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionAT.GetID())
439+
assert.Equal(t, entries[idx].SessionAT.GetID(), entry.SessionRT.GetID())
440+
assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionAT.GetID())
441+
assert.Equal(t, entries[idx].SessionRT.GetID(), entry.SessionRT.GetID())
442+
443+
assert.Greater(t, entries[idx].SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix(), entry.SessionAT.GetSession().GetExpiresAt(fosite.AccessToken).Unix())
444+
assert.Greater(t, entries[idx].SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix(), entry.SessionRT.GetSession().GetExpiresAt(fosite.RefreshToken).Unix())
445+
assert.Greater(t, entries[idx].SessionAT.GetRequestedAt().Unix(), entry.SessionAT.GetRequestedAt().Unix())
446+
assert.Greater(t, entries[idx].SessionRT.GetRequestedAt().Unix(), entry.SessionRT.GetRequestedAt().Unix())
447+
}
448+
})
449+
}
356450
}

oauth2.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,9 @@ type Requester interface {
248248
// RefreshTokenAccessRequester is an extended AccessRequester implementation that allows preserving
249249
// the original Requester.
250250
type RefreshTokenAccessRequester interface {
251-
// GetRefreshTokenRequestedScopes returns the request's scopes specifically for the refresh token.
252-
GetRefreshTokenRequestedScopes() (scopes Arguments)
253-
254-
// SetRefreshTokenRequestedScopes sets the request's scopes specifically for the refresh token.
255-
SetRefreshTokenRequestedScopes(scopes Arguments)
256-
257-
// GetRefreshTokenGrantedScopes returns all granted scopes specifically for the refresh token.
258-
GetRefreshTokenGrantedScopes() (scopes Arguments)
259-
260-
// SetRefreshTokenGrantedScopes sets all granted scopes specifically for the refresh token.
261-
SetRefreshTokenGrantedScopes(scopes Arguments)
251+
// SanitizeRestoreRefreshTokenOriginalRequester returns a sanitized copy of this Requester and mutates the relevant
252+
// values from the provided Requester which is the original refresh token session Requester.
253+
SanitizeRestoreRefreshTokenOriginalRequester(requester Requester) Requester
262254

263255
AccessRequester
264256
}

0 commit comments

Comments
 (0)