diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 487b46c95..1f02bdaf4 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -115,12 +115,18 @@ def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): class OAuthClient: def __init__( - self, token_endpoint: str, client_id: str, client_secret: str, scope: Optional[str] = None + self, + token_endpoint: str, + client_id: str, + client_secret: str, + scope: Optional[str] = None, + audience: Optional[str] = None, ): self.token_endpoint = token_endpoint self.client_id = client_id self.client_secret = client_secret self.scope = scope + self.audience = audience def client_credentials(self) -> OAuthClientCredentials: credentials = base64.b64encode( @@ -131,7 +137,9 @@ def client_credentials(self) -> OAuthClientCredentials: "authorization": f"Basic {credentials}", "content-type": "application/x-www-form-urlencoded", } - body = f"grant_type=client_credentials{f'&scope={self.scope}' if self.scope else ''}" + scope_chunk = f"&scope={self.scope}" if self.scope else "" + audience_chunk = f"&audience={self.audience}" if self.audience else "" + body = f"grant_type=client_credentials{scope_chunk}{audience_chunk}" response = requests.post(self.token_endpoint, headers=headers, data=body) response.raise_for_status() return self.parse_oauth_token_response(response.text) @@ -241,6 +249,7 @@ def __oauth_client_credentials(profile): client_id=profile.client_id, client_secret=profile.client_secret, scope=profile.scope, + audience=profile.audience, ) provider = OAuthClientCredentialsAuthProvider( oauth_client=oauth_client, auth_config=AuthConfig() diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index 71dee2d9d..25794e989 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -36,6 +36,7 @@ class DeltaSharingProfile: username: Optional[str] = None password: Optional[str] = None scope: Optional[str] = None + audience: Optional[str] = None def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: @@ -90,6 +91,7 @@ def from_json(json) -> "DeltaSharingProfile": client_id=json["clientId"], client_secret=json["clientSecret"], scope=json.get("scope"), + audience=json.get("audience"), ) elif type == "bearer_token": return DeltaSharingProfile( diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index 75650dd74..5b2dce4db 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -74,6 +74,7 @@ def test_oauth_client_credentials_auth_provider_exchange_token(): profile.client_id = "client-id" profile.client_secret = "client-secret" profile.scope = None + profile.audience = None provider = OAuthClientCredentialsAuthProvider(oauth_client) mock_session = MagicMock(spec=Session) @@ -97,6 +98,7 @@ def test_oauth_client_credentials_auth_provider_reuse_token(): profile.client_id = "client-id" profile.client_secret = "client-secret" profile.scope = None + profile.audience = None provider = OAuthClientCredentialsAuthProvider(oauth_client) mock_session = MagicMock(spec=Session) @@ -120,6 +122,7 @@ def test_oauth_client_credentials_auth_provider_refresh_token(): profile.client_id = "client-id" profile.client_secret = "client-secret" profile.scope = None + profile.audience = None provider = OAuthClientCredentialsAuthProvider(oauth_client) mock_session = MagicMock(spec=Session) @@ -147,6 +150,7 @@ def test_oauth_client_credentials_auth_provider_needs_refresh(): profile.client_id = "client-id" profile.client_secret = "client-secret" profile.scope = None + profile.audience = None provider = OAuthClientCredentialsAuthProvider(oauth_client) @@ -171,6 +175,7 @@ def test_oauth_client_credentials_auth_provider_is_expired(): profile.client_id = "client-id" profile.client_secret = "client-secret" profile.scope = None + profile.audience = None provider = OAuthClientCredentialsAuthProvider(oauth_client) assert not provider.is_expired() @@ -183,6 +188,7 @@ def test_oauth_client_credentials_auth_provider_get_expiration_time(): profile.client_id = "client-id" profile.client_secret = "client-secret" profile.scope = None + profile.audience = None provider = OAuthClientCredentialsAuthProvider(oauth_client) assert provider.get_expiration_time() is None