Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core auth flows #40084

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- The `DistributedTracingPolicy` and `distributed_trace`/`distributed_trace_async` decorators now uses the OpenTelemetry tracer if it is available and native tracing is enabled.
- SDK clients can define an `_instrumentation_config` class variable to configure the OpenTelemetry tracer used in method span creation. Possible configuration options are `library_name`, `library_version`, `schema_url`, and `attributes`.
- `DistributedTracingPolicy` now accepts a `instrumentation_config` keyword argument to configure the OpenTelemetry tracer used in HTTP span creation.
- Added `auth_flows` support in `BearerTokenCredentialPolicy`.

### Breaking Changes

Expand Down
20 changes: 16 additions & 4 deletions sdk/core/corehttp/corehttp/runtime/policies/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,22 @@ class _BearerTokenCredentialPolicyBase:
:param credential: The credential.
:type credential: ~corehttp.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword list[dict[str, str]] auth_flows: A list of authentication flows to use for the credential.
"""

# pylint: disable=unused-argument
def __init__(
self, credential: "TokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
self,
credential: "TokenCredential",
*scopes: str,
auth_flows: Optional[list[dict[str, str]]] = None,
**kwargs: Any,
) -> None:
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessTokenInfo"] = None
self._auth_flows = auth_flows

@staticmethod
def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
Expand Down Expand Up @@ -83,20 +90,25 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[H
:param credential: The credential.
:type credential: ~corehttp.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword list[dict[str, str]] auth_flows: A list of authentication flows to use for the credential.
:raises: :class:`~corehttp.exceptions.ServiceRequestError`
"""

def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
def on_request(
self, request: PipelineRequest[HTTPRequestType], *, auth_flows: Optional[list[dict[str, str]]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to take in auth_flows as a parameter here? Or can we just access it in this method via self._auth_flows?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the parameter because we will support per-operation auth_flow. :)

) -> None:
"""Called before the policy sends a request.

The base implementation authorizes the request with a bearer token.

:param ~corehttp.runtime.pipeline.PipelineRequest request: the request
:keyword list[dict[str, str]] auth_flows: A list of authentication flows to use for the credential.
"""
self._enforce_https(request)

if self._token is None or self._need_new_token:
self._token = self._credential.get_token_info(*self._scopes)
options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add auth_flows to the TypedDict TokenRequestOptions class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is optional. I would hold until the GA at least. :)

self._token = self._credential.get_token_info(*self._scopes, options=options)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
Expand Down Expand Up @@ -124,7 +136,7 @@ def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HT
:return: The pipeline response object
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
"""
self.on_request(request)
self.on_request(request, auth_flows=self._auth_flows)
try:
response = self.next.send(request)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,38 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
:param credential: The credential.
:type credential: ~corehttp.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword list[dict[str, str]] auth_flows: A list of authentication flows to use for the credential.
"""

# pylint: disable=unused-argument
def __init__(
self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any # pylint: disable=unused-argument
self,
credential: "AsyncTokenCredential",
*scopes: str,
auth_flows: Optional[list[dict[str, str]]] = None,
**kwargs: Any,
) -> None:
super().__init__()
self._credential = credential
self._lock_instance = None
self._scopes = scopes
self._token: Optional[AccessTokenInfo] = None
self._auth_flows = auth_flows

@property
def _lock(self):
if self._lock_instance is None:
self._lock_instance = get_running_async_lock()
return self._lock_instance

async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
async def on_request(
self, request: PipelineRequest[HTTPRequestType], *, auth_flows: Optional[list[dict[str, str]]] = None
) -> None:
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~corehttp.runtime.pipeline.PipelineRequest
:keyword list[dict[str, str]] auth_flows: A list of authentication flows to use for the credential.
:raises: :class:`~corehttp.exceptions.ServiceRequestError`
"""
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
Expand All @@ -59,7 +69,8 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token:
self._token = await await_result(self._credential.get_token_info, *self._scopes)
options: TokenRequestOptions = {"auth_flows": auth_flows} if auth_flows else {} # type: ignore
self._token = await await_result(self._credential.get_token_info, *self._scopes, options=options)
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessTokenInfo, self._token).token

async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
Expand Down Expand Up @@ -91,7 +102,7 @@ async def send(
:return: The pipeline response object
:rtype: ~corehttp.runtime.pipeline.PipelineResponse
"""
await await_result(self.on_request, request)
await await_result(self.on_request, request, auth_flows=self._auth_flows)
try:
response = await self.next.send(request)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def send(self, request):
pipeline = AsyncPipeline(transport=transport, policies=[policy])
await pipeline.run(HttpRequest("GET", "https://localhost"))

policy.on_request.assert_called_once_with(policy.request)
policy.on_request.assert_called_once_with(policy.request, auth_flows=None)
policy.on_response.assert_called_once_with(policy.request, policy.response)

# the policy should call on_exception when next.send() raises
Expand Down Expand Up @@ -275,7 +275,7 @@ def get_completed_future(result=None):
@pytest.mark.asyncio
async def test_async_token_credential_inheritance():
class TestTokenCredential(AsyncTokenCredential):
async def get_token_info(self, *scopes, options=None):
async def get_token_info(self, *scopes, options={}):
return "TOKEN"

cred = TestTokenCredential()
Expand Down Expand Up @@ -319,3 +319,18 @@ async def test_need_new_token():
# Token is not close to expiring, but refresh_on is in the past.
policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1)
assert policy._need_new_token


@pytest.mark.asyncio
async def test_send_with_auth_flows():
auth_flows = [{"type": "flow1"}, {"type": "flow2"}]
credential = Mock(
spec_set=["get_token_info"],
get_token_info=Mock(return_value=get_completed_future(AccessTokenInfo("***", int(time.time()) + 3600))),
)
policy = AsyncBearerTokenCredentialPolicy(credential, "scope", auth_flows=auth_flows)
transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200))))

pipeline = AsyncPipeline(transport=transport, policies=[policy])
await pipeline.run(HttpRequest("GET", "https://localhost"))
policy._credential.get_token_info.assert_called_with("scope", options={"auth_flows": auth_flows})
20 changes: 17 additions & 3 deletions sdk/core/corehttp/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_bearer_policy_default_context():

pipeline.run(HttpRequest("GET", "https://localhost"))

credential.get_token_info.assert_called_once_with(expected_scope)
credential.get_token_info.assert_called_once_with(expected_scope, options={})


def test_bearer_policy_context_unmodified_by_default():
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_bearer_policy_cannot_complete_challenge():

assert response.http_response is expected_response
assert transport.send.call_count == 1
credential.get_token_info.assert_called_once_with(expected_scope)
credential.get_token_info.assert_called_once_with(expected_scope, options={})


def test_bearer_policy_calls_sansio_methods():
Expand All @@ -221,7 +221,7 @@ def send(self, request):
pipeline = Pipeline(transport=transport, policies=[policy])
pipeline.run(HttpRequest("GET", "https://localhost"))

policy.on_request.assert_called_once_with(policy.request)
policy.on_request.assert_called_once_with(policy.request, auth_flows=None)
policy.on_response.assert_called_once_with(policy.request, policy.response)

# the policy should call on_exception when next.send() raises
Expand Down Expand Up @@ -415,3 +415,17 @@ def test_need_new_token():
# Token is not close to expiring, but refresh_on is in the past.
policy._token = AccessTokenInfo("", now + 1200, refresh_on=now - 1)
assert policy._need_new_token


def test_send_with_auth_flows():
auth_flows = [{"type": "flow1"}, {"type": "flow2"}]
credential = Mock(
spec_set=["get_token_info"],
get_token_info=Mock(return_value=AccessTokenInfo("***", int(time.time()) + 3600)),
)
policy = BearerTokenCredentialPolicy(credential, "scope", auth_flows=auth_flows)
transport = Mock(send=Mock(return_value=Mock(status_code=200)))

pipeline = Pipeline(transport=transport, policies=[policy])
pipeline.run(HttpRequest("GET", "https://localhost"))
policy._credential.get_token_info.assert_called_with("scope", options={"auth_flows": auth_flows})