Skip to content

Commit 834dca6

Browse files
committed
Client credentials flow working
1 parent de5d8e6 commit 834dca6

File tree

3 files changed

+164
-6
lines changed

3 files changed

+164
-6
lines changed

datajoint/axon.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import sys
77
import flask
88
import webbrowser
9+
import requests_oauthlib
10+
import oauthlib
11+
from oauthlib.oauth2 import BackendApplicationClient
12+
from requests_oauthlib import OAuth2Session
13+
import requests
914
import urllib
1015
import http.client
1116
import botocore
@@ -249,4 +254,69 @@ def refreshable_session(self) -> boto3.Session:
249254
session._credentials = refreshable_credentials
250255
autorefresh_session = boto3.Session(botocore_session=session)
251256

252-
return autorefresh_session
257+
return autorefresh_session
258+
259+
def get_s3_client(
260+
aws_account_id: str,
261+
s3_role: str,
262+
auth_client_id: str,
263+
auth_client_secret: str = None,
264+
bearer_token: str = None,
265+
well_known_url: str = "https://keycloak-qa.datajoint.io/realms/datajoint/.well-known/openid-configuration",
266+
):
267+
"""
268+
Get S3 client with the given credentials.
269+
270+
Parameters
271+
----------
272+
aws_account_id : str
273+
AWS account ID
274+
275+
s3_role : str
276+
S3 role
277+
278+
auth_client_id : str
279+
Auth client ID
280+
281+
auth_client_secret : str (optional)
282+
Auth client secret
283+
284+
bearer_token : str (optional)
285+
Bearer token
286+
287+
well_known_url : str (optional)
288+
Well-known URL for the OpenID configuration
289+
290+
Returns
291+
-------
292+
boto3.client
293+
S3 client
294+
"""
295+
# Get token URL from well-known URL
296+
well_known_resp = requests.get(well_known_url)
297+
assert well_known_resp.status_code == 200, f"Failed to get well-known URL: {well_known_url}"
298+
well_known_data = well_known_resp.json()
299+
token_url = well_known_data.get("token_endpoint")
300+
assert token_url, f"Token URL not found in well-known data: {well_known_data=}"
301+
302+
# Client credentials flow
303+
token = _client_credentials_flow(
304+
auth_client_id, auth_client_secret, token_url
305+
)
306+
307+
#
308+
309+
310+
def _client_credentials_flow(client_id, client_secret, token_url):
311+
client = BackendApplicationClient(client_id=client_id)
312+
oauth = OAuth2Session(client=client)
313+
try:
314+
return oauth.fetch_token(
315+
token_url=token_url,
316+
client_id=client_id,
317+
client_secret=client_secret,
318+
)
319+
except oauthlib.oauth2.rfc6749.errors.UnauthorizedClientError as e:
320+
msg = f"Error getting OAuth2 client: {e.description}"
321+
log.error(msg)
322+
raise ValueError(msg) from e

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,15 @@ classifiers = [
4545
test = [
4646
"pytest",
4747
"pytest-cov",
48+
"pytest-dotenv",
4849
"black==24.2.0",
4950
"flake8",
5051
"moto[s3]>=4.2.13",
5152
]
5253
axon = [
5354
"boto3",
55+
"requests_oauthlib",
56+
"requests",
5457
]
5558

5659
[project.urls]

tests/test_axon.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,106 @@
1-
from datajoint.axon import Session
1+
import os
2+
from datajoint.axon import Session, get_s3_client
3+
import json
24
import pytest
35
import boto3
46
from moto import mock_aws
7+
import dotenv
8+
dotenv.load_dotenv(dotenv.find_dotenv())
59

610

711
@pytest.fixture
812
def moto_account_id():
913
"""Default account ID for moto"""
1014
return "123456789012"
1115

16+
17+
@pytest.fixture
18+
def keycloak_client_secret():
19+
secret = os.getenv("OAUTH_CLIENT_SECRET")
20+
if not secret:
21+
pytest.skip("No client secret found")
22+
else:
23+
return secret
24+
25+
26+
@pytest.fixture
27+
def keycloak_client_id():
28+
return os.getenv("OAUTH_CLIENT_ID", "works")
29+
30+
31+
@pytest.fixture(scope="function")
32+
def aws_credentials():
33+
"""Mocked AWS Credentials for moto."""
34+
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
35+
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
36+
os.environ["AWS_SECURITY_TOKEN"] = "testing"
37+
os.environ["AWS_SESSION_TOKEN"] = "testing"
38+
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
39+
40+
41+
@pytest.fixture(scope="function")
42+
def s3_client(aws_credentials):
43+
"""
44+
Return a mocked S3 client
45+
"""
46+
with mock_aws():
47+
yield boto3.client("s3", region_name="us-east-1")
48+
49+
50+
@pytest.fixture(scope="function")
51+
def iam_client(aws_credentials):
52+
"""
53+
Return a mocked S3 client
54+
"""
55+
with mock_aws():
56+
yield boto3.client("iam", region_name="us-east-1")
57+
58+
59+
@pytest.fixture
60+
def s3_policy(iam_client):
61+
"""Create a policy with S3 read access using boto3."""
62+
policy_doc = {
63+
"Version": "2012-10-17",
64+
"Statement": [
65+
{
66+
"Effect": "Allow",
67+
"Action": "s3:GetObject",
68+
"Resource": "arn:aws:s3:::mybucket/*",
69+
}
70+
],
71+
}
72+
return iam_client.create_policy(
73+
PolicyName="test-policy",
74+
Path="/",
75+
PolicyDocument=json.dumps(policy_doc),
76+
Description="Test policy",
77+
)
78+
79+
@pytest.fixture
80+
def s3_role(moto_account_id, s3_policy):
81+
"""Create a mock role and policy document for testing"""
82+
return "123456789012"
83+
84+
1285
@mock_aws
86+
@pytest.mark.skip
1387
class TestSession:
14-
def test_can_init(self):
88+
def test_can_init(self, s3_role, keycloak_client_id, keycloak_client_secret, moto_account_id):
1589
session = Session(
1690
aws_account_id=moto_account_id,
17-
s3_role="test-role",
18-
auth_client_id="test-client-id",
19-
auth_client_secret="test-client-secret",
91+
s3_role=s3_role,
92+
auth_client_id=keycloak_client_id,
93+
auth_client_secret=keycloak_client_secret,
2094
)
2195
assert session.bearer_token, "Bearer token not set"
96+
97+
def test_get_s3_client(s3_role, keycloak_client_id, keycloak_client_secret, moto_account_id):
98+
client = get_s3_client(
99+
auth_client_id=keycloak_client_id,
100+
auth_client_secret=keycloak_client_secret,
101+
aws_account_id=moto_account_id,
102+
s3_role=s3_role,
103+
bearer_token=None,
104+
)
105+
assert client
106+

0 commit comments

Comments
 (0)