|
| 1 | +# ------------------------------------ |
| 2 | +# Copyright (c) Microsoft Corporation. |
| 3 | +# Licensed under the MIT License. |
| 4 | +# ------------------------------------ |
| 5 | +"""Policy implementing Key Vault's challenge authentication protocol. |
| 6 | +
|
| 7 | +Normally the protocol is only used for the client's first service request, upon which: |
| 8 | +1. The challenge authentication policy sends a copy of the request, without authorization or content. |
| 9 | +2. Key Vault responds 401 with a header (the 'challenge') detailing how the client should authenticate such a request. |
| 10 | +3. The policy authenticates according to the challenge and sends the original request with authorization. |
| 11 | +
|
| 12 | +The policy caches the challenge and thus knows how to authenticate future requests. However, authentication |
| 13 | +requirements can change. For example, a vault may move to a new tenant. In such a case the policy will attempt the |
| 14 | +protocol again. |
| 15 | +""" |
| 16 | + |
| 17 | +from copy import deepcopy |
| 18 | +import sys |
| 19 | +import time |
| 20 | +from typing import Any, Callable, cast, Optional, overload, TypeVar, Union |
| 21 | +from urllib.parse import urlparse |
| 22 | + |
| 23 | +from typing_extensions import ParamSpec |
| 24 | + |
| 25 | +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions |
| 26 | +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider |
| 27 | +from azure.core.pipeline import PipelineRequest, PipelineResponse |
| 28 | +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy |
| 29 | +from azure.core.rest import AsyncHttpResponse, HttpRequest |
| 30 | + |
| 31 | +from .http_challenge import HttpChallenge |
| 32 | +from . import http_challenge_cache as ChallengeCache |
| 33 | +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge |
| 34 | + |
| 35 | +if sys.version_info < (3, 9): |
| 36 | + from typing import Awaitable |
| 37 | +else: |
| 38 | + from collections.abc import Awaitable |
| 39 | + |
| 40 | + |
| 41 | +P = ParamSpec("P") |
| 42 | +T = TypeVar("T") |
| 43 | + |
| 44 | + |
| 45 | +@overload |
| 46 | +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... |
| 47 | + |
| 48 | + |
| 49 | +@overload |
| 50 | +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... |
| 51 | + |
| 52 | + |
| 53 | +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: |
| 54 | + """If func returns an awaitable, await it. |
| 55 | +
|
| 56 | + :param func: The function to run. |
| 57 | + :type func: callable |
| 58 | + :param args: The positional arguments to pass to the function. |
| 59 | + :type args: list |
| 60 | + :rtype: any |
| 61 | + :return: The result of the function |
| 62 | + """ |
| 63 | + result = func(*args, **kwargs) |
| 64 | + if isinstance(result, Awaitable): |
| 65 | + return await result |
| 66 | + return result |
| 67 | + |
| 68 | + |
| 69 | +class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): |
| 70 | + """Policy for handling HTTP authentication challenges. |
| 71 | +
|
| 72 | + :param credential: An object which can provide an access token for the vault, such as a credential from |
| 73 | + :mod:`azure.identity.aio` |
| 74 | + :type credential: ~azure.core.credentials_async.AsyncTokenProvider |
| 75 | + """ |
| 76 | + |
| 77 | + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: |
| 78 | + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request |
| 79 | + super().__init__(credential, *scopes, enable_cae=True, **kwargs) |
| 80 | + self._credential: AsyncTokenProvider = credential |
| 81 | + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None |
| 82 | + self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) |
| 83 | + self._request_copy: Optional[HttpRequest] = None |
| 84 | + |
| 85 | + async def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: |
| 86 | + """Authorize request with a bearer token and send it to the next policy. |
| 87 | +
|
| 88 | + We implement this method to account for the valid scenario where a Key Vault authentication challenge is |
| 89 | + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to |
| 90 | + the caller, but we should handle that second challenge as well (and only return any third 401 response). |
| 91 | +
|
| 92 | + :param request: The pipeline request object |
| 93 | + :type request: ~azure.core.pipeline.PipelineRequest |
| 94 | + :return: The pipeline response object |
| 95 | + :rtype: ~azure.core.pipeline.PipelineResponse |
| 96 | + """ |
| 97 | + await await_result(self.on_request, request) |
| 98 | + response: PipelineResponse[HttpRequest, AsyncHttpResponse] |
| 99 | + try: |
| 100 | + response = await self.next.send(request) |
| 101 | + except Exception: # pylint:disable=broad-except |
| 102 | + await await_result(self.on_exception, request) |
| 103 | + raise |
| 104 | + await await_result(self.on_response, request, response) |
| 105 | + |
| 106 | + if response.http_response.status_code == 401: |
| 107 | + return await self.handle_challenge_flow(request, response) |
| 108 | + return response |
| 109 | + |
| 110 | + async def handle_challenge_flow( |
| 111 | + self, |
| 112 | + request: PipelineRequest[HttpRequest], |
| 113 | + response: PipelineResponse[HttpRequest, AsyncHttpResponse], |
| 114 | + consecutive_challenge: bool = False, |
| 115 | + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: |
| 116 | + """Handle the challenge flow of Key Vault and CAE authentication. |
| 117 | +
|
| 118 | + :param request: The pipeline request object |
| 119 | + :type request: ~azure.core.pipeline.PipelineRequest |
| 120 | + :param response: The pipeline response object |
| 121 | + :type response: ~azure.core.pipeline.PipelineResponse |
| 122 | + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. |
| 123 | + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. |
| 124 | + True if the preceding challenge was a Key Vault challenge; False otherwise. |
| 125 | +
|
| 126 | + :return: The pipeline response object |
| 127 | + :rtype: ~azure.core.pipeline.PipelineResponse |
| 128 | + """ |
| 129 | + self._token = None # any cached token is invalid |
| 130 | + if "WWW-Authenticate" in response.http_response.headers: |
| 131 | + # If the previous challenge was a KV challenge and this one is too, return the 401 |
| 132 | + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) |
| 133 | + if consecutive_challenge and not claims_challenge: |
| 134 | + return response |
| 135 | + |
| 136 | + request_authorized = await self.on_challenge(request, response) |
| 137 | + if request_authorized: |
| 138 | + # if we receive a challenge response, we retrieve a new token |
| 139 | + # which matches the new target. In this case, we don't want to remove |
| 140 | + # token from the request so clear the 'insecure_domain_change' tag |
| 141 | + request.context.options.pop("insecure_domain_change", False) |
| 142 | + try: |
| 143 | + response = await self.next.send(request) |
| 144 | + except Exception: # pylint:disable=broad-except |
| 145 | + await await_result(self.on_exception, request) |
| 146 | + raise |
| 147 | + |
| 148 | + # If consecutive_challenge == True, this could be a third consecutive 401 |
| 149 | + if response.http_response.status_code == 401 and not consecutive_challenge: |
| 150 | + # If the previous challenge wasn't from CAE, we can try this function one more time |
| 151 | + if not claims_challenge: |
| 152 | + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) |
| 153 | + await await_result(self.on_response, request, response) |
| 154 | + return response |
| 155 | + |
| 156 | + async def on_request(self, request: PipelineRequest) -> None: |
| 157 | + _enforce_tls(request) |
| 158 | + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) |
| 159 | + if challenge: |
| 160 | + # Note that if the vault has moved to a new tenant since our last request for it, this request will fail. |
| 161 | + if self._need_new_token(): |
| 162 | + # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource |
| 163 | + scope = challenge.get_scope() or challenge.get_resource() + "/.default" |
| 164 | + await self._request_kv_token(scope, challenge) |
| 165 | + |
| 166 | + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token |
| 167 | + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" |
| 168 | + return |
| 169 | + |
| 170 | + # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, |
| 171 | + # saving it for later. Key Vault will reject the request as unauthorized and respond with a challenge. |
| 172 | + # on_challenge will parse that challenge, use the original request including the body, authorize the |
| 173 | + # request, and tell super to send it again. |
| 174 | + if request.http_request.content: |
| 175 | + self._request_copy = request.http_request |
| 176 | + bodiless_request = HttpRequest( |
| 177 | + method=request.http_request.method, |
| 178 | + url=request.http_request.url, |
| 179 | + headers=deepcopy(request.http_request.headers), |
| 180 | + ) |
| 181 | + bodiless_request.headers["Content-Length"] = "0" |
| 182 | + request.http_request = bodiless_request |
| 183 | + |
| 184 | + async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: |
| 185 | + try: |
| 186 | + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary |
| 187 | + old_scope: Optional[str] = None |
| 188 | + old_tenant: Optional[str] = None |
| 189 | + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) |
| 190 | + if cached_challenge: |
| 191 | + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" |
| 192 | + old_tenant = cached_challenge.tenant_id |
| 193 | + |
| 194 | + challenge = _update_challenge(request, response) |
| 195 | + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary |
| 196 | + if challenge.claims and old_scope: |
| 197 | + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access |
| 198 | + challenge.tenant_id = old_tenant |
| 199 | + # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource |
| 200 | + scope = challenge.get_scope() or challenge.get_resource() + "/.default" |
| 201 | + except ValueError: |
| 202 | + return False |
| 203 | + |
| 204 | + if self._verify_challenge_resource: |
| 205 | + resource_domain = urlparse(scope).netloc |
| 206 | + if not resource_domain: |
| 207 | + raise ValueError(f"The challenge contains invalid scope '{scope}'.") |
| 208 | + |
| 209 | + request_domain = urlparse(request.http_request.url).netloc |
| 210 | + if not request_domain.lower().endswith(f".{resource_domain.lower()}"): |
| 211 | + raise ValueError( |
| 212 | + f"The challenge resource '{resource_domain}' does not match the requested domain. Pass " |
| 213 | + "`verify_challenge_resource=False` to your client's constructor to disable this verification. " |
| 214 | + "See https://aka.ms/azsdk/blog/vault-uri for more information." |
| 215 | + ) |
| 216 | + |
| 217 | + # If we had created a request copy in on_request, use it now to send along the original body content |
| 218 | + if self._request_copy: |
| 219 | + request.http_request = self._request_copy |
| 220 | + |
| 221 | + # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication |
| 222 | + # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 |
| 223 | + if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): |
| 224 | + await self.authorize_request(request, scope, claims=challenge.claims) |
| 225 | + else: |
| 226 | + await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) |
| 227 | + |
| 228 | + return True |
| 229 | + |
| 230 | + def _need_new_token(self) -> bool: |
| 231 | + now = time.time() |
| 232 | + refresh_on = getattr(self._token, "refresh_on", None) |
| 233 | + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 |
| 234 | + |
| 235 | + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: |
| 236 | + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. |
| 237 | +
|
| 238 | + :param str scope: The scope for which to request a token. |
| 239 | + :param challenge: The challenge for the request being made. |
| 240 | + :type challenge: HttpChallenge |
| 241 | + """ |
| 242 | + # Exclude tenant for AD FS authentication |
| 243 | + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") |
| 244 | + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs |
| 245 | + if hasattr(self._credential, "get_token_info"): |
| 246 | + options: TokenRequestOptions = {"enable_cae": True} |
| 247 | + if challenge.tenant_id and not exclude_tenant: |
| 248 | + options["tenant_id"] = challenge.tenant_id |
| 249 | + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) |
| 250 | + else: |
| 251 | + if exclude_tenant: |
| 252 | + self._token = await self._credential.get_token(scope, enable_cae=True) |
| 253 | + else: |
| 254 | + self._token = await cast(AsyncTokenCredential, self._credential).get_token( |
| 255 | + scope, tenant_id=challenge.tenant_id, enable_cae=True |
| 256 | + ) |
0 commit comments