diff --git a/packages/snaps-controllers/coverage.json b/packages/snaps-controllers/coverage.json index 236fcc9d15..993ff36acf 100644 --- a/packages/snaps-controllers/coverage.json +++ b/packages/snaps-controllers/coverage.json @@ -1,6 +1,6 @@ { - "branches": 93.51, - "functions": 98.14, - "lines": 98.48, - "statements": 98.31 + "branches": 93.61, + "functions": 98.16, + "lines": 98.5, + "statements": 98.33 } diff --git a/packages/snaps-controllers/src/snaps/SnapController.test.tsx b/packages/snaps-controllers/src/snaps/SnapController.test.tsx index 27b5673624..f3a85abb8c 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.test.tsx +++ b/packages/snaps-controllers/src/snaps/SnapController.test.tsx @@ -4702,7 +4702,7 @@ describe('SnapController', () => { }, }); - expect(rootMessenger.call).toHaveBeenCalledTimes(4); + expect(rootMessenger.call).toHaveBeenCalledTimes(5); expect(rootMessenger.call).toHaveBeenCalledWith( 'ExecutionService:handleRpcRequest', MOCK_SNAP_ID, @@ -9214,46 +9214,6 @@ describe('SnapController', () => { }); }); - describe('getRegistryMetadata', () => { - it('returns the metadata for a verified snap', async () => { - const registry = new MockSnapsRegistry(); - const rootMessenger = getControllerMessenger(registry); - const messenger = getSnapControllerMessenger(rootMessenger); - registry.getMetadata.mockReturnValue({ - name: 'Mock Snap', - }); - - const snapController = getSnapController( - getSnapControllerOptions({ - messenger, - }), - ); - - expect( - await snapController.getRegistryMetadata(MOCK_SNAP_ID), - ).toStrictEqual({ - name: 'Mock Snap', - }); - - snapController.destroy(); - }); - - it('returns null for a non-verified snap', async () => { - const registry = new MockSnapsRegistry(); - const rootMessenger = getControllerMessenger(registry); - const messenger = getSnapControllerMessenger(rootMessenger); - const snapController = getSnapController( - getSnapControllerOptions({ - messenger, - }), - ); - - expect(await snapController.getRegistryMetadata(MOCK_SNAP_ID)).toBeNull(); - - snapController.destroy(); - }); - }); - describe('clearState', () => { it('clears the state and terminates running snaps', async () => { const rootMessenger = getControllerMessenger(); @@ -9416,6 +9376,195 @@ describe('SnapController', () => { snapController.destroy(); }); + + it('should track event for allowed handler', async () => { + const mockTrackEvent = jest.fn(); + const rootMessenger = getControllerMessenger(); + const executionEnvironmentStub = new ExecutionEnvironmentStub( + getNodeEESMessenger(rootMessenger), + ) as unknown as NodeThreadExecutionService; + + const [snapController] = getSnapControllerWithEES( + getSnapControllerWithEESOptions({ + rootMessenger, + trackEvent: mockTrackEvent, + state: { + snaps: getPersistedSnapsState(), + }, + }), + executionEnvironmentStub, + ); + + const snap = snapController.getExpect(MOCK_SNAP_ID); + await snapController.startSnap(snap.id); + + await snapController.handleRequest({ + snapId: snap.id, + origin: MOCK_ORIGIN, + handler: HandlerType.OnRpcRequest, + request: { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 1, + }, + }); + + expect(mockTrackEvent).toHaveBeenCalledTimes(1); + expect(mockTrackEvent).toHaveBeenCalledWith({ + event: 'SnapExportUsed', + category: 'Snaps', + properties: { + export: 'onRpcRequest', + origin: 'https://example.com', + // eslint-disable-next-line @typescript-eslint/naming-convention + snap_category: null, + // eslint-disable-next-line @typescript-eslint/naming-convention + snap_id: 'npm:@metamask/example-snap', + success: true, + }, + }); + snapController.destroy(); + }); + + it('should not track event for disallowed handler', async () => { + const mockTrackEvent = jest.fn(); + const rootMessenger = getControllerMessenger(); + + rootMessenger.registerActionHandler( + 'PermissionController:getPermissions', + () => ({ + [SnapEndowments.Cronjob]: { + caveats: [ + { type: SnapCaveatType.SnapCronjob, value: '* * * * *' }, + ], + date: 1664187844588, + id: 'izn0WGUO8cvq_jqvLQuQP', + invoker: MOCK_SNAP_ID, + parentCapability: SnapEndowments.Cronjob, + }, + }), + ); + + const executionEnvironmentStub = new ExecutionEnvironmentStub( + getNodeEESMessenger(rootMessenger), + ) as unknown as NodeThreadExecutionService; + + const [snapController] = getSnapControllerWithEES( + getSnapControllerWithEESOptions({ + environmentEndowmentPermissions: ['endowment:cronjob'], + rootMessenger, + trackEvent: mockTrackEvent, + state: { + snaps: getPersistedSnapsState(), + }, + }), + executionEnvironmentStub, + ); + + const snap = snapController.getExpect(MOCK_SNAP_ID); + await snapController.startSnap(snap.id); + + await snapController.handleRequest({ + snapId: snap.id, + origin: MOCK_ORIGIN, + handler: HandlerType.OnCronjob, + request: { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 1, + }, + }); + + expect(mockTrackEvent).not.toHaveBeenCalled(); + snapController.destroy(); + }); + + it('should properly handle error when MetaMetrics hook throws an error', async () => { + const log = jest.spyOn(console, 'error').mockImplementation(); + const error = new Error('MetaMetrics hook error'); + const mockTrackEvent = jest.fn().mockImplementation(() => { + throw error; + }); + const rootMessenger = getControllerMessenger(); + const executionEnvironmentStub = new ExecutionEnvironmentStub( + getNodeEESMessenger(rootMessenger), + ) as unknown as NodeThreadExecutionService; + + const [snapController] = getSnapControllerWithEES( + getSnapControllerWithEESOptions({ + rootMessenger, + trackEvent: mockTrackEvent, + state: { + snaps: getPersistedSnapsState(), + }, + }), + executionEnvironmentStub, + ); + + const snap = snapController.getExpect(MOCK_SNAP_ID); + await snapController.startSnap(snap.id); + + await snapController.handleRequest({ + snapId: snap.id, + origin: MOCK_ORIGIN, + handler: HandlerType.OnRpcRequest, + request: { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 1, + }, + }); + + expect(mockTrackEvent).toHaveBeenCalled(); + expect(log).toHaveBeenCalledWith( + expect.stringContaining( + 'Error when calling MetaMetrics hook for snap', + ), + ); + snapController.destroy(); + }); + + it('should not track event for preinstalled snap', async () => { + const mockTrackEvent = jest.fn(); + const rootMessenger = getControllerMessenger(); + const executionEnvironmentStub = new ExecutionEnvironmentStub( + getNodeEESMessenger(rootMessenger), + ) as unknown as NodeThreadExecutionService; + + const [snapController] = getSnapControllerWithEES( + getSnapControllerWithEESOptions({ + rootMessenger, + trackEvent: mockTrackEvent, + state: { + snaps: getPersistedSnapsState( + getPersistedSnapObject({ preinstalled: true }), + ), + }, + }), + executionEnvironmentStub, + ); + + const snap = snapController.getExpect(MOCK_SNAP_ID); + await snapController.startSnap(snap.id); + + await snapController.handleRequest({ + snapId: snap.id, + origin: MOCK_ORIGIN, + handler: HandlerType.OnRpcRequest, + request: { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 1, + }, + }); + + expect(mockTrackEvent).not.toHaveBeenCalled(); + snapController.destroy(); + }); }); it('handles a transaction insight request', async () => { @@ -10388,35 +10537,6 @@ describe('SnapController', () => { }); }); - describe('SnapController:getRegistryMetadata', () => { - it('calls SnapController.getRegistryMetadata()', async () => { - const registry = new MockSnapsRegistry(); - const rootMessenger = getControllerMessenger(registry); - const messenger = getSnapControllerMessenger(rootMessenger); - - registry.getMetadata.mockReturnValue({ - name: 'Mock Snap', - }); - - const snapController = getSnapController( - getSnapControllerOptions({ - messenger, - }), - ); - - expect( - await messenger.call( - 'SnapController:getRegistryMetadata', - MOCK_SNAP_ID, - ), - ).toStrictEqual({ - name: 'Mock Snap', - }); - - snapController.destroy(); - }); - }); - describe('SnapController:disconnectOrigin', () => { it('calls SnapController.removeSnapFromSubject()', () => { const messenger = getSnapControllerMessenger(); diff --git a/packages/snaps-controllers/src/snaps/SnapController.ts b/packages/snaps-controllers/src/snaps/SnapController.ts index 12755c9d7f..08b1d7b70e 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.ts +++ b/packages/snaps-controllers/src/snaps/SnapController.ts @@ -146,7 +146,6 @@ import type { GetResult, ResolveVersion, SnapsRegistryInfo, - SnapsRegistryMetadata, SnapsRegistryRequest, Update, } from './registry'; @@ -176,7 +175,9 @@ import { hasTimedOut, permissionsDiff, setDiff, + throttleTracking, withTimeout, + isTrackableHandler, } from '../utils'; export const controllerName = 'SnapController'; @@ -442,11 +443,6 @@ export type InstallSnaps = { handler: SnapController['installSnaps']; }; -export type GetRegistryMetadata = { - type: `${typeof controllerName}:getRegistryMetadata`; - handler: SnapController['getRegistryMetadata']; -}; - export type DisconnectOrigin = { type: `${typeof controllerName}:disconnectOrigin`; handler: SnapController['removeSnapFromSubject']; @@ -484,7 +480,6 @@ export type SnapControllerActions = | GetRunnableSnaps | IncrementActiveReferences | DecrementActiveReferences - | GetRegistryMetadata | DisconnectOrigin | RevokeDynamicPermissions | GetSnapFile @@ -775,6 +770,11 @@ type SnapControllerArgs = { * object to fall back to the default cryptographic functions. */ clientCryptography?: CryptographicFunctions; + + /** + * MetaMetrics event tracking hook. + */ + trackEvent: TrackEventHook; }; type AddSnapArgs = { @@ -795,6 +795,14 @@ type SetSnapArgs = Omit & { hideSnapBranding?: boolean; }; +type TrackingEventPayload = { + event: string; + category: string; + properties: Record; +}; + +type TrackEventHook = (event: TrackingEventPayload) => void; + const defaultState: SnapControllerState = { snaps: {}, snapStates: {}, @@ -880,6 +888,10 @@ export class SnapController extends BaseController< readonly #preinstalledSnaps: PreinstalledSnap[] | null; + readonly #trackEvent: TrackEventHook; + + readonly #trackSnapExport: ReturnType; + constructor({ closeAllConnections, messenger, @@ -898,6 +910,7 @@ export class SnapController extends BaseController< getMnemonicSeed, getFeatureFlags = () => ({}), clientCryptography, + trackEvent, }: SnapControllerArgs) { super({ messenger, @@ -960,6 +973,7 @@ export class SnapController extends BaseController< this._onOutboundResponse = this._onOutboundResponse.bind(this); this.#rollbackSnapshots = new Map(); this.#snapsRuntimeData = new Map(); + this.#trackEvent = trackEvent; this.#pollForLastRequestStatus(); @@ -1025,6 +1039,28 @@ export class SnapController extends BaseController< Object.values(this.state?.snaps ?? {}).forEach((snap) => this.#setupRuntime(snap.id), ); + + this.#trackSnapExport = throttleTracking( + (snapId: SnapId, handler: string, success: boolean, origin: string) => { + const snapMetadata = this.messagingSystem.call( + 'SnapsRegistry:getMetadata', + snapId, + ); + this.#trackEvent({ + event: 'SnapExportUsed', + category: 'Snaps', + properties: { + // eslint-disable-next-line @typescript-eslint/naming-convention + snap_id: snapId, + export: handler, + // eslint-disable-next-line @typescript-eslint/naming-convention + snap_category: snapMetadata?.category ?? null, + success, + origin, + }, + }); + }, + ); } /** @@ -1180,11 +1216,6 @@ export class SnapController extends BaseController< (...args) => this.decrementActiveReferences(...args), ); - this.messagingSystem.registerActionHandler( - `${controllerName}:getRegistryMetadata`, - async (...args) => this.getRegistryMetadata(...args), - ); - this.messagingSystem.registerActionHandler( `${controllerName}:disconnectOrigin`, (...args) => this.removeSnapFromSubject(...args), @@ -2949,20 +2980,6 @@ export class SnapController extends BaseController< ); } - /** - * Get metadata for the given snap ID. - * - * @param snapId - The ID of the snap to get metadata for. - * @returns The metadata for the given snap ID, or `null` if the snap is not - * verified. - */ - async getRegistryMetadata( - snapId: SnapId, - ): Promise { - this.#assertCanUsePlatform(); - return await this.messagingSystem.call('SnapsRegistry:getMetadata', snapId); - } - /** * Returns a promise representing the complete installation of the requested snap. * If the snap is already being installed, the previously pending promise will be returned. @@ -3582,12 +3599,25 @@ export class SnapController extends BaseController< result, ); - this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id); + this.#recordSnapRpcRequestFinish( + snapId, + transformedRequest.id, + handlerType, + origin, + true, + ); return transformedResult; } catch (error) { // We flag the RPC request as finished early since termination may affect pending requests - this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id); + this.#recordSnapRpcRequestFinish( + snapId, + transformedRequest.id, + handlerType, + origin, + false, + ); + const [jsonRpcError, handled] = unwrapError(error); if (!handled) { @@ -3878,7 +3908,13 @@ export class SnapController extends BaseController< runtime.lastRequest = null; } - #recordSnapRpcRequestFinish(snapId: SnapId, requestId: unknown) { + #recordSnapRpcRequestFinish( + snapId: SnapId, + requestId: unknown, + handlerType: HandlerType, + origin: string, + success: boolean, + ) { const runtime = this.#getRuntimeExpect(snapId); runtime.pendingInboundRequests = runtime.pendingInboundRequests.filter( (request) => request.requestId !== requestId, @@ -3887,6 +3923,20 @@ export class SnapController extends BaseController< if (runtime.pendingInboundRequests.length === 0) { runtime.lastRequest = Date.now(); } + + const snap = this.get(snapId); + + if (isTrackableHandler(handlerType) && !snap?.preinstalled) { + try { + this.#trackSnapExport(snapId, handlerType, success, origin); + } catch (error) { + logError( + `Error when calling MetaMetrics hook for snap "${snap?.id}": ${getErrorMessage( + error, + )}`, + ); + } + } } /** diff --git a/packages/snaps-controllers/src/snaps/registry/json.test.ts b/packages/snaps-controllers/src/snaps/registry/json.test.ts index c2adc34c69..ce245c4b72 100644 --- a/packages/snaps-controllers/src/snaps/registry/json.test.ts +++ b/packages/snaps-controllers/src/snaps/registry/json.test.ts @@ -470,10 +470,8 @@ describe('JsonSnapsRegistry', () => { .mockResponseOnce(JSON.stringify(MOCK_SIGNATURE_FILE)); const { messenger } = getRegistry(); - const result = await messenger.call( - 'SnapsRegistry:getMetadata', - MOCK_SNAP_ID, - ); + await messenger.call('SnapsRegistry:update'); + const result = messenger.call('SnapsRegistry:getMetadata', MOCK_SNAP_ID); expect(result).toStrictEqual({ name: 'Mock Snap', @@ -486,7 +484,8 @@ describe('JsonSnapsRegistry', () => { .mockResponseOnce(JSON.stringify(MOCK_SIGNATURE_FILE)); const { messenger } = getRegistry(); - const result = await messenger.call('SnapsRegistry:getMetadata', 'foo'); + await messenger.call('SnapsRegistry:update'); + const result = messenger.call('SnapsRegistry:getMetadata', 'foo'); expect(result).toBeNull(); }); diff --git a/packages/snaps-controllers/src/snaps/registry/json.ts b/packages/snaps-controllers/src/snaps/registry/json.ts index 1ec2baf62c..86f7828ac5 100644 --- a/packages/snaps-controllers/src/snaps/registry/json.ts +++ b/packages/snaps-controllers/src/snaps/registry/json.ts @@ -166,7 +166,7 @@ export class JsonSnapsRegistry extends BaseController< this.messagingSystem.registerActionHandler( 'SnapsRegistry:getMetadata', - async (...args) => this.#getMetadata(...args), + (...args) => this.#getMetadata(...args), ); this.messagingSystem.registerActionHandler( @@ -346,15 +346,14 @@ export class JsonSnapsRegistry extends BaseController< } /** - * Get metadata for the given snap ID. + * Get metadata for the given snap ID, if available, without updating registry. * * @param snapId - The ID of the snap to get metadata for. * @returns The metadata for the given snap ID, or `null` if the snap is not * verified. */ - async #getMetadata(snapId: string): Promise { - const database = await this.#getDatabase(); - return database?.verifiedSnaps[snapId]?.metadata ?? null; + #getMetadata(snapId: string): SnapsRegistryMetadata | null { + return this.state?.database?.verifiedSnaps[snapId]?.metadata ?? null; } /** diff --git a/packages/snaps-controllers/src/snaps/registry/registry.ts b/packages/snaps-controllers/src/snaps/registry/registry.ts index 21e50dd98b..07c5009677 100644 --- a/packages/snaps-controllers/src/snaps/registry/registry.ts +++ b/packages/snaps-controllers/src/snaps/registry/registry.ts @@ -49,5 +49,5 @@ export type SnapsRegistry = { * @returns The metadata for the given snap ID, or `null` if the snap is not * verified. */ - getMetadata(snapId: SnapId): Promise; + getMetadata(snapId: SnapId): SnapsRegistryMetadata | null; }; diff --git a/packages/snaps-controllers/src/test-utils/controller.ts b/packages/snaps-controllers/src/test-utils/controller.ts index dfbfc27f2f..3cff5c0499 100644 --- a/packages/snaps-controllers/src/test-utils/controller.ts +++ b/packages/snaps-controllers/src/test-utils/controller.ts @@ -593,6 +593,7 @@ export const getSnapControllerOptions = ( Promise.resolve(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES), clientCryptography: {}, encryptor: getSnapControllerEncryptor(), + trackEvent: jest.fn(), ...opts, } as SnapControllerConstructorParams; @@ -626,6 +627,7 @@ export const getSnapControllerWithEESOptions = ({ Promise.resolve(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES), encryptor: getSnapControllerEncryptor(), fetchFunction: jest.fn(), + trackEvent: jest.fn(), ...options, } as SnapControllerConstructorParams & { rootMessenger: ReturnType; diff --git a/packages/snaps-controllers/src/test-utils/registry.ts b/packages/snaps-controllers/src/test-utils/registry.ts index 20e3c859ee..c66b931977 100644 --- a/packages/snaps-controllers/src/test-utils/registry.ts +++ b/packages/snaps-controllers/src/test-utils/registry.ts @@ -18,7 +18,7 @@ export class MockSnapsRegistry implements SnapsRegistry { throw new Error('The snap is not on the allowlist.'); }); - getMetadata = jest.fn().mockResolvedValue(null); + getMetadata = jest.fn().mockReturnValue(null); update = jest.fn(); } diff --git a/packages/snaps-controllers/src/utils.test.ts b/packages/snaps-controllers/src/utils.test.ts index 7899c2a773..cc0a320bc9 100644 --- a/packages/snaps-controllers/src/utils.test.ts +++ b/packages/snaps-controllers/src/utils.test.ts @@ -1,4 +1,4 @@ -import { VirtualFile } from '@metamask/snaps-utils'; +import { HandlerType, VirtualFile } from '@metamask/snaps-utils'; import { getMockSnapFiles, getSnapManifest, @@ -20,6 +20,8 @@ import { getSnapFiles, permissionsDiff, setDiff, + throttleTracking, + TRACKABLE_HANDLERS, } from './utils'; import { SnapEndowments } from '../../snaps-rpc-methods/src/endowments'; @@ -221,3 +223,161 @@ describe('debouncePersistState', () => { expect(fn).toHaveBeenNthCalledWith(4, MOCK_LOCAL_SNAP_ID, {}, false); }); }); + +describe('TRACKABLE_HANDLERS', () => { + it('should contain the expected handler types', () => { + expect(TRACKABLE_HANDLERS).toStrictEqual([ + HandlerType.OnHomePage, + HandlerType.OnInstall, + HandlerType.OnNameLookup, + HandlerType.OnRpcRequest, + HandlerType.OnSignature, + HandlerType.OnTransaction, + HandlerType.OnUpdate, + ]); + }); + + it('should be a readonly array', () => { + expect(Object.isFrozen(TRACKABLE_HANDLERS)).toBe(true); + }); + + it('should contain unique values', () => { + const uniqueValues = new Set(TRACKABLE_HANDLERS); + expect(uniqueValues.size).toBe(TRACKABLE_HANDLERS.length); + }); +}); + +describe('throttleTracking', () => { + beforeAll(() => { + jest.useFakeTimers(); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + + it('throttles tracking calls based on unique combinations of snapId, handler, and origin', () => { + const fn = jest.fn(); + const throttled = throttleTracking(fn, 1000); + + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnRpcRequest, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin2'); + + expect(fn).toHaveBeenCalledTimes(3); + expect(fn).toHaveBeenNthCalledWith( + 1, + MOCK_SNAP_ID, + HandlerType.OnHomePage, + true, + 'origin1', + ); + expect(fn).toHaveBeenNthCalledWith( + 2, + MOCK_SNAP_ID, + HandlerType.OnRpcRequest, + true, + 'origin1', + ); + expect(fn).toHaveBeenNthCalledWith( + 3, + MOCK_SNAP_ID, + HandlerType.OnHomePage, + true, + 'origin2', + ); + + jest.advanceTimersByTime(500); + + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnRpcRequest, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin2'); + + expect(fn).toHaveBeenCalledTimes(3); + + jest.advanceTimersByTime(500); + + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnRpcRequest, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin2'); + + expect(fn).toHaveBeenCalledTimes(3); + + jest.advanceTimersByTime(1000); + + expect(fn).toHaveBeenCalledTimes(6); + expect(fn).toHaveBeenNthCalledWith( + 4, + MOCK_SNAP_ID, + HandlerType.OnHomePage, + true, + 'origin1', + ); + expect(fn).toHaveBeenNthCalledWith( + 5, + MOCK_SNAP_ID, + HandlerType.OnRpcRequest, + true, + 'origin1', + ); + expect(fn).toHaveBeenNthCalledWith( + 6, + MOCK_SNAP_ID, + HandlerType.OnHomePage, + true, + 'origin2', + ); + + jest.advanceTimersByTime(5000); + expect(fn).toHaveBeenCalledTimes(6); + }); + + it('uses default timeout of 60000ms when no timeout is specified', async () => { + const fn = jest.fn(); + const throttled = throttleTracking(fn); + + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + expect(fn).toHaveBeenCalledTimes(1); + + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + expect(fn).toHaveBeenCalledTimes(1); + + jest.advanceTimersByTime(60000); + expect(fn).toHaveBeenCalledTimes(2); + }); + + it('should execute the last throttled call after timeout', () => { + const mockFn = jest.fn(); + const throttled = throttleTracking(mockFn, 1000); + + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + expect(mockFn).toHaveBeenCalledTimes(1); + expect(mockFn).toHaveBeenLastCalledWith( + MOCK_SNAP_ID, + HandlerType.OnHomePage, + true, + 'origin1', + ); + + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + throttled(MOCK_SNAP_ID, HandlerType.OnHomePage, true, 'origin1'); + + expect(mockFn).toHaveBeenCalledTimes(1); + + jest.advanceTimersByTime(500); + + expect(mockFn).toHaveBeenCalledTimes(1); + + jest.advanceTimersByTime(500); + + expect(mockFn).toHaveBeenCalledTimes(2); + expect(mockFn).toHaveBeenLastCalledWith( + MOCK_SNAP_ID, + HandlerType.OnHomePage, + true, + 'origin1', + ); + }); +}); diff --git a/packages/snaps-controllers/src/utils.ts b/packages/snaps-controllers/src/utils.ts index 2292c0462f..1aec363637 100644 --- a/packages/snaps-controllers/src/utils.ts +++ b/packages/snaps-controllers/src/utils.ts @@ -7,6 +7,7 @@ import { getValidatedLocalizationFiles, validateAuxiliaryFiles, validateFetchedSnap, + HandlerType, } from '@metamask/snaps-utils'; import type { Json } from '@metamask/utils'; import deepEqual from 'fast-deep-equal'; @@ -375,3 +376,97 @@ export function debouncePersistState( ); }; } + +/** + * Handlers allowed for tracking. + */ +export const TRACKABLE_HANDLERS = Object.freeze([ + HandlerType.OnHomePage, + HandlerType.OnInstall, + HandlerType.OnNameLookup, + HandlerType.OnRpcRequest, + HandlerType.OnSignature, + HandlerType.OnTransaction, + HandlerType.OnUpdate, +] as const); + +/** + * A union type representing all possible trackable handler types. + */ +export type TrackableHandler = (typeof TRACKABLE_HANDLERS)[number]; + +/** + * Throttles event tracking calls per unique combination of parameters. + * + * @param fn - The tracking function to throttle. + * @param timeout - The timeout in milliseconds. Defaults to 60000 (1 minute). + * @returns The throttled function. + */ +export function throttleTracking( + fn: ( + snapId: SnapId, + handler: TrackableHandler, + success: boolean, + origin: string, + ) => void, + timeout = 60000, +) { + const previousCalls = new Map(); + const pendingCalls = new Map< + string, + { + args: [SnapId, TrackableHandler, boolean, string]; + timer: ReturnType | null; + } + >(); + + return ( + snapId: SnapId, + handler: TrackableHandler, + success: boolean, + origin: string, + ): void => { + const key = `${snapId}${handler}${success}${origin}`; + const now = Date.now(); + const lastCall = previousCalls.get(key) ?? 0; + const args: [SnapId, TrackableHandler, boolean, string] = [ + snapId, + handler, + success, + origin, + ]; + + if (now - lastCall >= timeout) { + previousCalls.set(key, now); + fn(...args); + return; + } + + const pending = pendingCalls.get(key); + if (pending?.timer) { + clearTimeout(pending.timer); + } + + previousCalls.set(key, now); + + pendingCalls.set(key, { + args, + timer: setTimeout(() => { + fn(...args); + pendingCalls.delete(key); + }, timeout), + }); + }; +} + +/** + * Whether the handler type if allowed for tracking. + * + * @param handler Type of a handler. + * @returns True if handler is allowed for tracking, false otherwise. + */ +export function isTrackableHandler( + handler: HandlerType, +): handler is TrackableHandler { + return TRACKABLE_HANDLERS.includes(handler as TrackableHandler); +}