diff --git a/lib/msal-browser/src/app/PublicClientNext.ts b/lib/msal-browser/src/app/PublicClientNext.ts index fd11e940cd..a3c8634993 100644 --- a/lib/msal-browser/src/app/PublicClientNext.ts +++ b/lib/msal-browser/src/app/PublicClientNext.ts @@ -281,6 +281,18 @@ export class PublicClientNext implements IPublicClientApplication { return this.controller.getAllAccounts(accountFilter); } + /** + * Returns all the accounts in the cache that match the optional filter. If no filter is provided, all accounts are returned. + * @param accountFilter - (Optional) filter to narrow down the accounts returned + * @returns Array of AccountInfo objects in cache + */ + async getAllAccountsAsync(accountFilter?: AccountFilter): Promise { + if (typeof this.controller.getAllAccountsAsync === "function") { + return this.controller.getAllAccountsAsync(accountFilter); + } + return Promise.resolve([]); + } + /** * Event handler function which allows users to fire events after the PublicClientApplication object * has loaded during redirect flows. This should be invoked on all page loads involved in redirect diff --git a/lib/msal-browser/src/cache/AsyncAccountManager.ts b/lib/msal-browser/src/cache/AsyncAccountManager.ts new file mode 100644 index 0000000000..79f3466df1 --- /dev/null +++ b/lib/msal-browser/src/cache/AsyncAccountManager.ts @@ -0,0 +1,192 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { AccountInfo, AccountFilter, Logger } from "@azure/msal-common/browser"; +import { WorkerCacheManager } from "./WorkerCacheManager.js"; +/** + * Returns all the accounts in the cache that match the optional filter. If no filter is provided, all accounts are returned. + * @param accountFilter - (Optional) filter to narrow down the accounts returned + * @returns Array of AccountInfo objects in cache + */ +export async function getAllAccounts( + logger: Logger, + workerStorage: WorkerCacheManager, + accountFilter?: AccountFilter +): Promise { + logger.verbose("getAllAccounts called"); + return workerStorage.getAllAccounts(accountFilter); +} + +// /** +// * Returns the first account found in the cache that matches the account filter passed in. +// * @param accountFilter +// * @returns The first account found in the cache matching the provided filter or null if no account could be found. +// */ +// export function getAccount( +// accountFilter: AccountFilter, +// logger: Logger, +// workerStorage: WorkerCacheManager +// ): AccountInfo | null { +// logger.trace("getAccount called"); +// if (Object.keys(accountFilter).length === 0) { +// logger.warning("getAccount: No accountFilter provided"); +// return null; +// } + +/* + * const account: AccountInfo | null = + * workerStorage.getAccountInfoFilteredBy(accountFilter); + */ + +/* + * if (account) { + * logger.verbose( + * "getAccount: Account matching provided filter found, returning" + * ); + * return account; + * } else { + * logger.verbose("getAccount: No matching account found, returning null"); + * return null; + * } + * } + */ + +// /** +// * Returns the signed in account matching username. +// * (the account object is created at the time of successful login) +// * or null when no matching account is found. +// * This API is provided for convenience but getAccountById should be used for best reliability +// * @param username +// * @returns The account object stored in MSAL +// */ +// export function getAccountByUsername( +// username: string, +// logger: Logger, +// workerStorage: WorkerCacheManager +// ): AccountInfo | null { +// logger.trace("getAccountByUsername called"); +// if (!username) { +// logger.warning("getAccountByUsername: No username provided"); +// return null; +// } + +/* + * const account = workerStorage.getAccountInfoFilteredBy({ + * username, + * }); + * if (account) { + * logger.verbose( + * "getAccountByUsername: Account matching username found, returning" + * ); + * logger.verbosePii( + * `getAccountByUsername: Returning signed-in accounts matching username: ${username}` + * ); + * return account; + * } else { + * logger.verbose( + * "getAccountByUsername: No matching account found, returning null" + * ); + * return null; + * } + * } + */ + +// /** +// * Returns the signed in account matching homeAccountId. +// * (the account object is created at the time of successful login) +// * or null when no matching account is found +// * @param homeAccountId +// * @returns The account object stored in MSAL +// */ +// export function getAccountByHomeId( +// homeAccountId: string, +// logger: Logger, +// workerStorage: WorkerCacheManager +// ): AccountInfo | null { +// logger.trace("getAccountByHomeId called"); +// if (!homeAccountId) { +// logger.warning("getAccountByHomeId: No homeAccountId provided"); +// return null; +// } + +/* + * const account = workerStorage.getAccountInfoFilteredBy({ + * homeAccountId, + * }); + * if (account) { + * logger.verbose( + * "getAccountByHomeId: Account matching homeAccountId found, returning" + * ); + * logger.verbosePii( + * `getAccountByHomeId: Returning signed-in accounts matching homeAccountId: ${homeAccountId}` + * ); + * return account; + * } else { + * logger.verbose( + * "getAccountByHomeId: No matching account found, returning null" + * ); + * return null; + * } + * } + */ + +// /** +// * Returns the signed in account matching localAccountId. +// * (the account object is created at the time of successful login) +// * or null when no matching account is found +// * @param localAccountId +// * @returns The account object stored in MSAL +// */ +// export function getAccountByLocalId( +// localAccountId: string, +// logger: Logger, +// workerStorage: WorkerCacheManager +// ): AccountInfo | null { +// logger.trace("getAccountByLocalId called"); +// if (!localAccountId) { +// logger.warning("getAccountByLocalId: No localAccountId provided"); +// return null; +// } + +/* + * const account = workerStorage.getAccountInfoFilteredBy({ + * localAccountId, + * }); + * if (account) { + * logger.verbose( + * "getAccountByLocalId: Account matching localAccountId found, returning" + * ); + * logger.verbosePii( + * `getAccountByLocalId: Returning signed-in accounts matching localAccountId: ${localAccountId}` + * ); + * return account; + * } else { + * logger.verbose( + * "getAccountByLocalId: No matching account found, returning null" + * ); + * return null; + * } + * } + */ + +// /** +// * Sets the account to use as the active account. If no account is passed to the acquireToken APIs, then MSAL will use this active account. +// * @param account +// */ +// export function setActiveAccount( +// account: AccountInfo | null, +// workerStorage: WorkerCacheManager +// ): void { +// workerStorage.setActiveAccount(account); +// } + +// /** +// * Gets the currently active account +// */ +// export function getActiveAccount( +// workerStorage: WorkerCacheManager +// ): AccountInfo | null { +// return workerStorage.getActiveAccount(); +// } diff --git a/lib/msal-browser/src/cache/AsyncMemoryStorage.ts b/lib/msal-browser/src/cache/AsyncMemoryStorage.ts index b28a22f66d..fc70fb2211 100644 --- a/lib/msal-browser/src/cache/AsyncMemoryStorage.ts +++ b/lib/msal-browser/src/cache/AsyncMemoryStorage.ts @@ -3,7 +3,7 @@ * Licensed under the MIT License. */ -import { Logger } from "@azure/msal-common/browser"; +import { IPerformanceClient, Logger, PerformanceEvents } from "@azure/msal-common/browser"; import { BrowserAuthError, BrowserAuthErrorCodes, @@ -12,21 +12,36 @@ import { DatabaseStorage } from "./DatabaseStorage.js"; import { IAsyncStorage } from "./IAsyncStorage.js"; import { MemoryStorage } from "./MemoryStorage.js"; +const BROADCAST_CHANNEL_NAME = "msal.broadcast.cache"; // TODO: Dedupe + /** * This class allows MSAL to store artifacts asynchronously using the DatabaseStorage IndexedDB wrapper, * backed up with the more volatile MemoryStorage object for cases in which IndexedDB may be unavailable. */ export class AsyncMemoryStorage implements IAsyncStorage { + private clientId: string | undefined; private inMemoryCache: MemoryStorage; private indexedDBCache: DatabaseStorage; private logger: Logger; + private initialized: boolean = false; + private performanceClient: IPerformanceClient | undefined; + private broadcast: BroadcastChannel; - constructor(logger: Logger) { + constructor(logger: Logger, clientId?: string, performanceClient?: IPerformanceClient) { this.inMemoryCache = new MemoryStorage(); this.indexedDBCache = new DatabaseStorage(); + this.clientId = clientId; this.logger = logger; + this.performanceClient = performanceClient; + this.broadcast = new BroadcastChannel(BROADCAST_CHANNEL_NAME); } + async initialize(): Promise { + this.initialized = true; + // Register listener for cache updates in other tabs + this.broadcast.addEventListener("message", this.updateCache.bind(this)); + }; + private handleDatabaseAccessError(error: unknown): void { if ( error instanceof BrowserAuthError && @@ -153,4 +168,39 @@ export class AsyncMemoryStorage implements IAsyncStorage { return false; } } + + private updateCache(event: MessageEvent): void { + this.logger.trace("Updating internal cache from broadcast event"); + const perfMeasurement = this.performanceClient?.startMeasurement( + PerformanceEvents.LocalStorageUpdated + ); + perfMeasurement?.add({ isBackground: true }); + + const { key, value, context } = event.data; + if (!key) { + this.logger.error("Broadcast event missing key"); + perfMeasurement?.end({ success: false, errorCode: "noKey" }); + return; + } + + if (context && context !== this.clientId) { + this.logger.trace( + `Ignoring broadcast event from clientId: ${context}` + ); + perfMeasurement?.end({ + success: false, + errorCode: "contextMismatch", + }); + return; + } + + if (!value) { + this.inMemoryCache.removeItem(key); + this.logger.verbose("Removed item from internal cache"); + } else { + this.inMemoryCache.setItem(key, value); + this.logger.verbose("Updated item in internal cache"); + } + perfMeasurement?.end({ success: true }); + } } diff --git a/lib/msal-browser/src/cache/CacheHelpers.ts b/lib/msal-browser/src/cache/CacheHelpers.ts index b8439c627c..db579fa890 100644 --- a/lib/msal-browser/src/cache/CacheHelpers.ts +++ b/lib/msal-browser/src/cache/CacheHelpers.ts @@ -6,6 +6,7 @@ import { TokenKeys } from "@azure/msal-common/browser"; import { StaticCacheKeys } from "../utils/BrowserConstants.js"; import { IWindowStorage } from "./IWindowStorage.js"; +import { IAsyncStorage } from "./IAsyncStorage.js"; /** * Returns a list of cache keys for all known accounts @@ -21,6 +22,15 @@ export function getAccountKeys(storage: IWindowStorage): Array { return []; } +export async function getAccountKeysAsync(storage: IAsyncStorage): Promise> { + const accountKeys = await storage.getItem(StaticCacheKeys.ACCOUNT_KEYS); + if (accountKeys) { + return JSON.parse(accountKeys); + } + + return []; +} + /** * Returns a list of cache keys for all known tokens * @param clientId @@ -50,3 +60,33 @@ export function getTokenKeys( refreshToken: [], }; } + +/** + * Returns a list of cache keys for all known tokens + * @param clientId + * @param storage + * @returns + */ +export async function getTokenKeysAsync( + clientId: string, + storage: IAsyncStorage +): Promise { + const item = await storage.getItem(`${StaticCacheKeys.TOKEN_KEYS}.${clientId}`); + if (item) { + const tokenKeys = JSON.parse(item); + if ( + tokenKeys && + tokenKeys.hasOwnProperty("idToken") && + tokenKeys.hasOwnProperty("accessToken") && + tokenKeys.hasOwnProperty("refreshToken") + ) { + return tokenKeys as TokenKeys; + } + } + + return { + idToken: [], + accessToken: [], + refreshToken: [], + }; +} diff --git a/lib/msal-browser/src/cache/DatabaseStorage.ts b/lib/msal-browser/src/cache/DatabaseStorage.ts index e2c53e5faf..a8ca776a8e 100644 --- a/lib/msal-browser/src/cache/DatabaseStorage.ts +++ b/lib/msal-browser/src/cache/DatabaseStorage.ts @@ -43,12 +43,22 @@ export class DatabaseStorage implements IAsyncStorage { this.dbOpen = false; } + async initialize(): Promise { + } + /** * Opens IndexedDB instance. */ async open(): Promise { return new Promise((resolve, reject) => { - const openDB = window.indexedDB.open(this.dbName, this.version); + + let openDB = null; + if(typeof window !== "undefined" && window.indexedDB) { + openDB = window.indexedDB.open(this.dbName, this.version); + } else { + openDB = self.indexedDB.open(this.dbName, this.version); + } + openDB.addEventListener( "upgradeneeded", (e: IDBVersionChangeEvent) => { diff --git a/lib/msal-browser/src/cache/IAsyncStorage.ts b/lib/msal-browser/src/cache/IAsyncStorage.ts index e96a97927e..47250c34e8 100644 --- a/lib/msal-browser/src/cache/IAsyncStorage.ts +++ b/lib/msal-browser/src/cache/IAsyncStorage.ts @@ -4,6 +4,11 @@ */ export interface IAsyncStorage { + /** + * Async initializer + */ + + initialize(): Promise; /** * Get the item from the asynchronous storage object matching the given key. * @param key diff --git a/lib/msal-browser/src/cache/WorkerCacheManager.ts b/lib/msal-browser/src/cache/WorkerCacheManager.ts new file mode 100644 index 0000000000..8750595de1 --- /dev/null +++ b/lib/msal-browser/src/cache/WorkerCacheManager.ts @@ -0,0 +1,1340 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { + AccessTokenEntity, + AccountEntity, + AccountInfo, + ActiveAccountFilters, + AppMetadataEntity, + AsyncCacheManager, + AuthenticationScheme, + AuthorityMetadataEntity, + CacheError, + CacheHelpers, + CacheRecord, + ClientAuthErrorCodes, + CommonAuthorizationUrlRequest, + Constants, + createClientAuthError, + CredentialType, + DEFAULT_CRYPTO_IMPLEMENTATION, + ICrypto, + IdTokenEntity, + invokeAsync, + IPerformanceClient, + Logger, + PerformanceEvents, + PersistentCacheKeys, + RefreshTokenEntity, + ServerTelemetryEntity, + StaticAuthorityOptions, + StoreInCache, + StringUtils, + ThrottlingEntity, + TimeUtils, + TokenKeys, +} from "@azure/msal-common/browser"; +import { CacheOptions } from "../config/Configuration.js"; +import { + BrowserAuthErrorCodes, + createBrowserAuthError, +} from "../error/BrowserAuthError.js"; +import { + BrowserCacheLocation, + InMemoryCacheKeys, + INTERACTION_TYPE, + StaticCacheKeys, + TemporaryCacheKeys, +} from "../utils/BrowserConstants.js"; +import { MemoryStorage } from "./MemoryStorage.js"; +import { NativeTokenRequest } from "../broker/nativeBroker/NativeRequest.js"; +import { AuthenticationResult } from "../response/AuthenticationResult.js"; +import { SilentRequest } from "../request/SilentRequest.js"; +import { SsoSilentRequest } from "../request/SsoSilentRequest.js"; +import { RedirectRequest } from "../request/RedirectRequest.js"; +import { PopupRequest } from "../request/PopupRequest.js"; +import { base64Decode } from "../encode/Base64Decode.js"; +import { base64Encode } from "../encode/Base64Encode.js"; +import { getAccountKeysAsync, getTokenKeysAsync } from "./CacheHelpers.js"; +import { EventType } from "../event/EventType.js"; +import { EventHandler } from "../event/EventHandler.js"; +import { AsyncMemoryStorage } from "./AsyncMemoryStorage.js"; +import { IAsyncStorage } from "./IAsyncStorage.js"; + +/** + * This class implements the cache storage interface for MSAL through browser local or session storage. + * Cookies are only used if storeAuthStateInCookie is true, and are only used for + * parameters such as state and nonce, generally. + */ +export class WorkerCacheManager extends AsyncCacheManager { + // Cache configuration, either set by user or default values. + protected cacheConfig: Required; + // Window storage object (either local or sessionStorage) + protected workerStorage: IAsyncStorage; + // Internal in-memory storage object used for data used by msal that does not need to persist across page loads + protected internalStorage: MemoryStorage; + // Temporary cache + protected temporaryCacheStorage: IAsyncStorage; + // Logger instance + protected logger: Logger; + // Telemetry perf client + protected performanceClient: IPerformanceClient; + // Event Handler + private eventHandler: EventHandler; + + constructor( + clientId: string, + cacheConfig: Required, + cryptoImpl: ICrypto, + logger: Logger, + performanceClient: IPerformanceClient, + eventHandler: EventHandler, + staticAuthorityOptions?: StaticAuthorityOptions + ) { + super(clientId, cryptoImpl, logger, staticAuthorityOptions); + this.cacheConfig = cacheConfig; + this.logger = logger; + this.internalStorage = new MemoryStorage(); + this.performanceClient = performanceClient; + + this.workerStorage = new AsyncMemoryStorage(this.logger, this.clientId, this.performanceClient); + + this.temporaryCacheStorage = new AsyncMemoryStorage(this.logger, this.clientId, this.performanceClient); + + this.eventHandler = eventHandler; + } + + async initialize(): Promise { + await this.workerStorage.initialize(); + } + + /** + * Parses passed value as JSON object, JSON.parse() will throw an error. + * @param input + */ + protected validateAndParseJson(jsonValue: string): object | null { + try { + const parsedJson = JSON.parse(jsonValue); + /** + * There are edge cases in which JSON.parse will successfully parse a non-valid JSON object + * (e.g. JSON.parse will parse an escaped string into an unescaped string), so adding a type check + * of the parsed value is necessary in order to be certain that the string represents a valid JSON object. + * + */ + return parsedJson && typeof parsedJson === "object" + ? parsedJson + : null; + } catch (error) { + return null; + } + } + + /** + * Reads account from cache, deserializes it into an account entity and returns it. + * If account is not found from the key, returns null and removes key from map. + * @param accountKey + * @returns + */ + async getAccount(accountKey: string): Promise { + this.logger.trace("WorkerCacheManager.getAccount called"); + const serializedAccount = await this.workerStorage.getItem(accountKey); + if (!serializedAccount) { + await this.removeAccountKeyFromMap(accountKey); + return null; + } + + const parsedAccount = this.validateAndParseJson(serializedAccount); + if (!parsedAccount || !AccountEntity.isAccountEntity(parsedAccount)) { + await this.removeAccountKeyFromMap(accountKey); + return null; + } + + return AsyncCacheManager.toObject( + new AccountEntity(), + parsedAccount + ); + } + + /** + * set account entity in the platform cache + * @param account + */ + async setAccount( + account: AccountEntity, + ): Promise { + this.logger.trace("WorkerCacheManager.setAccount called"); + const key = account.generateAccountKey(); + await invokeAsync( + this.workerStorage.setItem.bind(this.workerStorage), + PerformanceEvents.SetUserData, + this.logger, + this.performanceClient + )(key, JSON.stringify(account)); + const wasAdded = await this.addAccountKeyToMap(key); + + /** + * @deprecated - Remove this in next major version in favor of more consistent LOGIN event + */ + if ( + this.cacheConfig.cacheLocation === + BrowserCacheLocation.LocalStorage && + wasAdded + ) { + this.eventHandler.emitEvent( + EventType.ACCOUNT_ADDED, + undefined, + account.getAccountInfo() + ); + } + } + + /** + * Returns the array of account keys currently cached + * @returns + */ + async getAccountKeys(): Promise> { + return getAccountKeysAsync(this.workerStorage); + } + + /** + * Add a new account to the key map + * @param key + */ + async addAccountKeyToMap(key: string): Promise { + this.logger.trace("WorkerCacheManager.addAccountKeyToMap called"); + this.logger.tracePii( + `WorkerCacheManager.addAccountKeyToMap called with key: ${key}` + ); + const accountKeys = await this.getAccountKeys(); + if (accountKeys.indexOf(key) === -1) { + // Only add key if it does not already exist in the map + accountKeys.push(key); + await this.workerStorage.setItem( + StaticCacheKeys.ACCOUNT_KEYS, + JSON.stringify(accountKeys) + ); + this.logger.verbose( + "WorkerCacheManager.addAccountKeyToMap account key added" + ); + return true; + } else { + this.logger.verbose( + "WorkerCacheManager.addAccountKeyToMap account key already exists in map" + ); + return false; + } + } + + /** + * Remove an account from the key map + * @param key + */ + async removeAccountKeyFromMap(key: string): Promise { + this.logger.trace("WorkerCacheManager.removeAccountKeyFromMap called"); + this.logger.tracePii( + `WorkerCacheManager.removeAccountKeyFromMap called with key: ${key}` + ); + const accountKeys = await this.getAccountKeys(); + const removalIndex = accountKeys.indexOf(key); + if (removalIndex > -1) { + accountKeys.splice(removalIndex, 1); + await this.workerStorage.setItem( + StaticCacheKeys.ACCOUNT_KEYS, + JSON.stringify(accountKeys) + ); + this.logger.trace( + "WorkerCacheManager.removeAccountKeyFromMap account key removed" + ); + } else { + this.logger.trace( + "WorkerCacheManager.removeAccountKeyFromMap key not found in existing map" + ); + } + } + + /** + * Extends inherited removeAccount function to include removal of the account key from the map + * @param key + */ + async removeAccount(key: string): Promise { + void await super.removeAccount(key); + await this.removeAccountKeyFromMap(key); + } + + /** + * Removes credentials associated with the provided account + * @param account + */ + async removeAccountContext(account: AccountEntity): Promise { + await super.removeAccountContext(account); + + /** + * @deprecated - Remove this in next major version in favor of more consistent LOGOUT event + */ + if ( + this.cacheConfig.cacheLocation === BrowserCacheLocation.LocalStorage + ) { + this.eventHandler.emitEvent( + EventType.ACCOUNT_REMOVED, + undefined, + account.getAccountInfo() + ); + } + } + + /** + * Removes given idToken from the cache and from the key map + * @param key + */ + async removeIdToken(key: string): Promise { + await super.removeIdToken(key); + await this.removeTokenKey(key, CredentialType.ID_TOKEN); + } + + /** + * Removes given accessToken from the cache and from the key map + * @param key + */ + async removeAccessToken(key: string): Promise { + void await super.removeAccessToken(key); + await this.removeTokenKey(key, CredentialType.ACCESS_TOKEN); + } + + /** + * Removes given refreshToken from the cache and from the key map + * @param key + */ + async removeRefreshToken(key: string): Promise { + await super.removeRefreshToken(key); + await this.removeTokenKey(key, CredentialType.REFRESH_TOKEN); + } + + /** + * Gets the keys for the cached tokens associated with this clientId + * @returns + */ + async getTokenKeys(): Promise { + return getTokenKeysAsync(this.clientId, this.workerStorage); + } + + /** + * Adds the given key to the token key map + * @param key + * @param type + */ + async addTokenKey(key: string, type: CredentialType): Promise { + this.logger.trace("WorkerCacheManager addTokenKey called"); + const tokenKeys = await this.getTokenKeys(); + + switch (type) { + case CredentialType.ID_TOKEN: + if (tokenKeys.idToken.indexOf(key) === -1) { + this.logger.info( + "WorkerCacheManager: addTokenKey - idToken added to map" + ); + tokenKeys.idToken.push(key); + } + break; + case CredentialType.ACCESS_TOKEN: + if (tokenKeys.accessToken.indexOf(key) === -1) { + this.logger.info( + "WorkerCacheManager: addTokenKey - accessToken added to map" + ); + tokenKeys.accessToken.push(key); + } + break; + case CredentialType.REFRESH_TOKEN: + if (tokenKeys.refreshToken.indexOf(key) === -1) { + this.logger.info( + "WorkerCacheManager: addTokenKey - refreshToken added to map" + ); + tokenKeys.refreshToken.push(key); + } + break; + default: + this.logger.error( + `WorkerCacheManager:addTokenKey - CredentialType provided invalid. CredentialType: ${type}` + ); + throw createClientAuthError( + ClientAuthErrorCodes.unexpectedCredentialType + ); + } + + await this.workerStorage.setItem( + `${StaticCacheKeys.TOKEN_KEYS}.${this.clientId}`, + JSON.stringify(tokenKeys) + ); + } + + /** + * Removes the given key from the token key map + * @param key + * @param type + */ + async removeTokenKey(key: string, type: CredentialType): Promise { + this.logger.trace("WorkerCacheManager removeTokenKey called"); + const tokenKeys = await this.getTokenKeys(); + + switch (type) { + case CredentialType.ID_TOKEN: { + this.logger.infoPii( + `WorkerCacheManager: removeTokenKey - attempting to remove idToken with key: ${key} from map` + ); + const idRemoval = tokenKeys.idToken.indexOf(key); + if (idRemoval > -1) { + this.logger.info( + "WorkerCacheManager: removeTokenKey - idToken removed from map" + ); + tokenKeys.idToken.splice(idRemoval, 1); + } else { + this.logger.info( + "WorkerCacheManager: removeTokenKey - idToken does not exist in map. Either it was previously removed or it was never added." + ); + } + break; + } + case CredentialType.ACCESS_TOKEN: { + this.logger.infoPii( + `WorkerCacheManager: removeTokenKey - attempting to remove accessToken with key: ${key} from map` + ); + const accessRemoval = tokenKeys.accessToken.indexOf(key); + if (accessRemoval > -1) { + this.logger.info( + "WorkerCacheManager: removeTokenKey - accessToken removed from map" + ); + tokenKeys.accessToken.splice(accessRemoval, 1); + } else { + this.logger.info( + "WorkerCacheManager: removeTokenKey - accessToken does not exist in map. Either it was previously removed or it was never added." + ); + } + break; + } + case CredentialType.REFRESH_TOKEN: { + this.logger.infoPii( + `WorkerCacheManager: removeTokenKey - attempting to remove refreshToken with key: ${key} from map` + ); + const refreshRemoval = tokenKeys.refreshToken.indexOf(key); + if (refreshRemoval > -1) { + this.logger.info( + "WorkerCacheManager: removeTokenKey - refreshToken removed from map" + ); + tokenKeys.refreshToken.splice(refreshRemoval, 1); + } else { + this.logger.info( + "WorkerCacheManager: removeTokenKey - refreshToken does not exist in map. Either it was previously removed or it was never added." + ); + } + break; + } + default: + this.logger.error( + `WorkerCacheManager:removeTokenKey - CredentialType provided invalid. CredentialType: ${type}` + ); + throw createClientAuthError( + ClientAuthErrorCodes.unexpectedCredentialType + ); + } + + await this.workerStorage.setItem( + `${StaticCacheKeys.TOKEN_KEYS}.${this.clientId}`, + JSON.stringify(tokenKeys) + ); + } + + /** + * generates idToken entity from a string + * @param idTokenKey + */ + async getIdTokenCredential(idTokenKey: string): Promise { + const value = await this.workerStorage.getItem(idTokenKey); + if (!value) { + this.logger.trace( + "WorkerCacheManager.getIdTokenCredential: called, no cache hit" + ); + await this.removeTokenKey(idTokenKey, CredentialType.ID_TOKEN); + return null; + } + + const parsedIdToken = this.validateAndParseJson(value); + if (!parsedIdToken || !CacheHelpers.isIdTokenEntity(parsedIdToken)) { + this.logger.trace( + "WorkerCacheManager.getIdTokenCredential: called, no cache hit" + ); + await this.removeTokenKey(idTokenKey, CredentialType.ID_TOKEN); + return null; + } + + this.logger.trace( + "WorkerCacheManager.getIdTokenCredential: cache hit" + ); + return parsedIdToken as IdTokenEntity; + } + + /** + * set IdToken credential to the platform cache + * @param idToken + */ + async setIdTokenCredential( + idToken: IdTokenEntity, + ): Promise { + this.logger.trace("WorkerCacheManager.setIdTokenCredential called"); + const idTokenKey = CacheHelpers.generateCredentialKey(idToken); + + await invokeAsync( + this.workerStorage.setItem.bind(this.workerStorage), + PerformanceEvents.SetUserData, + this.logger, + this.performanceClient + )(idTokenKey, JSON.stringify(idToken)); + + await this.addTokenKey(idTokenKey, CredentialType.ID_TOKEN); + } + + /** + * generates accessToken entity from a string + * @param key + */ + async getAccessTokenCredential(accessTokenKey: string): Promise { + const value = await this.workerStorage.getItem(accessTokenKey); + if (!value) { + this.logger.trace( + "WorkerCacheManager.getAccessTokenCredential: called, no cache hit" + ); + await this.removeTokenKey(accessTokenKey, CredentialType.ACCESS_TOKEN); + return null; + } + const parsedAccessToken = this.validateAndParseJson(value); + if ( + !parsedAccessToken || + !CacheHelpers.isAccessTokenEntity(parsedAccessToken) + ) { + this.logger.trace( + "WorkerCacheManager.getAccessTokenCredential: called, no cache hit" + ); + await this.removeTokenKey(accessTokenKey, CredentialType.ACCESS_TOKEN); + return null; + } + + this.logger.trace( + "WorkerCacheManager.getAccessTokenCredential: cache hit" + ); + return parsedAccessToken as AccessTokenEntity; + } + + /** + * set accessToken credential to the platform cache + * @param accessToken + */ + async setAccessTokenCredential( + accessToken: AccessTokenEntity, + ): Promise { + this.logger.trace( + "WorkerCacheManager.setAccessTokenCredential called" + ); + const accessTokenKey = CacheHelpers.generateCredentialKey(accessToken); + await invokeAsync( + this.workerStorage.setItem.bind(this.workerStorage), + PerformanceEvents.SetUserData, + this.logger, + this.performanceClient + )(accessTokenKey, JSON.stringify(accessToken)); + + await this.addTokenKey(accessTokenKey, CredentialType.ACCESS_TOKEN); + } + + /** + * generates refreshToken entity from a string + * @param refreshTokenKey + */ + async getRefreshTokenCredential( + refreshTokenKey: string + ): Promise { + const value = await this.workerStorage.getItem(refreshTokenKey); + if (!value) { + this.logger.trace( + "WorkerCacheManager.getRefreshTokenCredential: called, no cache hit" + ); + await this.removeTokenKey(refreshTokenKey, CredentialType.REFRESH_TOKEN); + return null; + } + const parsedRefreshToken = this.validateAndParseJson(value); + if ( + !parsedRefreshToken || + !CacheHelpers.isRefreshTokenEntity(parsedRefreshToken) + ) { + this.logger.trace( + "WorkerCacheManager.getRefreshTokenCredential: called, no cache hit" + ); + await this.removeTokenKey(refreshTokenKey, CredentialType.REFRESH_TOKEN); + return null; + } + + this.logger.trace( + "WorkerCacheManager.getRefreshTokenCredential: cache hit" + ); + return parsedRefreshToken as RefreshTokenEntity; + } + + /** + * set refreshToken credential to the platform cache + * @param refreshToken + */ + async setRefreshTokenCredential( + refreshToken: RefreshTokenEntity, + ): Promise { + this.logger.trace( + "WorkerCacheManager.setRefreshTokenCredential called" + ); + const refreshTokenKey = + CacheHelpers.generateCredentialKey(refreshToken); + await invokeAsync( + this.workerStorage.setItem.bind(this.workerStorage), + PerformanceEvents.SetUserData, + this.logger, + this.performanceClient + )(refreshTokenKey, JSON.stringify(refreshToken)); + + await this.addTokenKey(refreshTokenKey, CredentialType.REFRESH_TOKEN); + } + + /** + * fetch appMetadata entity from the platform cache + * @param appMetadataKey + */ + async getAppMetadata(appMetadataKey: string): Promise { + const value = await this.workerStorage.getItem(appMetadataKey); + if (!value) { + this.logger.trace( + "WorkerCacheManager.getAppMetadata: called, no cache hit" + ); + return null; + } + + const parsedMetadata = this.validateAndParseJson(value); + if ( + !parsedMetadata || + !CacheHelpers.isAppMetadataEntity(appMetadataKey, parsedMetadata) + ) { + this.logger.trace( + "WorkerCacheManager.getAppMetadata: called, no cache hit" + ); + return null; + } + + this.logger.trace("WorkerCacheManager.getAppMetadata: cache hit"); + return parsedMetadata as AppMetadataEntity; + } + + /** + * set appMetadata entity to the platform cache + * @param appMetadata + */ + async setAppMetadata(appMetadata: AppMetadataEntity): Promise { + this.logger.trace("WorkerCacheManager.setAppMetadata called"); + const appMetadataKey = CacheHelpers.generateAppMetadataKey(appMetadata); + await this.workerStorage.setItem( + appMetadataKey, + JSON.stringify(appMetadata) + ); + } + + /** + * fetch server telemetry entity from the platform cache + * @param serverTelemetryKey + */ + async getServerTelemetry( + serverTelemetryKey: string + ): Promise { + const value = await this.workerStorage.getItem(serverTelemetryKey); + if (!value) { + this.logger.trace( + "WorkerCacheManager.getServerTelemetry: called, no cache hit" + ); + return null; + } + const parsedEntity = this.validateAndParseJson(value); + if ( + !parsedEntity || + !CacheHelpers.isServerTelemetryEntity( + serverTelemetryKey, + parsedEntity + ) + ) { + this.logger.trace( + "WorkerCacheManager.getServerTelemetry: called, no cache hit" + ); + return null; + } + + this.logger.trace("WorkerCacheManager.getServerTelemetry: cache hit"); + return parsedEntity as ServerTelemetryEntity; + } + + /** + * set server telemetry entity to the platform cache + * @param serverTelemetryKey + * @param serverTelemetry + */ + async setServerTelemetry( + serverTelemetryKey: string, + serverTelemetry: ServerTelemetryEntity + ): Promise { + this.logger.trace("WorkerCacheManager.setServerTelemetry called"); + await this.workerStorage.setItem( + serverTelemetryKey, + JSON.stringify(serverTelemetry) + ); + } + + /** + * + */ + async getAuthorityMetadata(key: string): Promise { + const value = this.internalStorage.getItem(key); + if (!value) { + this.logger.trace( + "WorkerCacheManager.getAuthorityMetadata: called, no cache hit" + ); + return null; + } + const parsedMetadata = this.validateAndParseJson(value); + if ( + parsedMetadata && + CacheHelpers.isAuthorityMetadataEntity(key, parsedMetadata) + ) { + this.logger.trace( + "WorkerCacheManager.getAuthorityMetadata: cache hit" + ); + return parsedMetadata as AuthorityMetadataEntity; + } + return null; + } + + /** + * + */ + async getAuthorityMetadataKeys(): Promise> { + const allKeys = this.internalStorage.getKeys(); + return allKeys.filter((key) => { + return this.isAuthorityMetadata(key); + }); + } + + /** + * Sets wrapper metadata in memory + * @param wrapperSKU + * @param wrapperVersion + */ + setWrapperMetadata(wrapperSKU: string, wrapperVersion: string): void { + this.internalStorage.setItem(InMemoryCacheKeys.WRAPPER_SKU, wrapperSKU); + this.internalStorage.setItem( + InMemoryCacheKeys.WRAPPER_VER, + wrapperVersion + ); + } + + /** + * Returns wrapper metadata from in-memory storage + */ + getWrapperMetadata(): [string, string] { + const sku = + this.internalStorage.getItem(InMemoryCacheKeys.WRAPPER_SKU) || + Constants.EMPTY_STRING; + const version = + this.internalStorage.getItem(InMemoryCacheKeys.WRAPPER_VER) || + Constants.EMPTY_STRING; + return [sku, version]; + } + + /** + * + * @param entity + */ + async setAuthorityMetadata(key: string, entity: AuthorityMetadataEntity): Promise { + this.logger.trace("WorkerCacheManager.setAuthorityMetadata called"); + this.internalStorage.setItem(key, JSON.stringify(entity)); + } + + /** + * Gets the active account + */ + async getActiveAccount(): Promise { + const activeAccountKeyFilters = this.generateCacheKey( + PersistentCacheKeys.ACTIVE_ACCOUNT_FILTERS + ); + const activeAccountValueFilters = await this.workerStorage.getItem( + activeAccountKeyFilters + ); + if (!activeAccountValueFilters) { + this.logger.trace( + "WorkerCacheManager.getActiveAccount: No active account filters found" + ); + return null; + } + const activeAccountValueObj = this.validateAndParseJson( + activeAccountValueFilters + ) as AccountInfo; + if (activeAccountValueObj) { + this.logger.trace( + "WorkerCacheManager.getActiveAccount: Active account filters schema found" + ); + return this.getAccountInfoFilteredBy({ + homeAccountId: activeAccountValueObj.homeAccountId, + localAccountId: activeAccountValueObj.localAccountId, + tenantId: activeAccountValueObj.tenantId, + }); + } + this.logger.trace( + "WorkerCacheManager.getActiveAccount: No active account found" + ); + return null; + } + + /** + * Sets the active account's localAccountId in cache + * @param account + */ + async setActiveAccount(account: AccountInfo | null): Promise { + const activeAccountKey = this.generateCacheKey( + PersistentCacheKeys.ACTIVE_ACCOUNT_FILTERS + ); + if (account) { + this.logger.verbose("setActiveAccount: Active account set"); + const activeAccountValue: ActiveAccountFilters = { + homeAccountId: account.homeAccountId, + localAccountId: account.localAccountId, + tenantId: account.tenantId, + }; + await this.workerStorage.setItem( + activeAccountKey, + JSON.stringify(activeAccountValue) + ); + } else { + this.logger.verbose( + "setActiveAccount: No account passed, active account not set" + ); + await this.workerStorage.removeItem(activeAccountKey); + } + this.eventHandler.emitEvent(EventType.ACTIVE_ACCOUNT_CHANGED); + } + + /** + * fetch throttling entity from the platform cache + * @param throttlingCacheKey + */ + async getThrottlingCache(throttlingCacheKey: string): Promise { + const value = await this.workerStorage.getItem(throttlingCacheKey); + if (!value) { + this.logger.trace( + "WorkerCacheManager.getThrottlingCache: called, no cache hit" + ); + return null; + } + + const parsedThrottlingCache = this.validateAndParseJson(value); + if ( + !parsedThrottlingCache || + !CacheHelpers.isThrottlingEntity( + throttlingCacheKey, + parsedThrottlingCache + ) + ) { + this.logger.trace( + "WorkerCacheManager.getThrottlingCache: called, no cache hit" + ); + return null; + } + + this.logger.trace("WorkerCacheManager.getThrottlingCache: cache hit"); + return parsedThrottlingCache as ThrottlingEntity; + } + + /** + * set throttling entity to the platform cache + * @param throttlingCacheKey + * @param throttlingCache + */ + async setThrottlingCache( + throttlingCacheKey: string, + throttlingCache: ThrottlingEntity + ): Promise { + this.logger.trace("WorkerCacheManager.setThrottlingCache called"); + await this.workerStorage.setItem( + throttlingCacheKey, + JSON.stringify(throttlingCache) + ); + } + + /** + * Gets cache item with given key. + * Will retrieve from cookies if storeAuthStateInCookie is set to true. + * @param key + */ + async getTemporaryCache(cacheKey: string, generateKey?: boolean): Promise { + const key = generateKey ? this.generateCacheKey(cacheKey) : cacheKey; + + const value = await this.temporaryCacheStorage.getItem(key); + if (!value) { + // If temp cache item not found in session/memory, check local storage for items set by old versions + if ( + this.cacheConfig.cacheLocation === + BrowserCacheLocation.LocalStorage + ) { + const item = await this.workerStorage.getItem(key); + if (item) { + this.logger.trace( + "WorkerCacheManager.getTemporaryCache: Temporary cache item found in local storage" + ); + return item; + } + } + this.logger.trace( + "WorkerCacheManager.getTemporaryCache: No cache item found in local storage" + ); + return null; + } + this.logger.trace( + "WorkerCacheManager.getTemporaryCache: Temporary cache item returned" + ); + return value; + } + + /** + * Sets the cache item with the key and value given. + * Stores in cookie if storeAuthStateInCookie is set to true. + * This can cause cookie overflow if used incorrectly. + * @param key + * @param value + */ + async setTemporaryCache( + cacheKey: string, + value: string, + generateKey?: boolean + ): Promise { + const key = generateKey ? this.generateCacheKey(cacheKey) : cacheKey; + + await this.temporaryCacheStorage.setItem(key, value); + } + + /** + * Removes the cache item with the given key. + * @param key + */ + async removeItem(key: string): Promise { + await this.workerStorage.removeItem(key); + } + + /** + * Removes the temporary cache item with the given key. + * Will also clear the cookie item if storeAuthStateInCookie is set to true. + * @param key + */ + async removeTemporaryItem(key: string): Promise { + await this.temporaryCacheStorage.removeItem(key); + } + + /** + * Gets all keys in window. + */ + async getKeys(): Promise { + return this.workerStorage.getKeys(); + } + + /** + * Clears all cache entries created by MSAL. + */ + async clear(): Promise { + // Removes all accounts and their credentials + await this.removeAllAccounts(); + await this.removeAppMetadata(); + + // Remove temp storage first to make sure any cookies are cleared + const tempKeys = await this.temporaryCacheStorage.getKeys(); + for (const cacheKey of tempKeys) { + if ( + cacheKey.indexOf(Constants.CACHE_PREFIX) !== -1 || + cacheKey.indexOf(this.clientId) !== -1 + ) { + await this.removeTemporaryItem(cacheKey); + } + } + + // Removes all remaining MSAL cache items + const workerKeys = await this.workerStorage.getKeys(); + for (const cacheKey of workerKeys) { + if ( + cacheKey.indexOf(Constants.CACHE_PREFIX) !== -1 || + cacheKey.indexOf(this.clientId) !== -1 + ) { + await this.workerStorage.removeItem(cacheKey); + } + } + + this.internalStorage.clear(); + } + + /** + * Clears all access tokes that have claims prior to saving the current one + * @param performanceClient {IPerformanceClient} + * @param correlationId {string} correlation id + * @returns + */ + async clearTokensAndKeysWithClaims( + performanceClient: IPerformanceClient, + correlationId: string + ): Promise { + performanceClient.addQueueMeasurement( + PerformanceEvents.ClearTokensAndKeysWithClaims, + correlationId + ); + + const tokenKeys = await this.getTokenKeys(); + + const removedAccessTokens: Array> = []; + for (const key of tokenKeys.accessToken) { + // if the access token has claims in its key, remove the token key and the token + const credential = await this.getAccessTokenCredential(key); + if ( + credential?.requestedClaimsHash && + key.includes(credential.requestedClaimsHash.toLowerCase()) + ) { + removedAccessTokens.push(this.removeAccessToken(key)); + } + } + await Promise.all(removedAccessTokens); + + // warn if any access tokens are removed + if (removedAccessTokens.length > 0) { + this.logger.warning( + `${removedAccessTokens.length} access tokens with claims in the cache keys have been removed from the cache.` + ); + } + } + + /** + * Prepend msal. to each key; Skip for any JSON object as Key (defined schemas do not need the key appended: AccessToken Keys or the upcoming schema) + * @param key + * @param addInstanceId + */ + generateCacheKey(key: string): string { + const generatedKey = this.validateAndParseJson(key); + if (!generatedKey) { + if (StringUtils.startsWith(key, Constants.CACHE_PREFIX)) { + return key; + } + return `${Constants.CACHE_PREFIX}.${this.clientId}.${key}`; + } + + return JSON.stringify(key); + } + + /** + * Reset all temporary cache items + * @param state + */ + async resetRequestCache(): Promise { + this.logger.trace("WorkerCacheManager.resetRequestCache called"); + + await this.removeTemporaryItem( + this.generateCacheKey(TemporaryCacheKeys.REQUEST_PARAMS) + ); + await this.removeTemporaryItem( + this.generateCacheKey(TemporaryCacheKeys.VERIFIER) + ); + await this.removeTemporaryItem( + this.generateCacheKey(TemporaryCacheKeys.ORIGIN_URI) + ); + await this.removeTemporaryItem( + this.generateCacheKey(TemporaryCacheKeys.URL_HASH) + ); + await this.removeTemporaryItem( + this.generateCacheKey(TemporaryCacheKeys.NATIVE_REQUEST) + ); + await this.setInteractionInProgress(false); + } + + async cacheAuthorizeRequest( + authCodeRequest: CommonAuthorizationUrlRequest, + codeVerifier?: string + ): Promise { + this.logger.trace("WorkerCacheManager.cacheAuthorizeRequest called"); + + const encodedValue = base64Encode(JSON.stringify(authCodeRequest)); + await this.setTemporaryCache( + TemporaryCacheKeys.REQUEST_PARAMS, + encodedValue, + true + ); + + if (codeVerifier) { + const encodedVerifier = base64Encode(codeVerifier); + await this.setTemporaryCache( + TemporaryCacheKeys.VERIFIER, + encodedVerifier, + true + ); + } + } + + /** + * Gets the token exchange parameters from the cache. Throws an error if nothing is found. + */ + async getCachedRequest(): Promise<[CommonAuthorizationUrlRequest, string]> { + this.logger.trace("WorkerCacheManager.getCachedRequest called"); + // Get token request from cache and parse as TokenExchangeParameters. + const encodedTokenRequest = await this.getTemporaryCache( + TemporaryCacheKeys.REQUEST_PARAMS, + true + ); + if (!encodedTokenRequest) { + throw createBrowserAuthError( + BrowserAuthErrorCodes.noTokenRequestCacheError + ); + } + const encodedVerifier = await this.getTemporaryCache( + TemporaryCacheKeys.VERIFIER, + true + ); + + let parsedRequest: CommonAuthorizationUrlRequest; + let verifier = ""; + try { + parsedRequest = JSON.parse(base64Decode(encodedTokenRequest)); + if (encodedVerifier) { + verifier = base64Decode(encodedVerifier); + } + } catch (e) { + this.logger.errorPii(`Attempted to parse: ${encodedTokenRequest}`); + this.logger.error( + `Parsing cached token request threw with error: ${e}` + ); + throw createBrowserAuthError( + BrowserAuthErrorCodes.unableToParseTokenRequestCacheError + ); + } + + return [parsedRequest, verifier]; + } + + /** + * Gets cached native request for redirect flows + */ + async getCachedNativeRequest(): Promise { + this.logger.trace("WorkerCacheManager.getCachedNativeRequest called"); + const cachedRequest = await this.getTemporaryCache( + TemporaryCacheKeys.NATIVE_REQUEST, + true + ); + if (!cachedRequest) { + this.logger.trace( + "WorkerCacheManager.getCachedNativeRequest: No cached native request found" + ); + return null; + } + + const parsedRequest = this.validateAndParseJson( + cachedRequest + ) as NativeTokenRequest; + if (!parsedRequest) { + this.logger.error( + "WorkerCacheManager.getCachedNativeRequest: Unable to parse native request" + ); + return null; + } + + return parsedRequest; + } + + async isInteractionInProgress(matchClientId?: boolean): Promise { + const interaction = await this.getInteractionInProgress(); + const clientId = interaction?.clientId; + + if (matchClientId) { + return clientId === this.clientId; + } else { + return !!clientId; + } + } + + async getInteractionInProgress(): Promise<{ + clientId: string; + type: INTERACTION_TYPE; + } | null> { + const key = `${Constants.CACHE_PREFIX}.${TemporaryCacheKeys.INTERACTION_STATUS_KEY}`; + const value = await this.getTemporaryCache(key, false); + return value ? JSON.parse(value) : null; + } + + async setInteractionInProgress( + inProgress: boolean, + type: INTERACTION_TYPE = INTERACTION_TYPE.SIGNIN + ): Promise { + // Ensure we don't overwrite interaction in progress for a different clientId + const key = `${Constants.CACHE_PREFIX}.${TemporaryCacheKeys.INTERACTION_STATUS_KEY}`; + if (inProgress) { + if (await this.getInteractionInProgress()) { + throw createBrowserAuthError( + BrowserAuthErrorCodes.interactionInProgress + ); + } else { + // No interaction is in progress + await this.setTemporaryCache( + key, + JSON.stringify({ clientId: this.clientId, type }), + false + ); + } + } else if ( + !inProgress && + (await this.getInteractionInProgress())?.clientId === this.clientId + ) { + await this.removeTemporaryItem(key); + } + } + + /** + * Builds credential entities from AuthenticationResult object and saves the resulting credentials to the cache + * @param result + * @param request + */ + async hydrateCache( + result: AuthenticationResult, + request: + | SilentRequest + | SsoSilentRequest + | RedirectRequest + | PopupRequest + ): Promise { + const idTokenEntity = CacheHelpers.createIdTokenEntity( + result.account?.homeAccountId, + result.account?.environment, + result.idToken, + this.clientId, + result.tenantId + ); + + let claimsHash; + if (request.claims) { + claimsHash = await this.cryptoImpl.hashString(request.claims); + } + + /** + * meta data for cache stores time in seconds from epoch + * AuthenticationResult returns expiresOn and extExpiresOn in milliseconds (as a Date object which is in ms) + * We need to map these for the cache when building tokens from AuthenticationResult + * + * The next MSAL VFuture should map these both to same value if possible + */ + + const accessTokenEntity = CacheHelpers.createAccessTokenEntity( + result.account?.homeAccountId, + result.account.environment, + result.accessToken, + this.clientId, + result.tenantId, + result.scopes.join(" "), + // Access token expiresOn stored in seconds, converting from AuthenticationResult expiresOn stored as Date + result.expiresOn + ? TimeUtils.toSecondsFromDate(result.expiresOn) + : 0, + result.extExpiresOn + ? TimeUtils.toSecondsFromDate(result.extExpiresOn) + : 0, + base64Decode, + undefined, // refreshOn + result.tokenType as AuthenticationScheme, + undefined, // userAssertionHash + request.sshKid, + request.claims, + claimsHash + ); + + let refreshToken: RefreshTokenEntity | null = null; + + if(result.refreshToken) { + refreshToken = CacheHelpers.createRefreshTokenEntity( + result.account?.homeAccountId, + result.account?.environment, + result.refreshToken, + this.clientId, + result.familyId, + ); + } + + const cacheRecord = { + idToken: idTokenEntity, + accessToken: accessTokenEntity, + refreshToken: refreshToken + }; + return this.saveCacheRecord(cacheRecord, result.correlationId); + } + + /** + * saves a cache record + * @param cacheRecord {CacheRecord} + * @param storeInCache {?StoreInCache} + * @param correlationId {?string} correlation id + */ + async saveCacheRecord( + cacheRecord: CacheRecord, + correlationId: string, + storeInCache?: StoreInCache + ): Promise { + try { + await super.saveCacheRecord( + cacheRecord, + correlationId, + storeInCache + ); + } catch (e) { + if ( + e instanceof CacheError && + this.performanceClient && + correlationId + ) { + try { + const tokenKeys = await this.getTokenKeys(); + + this.performanceClient.addFields( + { + cacheRtCount: tokenKeys.refreshToken.length, + cacheIdCount: tokenKeys.idToken.length, + cacheAtCount: tokenKeys.accessToken.length, + }, + correlationId + ); + } catch (e) {} + } + + throw e; + } + } +} + +export const DEFAULT_WORKER_CACHE_MANAGER = ( + clientId: string, + logger: Logger, + performanceClient: IPerformanceClient, + eventHandler: EventHandler +): WorkerCacheManager => { + const cacheOptions: Required = { + cacheLocation: BrowserCacheLocation.MemoryStorage, + temporaryCacheLocation: BrowserCacheLocation.MemoryStorage, + storeAuthStateInCookie: false, + secureCookies: false, + cacheMigrationEnabled: false, + claimsBasedCachingEnabled: false, + }; + return new WorkerCacheManager( + clientId, + cacheOptions, + DEFAULT_CRYPTO_IMPLEMENTATION, + logger, + performanceClient, + eventHandler + ); +}; diff --git a/lib/msal-browser/src/config/Configuration.ts b/lib/msal-browser/src/config/Configuration.ts index 8f5175eeab..6f10f4dea2 100644 --- a/lib/msal-browser/src/config/Configuration.ts +++ b/lib/msal-browser/src/config/Configuration.ts @@ -275,7 +275,8 @@ export function buildConfiguration( system: userInputSystem, telemetry: userInputTelemetry, }: Configuration, - isBrowserEnvironment: boolean + isBrowserEnvironment: boolean, + isWebWorkerEnvironment?: boolean ): BrowserConfiguration { // Default auth options for browser const DEFAULT_AUTH_OPTIONS: InternalAuthOptions = { @@ -337,7 +338,7 @@ export function buildConfiguration( const DEFAULT_BROWSER_SYSTEM_OPTIONS: Required = { ...DEFAULT_SYSTEM_OPTIONS, loggerOptions: DEFAULT_LOGGER_OPTIONS, - networkClient: isBrowserEnvironment + networkClient: (isBrowserEnvironment || isWebWorkerEnvironment) ? new FetchClient() : StubbedNetworkModule, navigationClient: new NavigationClient(), diff --git a/lib/msal-browser/src/controllers/ControllerFactory.ts b/lib/msal-browser/src/controllers/ControllerFactory.ts index dc647fb06f..2e7db19d44 100644 --- a/lib/msal-browser/src/controllers/ControllerFactory.ts +++ b/lib/msal-browser/src/controllers/ControllerFactory.ts @@ -10,6 +10,8 @@ import { Configuration } from "../config/Configuration.js"; import { StandardController } from "./StandardController.js"; import { NestedAppAuthController } from "./NestedAppAuthController.js"; import { InitializeApplicationRequest } from "../request/InitializeApplicationRequest.js"; +import { WorkerOperatingContext } from "../operatingcontext/WorkerOperatingContext.js"; +import { WorkerController } from "./WorkerController.js"; export async function createV3Controller( config: Configuration, @@ -24,19 +26,30 @@ export async function createV3Controller( export async function createController( config: Configuration ): Promise { - const standard = new StandardOperatingContext(config); - const nestedApp = new NestedAppOperatingContext(config); + if (typeof window !== "undefined") { + const standard = new StandardOperatingContext(config); + const nestedApp = new NestedAppOperatingContext(config); + + const operatingContexts = [standard.initialize(), nestedApp.initialize()]; - const operatingContexts = [standard.initialize(), nestedApp.initialize()]; + await Promise.all(operatingContexts); - await Promise.all(operatingContexts); + if (nestedApp.isAvailable() && config.auth.supportsNestedAppAuth) { + return NestedAppAuthController.createController(nestedApp); + } else if (standard.isAvailable()) { + return StandardController.createController(standard); + } else { + // Since neither of the actual operating contexts are available keep the UnknownOperatingContextController + return null; + } + } - if (nestedApp.isAvailable() && config.auth.supportsNestedAppAuth) { - return NestedAppAuthController.createController(nestedApp); - } else if (standard.isAvailable()) { - return StandardController.createController(standard); + if (typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope) { + const workerOperatingContext = new WorkerOperatingContext(config); + await workerOperatingContext.initialize(); + return WorkerController.createController(workerOperatingContext); } else { - // Since neither of the actual operating contexts are available keep the UnknownOperatingContextController return null; } + } diff --git a/lib/msal-browser/src/controllers/IController.ts b/lib/msal-browser/src/controllers/IController.ts index 856f4abc38..95a4d211a5 100644 --- a/lib/msal-browser/src/controllers/IController.ts +++ b/lib/msal-browser/src/controllers/IController.ts @@ -74,6 +74,10 @@ export interface IController { getAllAccounts(accountFilter?: AccountFilter): AccountInfo[]; + getAllAccountsAsync?( + accountFilter?: AccountFilter + ): Promise; + handleRedirectPromise(hash?: string): Promise; loginPopup(request?: PopupRequest): Promise; diff --git a/lib/msal-browser/src/controllers/WorkerController.ts b/lib/msal-browser/src/controllers/WorkerController.ts new file mode 100644 index 0000000000..eea8706b95 --- /dev/null +++ b/lib/msal-browser/src/controllers/WorkerController.ts @@ -0,0 +1,969 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { + CommonAuthorizationUrlRequest, + CommonSilentFlowRequest, + PerformanceCallbackFunction, + AccountInfo, + Logger, + ICrypto, + IPerformanceClient, + AccountFilter, + BaseAuthRequest, + PerformanceEvents, + invokeAsync, + Constants, + AccountEntity, + InProgressPerformanceEvent, + AuthError, + getRequestThumbprint, + ClientAuthErrorCodes, + createClientAuthError, + INetworkModule, +} from "@azure/msal-common/browser"; +import { ITokenCache } from "../cache/ITokenCache.js"; +import { BrowserConfiguration } from "../config/Configuration.js"; +import { INavigationClient } from "../navigation/INavigationClient.js"; +import { AuthorizationCodeRequest } from "../request/AuthorizationCodeRequest.js"; +import { EndSessionPopupRequest } from "../request/EndSessionPopupRequest.js"; +import { EndSessionRequest } from "../request/EndSessionRequest.js"; +import { PopupRequest } from "../request/PopupRequest.js"; +import { RedirectRequest } from "../request/RedirectRequest.js"; +import { SilentRequest } from "../request/SilentRequest.js"; +import { SsoSilentRequest } from "../request/SsoSilentRequest.js"; +import { AuthenticationResult } from "../response/AuthenticationResult.js"; +import { ApiId, CacheLookupPolicy, InteractionType, WrapperSKU } from "../utils/BrowserConstants.js"; +import { IController } from "./IController.js"; +import { CryptoOps } from "../crypto/CryptoOps.js"; +import { + blockAPICallsBeforeInitialize, + blockNonBrowserEnvironment, +} from "../utils/BrowserUtils.js"; +import { EventCallbackFunction } from "../event/EventMessage.js"; +import { ClearCacheRequest } from "../request/ClearCacheRequest.js"; +import { EventType } from "../event/EventType.js"; +import { EventHandler } from "../event/EventHandler.js"; +import { BaseOperatingContext } from "../operatingcontext/BaseOperatingContext.js"; +import { InitializeApplicationRequest } from "../request/InitializeApplicationRequest.js"; +import { createNewGuid } from "../crypto/WorkerCrypto.js"; +import { WorkerOperatingContext } from "../operatingcontext/WorkerOperatingContext.js"; +import { NativeMessageHandler } from "../broker/nativeBroker/NativeMessageHandler.js"; +import { DEFAULT_WORKER_CACHE_MANAGER, WorkerCacheManager } from "../cache/WorkerCacheManager.js"; +import * as AsyncAccountManager from "../cache/AsyncAccountManager.js"; +import { BrowserAuthErrorCodes, createBrowserAuthError } from "../error/BrowserAuthError.js"; +import { initializeSilentRequest } from "../request/RequestHelpers.js"; +// import { NativeAuthError, isFatalNativeAuthError } from "../error/NativeAuthError.js"; +import { SilentCacheClient } from "../interaction_client/SilentCacheClient.js"; +import { SilentRefreshClient } from "../interaction_client/SilentRefreshClient.js"; +import { FetchClient } from "../network/FetchClient.js"; + +function getAccountType( + account?: AccountInfo +): "AAD" | "MSA" | "B2C" | undefined { + const idTokenClaims = account?.idTokenClaims; + if (idTokenClaims?.tfp || idTokenClaims?.acr) { + return "B2C"; + } + + if (!idTokenClaims?.tid) { + return undefined; + } else if (idTokenClaims?.tid === "9188040d-6c67-4c5b-b112-36a304b66dad") { + return "MSA"; + } + return "AAD"; +} + +function preflightCheck( + initialized: boolean, + performanceEvent: InProgressPerformanceEvent +) { + try { + // Block token acquisition before initialize has been called + blockAPICallsBeforeInitialize(initialized); + } catch (e) { + performanceEvent.end({ success: false }, e); + throw e; + } +} + +/** + * WorkerController class + */ +export class WorkerController implements IController { + // OperatingContext + protected readonly operatingContext: WorkerOperatingContext; + + // Logger + protected logger: Logger; + + // Storage interface implementation + protected readonly workerStorage: WorkerCacheManager; + + // Network interface implementation + protected readonly networkClient: INetworkModule; + + // Input configuration by developer/user + protected readonly config: BrowserConfiguration; + + // Performance telemetry client + protected readonly performanceClient: IPerformanceClient; + + // Event handler + private readonly eventHandler: EventHandler; + + // Native Extension Provider + protected nativeExtensionProvider: NativeMessageHandler | undefined; + + // Crypto interface implementation + protected readonly workerCrypto: ICrypto; + + // Flag to indicate if in browser environment + protected isBrowserEnvironment: boolean; + + protected isWorkerEnvironment: boolean; + + // Navigation interface implementation + protected navigationClient: INavigationClient; + + // Flag representing whether or not the initialize API has been called and completed + protected initialized: boolean = false; + // Active requests + private activeSilentTokenRequests: Map< + string, + Promise + >; + + constructor(operatingContext: WorkerOperatingContext) { + this.operatingContext = operatingContext; + + this.isBrowserEnvironment = + this.operatingContext.isBrowserEnvironment(); + + this.isWorkerEnvironment = this.operatingContext.isWorkerEnvironment(); + + this.config = operatingContext.getConfig(); + + this.logger = operatingContext.getLogger(); + + this.networkClient = this.config.system.networkClient; + + // Initialize performance client + this.performanceClient = this.config.telemetry.client; + + // Initialize the crypto class. + this.workerCrypto = new CryptoOps(this.logger, this.performanceClient); + + this.eventHandler = new EventHandler(this.logger); + + this.navigationClient = this.config.system.navigationClient; + + // Initialize the browser storage class. + this.workerStorage = this.isWorkerEnvironment + ? new WorkerCacheManager( + this.config.auth.clientId, + this.config.cache, + this.workerCrypto, + this.logger, + this.performanceClient, + this.eventHandler, + undefined + ) + : DEFAULT_WORKER_CACHE_MANAGER( + this.config.auth.clientId, + this.logger, + this.performanceClient, + this.eventHandler + ); + + this.activeSilentTokenRequests = new Map(); + } + + // TODO: Dedupe with StandardController + static async createController( + operatingContext: BaseOperatingContext, + request?: InitializeApplicationRequest + ): Promise { + const controller = new WorkerController(operatingContext); + await controller.initialize(request); + return controller; + } + + // TODO: Dedupe with StandardController + /** + * Initializer function to perform async startup tasks such as connecting to WAM extension + * @param request {?InitializeApplicationRequest} correlation id + */ + async initialize(request?: InitializeApplicationRequest): Promise { + this.logger.trace("initialize WorkerOperatingContext called"); + if (this.initialized) { + this.logger.info( + "initialize has already been called, exiting early." + ); + return; + } + + if (this.isBrowserEnvironment) { + this.logger.info("in browser environment, exiting early."); + this.initialized = true; + this.eventHandler.emitEvent(EventType.INITIALIZE_END); + return; + } + + const initCorrelationId = + request?.correlationId || this.getRequestCorrelationId(); + const allowPlatformBroker = this.config.system.allowPlatformBroker; + const initMeasurement = this.performanceClient.startMeasurement( + PerformanceEvents.InitializeClientApplication, + initCorrelationId + ); + this.eventHandler.emitEvent(EventType.INITIALIZE_START); + + await invokeAsync( + this.workerStorage.initialize.bind(this.workerStorage), + PerformanceEvents.InitializeCache, + this.logger, + this.performanceClient, + initCorrelationId + )(); + + if (allowPlatformBroker) { + try { + this.nativeExtensionProvider = + await NativeMessageHandler.createProvider( + this.logger, + this.config.system.nativeBrokerHandshakeTimeout, + this.performanceClient + ); + } catch (e) { + this.logger.verbose(e as string); + } + } + + if (!this.config.cache.claimsBasedCachingEnabled) { + this.logger.verbose( + "Claims-based caching is disabled. Clearing the previous cache with claims" + ); + + await invokeAsync( + this.workerStorage.clearTokensAndKeysWithClaims.bind( + this.workerStorage + ), + PerformanceEvents.ClearTokensAndKeysWithClaims, + this.logger, + this.performanceClient, + initCorrelationId + )(this.performanceClient, initCorrelationId); + } + + /* + * this.config.system.asyncPopups && + * (await this.preGeneratePkceCodes(initCorrelationId)); + */ + this.initialized = true; + this.eventHandler.emitEvent(EventType.INITIALIZE_END); + initMeasurement.end({ + allowPlatformBroker: allowPlatformBroker, + success: true, + }); + } + + /** + * Returns all the accounts in the cache that match the optional filter. If no filter is provided, all accounts are returned. + * @param accountFilter - (Optional) filter to narrow down the accounts returned + * @returns Array of AccountInfo objects in cache + */ + async getAllAccountsAsync(accountFilter?: AccountFilter): Promise { + return AsyncAccountManager.getAllAccounts( + this.logger, + this.workerStorage, + accountFilter + ); + } + + /** + * Silently acquire an access token for a given set of scopes. Returns currently processing promise if parallel requests are made. + * + * @param {@link (SilentRequest:type)} + * @returns {Promise.} - a promise that is fulfilled when this function has completed, or rejected if an error was raised. Returns the {@link AuthResponse} object + */ + async acquireTokenSilent( + request: SilentRequest + ): Promise { + const correlationId = this.getRequestCorrelationId(request); + const atsMeasurement = this.performanceClient.startMeasurement( + PerformanceEvents.AcquireTokenSilent, + correlationId + ); + atsMeasurement.add({ + cacheLookupPolicy: request.cacheLookupPolicy, + scenarioId: request.scenarioId, + }); + + preflightCheck(this.initialized, atsMeasurement); + this.logger.verbose("acquireTokenSilent called", correlationId); + + const account = request.account || this.getActiveAccount(); + if (!account) { + throw createBrowserAuthError(BrowserAuthErrorCodes.noAccountError); + } + atsMeasurement.add({ accountType: getAccountType(account) }); + + return this.acquireTokenSilentDeduped(request, account, correlationId) + .then((result) => { + atsMeasurement.end({ + success: true, + fromCache: result.fromCache, + isNativeBroker: result.fromNativeBroker, + accessTokenSize: result.accessToken.length, + idTokenSize: result.idToken.length, + }); + return { + ...result, + state: request.state, + correlationId: correlationId, // Ensures PWB scenarios can correctly match request to response + }; + }) + .catch((error: Error) => { + if (error instanceof AuthError) { + // Ensures PWB scenarios can correctly match request to response + error.setCorrelationId(correlationId); + } + + atsMeasurement.end( + { + success: false, + }, + error + ); + throw error; + }); + } + + /** + * Checks if identical request is already in flight and returns reference to the existing promise or fires off a new one if this is the first + * @param request + * @param account + * @param correlationId + * @returns + */ + private async acquireTokenSilentDeduped( + request: SilentRequest, + account: AccountInfo, + correlationId: string + ): Promise { + const thumbprint = getRequestThumbprint( + this.config.auth.clientId, + { + ...request, + authority: request.authority || this.config.auth.authority, + correlationId: correlationId, + }, + account.homeAccountId + ); + const silentRequestKey = JSON.stringify(thumbprint); + + const inProgressRequest = + this.activeSilentTokenRequests.get(silentRequestKey); + + if (typeof inProgressRequest === "undefined") { + this.logger.verbose( + "acquireTokenSilent called for the first time, storing active request", + correlationId + ); + this.performanceClient.addFields({ deduped: false }, correlationId); + + const activeRequest = invokeAsync( + this.acquireTokenSilentAsync.bind(this), + PerformanceEvents.AcquireTokenSilentAsync, + this.logger, + this.performanceClient, + correlationId + )( + { + ...request, + correlationId, + }, + account + ); + this.activeSilentTokenRequests.set(silentRequestKey, activeRequest); + + return activeRequest.finally(() => { + this.activeSilentTokenRequests.delete(silentRequestKey); + }); + } else { + this.logger.verbose( + "acquireTokenSilent has been called previously, returning the result from the first call", + correlationId + ); + this.performanceClient.addFields({ deduped: true }, correlationId); + return inProgressRequest; + } + } + + /** + * Silently acquire an access token for a given set of scopes. Will use cached token if available, otherwise will attempt to acquire a new token from the network via refresh token. + * @param {@link (SilentRequest:type)} + * @param {@link (AccountInfo:type)} + * @returns {Promise.} - a promise that is fulfilled when this function has completed, or rejected if an error was raised. Returns the {@link AuthResponse} + */ + protected async acquireTokenSilentAsync( + request: SilentRequest & { correlationId: string }, + account: AccountInfo + ): Promise { + this.performanceClient.addQueueMeasurement( + PerformanceEvents.AcquireTokenSilentAsync, + request.correlationId + ); + + this.eventHandler.emitEvent( + EventType.ACQUIRE_TOKEN_START, + InteractionType.Silent, + request + ); + + if (request.correlationId) { + this.performanceClient.incrementFields( + { visibilityChangeCount: 0 }, + request.correlationId + ); + } + + const silentRequest = await invokeAsync( + initializeSilentRequest, + PerformanceEvents.InitializeSilentRequest, + this.logger, + this.performanceClient, + request.correlationId + )(request, account, this.config, this.performanceClient, this.logger); + const cacheLookupPolicy = + request.cacheLookupPolicy || CacheLookupPolicy.Default; + + const result = this.acquireTokenSilentNoIframe( + silentRequest, + cacheLookupPolicy + ).catch(async (refreshTokenError: AuthError) => { + // Error cannot be silently resolved since iframes are not allowed + throw refreshTokenError; + }); + + return result + .then((response) => { + this.eventHandler.emitEvent( + EventType.ACQUIRE_TOKEN_SUCCESS, + InteractionType.Silent, + response + ); + if (request.correlationId) { + this.performanceClient.addFields( + { + fromCache: response.fromCache, + isNativeBroker: response.fromNativeBroker, + }, + request.correlationId + ); + } + + return response; + }) + .catch((tokenRenewalError: Error) => { + this.eventHandler.emitEvent( + EventType.ACQUIRE_TOKEN_FAILURE, + InteractionType.Silent, + null, + tokenRenewalError + ); + throw tokenRenewalError; + }); + } + + /** + * AcquireTokenSilent without the iframe fallback. This is used to enable the correct fallbacks in cases where there's a potential for multiple silent requests to be made in parallel and prevent those requests from making concurrent iframe requests. + * @param silentRequest + * @param cacheLookupPolicy + * @returns + */ + private async acquireTokenSilentNoIframe( + silentRequest: CommonSilentFlowRequest, + cacheLookupPolicy: CacheLookupPolicy + ): Promise { + /* + * if the cache policy is set to access_token only, we should not be hitting the native layer yet + * if ( + * NativeMessageHandler.isPlatformBrokerAvailable( + * this.config, + * this.logger, + * this.nativeExtensionProvider, + * silentRequest.authenticationScheme + * ) && + * silentRequest.account.nativeAccountId + * ) { + * this.logger.verbose( + * "acquireTokenSilent - attempting to acquire token from native platform" + * ); + * return this.acquireTokenNative( + * silentRequest, + * ApiId.acquireTokenSilent_silentFlow, + * silentRequest.account.nativeAccountId, + * cacheLookupPolicy + * ).catch(async (e: AuthError) => { + * // If native token acquisition fails for availability reasons fallback to web flow + * if (e instanceof NativeAuthError && isFatalNativeAuthError(e)) { + * this.logger.verbose( + * "acquireTokenSilent - native platform unavailable, falling back to web flow" + * ); + * this.nativeExtensionProvider = undefined; // Prevent future requests from continuing to attempt + */ + + /* + * // Cache will not contain tokens, given that previous WAM requests succeeded. Skip cache and RT renewal and go straight to iframe renewal + * throw createClientAuthError( + * ClientAuthErrorCodes.tokenRefreshRequired + * ); + * } + * throw e; + * }); + * } else { + */ + this.logger.verbose( + "acquireTokenSilent - attempting to acquire token from web flow" + ); + // add logs to identify embedded cache retrieval + if (cacheLookupPolicy === CacheLookupPolicy.AccessToken) { + this.logger.verbose( + "acquireTokenSilent - cache lookup policy set to AccessToken, attempting to acquire token from local cache" + ); + } + return invokeAsync( + this.acquireTokenFromCache.bind(this), + PerformanceEvents.AcquireTokenFromCache, + this.logger, + this.performanceClient, + silentRequest.correlationId + )(silentRequest, cacheLookupPolicy).catch( + (cacheError: AuthError) => { + if (cacheLookupPolicy === CacheLookupPolicy.AccessToken) { + throw cacheError; + } + + this.eventHandler.emitEvent( + EventType.ACQUIRE_TOKEN_NETWORK_START, + InteractionType.Silent, + silentRequest + ); + + return invokeAsync( + this.acquireTokenByRefreshToken.bind(this), + PerformanceEvents.AcquireTokenByRefreshToken, + this.logger, + this.performanceClient, + silentRequest.correlationId + )(silentRequest, cacheLookupPolicy); + } + ); + // } + } + + /** + * Attempt to acquire an access token from the cache + * @param silentCacheClient SilentCacheClient + * @param commonRequest CommonSilentFlowRequest + * @param silentRequest SilentRequest + * @returns A promise that, when resolved, returns the access token + */ + protected async acquireTokenFromCache( + commonRequest: CommonSilentFlowRequest, + cacheLookupPolicy: CacheLookupPolicy + ): Promise { + this.performanceClient.addQueueMeasurement( + PerformanceEvents.AcquireTokenFromCache, + commonRequest.correlationId + ); + switch (cacheLookupPolicy) { + case CacheLookupPolicy.Default: + case CacheLookupPolicy.AccessToken: + case CacheLookupPolicy.AccessTokenAndRefreshToken: + const silentCacheClient = this.createSilentCacheClient( + commonRequest.correlationId + ); + return invokeAsync( + silentCacheClient.acquireToken.bind(silentCacheClient), + PerformanceEvents.SilentCacheClientAcquireToken, + this.logger, + this.performanceClient, + commonRequest.correlationId + )(commonRequest); + default: + throw createClientAuthError( + ClientAuthErrorCodes.tokenRefreshRequired + ); + } + + } + + /** + * Attempt to acquire an access token via a refresh token + * @param commonRequest CommonSilentFlowRequest + * @param cacheLookupPolicy CacheLookupPolicy + * @returns A promise that, when resolved, returns the access token + */ + public async acquireTokenByRefreshToken( + commonRequest: CommonSilentFlowRequest, + cacheLookupPolicy: CacheLookupPolicy + ): Promise { + this.performanceClient.addQueueMeasurement( + PerformanceEvents.AcquireTokenByRefreshToken, + commonRequest.correlationId + ); + switch (cacheLookupPolicy) { + case CacheLookupPolicy.Default: + case CacheLookupPolicy.AccessTokenAndRefreshToken: + case CacheLookupPolicy.RefreshToken: + case CacheLookupPolicy.RefreshTokenAndNetwork: + const silentRefreshClient = this.createSilentRefreshClient( + commonRequest.correlationId + ); + + return invokeAsync( + silentRefreshClient.acquireToken.bind(silentRefreshClient), + PerformanceEvents.SilentRefreshClientAcquireToken, + this.logger, + this.performanceClient, + commonRequest.correlationId + )(commonRequest); + default: + throw createClientAuthError( + ClientAuthErrorCodes.tokenRefreshRequired + ); + } + } + + /** + * Returns new instance of the Silent Cache Interaction Client + */ + protected createSilentCacheClient( + correlationId?: string + ): SilentCacheClient { + return new SilentCacheClient( + this.config, + this.workerStorage, + this.workerCrypto, + this.logger, + this.eventHandler, + this.navigationClient, + this.performanceClient, + this.nativeExtensionProvider, + correlationId + ); + } + + /** + * Returns new instance of the Silent Refresh Interaction Client + */ + protected createSilentRefreshClient( + correlationId?: string + ): SilentRefreshClient { + return new SilentRefreshClient( + this.config, + this.workerStorage, + this.workerCrypto, + this.logger, + this.eventHandler, + this.navigationClient, + this.performanceClient, + this.nativeExtensionProvider, + correlationId + ); + } + + // TODO: Dedupe with StandardController + /** + * Generates a correlation id for a request if none is provided. + * + * @protected + * @param {?Partial} [request] + * @returns {string} + */ + protected getRequestCorrelationId( + request?: Partial + ): string { + if (request?.correlationId) { + return request.correlationId; + } + + if (this.isWorkerEnvironment){ + return createNewGuid(); + } + + /* + * Included for fallback for non-browser environments, + * and to ensure this method always returns a string. + */ + return Constants.EMPTY_STRING; + } + + /** + * Hydrates the cache with the tokens from an AuthenticationResult + * @param result + * @param request + * @returns + */ + async hydrateCache( + result: AuthenticationResult, + request: + | SilentRequest + | SsoSilentRequest + | RedirectRequest + | PopupRequest + ): Promise { + this.logger.verbose("hydrateCache called"); + + // Account gets saved to browser storage regardless of native or not + const accountEntity = AccountEntity.createFromAccountInfo( + result.account, + result.cloudGraphHostName, + result.msGraphHost + ); + await this.workerStorage.setAccount( + accountEntity + ); + + /* + * if (result.fromNativeBroker) { + * this.logger.verbose( + * "Response was from native broker, storing in-memory" + * ); + * // Tokens from native broker are stored in-memory + * return this.nativeInternalStorage.hydrateCache(result, request); + * } else { + */ + return this.workerStorage.hydrateCache(result, request); + // } + } + + getBrowserStorage(): WorkerCacheManager { + return this.workerStorage; + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + getAccount(accountFilter: AccountFilter): AccountInfo | null { + return null; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + getAccountByHomeId(homeAccountId: string): AccountInfo | null { + return null; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + getAccountByLocalId(localAccountId: string): AccountInfo | null { + return null; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + getAccountByUsername(username: string): AccountInfo | null { + return null; + } + getAllAccounts(): AccountInfo[] { + return []; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + acquireTokenPopup(request: PopupRequest): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + acquireTokenRedirect(request: RedirectRequest): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return Promise.resolve(); + } + acquireTokenByCode( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + request: AuthorizationCodeRequest + ): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + acquireTokenNative( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + request: + | PopupRequest + | SilentRequest + | Partial< + Omit< + CommonAuthorizationUrlRequest, + | "responseMode" + | "earJwk" + | "codeChallenge" + | "codeChallengeMethod" + | "requestedClaimsHash" + | "platformBroker" + > + >, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + apiId: ApiId, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + accountId?: string | undefined + ): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + addEventCallback( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + callback: EventCallbackFunction, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + eventTypes?: Array + ): string | null { + return null; + } + removeEventCallback( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + callbackId: string + ): void {} + // eslint-disable-next-line @typescript-eslint/no-unused-vars + addPerformanceCallback(callback: PerformanceCallbackFunction): string { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return ""; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + removePerformanceCallback(callbackId: string): boolean { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return true; + } + enableAccountStorageEvents(): void { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + } + disableAccountStorageEvents(): void { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + } + + handleRedirectPromise( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + hash?: string | undefined + ): Promise { + blockAPICallsBeforeInitialize(this.initialized); + return Promise.resolve(null); + } + loginPopup( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + request?: PopupRequest | undefined + ): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + loginRedirect(request?: RedirectRequest | undefined): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + logout(logoutRequest?: EndSessionRequest | undefined): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + logoutRedirect( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + logoutRequest?: EndSessionRequest | undefined + ): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + logoutPopup( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + logoutRequest?: EndSessionPopupRequest | undefined + ): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + ssoSilent( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + request: Partial< + Omit< + CommonAuthorizationUrlRequest, + | "responseMode" + | "earJwk" + | "codeChallenge" + | "codeChallengeMethod" + | "requestedClaimsHash" + | "platformBroker" + > + > + ): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Promise; + } + getTokenCache(): ITokenCache { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as ITokenCache; + } + getLogger(): Logger { + return this.logger; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + setLogger(logger: Logger): void { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + setActiveAccount(account: AccountInfo | null): void { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + } + getActiveAccount(): AccountInfo | null { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return null; + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + initializeWrapperLibrary(sku: WrapperSKU, version: string): void { + this.workerStorage.setWrapperMetadata(sku, version); + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars + setNavigationClient(navigationClient: INavigationClient): void { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + } + getConfiguration(): BrowserConfiguration { + return this.config; + } + isBrowserEnv(): boolean { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return true; + } + getBrowserCrypto(): ICrypto { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as ICrypto; + } + getPerformanceClient(): IPerformanceClient { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as IPerformanceClient; + } + getRedirectResponse(): Map> { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + return {} as Map>; + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + async clearCache(logoutRequest?: ClearCacheRequest): Promise { + blockAPICallsBeforeInitialize(this.initialized); + blockNonBrowserEnvironment(); + } +} diff --git a/lib/msal-browser/src/crypto/CryptoOps.ts b/lib/msal-browser/src/crypto/CryptoOps.ts index fc30dfd56d..c9c60ba8d7 100644 --- a/lib/msal-browser/src/crypto/CryptoOps.ts +++ b/lib/msal-browser/src/crypto/CryptoOps.ts @@ -20,6 +20,7 @@ import { } from "../encode/Base64Encode.js"; import { base64Decode } from "../encode/Base64Decode.js"; import * as BrowserCrypto from "./BrowserCrypto.js"; +import * as WorkerCrypto from "./WorkerCrypto.js"; import { createBrowserAuthError, BrowserAuthErrorCodes, @@ -32,13 +33,13 @@ export type CachedKeyPair = { requestMethod?: string; requestUri?: string; }; - /** * This class implements MSAL's crypto interface, which allows it to perform base64 encoding and decoding, generating cryptographically random GUIDs and * implementing Proof Key for Code Exchange specs for the OAuth Authorization Code Flow using PKCE (rfc here: https://tools.ietf.org/html/rfc7636). */ export class CryptoOps implements ICrypto { private logger: Logger; + private cryptoImpl: typeof BrowserCrypto | typeof WorkerCrypto; /** * CryptoOps can be used in contexts outside a PCA instance, @@ -56,8 +57,10 @@ export class CryptoOps implements ICrypto { skipValidateSubtleCrypto?: boolean ) { this.logger = logger; + this.cryptoImpl = (typeof window !== "undefined") ? BrowserCrypto : WorkerCrypto; + // Browser crypto needs to be validated first before any other classes can be set. - BrowserCrypto.validateCryptoAvailable( + this.cryptoImpl.validateCryptoAvailable( skipValidateSubtleCrypto ?? false ); this.cache = new AsyncMemoryStorage(this.logger); @@ -69,7 +72,7 @@ export class CryptoOps implements ICrypto { * @returns string (GUID) */ createNewGuid(): string { - return BrowserCrypto.createNewGuid(); + return this.cryptoImpl.createNewGuid(); } /** @@ -119,13 +122,13 @@ export class CryptoOps implements ICrypto { ); // Generate Keypair - const keyPair: CryptoKeyPair = await BrowserCrypto.generateKeyPair( + const keyPair: CryptoKeyPair = await this.cryptoImpl.generateKeyPair( CryptoOps.EXTRACTABLE, CryptoOps.POP_KEY_USAGES ); // Generate Thumbprint for Public Key - const publicKeyJwk: JsonWebKey = await BrowserCrypto.exportJwk( + const publicKeyJwk: JsonWebKey = await this.cryptoImpl.exportJwk( keyPair.publicKey ); @@ -140,12 +143,12 @@ export class CryptoOps implements ICrypto { const publicJwkHash = await this.hashString(publicJwkString); // Generate Thumbprint for Private Key - const privateKeyJwk: JsonWebKey = await BrowserCrypto.exportJwk( + const privateKeyJwk: JsonWebKey = await this.cryptoImpl.exportJwk( keyPair.privateKey ); // Re-import private key to make it unextractable const unextractablePrivateKey: CryptoKey = - await BrowserCrypto.importJwk(privateKeyJwk, false, ["sign"]); + await this.cryptoImpl.importJwk(privateKeyJwk, false, ["sign"]); // Store Keypair data in keystore await this.cache.setItem(publicJwkHash, { @@ -227,7 +230,7 @@ export class CryptoOps implements ICrypto { } // Get public key as JWK - const publicKeyJwk = await BrowserCrypto.exportJwk( + const publicKeyJwk = await this.cryptoImpl.exportJwk( cachedKeyPair.publicKey ); const publicKeyJwkString = getSortedObjectString(publicKeyJwk); @@ -254,7 +257,7 @@ export class CryptoOps implements ICrypto { // Sign token const encoder = new TextEncoder(); const tokenBuffer = encoder.encode(tokenString); - const signatureBuffer = await BrowserCrypto.sign( + const signatureBuffer = await this.cryptoImpl.sign( cachedKeyPair.privateKey, tokenBuffer ); @@ -276,7 +279,7 @@ export class CryptoOps implements ICrypto { * @param plainText */ async hashString(plainText: string): Promise { - return BrowserCrypto.hashString(plainText); + return this.cryptoImpl.hashString(plainText); } } diff --git a/lib/msal-browser/src/crypto/WorkerCrypto.ts b/lib/msal-browser/src/crypto/WorkerCrypto.ts new file mode 100644 index 0000000000..ef8f84d63b --- /dev/null +++ b/lib/msal-browser/src/crypto/WorkerCrypto.ts @@ -0,0 +1,434 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { + createBrowserAuthError, + BrowserAuthErrorCodes, +} from "../error/BrowserAuthError.js"; +import { + IPerformanceClient, + PerformanceEvents, +} from "@azure/msal-common/browser"; +import { KEY_FORMAT_JWK } from "../utils/BrowserConstants.js"; +import { base64Encode, urlEncodeArr } from "../encode/Base64Encode.js"; +import { base64Decode, base64DecToArr } from "../encode/Base64Decode.js"; + +/** + * This file defines functions used by the browser library to perform cryptography operations such as + * hashing and encoding. It also has helper functions to validate the availability of specific APIs. + */ + +/** + * See here for more info on RsaHashedKeyGenParams: https://developer.mozilla.org/en-US/docs/Web/API/RsaHashedKeyGenParams + */ +// Algorithms +const PKCS1_V15_KEYGEN_ALG = "RSASSA-PKCS1-v1_5"; +const AES_GCM = "AES-GCM"; +const HKDF = "HKDF"; +// SHA-256 hashing algorithm +const S256_HASH_ALG = "SHA-256"; +// MOD length for PoP tokens +const MODULUS_LENGTH = 2048; +// Public Exponent +const PUBLIC_EXPONENT: Uint8Array = new Uint8Array([0x01, 0x00, 0x01]); +// UUID hex digits +const UUID_CHARS = "0123456789abcdef"; +// Array to store UINT32 random value +const UINT32_ARR = new Uint32Array(1); + +// Key Format +const RAW = "raw"; +// Key Usages +const ENCRYPT = "encrypt"; +const DECRYPT = "decrypt"; +const DERIVE_KEY = "deriveKey"; + +// Suberror +const SUBTLE_SUBERROR = "crypto_subtle_undefined"; + +const keygenAlgorithmOptions: RsaHashedKeyGenParams = { + name: PKCS1_V15_KEYGEN_ALG, + hash: S256_HASH_ALG, + modulusLength: MODULUS_LENGTH, + publicExponent: PUBLIC_EXPONENT, +}; + +/** + * Check whether browser crypto is available. + */ +export function validateCryptoAvailable( + skipValidateSubtleCrypto: boolean +): void { + if (!self) { + throw createBrowserAuthError( + BrowserAuthErrorCodes.nonBrowserEnvironment + ); + } + if (!self.crypto) { + throw createBrowserAuthError(BrowserAuthErrorCodes.cryptoNonExistent); + } + if (!skipValidateSubtleCrypto && !self.crypto.subtle) { + throw createBrowserAuthError( + BrowserAuthErrorCodes.cryptoNonExistent, + SUBTLE_SUBERROR + ); + } +} + +/** + * Returns a sha-256 hash of the given dataString as an ArrayBuffer. + * @param dataString {string} data string + * @param performanceClient {?IPerformanceClient} + * @param correlationId {?string} correlation id + */ +export async function sha256Digest( + dataString: string, + performanceClient?: IPerformanceClient, + correlationId?: string +): Promise { + performanceClient?.addQueueMeasurement( + PerformanceEvents.Sha256Digest, + correlationId + ); + const encoder = new TextEncoder(); + const data = encoder.encode(dataString); + return self.crypto.subtle.digest( + S256_HASH_ALG, + data + ) as Promise; +} + +/** + * Populates buffer with cryptographically random values. + * @param dataBuffer + */ +export function getRandomValues(dataBuffer: Uint8Array): Uint8Array { + return self.crypto.getRandomValues(dataBuffer); +} + +/** + * Returns random Uint32 value. + * @returns {number} + */ +function getRandomUint32(): number { + self.crypto.getRandomValues(UINT32_ARR); + return UINT32_ARR[0]; +} + +/** + * Creates a UUID v7 from the current timestamp. + * Implementation relies on the system clock to guarantee increasing order of generated identifiers. + * @returns {number} + */ +export function createNewGuid(): string { + const currentTimestamp = Date.now(); + const baseRand = getRandomUint32() * 0x400 + (getRandomUint32() & 0x3ff); + + // Result byte array + const bytes = new Uint8Array(16); + // A 12-bit `rand_a` field value + const randA = Math.trunc(baseRand / 2 ** 30); + // The higher 30 bits of 62-bit `rand_b` field value + const randBHi = baseRand & (2 ** 30 - 1); + // The lower 32 bits of 62-bit `rand_b` field value + const randBLo = getRandomUint32(); + + bytes[0] = currentTimestamp / 2 ** 40; + bytes[1] = currentTimestamp / 2 ** 32; + bytes[2] = currentTimestamp / 2 ** 24; + bytes[3] = currentTimestamp / 2 ** 16; + bytes[4] = currentTimestamp / 2 ** 8; + bytes[5] = currentTimestamp; + bytes[6] = 0x70 | (randA >>> 8); + bytes[7] = randA; + bytes[8] = 0x80 | (randBHi >>> 24); + bytes[9] = randBHi >>> 16; + bytes[10] = randBHi >>> 8; + bytes[11] = randBHi; + bytes[12] = randBLo >>> 24; + bytes[13] = randBLo >>> 16; + bytes[14] = randBLo >>> 8; + bytes[15] = randBLo; + + let text = ""; + for (let i = 0; i < bytes.length; i++) { + text += UUID_CHARS.charAt(bytes[i] >>> 4); + text += UUID_CHARS.charAt(bytes[i] & 0xf); + if (i === 3 || i === 5 || i === 7 || i === 9) { + text += "-"; + } + } + return text; +} + +/** + * Generates a keypair based on current keygen algorithm config. + * @param extractable + * @param usages + */ +export async function generateKeyPair( + extractable: boolean, + usages: Array +): Promise { + return self.crypto.subtle.generateKey( + keygenAlgorithmOptions, + extractable, + usages + ) as Promise; +} + +/** + * Export key as Json Web Key (JWK) + * @param key + */ +export async function exportJwk(key: CryptoKey): Promise { + return self.crypto.subtle.exportKey( + KEY_FORMAT_JWK, + key + ) as Promise; +} + +/** + * Imports key as Json Web Key (JWK), can set extractable and usages. + * @param key + * @param extractable + * @param usages + */ +export async function importJwk( + key: JsonWebKey, + extractable: boolean, + usages: Array +): Promise { + return self.crypto.subtle.importKey( + KEY_FORMAT_JWK, + key, + keygenAlgorithmOptions, + extractable, + usages + ) as Promise; +} + +/** + * Signs given data with given key + * @param key + * @param data + */ +export async function sign( + key: CryptoKey, + data: ArrayBuffer +): Promise { + return self.crypto.subtle.sign( + keygenAlgorithmOptions, + key, + data + ) as Promise; +} + +/** + * Generates Base64 encoded jwk used in the Encrypted Authorize Response (EAR) flow + */ +export async function generateEarKey(): Promise { + const key = await generateBaseKey(); + const keyStr = urlEncodeArr(new Uint8Array(key)); + + const jwk = { + alg: "dir", + kty: "oct", + k: keyStr, + }; + + return base64Encode(JSON.stringify(jwk)); +} + +/** + * Parses earJwk for encryption key and returns CryptoKey object + * @param earJwk + * @returns + */ +export async function importEarKey(earJwk: string): Promise { + const b64DecodedJwk = base64Decode(earJwk); + const jwkJson = JSON.parse(b64DecodedJwk); + const rawKey = jwkJson.k; + const keyBuffer = base64DecToArr(rawKey); + + return self.crypto.subtle.importKey(RAW, keyBuffer, AES_GCM, false, [ + DECRYPT, + ]); +} + +/** + * Decrypt ear_jwe response returned in the Encrypted Authorize Response (EAR) flow + * @param earJwk + * @param earJwe + * @returns + */ +export async function decryptEarResponse( + earJwk: string, + earJwe: string +): Promise { + const earJweParts = earJwe.split("."); + if (earJweParts.length !== 5) { + throw createBrowserAuthError( + BrowserAuthErrorCodes.failedToDecryptEarResponse, + "jwe_length" + ); + } + + const key = await importEarKey(earJwk).catch(() => { + throw createBrowserAuthError( + BrowserAuthErrorCodes.failedToDecryptEarResponse, + "import_key" + ); + }); + + try { + const header = new TextEncoder().encode(earJweParts[0]); + const iv = base64DecToArr(earJweParts[2]); + const ciphertext = base64DecToArr(earJweParts[3]); + const tag = base64DecToArr(earJweParts[4]); + const tagLengthBits = tag.byteLength * 8; + + // Concat ciphertext and tag + const encryptedData = new Uint8Array(ciphertext.length + tag.length); + encryptedData.set(ciphertext); + encryptedData.set(tag, ciphertext.length); + + const decryptedData = await self.crypto.subtle.decrypt( + { + name: AES_GCM, + iv: iv, + tagLength: tagLengthBits, + additionalData: header, + }, + key, + encryptedData + ); + + return new TextDecoder().decode(decryptedData); + } catch (e) { + throw createBrowserAuthError( + BrowserAuthErrorCodes.failedToDecryptEarResponse, + "decrypt" + ); + } +} + +/** + * Generates symmetric base encryption key. This may be stored as all encryption/decryption keys will be derived from this one. + */ +export async function generateBaseKey(): Promise { + const key = await self.crypto.subtle.generateKey( + { + name: AES_GCM, + length: 256, + }, + true, + [ENCRYPT, DECRYPT] + ); + return self.crypto.subtle.exportKey(RAW, key); +} + +/** + * Returns the raw key to be passed into the key derivation function + * @param baseKey + * @returns + */ +export async function generateHKDF(baseKey: ArrayBuffer): Promise { + return self.crypto.subtle.importKey(RAW, baseKey, HKDF, false, [ + DERIVE_KEY, + ]); +} + +/** + * Given a base key and a nonce generates a derived key to be used in encryption and decryption. + * Note: every time we encrypt a new key is derived + * @param baseKey + * @param nonce + * @returns + */ +async function deriveKey( + baseKey: CryptoKey, + nonce: ArrayBuffer, + context: string +): Promise { + return self.crypto.subtle.deriveKey( + { + name: HKDF, + salt: nonce, + hash: S256_HASH_ALG, + info: new TextEncoder().encode(context), + }, + baseKey, + { name: AES_GCM, length: 256 }, + false, + [ENCRYPT, DECRYPT] + ); +} + +/** + * Encrypt the given data given a base key. Returns encrypted data and a nonce that must be provided during decryption + * @param key + * @param rawData + */ +export async function encrypt( + baseKey: CryptoKey, + rawData: string, + context: string +): Promise<{ data: string; nonce: string }> { + const encodedData = new TextEncoder().encode(rawData); + // The nonce must never be reused with a given key. + const nonce = self.crypto.getRandomValues(new Uint8Array(16)); + const derivedKey = await deriveKey(baseKey, nonce, context); + const encryptedData = await self.crypto.subtle.encrypt( + { + name: AES_GCM, + iv: new Uint8Array(12), // New key is derived for every encrypt so we don't need a new nonce + }, + derivedKey, + encodedData + ); + + return { + data: urlEncodeArr(new Uint8Array(encryptedData)), + nonce: urlEncodeArr(nonce), + }; +} + +/** + * Decrypt data with the given key and nonce + * @param key + * @param nonce + * @param encryptedData + * @returns + */ +export async function decrypt( + baseKey: CryptoKey, + nonce: string, + context: string, + encryptedData: string +): Promise { + const encodedData = base64DecToArr(encryptedData); + const derivedKey = await deriveKey(baseKey, base64DecToArr(nonce), context); + const decryptedData = await self.crypto.subtle.decrypt( + { + name: AES_GCM, + iv: new Uint8Array(12), // New key is derived for every encrypt so we don't need a new nonce + }, + derivedKey, + encodedData + ); + + return new TextDecoder().decode(decryptedData); +} + +/** + * Returns the SHA-256 hash of an input string + * @param plainText + */ +export async function hashString(plainText: string): Promise { + const hashBuffer: ArrayBuffer = await sha256Digest(plainText); + const hashBytes = new Uint8Array(hashBuffer); + return urlEncodeArr(hashBytes); +} diff --git a/lib/msal-browser/src/interaction_client/BaseInteractionClient.ts b/lib/msal-browser/src/interaction_client/BaseInteractionClient.ts index b8ef74c55e..5220b46261 100644 --- a/lib/msal-browser/src/interaction_client/BaseInteractionClient.ts +++ b/lib/msal-browser/src/interaction_client/BaseInteractionClient.ts @@ -38,10 +38,11 @@ import { NativeMessageHandler } from "../broker/nativeBroker/NativeMessageHandle import { AuthenticationResult } from "../response/AuthenticationResult.js"; import { ClearCacheRequest } from "../request/ClearCacheRequest.js"; import { createNewGuid } from "../crypto/BrowserCrypto.js"; +import { WorkerCacheManager } from "../cache/WorkerCacheManager.js"; export abstract class BaseInteractionClient { protected config: BrowserConfiguration; - protected browserStorage: BrowserCacheManager; + protected browserStorage: BrowserCacheManager | WorkerCacheManager; protected browserCrypto: ICrypto; protected networkClient: INetworkModule; protected logger: Logger; @@ -53,7 +54,7 @@ export abstract class BaseInteractionClient { constructor( config: BrowserConfiguration, - storageImpl: BrowserCacheManager, + storageImpl: BrowserCacheManager | WorkerCacheManager, browserCrypto: ICrypto, logger: Logger, eventHandler: EventHandler, @@ -93,12 +94,12 @@ export abstract class BaseInteractionClient { if ( AccountEntity.accountInfoIsEqual( account, - this.browserStorage.getActiveAccount(), + await this.browserStorage.getActiveAccount(), false ) ) { this.logger.verbose("Setting active account to null"); - this.browserStorage.setActiveAccount(null); + await this.browserStorage.setActiveAccount(null); } // Clear given account. try { diff --git a/lib/msal-browser/src/operatingcontext/BaseOperatingContext.ts b/lib/msal-browser/src/operatingcontext/BaseOperatingContext.ts index 300259bc9a..0c89d978b9 100644 --- a/lib/msal-browser/src/operatingcontext/BaseOperatingContext.ts +++ b/lib/msal-browser/src/operatingcontext/BaseOperatingContext.ts @@ -28,6 +28,7 @@ export abstract class BaseOperatingContext { protected config: BrowserConfiguration; protected available: boolean; protected browserEnvironment: boolean; + protected workerEnvironment: boolean; protected static loggerCallback(level: LogLevel, message: string): void { switch (level) { @@ -61,7 +62,8 @@ export abstract class BaseOperatingContext { * This is to support server-side rendering environments. */ this.browserEnvironment = typeof window !== "undefined"; - this.config = buildConfiguration(config, this.browserEnvironment); + this.workerEnvironment = typeof WorkerGlobalScope !== "undefined" && self instanceof WorkerGlobalScope; + this.config = buildConfiguration(config, this.browserEnvironment, this.workerEnvironment); let sessionStorage: Storage | undefined; try { @@ -136,4 +138,8 @@ export abstract class BaseOperatingContext { isBrowserEnvironment(): boolean { return this.browserEnvironment; } + + isWorkerEnvironment(): boolean { + return this.workerEnvironment; + } } diff --git a/lib/msal-browser/src/operatingcontext/WorkerOperatingContext.ts b/lib/msal-browser/src/operatingcontext/WorkerOperatingContext.ts new file mode 100644 index 0000000000..d853aebbae --- /dev/null +++ b/lib/msal-browser/src/operatingcontext/WorkerOperatingContext.ts @@ -0,0 +1,46 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { BaseOperatingContext } from "./BaseOperatingContext.js"; + +export class WorkerOperatingContext extends BaseOperatingContext { + /* + * TODO: Once we have determine the bundling code return here to specify the name of the bundle + * containing the implementation for this operating context + */ + static readonly MODULE_NAME: string = ""; + + /** + * Unique identifier for the operating context + */ + static readonly ID: string = "WorkerOperatingContext"; + + /** + * Return the module name. Intended for use with import() to enable dynamic import + * of the implementation associated with this operating context + * @returns + */ + getModuleName(): string { + return WorkerOperatingContext.MODULE_NAME; + } + + /** + * Returns the unique identifier for this operating context + * @returns string + */ + getId(): string { + return WorkerOperatingContext.ID; + } + + /** + * Checks whether the operating context is available. + * Confirms that the code is running in a web worker. This is required. + * @returns Promise indicating whether this operating context is currently available. + */ + async initialize(): Promise { + this.available = typeof WorkerGlobalScope !== 'undefined' && self instanceof WorkerGlobalScope; + return this.available; + } +} diff --git a/lib/msal-browser/src/response/AuthenticationResult.ts b/lib/msal-browser/src/response/AuthenticationResult.ts index 71b5b1b51f..ba2276e14d 100644 --- a/lib/msal-browser/src/response/AuthenticationResult.ts +++ b/lib/msal-browser/src/response/AuthenticationResult.ts @@ -10,4 +10,5 @@ import { export type AuthenticationResult = CommonAuthenticationResult & { account: AccountInfo; + refreshToken?: string; }; diff --git a/lib/msal-browser/tsconfig.json b/lib/msal-browser/tsconfig.json index b0f6adbdd8..8f87d87307 100644 --- a/lib/msal-browser/tsconfig.json +++ b/lib/msal-browser/tsconfig.json @@ -7,7 +7,8 @@ "lib": [ "es2020", "dom", - "es2020.promise" + "es2020.promise", + "webworker" ], "outDir": "./dist", "allowUnusedLabels": false, diff --git a/lib/msal-common/src/authority/Authority.ts b/lib/msal-common/src/authority/Authority.ts index 8264bb3267..4665eaa906 100644 --- a/lib/msal-common/src/authority/Authority.ts +++ b/lib/msal-common/src/authority/Authority.ts @@ -58,6 +58,7 @@ import { IPerformanceClient } from "../telemetry/performance/IPerformanceClient. import { PerformanceEvents } from "../telemetry/performance/PerformanceEvent.js"; import { invokeAsync } from "../utils/FunctionWrappers.js"; import * as CacheHelpers from "../cache/utils/CacheHelpers.js"; +import { AsyncCacheManager } from "../exports-common.js"; /** * The authority class validates the authority URIs used by the user, and retrieves the OpenID Configuration Data from the @@ -72,7 +73,7 @@ export class Authority { // Network interface to make requests with. protected networkInterface: INetworkModule; // Cache Manager to cache network responses - protected cacheManager: ICacheManager; + protected cacheManager: ICacheManager | AsyncCacheManager; // Protocol mode to construct endpoints private authorityOptions: AuthorityOptions; // Authority metadata @@ -101,7 +102,7 @@ export class Authority { constructor( authority: string, networkInterface: INetworkModule, - cacheManager: ICacheManager, + cacheManager: ICacheManager | AsyncCacheManager, authorityOptions: AuthorityOptions, logger: Logger, correlationId: string, @@ -403,7 +404,7 @@ export class Authority { this.correlationId ); - const metadataEntity = this.getCurrentMetadataEntity(); + const metadataEntity = await this.getCurrentMetadataEntity(); const cloudDiscoverySource = await invokeAsync( this.updateCloudDiscoveryMetadata.bind(this), @@ -423,7 +424,7 @@ export class Authority { this.performanceClient, this.correlationId )(metadataEntity); - this.updateCachedMetadata(metadataEntity, cloudDiscoverySource, { + await this.updateCachedMetadata(metadataEntity, cloudDiscoverySource, { source: endpointSource, }); this.performanceClient?.addFields( @@ -440,9 +441,9 @@ export class Authority { * from the configured canonical authority * @returns */ - private getCurrentMetadataEntity(): AuthorityMetadataEntity { + private async getCurrentMetadataEntity(): Promise { let metadataEntity: AuthorityMetadataEntity | null = - this.cacheManager.getAuthorityMetadataByAlias(this.hostnameAndPort); + await this.cacheManager.getAuthorityMetadataByAlias(this.hostnameAndPort); if (!metadataEntity) { metadataEntity = { @@ -470,14 +471,14 @@ export class Authority { * @param cloudDiscoverySource * @param endpointMetadataResult */ - private updateCachedMetadata( + private async updateCachedMetadata( metadataEntity: AuthorityMetadataEntity, cloudDiscoverySource: AuthorityMetadataSource | null, endpointMetadataResult: { source: AuthorityMetadataSource; metadata?: OpenIdConfigResponse; } | null - ): void { + ): Promise { if ( cloudDiscoverySource !== AuthorityMetadataSource.CACHE && endpointMetadataResult?.source !== AuthorityMetadataSource.CACHE @@ -491,7 +492,7 @@ export class Authority { const cacheKey = this.cacheManager.generateAuthorityMetadataCacheKey( metadataEntity.preferred_cache ); - this.cacheManager.setAuthorityMetadata(cacheKey, metadataEntity); + await this.cacheManager.setAuthorityMetadata(cacheKey, metadataEntity); this.metadata = metadataEntity; } diff --git a/lib/msal-common/src/authority/AuthorityFactory.ts b/lib/msal-common/src/authority/AuthorityFactory.ts index abeb099e72..baef0b9ab2 100644 --- a/lib/msal-common/src/authority/AuthorityFactory.ts +++ b/lib/msal-common/src/authority/AuthorityFactory.ts @@ -15,6 +15,7 @@ import { Logger } from "../logger/Logger.js"; import { IPerformanceClient } from "../telemetry/performance/IPerformanceClient.js"; import { PerformanceEvents } from "../telemetry/performance/PerformanceEvent.js"; import { invokeAsync } from "../utils/FunctionWrappers.js"; +import { AsyncCacheManager } from "../exports-common.js"; /** * Create an authority object of the correct type based on the url @@ -30,7 +31,7 @@ import { invokeAsync } from "../utils/FunctionWrappers.js"; export async function createDiscoveredInstance( authorityUri: string, networkClient: INetworkModule, - cacheManager: ICacheManager, + cacheManager: ICacheManager | AsyncCacheManager, authorityOptions: AuthorityOptions, logger: Logger, correlationId: string, diff --git a/lib/msal-common/src/cache/AsyncCacheManager.ts b/lib/msal-common/src/cache/AsyncCacheManager.ts new file mode 100644 index 0000000000..f0043083ff --- /dev/null +++ b/lib/msal-common/src/cache/AsyncCacheManager.ts @@ -0,0 +1,1945 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { + AccountFilter, + CredentialFilter, + ValidCredentialType, + AppMetadataFilter, + AppMetadataCache, + TokenKeys, + TenantProfileFilter, +} from "./utils/CacheTypes.js"; +import { CacheRecord } from "./entities/CacheRecord.js"; +import { + CredentialType, + APP_METADATA, + THE_FAMILY_ID, + AUTHORITY_METADATA_CONSTANTS, + AuthenticationScheme, + Separators, +} from "../utils/Constants.js"; +import { CredentialEntity } from "./entities/CredentialEntity.js"; +import { generateCredentialKey } from "./utils/CacheHelpers.js"; +import { ScopeSet } from "../request/ScopeSet.js"; +import { AccountEntity } from "./entities/AccountEntity.js"; +import { AccessTokenEntity } from "./entities/AccessTokenEntity.js"; +import { IdTokenEntity } from "./entities/IdTokenEntity.js"; +import { RefreshTokenEntity } from "./entities/RefreshTokenEntity.js"; +import { + createClientAuthError, + ClientAuthErrorCodes, +} from "../error/ClientAuthError.js"; +import { + AccountInfo, + TenantProfile, + updateAccountTenantProfileData, +} from "../account/AccountInfo.js"; +import { AppMetadataEntity } from "./entities/AppMetadataEntity.js"; +import { ServerTelemetryEntity } from "./entities/ServerTelemetryEntity.js"; +import { ThrottlingEntity } from "./entities/ThrottlingEntity.js"; +import { extractTokenClaims } from "../account/AuthToken.js"; +import { ICrypto } from "../crypto/ICrypto.js"; +import { AuthorityMetadataEntity } from "./entities/AuthorityMetadataEntity.js"; +import { BaseAuthRequest } from "../request/BaseAuthRequest.js"; +import { Logger } from "../logger/Logger.js"; +import { name, version } from "../packageMetadata.js"; +import { StoreInCache } from "../request/StoreInCache.js"; +import { getAliasesFromStaticSources } from "../authority/AuthorityMetadata.js"; +import { StaticAuthorityOptions } from "../authority/AuthorityOptions.js"; +import { TokenClaims } from "../account/TokenClaims.js"; +import { IPerformanceClient } from "../telemetry/performance/IPerformanceClient.js"; +import { CacheError, CacheErrorCodes } from "../error/CacheError.js"; + +/** + * Interface class which implement cache storage functions used by MSAL to perform validity checks, and store tokens. + * @internal + */ +export abstract class AsyncCacheManager { + protected clientId: string; + protected cryptoImpl: ICrypto; + // Instance of logger for functions defined in the msal-common layer + private commonLogger: Logger; + private staticAuthorityOptions?: StaticAuthorityOptions; + + constructor( + clientId: string, + cryptoImpl: ICrypto, + logger: Logger, + staticAuthorityOptions?: StaticAuthorityOptions + ) { + this.clientId = clientId; + this.cryptoImpl = cryptoImpl; + this.commonLogger = logger.clone(name, version); + this.staticAuthorityOptions = staticAuthorityOptions; + } + + /** + * fetch the account entity from the platform cache + * @param accountKey + */ + abstract getAccount( + accountKey: string, + logger?: Logger + ): Promise; + + /** + * set account entity in the platform cache + * @param account + * @param correlationId + */ + abstract setAccount( + account: AccountEntity, + correlationId: string + ): Promise; + + /** + * fetch the idToken entity from the platform cache + * @param idTokenKey + */ + abstract getIdTokenCredential(idTokenKey: string): Promise; + + /** + * set idToken entity to the platform cache + * @param idToken + * @param correlationId + */ + abstract setIdTokenCredential( + idToken: IdTokenEntity, + correlationId: string + ): Promise; + + /** + * fetch the idToken entity from the platform cache + * @param accessTokenKey + */ + abstract getAccessTokenCredential( + accessTokenKey: string + ): Promise; + + /** + * set accessToken entity to the platform cache + * @param accessToken + * @param correlationId + */ + abstract setAccessTokenCredential( + accessToken: AccessTokenEntity, + correlationId: string + ): Promise; + + /** + * fetch the idToken entity from the platform cache + * @param refreshTokenKey + */ + abstract getRefreshTokenCredential( + refreshTokenKey: string + ): Promise; + + /** + * set refreshToken entity to the platform cache + * @param refreshToken + * @param correlationId + */ + abstract setRefreshTokenCredential( + refreshToken: RefreshTokenEntity, + correlationId: string + ): Promise; + + /** + * fetch appMetadata entity from the platform cache + * @param appMetadataKey + */ + abstract getAppMetadata(appMetadataKey: string): Promise; + + /** + * set appMetadata entity to the platform cache + * @param appMetadata + */ + abstract setAppMetadata(appMetadata: AppMetadataEntity): Promise; + + /** + * fetch server telemetry entity from the platform cache + * @param serverTelemetryKey + */ + abstract getServerTelemetry( + serverTelemetryKey: string + ): Promise; + + /** + * set server telemetry entity to the platform cache + * @param serverTelemetryKey + * @param serverTelemetry + */ + abstract setServerTelemetry( + serverTelemetryKey: string, + serverTelemetry: ServerTelemetryEntity + ): Promise; + + /** + * fetch cloud discovery metadata entity from the platform cache + * @param key + */ + abstract getAuthorityMetadata(key: string): Promise; + + /** + * + */ + abstract getAuthorityMetadataKeys(): Promise>; + + /** + * set cloud discovery metadata entity to the platform cache + * @param key + * @param value + */ + abstract setAuthorityMetadata( + key: string, + value: AuthorityMetadataEntity + ): Promise; + + /** + * fetch throttling entity from the platform cache + * @param throttlingCacheKey + */ + abstract getThrottlingCache( + throttlingCacheKey: string + ): Promise; + + /** + * set throttling entity to the platform cache + * @param throttlingCacheKey + * @param throttlingCache + */ + abstract setThrottlingCache( + throttlingCacheKey: string, + throttlingCache: ThrottlingEntity + ): Promise; + + /** + * Function to remove an item from cache given its key. + * @param key + */ + abstract removeItem(key: string): Promise; + + /** + * Function which retrieves all current keys from the cache. + */ + abstract getKeys(): Promise; + + /** + * Function which retrieves all account keys from the cache + */ + abstract getAccountKeys(): Promise; + + /** + * Function which retrieves all token keys from the cache + */ + abstract getTokenKeys(): Promise; + + /** + * Returns all the accounts in the cache that match the optional filter. If no filter is provided, all accounts are returned. + * @param accountFilter - (Optional) filter to narrow down the accounts returned + * @returns Array of AccountInfo objects in cache + */ + async getAllAccounts(accountFilter?: AccountFilter): Promise { + const filteredAccounts= await this.getAccountsFilteredBy(accountFilter || {}); + return this.buildTenantProfiles(filteredAccounts, accountFilter); + } + + /** + * Gets first tenanted AccountInfo object found based on provided filters + */ + async getAccountInfoFilteredBy(accountFilter: AccountFilter): Promise { + const allAccounts = await this.getAllAccounts(accountFilter); + if (allAccounts.length > 1) { + // If one or more accounts are found, prioritize accounts that have an ID token + const sortedAccounts = allAccounts.sort((account) => { + return account.idTokenClaims ? -1 : 1; + }); + return sortedAccounts[0]; + } else if (allAccounts.length === 1) { + // If only one account is found, return it regardless of whether a matching ID token was found + return allAccounts[0]; + } else { + return null; + } + } + + /** + * Returns a single matching + * @param accountFilter + * @returns + */ + async getBaseAccountInfo(accountFilter: AccountFilter): Promise { + const accountEntities = await this.getAccountsFilteredBy(accountFilter); + if (accountEntities.length > 0) { + return accountEntities[0].getAccountInfo(); + } else { + return null; + } + } + + /** + * Matches filtered account entities with cached ID tokens that match the tenant profile-specific account filters + * and builds the account info objects from the matching ID token's claims + * @param cachedAccounts + * @param accountFilter + * @returns Array of AccountInfo objects that match account and tenant profile filters + */ + private async buildTenantProfiles( + cachedAccounts: AccountEntity[], + accountFilter?: AccountFilter + ): Promise { + const profilePromises = cachedAccounts.map(accountEntity => + this.getTenantProfilesFromAccountEntity( + accountEntity, + accountFilter?.tenantId, + accountFilter + ) + ); + const profilesArrays = await Promise.all(profilePromises); + return profilesArrays.flat(); + } + + private async getTenantedAccountInfoByFilter( + accountInfo: AccountInfo, + tokenKeys: TokenKeys, + tenantProfile: TenantProfile, + tenantProfileFilter?: TenantProfileFilter + ): Promise { + let tenantedAccountInfo: AccountInfo | null = null; + let idTokenClaims: TokenClaims | undefined; + + if (tenantProfileFilter) { + if ( + !this.tenantProfileMatchesFilter( + tenantProfile, + tenantProfileFilter + ) + ) { + return null; + } + } + + const idToken = await this.getIdToken( + accountInfo, + tokenKeys, + tenantProfile.tenantId + ); + + if (idToken) { + idTokenClaims = extractTokenClaims( + idToken.secret, + this.cryptoImpl.base64Decode + ); + + if ( + !this.idTokenClaimsMatchTenantProfileFilter( + idTokenClaims, + tenantProfileFilter + ) + ) { + // ID token sourced claims don't match so this tenant profile is not a match + return null; + } + } + + // Expand tenant profile into account info based on matching tenant profile and if available matching ID token claims + tenantedAccountInfo = updateAccountTenantProfileData( + accountInfo, + tenantProfile, + idTokenClaims, + idToken?.secret + ); + + return tenantedAccountInfo; + } + + private async getTenantProfilesFromAccountEntity( + accountEntity: AccountEntity, + targetTenantId?: string, + tenantProfileFilter?: TenantProfileFilter + ): Promise { + const accountInfo = accountEntity.getAccountInfo(); + let searchTenantProfiles: Map = + accountInfo.tenantProfiles || new Map(); + const tokenKeys = await this.getTokenKeys(); + + // If a tenant ID was provided, only return the tenant profile for that tenant ID if it exists + if (targetTenantId) { + const tenantProfile = searchTenantProfiles.get(targetTenantId); + if (tenantProfile) { + // Reduce search field to just this tenant profile + searchTenantProfiles = new Map([ + [targetTenantId, tenantProfile], + ]); + } else { + // No tenant profile for search tenant ID, return empty array + return []; + } + } + + const tenantProfilePromises = Array.from(searchTenantProfiles.values()).map( + async (tenantProfile: TenantProfile) => { + return this.getTenantedAccountInfoByFilter( + accountInfo, + tokenKeys, + tenantProfile, + tenantProfileFilter + ); + } + ); + + const results = await Promise.all(tenantProfilePromises); + return results.filter((tenantedAccountInfo): tenantedAccountInfo is AccountInfo => !!tenantedAccountInfo); + } + + private tenantProfileMatchesFilter( + tenantProfile: TenantProfile, + tenantProfileFilter: TenantProfileFilter + ): boolean { + if ( + !!tenantProfileFilter.localAccountId && + !this.matchLocalAccountIdFromTenantProfile( + tenantProfile, + tenantProfileFilter.localAccountId + ) + ) { + return false; + } + + if ( + !!tenantProfileFilter.name && + !(tenantProfile.name === tenantProfileFilter.name) + ) { + return false; + } + + if ( + tenantProfileFilter.isHomeTenant !== undefined && + !(tenantProfile.isHomeTenant === tenantProfileFilter.isHomeTenant) + ) { + return false; + } + + return true; + } + + private idTokenClaimsMatchTenantProfileFilter( + idTokenClaims: TokenClaims, + tenantProfileFilter?: TenantProfileFilter + ): boolean { + // Tenant Profile filtering + if (tenantProfileFilter) { + if ( + !!tenantProfileFilter.localAccountId && + !this.matchLocalAccountIdFromTokenClaims( + idTokenClaims, + tenantProfileFilter.localAccountId + ) + ) { + return false; + } + + if ( + !!tenantProfileFilter.loginHint && + !this.matchLoginHintFromTokenClaims( + idTokenClaims, + tenantProfileFilter.loginHint + ) + ) { + return false; + } + + if ( + !!tenantProfileFilter.username && + !this.matchUsername( + idTokenClaims.preferred_username, + tenantProfileFilter.username + ) + ) { + return false; + } + + if ( + !!tenantProfileFilter.name && + !this.matchName(idTokenClaims, tenantProfileFilter.name) + ) { + return false; + } + + if ( + !!tenantProfileFilter.sid && + !this.matchSid(idTokenClaims, tenantProfileFilter.sid) + ) { + return false; + } + } + + return true; + } + + /** + * saves a cache record + * @param cacheRecord {CacheRecord} + * @param storeInCache {?StoreInCache} + * @param correlationId {?string} correlation id + */ + async saveCacheRecord( + cacheRecord: CacheRecord, + correlationId: string, + storeInCache?: StoreInCache + ): Promise { + if (!cacheRecord) { + throw createClientAuthError( + ClientAuthErrorCodes.invalidCacheRecord + ); + } + + try { + if (!!cacheRecord.account) { + await this.setAccount(cacheRecord.account, correlationId); + } + + if (!!cacheRecord.idToken && storeInCache?.idToken !== false) { + await this.setIdTokenCredential( + cacheRecord.idToken, + correlationId + ); + } + + if ( + !!cacheRecord.accessToken && + storeInCache?.accessToken !== false + ) { + await this.saveAccessToken( + cacheRecord.accessToken, + correlationId + ); + } + + if ( + !!cacheRecord.refreshToken && + storeInCache?.refreshToken !== false + ) { + await this.setRefreshTokenCredential( + cacheRecord.refreshToken, + correlationId + ); + } + + if (!!cacheRecord.appMetadata) { + await this.setAppMetadata(cacheRecord.appMetadata); + } + } catch (e: unknown) { + this.commonLogger?.error(`CacheManager.saveCacheRecord: failed`); + if (e instanceof Error) { + this.commonLogger?.errorPii( + `CacheManager.saveCacheRecord: ${e.message}`, + correlationId + ); + + if ( + e.name === "QuotaExceededError" || + e.name === "NS_ERROR_DOM_QUOTA_REACHED" || + e.message.includes("exceeded the quota") + ) { + this.commonLogger?.error( + `CacheManager.saveCacheRecord: exceeded storage quota`, + correlationId + ); + throw new CacheError( + CacheErrorCodes.cacheQuotaExceededErrorCode + ); + } else { + throw new CacheError(e.name, e.message); + } + } else { + this.commonLogger?.errorPii( + `CacheManager.saveCacheRecord: ${e}`, + correlationId + ); + throw new CacheError(CacheErrorCodes.cacheUnknownErrorCode); + } + } + } + + /** + * saves access token credential + * @param credential + */ + private async saveAccessToken( + credential: AccessTokenEntity, + correlationId: string + ): Promise { + const accessTokenFilter: CredentialFilter = { + clientId: credential.clientId, + credentialType: credential.credentialType, + environment: credential.environment, + homeAccountId: credential.homeAccountId, + realm: credential.realm, + tokenType: credential.tokenType, + requestedClaimsHash: credential.requestedClaimsHash, + }; + + const tokenKeys = await this.getTokenKeys(); + const currentScopes = ScopeSet.fromString(credential.target); + + const removedAccessTokens: Array> = []; + tokenKeys.accessToken.forEach(async (key) => { + if ( + !this.accessTokenKeyMatchesFilter(key, accessTokenFilter, false) + ) { + return; + } + + const tokenEntity = await this.getAccessTokenCredential(key); + + if ( + tokenEntity && + this.credentialMatchesFilter(tokenEntity, accessTokenFilter) + ) { + const tokenScopeSet = ScopeSet.fromString(tokenEntity.target); + if (tokenScopeSet.intersectingScopeSets(currentScopes)) { + removedAccessTokens.push(this.removeAccessToken(key)); + } + } + }); + await Promise.all(removedAccessTokens); + await this.setAccessTokenCredential(credential, correlationId); + } + + /** + * Retrieve account entities matching all provided tenant-agnostic filters; if no filter is set, get all account entities in the cache + * Not checking for casing as keys are all generated in lower case, remember to convert to lower case if object properties are compared + * @param accountFilter - An object containing Account properties to filter by + */ + async getAccountsFilteredBy(accountFilter: AccountFilter): Promise { + const allAccountKeys = await this.getAccountKeys(); + const accountPromises = allAccountKeys.map(async (cacheKey) => { + if (!this.isAccountKey(cacheKey, accountFilter.homeAccountId)) { + // Don't parse value if the key doesn't match the account filters + return null; + } + + const entity: AccountEntity | null = await this.getAccount( + cacheKey, + this.commonLogger + ); + + // Match base account fields + + if (!entity) { + return null; + } + + if ( + !!accountFilter.homeAccountId && + !this.matchHomeAccountId(entity, accountFilter.homeAccountId) + ) { + return null; + } + + if ( + !!accountFilter.username && + !this.matchUsername(entity.username, accountFilter.username) + ) { + return null; + } + + if ( + !!accountFilter.environment && + !(await this.matchEnvironment(entity, accountFilter.environment)) + ) { + return null; + } + + if ( + !!accountFilter.realm && + !this.matchRealm(entity, accountFilter.realm) + ) { + return null; + } + + if ( + !!accountFilter.nativeAccountId && + !this.matchNativeAccountId( + entity, + accountFilter.nativeAccountId + ) + ) { + return null; + } + + if ( + !!accountFilter.authorityType && + !this.matchAuthorityType(entity, accountFilter.authorityType) + ) { + return null; + } + + // If at least one tenant profile matches the tenant profile filter, add the account to the list of matching accounts + const tenantProfileFilter: TenantProfileFilter = { + localAccountId: accountFilter?.localAccountId, + name: accountFilter?.name, + }; + + const matchingTenantProfiles = entity.tenantProfiles?.filter( + (tenantProfile: TenantProfile) => { + return this.tenantProfileMatchesFilter( + tenantProfile, + tenantProfileFilter + ); + } + ); + + if (matchingTenantProfiles && matchingTenantProfiles.length === 0) { + // No tenant profile for this account matches filter, don't add to list of matching accounts + return null; + } + + return entity; + }); + + const results = await Promise.all(accountPromises); + return results.filter((entity): entity is AccountEntity => entity !== null); + } + + /** + * Returns true if the given key matches our account key schema. Also matches homeAccountId and/or tenantId if provided + * @param key + * @param homeAccountId + * @param tenantId + * @returns + */ + isAccountKey( + key: string, + homeAccountId?: string, + tenantId?: string + ): boolean { + if (key.split(Separators.CACHE_KEY_SEPARATOR).length < 3) { + // Account cache keys contain 3 items separated by '-' (each item may also contain '-') + return false; + } + + if ( + homeAccountId && + !key.toLowerCase().includes(homeAccountId.toLowerCase()) + ) { + return false; + } + + if (tenantId && !key.toLowerCase().includes(tenantId.toLowerCase())) { + return false; + } + + // Do not check environment as aliasing can cause false negatives + + return true; + } + + /** + * Returns true if the given key matches our credential key schema. + * @param key + */ + isCredentialKey(key: string): boolean { + if (key.split(Separators.CACHE_KEY_SEPARATOR).length < 6) { + // Credential cache keys contain 6 items separated by '-' (each item may also contain '-') + return false; + } + + const lowerCaseKey = key.toLowerCase(); + // Credential keys must indicate what credential type they represent + if ( + lowerCaseKey.indexOf(CredentialType.ID_TOKEN.toLowerCase()) === + -1 && + lowerCaseKey.indexOf(CredentialType.ACCESS_TOKEN.toLowerCase()) === + -1 && + lowerCaseKey.indexOf( + CredentialType.ACCESS_TOKEN_WITH_AUTH_SCHEME.toLowerCase() + ) === -1 && + lowerCaseKey.indexOf(CredentialType.REFRESH_TOKEN.toLowerCase()) === + -1 + ) { + return false; + } + + if ( + lowerCaseKey.indexOf(CredentialType.REFRESH_TOKEN.toLowerCase()) > + -1 + ) { + // Refresh tokens must contain the client id or family id + const clientIdValidation = `${CredentialType.REFRESH_TOKEN}${Separators.CACHE_KEY_SEPARATOR}${this.clientId}${Separators.CACHE_KEY_SEPARATOR}`; + const familyIdValidation = `${CredentialType.REFRESH_TOKEN}${Separators.CACHE_KEY_SEPARATOR}${THE_FAMILY_ID}${Separators.CACHE_KEY_SEPARATOR}`; + if ( + lowerCaseKey.indexOf(clientIdValidation.toLowerCase()) === -1 && + lowerCaseKey.indexOf(familyIdValidation.toLowerCase()) === -1 + ) { + return false; + } + } else if (lowerCaseKey.indexOf(this.clientId.toLowerCase()) === -1) { + // Tokens must contain the clientId + return false; + } + + return true; + } + + /** + * Returns whether or not the given credential entity matches the filter + * @param entity + * @param filter + * @returns + */ + credentialMatchesFilter( + entity: ValidCredentialType, + filter: CredentialFilter + ): boolean { + if (!!filter.clientId && !this.matchClientId(entity, filter.clientId)) { + return false; + } + + if ( + !!filter.userAssertionHash && + !this.matchUserAssertionHash(entity, filter.userAssertionHash) + ) { + return false; + } + + /* + * homeAccountId can be undefined, and we want to filter out cached items that have a homeAccountId of "" + * because we don't want a client_credential request to return a cached token that has a homeAccountId + */ + if ( + typeof filter.homeAccountId === "string" && + !this.matchHomeAccountId(entity, filter.homeAccountId) + ) { + return false; + } + + if ( + !!filter.environment && + !this.matchEnvironment(entity, filter.environment) + ) { + return false; + } + + if (!!filter.realm && !this.matchRealm(entity, filter.realm)) { + return false; + } + + if ( + !!filter.credentialType && + !this.matchCredentialType(entity, filter.credentialType) + ) { + return false; + } + + if (!!filter.familyId && !this.matchFamilyId(entity, filter.familyId)) { + return false; + } + + /* + * idTokens do not have "target", target specific refreshTokens do exist for some types of authentication + * Resource specific refresh tokens case will be added when the support is deemed necessary + */ + if (!!filter.target && !this.matchTarget(entity, filter.target)) { + return false; + } + + // If request OR cached entity has requested Claims Hash, check if they match + if (filter.requestedClaimsHash || entity.requestedClaimsHash) { + // Don't match if either is undefined or they are different + if (entity.requestedClaimsHash !== filter.requestedClaimsHash) { + return false; + } + } + + // Access Token with Auth Scheme specific matching + if ( + entity.credentialType === + CredentialType.ACCESS_TOKEN_WITH_AUTH_SCHEME + ) { + if ( + !!filter.tokenType && + !this.matchTokenType(entity, filter.tokenType) + ) { + return false; + } + + // KeyId (sshKid) in request must match cached SSH certificate keyId because SSH cert is bound to a specific key + if (filter.tokenType === AuthenticationScheme.SSH) { + if (filter.keyId && !this.matchKeyId(entity, filter.keyId)) { + return false; + } + } + } + + return true; + } + + /** + * retrieve appMetadata matching all provided filters; if no filter is set, get all appMetadata + * @param filter + */ + async getAppMetadataFilteredBy(filter: AppMetadataFilter): Promise { + const allCacheKeys = await this.getKeys(); + const matchingAppMetadata: AppMetadataCache = {}; + + allCacheKeys.forEach(async (cacheKey) => { + // don't parse any non-appMetadata type cache entities + if (!this.isAppMetadata(cacheKey)) { + return; + } + + // Attempt retrieval + const entity = await this.getAppMetadata(cacheKey); + + if (!entity) { + return; + } + + if ( + !!filter.environment && + !this.matchEnvironment(entity, filter.environment) + ) { + return; + } + + if ( + !!filter.clientId && + !this.matchClientId(entity, filter.clientId) + ) { + return; + } + + matchingAppMetadata[cacheKey] = entity; + }); + + return matchingAppMetadata; + } + + /** + * retrieve authorityMetadata that contains a matching alias + * @param filter + */ + async getAuthorityMetadataByAlias(host: string): Promise { + const allCacheKeys = await this.getAuthorityMetadataKeys(); + let matchedEntity = null; + + allCacheKeys.forEach(async (cacheKey) => { + // don't parse any non-authorityMetadata type cache entities + if ( + !this.isAuthorityMetadata(cacheKey) || + cacheKey.indexOf(this.clientId) === -1 + ) { + return; + } + + // Attempt retrieval + const entity = await this.getAuthorityMetadata(cacheKey); + + if (!entity) { + return; + } + + if (entity.aliases.indexOf(host) === -1) { + return; + } + + matchedEntity = entity; + }); + + return matchedEntity; + } + + /** + * Removes all accounts and related tokens from cache. + */ + async removeAllAccounts(): Promise { + const allAccountKeys = await this.getAccountKeys(); + const removedAccounts: Array> = []; + + allAccountKeys.forEach((cacheKey) => { + removedAccounts.push(this.removeAccount(cacheKey)); + }); + + await Promise.all(removedAccounts); + } + + /** + * Removes the account and related tokens for a given account key + * @param account + */ + async removeAccount(accountKey: string): Promise { + const account = await this.getAccount(accountKey, this.commonLogger); + if (!account) { + return; + } + await this.removeAccountContext(account); + await this.removeItem(accountKey); + } + + /** + * Removes credentials associated with the provided account + * @param account + */ + async removeAccountContext(account: AccountEntity): Promise { + const allTokenKeys = await this.getTokenKeys(); + const accountId = account.generateAccountId(); + const removedCredentials: Array> = []; + + allTokenKeys.idToken.forEach(async (key) => { + if (key.indexOf(accountId) === 0) { + await this.removeIdToken(key); + } + }); + + allTokenKeys.accessToken.forEach(async (key) => { + if (key.indexOf(accountId) === 0) { + removedCredentials.push(this.removeAccessToken(key)); + } + }); + + allTokenKeys.refreshToken.forEach(async (key) => { + if (key.indexOf(accountId) === 0) { + await this.removeRefreshToken(key); + } + }); + + await Promise.all(removedCredentials); + } + + /** + * returns a boolean if the given credential is removed + * @param credential + */ + async removeAccessToken(key: string): Promise { + const credential = await this.getAccessTokenCredential(key); + if (!credential) { + return; + } + + // Remove Token Binding Key from key store for PoP Tokens Credentials + if ( + credential.credentialType.toLowerCase() === + CredentialType.ACCESS_TOKEN_WITH_AUTH_SCHEME.toLowerCase() + ) { + if (credential.tokenType === AuthenticationScheme.POP) { + const accessTokenWithAuthSchemeEntity = + credential as AccessTokenEntity; + const kid = accessTokenWithAuthSchemeEntity.keyId; + + if (kid) { + try { + await this.cryptoImpl.removeTokenBindingKey(kid); + } catch (error) { + throw createClientAuthError( + ClientAuthErrorCodes.bindingKeyNotRemoved + ); + } + } + } + } + + return this.removeItem(key); + } + + /** + * Removes all app metadata objects from cache. + */ + async removeAppMetadata(): Promise { + const allCacheKeys = await this.getKeys(); + allCacheKeys.forEach(async (cacheKey) => { + if (this.isAppMetadata(cacheKey)) { + await this.removeItem(cacheKey); + } + }); + + return true; + } + + /** + * Retrieve AccountEntity from cache + * @param account + */ + async readAccountFromCache(account: AccountInfo): Promise { + const accountKey: string = + AccountEntity.generateAccountCacheKey(account); + return this.getAccount(accountKey, this.commonLogger); + } + + /** + * Retrieve IdTokenEntity from cache + * @param account {AccountInfo} + * @param tokenKeys {?TokenKeys} + * @param targetRealm {?string} + * @param performanceClient {?IPerformanceClient} + * @param correlationId {?string} + */ + async getIdToken( + account: AccountInfo, + tokenKeys?: TokenKeys, + targetRealm?: string, + performanceClient?: IPerformanceClient, + correlationId?: string + ): Promise { + this.commonLogger.trace("AsyncCacheManager - getIdToken called"); + const idTokenFilter: CredentialFilter = { + homeAccountId: account.homeAccountId, + environment: account.environment, + credentialType: CredentialType.ID_TOKEN, + clientId: this.clientId, + realm: targetRealm, + }; + + const idTokenMap: Map = await this.getIdTokensByFilter( + idTokenFilter, + tokenKeys + ); + + const numIdTokens = idTokenMap.size; + + if (numIdTokens < 1) { + this.commonLogger.info("CacheManager:getIdToken - No token found"); + return null; + } else if (numIdTokens > 1) { + let tokensToBeRemoved: Map = idTokenMap; + // Multiple tenant profiles and no tenant specified, pick home account + if (!targetRealm) { + const homeIdTokenMap: Map = new Map< + string, + IdTokenEntity + >(); + idTokenMap.forEach((idToken, key) => { + if (idToken.realm === account.tenantId) { + homeIdTokenMap.set(key, idToken); + } + }); + const numHomeIdTokens = homeIdTokenMap.size; + if (numHomeIdTokens < 1) { + this.commonLogger.info( + "CacheManager:getIdToken - Multiple ID tokens found for account but none match account entity tenant id, returning first result" + ); + return idTokenMap.values().next().value; + } else if (numHomeIdTokens === 1) { + this.commonLogger.info( + "CacheManager:getIdToken - Multiple ID tokens found for account, defaulting to home tenant profile" + ); + return homeIdTokenMap.values().next().value; + } else { + // Multiple ID tokens for home tenant profile, remove all and return null + tokensToBeRemoved = homeIdTokenMap; + } + } + // Multiple tokens for a single tenant profile, remove all and return null + this.commonLogger.info( + "CacheManager:getIdToken - Multiple matching ID tokens found, clearing them" + ); + tokensToBeRemoved.forEach(async (idToken, key) => { + await this.removeIdToken(key); + }); + if (performanceClient && correlationId) { + performanceClient.addFields( + { multiMatchedID: idTokenMap.size }, + correlationId + ); + } + return null; + } + + this.commonLogger.info("CacheManager:getIdToken - Returning ID token"); + return idTokenMap.values().next().value; + } + + /** + * Gets all idTokens matching the given filter + * @param filter + * @returns + */ + async getIdTokensByFilter( + filter: CredentialFilter, + tokenKeys?: TokenKeys + ): Promise> { + const idTokenKeys = + (tokenKeys && tokenKeys.idToken) || (await this.getTokenKeys()).idToken; + + const idTokens: Map = new Map(); + + const idTokenPromises = idTokenKeys.map(async (key) => { + if ( + !this.idTokenKeyMatchesFilter(key, { + clientId: this.clientId, + ...filter, + }) + ) { + return null; + } + const idToken = await this.getIdTokenCredential(key); + if (idToken && this.credentialMatchesFilter(idToken, filter)) { + return { key, idToken }; + } + return null; + }); + + const results = await Promise.all(idTokenPromises); + results.forEach((result) => { + if (result) { + idTokens.set(result.key, result.idToken); + } + }); + + return idTokens; + } + + /** + * Validate the cache key against filter before retrieving and parsing cache value + * @param key + * @param filter + * @returns + */ + idTokenKeyMatchesFilter( + inputKey: string, + filter: CredentialFilter + ): boolean { + const key = inputKey.toLowerCase(); + if ( + filter.clientId && + key.indexOf(filter.clientId.toLowerCase()) === -1 + ) { + return false; + } + + if ( + filter.homeAccountId && + key.indexOf(filter.homeAccountId.toLowerCase()) === -1 + ) { + return false; + } + + return true; + } + + /** + * Removes idToken from the cache + * @param key + */ + async removeIdToken(key: string): Promise { + await this.removeItem(key); + } + + /** + * Removes refresh token from the cache + * @param key + */ + async removeRefreshToken(key: string): Promise { + await this.removeItem(key); + } + + /** + * Retrieve AccessTokenEntity from cache + * @param account {AccountInfo} + * @param request {BaseAuthRequest} + * @param tokenKeys {?TokenKeys} + * @param performanceClient {?IPerformanceClient} + * @param correlationId {?string} + */ + async getAccessToken( + account: AccountInfo, + request: BaseAuthRequest, + tokenKeys?: TokenKeys, + targetRealm?: string, + performanceClient?: IPerformanceClient, + correlationId?: string + ): Promise { + this.commonLogger.trace("AsyncCacheManager - getAccessToken called"); + const scopes = ScopeSet.createSearchScopes(request.scopes); + const authScheme = + request.authenticationScheme || AuthenticationScheme.BEARER; + /* + * Distinguish between Bearer and PoP/SSH token cache types + * Cast to lowercase to handle "bearer" from ADFS + */ + const credentialType = + authScheme && + authScheme.toLowerCase() !== + AuthenticationScheme.BEARER.toLowerCase() + ? CredentialType.ACCESS_TOKEN_WITH_AUTH_SCHEME + : CredentialType.ACCESS_TOKEN; + + const accessTokenFilter: CredentialFilter = { + homeAccountId: account.homeAccountId, + environment: account.environment, + credentialType: credentialType, + clientId: this.clientId, + realm: targetRealm || account.tenantId, + target: scopes, + tokenType: authScheme, + keyId: request.sshKid, + requestedClaimsHash: request.requestedClaimsHash, + }; + + const accessTokenKeys = + (tokenKeys && tokenKeys.accessToken) || + (await this.getTokenKeys()).accessToken; + const accessTokens: AccessTokenEntity[] = []; + + const accessTokenPromises = accessTokenKeys.map(async (key) => { + // Validate key + if ( + this.accessTokenKeyMatchesFilter(key, accessTokenFilter, true) + ) { + const accessToken = await this.getAccessTokenCredential(key); + + // Validate value + if ( + accessToken && + this.credentialMatchesFilter(accessToken, accessTokenFilter) + ) { + return accessToken; + } + } + return null; + }); + + const accessTokenResults = await Promise.all(accessTokenPromises); + accessTokenResults.forEach((accessToken) => { + if (accessToken) { + accessTokens.push(accessToken); + } + }); + + const numAccessTokens = accessTokens.length; + if (numAccessTokens < 1) { + this.commonLogger.info( + "CacheManager:getAccessToken - No token found" + ); + return null; + } else if (numAccessTokens > 1) { + this.commonLogger.info( + "CacheManager:getAccessToken - Multiple access tokens found, clearing them" + ); + accessTokens.forEach((accessToken) => { + void this.removeAccessToken(generateCredentialKey(accessToken)); + }); + if (performanceClient && correlationId) { + performanceClient.addFields( + { multiMatchedAT: accessTokens.length }, + correlationId + ); + } + return null; + } + + this.commonLogger.info( + "CacheManager:getAccessToken - Returning access token" + ); + return accessTokens[0]; + } + + /** + * Validate the cache key against filter before retrieving and parsing cache value + * @param key + * @param filter + * @param keyMustContainAllScopes + * @returns + */ + accessTokenKeyMatchesFilter( + inputKey: string, + filter: CredentialFilter, + keyMustContainAllScopes: boolean + ): boolean { + const key = inputKey.toLowerCase(); + if ( + filter.clientId && + key.indexOf(filter.clientId.toLowerCase()) === -1 + ) { + return false; + } + + if ( + filter.homeAccountId && + key.indexOf(filter.homeAccountId.toLowerCase()) === -1 + ) { + return false; + } + + if (filter.realm && key.indexOf(filter.realm.toLowerCase()) === -1) { + return false; + } + + if ( + filter.requestedClaimsHash && + key.indexOf(filter.requestedClaimsHash.toLowerCase()) === -1 + ) { + return false; + } + + if (filter.target) { + const scopes = filter.target.asArray(); + for (let i = 0; i < scopes.length; i++) { + if ( + keyMustContainAllScopes && + !key.includes(scopes[i].toLowerCase()) + ) { + // When performing a cache lookup a missing scope would be a cache miss + return false; + } else if ( + !keyMustContainAllScopes && + key.includes(scopes[i].toLowerCase()) + ) { + // When performing a cache write, any token with a subset of requested scopes should be replaced + return true; + } + } + } + + return true; + } + + /** + * Gets all access tokens matching the filter + * @param filter + * @returns + */ + async getAccessTokensByFilter(filter: CredentialFilter): Promise { + const tokenKeys = await this.getTokenKeys(); + + const accessTokens: AccessTokenEntity[] = []; + tokenKeys.accessToken.forEach(async (key) => { + if (!this.accessTokenKeyMatchesFilter(key, filter, true)) { + return; + } + + const accessToken = await this.getAccessTokenCredential(key); + if ( + accessToken && + this.credentialMatchesFilter(accessToken, filter) + ) { + accessTokens.push(accessToken); + } + }); + + return accessTokens; + } + + /** + * Helper to retrieve the appropriate refresh token from cache + * @param account {AccountInfo} + * @param familyRT {boolean} + * @param tokenKeys {?TokenKeys} + * @param performanceClient {?IPerformanceClient} + * @param correlationId {?string} + */ + async getRefreshToken( + account: AccountInfo, + familyRT: boolean, + tokenKeys?: TokenKeys, + performanceClient?: IPerformanceClient, + correlationId?: string + ): Promise { + this.commonLogger.trace("AsyncCacheManager - getRefreshToken called"); + const id = familyRT ? THE_FAMILY_ID : undefined; + const refreshTokenFilter: CredentialFilter = { + homeAccountId: account.homeAccountId, + environment: account.environment, + credentialType: CredentialType.REFRESH_TOKEN, + clientId: this.clientId, + familyId: id, + }; + + let refreshTokenKeys: string[] = []; + if(tokenKeys && tokenKeys.refreshToken) { + refreshTokenKeys = tokenKeys.refreshToken; + } else { + refreshTokenKeys = await this.getTokenKeys().then((keys) => keys.refreshToken); + } + + const refreshTokenPromises = refreshTokenKeys.map(async (key) => { + // Validate key + if (this.refreshTokenKeyMatchesFilter(key, refreshTokenFilter)) { + const refreshToken = await this.getRefreshTokenCredential(key); + // Validate value + if ( + refreshToken && + this.credentialMatchesFilter( + refreshToken, + refreshTokenFilter + ) + ) { + return refreshToken; + } + } + return null; + }); + + const refreshTokenResults = await Promise.all(refreshTokenPromises); + const refreshTokens: RefreshTokenEntity[] = refreshTokenResults.filter( + (token): token is RefreshTokenEntity => token !== null + ); + + const numRefreshTokens = refreshTokens.length; + if (numRefreshTokens < 1) { + this.commonLogger.info( + "CacheManager:getRefreshToken - No refresh token found." + ); + return null; + } + // address the else case after remove functions address environment aliases + + if (numRefreshTokens > 1 && performanceClient && correlationId) { + performanceClient.addFields( + { multiMatchedRT: numRefreshTokens }, + correlationId + ); + } + + this.commonLogger.info( + "CacheManager:getRefreshToken - returning refresh token" + ); + return refreshTokens[0] as RefreshTokenEntity; + } + + /** + * Validate the cache key against filter before retrieving and parsing cache value + * @param key + * @param filter + */ + refreshTokenKeyMatchesFilter( + inputKey: string, + filter: CredentialFilter + ): boolean { + const key = inputKey.toLowerCase(); + if ( + filter.familyId && + key.indexOf(filter.familyId.toLowerCase()) === -1 + ) { + return false; + } + + // If familyId is used, clientId is not in the key + if ( + !filter.familyId && + filter.clientId && + key.indexOf(filter.clientId.toLowerCase()) === -1 + ) { + return false; + } + + if ( + filter.homeAccountId && + key.indexOf(filter.homeAccountId.toLowerCase()) === -1 + ) { + return false; + } + + return true; + } + + /** + * Retrieve AppMetadataEntity from cache + */ + async readAppMetadataFromCache(environment: string): Promise { + const appMetadataFilter: AppMetadataFilter = { + environment, + clientId: this.clientId, + }; + + const appMetadata: AppMetadataCache = + await this.getAppMetadataFilteredBy(appMetadataFilter); + const appMetadataEntries: AppMetadataEntity[] = Object.keys( + appMetadata + ).map((key) => appMetadata[key]); + + const numAppMetadata = appMetadataEntries.length; + if (numAppMetadata < 1) { + return null; + } else if (numAppMetadata > 1) { + throw createClientAuthError( + ClientAuthErrorCodes.multipleMatchingAppMetadata + ); + } + + return appMetadataEntries[0] as AppMetadataEntity; + } + + /** + * Return the family_id value associated with FOCI + * @param environment + * @param clientId + */ + async isAppMetadataFOCI(environment: string): Promise { + const appMetadata = await this.readAppMetadataFromCache(environment); + return !!(appMetadata && appMetadata.familyId === THE_FAMILY_ID); + } + + /** + * helper to match account ids + * @param value + * @param homeAccountId + */ + private matchHomeAccountId( + entity: AccountEntity | CredentialEntity, + homeAccountId: string + ): boolean { + return !!( + typeof entity.homeAccountId === "string" && + homeAccountId === entity.homeAccountId + ); + } + + /** + * helper to match account ids + * @param entity + * @param localAccountId + * @returns + */ + private matchLocalAccountIdFromTokenClaims( + tokenClaims: TokenClaims, + localAccountId: string + ): boolean { + const idTokenLocalAccountId = tokenClaims.oid || tokenClaims.sub; + return localAccountId === idTokenLocalAccountId; + } + + private matchLocalAccountIdFromTenantProfile( + tenantProfile: TenantProfile, + localAccountId: string + ): boolean { + return tenantProfile.localAccountId === localAccountId; + } + + /** + * helper to match names + * @param entity + * @param name + * @returns true if the downcased name properties are present and match in the filter and the entity + */ + private matchName(claims: TokenClaims, name: string): boolean { + return !!(name.toLowerCase() === claims.name?.toLowerCase()); + } + + /** + * helper to match usernames + * @param entity + * @param username + * @returns + */ + private matchUsername( + cachedUsername?: string, + filterUsername?: string + ): boolean { + return !!( + cachedUsername && + typeof cachedUsername === "string" && + filterUsername?.toLowerCase() === cachedUsername.toLowerCase() + ); + } + + /** + * helper to match assertion + * @param value + * @param oboAssertion + */ + private matchUserAssertionHash( + entity: CredentialEntity, + userAssertionHash: string + ): boolean { + return !!( + entity.userAssertionHash && + userAssertionHash === entity.userAssertionHash + ); + } + + /** + * helper to match environment + * @param value + * @param environment + */ + private async matchEnvironment( + entity: AccountEntity | CredentialEntity | AppMetadataEntity, + environment: string + ): Promise { + // Check static authority options first for cases where authority metadata has not been resolved and cached yet + if (this.staticAuthorityOptions) { + const staticAliases = getAliasesFromStaticSources( + this.staticAuthorityOptions, + this.commonLogger + ); + if ( + staticAliases.includes(environment) && + staticAliases.includes(entity.environment) + ) { + return true; + } + } + + // Query metadata cache if no static authority configuration has aliases that match environment + const cloudMetadata = await this.getAuthorityMetadataByAlias(environment); + if ( + cloudMetadata && + cloudMetadata.aliases.indexOf(entity.environment) > -1 + ) { + return true; + } + return false; + } + + /** + * helper to match credential type + * @param entity + * @param credentialType + */ + private matchCredentialType( + entity: CredentialEntity, + credentialType: string + ): boolean { + return ( + entity.credentialType && + credentialType.toLowerCase() === entity.credentialType.toLowerCase() + ); + } + + /** + * helper to match client ids + * @param entity + * @param clientId + */ + private matchClientId( + entity: CredentialEntity | AppMetadataEntity, + clientId: string + ): boolean { + return !!(entity.clientId && clientId === entity.clientId); + } + + /** + * helper to match family ids + * @param entity + * @param familyId + */ + private matchFamilyId( + entity: CredentialEntity | AppMetadataEntity, + familyId: string + ): boolean { + return !!(entity.familyId && familyId === entity.familyId); + } + + /** + * helper to match realm + * @param entity + * @param realm + */ + private matchRealm( + entity: AccountEntity | CredentialEntity, + realm: string + ): boolean { + return !!(entity.realm?.toLowerCase() === realm.toLowerCase()); + } + + /** + * helper to match nativeAccountId + * @param entity + * @param nativeAccountId + * @returns boolean indicating the match result + */ + private matchNativeAccountId( + entity: AccountEntity, + nativeAccountId: string + ): boolean { + return !!( + entity.nativeAccountId && nativeAccountId === entity.nativeAccountId + ); + } + + /** + * helper to match loginHint which can be either: + * 1. login_hint ID token claim + * 2. username in cached account object + * 3. upn in ID token claims + * @param entity + * @param loginHint + * @returns + */ + private matchLoginHintFromTokenClaims( + tokenClaims: TokenClaims, + loginHint: string + ): boolean { + if (tokenClaims.login_hint === loginHint) { + return true; + } + + if (tokenClaims.preferred_username === loginHint) { + return true; + } + + if (tokenClaims.upn === loginHint) { + return true; + } + + return false; + } + + /** + * Helper to match sid + * @param entity + * @param sid + * @returns true if the sid claim is present and matches the filter + */ + private matchSid(idTokenClaims: TokenClaims, sid: string): boolean { + return idTokenClaims.sid === sid; + } + + private matchAuthorityType( + entity: AccountEntity, + authorityType: string + ): boolean { + return !!( + entity.authorityType && + authorityType.toLowerCase() === entity.authorityType.toLowerCase() + ); + } + + /** + * Returns true if the target scopes are a subset of the current entity's scopes, false otherwise. + * @param entity + * @param target + */ + private matchTarget(entity: CredentialEntity, target: ScopeSet): boolean { + const isNotAccessTokenCredential = + entity.credentialType !== CredentialType.ACCESS_TOKEN && + entity.credentialType !== + CredentialType.ACCESS_TOKEN_WITH_AUTH_SCHEME; + + if (isNotAccessTokenCredential || !entity.target) { + return false; + } + + const entityScopeSet: ScopeSet = ScopeSet.fromString(entity.target); + + return entityScopeSet.containsScopeSet(target); + } + + /** + * Returns true if the credential's tokenType or Authentication Scheme matches the one in the request, false otherwise + * @param entity + * @param tokenType + */ + private matchTokenType( + entity: CredentialEntity, + tokenType: AuthenticationScheme + ): boolean { + return !!(entity.tokenType && entity.tokenType === tokenType); + } + + /** + * Returns true if the credential's keyId matches the one in the request, false otherwise + * @param entity + * @param keyId + */ + private matchKeyId(entity: CredentialEntity, keyId: string): boolean { + return !!(entity.keyId && entity.keyId === keyId); + } + + /** + * returns if a given cache entity is of the type appmetadata + * @param key + */ + private isAppMetadata(key: string): boolean { + return key.indexOf(APP_METADATA) !== -1; + } + + /** + * returns if a given cache entity is of the type authoritymetadata + * @param key + */ + protected isAuthorityMetadata(key: string): boolean { + return key.indexOf(AUTHORITY_METADATA_CONSTANTS.CACHE_KEY) !== -1; + } + + /** + * returns cache key used for cloud instance metadata + */ + generateAuthorityMetadataCacheKey(authority: string): string { + return `${AUTHORITY_METADATA_CONSTANTS.CACHE_KEY}-${this.clientId}-${authority}`; + } + + /** + * Helper to convert serialized data to object + * @param obj + * @param json + */ + static toObject(obj: T, json: object): T { + for (const propertyName in json) { + obj[propertyName] = json[propertyName]; + } + return obj; + } +} + +/** @internal */ +export class DefaultAsyncStorageClass extends AsyncCacheManager { + async setAccount(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getAccount(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async setIdTokenCredential(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getIdTokenCredential(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async setAccessTokenCredential(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getAccessTokenCredential(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async setRefreshTokenCredential(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getRefreshTokenCredential(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async setAppMetadata(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getAppMetadata(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async setServerTelemetry(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getServerTelemetry(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async setAuthorityMetadata(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getAuthorityMetadata(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getAuthorityMetadataKeys(): Promise> { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async setThrottlingCache(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getThrottlingCache(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async removeItem(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getKeys(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getAccountKeys(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } + async getTokenKeys(): Promise { + throw createClientAuthError(ClientAuthErrorCodes.methodNotImplemented); + } +} diff --git a/lib/msal-common/src/client/BaseClient.ts b/lib/msal-common/src/client/BaseClient.ts index e60aa39943..94c6f6f19f 100644 --- a/lib/msal-common/src/client/BaseClient.ts +++ b/lib/msal-common/src/client/BaseClient.ts @@ -38,6 +38,7 @@ import { } from "../error/ClientAuthError.js"; import { NetworkError } from "../error/NetworkError.js"; import { invokeAsync } from "../utils/FunctionWrappers.js"; +import { AsyncCacheManager } from "../exports-common.js"; /** * Base application class which will construct requests to send to and handle responses from the Microsoft STS using the authorization code flow. @@ -54,7 +55,7 @@ export abstract class BaseClient { protected cryptoUtils: ICrypto; // Storage Interface - protected cacheManager: CacheManager; + protected cacheManager: CacheManager | AsyncCacheManager; // Network Interface protected networkClient: INetworkModule; @@ -168,7 +169,7 @@ export abstract class BaseClient { response.status !== 429 ) { // Telemetry data successfully logged by server, clear Telemetry cache - this.config.serverTelemetryManager.clearTelemetryCache(); + await this.config.serverTelemetryManager.clearTelemetryCache(); } return response; @@ -187,8 +188,7 @@ export abstract class BaseClient { options: NetworkRequestOptions, correlationId: string ): Promise> { - ThrottlingUtils.preProcess(this.cacheManager, thumbprint); - + await ThrottlingUtils.preProcess(this.cacheManager, thumbprint); let response; try { response = await invokeAsync( @@ -244,7 +244,7 @@ export abstract class BaseClient { } } - ThrottlingUtils.postProcess(this.cacheManager, thumbprint, response); + await ThrottlingUtils.postProcess(this.cacheManager, thumbprint, response); return response; } diff --git a/lib/msal-common/src/client/RefreshTokenClient.ts b/lib/msal-common/src/client/RefreshTokenClient.ts index 1da926717e..4981fa9eab 100644 --- a/lib/msal-common/src/client/RefreshTokenClient.ts +++ b/lib/msal-common/src/client/RefreshTokenClient.ts @@ -46,7 +46,7 @@ import { } from "../error/InteractionRequiredAuthError.js"; import { PerformanceEvents } from "../telemetry/performance/PerformanceEvent.js"; import { IPerformanceClient } from "../telemetry/performance/IPerformanceClient.js"; -import { invoke, invokeAsync } from "../utils/FunctionWrappers.js"; +import { invokeAsync } from "../utils/FunctionWrappers.js"; import { generateCredentialKey } from "../cache/utils/CacheHelpers.js"; import { ClientAssertion } from "../account/ClientCredentials.js"; import { getClientAssertion } from "../utils/ClientAssertionUtils.js"; @@ -74,43 +74,49 @@ export class RefreshTokenClient extends BaseClient { ); const reqTimestamp = TimeUtils.nowSeconds(); - const response = await invokeAsync( - this.executeTokenRequest.bind(this), - PerformanceEvents.RefreshTokenClientExecuteTokenRequest, - this.logger, - this.performanceClient, - request.correlationId - )(request, this.authority); - - // Retrieve requestId from response headers - const requestId = response.headers?.[HeaderNames.X_MS_REQUEST_ID]; - const responseHandler = new ResponseHandler( - this.config.authOptions.clientId, - this.cacheManager, - this.cryptoUtils, - this.logger, - this.config.serializableCache, - this.config.persistencePlugin - ); - responseHandler.validateTokenResponse(response.body); + let response: NetworkResponse; + try { + response = await invokeAsync( + this.executeTokenRequest.bind(this), + PerformanceEvents.RefreshTokenClientExecuteTokenRequest, + this.logger, + this.performanceClient, + request.correlationId + )(request, this.authority); + // Retrieve requestId from response headers + const requestId = response.headers?.[HeaderNames.X_MS_REQUEST_ID]; + const responseHandler = new ResponseHandler( + this.config.authOptions.clientId, + this.cacheManager, + this.cryptoUtils, + this.logger, + this.config.serializableCache, + this.config.persistencePlugin + ); + responseHandler.validateTokenResponse(response.body); - return invokeAsync( - responseHandler.handleServerTokenResponse.bind(responseHandler), - PerformanceEvents.HandleServerTokenResponse, - this.logger, - this.performanceClient, - request.correlationId - )( - response.body, - this.authority, - reqTimestamp, - request, - undefined, - undefined, - true, - request.forceCache, - requestId - ); + return await invokeAsync( + responseHandler.handleServerTokenResponse.bind(responseHandler), + PerformanceEvents.HandleServerTokenResponse, + this.logger, + this.performanceClient, + request.correlationId + )( + response.body, + this.authority, + reqTimestamp, + request, + undefined, + undefined, + true, + request.forceCache, + requestId + ); + } catch (e) { + console.log("Error: ", e); + return Promise.reject(e); + } + } /** @@ -140,7 +146,7 @@ export class RefreshTokenClient extends BaseClient { } // try checking if FOCI is enabled for the given application - const isFOCI = this.cacheManager.isAppMetadataFOCI( + const isFOCI = await this.cacheManager.isAppMetadataFOCI( request.account.environment ); @@ -203,8 +209,14 @@ export class RefreshTokenClient extends BaseClient { ); // fetches family RT or application RT based on FOCI value - const refreshToken = invoke( - this.cacheManager.getRefreshToken.bind(this.cacheManager), + const refreshToken = await invokeAsync( + async ( + account, + familyRT, + tokenKeys, + performanceClient, + correlationId + ) => this.cacheManager.getRefreshToken(account, familyRT, tokenKeys, performanceClient, correlationId), PerformanceEvents.CacheManagerGetRefreshToken, this.logger, this.performanceClient, @@ -274,7 +286,7 @@ export class RefreshTokenClient extends BaseClient { ); const badRefreshTokenKey = generateCredentialKey(refreshToken); - this.cacheManager.removeRefreshToken(badRefreshTokenKey); + await this.cacheManager.removeRefreshToken(badRefreshTokenKey); } } @@ -387,7 +399,7 @@ export class RefreshTokenClient extends BaseClient { RequestParameterBuilder.addThrottling(parameters); if (this.serverTelemetryManager && !isOidcProtocolMode(this.config)) { - RequestParameterBuilder.addServerTelemetry( + await RequestParameterBuilder.addServerTelemetry( parameters, this.serverTelemetryManager ); diff --git a/lib/msal-common/src/client/SilentFlowClient.ts b/lib/msal-common/src/client/SilentFlowClient.ts index 0801c88ee0..2a538bf377 100644 --- a/lib/msal-common/src/client/SilentFlowClient.ts +++ b/lib/msal-common/src/client/SilentFlowClient.ts @@ -70,8 +70,8 @@ export class SilentFlowClient extends BaseClient { const requestTenantId = request.account.tenantId || getTenantFromAuthorityString(request.authority); - const tokenKeys = this.cacheManager.getTokenKeys(); - const cachedAccessToken = this.cacheManager.getAccessToken( + const tokenKeys = await this.cacheManager.getTokenKeys(); + const cachedAccessToken = await this.cacheManager.getAccessToken( request.account, request, tokenKeys, @@ -117,9 +117,9 @@ export class SilentFlowClient extends BaseClient { const environment = request.authority || this.authority.getPreferredCache(); const cacheRecord: CacheRecord = { - account: this.cacheManager.readAccountFromCache(request.account), + account: await this.cacheManager.readAccountFromCache(request.account), accessToken: cachedAccessToken, - idToken: this.cacheManager.getIdToken( + idToken: await this.cacheManager.getIdToken( request.account, tokenKeys, requestTenantId, @@ -128,15 +128,14 @@ export class SilentFlowClient extends BaseClient { ), refreshToken: null, appMetadata: - this.cacheManager.readAppMetadataFromCache(environment), + await this.cacheManager.readAppMetadataFromCache(environment), }; this.setCacheOutcome(lastCacheOutcome, request.correlationId); if (this.config.serverTelemetryManager) { - this.config.serverTelemetryManager.incrementCacheHits(); + await this.config.serverTelemetryManager.incrementCacheHits(); } - return [ await invokeAsync( this.generateResultFromCacheRecord.bind(this), diff --git a/lib/msal-common/src/config/ClientConfiguration.ts b/lib/msal-common/src/config/ClientConfiguration.ts index 6d9975d652..f78cbd35d5 100644 --- a/lib/msal-common/src/config/ClientConfiguration.ts +++ b/lib/msal-common/src/config/ClientConfiguration.ts @@ -14,6 +14,7 @@ import { version } from "../packageMetadata.js"; import { Authority } from "../authority/Authority.js"; import { AzureCloudInstance } from "../authority/AuthorityOptions.js"; import { CacheManager, DefaultStorageClass } from "../cache/CacheManager.js"; +import { AsyncCacheManager } from "../cache/AsyncCacheManager.js"; import { ServerTelemetryManager } from "../telemetry/server/ServerTelemetryManager.js"; import { ICachePlugin } from "../cache/interface/ICachePlugin.js"; import { ISerializableTokenCache } from "../cache/interface/ISerializableTokenCache.js"; @@ -44,7 +45,7 @@ export type ClientConfiguration = { systemOptions?: SystemOptions; loggerOptions?: LoggerOptions; cacheOptions?: CacheOptions; - storageInterface?: CacheManager; + storageInterface?: CacheManager | AsyncCacheManager; networkInterface?: INetworkModule; cryptoInterface?: ICrypto; clientCredentials?: ClientCredentials; @@ -60,7 +61,7 @@ export type CommonClientConfiguration = { systemOptions: Required; loggerOptions: Required; cacheOptions: Required; - storageInterface: CacheManager; + storageInterface: CacheManager | AsyncCacheManager; networkInterface: INetworkModule; cryptoInterface: Required; libraryInfo: LibraryInfo; diff --git a/lib/msal-common/src/exports-common.ts b/lib/msal-common/src/exports-common.ts index 5700a50c40..bbbdc24f05 100644 --- a/lib/msal-common/src/exports-common.ts +++ b/lib/msal-common/src/exports-common.ts @@ -64,6 +64,7 @@ export { AuthorityType } from "./authority/AuthorityType.js"; export { ProtocolMode } from "./authority/ProtocolMode.js"; export { OIDCOptions } from "./authority/OIDCOptions.js"; export { CacheManager, DefaultStorageClass } from "./cache/CacheManager.js"; +export { AsyncCacheManager, DefaultAsyncStorageClass } from "./cache/AsyncCacheManager.js"; export { AccountCache, AccountFilter, diff --git a/lib/msal-common/src/network/ThrottlingUtils.ts b/lib/msal-common/src/network/ThrottlingUtils.ts index 8c5f8db35f..a8ba83c5ba 100644 --- a/lib/msal-common/src/network/ThrottlingUtils.ts +++ b/lib/msal-common/src/network/ThrottlingUtils.ts @@ -18,6 +18,7 @@ import { } from "./RequestThumbprint.js"; import { ThrottlingEntity } from "../cache/entities/ThrottlingEntity.js"; import { BaseAuthRequest } from "../request/BaseAuthRequest.js"; +import { AsyncCacheManager } from "../exports-common.js"; /** @internal */ export class ThrottlingUtils { @@ -36,16 +37,16 @@ export class ThrottlingUtils { * @param cacheManager * @param thumbprint */ - static preProcess( - cacheManager: CacheManager, + static async preProcess( + cacheManager: CacheManager | AsyncCacheManager, thumbprint: RequestThumbprint - ): void { + ): Promise { const key = ThrottlingUtils.generateThrottlingStorageKey(thumbprint); - const value = cacheManager.getThrottlingCache(key); + const value = await cacheManager.getThrottlingCache(key); if (value) { if (value.throttleTime < Date.now()) { - cacheManager.removeItem(key); + await cacheManager.removeItem(key); return; } throw new ServerError( @@ -62,11 +63,11 @@ export class ThrottlingUtils { * @param thumbprint * @param response */ - static postProcess( - cacheManager: CacheManager, + static async postProcess( + cacheManager: CacheManager | AsyncCacheManager, thumbprint: RequestThumbprint, response: NetworkResponse - ): void { + ): Promise { if ( ThrottlingUtils.checkResponseStatus(response) || ThrottlingUtils.checkResponseForRetryAfter(response) @@ -80,7 +81,7 @@ export class ThrottlingUtils { errorMessage: response.body.error_description, subError: response.body.suberror, }; - cacheManager.setThrottlingCache( + await cacheManager.setThrottlingCache( ThrottlingUtils.generateThrottlingStorageKey(thumbprint), thumbprintValue ); diff --git a/lib/msal-common/src/request/RequestParameterBuilder.ts b/lib/msal-common/src/request/RequestParameterBuilder.ts index 70555833b4..e0d2e24413 100644 --- a/lib/msal-common/src/request/RequestParameterBuilder.ts +++ b/lib/msal-common/src/request/RequestParameterBuilder.ts @@ -567,17 +567,17 @@ export function addSshJwk( * add server telemetry fields * @param serverTelemetryManager */ -export function addServerTelemetry( +export async function addServerTelemetry( parameters: Map, serverTelemetryManager: ServerTelemetryManager -): void { +): Promise { parameters.set( AADServerParamKeys.X_CLIENT_CURR_TELEM, - serverTelemetryManager.generateCurrentRequestHeaderValue() + await serverTelemetryManager.generateCurrentRequestHeaderValue() ); parameters.set( AADServerParamKeys.X_CLIENT_LAST_TELEM, - serverTelemetryManager.generateLastRequestHeaderValue() + await serverTelemetryManager.generateLastRequestHeaderValue() ); } diff --git a/lib/msal-common/src/response/AuthenticationResult.ts b/lib/msal-common/src/response/AuthenticationResult.ts index abcaeabecb..cb54c5edd7 100644 --- a/lib/msal-common/src/response/AuthenticationResult.ts +++ b/lib/msal-common/src/response/AuthenticationResult.ts @@ -31,6 +31,7 @@ export type AuthenticationResult = { idToken: string; idTokenClaims: object; accessToken: string; + refreshToken?: string; fromCache: boolean; expiresOn: Date | null; extExpiresOn?: Date; diff --git a/lib/msal-common/src/response/ResponseHandler.ts b/lib/msal-common/src/response/ResponseHandler.ts index 1a90b755f0..80e486a86a 100644 --- a/lib/msal-common/src/response/ResponseHandler.ts +++ b/lib/msal-common/src/response/ResponseHandler.ts @@ -52,6 +52,7 @@ import { } from "../account/AccountInfo.js"; import * as CacheHelpers from "../cache/utils/CacheHelpers.js"; import * as TimeUtils from "../utils/TimeUtils.js"; +import { AsyncCacheManager } from "../exports-common.js"; /** * Class that handles response parsing. @@ -59,7 +60,7 @@ import * as TimeUtils from "../utils/TimeUtils.js"; */ export class ResponseHandler { private clientId: string; - private cacheStorage: CacheManager; + private cacheStorage: CacheManager | AsyncCacheManager; private cryptoObj: ICrypto; private logger: Logger; private homeAccountIdentifier: string; @@ -69,7 +70,7 @@ export class ResponseHandler { constructor( clientId: string, - cacheStorage: CacheManager, + cacheStorage: CacheManager | AsyncCacheManager, cryptoObj: ICrypto, logger: Logger, serializableCache: ISerializableTokenCache | null, @@ -246,7 +247,7 @@ export class ResponseHandler { serverTokenResponse.key_id = serverTokenResponse.key_id || request.sshKid || undefined; - const cacheRecord = this.generateCacheRecord( + const cacheRecord = await this.generateCacheRecord( serverTokenResponse, authority, reqTimestamp, @@ -334,7 +335,7 @@ export class ResponseHandler { * @param idTokenObj * @param authority */ - private generateCacheRecord( + private async generateCacheRecord( serverTokenResponse: ServerAuthorizationTokenResponse, authority: Authority, reqTimestamp: number, @@ -342,7 +343,7 @@ export class ResponseHandler { idTokenClaims?: TokenClaims, userAssertionHash?: string, authCodePayload?: AuthorizationCodePayload - ): CacheRecord { + ): Promise { const env = authority.getPreferredCache(); if (!env) { throw createClientAuthError( @@ -364,7 +365,7 @@ export class ResponseHandler { claimsTenantId || "" ); - cachedAccount = buildAccountToCache( + cachedAccount = await buildAccountToCache( this.cacheStorage, authority, this.homeAccountIdentifier, @@ -498,6 +499,7 @@ export class ResponseHandler { requestId?: string ): Promise { let accessToken: string = Constants.EMPTY_STRING; + let refreshToken: string = Constants.EMPTY_STRING; // TODO: Remove once a better way to handle post-interaction worker cache hydration is in place let responseScopes: Array = []; let expiresOn: Date | null = null; let extExpiresOn: Date | undefined; @@ -549,6 +551,15 @@ export class ResponseHandler { } } + if (cacheRecord.refreshToken) { + refreshToken = cacheRecord.refreshToken.secret; + + // Access token expiresOn cached in seconds, converting to Date for AuthenticationResult + expiresOn = TimeUtils.toDateFromSeconds( + cacheRecord.refreshToken.expiresOn + ); + } + if (cacheRecord.appMetadata) { familyId = cacheRecord.appMetadata.familyId === THE_FAMILY_ID @@ -582,6 +593,7 @@ export class ResponseHandler { idToken: cacheRecord?.idToken?.secret || "", idTokenClaims: idTokenClaims || {}, accessToken: accessToken, + refreshToken: refreshToken, fromCache: fromTokenCache, expiresOn: expiresOn, extExpiresOn: extExpiresOn, @@ -605,8 +617,8 @@ export class ResponseHandler { } } -export function buildAccountToCache( - cacheStorage: CacheManager, +export async function buildAccountToCache( + cacheStorage: CacheManager | AsyncCacheManager, authority: Authority, homeAccountId: string, base64Decode: (input: string) => string, @@ -617,18 +629,18 @@ export function buildAccountToCache( authCodePayload?: AuthorizationCodePayload, nativeAccountId?: string, logger?: Logger -): AccountEntity { +): Promise { logger?.verbose("setCachedAccount called"); // Check if base account is already cached - const accountKeys = cacheStorage.getAccountKeys(); + const accountKeys = await cacheStorage.getAccountKeys(); const baseAccountKey = accountKeys.find((accountKey: string) => { return accountKey.startsWith(homeAccountId); }); let cachedAccount: AccountEntity | null = null; if (baseAccountKey) { - cachedAccount = cacheStorage.getAccount(baseAccountKey); + cachedAccount = await cacheStorage.getAccount(baseAccountKey); } const baseAccount = diff --git a/lib/msal-common/src/telemetry/server/ServerTelemetryManager.ts b/lib/msal-common/src/telemetry/server/ServerTelemetryManager.ts index 6de307b3a0..e8ea17b0d4 100644 --- a/lib/msal-common/src/telemetry/server/ServerTelemetryManager.ts +++ b/lib/msal-common/src/telemetry/server/ServerTelemetryManager.ts @@ -16,6 +16,7 @@ import { AuthError } from "../../error/AuthError.js"; import { ServerTelemetryRequest } from "./ServerTelemetryRequest.js"; import { ServerTelemetryEntity } from "../../cache/entities/ServerTelemetryEntity.js"; import { RegionDiscoveryMetadata } from "../../authority/RegionDiscoveryMetadata.js"; +import { AsyncCacheManager } from "../../exports-common.js"; const skuGroupSeparator = ","; const skuValueSeparator = "|"; @@ -82,7 +83,7 @@ function setSku(params: { /** @internal */ export class ServerTelemetryManager { - private cacheManager: CacheManager; + private cacheManager: CacheManager | AsyncCacheManager; private apiId: number; private correlationId: string; private telemetryCacheKey: string; @@ -95,7 +96,7 @@ export class ServerTelemetryManager { constructor( telemetryRequest: ServerTelemetryRequest, - cacheManager: CacheManager + cacheManager: CacheManager | AsyncCacheManager ) { this.cacheManager = cacheManager; this.apiId = telemetryRequest.apiId; @@ -112,10 +113,10 @@ export class ServerTelemetryManager { /** * API to add MSER Telemetry to request */ - generateCurrentRequestHeaderValue(): string { + async generateCurrentRequestHeaderValue(): Promise { const request = `${this.apiId}${SERVER_TELEM_CONSTANTS.VALUE_SEPARATOR}${this.cacheOutcome}`; const platformFieldsArr = [this.wrapperSKU, this.wrapperVer]; - const nativeBrokerErrorCode = this.getNativeBrokerErrorCode(); + const nativeBrokerErrorCode = await this.getNativeBrokerErrorCode(); if (nativeBrokerErrorCode?.length) { platformFieldsArr.push(`broker_error=${nativeBrokerErrorCode}`); } @@ -138,8 +139,8 @@ export class ServerTelemetryManager { /** * API to add MSER Telemetry for the last failed request */ - generateLastRequestHeaderValue(): string { - const lastRequests = this.getLastRequests(); + async generateLastRequestHeaderValue(): Promise { + const lastRequests = await this.getLastRequests(); const maxErrors = ServerTelemetryManager.maxErrorsToSend(lastRequests); const failedRequests = lastRequests.failedRequests @@ -172,8 +173,8 @@ export class ServerTelemetryManager { * API to cache token failures for MSER data capture * @param error */ - cacheFailedRequest(error: unknown): void { - const lastRequests = this.getLastRequests(); + async cacheFailedRequest(error: unknown): Promise { + const lastRequests = await this.getLastRequests(); if ( lastRequests.errors.length >= SERVER_TELEM_CONSTANTS.MAX_CACHED_ERRORS @@ -202,7 +203,7 @@ export class ServerTelemetryManager { lastRequests.errors.push(SERVER_TELEM_CONSTANTS.UNKNOWN_ERROR); } - this.cacheManager.setServerTelemetry( + await this.cacheManager.setServerTelemetry( this.telemetryCacheKey, lastRequests ); @@ -213,11 +214,11 @@ export class ServerTelemetryManager { /** * Update server telemetry cache entry by incrementing cache hit counter */ - incrementCacheHits(): number { - const lastRequests = this.getLastRequests(); + async incrementCacheHits(): Promise { + const lastRequests = await this.getLastRequests(); lastRequests.cacheHits += 1; - this.cacheManager.setServerTelemetry( + await this.cacheManager.setServerTelemetry( this.telemetryCacheKey, lastRequests ); @@ -227,13 +228,13 @@ export class ServerTelemetryManager { /** * Get the server telemetry entity from cache or initialize a new one */ - getLastRequests(): ServerTelemetryEntity { + async getLastRequests(): Promise { const initialValue: ServerTelemetryEntity = { failedRequests: [], errors: [], cacheHits: 0, }; - const lastRequests = this.cacheManager.getServerTelemetry( + const lastRequests = await this.cacheManager.getServerTelemetry( this.telemetryCacheKey ) as ServerTelemetryEntity; @@ -243,14 +244,14 @@ export class ServerTelemetryManager { /** * Remove server telemetry cache entry */ - clearTelemetryCache(): void { - const lastRequests = this.getLastRequests(); + async clearTelemetryCache(): Promise { + const lastRequests = await this.getLastRequests(); const numErrorsFlushed = ServerTelemetryManager.maxErrorsToSend(lastRequests); const errorCount = lastRequests.errors.length; if (numErrorsFlushed === errorCount) { // All errors were sent on last request, clear Telemetry cache - this.cacheManager.removeItem(this.telemetryCacheKey); + await this.cacheManager.removeItem(this.telemetryCacheKey); } else { // Partial data was flushed to server, construct a new telemetry cache item with errors that were not flushed const serverTelemEntity: ServerTelemetryEntity = { @@ -261,7 +262,7 @@ export class ServerTelemetryManager { cacheHits: 0, }; - this.cacheManager.setServerTelemetry( + await this.cacheManager.setServerTelemetry( this.telemetryCacheKey, serverTelemEntity ); @@ -346,23 +347,24 @@ export class ServerTelemetryManager { this.cacheOutcome = cacheOutcome; } - setNativeBrokerErrorCode(errorCode: string): void { - const lastRequests = this.getLastRequests(); + async setNativeBrokerErrorCode(errorCode: string): Promise { + const lastRequests = await this.getLastRequests(); lastRequests.nativeBrokerErrorCode = errorCode; - this.cacheManager.setServerTelemetry( + await this.cacheManager.setServerTelemetry( this.telemetryCacheKey, lastRequests ); } - getNativeBrokerErrorCode(): string | undefined { - return this.getLastRequests().nativeBrokerErrorCode; + async getNativeBrokerErrorCode(): Promise { + const lastRequests = await this.getLastRequests(); + return lastRequests.nativeBrokerErrorCode; } - clearNativeBrokerErrorCode(): void { - const lastRequests = this.getLastRequests(); + async clearNativeBrokerErrorCode(): Promise { + const lastRequests = await this.getLastRequests(); delete lastRequests.nativeBrokerErrorCode; - this.cacheManager.setServerTelemetry( + await this.cacheManager.setServerTelemetry( this.telemetryCacheKey, lastRequests ); diff --git a/package-lock.json b/package-lock.json index 77c54a6259..1d0c143fa2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -44259,9 +44259,9 @@ } }, "node_modules/vite": { - "version": "4.5.13", - "resolved": "https://identitydivision.pkgs.visualstudio.com/fac9d424-53d2-45c0-91b5-ef6ba7a6bf26/_packaging/dd15892d-fc68-4d1c-93a5-090f3b303f31/npm/registry/vite/-/vite-4.5.13.tgz", - "integrity": "sha1-d4U0qUcRLGxFXolzdzD65dRYopQ=", + "version": "4.5.14", + "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.14.tgz", + "integrity": "sha512-+v57oAaoYNnO3hIu5Z/tJRZjq5aHM2zDve9YZ8HngVHbhk66RStobhb1sqPMIPEleV6cNKYK4eGrAbE9Ulbl2g==", "dev": true, "license": "MIT", "peer": true, diff --git a/samples/msal-browser-samples/VanillaJSTestApp2.0/app/default/authConfig.js b/samples/msal-browser-samples/VanillaJSTestApp2.0/app/default/authConfig.js index afb8e0933f..dc43d5e2b2 100644 --- a/samples/msal-browser-samples/VanillaJSTestApp2.0/app/default/authConfig.js +++ b/samples/msal-browser-samples/VanillaJSTestApp2.0/app/default/authConfig.js @@ -1,9 +1,8 @@ // Config object to be passed to Msal on creation const msalConfig = { auth: { - clientId: "b5c2e510-4a17-4feb-b219-e55aa5b74144", - authority: - "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47", + clientId: "654736c7-9f4e-4158-9c22-54081d1896c6", + authority: "https://login.microsoftonline.com/common", }, cache: { cacheLocation: "sessionStorage", // This configures where your cache will be stored