Skip to content

Commit ac1bf76

Browse files
committed
Add _internal dir for auth support
1 parent d545ba6 commit ac1bf76

File tree

5 files changed

+856
-0
lines changed

5 files changed

+856
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
from collections import namedtuple
6+
7+
from urllib.parse import urlparse
8+
9+
from .challenge_auth_policy import ChallengeAuthPolicy
10+
from .http_challenge import HttpChallenge
11+
from . import http_challenge_cache
12+
13+
HttpChallengeCache = http_challenge_cache # to avoid aliasing pylint error (C4745)
14+
15+
__all__ = [
16+
"ChallengeAuthPolicy",
17+
"HttpChallenge",
18+
"HttpChallengeCache",
19+
]
20+
21+
_VaultId = namedtuple("_VaultId", ["vault_url", "collection", "name", "version"])
22+
23+
24+
def parse_vault_id(url: str) -> "_VaultId":
25+
try:
26+
parsed_uri = urlparse(url)
27+
except Exception as exc: # pylint: disable=broad-except
28+
raise ValueError(f"'{url}' is not a valid url") from exc
29+
if not (parsed_uri.scheme and parsed_uri.hostname):
30+
raise ValueError(f"'{url}' is not a valid url")
31+
32+
path = list(filter(None, parsed_uri.path.split("/")))
33+
34+
if len(path) < 2 or len(path) > 3:
35+
raise ValueError(f"'{url}' is not a valid vault url")
36+
37+
return _VaultId(
38+
vault_url=f"{parsed_uri.scheme}://{parsed_uri.hostname}",
39+
collection=path[0],
40+
name=path[1],
41+
version=path[2] if len(path) == 3 else None,
42+
)
43+
44+
45+
try:
46+
# pylint:disable=unused-import
47+
from .async_challenge_auth_policy import AsyncChallengeAuthPolicy
48+
49+
__all__.extend(["AsyncChallengeAuthPolicy"])
50+
except (SyntaxError, ImportError):
51+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
"""Policy implementing Key Vault's challenge authentication protocol.
6+
7+
Normally the protocol is only used for the client's first service request, upon which:
8+
1. The challenge authentication policy sends a copy of the request, without authorization or content.
9+
2. Key Vault responds 401 with a header (the 'challenge') detailing how the client should authenticate such a request.
10+
3. The policy authenticates according to the challenge and sends the original request with authorization.
11+
12+
The policy caches the challenge and thus knows how to authenticate future requests. However, authentication
13+
requirements can change. For example, a vault may move to a new tenant. In such a case the policy will attempt the
14+
protocol again.
15+
"""
16+
17+
from copy import deepcopy
18+
import sys
19+
import time
20+
from typing import Any, Callable, cast, Optional, overload, TypeVar, Union
21+
from urllib.parse import urlparse
22+
23+
from typing_extensions import ParamSpec
24+
25+
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
26+
from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider
27+
from azure.core.pipeline import PipelineRequest, PipelineResponse
28+
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
29+
from azure.core.rest import AsyncHttpResponse, HttpRequest
30+
31+
from .http_challenge import HttpChallenge
32+
from . import http_challenge_cache as ChallengeCache
33+
from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge
34+
35+
if sys.version_info < (3, 9):
36+
from typing import Awaitable
37+
else:
38+
from collections.abc import Awaitable
39+
40+
41+
P = ParamSpec("P")
42+
T = TypeVar("T")
43+
44+
45+
@overload
46+
async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ...
47+
48+
49+
@overload
50+
async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...
51+
52+
53+
async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T:
54+
"""If func returns an awaitable, await it.
55+
56+
:param func: The function to run.
57+
:type func: callable
58+
:param args: The positional arguments to pass to the function.
59+
:type args: list
60+
:rtype: any
61+
:return: The result of the function
62+
"""
63+
result = func(*args, **kwargs)
64+
if isinstance(result, Awaitable):
65+
return await result
66+
return result
67+
68+
69+
class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
70+
"""Policy for handling HTTP authentication challenges.
71+
72+
:param credential: An object which can provide an access token for the vault, such as a credential from
73+
:mod:`azure.identity.aio`
74+
:type credential: ~azure.core.credentials_async.AsyncTokenProvider
75+
"""
76+
77+
def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
78+
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
79+
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
80+
self._credential: AsyncTokenProvider = credential
81+
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
82+
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
83+
self._request_copy: Optional[HttpRequest] = None
84+
85+
async def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, AsyncHttpResponse]:
86+
"""Authorize request with a bearer token and send it to the next policy.
87+
88+
We implement this method to account for the valid scenario where a Key Vault authentication challenge is
89+
immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to
90+
the caller, but we should handle that second challenge as well (and only return any third 401 response).
91+
92+
:param request: The pipeline request object
93+
:type request: ~azure.core.pipeline.PipelineRequest
94+
:return: The pipeline response object
95+
:rtype: ~azure.core.pipeline.PipelineResponse
96+
"""
97+
await await_result(self.on_request, request)
98+
response: PipelineResponse[HttpRequest, AsyncHttpResponse]
99+
try:
100+
response = await self.next.send(request)
101+
except Exception: # pylint:disable=broad-except
102+
await await_result(self.on_exception, request)
103+
raise
104+
await await_result(self.on_response, request, response)
105+
106+
if response.http_response.status_code == 401:
107+
return await self.handle_challenge_flow(request, response)
108+
return response
109+
110+
async def handle_challenge_flow(
111+
self,
112+
request: PipelineRequest[HttpRequest],
113+
response: PipelineResponse[HttpRequest, AsyncHttpResponse],
114+
consecutive_challenge: bool = False,
115+
) -> PipelineResponse[HttpRequest, AsyncHttpResponse]:
116+
"""Handle the challenge flow of Key Vault and CAE authentication.
117+
118+
:param request: The pipeline request object
119+
:type request: ~azure.core.pipeline.PipelineRequest
120+
:param response: The pipeline response object
121+
:type response: ~azure.core.pipeline.PipelineResponse
122+
:param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge.
123+
Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge.
124+
True if the preceding challenge was a Key Vault challenge; False otherwise.
125+
126+
:return: The pipeline response object
127+
:rtype: ~azure.core.pipeline.PipelineResponse
128+
"""
129+
self._token = None # any cached token is invalid
130+
if "WWW-Authenticate" in response.http_response.headers:
131+
# If the previous challenge was a KV challenge and this one is too, return the 401
132+
claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"])
133+
if consecutive_challenge and not claims_challenge:
134+
return response
135+
136+
request_authorized = await self.on_challenge(request, response)
137+
if request_authorized:
138+
# if we receive a challenge response, we retrieve a new token
139+
# which matches the new target. In this case, we don't want to remove
140+
# token from the request so clear the 'insecure_domain_change' tag
141+
request.context.options.pop("insecure_domain_change", False)
142+
try:
143+
response = await self.next.send(request)
144+
except Exception: # pylint:disable=broad-except
145+
await await_result(self.on_exception, request)
146+
raise
147+
148+
# If consecutive_challenge == True, this could be a third consecutive 401
149+
if response.http_response.status_code == 401 and not consecutive_challenge:
150+
# If the previous challenge wasn't from CAE, we can try this function one more time
151+
if not claims_challenge:
152+
return await self.handle_challenge_flow(request, response, consecutive_challenge=True)
153+
await await_result(self.on_response, request, response)
154+
return response
155+
156+
async def on_request(self, request: PipelineRequest) -> None:
157+
_enforce_tls(request)
158+
challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
159+
if challenge:
160+
# Note that if the vault has moved to a new tenant since our last request for it, this request will fail.
161+
if self._need_new_token():
162+
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
163+
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
164+
await self._request_kv_token(scope, challenge)
165+
166+
bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
167+
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
168+
return
169+
170+
# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
171+
# saving it for later. Key Vault will reject the request as unauthorized and respond with a challenge.
172+
# on_challenge will parse that challenge, use the original request including the body, authorize the
173+
# request, and tell super to send it again.
174+
if request.http_request.content:
175+
self._request_copy = request.http_request
176+
bodiless_request = HttpRequest(
177+
method=request.http_request.method,
178+
url=request.http_request.url,
179+
headers=deepcopy(request.http_request.headers),
180+
)
181+
bodiless_request.headers["Content-Length"] = "0"
182+
request.http_request = bodiless_request
183+
184+
async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool:
185+
try:
186+
# CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary
187+
old_scope: Optional[str] = None
188+
old_tenant: Optional[str] = None
189+
cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
190+
if cached_challenge:
191+
old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default"
192+
old_tenant = cached_challenge.tenant_id
193+
194+
challenge = _update_challenge(request, response)
195+
# CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary
196+
if challenge.claims and old_scope:
197+
challenge._parameters["scope"] = old_scope # pylint:disable=protected-access
198+
challenge.tenant_id = old_tenant
199+
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
200+
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
201+
except ValueError:
202+
return False
203+
204+
if self._verify_challenge_resource:
205+
resource_domain = urlparse(scope).netloc
206+
if not resource_domain:
207+
raise ValueError(f"The challenge contains invalid scope '{scope}'.")
208+
209+
request_domain = urlparse(request.http_request.url).netloc
210+
if not request_domain.lower().endswith(f".{resource_domain.lower()}"):
211+
raise ValueError(
212+
f"The challenge resource '{resource_domain}' does not match the requested domain. Pass "
213+
"`verify_challenge_resource=False` to your client's constructor to disable this verification. "
214+
"See https://aka.ms/azsdk/blog/vault-uri for more information."
215+
)
216+
217+
# If we had created a request copy in on_request, use it now to send along the original body content
218+
if self._request_copy:
219+
request.http_request = self._request_copy
220+
221+
# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
222+
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
223+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
224+
await self.authorize_request(request, scope, claims=challenge.claims)
225+
else:
226+
await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id)
227+
228+
return True
229+
230+
def _need_new_token(self) -> bool:
231+
now = time.time()
232+
refresh_on = getattr(self._token, "refresh_on", None)
233+
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
234+
235+
async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None:
236+
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
237+
238+
:param str scope: The scope for which to request a token.
239+
:param challenge: The challenge for the request being made.
240+
:type challenge: HttpChallenge
241+
"""
242+
# Exclude tenant for AD FS authentication
243+
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
244+
# The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
245+
if hasattr(self._credential, "get_token_info"):
246+
options: TokenRequestOptions = {"enable_cae": True}
247+
if challenge.tenant_id and not exclude_tenant:
248+
options["tenant_id"] = challenge.tenant_id
249+
self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options)
250+
else:
251+
if exclude_tenant:
252+
self._token = await self._credential.get_token(scope, enable_cae=True)
253+
else:
254+
self._token = await cast(AsyncTokenCredential, self._credential).get_token(
255+
scope, tenant_id=challenge.tenant_id, enable_cae=True
256+
)

0 commit comments

Comments
 (0)