diff --git a/change/@azure-msal-node-e0a8f34f-53b7-465b-803a-921d02435808.json b/change/@azure-msal-node-e0a8f34f-53b7-465b-803a-921d02435808.json new file mode 100644 index 0000000000..c825030998 --- /dev/null +++ b/change/@azure-msal-node-e0a8f34f-53b7-465b-803a-921d02435808.json @@ -0,0 +1,7 @@ +{ + "type": "minor", + "comment": "Implemented Managed Identity Version 2 #7587", + "packageName": "@azure/msal-node", + "email": "rginsburg@microsoft.com", + "dependentChangeType": "patch" +} diff --git a/lib/msal-node/apiReview/msal-node.api.md b/lib/msal-node/apiReview/msal-node.api.md index 5fa0dc6173..59e6d11b84 100644 --- a/lib/msal-node/apiReview/msal-node.api.md +++ b/lib/msal-node/apiReview/msal-node.api.md @@ -392,7 +392,7 @@ export { LogLevel } export class ManagedIdentityApplication { constructor(configuration?: ManagedIdentityConfiguration); acquireToken(managedIdentityRequestParams: ManagedIdentityRequestParams): Promise; - getManagedIdentitySource(): ManagedIdentitySourceNames; + getManagedIdentitySource(): Promise; } // @public (undocumented) @@ -423,6 +423,7 @@ export const ManagedIdentitySourceNames: { readonly CLOUD_SHELL: "CloudShell"; readonly DEFAULT_TO_IMDS: "DefaultToImds"; readonly IMDS: "Imds"; + readonly IMDSV2: "ImdsV2"; readonly MACHINE_LEARNING: "MachineLearning"; readonly SERVICE_FABRIC: "ServiceFabric"; }; diff --git a/lib/msal-node/src/client/ManagedIdentityApplication.ts b/lib/msal-node/src/client/ManagedIdentityApplication.ts index c98ec75385..d8e0e930e4 100644 --- a/lib/msal-node/src/client/ManagedIdentityApplication.ts +++ b/lib/msal-node/src/client/ManagedIdentityApplication.ts @@ -178,7 +178,7 @@ export class ManagedIdentityApplication { */ if (managedIdentityRequest.claims) { const sourceName: ManagedIdentitySourceNames = - this.managedIdentityClient.getManagedIdentitySource(); + await this.managedIdentityClient.getManagedIdentitySource(); /* * Check if there is a cached token and if the Managed Identity source supports token revocation. @@ -257,10 +257,10 @@ export class ManagedIdentityApplication { * Determine the Managed Identity Source based on available environment variables. This API is consumed by Azure Identity SDK. * @returns ManagedIdentitySourceNames - The Managed Identity source's name */ - public getManagedIdentitySource(): ManagedIdentitySourceNames { + public async getManagedIdentitySource(): Promise { return ( ManagedIdentityClient.sourceName || - this.managedIdentityClient.getManagedIdentitySource() + (await this.managedIdentityClient.getManagedIdentitySource()) ); } } diff --git a/lib/msal-node/src/client/ManagedIdentityClient.ts b/lib/msal-node/src/client/ManagedIdentityClient.ts index 798606736a..9a60e3be84 100644 --- a/lib/msal-node/src/client/ManagedIdentityClient.ts +++ b/lib/msal-node/src/client/ManagedIdentityClient.ts @@ -25,6 +25,7 @@ import { NodeStorage } from "../cache/NodeStorage.js"; import { BaseManagedIdentitySource } from "./ManagedIdentitySources/BaseManagedIdentitySource.js"; import { ManagedIdentitySourceNames } from "../utils/Constants.js"; import { MachineLearning } from "./ManagedIdentitySources/MachineLearning.js"; +import { ImdsV2 } from "./ManagedIdentitySources/ImdsV2.js"; /* * Class to initialize a managed identity and identify the service. @@ -62,7 +63,7 @@ export class ManagedIdentityClient { ): Promise { if (!ManagedIdentityClient.identitySource) { ManagedIdentityClient.identitySource = - this.selectManagedIdentitySource( + await this.selectManagedIdentitySource( this.logger, this.nodeStorage, this.networkClient, @@ -91,10 +92,11 @@ export class ManagedIdentityClient { } /** - * Determine the Managed Identity Source based on available environment variables. This API is consumed by ManagedIdentityApplication's getManagedIdentitySource. + * Determine the Managed Identity Source based on available environment variables and probing an IMDS credential endpoint. + * This API is consumed by ManagedIdentityApplication's getManagedIdentitySource. * @returns ManagedIdentitySourceNames - The Managed Identity source's name */ - public getManagedIdentitySource(): ManagedIdentitySourceNames { + public async getManagedIdentitySource(): Promise { ManagedIdentityClient.sourceName = this.allEnvironmentVariablesAreDefined( ServiceFabric.getEnvironmentVariables() @@ -116,6 +118,11 @@ export class ManagedIdentityClient { AzureArc.getEnvironmentVariables() ) ? ManagedIdentitySourceNames.AZURE_ARC + : (await ImdsV2.isCredentialEndpointAvailable( + this.logger, + this.networkClient + )) + ? ManagedIdentitySourceNames.IMDSV2 : ManagedIdentitySourceNames.DEFAULT_TO_IMDS; return ManagedIdentityClient.sourceName; @@ -125,14 +132,14 @@ export class ManagedIdentityClient { * Tries to create a managed identity source for all sources * @returns the managed identity Source */ - private selectManagedIdentitySource( + private async selectManagedIdentitySource( logger: Logger, nodeStorage: NodeStorage, networkClient: INetworkModule, cryptoProvider: CryptoProvider, disableInternalRetries: boolean, managedIdentityId: ManagedIdentityId - ): BaseManagedIdentitySource { + ): Promise { const source = ServiceFabric.tryCreate( logger, @@ -172,6 +179,13 @@ export class ManagedIdentityClient { disableInternalRetries, managedIdentityId ) || + (await ImdsV2.tryCreate( + logger, + nodeStorage, + networkClient, + cryptoProvider, + disableInternalRetries + )) || Imds.tryCreate( logger, nodeStorage, diff --git a/lib/msal-node/src/client/ManagedIdentitySources/Imds.ts b/lib/msal-node/src/client/ManagedIdentitySources/Imds.ts index 69aed1f746..a70fbe82ca 100644 --- a/lib/msal-node/src/client/ManagedIdentitySources/Imds.ts +++ b/lib/msal-node/src/client/ManagedIdentitySources/Imds.ts @@ -21,9 +21,9 @@ import { ImdsRetryPolicy } from "../../retry/ImdsRetryPolicy.js"; // Documentation for IMDS is available at https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http +const DEFAULT_IMDS_BASE_ENDPOINT: string = `http://169.254.169.254`; const IMDS_TOKEN_PATH: string = "/metadata/identity/oauth2/token"; -const DEFAULT_IMDS_ENDPOINT: string = `http://169.254.169.254${IMDS_TOKEN_PATH}`; -const IMDS_API_VERSION: string = "2018-02-01"; +export const IMDS_API_VERSION: string = "2018-02-01"; // referenced in ImdsV2 /** * Original source of code: https://github.com/Azure/azure-sdk-for-net/blob/main/sdk/identity/Azure.Identity/src/ImdsManagedIdentitySource.cs @@ -78,41 +78,10 @@ export class Imds extends BaseManagedIdentitySource { cryptoProvider: CryptoProvider, disableInternalRetries: boolean ): Imds { - let validatedIdentityEndpoint: string; - - if ( - process.env[ - ManagedIdentityEnvironmentVariableNames - .AZURE_POD_IDENTITY_AUTHORITY_HOST - ] - ) { - logger.info( - `[Managed Identity] Environment variable ${ - ManagedIdentityEnvironmentVariableNames.AZURE_POD_IDENTITY_AUTHORITY_HOST - } for ${ManagedIdentitySourceNames.IMDS} returned endpoint: ${ - process.env[ - ManagedIdentityEnvironmentVariableNames - .AZURE_POD_IDENTITY_AUTHORITY_HOST - ] - }` - ); - validatedIdentityEndpoint = Imds.getValidatedEnvVariableUrlString( - ManagedIdentityEnvironmentVariableNames.AZURE_POD_IDENTITY_AUTHORITY_HOST, - `${ - process.env[ - ManagedIdentityEnvironmentVariableNames - .AZURE_POD_IDENTITY_AUTHORITY_HOST - ] - }${IMDS_TOKEN_PATH}`, - ManagedIdentitySourceNames.IMDS, - logger - ); - } else { - logger.info( - `[Managed Identity] Unable to find ${ManagedIdentityEnvironmentVariableNames.AZURE_POD_IDENTITY_AUTHORITY_HOST} environment variable for ${ManagedIdentitySourceNames.IMDS}, using the default endpoint.` - ); - validatedIdentityEndpoint = DEFAULT_IMDS_ENDPOINT; - } + const validatedIdentityEndpoint: string = this.getValidatedEndpoint( + IMDS_TOKEN_PATH, + logger + ); return new Imds( logger, @@ -166,4 +135,44 @@ export class Imds extends BaseManagedIdentitySource { return request; } + + public static getValidatedEndpoint = ( + subPath: string, + logger: Logger + ): string => { + if ( + process.env[ + ManagedIdentityEnvironmentVariableNames + .AZURE_POD_IDENTITY_AUTHORITY_HOST + ] + ) { + logger.info( + `[Managed Identity] Environment variable ${ + ManagedIdentityEnvironmentVariableNames.AZURE_POD_IDENTITY_AUTHORITY_HOST + } for ${ManagedIdentitySourceNames.IMDS} returned endpoint: ${ + process.env[ + ManagedIdentityEnvironmentVariableNames + .AZURE_POD_IDENTITY_AUTHORITY_HOST + ] + }` + ); + + return Imds.getValidatedEnvVariableUrlString( + ManagedIdentityEnvironmentVariableNames.AZURE_POD_IDENTITY_AUTHORITY_HOST, + `${ + process.env[ + ManagedIdentityEnvironmentVariableNames + .AZURE_POD_IDENTITY_AUTHORITY_HOST + ] + }${subPath}`, + ManagedIdentitySourceNames.IMDS, + logger + ); + } else { + logger.info( + `[Managed Identity] Unable to find ${ManagedIdentityEnvironmentVariableNames.AZURE_POD_IDENTITY_AUTHORITY_HOST} environment variable for ${ManagedIdentitySourceNames.IMDS}, using the default endpoint.` + ); + return `${DEFAULT_IMDS_BASE_ENDPOINT}${subPath}`; + } + }; } diff --git a/lib/msal-node/src/client/ManagedIdentitySources/ImdsV2.ts b/lib/msal-node/src/client/ManagedIdentitySources/ImdsV2.ts new file mode 100644 index 0000000000..5ff9654f1a --- /dev/null +++ b/lib/msal-node/src/client/ManagedIdentitySources/ImdsV2.ts @@ -0,0 +1,244 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { + INetworkModule, + Logger, + NetworkResponse, +} from "@azure/msal-common/node"; +// import { Agent } from "https"; +import { ManagedIdentityId } from "../../config/ManagedIdentityId.js"; +import { ManagedIdentityRequestParameters } from "../../config/ManagedIdentityRequestParameters.js"; +import { BaseManagedIdentitySource } from "./BaseManagedIdentitySource.js"; +import { CryptoProvider } from "../../crypto/CryptoProvider.js"; +import { + HttpMethod, + ManagedIdentityHeaders, + ManagedIdentityIdType, + ManagedIdentityQueryParameters, +} from "../../utils/Constants.js"; +import { NodeStorage } from "../../cache/NodeStorage.js"; +import { Imds, IMDS_API_VERSION } from "./Imds.js"; +import { ShortLivedCredential } from "../../response/ShortLivedCredentialResponse.js"; +import { HttpClientWithRetries } from "../../network/HttpClientWithRetries.js"; +import { DefaultManagedIdentityRetryPolicy } from "../../retry/DefaultManagedIdentityRetryPolicy.js"; + +export const CREDENTIAL_PATH: string = + "/metadata/identity/credential?cred-api-version=1.0"; + +export interface CredentialEndpointProbeResponse { + error: string; + error_description: string; +} + +export class ImdsV2 extends BaseManagedIdentitySource { + private credentialEndpoint: string; + + constructor( + logger: Logger, + nodeStorage: NodeStorage, + networkClient: INetworkModule, + cryptoProvider: CryptoProvider, + disableInternalRetries: boolean, + credentialEndpoint: string + ) { + super( + logger, + nodeStorage, + networkClient, + cryptoProvider, + disableInternalRetries + ); + + this.credentialEndpoint = credentialEndpoint; + } + + public static async tryCreate( + logger: Logger, + nodeStorage: NodeStorage, + networkClient: INetworkModule, + cryptoProvider: CryptoProvider, + disableInternalRetries: boolean + ): Promise { + const validatedCredentialEndpoint: string = Imds.getValidatedEndpoint( + CREDENTIAL_PATH, + logger + ); + + if ( + !(await this.isCredentialEndpointAvailable( + logger, + networkClient, + validatedCredentialEndpoint + )) + ) { + return null; + } + + return new ImdsV2( + logger, + nodeStorage, + networkClient, + cryptoProvider, + disableInternalRetries, + validatedCredentialEndpoint + ); + } + + public static async isCredentialEndpointAvailable( + logger: Logger, + networkClient: INetworkModule, + credentialEndpoint?: string // only passed in from tryCreate in this class + ): Promise { + const validatedCredentialEndpoint: string = + credentialEndpoint || + Imds.getValidatedEndpoint(CREDENTIAL_PATH, logger); + + const networkClientWithRetry: INetworkModule = + new HttpClientWithRetries( + networkClient, + /* + * TODO: create probe credential endpoint retry policy that extends DefaultManagedIdentityRetryPolicy, + * that only retries on 400 and 500 + */ + new DefaultManagedIdentityRetryPolicy(), + logger + ); + + const response: NetworkResponse = + await networkClientWithRetry.sendPostRequestAsync( + validatedCredentialEndpoint, + { body: "." } + ); + + if (response.status !== 400) { + return false; + } + + /* + * Match "IMDS/" at start of "server" header string (`^IMDS\/`) + * Match the first three numbers with dots (`\d+.\d+.\d+.`) + * Capture the last number in a group (`(\d+)`) + * Ensure end of string (`$`) + * + * Example: + * [ + * "IMDS/150.870.65.1556", // index 0: full match + * "1556" // index 1: captured group (\d+) + * ] + */ + const versionMatch = response.headers["server"]?.match( + /^IMDS\/\d+\.\d+\.\d+\.(\d+)$/ + ); + return Boolean(versionMatch && parseInt(versionMatch[1], 10) > 1324); // .match can return null, so Boolean() is needed + } + + public createRequest( + resource: string, + managedIdentityId: ManagedIdentityId + ): ManagedIdentityRequestParameters { + const imdsRequest: ManagedIdentityRequestParameters = + new ManagedIdentityRequestParameters( + HttpMethod.POST, + this.credentialEndpoint + ); + + imdsRequest.headers[ManagedIdentityHeaders.METADATA_HEADER_NAME] = + "true"; + imdsRequest.headers[ + ManagedIdentityHeaders.CLIENT_REQUEST_ID_HEADER_NAME + ] = "1234567890"; // TODO: generate random request ID + + imdsRequest.queryParameters[ + ManagedIdentityQueryParameters.API_VERSION + ] = IMDS_API_VERSION; + imdsRequest.queryParameters[ManagedIdentityQueryParameters.RESOURCE] = + resource; + + if ( + managedIdentityId.idType !== ManagedIdentityIdType.SYSTEM_ASSIGNED + ) { + imdsRequest.queryParameters[ + this.getManagedIdentityUserAssignedIdQueryParameterKey( + managedIdentityId.idType, + true // indicates source is IMDS + ) + ] = managedIdentityId.id; + } + + /* + * TODO: add self-signed mTLS certificate functionality + * If Windows, check certificate store for mTLS certificate (no Linux support) + * Otherwise, check in-memory cache for mTLS certificate + * If not either of the above, create self-signed mTLS certificate + */ + /* + * const mTLSCertificatePem: string = "fake_cert"; + * const privateKeyPem: string = "fake_private_key"; + */ + const sha256HashOfPublicKey: string = "fake_sha256_hash_of_public_key"; + const x5C: string = "fake_x5c"; + imdsRequest.bodyParameters = { + cnf: JSON.stringify({ + jwk: { + kty: "RSA", + use: "sig", + alg: "RS256", + kid: sha256HashOfPublicKey, + x5c: [x5C], + }, + }), + latch_key: "false", + }; + + /* + * TODO: Request SLC via "/credential" endpoint instead of using this fake object. + * This will be complicated the current acquireTokenWithManagedIdentity function in + * BaseManagedIdentitySource is not built to handle this request. + */ + const shortLivedCredential: ShortLivedCredential = { + client_id: "fake_string", + credential: "fake_string", + expires_in: 3599, + identity_type: "fake_string", + refresh_in: 3599, + region: "fake_string", + regional_token_url: "fake_string", + tenant_id: "fake_string", + }; + + const estsRequest: ManagedIdentityRequestParameters = + new ManagedIdentityRequestParameters( + HttpMethod.POST, + `${shortLivedCredential.regional_token_url}/${shortLivedCredential.tenant_id}/oauth2/v2.0/token` + ); + + // TODO: define constants for these values + estsRequest.bodyParameters = { + grant_type: "client_credentials", + scope: "https://management.azure.com/.default", + client_id: shortLivedCredential.client_id, + client_assertion: shortLivedCredential.credential, + client_assertion_type: + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + }; + + /* + * TODO: + * 1. Re-work the HttpClient to handle the self-signed mTLS certificate + * 2. Add functionality to ManagedIdentityRequestParameters to handle the self-signed mTLS certificate + */ + /* + * const agent = new Agent({ + * cert: mTLSCertificatePem, + * key: privateKeyPem, + * ca: mTLSCertificatePem, + * }); + * estsRequest.agent = agent; + */ + + return estsRequest; + } +} diff --git a/lib/msal-node/src/response/ShortLivedCredentialResponse.ts b/lib/msal-node/src/response/ShortLivedCredentialResponse.ts new file mode 100644 index 0000000000..d077f357c6 --- /dev/null +++ b/lib/msal-node/src/response/ShortLivedCredentialResponse.ts @@ -0,0 +1,16 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +// TODO: Add documentation +export type ShortLivedCredential = { + client_id: string; + credential: string; + expires_in: number; + identity_type: string; + refresh_in: number; + region: string; + regional_token_url: string; + tenant_id: string; +}; diff --git a/lib/msal-node/src/utils/Constants.ts b/lib/msal-node/src/utils/Constants.ts index c539567fe6..fe7dc0e16a 100644 --- a/lib/msal-node/src/utils/Constants.ts +++ b/lib/msal-node/src/utils/Constants.ts @@ -20,6 +20,7 @@ export const ManagedIdentityHeaders = { METADATA_HEADER_NAME: "Metadata", APP_SERVICE_SECRET_HEADER_NAME: "X-IDENTITY-HEADER", ML_AND_SF_SECRET_HEADER_NAME: "secret", + CLIENT_REQUEST_ID_HEADER_NAME: "X-ms-Client-Request-id", } as const; export type ManagedIdentityHeaders = (typeof ManagedIdentityHeaders)[keyof typeof ManagedIdentityHeaders]; @@ -62,6 +63,7 @@ export const ManagedIdentitySourceNames = { CLOUD_SHELL: "CloudShell", DEFAULT_TO_IMDS: "DefaultToImds", IMDS: "Imds", + IMDSV2: "ImdsV2", MACHINE_LEARNING: "MachineLearning", SERVICE_FABRIC: "ServiceFabric", } as const; diff --git a/lib/msal-node/test/client/ManagedIdentitySources/AppService.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/AppService.spec.ts index ed9679f00a..82aa66c81f 100644 --- a/lib/msal-node/test/client/ManagedIdentitySources/AppService.spec.ts +++ b/lib/msal-node/test/client/ManagedIdentitySources/AppService.spec.ts @@ -65,9 +65,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedClientIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.APP_SERVICE - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.APP_SERVICE); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -100,9 +100,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedResourceIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.APP_SERVICE - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.APP_SERVICE); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -133,13 +133,13 @@ describe("Acquires a token successfully via an App Service Managed Identity", () describe("System Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.APP_SERVICE - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.APP_SERVICE); }); test("acquires a token", async () => { @@ -193,9 +193,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(systemAssignedConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.APP_SERVICE - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.APP_SERVICE); let serverError: ServerError = new ServerError(); try { diff --git a/lib/msal-node/test/client/ManagedIdentitySources/AzureArc.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/AzureArc.spec.ts index 2b72f84435..c8f6ee345b 100644 --- a/lib/msal-node/test/client/ManagedIdentitySources/AzureArc.spec.ts +++ b/lib/msal-node/test/client/ManagedIdentitySources/AzureArc.spec.ts @@ -99,13 +99,13 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () = // Azure Arc Managed Identities can only be system assigned describe("System Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.AZURE_ARC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.AZURE_ARC); }); test("acquires a token", async () => { @@ -135,7 +135,7 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () = // and accessSyncSpy still returns an error // (meaning either the himds file doesn't exists or its permissions don't allow it to be read) expect( - managedIdentityApplication.getManagedIdentitySource() + await managedIdentityApplication.getManagedIdentitySource() ).not.toBe(ManagedIdentitySourceNames.AZURE_ARC); // delete value cached from getManagedIdentitySource() directly above delete ManagedIdentityClient["sourceName"]; @@ -146,9 +146,9 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () = return undefined; }); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.AZURE_ARC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.AZURE_ARC); // returns undefined when the himds file exists and its permissions allow it to be read // otherwise, throws an error @@ -253,20 +253,20 @@ describe("Acquires a token successfully via an Azure Arc Managed Identity", () = describe("Errors", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.AZURE_ARC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.AZURE_ARC); }); test("throws an error if a user assigned managed identity is used", async () => { const userAssignedManagedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedClientIdConfig); expect( - userAssignedManagedIdentityApplication.getManagedIdentitySource() + await userAssignedManagedIdentityApplication.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.AZURE_ARC); await expect( diff --git a/lib/msal-node/test/client/ManagedIdentitySources/CloudShell.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/CloudShell.spec.ts index 534ae03a01..825a188e49 100644 --- a/lib/msal-node/test/client/ManagedIdentitySources/CloudShell.spec.ts +++ b/lib/msal-node/test/client/ManagedIdentitySources/CloudShell.spec.ts @@ -54,13 +54,13 @@ describe("Acquires a token successfully via an App Service Managed Identity", () describe("System Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.CLOUD_SHELL - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.CLOUD_SHELL); }); test("acquires a token", async () => { @@ -101,9 +101,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () test("throws an error when a user assigned managed identity is used", async () => { const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedClientIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.CLOUD_SHELL - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.CLOUD_SHELL); await expect( managedIdentityApplication.acquireToken( @@ -132,9 +132,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(systemAssignedConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.CLOUD_SHELL - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.CLOUD_SHELL); let serverError: ServerError = new ServerError(); try { diff --git a/lib/msal-node/test/client/ManagedIdentitySources/DefaultManagedIdentityRetryPolicy.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/DefaultManagedIdentityRetryPolicy.spec.ts index 562b0085eb..e7c348d07d 100644 --- a/lib/msal-node/test/client/ManagedIdentitySources/DefaultManagedIdentityRetryPolicy.spec.ts +++ b/lib/msal-node/test/client/ManagedIdentitySources/DefaultManagedIdentityRetryPolicy.spec.ts @@ -83,13 +83,13 @@ describe("Linear Retry Policy (App Service, Azure Arc, Cloud Shell, Machine Lear describe("User Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( userAssignedClientIdConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.SERVICE_FABRIC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); }); test("returns a 500 error response from the network request, just the first time", async () => { @@ -142,13 +142,13 @@ describe("Linear Retry Policy (App Service, Azure Arc, Cloud Shell, Machine Lear describe("System Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.SERVICE_FABRIC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); }); test("returns a 500 error response from the network request, just the first time, with no retry-after header", async () => { @@ -364,7 +364,7 @@ describe("Linear Retry Policy (App Service, Azure Arc, Cloud Shell, Machine Lear }, }); expect( - managedIdentityApplicationNoRetry.getManagedIdentitySource() + await managedIdentityApplicationNoRetry.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); const sendGetRequestAsyncSpy: jest.SpyInstance = jest diff --git a/lib/msal-node/test/client/ManagedIdentitySources/Imds.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/Imds.spec.ts index 48a1ff8e4f..061220b26c 100644 --- a/lib/msal-node/test/client/ManagedIdentitySources/Imds.spec.ts +++ b/lib/msal-node/test/client/ManagedIdentitySources/Imds.spec.ts @@ -32,6 +32,7 @@ import { managedIdentityRequestParams, systemAssignedConfig, userAssignedResourceIdConfig, + mockCredentialEndpointProbeRequest, userAssignedObjectIdConfig, } from "../../test_kit/ManagedIdentityTestUtils.js"; import { @@ -63,10 +64,17 @@ import { NodeStorage } from "../../../src/cache/NodeStorage.js"; import { CacheKVStore } from "../../../src/cache/serializer/SerializerTypes.js"; import { ManagedIdentityUserAssignedIdQueryParameterNames } from "../../../src/client/ManagedIdentitySources/BaseManagedIdentitySource.js"; import { ImdsRetryPolicy } from "../../../src/retry/ImdsRetryPolicy.js"; +// import { ImdsV2 } from "../../../src/client/ManagedIdentitySources/ImdsV2.js"; describe("Acquires a token successfully via an IMDS Managed Identity", () => { // IMDS doesn't need environment variables because there is a default IMDS endpoint + beforeEach(() => { + mockCredentialEndpointProbeRequest( + HttpStatus.SUCCESS // TODO: change this to NOT_FOUND after implementing the credential endpoint probe retry policy + ); + }); + afterEach(() => { delete ManagedIdentityClient["identitySource"]; delete ManagedIdentityApplication["nodeStorage"]; @@ -91,9 +99,9 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedClientIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.DEFAULT_TO_IMDS - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -121,9 +129,9 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { test("acquires a User Assigned Object Id token", async () => { const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedObjectIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.DEFAULT_TO_IMDS - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -143,9 +151,9 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedResourceIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.DEFAULT_TO_IMDS - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -176,13 +184,13 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { describe("System Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.DEFAULT_TO_IMDS - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); }); test("acquires a token", async () => { @@ -223,7 +231,7 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { let uamiApplication: ManagedIdentityApplication; // user-assigned let samiApplication: ManagedIdentityApplication; // system-assigned - beforeEach(() => { + beforeEach(async () => { jest.spyOn( ImdsRetryPolicy, "MIN_EXPONENTIAL_BACKOFF_MS", @@ -263,14 +271,14 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { uamiApplication = new ManagedIdentityApplication( userAssignedClientIdConfig ); - expect(uamiApplication.getManagedIdentitySource()).toBe( + expect(await uamiApplication.getManagedIdentitySource()).toBe( ManagedIdentitySourceNames.DEFAULT_TO_IMDS ); samiApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(samiApplication.getManagedIdentitySource()).toBe( + expect(await samiApplication.getManagedIdentitySource()).toBe( ManagedIdentitySourceNames.DEFAULT_TO_IMDS ); }); @@ -618,11 +626,11 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { describe("Miscellaneous", () => { let systemAssignedManagedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { systemAssignedManagedIdentityApplication = new ManagedIdentityApplication(systemAssignedConfig); expect( - systemAssignedManagedIdentityApplication.getManagedIdentitySource() + await systemAssignedManagedIdentityApplication.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); }); @@ -856,7 +864,7 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { }, }); expect( - userAssignedClientIdManagedIdentityApplicationResource1.getManagedIdentitySource() + await userAssignedClientIdManagedIdentityApplicationResource1.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); const userAssignedObjectIdManagedIdentityApplicationResource2: ManagedIdentityApplication = @@ -871,7 +879,7 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { }, }); expect( - userAssignedObjectIdManagedIdentityApplicationResource2.getManagedIdentitySource() + await userAssignedObjectIdManagedIdentityApplicationResource2.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); // ********** begin: return access tokens from a network request ********** @@ -910,7 +918,7 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { const systemAssignedManagedIdentityApplicationClone: ManagedIdentityApplication = new ManagedIdentityApplication(systemAssignedConfig); expect( - systemAssignedManagedIdentityApplicationClone.getManagedIdentitySource() + await systemAssignedManagedIdentityApplicationClone.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); let cachedManagedIdentityResult: AuthenticationResult = await systemAssignedManagedIdentityApplicationClone.acquireToken( @@ -931,7 +939,7 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { }, }); expect( - userAssignedClientIdManagedIdentityApplicationResource1Clone.getManagedIdentitySource() + await userAssignedClientIdManagedIdentityApplicationResource1Clone.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); cachedManagedIdentityResult = await userAssignedClientIdManagedIdentityApplicationResource1Clone.acquireToken( @@ -954,7 +962,7 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { }, }); expect( - userAssignedObjectIdManagedIdentityApplicationResource2Clone.getManagedIdentitySource() + await userAssignedObjectIdManagedIdentityApplicationResource2Clone.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); cachedManagedIdentityResult = await userAssignedObjectIdManagedIdentityApplicationResource2Clone.acquireToken( @@ -991,11 +999,11 @@ describe("Acquires a token successfully via an IMDS Managed Identity", () => { describe("Errors", () => { let systemAssignedManagedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { systemAssignedManagedIdentityApplication = new ManagedIdentityApplication(systemAssignedConfig); expect( - systemAssignedManagedIdentityApplication.getManagedIdentitySource() + await systemAssignedManagedIdentityApplication.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); }); diff --git a/lib/msal-node/test/client/ManagedIdentitySources/ImdsV2.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/ImdsV2.spec.ts new file mode 100644 index 0000000000..0c5a99b8db --- /dev/null +++ b/lib/msal-node/test/client/ManagedIdentitySources/ImdsV2.spec.ts @@ -0,0 +1,164 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +import { HttpStatus } from "@azure/msal-common"; +import { + credentialEndpointProbeResponse, + mockCredentialEndpointProbeRequest, + networkClient, +} from "../../test_kit/ManagedIdentityTestUtils.js"; +import { DefaultManagedIdentityRetryPolicy } from "../../../src/retry/DefaultManagedIdentityRetryPolicy.js"; +import { ONE_HUNDRED_TIMES_FASTER } from "../../test_kit/StringConstants.js"; +import { ManagedIdentityApplication } from "../../../src/client/ManagedIdentityApplication.js"; +import { ManagedIdentityClient } from "../../../src/client/ManagedIdentityClient.js"; +import { ManagedIdentitySourceNames } from "../../../src/utils/Constants.js"; + +describe("ImdsV2", () => { + beforeEach(() => { + jest.spyOn( + DefaultManagedIdentityRetryPolicy, + "DEFAULT_MANAGED_IDENTITY_RETRY_DELAY_MS", + "get" + ).mockReturnValue( + DefaultManagedIdentityRetryPolicy.DEFAULT_MANAGED_IDENTITY_RETRY_DELAY_MS * + ONE_HUNDRED_TIMES_FASTER + ); + }); + + afterEach(() => { + delete ManagedIdentityClient["identitySource"]; + delete ManagedIdentityClient["sourceName"]; + delete ManagedIdentityApplication["nodeStorage"]; + jest.restoreAllMocks(); + }); + + describe("isCredentialEndpointAvailable", () => { + describe("returns true", () => { + test("when probe response is 400 with valid version", async () => { + mockCredentialEndpointProbeRequest( + HttpStatus.BAD_REQUEST, + "IMDS/1.1.1.2222" + ); + + const managedIdentityApplication = + new ManagedIdentityApplication({ + system: { + networkClient, + }, + }); + + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.IMDSV2); + }); + + test("after retrying on retriable status code", async () => { + const sendPostRequestSpy: jest.SpyInstance = + mockCredentialEndpointProbeRequest( + HttpStatus.BAD_REQUEST, + "IMDS/1.1.1.2222" + ) + .mockReturnValueOnce( + Promise.resolve({ + headers: {}, + body: credentialEndpointProbeResponse, + status: HttpStatus.SERVER_ERROR, + }) + ) + // second retry, will trigger third retry + .mockReturnValueOnce( + Promise.resolve({ + headers: {}, + body: credentialEndpointProbeResponse, + status: HttpStatus.SERVER_ERROR, + }) + ); + + const managedIdentityApplication = + new ManagedIdentityApplication({ + system: { + networkClient, + }, + }); + + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.IMDSV2); + expect(sendPostRequestSpy).toHaveBeenCalledTimes(3); // initial request + 2 retries + }); + }); + + describe("returns false", () => { + test("when probe response is 400 with invalid version", async () => { + mockCredentialEndpointProbeRequest( + HttpStatus.BAD_REQUEST, + "IMDS/1.1.1.1111" + ); + + const managedIdentityApplication = + new ManagedIdentityApplication({ + system: { + networkClient, + }, + }); + + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); + }); + + test("when probe response is 400 but the server header is missing", async () => { + mockCredentialEndpointProbeRequest(HttpStatus.BAD_REQUEST); + + const managedIdentityApplication = + new ManagedIdentityApplication({ + system: { + networkClient, + }, + }); + + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); + }); + + test("when probe response is not 400 or 500", async () => { + const sendPostRequestSpy: jest.SpyInstance = + mockCredentialEndpointProbeRequest( + HttpStatus.SUCCESS // TODO: change this to NOT_FOUND after implementing the credential endpoint probe retry policy + ); + + const managedIdentityApplication = + new ManagedIdentityApplication({ + system: { + networkClient, + }, + }); + + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); + expect(sendPostRequestSpy).toHaveBeenCalledTimes(1); // initial request + 0 retries + }); + + test("after maximum retry attempts", async () => { + const sendPostRequestSpy: jest.SpyInstance = + mockCredentialEndpointProbeRequest(HttpStatus.SERVER_ERROR); + + const managedIdentityApplication = + new ManagedIdentityApplication({ + system: { + networkClient, + }, + }); + + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.DEFAULT_TO_IMDS); + expect(sendPostRequestSpy).toHaveBeenCalledTimes(4); // initial request + 3 retries + }); + }); + }); +}); diff --git a/lib/msal-node/test/client/ManagedIdentitySources/MachineLearning.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/MachineLearning.spec.ts index 7317ffbbf3..ac4708c8e3 100644 --- a/lib/msal-node/test/client/ManagedIdentitySources/MachineLearning.spec.ts +++ b/lib/msal-node/test/client/ManagedIdentitySources/MachineLearning.spec.ts @@ -71,9 +71,9 @@ describe("Acquires a token successfully via an Machine Learning Managed Identity const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedClientIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.MACHINE_LEARNING - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.MACHINE_LEARNING); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -98,13 +98,10 @@ describe("Acquires a token successfully via an Machine Learning Managed Identity ) ).toBe(true); expect( - url.get( - ManagedIdentityUserAssignedIdQueryParameterNames.MANAGED_IDENTITY_CLIENT_ID_2017 + url.has( + ManagedIdentityUserAssignedIdQueryParameterNames.MANAGED_IDENTITY_CLIENT_ID ) - ).toEqual( - userAssignedClientIdConfig.managedIdentityIdParams - ?.userAssignedClientId - ); + ).toBe(false); }); test("ensures that App Service is selected as the Managed Identity source when all App Service and Machine Learning environment variables are present", async () => { @@ -117,9 +114,9 @@ describe("Acquires a token successfully via an Machine Learning Managed Identity const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedClientIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.APP_SERVICE - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.APP_SERVICE); delete process.env[ ManagedIdentityEnvironmentVariableNames.IDENTITY_ENDPOINT @@ -132,13 +129,13 @@ describe("Acquires a token successfully via an Machine Learning Managed Identity describe("System Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.MACHINE_LEARNING - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.MACHINE_LEARNING); }); test("acquires a token", async () => { @@ -221,9 +218,9 @@ describe("Acquires a token successfully via an Machine Learning Managed Identity const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(systemAssignedConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.MACHINE_LEARNING - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.MACHINE_LEARNING); let serverError: ServerError = new ServerError(); try { diff --git a/lib/msal-node/test/client/ManagedIdentitySources/ServiceFabric.spec.ts b/lib/msal-node/test/client/ManagedIdentitySources/ServiceFabric.spec.ts index dff4109cdd..3d82bd0d9f 100644 --- a/lib/msal-node/test/client/ManagedIdentitySources/ServiceFabric.spec.ts +++ b/lib/msal-node/test/client/ManagedIdentitySources/ServiceFabric.spec.ts @@ -75,9 +75,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedClientIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.SERVICE_FABRIC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -109,9 +109,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(userAssignedResourceIdConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.SERVICE_FABRIC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); const networkManagedIdentityResult: AuthenticationResult = await managedIdentityApplication.acquireToken( @@ -142,13 +142,13 @@ describe("Acquires a token successfully via an App Service Managed Identity", () describe("System Assigned", () => { let managedIdentityApplication: ManagedIdentityApplication; - beforeEach(() => { + beforeEach(async () => { managedIdentityApplication = new ManagedIdentityApplication( systemAssignedConfig ); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.SERVICE_FABRIC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); }); test("acquires a token", async () => { @@ -206,7 +206,7 @@ describe("Acquires a token successfully via an App Service Managed Identity", () clientCapabilities: providedCapabilities, }); expect( - managedIdentityApplication.getManagedIdentitySource() + await managedIdentityApplication.getManagedIdentitySource() ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); let networkManagedIdentityResult: AuthenticationResult = @@ -282,9 +282,9 @@ describe("Acquires a token successfully via an App Service Managed Identity", () const managedIdentityApplication: ManagedIdentityApplication = new ManagedIdentityApplication(systemAssignedConfig); - expect(managedIdentityApplication.getManagedIdentitySource()).toBe( - ManagedIdentitySourceNames.SERVICE_FABRIC - ); + expect( + await managedIdentityApplication.getManagedIdentitySource() + ).toBe(ManagedIdentitySourceNames.SERVICE_FABRIC); let serverError: ServerError = new ServerError(); try { diff --git a/lib/msal-node/test/test_kit/ManagedIdentityTestUtils.ts b/lib/msal-node/test/test_kit/ManagedIdentityTestUtils.ts index 0640e7930d..b1056bf14e 100644 --- a/lib/msal-node/test/test_kit/ManagedIdentityTestUtils.ts +++ b/lib/msal-node/test/test_kit/ManagedIdentityTestUtils.ts @@ -7,6 +7,8 @@ import { AuthenticationScheme, HttpStatus, INetworkModule, + Logger, + NetworkRequestOptions, NetworkResponse, TimeUtils, } from "@azure/msal-common"; @@ -20,6 +22,11 @@ import { import { ManagedIdentityTokenResponse } from "../../src/response/ManagedIdentityTokenResponse.js"; import { ManagedIdentityRequestParams } from "../../src/request/ManagedIdentityRequestParams.js"; import { ManagedIdentityConfiguration } from "../../src/config/Configuration.js"; +import { + CREDENTIAL_PATH, + CredentialEndpointProbeResponse, +} from "../../src/client/ManagedIdentitySources/ImdsV2.js"; +import { Imds } from "../../src/client/ManagedIdentitySources/Imds.js"; const EMPTY_HEADERS: Record = {}; @@ -150,3 +157,44 @@ export const systemAssignedConfig: ManagedIdentityConfiguration = { // managedIdentityIdParams will be omitted for system assigned }, }; + +export const credentialEndpointProbeResponse: CredentialEndpointProbeResponse = + { + error: "credential_endpoint_probe_error", + error_description: "credential_endpoint_probe_error_description", + }; + +export const mockCredentialEndpointProbeRequest = ( + status: number, + serverHeader?: string +) => { + const validatedCredentialEndpoint: string = Imds.getValidatedEndpoint( + CREDENTIAL_PATH, + new Logger({}) + ); + + const response: NetworkResponse = { + headers: serverHeader ? { server: serverHeader } : {}, + body: credentialEndpointProbeResponse, + status, + }; + + const sendPostRequestAsyncSpy: jest.SpyInstance = jest + .spyOn(networkClient, "sendPostRequestAsync") + .mockImplementation((( + url: string, + _options?: NetworkRequestOptions + ) => { + if (url === validatedCredentialEndpoint) { + return Promise.resolve(response); + } + throw new Error( + "An invalid url was used in the tests' post request" + ); + }) as typeof networkClient.sendPostRequestAsync); + // Type assertion is needed because sendPostRequestAsync is a generic method. + // Jest's mockImplementation does not infer generics, so we cast to the method's type + // to ensure the mock matches the original signature and TypeScript type checks correctly. + + return sendPostRequestAsyncSpy; +};