-
Couldn't load subscription status.
- Fork 207
OAuth Private Key support #733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6924cc2
f1daf9e
f00da44
99ad3f5
6c989fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,15 +15,17 @@ | |
| # | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from datetime import datetime | ||
| from datetime import datetime, timezone | ||
| from typing import Optional | ||
| import requests | ||
| import base64 | ||
| import json | ||
| from jwcrypto import jwk, jwt | ||
| import threading | ||
| import requests.sessions | ||
| import time | ||
| from typing import Dict | ||
| import uuid | ||
|
|
||
| from delta_sharing.protocol import ( | ||
| DeltaSharingProfile, | ||
|
|
@@ -113,28 +115,10 @@ def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): | |
| self.creation_timestamp = creation_timestamp | ||
|
|
||
|
|
||
| class OAuthClient: | ||
| def __init__( | ||
| self, token_endpoint: str, client_id: str, client_secret: str, scope: Optional[str] = None | ||
| ): | ||
| self.token_endpoint = token_endpoint | ||
| self.client_id = client_id | ||
| self.client_secret = client_secret | ||
| self.scope = scope | ||
|
|
||
| class OAuthClient(ABC): | ||
| @abstractmethod | ||
| def client_credentials(self) -> OAuthClientCredentials: | ||
| credentials = base64.b64encode( | ||
| f"{self.client_id}:{self.client_secret}".encode("utf-8") | ||
| ).decode("utf-8") | ||
| headers = { | ||
| "accept": "application/json", | ||
| "authorization": f"Basic {credentials}", | ||
| "content-type": "application/x-www-form-urlencoded", | ||
| } | ||
| body = f"grant_type=client_credentials{f'&scope={self.scope}' if self.scope else ''}" | ||
| response = requests.post(self.token_endpoint, headers=headers, data=body) | ||
| response.raise_for_status() | ||
| return self.parse_oauth_token_response(response.text) | ||
| pass | ||
|
|
||
| def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: | ||
| if not response: | ||
|
|
@@ -169,6 +153,94 @@ def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: | |
| ) | ||
|
|
||
|
|
||
| class ClientSecretOAuthClient(OAuthClient): | ||
| def __init__( | ||
| self, | ||
| token_endpoint: str, | ||
| client_id: str, | ||
| client_secret: str, | ||
| scope: Optional[str] = None, | ||
| ): | ||
| self.token_endpoint = token_endpoint | ||
| self.client_id = client_id | ||
| self.client_secret = client_secret | ||
| self.scope = scope | ||
|
|
||
| def client_credentials(self) -> OAuthClientCredentials: | ||
| credentials = base64.b64encode( | ||
| f"{self.client_id}:{self.client_secret}".encode("utf-8") | ||
| ).decode("utf-8") | ||
| headers = { | ||
| "accept": "application/json", | ||
| "authorization": f"Basic {credentials}", | ||
| "content-type": "application/x-www-form-urlencoded", | ||
| } | ||
| body = f"grant_type=client_credentials{f'&scope={self.scope}' if self.scope else ''}" | ||
| response = requests.post(self.token_endpoint, headers=headers, data=body) | ||
| response.raise_for_status() | ||
| return self.parse_oauth_token_response(response.text) | ||
|
|
||
|
|
||
| class PrivateKeyOAuthClient(OAuthClient): | ||
| def __init__( | ||
| self, | ||
| token_endpoint: str, | ||
| client_id: str, | ||
| key_id: str, | ||
| private_key: str, | ||
| issuer: str, | ||
| scope: Optional[str] = None, | ||
| resource: Optional[str] = None, | ||
| algorithm: Optional[str] = None, | ||
| ): | ||
| self.token_endpoint = token_endpoint | ||
| self.client_id = client_id | ||
| self.key_id = key_id | ||
| self.private_key = private_key | ||
| self.issuer = issuer | ||
| self.scope = scope | ||
| self.resource = resource | ||
| if algorithm is None: | ||
| algorithm = "RS256" | ||
| self.algorithm = algorithm | ||
|
|
||
| def client_credentials(self) -> OAuthClientCredentials: | ||
| timestamp = int(datetime.now(timezone.utc).timestamp()) | ||
| jwt_header = {"alg": self.algorithm, "kid": self.key_id} | ||
| jwt_claims = { | ||
| "aud": self.issuer, | ||
| "iss": self.client_id, | ||
| "iat": timestamp, | ||
| "exp": timestamp + 120, | ||
| "jti": str(uuid.uuid4()), | ||
| } | ||
| if self.scope: | ||
| jwt_claims["scope"] = self.scope | ||
| if self.resource: | ||
| jwt_claims["resource"] = self.resource # In OAuth 2 spec audience is called resource | ||
| signed_jwt = self._signed_jwt(jwt_header, jwt_claims) | ||
| body = { | ||
| "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", | ||
| "assertion": signed_jwt, | ||
| } | ||
| headers = { | ||
| "accept": "application/json", | ||
| "content-type": "application/x-www-form-urlencoded", | ||
| } | ||
| response = requests.post(self.token_endpoint, headers=headers, data=body) | ||
| response.raise_for_status() | ||
| return self.parse_oauth_token_response(response.text) | ||
|
|
||
| def _signed_jwt(self, jwt_header, jwt_claims): | ||
| """Generate a signed JWT token using the private key""" | ||
| jwt_token = jwt.JWT(header=jwt_header, claims=jwt_claims) | ||
| with open(self.private_key, "rb") as key_file: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes the key_file is available on disk and that the client can read it directly. How would this work in a Spark cluster environment, where the key isn’t necessarily stored on local disk? Where is the key expected to be stored in that case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have not tested it inside a spark cluster environment, since we envisioned that data scientists would want to read the data in their development context. However, we are happy to get feedback from you on how to get it. In spark, one could get the key from a secrets service, and write the contents to a local path which is passed to DeltaSharingProfile(). However, an alternative or secondary option could be to pass the secret value into DeltaSharingProfile() as parameter. |
||
| pem_data = key_file.read() | ||
| private_key = jwk.JWK.from_pem(pem_data) | ||
| jwt_token.make_signed_token(private_key) | ||
| return jwt_token.serialize() | ||
|
|
||
|
|
||
| class OAuthClientCredentialsAuthProvider(AuthCredentialProvider): | ||
| def __init__(self, oauth_client: OAuthClient, auth_config: AuthConfig = AuthConfig()): | ||
| self.auth_config = auth_config | ||
|
|
@@ -210,6 +282,8 @@ def create_auth_credential_provider(profile: DeltaSharingProfile): | |
| if profile.share_credentials_version == 2: | ||
| if profile.type == "oauth_client_credentials": | ||
| return AuthCredentialProviderFactory.__oauth_client_credentials(profile) | ||
| elif profile.type == "oauth_jwt_bearer_private_key_jwt": | ||
| return AuthCredentialProviderFactory.__oauth_jwt_bearer_private_key_jwt(profile) | ||
| elif profile.type == "basic": | ||
| return AuthCredentialProviderFactory.__auth_basic(profile) | ||
| elif profile.share_credentials_version == 1 and ( | ||
|
|
@@ -236,7 +310,7 @@ def __oauth_client_credentials(profile): | |
| if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: | ||
| return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] | ||
|
|
||
| oauth_client = OAuthClient( | ||
| oauth_client = ClientSecretOAuthClient( | ||
| token_endpoint=profile.token_endpoint, | ||
| client_id=profile.client_id, | ||
| client_secret=profile.client_secret, | ||
|
|
@@ -248,6 +322,41 @@ def __oauth_client_credentials(profile): | |
| AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider | ||
| return provider | ||
|
|
||
| @staticmethod | ||
| def __oauth_jwt_bearer_private_key_jwt(profile): | ||
| # Once a clientId/privateKey/keyId is exchanged for an accessToken, | ||
| # the accessToken can be reused until it expires. | ||
| # Resource-claim in JWT-grant is optional, value is set in config.share.audience | ||
| # The Python client re-creates DeltaSharingClient for different requests. | ||
| # To ensure the OAuth access_token is reused, | ||
| # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. | ||
| # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, | ||
| # ensuring the access_token can be reused. | ||
| if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: | ||
| return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] | ||
|
|
||
| # Extract private key configuration from nested structure | ||
| private_key_config = profile.private_key or {} | ||
| private_key_file = private_key_config.get("privateKeyFile") | ||
| key_id = private_key_config.get("keyId") | ||
| algorithm = private_key_config.get("algorithm") | ||
|
|
||
| oauth_client = PrivateKeyOAuthClient( | ||
| token_endpoint=profile.token_endpoint, | ||
| client_id=profile.client_id, | ||
| key_id=key_id, | ||
| private_key=private_key_file, | ||
| issuer=profile.issuer, | ||
| resource=profile.audience, | ||
| scope=profile.scope, | ||
| algorithm=algorithm, | ||
| ) | ||
| provider = OAuthClientCredentialsAuthProvider( | ||
| oauth_client=oauth_client, auth_config=AuthConfig() | ||
| ) | ||
| AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider | ||
| return provider | ||
|
|
||
| @staticmethod | ||
| def __auth_bearer_token(profile): | ||
| return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@paalvibe how do you intend to use this client? Specifically, do you plan to use this change with Delta Sharing over OIDC M2M?
From the PR, it sounds like you are sending a self-signed token to your configured token endpoint, which then returns a JWT.
I am curious what the resultant token from your token endpoint looks like?
Could you share a sample of the JWT that your token endpoint returns (and that the client ultimately sends to the Delta Sharing server), along with the corresponding OIDC federation policy configuration the server uses to authenticate the request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Moe,
Yes, we are planning to use Delta Sharing over OIDC M2M.
Here is an example of the decoded JWT returned by the token endpoint:
Here are the policy details needed in Databricks OIDC server configuration:
sky.maskinporten.nois the national company identity provider in Norway: https://docs-digdir-no.translate.goog/docs/Maskinporten/maskinporten_skyporten.html?_x_tr_sl=no&_x_tr_tl=en&_x_tr_hl=en&_x_tr_pto=wappIt is a machine to machine identity service, where access is granted in a maskinporten web interface or API according to organization number e.g. "0192:315300649" and scope e.g. "acme:customerdata.gold".
We have already implemented this for sharing cloud resources on AWS, Azure and GCP, as explained in the documentation above.
We have already tested it with our branch and it works. It would enable any data consumer identified with skyporten to read a Delta Share wherever they are, inside or outside Databricks, e.g. in a local notebook.
More info if of intersest:
At Samferdselsdata.no (Public Transport Sector Data Sharing initiative) we have working with Norwegian Digitalisation Agency (Digdir) to implement Delta Shares directly with the Maskinporten National Orgnumber OAuth2 service for authentication, like we have achieved for IAM based cloud access with Skyporten (https://docs.digdir.no/docs/Maskinporten/maskinporten_skyporten). This would mean that one could simply declare which org number should have access, and avoid any credentials exchange at all. Hopefully, a similar pattern will also be possible to use across Europe soon. In practice this enables country code+orgnumber based delta sharing instead of Entra or email-based access.