From 71908b3189cc38af8a961305e58dc3bb638580c8 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 10 Jan 2025 10:54:53 -0800 Subject: [PATCH 01/11] added support for azure managed identity auth --- python/delta_sharing/_internal_auth.py | 103 ++++++++++++++++++++++++ python/delta_sharing/tests/test_auth.py | 78 +++++++++++++++++- 2 files changed, 179 insertions(+), 2 deletions(-) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 3b1d35103..694c308af 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -158,6 +158,7 @@ def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: raise RuntimeError( "'expires_in' field must be an integer or a string convertible to integer" ) + #print(json_node['access_token']) return OAuthClientCredentials( json_node['access_token'], expires_in, @@ -198,16 +199,109 @@ def get_expiration_time(self) -> Optional[str]: return None +class AzureManagedIdentityAuthProvider(AuthCredentialProvider): + def __init__(self, auth_config: AuthConfig = AuthConfig()): + self.auth_config = auth_config + self.current_token: Optional[OAuthClientCredentials] = None + self.lock = threading.RLock() + + + def managed_identity_token(self) -> OAuthClientCredentials: + # Azure IMDS endpoint to get the access token. + # This interface allows any client application running on the VM to acquire an access token via HTTP REST calls. + # For more details, see: + # https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + url = "http://169.254.169.254/metadata/identity/oauth2/token" + + resource = "https://management.azure.com/" + # Headers required to access Azure Instance Metadata Service + headers = {"Metadata": "true"} + + # Parameters to specify the resource and API version + params = { + "api-version": "2019-08-01", + "resource": resource + } + + # Make the GET request to fetch the token + response = requests.get(url, headers=headers, params=params) + + # Check if the request was successful + if response.status_code == 200: + # Return the access token + return self.parse_oauth_token_response(response.text) + + else: + # Handle errors + raise Exception(f"Failed to obtain token: {response.status_code} - {response.text}") + + def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: + if not response: + raise RuntimeError("Empty response from OAuth token endpoint") + # Parsing the response according to azure managed identity spec + # https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + json_node = json.loads(response) + if 'access_token' not in json_node or not isinstance(json_node['access_token'], str): + raise RuntimeError("Missing 'access_token' field in OAuth token response") + if 'expires_in' not in json_node: + raise RuntimeError("Missing 'expires_in' field in OAuth token response") + try: + expires_in = int(json_node['expires_in']) # Convert to int if it's a string + except ValueError: + raise RuntimeError( + "'expires_in' field must be an integer or a string convertible to integer" + ) + return OAuthClientCredentials( + json_node['access_token'], + expires_in, + int(datetime.now().timestamp()) + ) + + def add_auth_header(self,session: requests.Session) -> None: + token = self.maybe_refresh_token() + + with self.lock: + session.headers.update( + { + "Authorization": f"Bearer {token.access_token}", + } + ) + + def maybe_refresh_token(self) -> OAuthClientCredentials: + with self.lock: + if self.current_token and not self.needs_refresh(self.current_token): + return self.current_token + new_token = self.managed_identity_token() + self.current_token = new_token + return new_token + + def needs_refresh(self, token: OAuthClientCredentials) -> bool: + now = int(time.time()) + expiration_time = token.creation_timestamp + token.expires_in + return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds + + def is_expired(self) -> bool: + return False + + def get_expiration_time(self) -> Optional[str]: + return None + class AuthCredentialProviderFactory: __oauth_auth_provider_cache : Dict[ DeltaSharingProfile, OAuthClientCredentialsAuthProvider] = {} + __managed_identity_provider_cache : Dict[ + DeltaSharingProfile, + AzureManagedIdentityAuthProvider] = {} + @staticmethod def create_auth_credential_provider(profile: DeltaSharingProfile): if profile.share_credentials_version == 2: if profile.type == "oauth_client_credentials": return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "experimental_managed_identity": + return AuthCredentialProviderFactory.__experimental_managed_identity(profile) elif profile.type == "basic": return AuthCredentialProviderFactory.__auth_basic(profile) elif (profile.share_credentials_version == 1 and @@ -251,3 +345,12 @@ def __auth_bearer_token(profile): @staticmethod def __auth_basic(profile): return BasicAuthProvider(profile.endpoint, profile.username, profile.password) + + @staticmethod + def __experimental_managed_identity(profile): + if profile in AuthCredentialProviderFactory.__managed_identity_provider_cache: + return AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] + + provider = AzureManagedIdentityAuthProvider() + AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider + return provider \ No newline at end of file diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index 81dbd1ba9..3f57b3843 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -14,13 +14,13 @@ # limitations under the License. # -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from datetime import datetime, timedelta from delta_sharing._internal_auth import (OAuthClient, BasicAuthProvider, AuthCredentialProviderFactory, OAuthClientCredentialsAuthProvider, - OAuthClientCredentials) + OAuthClientCredentials, AzureManagedIdentityAuthProvider) from requests import Session import requests from delta_sharing._internal_auth import BearerTokenAuthProvider @@ -292,3 +292,77 @@ def test_oauth_auth_provider_with_different_profiles(): provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) assert provider1 != provider2 + +def test_azure_managed_identity_auth_provider_initialization(): + provider = AzureManagedIdentityAuthProvider() + assert provider.current_token is None + + +def test_azure_managed_identity_auth_provider_add_auth_header(): + provider = AzureManagedIdentityAuthProvider() + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + token = OAuthClientCredentials("access-token", 3600, int(datetime.now().timestamp())) + provider.current_token = token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {token.access_token}"} + ) + + +def test_azure_managed_identity_auth_provider_refresh_token(): + provider = AzureManagedIdentityAuthProvider() + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + new_token = OAuthClientCredentials( + "new-token", 3600, int(datetime.now().timestamp())) + provider.current_token = expired_token + + with patch.object(provider, 'managed_identity_token', return_value=new_token): + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {new_token.access_token}"} + ) + + +def test_azure_managed_identity_auth_provider_needs_refresh(): + provider = AzureManagedIdentityAuthProvider() + + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + assert provider.needs_refresh(expired_token) + + token_expiring_soon = OAuthClientCredentials( + "expiring-soon-token", 600 - 5, int(datetime.now().timestamp())) + assert provider.needs_refresh(token_expiring_soon) + + valid_token = OAuthClientCredentials( + "valid-token", 600 + 10, int(datetime.now().timestamp())) + assert not provider.needs_refresh(valid_token) + + +def test_azure_managed_identity_auth_provider_is_expired(): + provider = AzureManagedIdentityAuthProvider() + assert not provider.is_expired() + + +def test_azure_managed_identity_auth_provider_get_expiration_time(): + provider = AzureManagedIdentityAuthProvider() + assert provider.get_expiration_time() is None + + +def test_factory_creation_managed_identity(): + profile_managed_identity = DeltaSharingProfile( + share_credentials_version=2, + type="experimental_managed_identity", + endpoint="https://localhost/delta-sharing/" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_managed_identity) + assert isinstance(provider, AzureManagedIdentityAuthProvider) \ No newline at end of file From 28713d9c2f868a13513071713e1f475146b71d36 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 10 Jan 2025 11:06:01 -0800 Subject: [PATCH 02/11] cleanup --- python/delta_sharing/_internal_auth.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 694c308af..28d43806c 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -158,7 +158,6 @@ def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: raise RuntimeError( "'expires_in' field must be an integer or a string convertible to integer" ) - #print(json_node['access_token']) return OAuthClientCredentials( json_node['access_token'], expires_in, From 75a5975ce2f64b8fb7a7ba0b1ee0bc3b1299d08a Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 10 Jan 2025 12:09:27 -0800 Subject: [PATCH 03/11] slight refactoring with more tests --- python/delta_sharing/_internal_auth.py | 27 ++++-- python/delta_sharing/tests/test_auth.py | 38 ++++---- .../tests/test_managed_identity_client.py | 92 +++++++++++++++++++ .../delta_sharing/tests/test_oauth_client.py | 2 + 4 files changed, 134 insertions(+), 25 deletions(-) create mode 100644 python/delta_sharing/tests/test_managed_identity_client.py diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 28d43806c..05ed6f16c 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -198,12 +198,9 @@ def get_expiration_time(self) -> Optional[str]: return None -class AzureManagedIdentityAuthProvider(AuthCredentialProvider): - def __init__(self, auth_config: AuthConfig = AuthConfig()): - self.auth_config = auth_config - self.current_token: Optional[OAuthClientCredentials] = None - self.lock = threading.RLock() - +class AzureManagedIdentityClient: + def __init__(self): + None def managed_identity_token(self) -> OAuthClientCredentials: # Azure IMDS endpoint to get the access token. @@ -224,6 +221,7 @@ def managed_identity_token(self) -> OAuthClientCredentials: # Make the GET request to fetch the token response = requests.get(url, headers=headers, params=params) + response.raise_for_status() # Check if the request was successful if response.status_code == 200: @@ -236,7 +234,7 @@ def managed_identity_token(self) -> OAuthClientCredentials: def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: if not response: - raise RuntimeError("Empty response from OAuth token endpoint") + raise RuntimeError("Empty response from azure managed identity endpoint") # Parsing the response according to azure managed identity spec # https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http json_node = json.loads(response) @@ -256,6 +254,15 @@ def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: int(datetime.now().timestamp()) ) + +class AzureManagedIdentityAuthProvider(AuthCredentialProvider): + def __init__(self, managed_identity_client: AzureManagedIdentityClient, auth_config: AuthConfig = AuthConfig()): + self.auth_config = auth_config + self.managed_identity_client = managed_identity_client + self.current_token: Optional[OAuthClientCredentials] = None + self.lock = threading.RLock() + + def add_auth_header(self,session: requests.Session) -> None: token = self.maybe_refresh_token() @@ -270,7 +277,7 @@ def maybe_refresh_token(self) -> OAuthClientCredentials: with self.lock: if self.current_token and not self.needs_refresh(self.current_token): return self.current_token - new_token = self.managed_identity_token() + new_token = self.managed_identity_client.managed_identity_token() self.current_token = new_token return new_token @@ -285,6 +292,7 @@ def is_expired(self) -> bool: def get_expiration_time(self) -> Optional[str]: return None + class AuthCredentialProviderFactory: __oauth_auth_provider_cache : Dict[ DeltaSharingProfile, @@ -350,6 +358,7 @@ def __experimental_managed_identity(profile): if profile in AuthCredentialProviderFactory.__managed_identity_provider_cache: return AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] - provider = AzureManagedIdentityAuthProvider() + managed_identity_client = AzureManagedIdentityClient() + provider = AzureManagedIdentityAuthProvider(managed_identity_client = managed_identity_client) AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider return provider \ No newline at end of file diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index 3f57b3843..cd66701bb 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -20,13 +20,13 @@ BasicAuthProvider, AuthCredentialProviderFactory, OAuthClientCredentialsAuthProvider, - OAuthClientCredentials, AzureManagedIdentityAuthProvider) + OAuthClientCredentials, AzureManagedIdentityAuthProvider, + AzureManagedIdentityClient) from requests import Session import requests from delta_sharing._internal_auth import BearerTokenAuthProvider from delta_sharing.protocol import DeltaSharingProfile - def test_bearer_token_auth_provider_initialization(): token = "test-token" expiration_time = "2021-11-12T00:12:29.0Z" @@ -293,13 +293,18 @@ def test_oauth_auth_provider_with_different_profiles(): assert provider1 != provider2 + + + + def test_azure_managed_identity_auth_provider_initialization(): - provider = AzureManagedIdentityAuthProvider() + mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) assert provider.current_token is None - def test_azure_managed_identity_auth_provider_add_auth_header(): - provider = AzureManagedIdentityAuthProvider() + mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) mock_session = MagicMock(spec=Session) mock_session.headers = MagicMock() @@ -312,9 +317,9 @@ def test_azure_managed_identity_auth_provider_add_auth_header(): {"Authorization": f"Bearer {token.access_token}"} ) - def test_azure_managed_identity_auth_provider_refresh_token(): - provider = AzureManagedIdentityAuthProvider() + mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_managed_identity_client) mock_session = MagicMock(spec=Session) mock_session.headers = MagicMock() @@ -324,16 +329,18 @@ def test_azure_managed_identity_auth_provider_refresh_token(): "new-token", 3600, int(datetime.now().timestamp())) provider.current_token = expired_token - with patch.object(provider, 'managed_identity_token', return_value=new_token): - provider.add_auth_header(mock_session) + mock_managed_identity_client.managed_identity_token.return_value = new_token + + provider.add_auth_header(mock_session) mock_session.headers.update.assert_called_once_with( {"Authorization": f"Bearer {new_token.access_token}"} ) - + mock_managed_identity_client.managed_identity_token.assert_called_once() def test_azure_managed_identity_auth_provider_needs_refresh(): - provider = AzureManagedIdentityAuthProvider() + mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) expired_token = OAuthClientCredentials( "expired-token", 1, int(datetime.now().timestamp()) - 3600) @@ -347,17 +354,16 @@ def test_azure_managed_identity_auth_provider_needs_refresh(): "valid-token", 600 + 10, int(datetime.now().timestamp())) assert not provider.needs_refresh(valid_token) - def test_azure_managed_identity_auth_provider_is_expired(): - provider = AzureManagedIdentityAuthProvider() + mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) assert not provider.is_expired() - def test_azure_managed_identity_auth_provider_get_expiration_time(): - provider = AzureManagedIdentityAuthProvider() + mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) assert provider.get_expiration_time() is None - def test_factory_creation_managed_identity(): profile_managed_identity = DeltaSharingProfile( share_credentials_version=2, diff --git a/python/delta_sharing/tests/test_managed_identity_client.py b/python/delta_sharing/tests/test_managed_identity_client.py new file mode 100644 index 000000000..2fa60b4e8 --- /dev/null +++ b/python/delta_sharing/tests/test_managed_identity_client.py @@ -0,0 +1,92 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +import requests +from requests.models import Response +from unittest.mock import patch +from datetime import datetime +from delta_sharing._internal_auth import OAuthClient + +from datetime import datetime +from unittest.mock import patch + +import pytest +from requests import Response +from delta_sharing._internal_auth import AzureManagedIdentityClient + +class MockServer: + def __init__(self): + self.url = "http://169.254.169.254/metadata/identity/oauth2/token" + self.responses = [] + + def add_response(self, status_code, json_data): + response = Response() + response.status_code = status_code + response._content = json_data.encode('utf-8') + self.responses.append(response) + + def get_response(self): + return self.responses.pop(0) + + +@pytest.fixture +def mock_server(): + server = MockServer() + yield server + + +@pytest.mark.parametrize("response_data, expected_expires_in, expected_access_token", [ + ( + '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}', + 3600, + "test-access-token" + ), + ( + '{"access_token": "test-access-token", "expires_in": "3600", "token_type": "bearer"}', + 3600, + "test-access-token" + ) +]) +def test_managed_identity_client_should_parse_token_response_correctly(mock_server, + response_data, + expected_expires_in, + expected_access_token): + mock_server.add_response(200, response_data) + + with patch('requests.get') as mock_get: + mock_get.side_effect = lambda *args, **kwargs: mock_server.get_response() + client = AzureManagedIdentityClient() + + start = datetime.now().timestamp() + token = client.managed_identity_token() + end = datetime.now().timestamp() + + assert token.access_token == expected_access_token + assert token.expires_in == expected_expires_in + assert int(start) <= token.creation_timestamp + assert token.creation_timestamp <= int(end) + + +def test_managed_identity_client_should_handle_500_internal_server_error(mock_server): + mock_server.add_response(500, 'Internal Server Error') + + with patch('requests.get') as mock_get: + mock_get.side_effect = lambda *args, **kwargs: mock_server.get_response() + client = AzureManagedIdentityClient() + try: + client.managed_identity_token() + except Exception as e: + assert e.response.status_code == 500 \ No newline at end of file diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index d40ab8e82..042a889a1 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -99,3 +99,5 @@ def test_oauth_client_should_handle_401_unauthorized_response(mock_server): oauth_client.client_credentials() except requests.HTTPError as e: assert e.response.status_code == 401 + + From 34e1ab65cf980e42231b65c85e79284b511c7111 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 10 Jan 2025 12:40:49 -0800 Subject: [PATCH 04/11] fixed style --- python/delta_sharing/_internal_auth.py | 6 +-- python/delta_sharing/tests/test_auth.py | 37 ++++++++++--------- .../tests/test_managed_identity_client.py | 15 ++++---- .../delta_sharing/tests/test_oauth_client.py | 1 - 4 files changed, 31 insertions(+), 28 deletions(-) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 05ed6f16c..7cc77b391 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -204,7 +204,8 @@ def __init__(self): def managed_identity_token(self) -> OAuthClientCredentials: # Azure IMDS endpoint to get the access token. - # This interface allows any client application running on the VM to acquire an access token via HTTP REST calls. + # This interface allows any client application running on the Azure VM + # to acquire an access token via HTTP REST calls. # For more details, see: # https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http url = "http://169.254.169.254/metadata/identity/oauth2/token" @@ -262,7 +263,6 @@ def __init__(self, managed_identity_client: AzureManagedIdentityClient, auth_co self.current_token: Optional[OAuthClientCredentials] = None self.lock = threading.RLock() - def add_auth_header(self,session: requests.Session) -> None: token = self.maybe_refresh_token() @@ -359,6 +359,6 @@ def __experimental_managed_identity(profile): return AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] managed_identity_client = AzureManagedIdentityClient() - provider = AzureManagedIdentityAuthProvider(managed_identity_client = managed_identity_client) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=managed_identity_client) AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider return provider \ No newline at end of file diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index cd66701bb..619a09e62 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -27,6 +27,7 @@ from delta_sharing._internal_auth import BearerTokenAuthProvider from delta_sharing.protocol import DeltaSharingProfile + def test_bearer_token_auth_provider_initialization(): token = "test-token" expiration_time = "2021-11-12T00:12:29.0Z" @@ -294,17 +295,15 @@ def test_oauth_auth_provider_with_different_profiles(): assert provider1 != provider2 - - - def test_azure_managed_identity_auth_provider_initialization(): - mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) - provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) assert provider.current_token is None + def test_azure_managed_identity_auth_provider_add_auth_header(): - mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) - provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) mock_session = MagicMock(spec=Session) mock_session.headers = MagicMock() @@ -317,9 +316,10 @@ def test_azure_managed_identity_auth_provider_add_auth_header(): {"Authorization": f"Bearer {token.access_token}"} ) + def test_azure_managed_identity_auth_provider_refresh_token(): - mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) - provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_managed_identity_client) + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) mock_session = MagicMock(spec=Session) mock_session.headers = MagicMock() @@ -329,18 +329,19 @@ def test_azure_managed_identity_auth_provider_refresh_token(): "new-token", 3600, int(datetime.now().timestamp())) provider.current_token = expired_token - mock_managed_identity_client.managed_identity_token.return_value = new_token + mock_client.managed_identity_token.return_value = new_token provider.add_auth_header(mock_session) mock_session.headers.update.assert_called_once_with( {"Authorization": f"Bearer {new_token.access_token}"} ) - mock_managed_identity_client.managed_identity_token.assert_called_once() + mock_client.managed_identity_token.assert_called_once() + def test_azure_managed_identity_auth_provider_needs_refresh(): - mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) - provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) expired_token = OAuthClientCredentials( "expired-token", 1, int(datetime.now().timestamp()) - 3600) @@ -355,15 +356,17 @@ def test_azure_managed_identity_auth_provider_needs_refresh(): assert not provider.needs_refresh(valid_token) def test_azure_managed_identity_auth_provider_is_expired(): - mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) - provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) assert not provider.is_expired() + def test_azure_managed_identity_auth_provider_get_expiration_time(): - mock_managed_identity_client = MagicMock(spec=AzureManagedIdentityClient) - provider = AzureManagedIdentityAuthProvider(managed_identity_client = mock_managed_identity_client) + mock_client = MagicMock(spec=AzureManagedIdentityClient) + provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) assert provider.get_expiration_time() is None + def test_factory_creation_managed_identity(): profile_managed_identity = DeltaSharingProfile( share_credentials_version=2, diff --git a/python/delta_sharing/tests/test_managed_identity_client.py b/python/delta_sharing/tests/test_managed_identity_client.py index 2fa60b4e8..0a5080c15 100644 --- a/python/delta_sharing/tests/test_managed_identity_client.py +++ b/python/delta_sharing/tests/test_managed_identity_client.py @@ -27,6 +27,7 @@ from requests import Response from delta_sharing._internal_auth import AzureManagedIdentityClient + class MockServer: def __init__(self): self.url = "http://169.254.169.254/metadata/identity/oauth2/token" @@ -50,14 +51,14 @@ def mock_server(): @pytest.mark.parametrize("response_data, expected_expires_in, expected_access_token", [ ( - '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}', - 3600, - "test-access-token" + '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}', + 3600, + "test-access-token" ), ( - '{"access_token": "test-access-token", "expires_in": "3600", "token_type": "bearer"}', - 3600, - "test-access-token" + '{"access_token": "test-access-token", "expires_in": "3600", "token_type": "bearer"}', + 3600, + "test-access-token" ) ]) def test_managed_identity_client_should_parse_token_response_correctly(mock_server, @@ -89,4 +90,4 @@ def test_managed_identity_client_should_handle_500_internal_server_error(mock_se try: client.managed_identity_token() except Exception as e: - assert e.response.status_code == 500 \ No newline at end of file + assert e.response.status_code == 500 diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index 042a889a1..2e1825f60 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -100,4 +100,3 @@ def test_oauth_client_should_handle_401_unauthorized_response(mock_server): except requests.HTTPError as e: assert e.response.status_code == 401 - From aa91941d72575758a197b9b04e09b28d9b9d6176 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 10 Jan 2025 12:51:12 -0800 Subject: [PATCH 05/11] fixed style --- python/delta_sharing/_internal_auth.py | 6 ++++-- python/delta_sharing/tests/test_auth.py | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 7cc77b391..1ab144030 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -257,7 +257,9 @@ def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: class AzureManagedIdentityAuthProvider(AuthCredentialProvider): - def __init__(self, managed_identity_client: AzureManagedIdentityClient, auth_config: AuthConfig = AuthConfig()): + def __init__(self, + managed_identity_client: AzureManagedIdentityClient, + auth_config: AuthConfig = AuthConfig()): self.auth_config = auth_config self.managed_identity_client = managed_identity_client self.current_token: Optional[OAuthClientCredentials] = None @@ -361,4 +363,4 @@ def __experimental_managed_identity(profile): managed_identity_client = AzureManagedIdentityClient() provider = AzureManagedIdentityAuthProvider(managed_identity_client=managed_identity_client) AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider - return provider \ No newline at end of file + return provider diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index 619a09e62..5ea43d9b0 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -355,6 +355,7 @@ def test_azure_managed_identity_auth_provider_needs_refresh(): "valid-token", 600 + 10, int(datetime.now().timestamp())) assert not provider.needs_refresh(valid_token) + def test_azure_managed_identity_auth_provider_is_expired(): mock_client = MagicMock(spec=AzureManagedIdentityClient) provider = AzureManagedIdentityAuthProvider(managed_identity_client=mock_client) @@ -373,5 +374,7 @@ def test_factory_creation_managed_identity(): type="experimental_managed_identity", endpoint="https://localhost/delta-sharing/" ) - provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_managed_identity) - assert isinstance(provider, AzureManagedIdentityAuthProvider) \ No newline at end of file + provider = AuthCredentialProviderFactory.create_auth_credential_provider( + profile_managed_identity + ) + assert isinstance(provider, AzureManagedIdentityAuthProvider) From 0ac7f0362d2d3f7dfdad71f781aa741e2b03c165 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 10 Jan 2025 12:59:36 -0800 Subject: [PATCH 06/11] fixed style --- python/delta_sharing/tests/test_oauth_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index 2e1825f60..d40ab8e82 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -99,4 +99,3 @@ def test_oauth_client_should_handle_401_unauthorized_response(mock_server): oauth_client.client_credentials() except requests.HTTPError as e: assert e.response.status_code == 401 - From bb313cb1cfdf7978942ba1db25f70f2f07fd7809 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 10 Jan 2025 13:09:40 -0800 Subject: [PATCH 07/11] removed unused imports --- python/delta_sharing/tests/test_auth.py | 2 +- python/delta_sharing/tests/test_managed_identity_client.py | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index 5ea43d9b0..945915e86 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -14,7 +14,7 @@ # limitations under the License. # -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from datetime import datetime, timedelta from delta_sharing._internal_auth import (OAuthClient, BasicAuthProvider, diff --git a/python/delta_sharing/tests/test_managed_identity_client.py b/python/delta_sharing/tests/test_managed_identity_client.py index 0a5080c15..fd36f6a20 100644 --- a/python/delta_sharing/tests/test_managed_identity_client.py +++ b/python/delta_sharing/tests/test_managed_identity_client.py @@ -13,13 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import pytest -import requests -from requests.models import Response -from unittest.mock import patch -from datetime import datetime -from delta_sharing._internal_auth import OAuthClient - from datetime import datetime from unittest.mock import patch From 596f80d63191a08d2f07b001c2b2c3ff986a425f Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 16 Jan 2025 13:48:30 -0800 Subject: [PATCH 08/11] updated the missing file --- python/delta_sharing/protocol.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index a99acfb14..f3a46290f 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -99,6 +99,12 @@ def from_json(json) -> "DeltaSharingProfile": bearer_token=json["bearerToken"], expiration_time=json.get("expirationTime") ) + elif type == "experimental_managed_identity": + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + type=type, + endpoint=endpoint + ) elif type == "basic": return DeltaSharingProfile( share_credentials_version=share_credentials_version, From c1b45c841a98b87ee2a69d1ec63b9ded345e6b54 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 14 Mar 2025 16:19:24 -0700 Subject: [PATCH 09/11] managed-identity azure and gcp prototype --- python/delta_sharing/_internal_auth.py | 79 ++++++++++++++++++++++++-- python/delta_sharing/protocol.py | 7 ++- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 1ab144030..e980b5c3c 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -198,7 +198,15 @@ def get_expiration_time(self) -> Optional[str]: return None -class AzureManagedIdentityClient: +class OidcManagedIdentityClient: + def managed_identity_token(self) -> OAuthClientCredentials: + None + + def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: + None + + +class AzureManagedIdentityClient(OidcManagedIdentityClient): def __init__(self): None @@ -256,9 +264,66 @@ def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: ) -class AzureManagedIdentityAuthProvider(AuthCredentialProvider): +class OAuthClientCredentials: + def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): + self.access_token = access_token + self.expires_in = expires_in + self.creation_timestamp = creation_timestamp + +class GCPManagedIdentityOIDCClient(OidcManagedIdentityClient): + def __init__(self, audience: str): + if not audience: + raise ValueError("Audience must be specified for OIDC token request.") + self.audience = audience + + def managed_identity_token(self) -> OAuthClientCredentials: + """ + Fetches an OIDC token from the GCP metadata server for the specified audience + and returns it as an OAuthClientCredentials object. + """ + url = f"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity?audience={self.audience}&format=full" + headers = {"Metadata-Flavor": "Google"} + + response = requests.get(url, headers=headers) + response.raise_for_status() + + token = response.text # JWT OIDC Token + return self.parse_oidc_token(token) + + def parse_oidc_token(self, token: str) -> OAuthClientCredentials: + """ + Decodes the OIDC token and extracts useful claims such as expiration time. + Returns an OAuthClientCredentials object. + """ + if not token: + raise RuntimeError("Empty OIDC token received from GCP managed identity endpoint.") + + # JWT tokens are base64 encoded and consist of three parts: header, payload, and signature + try: + payload_encoded = token.split(".")[1] # Extract payload (second part of JWT) + payload_decoded = json.loads(base64.urlsafe_b64decode(payload_encoded + "==").decode("utf-8")) + except Exception as e: + raise RuntimeError(f"Failed to decode OIDC token payload: {str(e)}") + + if 'exp' not in payload_decoded: + raise RuntimeError("Missing 'exp' field in OIDC token payload.") + + # Calculate expiration time + expiration_time = int(payload_decoded['exp']) + current_timestamp = int(datetime.datetime.utcnow().timestamp()) + expires_in = expiration_time - current_timestamp # Time left in seconds + + return OAuthClientCredentials( + access_token=token, + expires_in=expires_in, + creation_timestamp=current_timestamp + ) + + + +class ManagedIdentityAuthProvider(AuthCredentialProvider): def __init__(self, - managed_identity_client: AzureManagedIdentityClient, + managed_identity_client: OidcManagedIdentityClient, auth_config: AuthConfig = AuthConfig()): self.auth_config = auth_config self.managed_identity_client = managed_identity_client @@ -360,7 +425,11 @@ def __experimental_managed_identity(profile): if profile in AuthCredentialProviderFactory.__managed_identity_provider_cache: return AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] - managed_identity_client = AzureManagedIdentityClient() - provider = AzureManagedIdentityAuthProvider(managed_identity_client=managed_identity_client) + if profile.cloud_provider == "azure" + managed_identity_client = AzureManagedIdentityClient() + elif profile.cloud_provider == "gcp" + managed_identity_client = GCPManagedIdentityOIDCClient(audience="my-audience") + + provider = ManagedIdentityAuthProvider(managed_identity_client=managed_identity_client) AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider return provider diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index f3a46290f..85611469a 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -37,6 +37,10 @@ class DeltaSharingProfile: password: Optional[str] = None scope: Optional[str] = None + # this is only applicable to the experimental_managed_identity auth mode + # for the experimental_managed_identity specifies in which cloud the compute is running + cloud_provider: Option[str] = None + def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: raise ValueError( @@ -103,7 +107,8 @@ def from_json(json) -> "DeltaSharingProfile": return DeltaSharingProfile( share_credentials_version=share_credentials_version, type=type, - endpoint=endpoint + endpoint=endpoint, + cloud_provider=json.get("cloud_provider") ) elif type == "basic": return DeltaSharingProfile( From 64e72e6bc9fd90f5513ff551fbb015e24a8a6e53 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Sun, 16 Mar 2025 11:55:17 -0700 Subject: [PATCH 10/11] added audience and fixed cloud-provider --- python/delta_sharing/_internal_auth.py | 2 +- python/delta_sharing/protocol.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index e980b5c3c..3c2ace5a9 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -428,7 +428,7 @@ def __experimental_managed_identity(profile): if profile.cloud_provider == "azure" managed_identity_client = AzureManagedIdentityClient() elif profile.cloud_provider == "gcp" - managed_identity_client = GCPManagedIdentityOIDCClient(audience="my-audience") + managed_identity_client = GCPManagedIdentityOIDCClient(audience=profile.audience) provider = ManagedIdentityAuthProvider(managed_identity_client=managed_identity_client) AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] = provider diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index 85611469a..e7d060b51 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -36,10 +36,11 @@ class DeltaSharingProfile: username: Optional[str] = None password: Optional[str] = None scope: Optional[str] = None + audience: Optional[str] = None # this is only applicable to the experimental_managed_identity auth mode # for the experimental_managed_identity specifies in which cloud the compute is running - cloud_provider: Option[str] = None + cloud_provider: Optional[str] = None def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: @@ -108,7 +109,8 @@ def from_json(json) -> "DeltaSharingProfile": share_credentials_version=share_credentials_version, type=type, endpoint=endpoint, - cloud_provider=json.get("cloud_provider") + cloud_provider=json.get("cloud_provider"), + audience=json.get("audience") ) elif type == "basic": return DeltaSharingProfile( From 886d4e77aa2af8e986af3ff15c1f8139377a678b Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Sun, 16 Mar 2025 12:11:39 -0700 Subject: [PATCH 11/11] fix --- python/delta_sharing/_internal_auth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 3c2ace5a9..408e577fc 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -310,7 +310,7 @@ def parse_oidc_token(self, token: str) -> OAuthClientCredentials: # Calculate expiration time expiration_time = int(payload_decoded['exp']) - current_timestamp = int(datetime.datetime.utcnow().timestamp()) + current_timestamp = int(datetime.now().timestamp()) expires_in = expiration_time - current_timestamp # Time left in seconds return OAuthClientCredentials( @@ -367,7 +367,7 @@ class AuthCredentialProviderFactory: __managed_identity_provider_cache : Dict[ DeltaSharingProfile, - AzureManagedIdentityAuthProvider] = {} + ManagedIdentityAuthProvider] = {} @staticmethod def create_auth_credential_provider(profile: DeltaSharingProfile): @@ -425,9 +425,9 @@ def __experimental_managed_identity(profile): if profile in AuthCredentialProviderFactory.__managed_identity_provider_cache: return AuthCredentialProviderFactory.__managed_identity_provider_cache[profile] - if profile.cloud_provider == "azure" + if profile.cloud_provider == "azure": managed_identity_client = AzureManagedIdentityClient() - elif profile.cloud_provider == "gcp" + elif profile.cloud_provider == "gcp": managed_identity_client = GCPManagedIdentityOIDCClient(audience=profile.audience) provider = ManagedIdentityAuthProvider(managed_identity_client=managed_identity_client)