Skip to content

Corehttp auth flows #40084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 26, 2025
Merged
1 change: 1 addition & 0 deletions sdk/core/corehttp/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. [#38346](https://github.com/Azure/azure-sdk-for-python/pull/38346)
- These policies now also check the `refresh_on` attribute when determining if a new token request should be made.
- Added `model` attribute to `HttpResponseError` to allow accessing error attributes based on a known model. [#39636](https://github.com/Azure/azure-sdk-for-python/pull/39636)
- Added `auth_flows` support in `BearerTokenCredentialPolicy`. [#40084](https://github.com/Azure/azure-sdk-for-python/pull/40084)

### Breaking Changes

Expand Down
28 changes: 23 additions & 5 deletions sdk/core/corehttp/corehttp/runtime/policies/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# -------------------------------------------------------------------------
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union

from ...credentials import TokenRequestOptions
from ...rest import HttpResponse, HttpRequest
Expand All @@ -32,15 +32,23 @@ class _BearerTokenCredentialPolicyBase:
:param credential: The credential.
:type credential: ~corehttp.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword auth_flows: A list of authentication flows to use for the credential.
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
"""

# pylint: disable=unused-argument
def __init__(
self, credential: "TokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
self,
credential: "TokenCredential",
*scopes: str,
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
**kwargs: Any,
) -> None:
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessTokenInfo"] = None
self._auth_flows = auth_flows

@staticmethod
def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
Expand Down Expand Up @@ -83,20 +91,30 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
:param credential: The credential.
:type credential: ~corehttp.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword auth_flows: A list of authentication flows to use for the credential.
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
:raises: :class:`~corehttp.exceptions.ServiceRequestError`
"""

def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
def on_request(
self,
request: PipelineRequest[HTTPRequestType],
*,
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
) -> None:
"""Called before the policy sends a request.

The base implementation authorizes the request with a bearer token.

:param ~corehttp.runtime.pipeline.PipelineRequest request: the request
:keyword auth_flows: A list of authentication flows to use for the credential.
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
"""
self._enforce_https(request)

if self._token is None or self._need_new_token:
self._token = self._credential.get_token_info(*self._scopes)
options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
self._token = self._credential.get_token_info(*self._scopes, options=options)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
Expand Down Expand Up @@ -124,7 +142,7 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT
:return: The pipeline response object
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
"""
self.on_request(request)
self.on_request(request, auth_flows=self._auth_flows)
try:
response = self.next.send(request)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# -------------------------------------------------------------------------
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar, Union

from ...credentials import AccessTokenInfo, TokenRequestOptions
from ..pipeline import PipelineRequest, PipelineResponse
Expand All @@ -29,28 +29,43 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
:param credential: The credential.
:type credential: ~corehttp.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword auth_flows: A list of authentication flows to use for the credential.
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
"""

# pylint: disable=unused-argument
def __init__(
self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
self,
credential: "AsyncTokenCredential",
*scopes: str,
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
**kwargs: Any,
) -> None:
super().__init__()
self._credential = credential
self._lock_instance = None
self._scopes = scopes
self._token: Optional[AccessTokenInfo] = None
self._auth_flows = auth_flows

@property
def _lock(self):
if self._lock_instance is None:
self._lock_instance = get_running_async_lock()
return self._lock_instance

async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
async def on_request(
self,
request: PipelineRequest[HTTPRequestType],
*,
auth_flows: Optional[list[dict[str, Union[str, list[dict[str, str]]]]]] = None,
) -> None:
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:keyword auth_flows: A list of authentication flows to use for the credential.
:paramtype auth_flows: list[dict[str, Union[str, list[dict[str, str]]]]]
:raises: :class:`~corehttp.exceptions.ServiceRequestError`
"""
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
Expand All @@ -59,7 +74,8 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token:
self._token = await await_result(self._credential.get_token_info, *self._scopes)
options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
self._token = await await_result(self._credential.get_token_info, *self._scopes, options=options)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token

async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
Expand Down Expand Up @@ -91,7 +107,7 @@ async def send(
:return: The pipeline response object
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
"""
await await_result(self.on_request, request)
await await_result(self.on_request, request, auth_flows=self._auth_flows)
try:
response = await self.next.send(request)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def send(self, request):
pipeline = AsyncPipeline(transport=transport, policies=[policy])
await pipeline.run(HttpRequest("GET", "https://localhost"))

policy.on_request.assert_called_once_with(policy.request)
policy.on_request.assert_called_once_with(policy.request, auth_flows=None)
policy.on_response.assert_called_once_with(policy.request, policy.response)

# the policy should call on_exception when next.send() raises
Expand Down Expand Up @@ -275,7 +275,7 @@ def get_completed_future(result=None):
@pytest.mark.asyncio
async def test_async_token_credential_inheritance():
class TestTokenCredential(AsyncTokenCredential):
async def get_token_info(self, *scopes, options=None):
async def get_token_info(self, *scopes, options={}):
return "TOKEN"

cred = TestTokenCredential()
Expand Down Expand Up @@ -319,3 +319,18 @@ async def test_need_new_token():
# Token is not close to expiring, but refresh_on is in the past.
policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1)
assert policy._need_new_token


@pytest.mark.asyncio
async def test_send_with_auth_flows():
auth_flows = [{"type": "flow1"}, {"type": "flow2"}]
credential = Mock(
spec_set=["get_token_info"],
get_token_info=Mock(return_value=get_completed_future(AccessTokenInfo("***", int(time.time()) + 3600))),
)
policy = AsyncBearerTokenCredentialPolicy(credential, "scope", auth_flows=auth_flows)
transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200))))

pipeline = AsyncPipeline(transport=transport, policies=[policy])
await pipeline.run(HttpRequest("GET", "https://localhost"))
policy._credential.get_token_info.assert_called_with("scope", options={"auth_flows": auth_flows})
20 changes: 17 additions & 3 deletions sdk/core/corehttp/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_bearer_policy_default_context():

pipeline.run(HttpRequest("GET", "https://localhost"))

credential.get_token_info.assert_called_once_with(expected_scope)
credential.get_token_info.assert_called_once_with(expected_scope, options={})


def test_bearer_policy_context_unmodified_by_default():
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_bearer_policy_cannot_complete_challenge():

assert response.http_response is expected_response
assert transport.send.call_count == 1
credential.get_token_info.assert_called_once_with(expected_scope)
credential.get_token_info.assert_called_once_with(expected_scope, options={})


def test_bearer_policy_calls_sansio_methods():
Expand All @@ -221,7 +221,7 @@ def send(self, request):
pipeline = Pipeline(transport=transport, policies=[policy])
pipeline.run(HttpRequest("GET", "https://localhost"))

policy.on_request.assert_called_once_with(policy.request)
policy.on_request.assert_called_once_with(policy.request, auth_flows=None)
policy.on_response.assert_called_once_with(policy.request, policy.response)

# the policy should call on_exception when next.send() raises
Expand Down Expand Up @@ -415,3 +415,17 @@ def test_need_new_token():
# Token is not close to expiring, but refresh_on is in the past.
policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1)
assert policy._need_new_token


def test_send_with_auth_flows():
auth_flows = [{"type": "flow1"}, {"type": "flow2"}]
credential = Mock(
spec_set=["get_token_info"],
get_token_info=Mock(return_value=AccessTokenInfo("***", int(time.time()) + 3600)),
)
policy = BearerTokenCredentialPolicy(credential, "scope", auth_flows=auth_flows)
transport = Mock(send=Mock(return_value=Mock(status_code=200)))

pipeline = Pipeline(transport=transport, policies=[policy])
pipeline.run(HttpRequest("GET", "https://localhost"))
policy._credential.get_token_info.assert_called_with("scope", options={"auth_flows": auth_flows})
Loading