Skip to content

Commit cca4b54

Browse files
author
DanielePalaia
committed
adding oauth configuration
1 parent 055c0e1 commit cca4b54

File tree

9 files changed

+186
-16
lines changed

9 files changed

+186
-16
lines changed

poetry.lock

Lines changed: 80 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ readme = "README.md"
99
[tool.poetry.dependencies]
1010
python = "^3.9"
1111
python-qpid-proton = "^0.39.0"
12+
jwt = "^1.3.1"
13+
pyjwt = "^2.10.1"
1214

1315
[tool.poetry.group.dev.dependencies]
1416
flake8 = "^7.1.1"

rabbitmq_amqp_python_client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ExchangeSpecification,
1111
ExchangeToExchangeBindingSpecification,
1212
ExchangeToQueueBindingSpecification,
13+
OAuth2Options,
1314
OffsetSpecification,
1415
RecoveryConfiguration,
1516
StreamOptions,
@@ -89,4 +90,5 @@
8990
"Environment",
9091
"ExchangeCustomSpecification",
9192
"RecoveryConfiguration",
93+
"OAuth2Options",
9294
]

rabbitmq_amqp_python_client/connection.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
from .address_helper import validate_address
1616
from .consumer import Consumer
17-
from .entities import RecoveryConfiguration, StreamOptions
17+
from .entities import (
18+
OAuth2Options,
19+
RecoveryConfiguration,
20+
StreamOptions,
21+
)
1822
from .exceptions import (
1923
ArgumentOutOfRangeException,
2024
ValidationCodeException,
@@ -60,6 +64,7 @@ def __init__(
6064
ssl_context: Union[
6165
PosixSslConfigurationContext, WinSslConfigurationContext, None
6266
] = None,
67+
oauth2_options: Optional[OAuth2Options] = None,
6368
recovery_configuration: RecoveryConfiguration = RecoveryConfiguration(),
6469
):
6570
"""
@@ -93,6 +98,7 @@ def __init__(
9398
self._index: int = -1
9499
self._publishers: list[Publisher] = []
95100
self._consumers: list[Consumer] = []
101+
self._oauth2_options = oauth2_options
96102

97103
# Some recovery_configuration validation
98104
if recovery_configuration.back_off_reconnect_interval < timedelta(seconds=1):
@@ -109,19 +115,8 @@ def _set_environment_connection_list(self, connections: []): # type: ignore
109115
def _open_connections(self, reconnect_handlers: bool = False) -> None:
110116

111117
logger.debug("inside connection._open_connections creating connection")
112-
if self._recovery_configuration.active_recovery is False:
113-
self._conn = BlockingConnection(
114-
url=self._addr,
115-
urls=self._addrs,
116-
ssl_domain=self._ssl_domain,
117-
)
118-
else:
119-
self._conn = BlockingConnection(
120-
url=self._addr,
121-
urls=self._addrs,
122-
ssl_domain=self._ssl_domain,
123-
on_disconnection_handler=self._on_disconnection,
124-
)
118+
119+
self._create_connection()
125120

126121
if reconnect_handlers is True:
127122
logger.debug("reconnecting managements, publishers and consumers handlers")
@@ -137,6 +132,35 @@ def _open_connections(self, reconnect_handlers: bool = False) -> None:
137132
# Update the broken connection and sender in the consumer
138133
self._consumers[i]._update_connection(self._conn)
139134

135+
def _create_connection(self) -> None:
136+
137+
user = None
138+
password = None
139+
140+
if self._oauth2_options is not None:
141+
user = ""
142+
password = self._oauth2_options.token
143+
144+
if self._recovery_configuration.active_recovery is False:
145+
self._conn = BlockingConnection(
146+
url=self._addr,
147+
urls=self._addrs,
148+
oauth2_options=self._oauth2_options,
149+
ssl_domain=self._ssl_domain,
150+
user=user,
151+
password=password,
152+
)
153+
else:
154+
self._conn = BlockingConnection(
155+
url=self._addr,
156+
urls=self._addrs,
157+
oauth2_options=self._oauth2_options,
158+
ssl_domain=self._ssl_domain,
159+
on_disconnection_handler=self._on_disconnection,
160+
user=user,
161+
password=password,
162+
)
163+
140164
def dial(self) -> None:
141165
"""
142166
Establish a connection to the AMQP server.

rabbitmq_amqp_python_client/entities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,8 @@ class RecoveryConfiguration:
257257
active_recovery: bool = True
258258
back_off_reconnect_interval: timedelta = timedelta(seconds=5)
259259
MaxReconnectAttempts: int = 5
260+
261+
262+
@dataclass
263+
class OAuth2Options:
264+
token: str

rabbitmq_amqp_python_client/environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111
from .connection import Connection
12-
from .entities import RecoveryConfiguration
12+
from .entities import OAuth2Options, RecoveryConfiguration
1313
from .ssl_configuration import (
1414
PosixSslConfigurationContext,
1515
WinSslConfigurationContext,
@@ -41,6 +41,7 @@ def __init__(
4141
ssl_context: Union[
4242
PosixSslConfigurationContext, WinSslConfigurationContext, None
4343
] = None,
44+
oauth2_options: Optional[OAuth2Options] = None,
4445
recovery_configuration: RecoveryConfiguration = RecoveryConfiguration(),
4546
):
4647
"""
@@ -66,6 +67,7 @@ def __init__(
6667
self._ssl_context = ssl_context
6768
self._recovery_configuration = recovery_configuration
6869
self._connections: list[Connection] = []
70+
self._oauth2_options = oauth2_options
6971

7072
def connection(
7173
self,

tests/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
from datetime import datetime, timedelta
34
from typing import Optional
45

56
import pytest
@@ -9,6 +10,7 @@
910
AMQPMessagingHandler,
1011
Environment,
1112
Event,
13+
OAuth2Options,
1214
PKCS12Store,
1315
PosixClientCert,
1416
PosixSslConfigurationContext,
@@ -22,6 +24,7 @@
2224
)
2325

2426
from .http_requests import delete_all_connections
27+
from .utils import token
2528

2629
os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
2730

@@ -36,6 +39,20 @@ def environment(pytestconfig):
3639
environment.close()
3740

3841

42+
@pytest.fixture()
43+
def environment_auth(pytestconfig):
44+
token_string = token(datetime.now() + timedelta(milliseconds=2500))
45+
environment = Environment(
46+
uri="amqp://guest:guest@localhost:5672/",
47+
oauth2_options=OAuth2Options(token=token_string),
48+
)
49+
try:
50+
yield environment
51+
52+
finally:
53+
environment.close()
54+
55+
3956
@pytest.fixture()
4057
def connection(pytestconfig):
4158
environment = Environment(uri="amqp://guest:guest@localhost:5672/")

tests/test_connection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def test_connection_ssl(ssl_context) -> None:
4343
environment.close()
4444

4545

46+
def test_connection_auth(environment_auth: Environment) -> None:
47+
48+
connection = environment_auth.connection()
49+
connection.dial()
50+
51+
4652
def test_environment_connections_management() -> None:
4753

4854
environment = Environment(uri="amqp://guest:guest@localhost:5672/")

tests/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import base64
2+
from datetime import datetime
13
from typing import Optional
24

5+
import jwt
6+
37
from rabbitmq_amqp_python_client import (
48
AddressHelper,
59
Connection,
@@ -92,3 +96,32 @@ def cleanup_dead_lettering(management: Management, bind_path: str) -> None:
9296
management.unbind(bind_path)
9397
management.delete_exchange(exchange_dead_lettering)
9498
management.delete_queue(queue_dead_lettering)
99+
100+
101+
def token(duration: datetime) -> str:
102+
# Decode the base64 key
103+
decoded_key = base64.b64decode("abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGH")
104+
105+
# Define the claims
106+
claims = {
107+
"iss": "unit_test",
108+
"aud": "rabbitmq",
109+
"exp": duration,
110+
"scope": ["rabbitmq.configure:*/*", "rabbitmq.write:*/*", "rabbitmq.read:*/*"],
111+
"random": random_string(6),
112+
}
113+
114+
# Create the token with the claims and sign it
115+
token = jwt.encode(
116+
claims, decoded_key, algorithm="HS256", headers={"kid": "token-key"}
117+
)
118+
119+
return token
120+
121+
122+
# Helper function to generate a random string (replace with your implementation)
123+
def random_string(length: int) -> str:
124+
import random
125+
import string
126+
127+
return "".join(random.choices(string.ascii_letters + string.digits, k=length))

0 commit comments

Comments
 (0)