From bc3a309b7550b33e6d004ecbadc17bfb606f5114 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 11 Sep 2024 13:59:10 -0700 Subject: [PATCH 1/2] d2o python authorization code grant --- python/delta_sharing/_internal_auth.py | 223 +++++++++++++++++++++++++ python/delta_sharing/protocol.py | 16 ++ python/delta_sharing/rest_client.py | 3 +- python/delta_sharing/tests/test_e2e.py | 39 +++++ 4 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 python/delta_sharing/tests/test_e2e.py diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 55a15fb06..7ff31b625 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -188,6 +188,8 @@ def create_auth_credential_provider(profile: DeltaSharingProfile): if profile.share_credentials_version == 2: if profile.type == "oauth_client_credentials": return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "oauth_authorization_code": + return AuthCredentialProviderFactory.__oauth_authorization_code(profile) elif profile.type == "basic": return AuthCredentialProviderFactory.__auth_basic(profile) elif (profile.share_credentials_version == 1 and @@ -224,6 +226,30 @@ def __oauth_client_credentials(profile): AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider return provider + @staticmethod + def __oauth_authorization_code(profile): + # Once a clientId/clientSecret is exchanged for an accessToken, + # the accessToken can be reused until it expires. + # The Python client re-creates DeltaSharingClient for different requests. + # To ensure the OAuth access_token is reused, + # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. + # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, + # ensuring the access_token can be reused. + if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: + return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] + + + provider = OAuthU2MAuthCredentialProvider( + client_id=profile.client_id, + client_secret=profile.client_secret, + token_url=profile.token_url, + auth_url=profile.auth_url, + scope=profile.scope, + redirect_uri=profile.redirect_uri + ) + AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider + return provider + @staticmethod def __auth_bearer_token(profile): return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) @@ -231,3 +257,200 @@ def __auth_bearer_token(profile): @staticmethod def __auth_basic(profile): return BasicAuthProvider(profile.endpoint, profile.username, profile.password) + + + + +import time +import json +from typing import Optional +from abc import ABC, abstractmethod +from urllib.parse import urlencode, urlparse, parse_qs +from urllib.request import Request, urlopen +from http.server import BaseHTTPRequestHandler, HTTPServer +import threading + +class AuthCredentialProvider(ABC): + @abstractmethod + def add_auth_header(self, session) -> None: + pass + + def is_expired(self) -> bool: + return False + + @abstractmethod + def get_expiration_time(self) -> Optional[str]: + return None + +import time +import json +import base64 +import hashlib +import secrets +import webbrowser +from typing import Optional +from urllib.parse import urlencode, urlparse, parse_qs +from urllib.request import Request, urlopen +from http.server import BaseHTTPRequestHandler, HTTPServer +import threading +from abc import ABC, abstractmethod + + +class AuthCredentialProvider(ABC): + @abstractmethod + def add_auth_header(self, session) -> None: + pass + + def is_expired(self) -> bool: + return False + + @abstractmethod + def get_expiration_time(self) -> Optional[str]: + return None + + +class OAuthU2MAuthCredentialProvider(AuthCredentialProvider): + def __init__(self, client_id: str, client_secret: str, token_url: str, auth_url: str, scope: str, redirect_uri: str, port_range=range(8080, 8081)): + self.client_id = client_id + self.client_secret = client_secret + self.token_url = token_url + self.auth_url = auth_url + self.scope = scope + self.redirect_uri = redirect_uri + self.access_token = None + self.token_expiry = None + self.server = None + self.authorization_code = None + self.port_range = port_range + + def start_http_server(self, port: int): + """ + Starts an HTTP server to listen for the OAuth provider's redirect with the authorization code. + """ + class OAuthCallbackHandler(BaseHTTPRequestHandler): + def do_GET(self): + # Parse the query parameters from the redirect URI + query_components = parse_qs(urlparse(self.path).query) + code = query_components.get('code', [None])[0] + + if code: + self.server.provider.authorization_code = code + self.send_response(200) + self.end_headers() + self.wfile.write(b"OAuth Authentication successful! You can close this window.") + else: + self.send_response(400) + self.end_headers() + self.wfile.write(b"Error: Missing authorization code.") + def log_message(self, format, *args): + pass # Suppress log messages + + OAuthCallbackHandler.server = self + + self.server = HTTPServer(('localhost', port), OAuthCallbackHandler) + self.server.provider = self + #print(f"Starting HTTP server at http://localhost:{port}") + threading.Thread(target=self.server.serve_forever).start() + + def stop_http_server(self): + """ + Stops the HTTP server if it's running. + """ + if self.server: + self.server.shutdown() + self.server.server_close() + #print("HTTP server stopped.") + + def _get_access_token(self) -> None: + """ + Fetches a new access token from the OAuth provider using the authorization code. + """ + if not self.authorization_code: + self._get_authorization_code() + + token_data = { + 'grant_type': 'authorization_code', + 'code': self.authorization_code, + 'redirect_uri': self.redirect_uri, + 'client_id': self.client_id, + 'code_verifier': self.code_verifier # Include the correct code_verifier here + } + + token_request = Request(self.token_url, data=urlencode(token_data).encode('utf-8')) + token_request.add_header('Content-Type', 'application/x-www-form-urlencoded') + #print(token_data) + + import urllib.request # Import the urllib.request module for making HTTP requests + import urllib.error # Import the urllib.error module for handling HTTP errors + try: + with urlopen(token_request) as response: + token_response = json.loads(response.read().decode('utf-8')) + self.access_token = token_response.get('access_token') + expires_in = token_response.get('expires_in') + self.token_expiry = time.time() + expires_in + + except urllib.error.HTTPError as e: + error_response = e.read().decode('utf-8') + print(f"HTTP Error: {e.code} - {e.reason}\nResponse: {error_response}") + + def _get_authorization_code(self): + """ + Directs the user to the OAuth provider's authorization URL to get the authorization code. + """ + import base64 + import hashlib + import secrets + + # Generate code_verifier and code_challenge + self.code_verifier = secrets.token_urlsafe(128) + self.code_challenge = base64.urlsafe_b64encode(hashlib.sha256(self.code_verifier.encode()).digest()).rstrip(b'=').decode('utf-8') + + auth_params = urlencode({ + 'response_type': 'code', + 'client_id': self.client_id, + 'redirect_uri': self.redirect_uri, + 'scope': self.scope, + 'code_challenge': self.code_challenge, + 'code_challenge_method': 'S256' + }) + auth_url = f"{self.auth_url}?{auth_params}" + #print(f"Initiating U2M OAuth: {auth_url}") + + + webbrowser.open(auth_url) # Open the browser for user authorization + + for port in self.port_range: + try: + self.start_http_server(port) + self.redirect_uri = f"http://localhost:{port}" + while not self.authorization_code: + time.sleep(1) + break + except OSError: + continue + + self.stop_http_server() + + def add_auth_header(self, session) -> None: + """ + Adds the OAuth token to the request headers if needed. + """ + if not self.access_token or self.is_expired(): + self._get_access_token() + + session.headers['Authorization'] = f'Bearer {self.access_token}' + + def is_expired(self) -> bool: + """ + Checks if the current access token is expired. + """ + return self.token_expiry is None or time.time() >= self.token_expiry + + def get_expiration_time(self) -> Optional[str]: + """ + Returns the expiration time of the current access token. + """ + if self.token_expiry: + return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.token_expiry)) + return None + diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index d8873c156..adbc4f99c 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -37,6 +37,11 @@ class DeltaSharingProfile: password: Optional[str] = None scope: Optional[str] = None + token_url: Optional[str] = None + auth_url: Optional[str] = None + scope: Optional[str] = None + redirect_uri: Optional[str] = None + def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: raise ValueError( @@ -107,6 +112,17 @@ def from_json(json) -> "DeltaSharingProfile": username=json["username"], password=json["password"], ) + elif type == "oauth_authorization_code": + return DeltaSharingProfile( + share_credentials_version=share_credentials_version, + type=type, + endpoint=endpoint, + client_id=json["clientId"], + redirect_uri=json["redirectUri"], + token_url=json["tokenEndpoint"], + auth_url=json["authorizeEndpoint"], + scope=json.get("scope"), + ) else: raise ValueError( f"The current release does not supports {type} type. " diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index e1103239a..c4c02a0c9 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -156,7 +156,8 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): self._session.headers.update( { - "User-Agent": DataSharingRestClient.USER_AGENT, + "Custom-Header-Recipient-Id": "f1894b13-8362-42a1-a618-5ca99e6ca642", + "User-Agent": "f1894b13-8362-42a1-a618-5ca99e6ca642" } ) diff --git a/python/delta_sharing/tests/test_e2e.py b/python/delta_sharing/tests/test_e2e.py new file mode 100644 index 000000000..c1b66b67a --- /dev/null +++ b/python/delta_sharing/tests/test_e2e.py @@ -0,0 +1,39 @@ +import delta_sharing + +# Point to the profile file. It can be a file on the local file system or a file on a remote storage. +profile_file = "/Users/Moe.Derakhshani/Documents/oauth/demo/u2m.share" + +# Create a SharingClient. +client = delta_sharing.SharingClient(profile_file) +# +# List all shared tables. +tables = client.list_all_tables() + +print(tables) + + +# +# # Create a url to access a shared table. +# # A table path is the profile file path following with `#` and the fully qualified name of a table +# # (`..`). +table_url = profile_file + "#demo-d2o-identity-federation.my_schema.my_table" + +# Fetch 10 rows from a table and convert it to a Pandas DataFrame. This can be used to read sample data +# from a table that cannot fit in the memory. +df = delta_sharing.load_as_pandas(table_url, limit=10) + +print(df) + +# +# Load a table as a Pandas DataFrame. This can be used to process tables that can fit in the memory. +delta_sharing.load_as_pandas(table_url) + +# Load a table as a Pandas DataFrame explicitly using Delta Format +#delta_sharing.load_as_pandas(table_url, use_delta_format = True) + +# # If the code is running with PySpark, you can use `load_as_spark` to load the table as a Spark DataFrame. +# delta_sharing.load_as_spark(table_url) + + + +print("DONE") From 244244787ed0df07ae6511c2af6d65942aca0667 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 13 Sep 2024 17:00:01 -0700 Subject: [PATCH 2/2] updated recipientId --- python/delta_sharing/rest_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index c4c02a0c9..0cbebd025 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -156,8 +156,8 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): self._session.headers.update( { - "Custom-Header-Recipient-Id": "f1894b13-8362-42a1-a618-5ca99e6ca642", - "User-Agent": "f1894b13-8362-42a1-a618-5ca99e6ca642" + "Custom-Header-Recipient-Id": "7ccbb5da-b1b1-4519-ae53-190db7988199", + "User-Agent": "Python-Delta-Sharing-Client" } )