diff --git a/app/components/UI/Predict/controllers/PredictController-method-action-types.ts b/app/components/UI/Predict/controllers/PredictController-method-action-types.ts index 611cc69c66c..5dc2a92767b 100644 --- a/app/components/UI/Predict/controllers/PredictController-method-action-types.ts +++ b/app/components/UI/Predict/controllers/PredictController-method-action-types.ts @@ -248,11 +248,21 @@ export type PredictControllerPrepareWithdrawAction = { handler: PredictController['prepareWithdraw']; }; +export type PredictControllerBeforePublishAction = { + type: `PredictController:beforePublish`; + handler: PredictController['beforePublish']; +}; + export type PredictControllerBeforeSignAction = { type: `PredictController:beforeSign`; handler: PredictController['beforeSign']; }; +export type PredictControllerPublishAction = { + type: `PredictController:publish`; + handler: PredictController['publish']; +}; + export type PredictControllerClearWithdrawTransactionAction = { type: `PredictController:clearWithdrawTransaction`; handler: PredictController['clearWithdrawTransaction']; @@ -300,5 +310,7 @@ export type PredictControllerMethodActions = | PredictControllerGetAccountStateAction | PredictControllerGetBalanceAction | PredictControllerPrepareWithdrawAction + | PredictControllerBeforePublishAction | PredictControllerBeforeSignAction + | PredictControllerPublishAction | PredictControllerClearWithdrawTransactionAction; diff --git a/app/components/UI/Predict/controllers/PredictController.test.ts b/app/components/UI/Predict/controllers/PredictController.test.ts index 6b0122eb258..08562cacbcc 100644 --- a/app/components/UI/Predict/controllers/PredictController.test.ts +++ b/app/components/UI/Predict/controllers/PredictController.test.ts @@ -6139,6 +6139,40 @@ describe('PredictController', () => { }); }); + describe('beforePublish', () => { + it('passes through by default', async () => { + await withController(async ({ controller }) => { + const result = await controller.beforePublish({ + transactionMeta: { + id: 'tx-1', + txParams: { + from: MOCK_ADDRESS, + }, + } as TransactionMeta, + }); + + expect(result).toBe(true); + }); + }); + }); + + describe('publish', () => { + it('passes through by default', async () => { + await withController(async ({ controller }) => { + const result = await controller.publish({ + transactionMeta: { + id: 'tx-1', + txParams: { + from: MOCK_ADDRESS, + }, + } as TransactionMeta, + }); + + expect(result).toEqual({ transactionHash: undefined }); + }); + }); + }); + describe('beforeSign', () => { const mockTransactionMeta = { id: 'tx-1', diff --git a/app/components/UI/Predict/controllers/PredictController.ts b/app/components/UI/Predict/controllers/PredictController.ts index 4396fb809e1..cab0aa6762a 100644 --- a/app/components/UI/Predict/controllers/PredictController.ts +++ b/app/components/UI/Predict/controllers/PredictController.ts @@ -345,6 +345,7 @@ export interface PredictControllerOptions { } const MESSENGER_EXPOSED_METHODS = [ + 'beforePublish', 'beforeSign', 'claimWithConfirmation', 'clearActiveOrder', @@ -370,6 +371,7 @@ const MESSENGER_EXPOSED_METHODS = [ 'onPlaceOrderSuccess', 'placeOrder', 'prepareWithdraw', + 'publish', 'previewOrder', 'refreshEligibility', 'selectPaymentToken', @@ -2619,6 +2621,12 @@ export class PredictController extends BaseController< } } + public async beforePublish(_request: { + transactionMeta: TransactionMeta; + }): Promise { + return true; + } + public async beforeSign(request: { transactionMeta: TransactionMeta; }): Promise< @@ -2722,6 +2730,12 @@ export class PredictController extends BaseController< }; } + public async publish(_request: { + transactionMeta: TransactionMeta; + }): Promise<{ transactionHash?: string }> { + return { transactionHash: undefined }; + } + public clearWithdrawTransaction(): void { this.update((state) => { state.withdrawTransaction = null; @@ -2730,6 +2744,7 @@ export class PredictController extends BaseController< } export type { + PredictControllerBeforePublishAction, PredictControllerBeforeSignAction, PredictControllerClaimWithConfirmationAction, PredictControllerClearActiveOrderAction, @@ -2754,6 +2769,7 @@ export type { PredictControllerPlaceOrderAction, PredictControllerPrepareWithdrawAction, PredictControllerPreviewOrderAction, + PredictControllerPublishAction, PredictControllerRefreshEligibilityAction, PredictControllerSelectPaymentTokenAction, PredictControllerSetSelectedPaymentTokenAction, diff --git a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts index 7d77de82cfe..d811d27d661 100644 --- a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts +++ b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.test.ts @@ -92,10 +92,20 @@ const MOCK_TRANSACTION_META = { * with the default mock. * @returns A mock NetworkController. */ +type ControllerMock = NetworkController & { + beforePublish: jest.Mock; + beforeSign: jest.Mock; + publish: jest.Mock; +}; + function buildControllerMock( - partialMock?: Partial, -): NetworkController { - const defaultControllerMocks = {}; + partialMock?: Partial, +): ControllerMock { + const defaultControllerMocks = { + beforePublish: jest.fn().mockResolvedValue(true), + beforeSign: jest.fn(), + publish: jest.fn().mockResolvedValue({ transactionHash: undefined }), + }; // @ts-expect-error Incomplete mock, just includes properties used by code-under-test. return { @@ -112,22 +122,43 @@ function buildInitRequestMock( TransactionControllerInitMessenger > > { + const { + predictControllerMock: providedPredictControllerMock, + ...requestOverrides + } = initRequestProperties; + const predictControllerMock = + (providedPredictControllerMock as ControllerMock | undefined) ?? + buildControllerMock(); const initMessenger = new ExtendedMessenger({ namespace: MOCK_ANY_NAMESPACE, }); const baseControllerMessenger = new ExtendedMessenger({ namespace: MOCK_ANY_NAMESPACE, }); + (initMessenger as unknown as { call: jest.Mock }).call = jest.fn( + (actionType: string, params: unknown) => { + if (actionType === 'PredictController:beforePublish') { + return predictControllerMock.beforePublish(params); + } + + if (actionType === 'PredictController:publish') { + return predictControllerMock.publish(params); + } + + throw new Error(`Unexpected init messenger action: ${actionType}`); + }, + ); + const requestMock = { ...buildMessengerClientInitRequestMock(baseControllerMessenger), initMessenger: initMessenger as unknown as TransactionControllerInitMessenger, controllerMessenger: baseControllerMessenger as unknown as TransactionControllerMessenger, - ...initRequestProperties, + ...requestOverrides, }; - if (!initRequestProperties.getMessengerClient) { + if (!requestOverrides.getMessengerClient) { requestMock.getMessengerClient.mockReturnValue(buildControllerMock()); } @@ -180,9 +211,11 @@ describe('Transaction Controller Init', () => { ): TransactionControllerOptions[T] { const requestMock = buildInitRequestMock(initRequestProperties); - requestMock.getMessengerClient.mockReturnValue( - buildControllerMock(dependencyProperties), - ); + if (!initRequestProperties.getMessengerClient) { + requestMock.getMessengerClient.mockReturnValue( + buildControllerMock(dependencyProperties), + ); + } TransactionControllerInit(requestMock); @@ -320,6 +353,25 @@ describe('Transaction Controller Init', () => { expect(optionFn?.()).toBe(false); }); + describe('beforePublish hook', () => { + it('delegates to PredictController beforePublish', async () => { + const predictControllerMock = buildControllerMock(); + const hooks = testConstructorOption( + 'hooks', + {}, + { + predictControllerMock, + }, + ); + + await hooks?.beforePublish?.(MOCK_TRANSACTION_META); + + expect(predictControllerMock.beforePublish).toHaveBeenCalledWith({ + transactionMeta: MOCK_TRANSACTION_META, + }); + }); + }); + describe('publish hook', () => { it('calls submitSmartTransactionHook', async () => { const hooks = testConstructorOption('hooks'); @@ -347,6 +399,53 @@ describe('Transaction Controller Init', () => { expect(payHookMock).toHaveBeenCalledTimes(1); }); + it('calls Predict publish before pay and smart transaction hooks', async () => { + const predictControllerMock = buildControllerMock(); + const hooks = testConstructorOption( + 'hooks', + {}, + { + predictControllerMock, + }, + ); + + await hooks?.publish?.(MOCK_TRANSACTION_META); + + expect(predictControllerMock.publish).toHaveBeenCalledWith({ + transactionMeta: MOCK_TRANSACTION_META, + }); + expect(payHookMock).toHaveBeenCalledTimes(1); + expect( + (predictControllerMock.publish as jest.Mock).mock + .invocationCallOrder[0], + ).toBeLessThan(payHookMock.mock.invocationCallOrder[0]); + expect( + (predictControllerMock.publish as jest.Mock).mock + .invocationCallOrder[0], + ).toBeLessThan( + submitSmartTransactionHookMock.mock.invocationCallOrder[0], + ); + }); + + it('short-circuits publish when Predict returns a transaction hash', async () => { + const predictControllerMock = buildControllerMock({ + publish: jest.fn().mockResolvedValue({ transactionHash: '0xpredict' }), + } as unknown as Partial); + const hooks = testConstructorOption( + 'hooks', + {}, + { + predictControllerMock, + }, + ); + + const result = await hooks?.publish?.(MOCK_TRANSACTION_META); + + expect(result).toEqual({ transactionHash: '0xpredict' }); + expect(payHookMock).not.toHaveBeenCalled(); + expect(submitSmartTransactionHookMock).not.toHaveBeenCalled(); + }); + it('passes isSmartTransaction returning false to pay hook when stxDisabled is true', async () => { selectMetaMaskPayFlagsMock.mockReturnValue({ attemptsMax: 2, diff --git a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts index 64bc20e25cb..e4757003edd 100644 --- a/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts +++ b/app/core/Engine/controllers/transaction-controller/transaction-controller-init.ts @@ -136,6 +136,8 @@ export const TransactionControllerInit: MessengerClientInitFunction< transactions: _request.transactions as PublishBatchHookTransaction[], }), + beforePublish: (transactionMeta: TransactionMeta) => + beforePublish(transactionMeta, initMessenger), beforeSign: (_request: { transactionMeta: TransactionMeta }) => beforeSign(_request, request), }, @@ -226,6 +228,15 @@ async function publishHook({ initMessenger: TransactionControllerInitMessenger; signedTransactionInHex: Hex; }): Promise<{ transactionHash?: string }> { + const { transactionHash: predictTransactionHash } = await initMessenger.call( + 'PredictController:publish', + { transactionMeta }, + ); + + if (predictTransactionHash) { + return { transactionHash: predictTransactionHash }; + } + const state = getState(); const { shouldUseSmartTransaction, featureFlags } = @@ -441,6 +452,15 @@ function getControllers( }; } +function beforePublish( + transactionMeta: TransactionMeta, + initMessenger: TransactionControllerInitMessenger, +) { + return initMessenger.call('PredictController:beforePublish', { + transactionMeta, + }); +} + function beforeSign( hookRequest: { transactionMeta: TransactionMeta }, request: MessengerClientInitRequest< diff --git a/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts b/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts index ca6bb67d57f..324fbdd8ccc 100644 --- a/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts +++ b/app/core/Engine/messengers/transaction-controller-messenger/transaction-controller-messenger.ts @@ -55,6 +55,10 @@ import { MessengerActions, MessengerEvents, } from '@metamask/messenger'; +import type { + PredictControllerBeforePublishAction, + PredictControllerPublishAction, +} from '../../../../components/UI/Predict/controllers/PredictController-method-action-types'; export function getTransactionControllerMessenger( rootMessenger: RootMessenger, @@ -114,7 +118,9 @@ type InitMessengerActions = | TransactionPayControllerGetDelegationTransactionAction | TransactionPayControllerGetStateAction | TransactionPayControllerGetStrategyAction - | AnalyticsControllerActions; + | AnalyticsControllerActions + | PredictControllerBeforePublishAction + | PredictControllerPublishAction; type InitMessengerEvents = | BridgeStatusControllerEvents @@ -173,6 +179,8 @@ export function getTransactionControllerInitMessenger( 'TransactionPayController:getState', 'TransactionPayController:getStrategy', 'AnalyticsController:trackEvent', + 'PredictController:beforePublish', + 'PredictController:publish', ], events: [ 'BridgeStatusController:stateChange',