From 1f2109f7edc694543bdf9dc19ffba000da13bbc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Tani=C3=A7a?= Date: Fri, 8 May 2026 06:25:19 -0600 Subject: [PATCH 1/2] feat(predict): add confirmation hook plumbing --- .../PredictController-method-action-types.ts | 12 ++ .../controllers/PredictController.test.ts | 34 ++++ .../Predict/controllers/PredictController.ts | 16 ++ .../transaction-controller-init.test.ts | 148 +++++++++++++++++- .../transaction-controller-init.ts | 64 ++++++++ 5 files changed, 268 insertions(+), 6 deletions(-) diff --git a/app/components/UI/Predict/controllers/PredictController-method-action-types.ts b/app/components/UI/Predict/controllers/PredictController-method-action-types.ts index 611cc69c66c..5dc2a92767b 100644 --- a/app/components/UI/Predict/controllers/PredictController-method-action-types.ts +++ b/app/components/UI/Predict/controllers/PredictController-method-action-types.ts @@ -248,11 +248,21 @@ export type PredictControllerPrepareWithdrawAction = { handler: PredictController['prepareWithdraw']; }; +export type PredictControllerBeforePublishAction = { + type: `PredictController:beforePublish`; + handler: PredictController['beforePublish']; +}; + export type PredictControllerBeforeSignAction = { type: `PredictController:beforeSign`; handler: PredictController['beforeSign']; }; +export type PredictControllerPublishAction = { + type: `PredictController:publish`; + handler: PredictController['publish']; +}; + export type PredictControllerClearWithdrawTransactionAction = { type: `PredictController:clearWithdrawTransaction`; handler: PredictController['clearWithdrawTransaction']; @@ -300,5 +310,7 @@ export type PredictControllerMethodActions = | PredictControllerGetAccountStateAction | PredictControllerGetBalanceAction | PredictControllerPrepareWithdrawAction + | PredictControllerBeforePublishAction | PredictControllerBeforeSignAction + | PredictControllerPublishAction | PredictControllerClearWithdrawTransactionAction; diff --git a/app/components/UI/Predict/controllers/PredictController.test.ts b/app/components/UI/Predict/controllers/PredictController.test.ts index 6b0122eb258..08562cacbcc 100644 --- a/app/components/UI/Predict/controllers/PredictController.test.ts +++ b/app/components/UI/Predict/controllers/PredictController.test.ts @@ -6139,6 +6139,40 @@ describe('PredictController', () => { }); }); + describe('beforePublish', () => { + it('passes through by default', async () => { + await withController(async ({ controller }) => { + const result = await controller.beforePublish({ + transactionMeta: { + id: 'tx-1', + txParams: { + from: MOCK_ADDRESS, + }, + } as TransactionMeta, + }); + + expect(result).toBe(true); + }); + }); + }); + + describe('publish', () => { + it('passes through by default', async () => { + await withController(async ({ controller }) => { + const result = await controller.publish({ + transactionMeta: { + id: 'tx-1', + txParams: { + from: MOCK_ADDRESS, + }, + } as TransactionMeta, + }); + + expect(result).toEqual({ transactionHash: undefined }); + }); + }); + }); + describe('beforeSign', () => { const mockTransactionMeta = { id: 'tx-1', diff --git a/app/components/UI/Predict/controllers/PredictController.ts b/app/components/UI/Predict/controllers/PredictController.ts index 4396fb809e1..96323c345cc 100644 --- a/app/components/UI/Predict/controllers/PredictController.ts +++ b/app/components/UI/Predict/controllers/PredictController.ts @@ -345,6 +345,7 @@ export interface PredictControllerOptions { } const MESSENGER_EXPOSED_METHODS = [ + 'beforePublish', 'beforeSign', 'claimWithConfirmation', 'clearActiveOrder', @@ -370,6 +371,7 @@ const MESSENGER_EXPOSED_METHODS = [ 'onPlaceOrderSuccess', 'placeOrder', 'prepareWithdraw', + 'publish', 'previewOrder', 'refreshEligibility', 'selectPaymentToken', @@ -2619,6 +2621,12 @@ export class PredictController extends BaseController< } } + public async beforePublish(_request: { + transactionMeta: TransactionMeta; + }): Promise { + return true; + } + public async beforeSign(request: { transactionMeta: TransactionMeta; }): Promise< @@ -2722,6 +2730,12 @@ export class PredictController extends BaseController< }; } + public async publish(_request: { + transactionMeta: TransactionMeta; + }): Promise<{ transactionHash?: string; isIntentComplete?: boolean }> { + return { transactionHash: undefined }; + } + public clearWithdrawTransaction(): void { this.update((state) => { state.withdrawTransaction = null; @@ -2730,6 +2744,7 @@ export class PredictController extends BaseController< } export type { + PredictControllerBeforePublishAction, PredictControllerBeforeSignAction, PredictControllerClaimWithConfirmationAction, PredictControllerClearActiveOrderAction, @@ -2754,6 +2769,7 @@ export type { PredictControllerPlaceOrderAction, PredictControllerPrepareWithdrawAction, PredictControllerPreviewOrderAction, + PredictControllerPublishAction, PredictControllerRefreshEligibilityAction, PredictControllerSelectPaymentTokenAction, PredictControllerSetSelectedPaymentTokenAction, diff --git a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts index 7d77de82cfe..8ab21bf6728 100644 --- a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts +++ b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts @@ -92,10 +92,20 @@ const MOCK_TRANSACTION_META = { * with the default mock. * @returns A mock NetworkController. */ +type ControllerMock = NetworkController & { + beforePublish: jest.Mock; + beforeSign: jest.Mock; + publish: jest.Mock; +}; + function buildControllerMock( - partialMock?: Partial, -): NetworkController { - const defaultControllerMocks = {}; + partialMock?: Partial, +): ControllerMock { + const defaultControllerMocks = { + beforePublish: jest.fn().mockResolvedValue(true), + beforeSign: jest.fn(), + publish: jest.fn().mockResolvedValue({ transactionHash: undefined }), + }; // @ts-expect-error Incomplete mock, just includes properties used by code-under-test. return { @@ -180,9 +190,11 @@ describe('Transaction Controller Init', () => { ): TransactionControllerOptions[T] { const requestMock = buildInitRequestMock(initRequestProperties); - requestMock.getMessengerClient.mockReturnValue( - buildControllerMock(dependencyProperties), - ); + if (!initRequestProperties.getMessengerClient) { + requestMock.getMessengerClient.mockReturnValue( + buildControllerMock(dependencyProperties), + ); + } TransactionControllerInit(requestMock); @@ -320,6 +332,29 @@ describe('Transaction Controller Init', () => { expect(optionFn?.()).toBe(false); }); + describe('beforePublish hook', () => { + it('delegates to PredictController beforePublish', async () => { + const predictControllerMock = buildControllerMock(); + const hooks = testConstructorOption( + 'hooks', + {}, + { + getMessengerClient: jest.fn((controllerName: string) => + controllerName === 'PredictController' + ? predictControllerMock + : buildControllerMock(), + ), + }, + ); + + await hooks?.beforePublish?.(MOCK_TRANSACTION_META); + + expect(predictControllerMock.beforePublish).toHaveBeenCalledWith({ + transactionMeta: MOCK_TRANSACTION_META, + }); + }); + }); + describe('publish hook', () => { it('calls submitSmartTransactionHook', async () => { const hooks = testConstructorOption('hooks'); @@ -347,6 +382,107 @@ describe('Transaction Controller Init', () => { expect(payHookMock).toHaveBeenCalledTimes(1); }); + it('calls Predict publish before pay and smart transaction hooks', async () => { + const predictControllerMock = buildControllerMock(); + const hooks = testConstructorOption( + 'hooks', + {}, + { + getMessengerClient: jest.fn((controllerName: string) => + controllerName === 'PredictController' + ? predictControllerMock + : buildControllerMock(), + ), + }, + ); + + await hooks?.publish?.(MOCK_TRANSACTION_META); + + expect(predictControllerMock.publish).toHaveBeenCalledWith({ + transactionMeta: MOCK_TRANSACTION_META, + }); + expect(payHookMock).toHaveBeenCalledTimes(1); + expect( + (predictControllerMock.publish as jest.Mock).mock + .invocationCallOrder[0], + ).toBeLessThan(payHookMock.mock.invocationCallOrder[0]); + expect( + (predictControllerMock.publish as jest.Mock).mock + .invocationCallOrder[0], + ).toBeLessThan( + submitSmartTransactionHookMock.mock.invocationCallOrder[0], + ); + }); + + it('short-circuits publish when Predict returns a transaction hash', async () => { + const predictControllerMock = buildControllerMock({ + publish: jest.fn().mockResolvedValue({ transactionHash: '0xpredict' }), + } as unknown as Partial); + const hooks = testConstructorOption( + 'hooks', + {}, + { + getMessengerClient: jest.fn((controllerName: string) => + controllerName === 'PredictController' + ? predictControllerMock + : buildControllerMock(), + ), + }, + ); + + const result = await hooks?.publish?.(MOCK_TRANSACTION_META); + + expect(result).toEqual({ transactionHash: '0xpredict' }); + expect(payHookMock).not.toHaveBeenCalled(); + expect(submitSmartTransactionHookMock).not.toHaveBeenCalled(); + }); + + it('marks the latest transaction intent complete when Predict publish completes an intent', async () => { + const predictControllerMock = buildControllerMock({ + publish: jest.fn().mockResolvedValue({ + transactionHash: '0xpredict', + isIntentComplete: true, + }), + } as unknown as Partial); + const getTransactionByIdMock = jest.requireMock( + '../../../../util/transactions', + ).getTransactionById; + getTransactionByIdMock.mockReturnValue({ ...MOCK_TRANSACTION_META }); + const hooks = testConstructorOption( + 'hooks', + {}, + { + getMessengerClient: jest.fn((controllerName: string) => + controllerName === 'PredictController' + ? predictControllerMock + : buildControllerMock(), + ), + }, + ); + + const result = await hooks?.publish?.(MOCK_TRANSACTION_META); + + const transactionControllerInstance = transactionControllerClassMock.mock + .instances[0] as unknown as { + updateTransaction: jest.Mock; + }; + + expect(result).toEqual({ transactionHash: '0xpredict' }); + expect(getTransactionByIdMock).toHaveBeenCalledWith( + MOCK_TRANSACTION_META.id, + transactionControllerInstance, + ); + expect( + transactionControllerInstance.updateTransaction, + ).toHaveBeenCalledWith( + expect.objectContaining({ + id: MOCK_TRANSACTION_META.id, + isIntentComplete: true, + }), + 'Predict claim relayer intent complete', + ); + }); + it('passes isSmartTransaction returning false to pay hook when stxDisabled is true', async () => { selectMetaMaskPayFlagsMock.mockReturnValue({ attemptsMax: 2, diff --git a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts index 64bc20e25cb..c71e6cb00c6 100644 --- a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts +++ b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts @@ -125,6 +125,7 @@ export const TransactionControllerInit: MessengerClientInitFunction< transactionController, smartTransactionsController, initMessenger, + request, signedTransactionInHex, }), publishBatch: async (_request: PublishBatchHookRequest) => @@ -136,6 +137,8 @@ export const TransactionControllerInit: MessengerClientInitFunction< transactions: _request.transactions as PublishBatchHookTransaction[], }), + beforePublish: (transactionMeta: TransactionMeta) => + beforePublish(transactionMeta, request), beforeSign: (_request: { transactionMeta: TransactionMeta }) => beforeSign(_request, request), }, @@ -216,6 +219,7 @@ async function publishHook({ transactionController, smartTransactionsController, initMessenger, + request, signedTransactionInHex, }: { transactionMeta: TransactionMeta; @@ -224,8 +228,22 @@ async function publishHook({ transactionController: TransactionController; smartTransactionsController: SmartTransactionsController; initMessenger: TransactionControllerInitMessenger; + request: MessengerClientInitRequest< + TransactionControllerMessenger, + TransactionControllerInitMessenger + >; signedTransactionInHex: Hex; }): Promise<{ transactionHash?: string }> { + const predictResult = await publishPredict({ + transactionMeta, + transactionController, + request, + }); + + if (predictResult.transactionHash) { + return { transactionHash: predictResult.transactionHash }; + } + const state = getState(); const { shouldUseSmartTransaction, featureFlags } = @@ -326,6 +344,41 @@ async function publishHook({ return { transactionHash: undefined }; } +async function publishPredict({ + transactionMeta, + transactionController, + request, +}: { + transactionMeta: TransactionMeta; + transactionController: TransactionController; + request: MessengerClientInitRequest< + TransactionControllerMessenger, + TransactionControllerInitMessenger + >; +}): Promise<{ transactionHash?: string; isIntentComplete?: boolean }> { + const predictController = request.getMessengerClient('PredictController'); + const result = await predictController.publish({ transactionMeta }); + + if (result.transactionHash && result.isIntentComplete) { + const latestMeta = getTransactionById( + transactionMeta.id, + transactionController, + ); + + if (latestMeta) { + transactionController.updateTransaction( + { + ...latestMeta, + isIntentComplete: true, + }, + 'Predict claim relayer intent complete', + ); + } + } + + return result; +} + function getSmartTransactionCommonParams(state: RootState, chainId: Hex) { const shouldUseSmartTransaction = selectShouldUseSmartTransaction( state, @@ -441,6 +494,17 @@ function getControllers( }; } +function beforePublish( + transactionMeta: TransactionMeta, + request: MessengerClientInitRequest< + TransactionControllerMessenger, + TransactionControllerInitMessenger + >, +) { + const predictController = request.getMessengerClient('PredictController'); + return predictController.beforePublish({ transactionMeta }); +} + function beforeSign( hookRequest: { transactionMeta: TransactionMeta }, request: MessengerClientInitRequest< From 3f9d6182134dd88a4adec93fb6a62a80f4e4422e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Tani=C3=A7a?= Date: Fri, 8 May 2026 09:34:03 -0600 Subject: [PATCH 2/2] fix(predict): address confirmation hook review --- .../Predict/controllers/PredictController.ts | 2 +- .../transaction-controller-init.test.ts | 89 ++++++------------- .../transaction-controller-init.ts | 66 +++----------- .../transaction-controller-messenger.ts | 10 ++- 4 files changed, 47 insertions(+), 120 deletions(-) diff --git a/app/components/UI/Predict/controllers/PredictController.ts b/app/components/UI/Predict/controllers/PredictController.ts index 96323c345cc..cab0aa6762a 100644 --- a/app/components/UI/Predict/controllers/PredictController.ts +++ b/app/components/UI/Predict/controllers/PredictController.ts @@ -2732,7 +2732,7 @@ export class PredictController extends BaseController< public async publish(_request: { transactionMeta: TransactionMeta; - }): Promise<{ transactionHash?: string; isIntentComplete?: boolean }> { + }): Promise<{ transactionHash?: string }> { return { transactionHash: undefined }; } diff --git a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts index 8ab21bf6728..d811d27d661 100644 --- a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts +++ b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts @@ -122,22 +122,43 @@ function buildInitRequestMock( TransactionControllerInitMessenger > > { + const { + predictControllerMock: providedPredictControllerMock, + ...requestOverrides + } = initRequestProperties; + const predictControllerMock = + (providedPredictControllerMock as ControllerMock | undefined) ?? + buildControllerMock(); const initMessenger = new ExtendedMessenger({ namespace: MOCK_ANY_NAMESPACE, }); const baseControllerMessenger = new ExtendedMessenger({ namespace: MOCK_ANY_NAMESPACE, }); + (initMessenger as unknown as { call: jest.Mock }).call = jest.fn( + (actionType: string, params: unknown) => { + if (actionType === 'PredictController:beforePublish') { + return predictControllerMock.beforePublish(params); + } + + if (actionType === 'PredictController:publish') { + return predictControllerMock.publish(params); + } + + throw new Error(`Unexpected init messenger action: ${actionType}`); + }, + ); + const requestMock = { ...buildMessengerClientInitRequestMock(baseControllerMessenger), initMessenger: initMessenger as unknown as TransactionControllerInitMessenger, controllerMessenger: baseControllerMessenger as unknown as TransactionControllerMessenger, - ...initRequestProperties, + ...requestOverrides, }; - if (!initRequestProperties.getMessengerClient) { + if (!requestOverrides.getMessengerClient) { requestMock.getMessengerClient.mockReturnValue(buildControllerMock()); } @@ -339,11 +360,7 @@ describe('Transaction Controller Init', () => { 'hooks', {}, { - getMessengerClient: jest.fn((controllerName: string) => - controllerName === 'PredictController' - ? predictControllerMock - : buildControllerMock(), - ), + predictControllerMock, }, ); @@ -388,11 +405,7 @@ describe('Transaction Controller Init', () => { 'hooks', {}, { - getMessengerClient: jest.fn((controllerName: string) => - controllerName === 'PredictController' - ? predictControllerMock - : buildControllerMock(), - ), + predictControllerMock, }, ); @@ -422,11 +435,7 @@ describe('Transaction Controller Init', () => { 'hooks', {}, { - getMessengerClient: jest.fn((controllerName: string) => - controllerName === 'PredictController' - ? predictControllerMock - : buildControllerMock(), - ), + predictControllerMock, }, ); @@ -437,52 +446,6 @@ describe('Transaction Controller Init', () => { expect(submitSmartTransactionHookMock).not.toHaveBeenCalled(); }); - it('marks the latest transaction intent complete when Predict publish completes an intent', async () => { - const predictControllerMock = buildControllerMock({ - publish: jest.fn().mockResolvedValue({ - transactionHash: '0xpredict', - isIntentComplete: true, - }), - } as unknown as Partial); - const getTransactionByIdMock = jest.requireMock( - '../../../../util/transactions', - ).getTransactionById; - getTransactionByIdMock.mockReturnValue({ ...MOCK_TRANSACTION_META }); - const hooks = testConstructorOption( - 'hooks', - {}, - { - getMessengerClient: jest.fn((controllerName: string) => - controllerName === 'PredictController' - ? predictControllerMock - : buildControllerMock(), - ), - }, - ); - - const result = await hooks?.publish?.(MOCK_TRANSACTION_META); - - const transactionControllerInstance = transactionControllerClassMock.mock - .instances[0] as unknown as { - updateTransaction: jest.Mock; - }; - - expect(result).toEqual({ transactionHash: '0xpredict' }); - expect(getTransactionByIdMock).toHaveBeenCalledWith( - MOCK_TRANSACTION_META.id, - transactionControllerInstance, - ); - expect( - transactionControllerInstance.updateTransaction, - ).toHaveBeenCalledWith( - expect.objectContaining({ - id: MOCK_TRANSACTION_META.id, - isIntentComplete: true, - }), - 'Predict claim relayer intent complete', - ); - }); - it('passes isSmartTransaction returning false to pay hook when stxDisabled is true', async () => { selectMetaMaskPayFlagsMock.mockReturnValue({ attemptsMax: 2, diff --git a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts index c71e6cb00c6..e4757003edd 100644 --- a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts +++ b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts @@ -125,7 +125,6 @@ export const TransactionControllerInit: MessengerClientInitFunction< transactionController, smartTransactionsController, initMessenger, - request, signedTransactionInHex, }), publishBatch: async (_request: PublishBatchHookRequest) => @@ -138,7 +137,7 @@ export const TransactionControllerInit: MessengerClientInitFunction< _request.transactions as PublishBatchHookTransaction[], }), beforePublish: (transactionMeta: TransactionMeta) => - beforePublish(transactionMeta, request), + beforePublish(transactionMeta, initMessenger), beforeSign: (_request: { transactionMeta: TransactionMeta }) => beforeSign(_request, request), }, @@ -219,7 +218,6 @@ async function publishHook({ transactionController, smartTransactionsController, initMessenger, - request, signedTransactionInHex, }: { transactionMeta: TransactionMeta; @@ -228,20 +226,15 @@ async function publishHook({ transactionController: TransactionController; smartTransactionsController: SmartTransactionsController; initMessenger: TransactionControllerInitMessenger; - request: MessengerClientInitRequest< - TransactionControllerMessenger, - TransactionControllerInitMessenger - >; signedTransactionInHex: Hex; }): Promise<{ transactionHash?: string }> { - const predictResult = await publishPredict({ - transactionMeta, - transactionController, - request, - }); + const { transactionHash: predictTransactionHash } = await initMessenger.call( + 'PredictController:publish', + { transactionMeta }, + ); - if (predictResult.transactionHash) { - return { transactionHash: predictResult.transactionHash }; + if (predictTransactionHash) { + return { transactionHash: predictTransactionHash }; } const state = getState(); @@ -344,41 +337,6 @@ async function publishHook({ return { transactionHash: undefined }; } -async function publishPredict({ - transactionMeta, - transactionController, - request, -}: { - transactionMeta: TransactionMeta; - transactionController: TransactionController; - request: MessengerClientInitRequest< - TransactionControllerMessenger, - TransactionControllerInitMessenger - >; -}): Promise<{ transactionHash?: string; isIntentComplete?: boolean }> { - const predictController = request.getMessengerClient('PredictController'); - const result = await predictController.publish({ transactionMeta }); - - if (result.transactionHash && result.isIntentComplete) { - const latestMeta = getTransactionById( - transactionMeta.id, - transactionController, - ); - - if (latestMeta) { - transactionController.updateTransaction( - { - ...latestMeta, - isIntentComplete: true, - }, - 'Predict claim relayer intent complete', - ); - } - } - - return result; -} - function getSmartTransactionCommonParams(state: RootState, chainId: Hex) { const shouldUseSmartTransaction = selectShouldUseSmartTransaction( state, @@ -496,13 +454,11 @@ function getControllers( function beforePublish( transactionMeta: TransactionMeta, - request: MessengerClientInitRequest< - TransactionControllerMessenger, - TransactionControllerInitMessenger - >, + initMessenger: TransactionControllerInitMessenger, ) { - const predictController = request.getMessengerClient('PredictController'); - return predictController.beforePublish({ transactionMeta }); + return initMessenger.call('PredictController:beforePublish', { + transactionMeta, + }); } function beforeSign( diff --git a/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts b/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts index ca6bb67d57f..324fbdd8ccc 100644 --- a/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts +++ b/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts @@ -55,6 +55,10 @@ import { MessengerActions, MessengerEvents, } from '@metamask/messenger'; +import type { + PredictControllerBeforePublishAction, + PredictControllerPublishAction, +} from '../../../../components/UI/Predict/controllers/PredictController-method-action-types'; export function getTransactionControllerMessenger( rootMessenger: RootMessenger, @@ -114,7 +118,9 @@ type InitMessengerActions = | TransactionPayControllerGetDelegationTransactionAction | TransactionPayControllerGetStateAction | TransactionPayControllerGetStrategyAction - | AnalyticsControllerActions; + | AnalyticsControllerActions + | PredictControllerBeforePublishAction + | PredictControllerPublishAction; type InitMessengerEvents = | BridgeStatusControllerEvents @@ -173,6 +179,8 @@ export function getTransactionControllerInitMessenger( 'TransactionPayController:getState', 'TransactionPayController:getStrategy', 'AnalyticsController:trackEvent', + 'PredictController:beforePublish', + 'PredictController:publish', ], events: [ 'BridgeStatusController:stateChange',