Skip to content

Commit f225950

Browse files
committed
Implement context propagation through to SessionDetail
1 parent 1975143 commit f225950

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

gateway/auth_manager.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gateway
22

33
import (
4+
"context"
45
"encoding/base64"
56
"encoding/json"
67
"strings"
@@ -21,6 +22,7 @@ type SessionHandler interface {
2122
UpdateSession(keyName string, session *user.SessionState, resetTTLTo int64, hashed bool) error
2223
RemoveSession(orgID string, keyName string, hashed bool) bool
2324
SessionDetail(orgID string, keyName string, hashed bool) (user.SessionState, bool)
25+
SessionDetailContext(ctx context.Context, orgID string, keyName string, hashed bool) (user.SessionState, bool)
2426
KeyExpired(newSession *user.SessionState) bool
2527
Sessions(filter string) []string
2628
ResetQuota(string, *user.SessionState, bool)
@@ -215,6 +217,84 @@ func (b *DefaultSessionManager) SessionDetail(orgID string, keyName string, hash
215217
return session.Clone(), true
216218
}
217219

220+
// SessionDetailContext returns the session detail using the storage engine with context support for cancellation
221+
func (b *DefaultSessionManager) SessionDetailContext(ctx context.Context, orgID string, keyName string, hashed bool) (user.SessionState, bool) {
222+
select {
223+
case <-ctx.Done():
224+
log.WithFields(logrus.Fields{
225+
"prefix": "auth-mgr",
226+
"inbound-key": b.Gw.obfuscateKey(keyName),
227+
}).Debug("Context cancelled before session detail fetch")
228+
return user.SessionState{}, false
229+
default:
230+
}
231+
232+
var jsonKeyVal string
233+
var err error
234+
keyId := keyName
235+
236+
// get session by key
237+
if hashed {
238+
jsonKeyVal, err = b.store.GetRawKey(b.store.GetKeyPrefix() + keyName)
239+
} else {
240+
if storage.TokenOrg(keyName) != orgID {
241+
// try to get legacy and new format key at once
242+
toSearchList := []string{}
243+
if !b.Gw.GetConfig().DisableKeyActionsByUsername {
244+
toSearchList = append(toSearchList, b.Gw.generateToken(orgID, keyName))
245+
}
246+
247+
toSearchList = append(toSearchList, keyName)
248+
for _, fallback := range b.Gw.GetConfig().HashKeyFunctionFallback {
249+
if !b.Gw.GetConfig().DisableKeyActionsByUsername {
250+
toSearchList = append(toSearchList, b.Gw.generateToken(orgID, keyName, fallback))
251+
}
252+
}
253+
254+
var jsonKeyValList []string
255+
256+
jsonKeyValList, err = b.store.GetMultiKey(toSearchList)
257+
// pick the 1st non empty from the returned list
258+
for idx, val := range jsonKeyValList {
259+
if val != "" {
260+
jsonKeyVal = val
261+
keyId = toSearchList[idx]
262+
break
263+
}
264+
}
265+
} else {
266+
// key is not an imported one
267+
jsonKeyVal, err = b.store.GetKey(keyName)
268+
}
269+
}
270+
271+
select {
272+
case <-ctx.Done():
273+
log.WithFields(logrus.Fields{
274+
"prefix": "auth-mgr",
275+
"inbound-key": b.Gw.obfuscateKey(keyName),
276+
}).Debug("Context cancelled during session detail fetch")
277+
return user.SessionState{}, false
278+
default:
279+
}
280+
281+
if err != nil {
282+
log.WithFields(logrus.Fields{
283+
"prefix": "auth-mgr",
284+
"inbound-key": b.Gw.obfuscateKey(keyName),
285+
"err": err,
286+
}).Debug("Could not get session detail, key not found")
287+
return user.SessionState{}, false
288+
}
289+
session := &user.SessionState{}
290+
if err := json.Unmarshal([]byte(jsonKeyVal), &session); err != nil {
291+
log.Error("Couldn't unmarshal session object (may be cache miss): ", err)
292+
return user.SessionState{}, false
293+
}
294+
session.KeyID = keyId
295+
return session.Clone(), true
296+
}
297+
218298
func (b *DefaultSessionManager) Stop() {}
219299

220300
// Sessions returns all sessions in the key store that match a filter key (a prefix)

gateway/middleware.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ func (t *BaseMiddleware) fetchOrgSessionWithTimeout(orgID string) (user.SessionS
368368
}
369369
}()
370370

371-
session, found := t.Spec.OrgSessionManager.SessionDetail(orgID, orgID, false)
371+
session, found := t.Spec.OrgSessionManager.SessionDetailContext(timeoutCtx, orgID, orgID, false)
372372
if found && t.Spec.GlobalConfig.EnforceOrgDataAge {
373373
t.Logger().Debug("Setting data expiry: ", orgID)
374374
t.Gw.ExpiryCache.Set(session.OrgID, session.DataExpires, cache.DefaultExpiration)

gateway/middleware_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gateway
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"io/ioutil"
@@ -40,6 +41,17 @@ func (m mockStore) SessionDetail(orgID string, keyName string, hashed bool) (use
4041
return sess.Clone(), !m.DetailNotFound
4142
}
4243

44+
func (m mockStore) SessionDetailContext(ctx context.Context, orgID string, keyName string, hashed bool) (user.SessionState, bool) {
45+
if m.Delay > 0 {
46+
select {
47+
case <-time.After(m.Delay):
48+
case <-ctx.Done():
49+
return user.SessionState{}, false
50+
}
51+
}
52+
return sess.Clone(), !m.DetailNotFound
53+
}
54+
4355
func TestBaseMiddleware_OrgSessionExpiry(t *testing.T) {
4456
ts := StartTest(nil)
4557
defer ts.Close()

0 commit comments

Comments
 (0)