diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 55a15fb06..846adb56c 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -106,6 +106,94 @@ def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): self.creation_timestamp = creation_timestamp +class ManagedIdentityAuthProvider(AuthCredentialProvider): + def __init__(self, auth_config: AuthConfig = AuthConfig()): + self.auth_config = auth_config + self.current_token: Optional[OAuthClientCredentials] = None + self.lock = threading.RLock() + + def get_managed_identity_token(self) -> OAuthClientCredentials: + # Azure IMDS endpoint to get the access token + url = "http://169.254.169.254/metadata/identity/oauth2/token" + + resource = "https://management.azure.com/" + # Headers required to access Azure Instance Metadata Service + headers = {"Metadata": "true"} + + # Parameters to specify the resource and API version + params = { + "api-version": "2019-08-01", + "resource": resource + } + + # Make the GET request to fetch the token + response = requests.get(url, headers=headers, params=params) + + # Check if the request was successful + if response.status_code == 200: + # Return the access token + return self.parse_oauth_token_response(response.text) + + else: + # Handle errors + raise Exception(f"Failed to obtain token: {response.status_code} - {response.text}") + + def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: + if not response: + raise RuntimeError("Empty response from OAuth token endpoint") + json_node = json.loads(response) + if 'access_token' not in json_node or not isinstance(json_node['access_token'], str): + raise RuntimeError("Missing 'access_token' field in OAuth token response") + expires_in = None + if 'expires_in' not in json_node: + raise RuntimeError("Missing 'expires_in' field in OAuth token response") + elif isinstance(json_node['expires_in'], int): + expires_in = json_node['expires_in'] + elif isinstance(json_node['expires_in'], str): + expires_in = int(json_node['expires_in']) + else: + raise RuntimeError("Invalid 'expires_in' field in OAuth token response") + return OAuthClientCredentials( + json_node['access_token'], + expires_in, + int(datetime.now().timestamp()) + ) + + def add_auth_header(self,session: requests.Session) -> None: + token = self.maybe_refresh_token() + + print("######") + print(token.access_token) + print("######") + + with self.lock: + session.headers.update( + { + "Authorization": f"Bearer {token.access_token}", + } + ) + + def maybe_refresh_token(self) -> OAuthClientCredentials: + with self.lock: + if self.current_token and not self.needs_refresh(self.current_token): + return self.current_token + new_token = self.get_managed_identity_token() + self.current_token = new_token + return new_token + + def needs_refresh(self, token: OAuthClientCredentials) -> bool: + now = int(time.time()) + expiration_time = token.creation_timestamp + token.expires_in + return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds + + def is_expired(self) -> bool: + return False + + def get_expiration_time(self) -> Optional[str]: + return None + + + class OAuthClient: def __init__(self, token_endpoint: str, @@ -186,8 +274,10 @@ class AuthCredentialProviderFactory: @staticmethod def create_auth_credential_provider(profile: DeltaSharingProfile): if profile.share_credentials_version == 2: - if profile.type == "oauth_client_credentials": + if profile.type == "oauth_client_credentials" or profile.type == "oidc_client_credentials": return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "oidc_managed_identity": + return AuthCredentialProviderFactory.__oidc_managed_identity(profile) elif profile.type == "basic": return AuthCredentialProviderFactory.__auth_basic(profile) elif (profile.share_credentials_version == 1 and @@ -224,6 +314,22 @@ def __oauth_client_credentials(profile): AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider return provider + @staticmethod + def __oidc_managed_identity(profile): + # Once a clientId/clientSecret is exchanged for an accessToken, + # the accessToken can be reused until it expires. + # The Python client re-creates DeltaSharingClient for different requests. + # To ensure the OAuth access_token is reused, + # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. + # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, + # ensuring the access_token can be reused. + if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: + return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] + + provider = ManagedIdentityAuthProvider() + AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider + return provider + @staticmethod def __auth_bearer_token(profile): return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index d8873c156..243239c84 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -99,6 +99,12 @@ def from_json(json) -> "DeltaSharingProfile": bearer_token=json["bearerToken"], expiration_time=json.get("expirationTime") ) + elif type == "oidc_managed_identity": + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + type=type, + endpoint=endpoint + ) elif type == "basic": return DeltaSharingProfile( share_credentials_version=share_credentials_version, diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index e1103239a..bfec32760 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -157,6 +157,7 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): self._session.headers.update( { "User-Agent": DataSharingRestClient.USER_AGENT, + "Custom-Header-Recipient-Id": "7ccbb5da-b1b1-4519-ae53-190db7988199" } )