2323 _inject_metadata_hook ,
2424 _sign_request_hook ,
2525)
26- from unittest .mock import MagicMock , Mock , patch
26+ from unittest .mock import MagicMock , Mock
2727
2828
2929def create_request_with_sigv4_headers (
@@ -343,87 +343,76 @@ async def test_hook_preserves_other_params(self):
343343class TestSignRequestHook :
344344 """Test cases for sign_request_hook function."""
345345
346- @patch ('mcp_proxy_for_aws.sigv4_helper.create_aws_session' )
347346 @pytest .mark .asyncio
348- async def test_sign_request_hook_signs_request (self , mock_create_session ):
347+ async def test_sign_request_hook_signs_request (self ):
349348 """Test that sign_request_hook properly signs requests."""
350349 # Setup mocks
351- mock_create_session . return_value = create_mock_session ()
350+ mock_session = create_mock_session ()
352351
353352 region = 'us-east-1'
354353 service = 'bedrock-agentcore'
355- profile = None
356354
357355 # Create request without signature headers
358356 request_body = json .dumps ({'test' : 'data' }).encode ('utf-8' )
359357 request = httpx .Request ('POST' , 'https://example.com/mcp' , content = request_body )
360358
361359 # Call the hook
362- await _sign_request_hook (region , service , profile , request )
360+ await _sign_request_hook (region , service , mock_session , request )
363361
364362 # Verify signature headers were added
365363 assert 'authorization' in request .headers
366364 assert 'x-amz-date' in request .headers
367365 assert 'x-amz-security-token' in request .headers
368366 assert request .headers ['content-length' ] == str (len (request_body ))
369367
370- @patch ('mcp_proxy_for_aws.sigv4_helper.create_aws_session' )
371368 @pytest .mark .asyncio
372- async def test_sign_request_hook_with_profile (self , mock_create_session ):
373- """Test that sign_request_hook uses profile when provided."""
369+ async def test_sign_request_hook_with_profile (self ):
370+ """Test that sign_request_hook uses session when provided."""
374371 # Setup mocks
375- mock_create_session . return_value = create_mock_session ()
372+ mock_session = create_mock_session ()
376373
377374 region = 'us-west-2'
378375 service = 'execute-api'
379- profile = 'test-profile'
380376
381377 request_body = b'test content'
382378 request = httpx .Request ('POST' , 'https://example.com/api' , content = request_body )
383379
384380 # Call the hook
385- await _sign_request_hook (region , service , profile , request )
386-
387- # Verify session was created with profile
388- mock_create_session .assert_called_once_with (profile )
381+ await _sign_request_hook (region , service , mock_session , request )
389382
390383 # Verify request was signed
391384 assert 'authorization' in request .headers
392385 assert 'x-amz-date' in request .headers
393386
394- @patch ('mcp_proxy_for_aws.sigv4_helper.create_aws_session' )
395387 @pytest .mark .asyncio
396- async def test_sign_request_hook_sets_content_length (self , mock_create_session ):
388+ async def test_sign_request_hook_sets_content_length (self ):
397389 """Test that sign_request_hook sets Content-Length header."""
398390 # Setup mocks
399- mock_create_session . return_value = create_mock_session ()
391+ mock_session = create_mock_session ()
400392
401393 region = 'eu-west-1'
402394 service = 'lambda'
403- profile = None
404395
405396 # Create request
406397 request_body = b'test content with specific length'
407398 request = httpx .Request ('POST' , 'https://example.com/api' , content = request_body )
408399
409- await _sign_request_hook (region , service , profile , request )
400+ await _sign_request_hook (region , service , mock_session , request )
410401
411402 # Verify Content-Length was set correctly
412403 assert request .headers ['content-length' ] == str (len (request_body ))
413404
414- @patch ('mcp_proxy_for_aws.sigv4_helper.create_aws_session' )
415405 @pytest .mark .asyncio
416- async def test_sign_request_hook_with_partial_application (self , mock_create_session ):
406+ async def test_sign_request_hook_with_partial_application (self ):
417407 """Test that sign_request_hook works with functools.partial."""
418408 # Setup mocks
419- mock_create_session . return_value = create_mock_session ()
409+ mock_session = create_mock_session ()
420410
421411 region = 'ap-southeast-1'
422412 service = 'execute-api'
423- profile = 'prod-profile'
424413
425414 # Create curried function using partial
426- curried_hook = partial (_sign_request_hook , region , service , profile )
415+ curried_hook = partial (_sign_request_hook , region , service , mock_session )
427416
428417 request_body = b'request data'
429418 request = httpx .Request ('POST' , 'https://example.com/mcp' , content = request_body )
@@ -434,4 +423,3 @@ async def test_sign_request_hook_with_partial_application(self, mock_create_sess
434423 # Verify request was signed
435424 assert 'authorization' in request .headers
436425 assert 'x-amz-date' in request .headers
437- mock_create_session .assert_called_once_with (profile )
0 commit comments