1515"""Tests for the aws-mcp-proxy Server."""
1616
1717import pytest
18- from aws_mcp_proxy .server import main , parse_args , setup_mcp_mode
18+ from aws_mcp_proxy .server import (
19+ add_retry_middleware ,
20+ add_tool_filtering_middleware ,
21+ main ,
22+ parse_args ,
23+ setup_mcp_mode ,
24+ )
1925from aws_mcp_proxy .sigv4_helper import create_sigv4_client
2026from aws_mcp_proxy .utils import determine_service_name
27+ from fastmcp .server .server import FastMCP
2128from unittest .mock import AsyncMock , Mock , patch
2229
2330
@@ -26,10 +33,22 @@ class TestServer:
2633
2734 @patch ('aws_mcp_proxy.server.create_transport_with_sigv4' )
2835 @patch ('aws_mcp_proxy.server.FastMCP.as_proxy' )
29- async def test_setup_mcp_mode (self , mock_as_proxy , mock_create_transport ):
36+ @patch ('aws_mcp_proxy.server.determine_aws_region' )
37+ @patch ('aws_mcp_proxy.server.determine_service_name' )
38+ @patch ('aws_mcp_proxy.server.add_tool_filtering_middleware' )
39+ @patch ('aws_mcp_proxy.server.add_retry_middleware' )
40+ async def test_setup_mcp_mode (
41+ self ,
42+ mock_add_retry ,
43+ mock_add_filtering ,
44+ mock_determine_service ,
45+ mock_determine_region ,
46+ mock_as_proxy ,
47+ mock_create_transport ,
48+ ):
3049 """Test that MCP mode is set up correctly."""
3150 # Arrange
32- mock_mcp = Mock ()
51+ local_mcp = Mock (spec = FastMCP )
3352 mock_args = Mock ()
3453 mock_args .endpoint = 'https://test.example.com'
3554 mock_args .service = 'test-service'
@@ -38,6 +57,10 @@ async def test_setup_mcp_mode(self, mock_as_proxy, mock_create_transport):
3857 mock_args .read_only = True
3958 mock_args .retries = 1
4059
60+ # Mock return values
61+ mock_determine_service .return_value = 'test-service'
62+ mock_determine_region .return_value = 'us-east-1'
63+
4164 # Mock the transport and proxy
4265 mock_transport = Mock ()
4366 mock_create_transport .return_value = mock_transport
@@ -46,25 +69,46 @@ async def test_setup_mcp_mode(self, mock_as_proxy, mock_create_transport):
4669 mock_as_proxy .return_value = mock_proxy
4770
4871 # Act
49- await setup_mcp_mode (mock_mcp , mock_args )
72+ await setup_mcp_mode (local_mcp , mock_args )
5073
5174 # Assert
52- mock_create_transport .assert_called_once ()
75+ mock_determine_service .assert_called_once_with ('https://test.example.com' , 'test-service' )
76+ mock_determine_region .assert_called_once_with ('https://test.example.com' , 'us-east-1' )
77+ mock_create_transport .assert_called_once_with (
78+ 'https://test.example.com' , 'test-service' , 'us-east-1' , None
79+ )
5380 mock_as_proxy .assert_called_once_with (mock_transport )
81+ mock_add_filtering .assert_called_once_with (mock_proxy , True )
82+ mock_add_retry .assert_called_once_with (mock_proxy , 1 )
83+ mock_proxy .run_async .assert_called_once ()
5484
5585 @patch ('aws_mcp_proxy.server.create_transport_with_sigv4' )
5686 @patch ('aws_mcp_proxy.server.FastMCP.as_proxy' )
57- async def test_setup_mcp_mode_with_tools (self , mock_as_proxy , mock_create_transport ):
58- """Test that MCP mode registers tools correctly."""
87+ @patch ('aws_mcp_proxy.server.determine_aws_region' )
88+ @patch ('aws_mcp_proxy.server.determine_service_name' )
89+ @patch ('aws_mcp_proxy.server.add_tool_filtering_middleware' )
90+ async def test_setup_mcp_mode_no_retries (
91+ self ,
92+ mock_add_filtering ,
93+ mock_determine_service ,
94+ mock_determine_region ,
95+ mock_as_proxy ,
96+ mock_create_transport ,
97+ ):
98+ """Test that MCP mode setup without retries doesn't add retry middleware."""
5999 # Arrange
60- mock_mcp = Mock ()
100+ local_mcp = Mock (spec = FastMCP )
61101 mock_args = Mock ()
62102 mock_args .endpoint = 'https://test.example.com'
63103 mock_args .service = 'test-service'
64104 mock_args .region = 'us-east-1'
65- mock_args .profile = None
66- mock_args .read_only = True
67- mock_args .retries = 1
105+ mock_args .profile = 'test-profile'
106+ mock_args .read_only = False
107+ mock_args .retries = 0 # No retries
108+
109+ # Mock return values
110+ mock_determine_service .return_value = 'test-service'
111+ mock_determine_region .return_value = 'us-east-1'
68112
69113 # Mock the transport and proxy
70114 mock_transport = Mock ()
@@ -74,17 +118,88 @@ async def test_setup_mcp_mode_with_tools(self, mock_as_proxy, mock_create_transp
74118 mock_as_proxy .return_value = mock_proxy
75119
76120 # Act
77- await setup_mcp_mode (mock_mcp , mock_args )
121+ await setup_mcp_mode (local_mcp , mock_args )
78122
79123 # Assert
80- mock_create_transport .assert_called_once ()
124+ mock_determine_service .assert_called_once_with ('https://test.example.com' , 'test-service' )
125+ mock_determine_region .assert_called_once_with ('https://test.example.com' , 'us-east-1' )
126+ mock_create_transport .assert_called_once_with (
127+ 'https://test.example.com' , 'test-service' , 'us-east-1' , 'test-profile'
128+ )
81129 mock_as_proxy .assert_called_once_with (mock_transport )
130+ mock_add_filtering .assert_called_once_with (mock_proxy , False )
131+ mock_proxy .run_async .assert_called_once ()
132+
133+ def test_add_tool_filtering_middleware (self ):
134+ """Test that tool filtering middleware is added correctly."""
135+ # Arrange
136+ mock_mcp = Mock ()
137+
138+ # Act
139+ add_tool_filtering_middleware (mock_mcp , read_only = True )
140+
141+ # Assert
142+ mock_mcp .add_middleware .assert_called_once ()
143+ # Verify that the middleware added is a ToolFilteringMiddleware
144+ call_args = mock_mcp .add_middleware .call_args [0 ][0 ]
145+ from aws_mcp_proxy .middleware .tool_filter import ToolFilteringMiddleware
146+
147+ assert isinstance (call_args , ToolFilteringMiddleware )
148+ assert call_args .read_only is True
149+
150+ def test_add_retry_middleware (self ):
151+ """Test that retry middleware is added correctly."""
152+ # Arrange
153+ mock_mcp = Mock ()
154+
155+ # Act
156+ add_retry_middleware (mock_mcp , retries = 5 )
157+
158+ # Assert
159+ mock_mcp .add_middleware .assert_called_once ()
160+ # Verify that the middleware added is a RetryMiddleware
161+ call_args = mock_mcp .add_middleware .call_args [0 ][0 ]
162+ from fastmcp .server .middleware .error_handling import RetryMiddleware
163+
164+ assert isinstance (call_args , RetryMiddleware )
82165
83166 @patch ('sys.argv' , ['test' , 'https://test.example.com' ])
84167 def test_parse_args_default (self ):
85168 """Test parse_args with default arguments."""
86169 args = parse_args ()
87170 assert args .endpoint == 'https://test.example.com'
171+ assert args .service is None
172+ assert args .region is None
173+ assert args .profile is None
174+ assert args .read_only is False
175+ assert args .log_level == 'INFO'
176+ assert args .retries == 0
177+
178+ @patch (
179+ 'sys.argv' ,
180+ [
181+ 'test' ,
182+ 'https://test.example.com' ,
183+ '--service' ,
184+ 'custom-service' ,
185+ '--region' ,
186+ 'us-west-2' ,
187+ '--read-only' ,
188+ '--log-level' ,
189+ 'DEBUG' ,
190+ '--retries' ,
191+ '5' ,
192+ ],
193+ )
194+ def test_parse_args_custom (self ):
195+ """Test parse_args with custom arguments."""
196+ args = parse_args ()
197+ assert args .endpoint == 'https://test.example.com'
198+ assert args .service == 'custom-service'
199+ assert args .region == 'us-west-2'
200+ assert args .read_only is True
201+ assert args .log_level == 'DEBUG'
202+ assert args .retries == 5
88203
89204 @patch ('aws_mcp_proxy.server.asyncio.run' )
90205 @patch ('sys.argv' , ['test' , 'https://test.example.com' ])
0 commit comments