diff --git a/packages/assets-controllers/src/AccountTrackerController.test.ts b/packages/assets-controllers/src/AccountTrackerController.test.ts index 8d37b7cd5ce..bb3c4a16b5b 100644 --- a/packages/assets-controllers/src/AccountTrackerController.test.ts +++ b/packages/assets-controllers/src/AccountTrackerController.test.ts @@ -4,6 +4,7 @@ import type { InternalAccount } from '@metamask/keyring-internal-api'; import { type NetworkClientId, type NetworkClientConfiguration, + type BlockTracker, getDefaultNetworkControllerState, } from '@metamask/network-controller'; import { getDefaultPreferencesState } from '@metamask/preferences-controller'; @@ -12,6 +13,7 @@ import { type TransactionMeta, } from '@metamask/transaction-controller'; import BN from 'bn.js'; +import { EventEmitter } from 'events'; import { useFakeTimers, type SinonFakeTimers } from 'sinon'; import type { @@ -78,6 +80,29 @@ const { safelyExecuteWithTimeout } = jest.requireMock( ); const mockedSafelyExecuteWithTimeout = safelyExecuteWithTimeout as jest.Mock; +// Mock BlockTracker class for testing +class MockBlockTracker extends EventEmitter implements BlockTracker { + getCurrentBlock(): string | null { + return '0x1'; + } + + async getLatestBlock(): Promise { + return '0x1'; + } + + async checkForLatestBlock(): Promise { + return '0x1'; + } + + isRunning(): boolean { + return true; + } + + async destroy(): Promise { + // Mock implementation + } +} + describe('AccountTrackerController', () => { let clock: SinonFakeTimers; @@ -141,7 +166,7 @@ describe('AccountTrackerController', () => { triggerSelectedAccountChange(ACCOUNT_1); - expect(refreshSpy).toHaveBeenCalled(); + expect(refreshSpy).toHaveBeenCalledWith(['mainnet']); }, ); }); @@ -1248,117 +1273,295 @@ describe('AccountTrackerController', () => { }); }); - it('should call refresh every interval on polling', async () => { - const pollSpy = jest.spyOn( - AccountTrackerController.prototype, - '_executePoll', - ); + it('should call refresh when block tracker emits latest block', async () => { await withController( { - options: { interval: 100 }, isMultiAccountBalancesEnabled: true, selectedAccount: EMPTY_ACCOUNT, listAccounts: [], }, - async ({ controller }) => { - jest.spyOn(controller, 'refresh').mockResolvedValue(); - - controller.startPolling({ - networkClientIds: ['networkClientId1'], - queryAllAccounts: true, - }); - await advanceTime({ clock, duration: 1 }); - - expect(pollSpy).toHaveBeenCalledTimes(1); + async ({ controller, blockTracker }) => { + const refreshSpy = jest + .spyOn(controller, 'refresh') + .mockResolvedValue(); - await advanceTime({ clock, duration: 50 }); + // Start tracking for global network + controller.start(); - expect(pollSpy).toHaveBeenCalledTimes(1); + // Simulate a new block + blockTracker.emit('latest', '0x2'); - await advanceTime({ clock, duration: 50 }); + // Give async handler time to execute + await advanceTime({ clock, duration: 1 }); - expect(pollSpy).toHaveBeenCalledTimes(2); + expect(refreshSpy).toHaveBeenCalledWith(['mainnet']); }, ); }); - it('should call refresh every interval for each networkClientId being polled', async () => { + it('should call refresh for each networkClientId being polled', async () => { const networkClientId1 = 'networkClientId1'; const networkClientId2 = 'networkClientId2'; await withController( { - options: { interval: 100 }, isMultiAccountBalancesEnabled: true, selectedAccount: EMPTY_ACCOUNT, listAccounts: [], + networkClientById: { + [networkClientId1]: buildCustomNetworkClientConfiguration({ + chainId: '0x1', + }), + [networkClientId2]: buildCustomNetworkClientConfiguration({ + chainId: '0x89', + }), + }, }, async ({ controller }) => { const refreshSpy = jest .spyOn(controller, 'refresh') .mockResolvedValue(); - controller.startPolling({ - networkClientIds: [networkClientId1], - queryAllAccounts: true, - }); + // Start polling for first network + const pollToken1 = + controller.startPollingByNetworkClientId(networkClientId1); - await advanceTime({ clock, duration: 0 }); - expect(refreshSpy).toHaveBeenNthCalledWith(1, [networkClientId1], true); - expect(refreshSpy).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: 50 }); + // Should immediately call refresh + await advanceTime({ clock, duration: 1 }); + expect(refreshSpy).toHaveBeenCalledWith([networkClientId1]); expect(refreshSpy).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: 50 }); - expect(refreshSpy).toHaveBeenNthCalledWith(2, [networkClientId1], true); - expect(refreshSpy).toHaveBeenCalledTimes(2); - const pollToken = controller.startPolling({ - networkClientIds: [networkClientId2], - queryAllAccounts: true, - }); + // Start polling for second network + const pollToken2 = + controller.startPollingByNetworkClientId(networkClientId2); - await advanceTime({ clock, duration: 0 }); - expect(refreshSpy).toHaveBeenNthCalledWith(3, [networkClientId2], true); - expect(refreshSpy).toHaveBeenCalledTimes(3); - await advanceTime({ clock, duration: 100 }); - expect(refreshSpy).toHaveBeenNthCalledWith(4, [networkClientId1], true); - expect(refreshSpy).toHaveBeenNthCalledWith(5, [networkClientId2], true); - expect(refreshSpy).toHaveBeenCalledTimes(5); + await advanceTime({ clock, duration: 1 }); + expect(refreshSpy).toHaveBeenCalledWith([networkClientId2]); + expect(refreshSpy).toHaveBeenCalledTimes(2); - controller.stopPollingByPollingToken(pollToken); + // Stop polling for second network + controller.stopPollingByPollingToken(pollToken2); - await advanceTime({ clock, duration: 100 }); - expect(refreshSpy).toHaveBeenNthCalledWith(6, [networkClientId1], true); - expect(refreshSpy).toHaveBeenCalledTimes(6); + // Clear the spy to verify no more calls after stopping + refreshSpy.mockClear(); + // Stop all polling controller.stopAllPolling(); await advanceTime({ clock, duration: 100 }); - expect(refreshSpy).toHaveBeenCalledTimes(6); + // Should not be called anymore after stopping + expect(refreshSpy).not.toHaveBeenCalled(); }, ); }); - it('should not call polling twice', async () => { - await withController( - { - options: { interval: 100 }, - }, - async ({ controller }) => { - const refreshSpy = jest - .spyOn(controller, 'refresh') - .mockResolvedValue(); + it('should not start tracking blocks automatically', async () => { + await withController({}, async ({ controller, blockTracker }) => { + const refreshSpy = jest.spyOn(controller, 'refresh').mockResolvedValue(); - expect(refreshSpy).not.toHaveBeenCalled(); - controller.startPolling({ - networkClientIds: ['networkClientId1'], - queryAllAccounts: true, - }); + // Emit a block without starting the controller + blockTracker.emit('latest', '0x2'); - await advanceTime({ clock, duration: 1 }); - expect(refreshSpy).toHaveBeenCalledTimes(1); - }, - ); + await advanceTime({ clock, duration: 1 }); + + // Should not be called because we haven't called start() + expect(refreshSpy).not.toHaveBeenCalled(); + + // Now start and emit another block + controller.start(); + blockTracker.emit('latest', '0x3'); + + await advanceTime({ clock, duration: 1 }); + + // Now it should be called + expect(refreshSpy).toHaveBeenCalledTimes(1); + }); + }); + + describe('polling methods', () => { + it('should handle multiple polling tokens for the same network', async () => { + const networkClientId = 'networkClientId1'; + await withController( + { + networkClientById: { + [networkClientId]: buildCustomNetworkClientConfiguration({ + chainId: '0x1', + }), + }, + }, + async ({ controller }) => { + const refreshSpy = jest + .spyOn(controller, 'refresh') + .mockResolvedValue(); + + // Start polling multiple times for the same network + const token1 = + controller.startPollingByNetworkClientId(networkClientId); + const token2 = + controller.startPollingByNetworkClientId(networkClientId); + const token3 = + controller.startPollingByNetworkClientId(networkClientId); + + // Should only call refresh once for initial subscription + await advanceTime({ clock, duration: 1 }); + expect(refreshSpy).toHaveBeenCalledTimes(1); + expect(refreshSpy).toHaveBeenCalledWith([networkClientId]); + + // Stop one token - polling should continue + controller.stopPollingByPollingToken(token1); + refreshSpy.mockClear(); + + // Verify polling continues by checking that refresh would be called on block + // (In real usage, block events would trigger this) + + // Stop another token - polling should still continue + controller.stopPollingByPollingToken(token2); + + // Stop the last token - polling should stop + controller.stopPollingByPollingToken(token3); + + // Verify no more polling happens + refreshSpy.mockClear(); + await advanceTime({ clock, duration: 100 }); + expect(refreshSpy).not.toHaveBeenCalled(); + }, + ); + }); + + it('should throw error when stopping with undefined token', async () => { + await withController({}, async ({ controller }) => { + expect(() => { + controller.stopPollingByPollingToken(undefined); + }).toThrow('pollingToken required'); + }); + }); + + it('should handle stopAllPolling correctly', async () => { + const networkClientId1 = 'networkClientId1'; + const networkClientId2 = 'networkClientId2'; + + await withController( + { + networkClientById: { + [networkClientId1]: buildCustomNetworkClientConfiguration({ + chainId: '0x1', + }), + [networkClientId2]: buildCustomNetworkClientConfiguration({ + chainId: '0x89', + }), + }, + }, + async ({ controller }) => { + const refreshSpy = jest + .spyOn(controller, 'refresh') + .mockResolvedValue(); + + // Start global polling + controller.start(); + + // Start polling for multiple networks + controller.startPollingByNetworkClientId(networkClientId1); + controller.startPollingByNetworkClientId(networkClientId2); + + // Clear the initial refresh calls + await advanceTime({ clock, duration: 1 }); + refreshSpy.mockClear(); + + // Stop all polling + controller.stopAllPolling(); + + // Verify all polling has stopped + await advanceTime({ clock, duration: 100 }); + expect(refreshSpy).not.toHaveBeenCalled(); + }, + ); + }); + + it('should track multiple blocks and update balances on each block', async () => { + await withController( + { + isMultiAccountBalancesEnabled: true, + selectedAccount: ACCOUNT_1, + listAccounts: [ACCOUNT_1], + }, + async ({ controller, blockTracker }) => { + const refreshSpy = jest + .spyOn(controller, 'refresh') + .mockResolvedValue(); + + // Start tracking + controller.start(); + + // Simulate multiple blocks + blockTracker.emit('latest', '0x2'); + await advanceTime({ clock, duration: 1 }); + + blockTracker.emit('latest', '0x3'); + await advanceTime({ clock, duration: 1 }); + + blockTracker.emit('latest', '0x4'); + await advanceTime({ clock, duration: 1 }); + + // Should be called for each block + expect(refreshSpy).toHaveBeenCalledTimes(3); + expect(refreshSpy).toHaveBeenCalledWith(['mainnet']); + + // Stop tracking + controller.stop(); + + // Emit another block after stopping + blockTracker.emit('latest', '0x5'); + await advanceTime({ clock, duration: 1 }); + + // Should not be called after stopping + expect(refreshSpy).toHaveBeenCalledTimes(3); + }, + ); + }); + + it('should handle block events for specific network clients', async () => { + const networkClientId = 'polygon'; + + await withController( + { + networkClientById: { + [networkClientId]: buildCustomNetworkClientConfiguration({ + chainId: '0x89', + }), + }, + }, + async ({ controller }) => { + const refreshSpy = jest + .spyOn(controller, 'refresh') + .mockResolvedValue(); + + // Get the mock block tracker for the network client + const networkBlockTracker = new MockBlockTracker(); + const getNetworkClientById = jest.fn().mockReturnValue({ + configuration: { chainId: '0x89' }, + provider: {} as any, + blockTracker: networkBlockTracker, + }); + + // Start polling for specific network + controller.startPollingByNetworkClientId(networkClientId); + + await advanceTime({ clock, duration: 1 }); + + // Should be called once for the initial poll + expect(refreshSpy).toHaveBeenCalledTimes(1); + expect(refreshSpy).toHaveBeenCalledWith([networkClientId]); + + // Simulate a block on the network-specific tracker + networkBlockTracker.emit('latest', '0x100'); + await advanceTime({ clock, duration: 1 }); + + // Should be called again for the new block + expect(refreshSpy).toHaveBeenCalledTimes(2); + expect(refreshSpy).toHaveBeenLastCalledWith([networkClientId]); + }, + ); + }); }); describe('metadata', () => { @@ -1438,6 +1641,8 @@ type WithControllerCallback = ({ networkClientIds: NetworkClientId[], queryAllAccounts?: boolean, ) => Promise; + blockTracker: MockBlockTracker; + provider: FakeProvider; }) => Promise | ReturnValue; type WithControllerOptions = { @@ -1561,7 +1766,9 @@ async function withController( // eslint-disable-next-line @typescript-eslint/no-explicit-any }) as any; - return { ...network, provider }; + const blockTracker = new MockBlockTracker(); + + return { ...network, provider, blockTracker }; }, ); @@ -1577,6 +1784,30 @@ async function withController( const mockNetworkState = jest.fn().mockReturnValue({ ...getDefaultNetworkControllerState(), chainId: initialChainId, + selectedNetworkClientId: 'mainnet', + networkConfigurationsByChainId: { + [initialChainId]: { + rpcEndpoints: [ + { + networkClientId: 'mainnet', + }, + ], + defaultRpcEndpointIndex: 0, + }, + ...Object.fromEntries( + Object.entries(networkClientById).map(([clientId, config]) => [ + config.configuration.chainId, + { + rpcEndpoints: [ + { + networkClientId: clientId, + }, + ], + defaultRpcEndpointIndex: 0, + }, + ]), + ), + }, }); messenger.registerActionHandler( @@ -1604,8 +1835,23 @@ async function withController( messenger.publish('AccountsController:selectedEvmAccountChange', account); }; + // Create provider and blockTracker + const provider = new FakeProvider({ + stubs: [ + { + request: { + method: 'eth_chainId', + }, + response: { result: initialChainId }, + }, + ], + }); + const blockTracker = new MockBlockTracker(); + const controller = new AccountTrackerController({ messenger: accountTrackerMessenger, + provider, + blockTracker, getStakedBalanceForChain: jest.fn(), ...options, }); @@ -1625,6 +1871,8 @@ async function withController( messenger, triggerSelectedAccountChange, refresh, + blockTracker, + provider, }); } diff --git a/packages/assets-controllers/src/AccountTrackerController.ts b/packages/assets-controllers/src/AccountTrackerController.ts index d0e07e9a96a..31ee95b493d 100644 --- a/packages/assets-controllers/src/AccountTrackerController.ts +++ b/packages/assets-controllers/src/AccountTrackerController.ts @@ -10,6 +10,7 @@ import type { ControllerGetStateAction, RestrictedMessenger, } from '@metamask/base-controller'; +import { BaseController } from '@metamask/base-controller'; import { query, safelyExecuteWithTimeout, @@ -18,12 +19,12 @@ import { import EthQuery from '@metamask/eth-query'; import type { InternalAccount } from '@metamask/keyring-internal-api'; import type { + BlockTracker, NetworkClient, NetworkClientId, NetworkControllerGetNetworkClientByIdAction, NetworkControllerGetStateAction, } from '@metamask/network-controller'; -import { StaticIntervalPollingController } from '@metamask/polling-controller'; import type { PreferencesControllerGetStateAction } from '@metamask/preferences-controller'; import type { TransactionControllerTransactionConfirmedEvent, @@ -33,6 +34,7 @@ import type { import { assert, type Hex } from '@metamask/utils'; import { Mutex } from 'async-mutex'; import { cloneDeep, isEqual } from 'lodash'; +import { v4 as random } from 'uuid'; import { STAKING_CONTRACT_ADDRESS_BY_CHAINID, @@ -213,16 +215,10 @@ export type AccountTrackerControllerMessenger = RestrictedMessenger< AllowedEvents['type'] >; -/** The input to start polling for the {@link AccountTrackerController} */ -type AccountTrackerPollingInput = { - networkClientIds: NetworkClientId[]; - queryAllAccounts?: boolean; -}; - /** * Controller that tracks the network balances for all user accounts. */ -export class AccountTrackerController extends StaticIntervalPollingController()< +export class AccountTrackerController extends BaseController< typeof controllerName, AccountTrackerControllerState, AccountTrackerControllerMessenger @@ -237,30 +233,41 @@ export class AccountTrackerController extends StaticIntervalPollingController>(); + + readonly #listeners: Record< + NetworkClientId, + (blockNumber: string) => Promise + > = {}; + + readonly #blockTracker: BlockTracker; + + #currentBlockNumberByChainId: Record = {}; + /** * Creates an AccountTracker instance. * * @param options - The controller options. - * @param options.interval - Polling interval used to fetch new account balances. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller messaging system. + * @param options.blockTracker - A block tracker, which emits events for each new block. * @param options.getStakedBalanceForChain - The function to get the staked native asset balance for a chain. * @param options.includeStakedAssets - Whether to include staked assets in the account balances. * @param options.accountsApiChainIds - Function that returns array of chainIds that should use Accounts-API strategy (if supported by API). * @param options.allowExternalServices - Disable external HTTP calls (privacy / offline mode). */ constructor({ - interval = 10000, state, messenger, + blockTracker, getStakedBalanceForChain, includeStakedAssets = false, accountsApiChainIds = () => [], allowExternalServices = () => true, }: { - interval?: number; state?: Partial; messenger: AccountTrackerControllerMessenger; + blockTracker: BlockTracker; getStakedBalanceForChain: AssetsContractController['getStakedBalanceForChain']; includeStakedAssets?: boolean; accountsApiChainIds?: () => ChainIdHex[]; @@ -286,6 +293,7 @@ export class AccountTrackerController extends StaticIntervalPollingController { @@ -340,6 +346,40 @@ export class AccountTrackerController extends StaticIntervalPollingController => { + await this.#updateForBlockByNetworkClientId(undefined, blockNumber); + }; + + /** + * Given a block, updates account balances for a specific network client + * + * @param networkClientId - optional network client ID to use instead of the globally selected network. + * @param blockNumber - the block number to update to. + * @fires 'block' The updated state, if all account updates are successful + */ + async #updateForBlockByNetworkClientId( + networkClientId: NetworkClientId | undefined, + blockNumber: string, + ): Promise { + const { chainId } = this.#getCorrectNetworkClient(networkClientId); + this.#currentBlockNumberByChainId[chainId] = blockNumber; + + try { + const networkClientIds = networkClientId + ? [networkClientId] + : this.#getNetworkClientIds(); + await this.refresh(networkClientIds); + } catch (err) { + console.error(err); + } + } + private syncAccounts(newChainIds: string[]) { const accountsByChainId = cloneDeep(this.state.accountsByChainId); const { selectedNetworkClientId } = this.messagingSystem.call( @@ -479,6 +519,24 @@ export class AccountTrackerController extends StaticIntervalPollingController { + this.#currentBlockNumberByChainId[this.#getCurrentChainId()] = + blockNumber; + }); + + // remove first to avoid double add + this.#blockTracker.removeListener('latest', this.#updateForBlock); + // add listener + this.#blockTracker.addListener('latest', this.#updateForBlock); + // fetch account balances + // eslint-disable-next-line @typescript-eslint/no-floating-promises + this.refresh([ + this.messagingSystem.call('NetworkController:getState') + .selectedNetworkClientId, + ]); + } + + /** + * Stops polling with global selected network + */ + stop(): void { + // remove listener + this.#blockTracker.removeListener('latest', this.#updateForBlock); + } + + /** + * Starts polling for a networkClientId * - * @param input - The input for the poll. - * @param input.networkClientIds - The network client IDs used to get balances. - * @param input.queryAllAccounts - Whether to query all accounts or just the selected account + * @param networkClientId - The networkClientId to start polling for + * @returns pollingToken */ - async _executePoll({ - networkClientIds, - queryAllAccounts = false, - }: AccountTrackerPollingInput): Promise { - // TODO: Either fix this lint violation or explain why it's necessary to ignore. + startPollingByNetworkClientId(networkClientId: NetworkClientId): string { + const pollToken = random(); + + const pollingTokenSet = this.#pollingTokenSets.get(networkClientId); + if (pollingTokenSet) { + pollingTokenSet.add(pollToken); + } else { + const set = new Set(); + set.add(pollToken); + this.#pollingTokenSets.set(networkClientId, set); + this.#subscribeWithNetworkClientId(networkClientId); + } + return pollToken; + } + + /** + * Stops polling for all networkClientIds + */ + stopAllPolling(): void { + this.stop(); + this.#pollingTokenSets.forEach((tokenSet, _networkClientId) => { + tokenSet.forEach((token) => { + this.stopPollingByPollingToken(token); + }); + }); + } + + /** + * Stops polling for a networkClientId + * + * @param pollingToken - The polling token to stop polling for + */ + stopPollingByPollingToken(pollingToken: string | undefined): void { + if (!pollingToken) { + throw new Error('pollingToken required'); + } + this.#pollingTokenSets.forEach((tokenSet, key) => { + if (tokenSet.has(pollingToken)) { + tokenSet.delete(pollingToken); + if (tokenSet.size === 0) { + this.#pollingTokenSets.delete(key); + this.#unsubscribeWithNetworkClientId(key); + } + } + }); + } + + /** + * Subscribes from the block tracker for the given networkClientId if not currently subscribed + * + * @param networkClientId - network client ID to fetch a block tracker with + */ + #subscribeWithNetworkClientId(networkClientId: NetworkClientId): void { + if (this.#listeners[networkClientId]) { + return; + } + const { blockTracker } = this.#getCorrectNetworkClient(networkClientId); + const updateForBlock = (blockNumber: string) => + this.#updateForBlockByNetworkClientId(networkClientId, blockNumber); + blockTracker.addListener('latest', updateForBlock); + + this.#listeners[networkClientId] = updateForBlock; + // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.refresh(networkClientIds, queryAllAccounts); + this.refresh([networkClientId]); + } + + /** + * Unsubscribes from the block tracker for the given networkClientId if currently subscribed + * + * @param networkClientId - The network client ID to fetch a block tracker with + */ + #unsubscribeWithNetworkClientId(networkClientId: NetworkClientId): void { + if (!this.#listeners[networkClientId]) { + return; + } + const { blockTracker } = this.#getCorrectNetworkClient(networkClientId); + blockTracker.removeListener('latest', this.#listeners[networkClientId]); + + delete this.#listeners[networkClientId]; } /** @@ -517,7 +680,7 @@ export class AccountTrackerController extends StaticIntervalPollingController