Skip to content

[Tables] Add audience keyword argument support #40487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/tables/azure-data-tables/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* Added to support customized encoding and decoding in entity CRUD operations.
* Added to support Entity property in Tuple and Enum types.
* Added to support flatten Entity metadata in entity deserialization by passing kwarg `flatten_result_entity` when creating clients.
* Added support for configuring custom audiences for `TokenCredential` authentication when initializing a `TableClient` or `TableServiceClient`. ([#40487](https://github.com/Azure/azure-sdk-for-python/pull/40487))

### Bugs Fixed
* Fixed duplicate odata tag bug in encoder when Entity property has "@odata.type" provided.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,14 @@ def _configure_credential(
Union[AzureNamedKeyCredential, AzureSasCredential, TokenCredential, SharedKeyCredentialPolicy]
],
cosmos_endpoint: bool = False,
audience: Optional[str] = None,
) -> Optional[Union[BearerTokenChallengePolicy, AzureSasCredentialPolicy, SharedKeyCredentialPolicy]]:
if hasattr(credential, "get_token"):
credential = cast(TokenCredential, credential)
scope = COSMOS_OAUTH_SCOPE if cosmos_endpoint else STORAGE_OAUTH_SCOPE
if audience:
scope = audience.rstrip("/") + "/.default"
else:
scope = COSMOS_OAUTH_SCOPE if cosmos_endpoint else STORAGE_OAUTH_SCOPE
return BearerTokenChallengePolicy(credential, scope)
if isinstance(credential, SharedKeyCredentialPolicy):
return credential
Expand Down
12 changes: 9 additions & 3 deletions sdk/tables/azure-data-tables/azure/data/tables/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
*,
credential: Optional[Union[AzureSasCredential, AzureNamedKeyCredential, TokenCredential]] = None,
api_version: Optional[str] = None,
audience: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Expand All @@ -83,6 +84,9 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
~azure.core.credentials.AzureNamedKeyCredential or
~azure.core.credentials.AzureSasCredential or
~azure.core.credentials.TokenCredential or None
:keyword audience: Optional audience to use for Microsoft Entra ID authentication. If not specified,
the public cloud audience will be used.
:paramtype audience: str or None
:keyword api_version: Specifies the version of the operation to use for this request. Default value
is "2019-02-02".
:paramtype api_version: str or None
Expand Down Expand Up @@ -129,7 +133,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
}
self._hosts = _hosts

self._policies = self._configure_policies(hosts=self._hosts, **kwargs)
self._policies = self._configure_policies(audience=audience, hosts=self._hosts, **kwargs)
if self._cosmos_endpoint:
self._policies.insert(0, CosmosPatchTransformPolicy())

Expand Down Expand Up @@ -222,8 +226,10 @@ def _format_url(self, hostname):
"""
return f"{self.scheme}://{hostname}{self._query_str}"

def _configure_policies(self, **kwargs):
credential_policy = _configure_credential(self.credential, self._cosmos_endpoint)
def _configure_policies(self, *, audience: Optional[str] = None, **kwargs: Any) -> List[Any]:
credential_policy = _configure_credential(
self.credential, cosmos_endpoint=self._cosmos_endpoint, audience=audience
)
return [
RequestIdPolicy(**kwargs),
StorageHeadersPolicy(**kwargs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
table_name: str,
*,
credential: Optional[Union[AzureSasCredential, AzureNamedKeyCredential, TokenCredential]] = None,
audience: Optional[str] = None,
api_version: Optional[str] = None,
encoder_map: Optional[EncoderMapType] = None,
decoder_map: Optional[DecoderMapType] = None,
Expand All @@ -91,6 +92,9 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
~azure.core.credentials.AzureNamedKeyCredential or
~azure.core.credentials.AzureSasCredential or
~azure.core.credentials.TokenCredential or None
:keyword audience: Optional audience to use for Microsoft Entra ID authentication. If not specified,
the public cloud audience will be used.
:paramtype audience: str or None
:keyword api_version: Specifies the version of the operation to use for this request. Default value
is "2019-02-02".
:paramtype api_version: str or None
Expand All @@ -112,7 +116,9 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
self.table_name: str = table_name
self.encoder = TableEntityEncoder(convert_map=encoder_map)
self.decoder = TableEntityDecoder(convert_map=decoder_map, flatten_result_entity=flatten_result_entity)
super(TableClient, self).__init__(endpoint, credential=credential, api_version=api_version, **kwargs)
super(TableClient, self).__init__(
endpoint, credential=credential, api_version=api_version, audience=audience, **kwargs
)

@classmethod
def from_connection_string(cls, conn_str: str, table_name: str, **kwargs: Any) -> "TableClient":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,14 @@ def _configure_credential(
Union[AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential, SharedKeyCredentialPolicy]
],
cosmos_endpoint: bool = False,
audience: Optional[str] = None,
) -> Optional[Union[AsyncBearerTokenChallengePolicy, AzureSasCredentialPolicy, SharedKeyCredentialPolicy]]:
if hasattr(credential, "get_token"):
credential = cast(AsyncTokenCredential, credential)
scope = COSMOS_OAUTH_SCOPE if cosmos_endpoint else STORAGE_OAUTH_SCOPE
if audience:
scope = audience.rstrip("/") + "/.default"
else:
scope = COSMOS_OAUTH_SCOPE if cosmos_endpoint else STORAGE_OAUTH_SCOPE
return AsyncBearerTokenChallengePolicy(credential, scope)
if isinstance(credential, SharedKeyCredentialPolicy):
return credential
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
endpoint: str,
*,
credential: Optional[Union[AzureSasCredential, AzureNamedKeyCredential, AsyncTokenCredential]] = None,
audience: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs: Any,
) -> None:
Expand All @@ -70,6 +71,9 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
~azure.core.credentials.AzureNamedKeyCredential or
~azure.core.credentials.AzureSasCredential or
~azure.core.credentials_async.AsyncTokenCredential or None
:keyword audience: Optional audience to use for Microsoft Entra ID authentication. If not specified,
the public cloud audience will be used.
:paramtype audience: str or None
:keyword api_version: Specifies the version of the operation to use for this request. Default value
is "2019-02-02".
:paramtype api_version: str or None
Expand Down Expand Up @@ -116,7 +120,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
}
self._hosts = _hosts

self._policies = self._configure_policies(hosts=self._hosts, **kwargs)
self._policies = self._configure_policies(audience=audience, hosts=self._hosts, **kwargs)
if self._cosmos_endpoint:
self._policies.insert(0, CosmosPatchTransformPolicy())

Expand Down Expand Up @@ -215,8 +219,8 @@ def _format_url(self, hostname):
"""
return f"{self.scheme}://{hostname}{self._query_str}"

def _configure_policies(self, **kwargs):
credential_policy = _configure_credential(self.credential, self._cosmos_endpoint)
def _configure_policies(self, *, audience: Optional[str] = None, **kwargs: Any) -> List[Any]:
credential_policy = _configure_credential(self.credential, self._cosmos_endpoint, audience=audience)
return [
RequestIdPolicy(**kwargs),
StorageHeadersPolicy(**kwargs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
table_name: str,
*,
credential: Optional[Union[AzureSasCredential, AzureNamedKeyCredential, AsyncTokenCredential]] = None,
audience: Optional[str] = None,
api_version: Optional[str] = None,
encoder_map: Optional[EncoderMapType] = None,
decoder_map: Optional[DecoderMapType] = None,
Expand All @@ -92,6 +93,9 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
~azure.core.credentials.AzureNamedKeyCredential or
~azure.core.credentials.AzureSasCredential or
~azure.core.credentials_async.AsyncTokenCredential or None
:keyword audience: Optional audience to use for Microsoft Entra ID authentication. If not specified,
the public cloud audience will be used.
:paramtype audience: str or None
:keyword api_version: Specifies the version of the operation to use for this request. Default value
is "2019-02-02".
:paramtype api_version: str or None
Expand All @@ -113,7 +117,9 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
self.table_name: str = table_name
self.encoder = TableEntityEncoder(convert_map=encoder_map)
self.decoder = TableEntityDecoder(convert_map=decoder_map, flatten_result_entity=flatten_result_entity)
super(TableClient, self).__init__(endpoint, credential=credential, api_version=api_version, **kwargs)
super(TableClient, self).__init__(
endpoint, credential=credential, api_version=api_version, audience=audience, **kwargs
)

@classmethod
def from_connection_string(cls, conn_str: str, table_name: str, **kwargs: Any) -> "TableClient":
Expand Down
19 changes: 19 additions & 0 deletions sdk/tables/azure-data-tables/tests/test_table_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from datetime import datetime, timedelta
from devtools_testutils import AzureRecordedTestCase, recorded_by_proxy
from unittest.mock import patch

from azure.data.tables import (
TableServiceClient,
Expand Down Expand Up @@ -903,6 +904,24 @@ def test_create_service_with_token(self):
assert service.credential == token_credential
assert not hasattr(service.credential, "account_key")

@pytest.mark.parametrize("client_class", SERVICES)
def test_create_service_client_with_custom_audience(self, client_class):
url = self.account_url(self.tables_storage_account_name, "table")
token_credential = self.get_token_credential()
custom_audience = "https://foo.bar"
expected_scope = custom_audience + "/.default"

# Test with patching to verify BearerTokenChallengePolicy is created with the proper scope.
with patch("azure.data.tables._authentication.BearerTokenChallengePolicy") as mock_policy:
client_class(
url,
credential=token_credential,
table_name="foo",
audience=custom_audience,
)

mock_policy.assert_called_with(token_credential, expected_scope)

def test_create_client_with_api_version(self):
url = self.account_url(self.tables_storage_account_name, "table")
client = TableServiceClient(url, credential=self.credential)
Expand Down
20 changes: 20 additions & 0 deletions sdk/tables/azure-data-tables/tests/test_table_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datetime import datetime, timedelta
from devtools_testutils import AzureRecordedTestCase
from devtools_testutils.aio import recorded_by_proxy_async
from unittest.mock import patch

from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError, ClientAuthenticationError
Expand Down Expand Up @@ -670,6 +671,25 @@ async def test_create_service_with_token_async(self):
assert service.credential == token_credential
assert not hasattr(service.credential, "account_key")

@pytest.mark.asyncio
@pytest.mark.parametrize("client_class", SERVICES)
def test_create_service_client_with_custom_audience(self, client_class):
url = self.account_url(self.tables_storage_account_name, "table")
token_credential = self.get_token_credential()
custom_audience = "https://foo.bar"
expected_scope = custom_audience + "/.default"

# Test with patching to verify AsyncBearerTokenChallengePolicy is created with the proper scope.
with patch("azure.data.tables.aio._authentication_async.AsyncBearerTokenChallengePolicy") as mock_policy:
client_class(
url,
credential=token_credential,
table_name="foo",
audience=custom_audience,
)

mock_policy.assert_called_with(token_credential, expected_scope)

@pytest.mark.skip("HTTP prefix does not raise an error")
@pytest.mark.asyncio
async def test_create_service_with_token_and_http(self):
Expand Down
Loading