Skip to content

Commit 85df16d

Browse files
authored
fix: reuse the boto3 session to sign request (#122)
* fix: reuse the boto3 session to sign request instead of creating a new session to sign every request * test: mock the create session when session is not provided
1 parent 92078b8 commit 85df16d

File tree

6 files changed

+90
-45
lines changed

6 files changed

+90
-45
lines changed

mcp_proxy_for_aws/sigv4_helper.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def create_sigv4_client(
107107
region: str,
108108
timeout: Optional[httpx.Timeout] = None,
109109
profile: Optional[str] = None,
110+
session: Optional[boto3.Session] = None,
110111
headers: Optional[Dict[str, str]] = None,
111112
metadata: Optional[Dict[str, Any]] = None,
112113
**kwargs: Any,
@@ -115,7 +116,8 @@ def create_sigv4_client(
115116
116117
Args:
117118
service: AWS service name for SigV4 signing
118-
profile: AWS profile to use (optional)
119+
profile: AWS profile to use (optional, only used if session is not provided)
120+
session: AWS boto3 session to use (optional, takes precedence over profile)
119121
region: AWS region (optional, defaults to AWS_REGION env var or us-east-1)
120122
timeout: Timeout configuration for the HTTP client
121123
headers: Headers to include in requests
@@ -125,6 +127,10 @@ def create_sigv4_client(
125127
Returns:
126128
httpx.AsyncClient with SigV4 authentication
127129
"""
130+
# Create or use provided AWS session
131+
if session is None:
132+
session = create_aws_session(profile)
133+
128134
# Create a copy of kwargs to avoid modifying the passed dict
129135
client_kwargs = {
130136
'follow_redirects': True,
@@ -151,7 +157,7 @@ def create_sigv4_client(
151157
'response': [_handle_error_response],
152158
'request': [
153159
partial(_inject_metadata_hook, metadata or {}),
154-
partial(_sign_request_hook, region, service, profile),
160+
partial(_sign_request_hook, region, service, session),
155161
],
156162
},
157163
)
@@ -210,7 +216,7 @@ async def _handle_error_response(response: httpx.Response) -> None:
210216
async def _sign_request_hook(
211217
region: str,
212218
service: str,
213-
profile: Optional[str],
219+
session: boto3.Session,
214220
request: httpx.Request,
215221
) -> None:
216222
"""Request hook to sign HTTP requests with AWS SigV4.
@@ -222,14 +228,13 @@ async def _sign_request_hook(
222228
Args:
223229
region: AWS region for SigV4 signing
224230
service: AWS service name for SigV4 signing
225-
profile: AWS profile to use (optional)
231+
session: AWS boto3 session to use for credentials
226232
request: The HTTP request object to sign (modified in-place)
227233
"""
228234
# Set Content-Length for signing
229235
request.headers['Content-Length'] = str(len(request.content))
230236

231-
# Get AWS credentials
232-
session = create_aws_session(profile)
237+
# Get AWS credentials from the session
233238
credentials = session.get_credentials()
234239
logger.info('Signing request with credentials for access key: %s', credentials.access_key)
235240

mcp_proxy_for_aws/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
from fastmcp.client.transports import StreamableHttpTransport
22-
from mcp_proxy_for_aws.sigv4_helper import create_sigv4_client
22+
from mcp_proxy_for_aws.sigv4_helper import create_aws_session, create_sigv4_client
2323
from typing import Any, Dict, Optional, Tuple
2424
from urllib.parse import urlparse
2525

@@ -49,6 +49,9 @@ def create_transport_with_sigv4(
4949
Returns:
5050
StreamableHttpTransport instance with SigV4 authentication
5151
"""
52+
# Create AWS session once and reuse it for all httpx clients
53+
logger.debug('Creating AWS session with profile: %s', profile)
54+
session = create_aws_session(profile)
5255

5356
def client_factory(
5457
headers: Optional[Dict[str, str]] = None,
@@ -57,7 +60,7 @@ def client_factory(
5760
) -> httpx.AsyncClient:
5861
return create_sigv4_client(
5962
service=service,
60-
profile=profile,
63+
session=session,
6164
region=region,
6265
headers=headers,
6366
timeout=custom_timeout,

tests/unit/test_hooks.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_inject_metadata_hook,
2424
_sign_request_hook,
2525
)
26-
from unittest.mock import MagicMock, Mock, patch
26+
from unittest.mock import MagicMock, Mock
2727

2828

2929
def create_request_with_sigv4_headers(
@@ -343,87 +343,76 @@ async def test_hook_preserves_other_params(self):
343343
class TestSignRequestHook:
344344
"""Test cases for sign_request_hook function."""
345345

346-
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
347346
@pytest.mark.asyncio
348-
async def test_sign_request_hook_signs_request(self, mock_create_session):
347+
async def test_sign_request_hook_signs_request(self):
349348
"""Test that sign_request_hook properly signs requests."""
350349
# Setup mocks
351-
mock_create_session.return_value = create_mock_session()
350+
mock_session = create_mock_session()
352351

353352
region = 'us-east-1'
354353
service = 'bedrock-agentcore'
355-
profile = None
356354

357355
# Create request without signature headers
358356
request_body = json.dumps({'test': 'data'}).encode('utf-8')
359357
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)
360358

361359
# Call the hook
362-
await _sign_request_hook(region, service, profile, request)
360+
await _sign_request_hook(region, service, mock_session, request)
363361

364362
# Verify signature headers were added
365363
assert 'authorization' in request.headers
366364
assert 'x-amz-date' in request.headers
367365
assert 'x-amz-security-token' in request.headers
368366
assert request.headers['content-length'] == str(len(request_body))
369367

370-
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
371368
@pytest.mark.asyncio
372-
async def test_sign_request_hook_with_profile(self, mock_create_session):
373-
"""Test that sign_request_hook uses profile when provided."""
369+
async def test_sign_request_hook_with_profile(self):
370+
"""Test that sign_request_hook uses session when provided."""
374371
# Setup mocks
375-
mock_create_session.return_value = create_mock_session()
372+
mock_session = create_mock_session()
376373

377374
region = 'us-west-2'
378375
service = 'execute-api'
379-
profile = 'test-profile'
380376

381377
request_body = b'test content'
382378
request = httpx.Request('POST', 'https://example.com/api', content=request_body)
383379

384380
# Call the hook
385-
await _sign_request_hook(region, service, profile, request)
386-
387-
# Verify session was created with profile
388-
mock_create_session.assert_called_once_with(profile)
381+
await _sign_request_hook(region, service, mock_session, request)
389382

390383
# Verify request was signed
391384
assert 'authorization' in request.headers
392385
assert 'x-amz-date' in request.headers
393386

394-
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
395387
@pytest.mark.asyncio
396-
async def test_sign_request_hook_sets_content_length(self, mock_create_session):
388+
async def test_sign_request_hook_sets_content_length(self):
397389
"""Test that sign_request_hook sets Content-Length header."""
398390
# Setup mocks
399-
mock_create_session.return_value = create_mock_session()
391+
mock_session = create_mock_session()
400392

401393
region = 'eu-west-1'
402394
service = 'lambda'
403-
profile = None
404395

405396
# Create request
406397
request_body = b'test content with specific length'
407398
request = httpx.Request('POST', 'https://example.com/api', content=request_body)
408399

409-
await _sign_request_hook(region, service, profile, request)
400+
await _sign_request_hook(region, service, mock_session, request)
410401

411402
# Verify Content-Length was set correctly
412403
assert request.headers['content-length'] == str(len(request_body))
413404

414-
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
415405
@pytest.mark.asyncio
416-
async def test_sign_request_hook_with_partial_application(self, mock_create_session):
406+
async def test_sign_request_hook_with_partial_application(self):
417407
"""Test that sign_request_hook works with functools.partial."""
418408
# Setup mocks
419-
mock_create_session.return_value = create_mock_session()
409+
mock_session = create_mock_session()
420410

421411
region = 'ap-southeast-1'
422412
service = 'execute-api'
423-
profile = 'prod-profile'
424413

425414
# Create curried function using partial
426-
curried_hook = partial(_sign_request_hook, region, service, profile)
415+
curried_hook = partial(_sign_request_hook, region, service, mock_session)
427416

428417
request_body = b'request data'
429418
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)
@@ -434,4 +423,3 @@ async def test_sign_request_hook_with_partial_application(self, mock_create_sess
434423
# Verify request was signed
435424
assert 'authorization' in request.headers
436425
assert 'x-amz-date' in request.headers
437-
mock_create_session.assert_called_once_with(profile)

tests/unit/test_server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,17 +401,25 @@ def test_validate_service_name_service_parsing(self):
401401
result = determine_service_name(endpoint)
402402
assert result == expected_service
403403

404+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
404405
@patch('mcp_proxy_for_aws.sigv4_helper.httpx.AsyncClient')
405-
def test_create_sigv4_client(self, mock_async_client):
406+
def test_create_sigv4_client(self, mock_async_client, mock_create_session):
406407
"""Test creating SigV4 authenticated client with request hooks.
407408
408409
Note: Session creation and signing now happens in sign_request_hook,
409410
not during client creation.
410411
"""
412+
# Mock session creation
413+
mock_session = Mock()
414+
mock_session.get_credentials.return_value = Mock(access_key='test-key')
415+
mock_create_session.return_value = mock_session
416+
411417
# Act
412418
create_sigv4_client(service='test-service', region='us-west-2', profile='test-profile')
413419

414420
# Assert
421+
# Verify session was created with profile
422+
mock_create_session.assert_called_once_with('test-profile')
415423
# Verify AsyncClient was called (signing happens via hooks)
416424
assert mock_async_client.call_count == 1
417425
call_args = mock_async_client.call_args
@@ -422,12 +430,16 @@ def test_create_sigv4_client(self, mock_async_client):
422430
# Should have metadata injection + sign hooks
423431
assert len(call_args[1]['event_hooks']['request']) == 2
424432

425-
def test_create_sigv4_client_no_credentials(self):
433+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
434+
def test_create_sigv4_client_no_credentials(self, mock_create_session):
426435
"""Test that credential check happens in sign_request_hook, not during client creation.
427436
428437
Note: With the refactoring, client creation no longer validates credentials.
429438
Credential validation now happens in sign_request_hook when the request is signed.
430439
"""
440+
mock_session = Mock()
441+
mock_create_session.return_value = mock_session
442+
431443
# Client creation should succeed even without credentials
432444
# (credentials are checked when signing happens)
433445
client = create_sigv4_client(service='test-service', region='test-region')

tests/unit/test_sigv4_helper.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,14 @@ def test_create_aws_session_creation_failure(self, mock_session_class):
119119
class TestCreateSigv4Client:
120120
"""Test cases for the create_sigv4_client function."""
121121

122+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
122123
@patch('httpx.AsyncClient')
123-
def test_create_sigv4_client_default(self, mock_client_class):
124+
def test_create_sigv4_client_default(self, mock_client_class, mock_create_session):
124125
"""Test creating SigV4 client with default parameters."""
125126
mock_client = Mock()
126127
mock_client_class.return_value = mock_client
128+
mock_session = Mock()
129+
mock_create_session.return_value = mock_session
127130

128131
# Test client creation
129132
result = create_sigv4_client(service='test-service', region='test-region')
@@ -139,11 +142,14 @@ def test_create_sigv4_client_default(self, mock_client_class):
139142
assert call_args[1]['headers']['Accept'] == 'application/json, text/event-stream'
140143
assert result == mock_client
141144

145+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
142146
@patch('httpx.AsyncClient')
143-
def test_create_sigv4_client_with_custom_headers(self, mock_client_class):
147+
def test_create_sigv4_client_with_custom_headers(self, mock_client_class, mock_create_session):
144148
"""Test creating SigV4 client with custom headers."""
145149
mock_client = Mock()
146150
mock_client_class.return_value = mock_client
151+
mock_session = Mock()
152+
mock_create_session.return_value = mock_session
147153

148154
# Test client creation with custom headers
149155
custom_headers = {'Custom-Header': 'custom-value'}
@@ -160,25 +166,38 @@ def test_create_sigv4_client_with_custom_headers(self, mock_client_class):
160166
assert call_args[1]['headers'] == expected_headers
161167
assert result == mock_client
162168

169+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
163170
@patch('httpx.AsyncClient')
164-
def test_create_sigv4_client_with_custom_service_and_region(self, mock_client_class):
171+
def test_create_sigv4_client_with_custom_service_and_region(
172+
self, mock_client_class, mock_create_session
173+
):
165174
"""Test creating SigV4 client with custom service and region."""
166175
mock_client = Mock()
167176
mock_client_class.return_value = mock_client
168177

178+
# Mock session creation
179+
mock_session = Mock()
180+
mock_session.get_credentials.return_value = Mock(access_key='test-key')
181+
mock_create_session.return_value = mock_session
182+
169183
# Test client creation with custom parameters
170184
result = create_sigv4_client(
171185
service='custom-service', profile='test-profile', region='us-east-1'
172186
)
173187

188+
# Verify session was created with profile
189+
mock_create_session.assert_called_once_with('test-profile')
174190
# Verify client was created
175191
assert result == mock_client
176192

193+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
177194
@patch('httpx.AsyncClient')
178-
def test_create_sigv4_client_with_kwargs(self, mock_client_class):
195+
def test_create_sigv4_client_with_kwargs(self, mock_client_class, mock_create_session):
179196
"""Test creating SigV4 client with additional kwargs."""
180197
mock_client = Mock()
181198
mock_client_class.return_value = mock_client
199+
mock_session = Mock()
200+
mock_create_session.return_value = mock_session
182201

183202
# Test client creation with additional kwargs
184203
result = create_sigv4_client(
@@ -194,8 +213,9 @@ def test_create_sigv4_client_with_kwargs(self, mock_client_class):
194213
assert call_args[1]['proxies'] == {'http': 'http://proxy:8080'}
195214
assert result == mock_client
196215

216+
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
197217
@patch('httpx.AsyncClient')
198-
def test_create_sigv4_client_with_prompt_context(self, mock_client_class):
218+
def test_create_sigv4_client_with_prompt_context(self, mock_client_class, mock_create_session):
199219
"""Test creating SigV4 client when prompts exist in the system context.
200220
201221
This test simulates the scenario where the sigv4_helper is used in a context
@@ -204,6 +224,8 @@ def test_create_sigv4_client_with_prompt_context(self, mock_client_class):
204224
"""
205225
mock_client = Mock()
206226
mock_client_class.return_value = mock_client
227+
mock_session = Mock()
228+
mock_create_session.return_value = mock_session
207229

208230
# Test client creation with headers that might be used when prompts exist
209231
prompt_context_headers = {

0 commit comments

Comments
 (0)