diff --git a/azure-quantum/azure/quantum/_authentication/_default.py b/azure-quantum/azure/quantum/_authentication/_default.py index 014754be6..32111de83 100644 --- a/azure-quantum/azure/quantum/_authentication/_default.py +++ b/azure-quantum/azure/quantum/_authentication/_default.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import sys import logging import re from typing import Optional @@ -16,10 +17,13 @@ InteractiveBrowserCredential, DeviceCodeCredential, _internal as AzureIdentityInternals, + TokenCachePersistenceOptions, + SharedTokenCacheCredential, + _persistent_cache as AzureIdentityPersistentCache ) +from azure.quantum._constants import ConnectionConstants from ._chained import _ChainedTokenCredential from ._token import _TokenFileCredential -from azure.quantum._constants import ConnectionConstants _LOGGER = logging.getLogger(__name__) WWW_AUTHENTICATE_REGEX = re.compile( @@ -61,7 +65,7 @@ def __init__( client_id: Optional[str] = None, tenant_id: Optional[str] = None, authority: Optional[str] = None, - ): + ) -> None: if arm_endpoint is None: raise ValueError("arm_endpoint is mandatory parameter") if subscription_id is None: @@ -84,22 +88,90 @@ def _authority_or_default(self, authority: str, arm_endpoint: str): return ConnectionConstants.DOGFOOD_AUTHORITY return ConnectionConstants.AUTHORITY - def _initialize_credentials(self): + def _get_cache_options(self) -> Optional[TokenCachePersistenceOptions]: + """ + Returns a valid TokenCachePersistenceOptions + if the AzureIdentity Persistent Cache is accessible. + Returns None otherwise. + """ + cache_options = TokenCachePersistenceOptions( + allow_unencrypted_storage=False, + name="AzureQuantumSDK" + ) + try: + # pylint: disable=protected-access + cache = AzureIdentityPersistentCache._load_persistent_cache(cache_options) + try: + # Try to get the location of the cache for + # tracing purpose. + _LOGGER.info( + "Using Azure.Identity Token Cache at %s.", + cache._persistence.get_location() + ) + except: # pylint: disable=bare-except + _LOGGER.info("Using Azure.Identity Token Cache.") + return cache_options + except Exception as ex: # pylint: disable=broad-except + # Check if the cache issue on linux is due + # libsecret not functioning to provider better + # information to the user. + if sys.platform.startswith("linux"): + try: + # pylint: disable=import-outside-toplevel + from msal_extensions.libsecret import trial_run + trial_run() + except Exception as libsecret_ex: # pylint: disable=broad-except + _LOGGER.warning( + "libsecret dependencies are not installed or are unusable.\n" + "Please install the necessary dependencies as instructed in " + "https://github.com/AzureAD/microsoft-authentication-extensions-for-python/wiki/Encryption-on-Linux" # pylint: disable=line-too-long + "Exception:\n%s", + libsecret_ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + + _LOGGER.warning( + 'Error trying to access Azure.Identity Token Cache. ' + "Raised unexpected exception:\n%s", + ex, + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), + ) + return None + + def _initialize_credentials(self) -> None: self._discover_tenant_id_( arm_endpoint=self.arm_endpoint, subscription_id=self.subscription_id) + cache_options = self._get_cache_options() credentials = [] credentials.append(_TokenFileCredential()) credentials.append(EnvironmentCredential()) if self.client_id: credentials.append(ManagedIdentityCredential(client_id=self.client_id)) if self.authority and self.tenant_id: - credentials.append(VisualStudioCodeCredential(authority=self.authority, tenant_id=self.tenant_id)) + credentials.append(VisualStudioCodeCredential( + authority=self.authority, + tenant_id=self.tenant_id)) credentials.append(AzureCliCredential(tenant_id=self.tenant_id)) credentials.append(AzurePowerShellCredential(tenant_id=self.tenant_id)) - credentials.append(InteractiveBrowserCredential(authority=self.authority, tenant_id=self.tenant_id)) + # The SharedTokenCacheCredential is used when the token cache + # is available to attempt loading a token stored in the cache + # by the InteractiveBrowserCredential. + if cache_options: + credentials.append(SharedTokenCacheCredential( + authority=self.authority, + tenant_id=self.tenant_id, + cache_persistence_options=cache_options)) + credentials.append( + InteractiveBrowserCredential( + authority=self.authority, + tenant_id=self.tenant_id, + cache_persistence_options=cache_options)) if self.client_id: - credentials.append(DeviceCodeCredential(authority=self.authority, client_id=self.client_id, tenant_id=self.tenant_id)) + credentials.append(DeviceCodeCredential( + authority=self.authority, + client_id=self.client_id, + tenant_id=self.tenant_id)) self.credentials = credentials def get_token(self, *scopes: str, **kwargs) -> AccessToken: @@ -145,8 +217,7 @@ def _discover_tenant_id_(self, arm_endpoint:str, subscription_id:str): match = re.search(WWW_AUTHENTICATE_REGEX, www_authenticate) if match: self.tenant_id = match.group("tenant_id") - # pylint: disable=broad-exception-caught - except Exception as ex: + except Exception as ex: # pylint: disable=broad-exception-caught _LOGGER.error(ex) # apply default values