Skip to content

Commit 1ae68d2

Browse files
authored
fix(server.py,-utils.py): fix issue where regions weren't parsed properly (#37)
Ensure that regions are resolved in the following order: 1. from --region param 2. from url 3. from environment variable and throws an error if none of these are available. unit tests were also updated to match.
1 parent e64a95f commit 1ae68d2

File tree

4 files changed

+184
-30
lines changed

4 files changed

+184
-30
lines changed

aws_mcp_proxy/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,11 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
4747

4848
# Validate and determine service
4949
service = determine_service_name(args.endpoint, args.service)
50+
logger.debug('Using service: %s', service)
5051

5152
# Validate and determine region
5253
region = determine_aws_region(args.endpoint, args.region)
54+
logger.debug('Using region: %s', region)
5355

5456
# Get profile
5557
profile = args.profile
@@ -134,7 +136,7 @@ def parse_args():
134136
parser.add_argument(
135137
'--region',
136138
help='AWS region to use (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)',
137-
default=os.getenv('AWS_REGION', 'us-east-1'),
139+
default=None,
138140
)
139141

140142
parser.add_argument(

aws_mcp_proxy/utils.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515
"""Utility functions for the AWS MCP Proxy."""
1616

1717
import httpx
18+
import logging
19+
import os
1820
import re
1921
from aws_mcp_proxy.sigv4_helper import create_sigv4_client
2022
from fastmcp.client.transports import StreamableHttpTransport
2123
from typing import Dict, Optional
2224
from urllib.parse import urlparse
2325

2426

27+
logger = logging.getLogger(__name__)
28+
29+
2530
def create_transport_with_sigv4(
2631
url: str,
2732
service: str,
@@ -92,7 +97,7 @@ def determine_service_name(endpoint: str, service: Optional[str] = None) -> str:
9297
return determined_service
9398

9499

95-
def determine_aws_region(endpoint: str, region: Optional[str] = None) -> str:
100+
def determine_aws_region(endpoint: str, region: Optional[str]) -> str:
96101
"""Validate and determine the AWS region.
97102
98103
Args:
@@ -106,6 +111,7 @@ def determine_aws_region(endpoint: str, region: Optional[str] = None) -> str:
106111
ValueError: If region cannot be determined
107112
"""
108113
if region:
114+
logger.debug('Region determined through explicit parameter')
109115
return region
110116

111117
# Parse AWS region from endpoint URL
@@ -114,11 +120,16 @@ def determine_aws_region(endpoint: str, region: Optional[str] = None) -> str:
114120

115121
# Extract region name (pattern: service.region.api.aws or service-name.region.api.aws)
116122
region_match = re.search(r'\.([a-z0-9-]+)\.api\.aws', hostname)
117-
determined_region = region_match.group(1) if region_match else None
118-
119-
if not determined_region:
120-
raise ValueError(
121-
f"Could not determine AWS region from endpoint '{endpoint}'. "
122-
'Please provide the region explicitly using --region argument.'
123-
)
124-
return determined_region
123+
if region_match:
124+
logger.debug('Region determined through endpoint URL')
125+
return region_match.group(1)
126+
127+
environment_region = os.getenv('AWS_REGION')
128+
if environment_region:
129+
logger.debug('Region determined through environment variable')
130+
return environment_region
131+
132+
raise ValueError(
133+
f"Could not determine AWS region from endpoint '{endpoint}' or from environment variable AWS_REGION. "
134+
'Please provide the region explicitly using --region argument.'
135+
)

tests/unit/test_server.py

Lines changed: 128 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@
1515
"""Tests for the aws-mcp-proxy Server."""
1616

1717
import pytest
18-
from aws_mcp_proxy.server import main, parse_args, setup_mcp_mode
18+
from aws_mcp_proxy.server import (
19+
add_retry_middleware,
20+
add_tool_filtering_middleware,
21+
main,
22+
parse_args,
23+
setup_mcp_mode,
24+
)
1925
from aws_mcp_proxy.sigv4_helper import create_sigv4_client
2026
from aws_mcp_proxy.utils import determine_service_name
27+
from fastmcp.server.server import FastMCP
2128
from unittest.mock import AsyncMock, Mock, patch
2229

2330

@@ -26,10 +33,22 @@ class TestServer:
2633

2734
@patch('aws_mcp_proxy.server.create_transport_with_sigv4')
2835
@patch('aws_mcp_proxy.server.FastMCP.as_proxy')
29-
async def test_setup_mcp_mode(self, mock_as_proxy, mock_create_transport):
36+
@patch('aws_mcp_proxy.server.determine_aws_region')
37+
@patch('aws_mcp_proxy.server.determine_service_name')
38+
@patch('aws_mcp_proxy.server.add_tool_filtering_middleware')
39+
@patch('aws_mcp_proxy.server.add_retry_middleware')
40+
async def test_setup_mcp_mode(
41+
self,
42+
mock_add_retry,
43+
mock_add_filtering,
44+
mock_determine_service,
45+
mock_determine_region,
46+
mock_as_proxy,
47+
mock_create_transport,
48+
):
3049
"""Test that MCP mode is set up correctly."""
3150
# Arrange
32-
mock_mcp = Mock()
51+
local_mcp = Mock(spec=FastMCP)
3352
mock_args = Mock()
3453
mock_args.endpoint = 'https://test.example.com'
3554
mock_args.service = 'test-service'
@@ -38,6 +57,10 @@ async def test_setup_mcp_mode(self, mock_as_proxy, mock_create_transport):
3857
mock_args.read_only = True
3958
mock_args.retries = 1
4059

60+
# Mock return values
61+
mock_determine_service.return_value = 'test-service'
62+
mock_determine_region.return_value = 'us-east-1'
63+
4164
# Mock the transport and proxy
4265
mock_transport = Mock()
4366
mock_create_transport.return_value = mock_transport
@@ -46,25 +69,46 @@ async def test_setup_mcp_mode(self, mock_as_proxy, mock_create_transport):
4669
mock_as_proxy.return_value = mock_proxy
4770

4871
# Act
49-
await setup_mcp_mode(mock_mcp, mock_args)
72+
await setup_mcp_mode(local_mcp, mock_args)
5073

5174
# Assert
52-
mock_create_transport.assert_called_once()
75+
mock_determine_service.assert_called_once_with('https://test.example.com', 'test-service')
76+
mock_determine_region.assert_called_once_with('https://test.example.com', 'us-east-1')
77+
mock_create_transport.assert_called_once_with(
78+
'https://test.example.com', 'test-service', 'us-east-1', None
79+
)
5380
mock_as_proxy.assert_called_once_with(mock_transport)
81+
mock_add_filtering.assert_called_once_with(mock_proxy, True)
82+
mock_add_retry.assert_called_once_with(mock_proxy, 1)
83+
mock_proxy.run_async.assert_called_once()
5484

5585
@patch('aws_mcp_proxy.server.create_transport_with_sigv4')
5686
@patch('aws_mcp_proxy.server.FastMCP.as_proxy')
57-
async def test_setup_mcp_mode_with_tools(self, mock_as_proxy, mock_create_transport):
58-
"""Test that MCP mode registers tools correctly."""
87+
@patch('aws_mcp_proxy.server.determine_aws_region')
88+
@patch('aws_mcp_proxy.server.determine_service_name')
89+
@patch('aws_mcp_proxy.server.add_tool_filtering_middleware')
90+
async def test_setup_mcp_mode_no_retries(
91+
self,
92+
mock_add_filtering,
93+
mock_determine_service,
94+
mock_determine_region,
95+
mock_as_proxy,
96+
mock_create_transport,
97+
):
98+
"""Test that MCP mode setup without retries doesn't add retry middleware."""
5999
# Arrange
60-
mock_mcp = Mock()
100+
local_mcp = Mock(spec=FastMCP)
61101
mock_args = Mock()
62102
mock_args.endpoint = 'https://test.example.com'
63103
mock_args.service = 'test-service'
64104
mock_args.region = 'us-east-1'
65-
mock_args.profile = None
66-
mock_args.read_only = True
67-
mock_args.retries = 1
105+
mock_args.profile = 'test-profile'
106+
mock_args.read_only = False
107+
mock_args.retries = 0 # No retries
108+
109+
# Mock return values
110+
mock_determine_service.return_value = 'test-service'
111+
mock_determine_region.return_value = 'us-east-1'
68112

69113
# Mock the transport and proxy
70114
mock_transport = Mock()
@@ -74,17 +118,88 @@ async def test_setup_mcp_mode_with_tools(self, mock_as_proxy, mock_create_transp
74118
mock_as_proxy.return_value = mock_proxy
75119

76120
# Act
77-
await setup_mcp_mode(mock_mcp, mock_args)
121+
await setup_mcp_mode(local_mcp, mock_args)
78122

79123
# Assert
80-
mock_create_transport.assert_called_once()
124+
mock_determine_service.assert_called_once_with('https://test.example.com', 'test-service')
125+
mock_determine_region.assert_called_once_with('https://test.example.com', 'us-east-1')
126+
mock_create_transport.assert_called_once_with(
127+
'https://test.example.com', 'test-service', 'us-east-1', 'test-profile'
128+
)
81129
mock_as_proxy.assert_called_once_with(mock_transport)
130+
mock_add_filtering.assert_called_once_with(mock_proxy, False)
131+
mock_proxy.run_async.assert_called_once()
132+
133+
def test_add_tool_filtering_middleware(self):
134+
"""Test that tool filtering middleware is added correctly."""
135+
# Arrange
136+
mock_mcp = Mock()
137+
138+
# Act
139+
add_tool_filtering_middleware(mock_mcp, read_only=True)
140+
141+
# Assert
142+
mock_mcp.add_middleware.assert_called_once()
143+
# Verify that the middleware added is a ToolFilteringMiddleware
144+
call_args = mock_mcp.add_middleware.call_args[0][0]
145+
from aws_mcp_proxy.middleware.tool_filter import ToolFilteringMiddleware
146+
147+
assert isinstance(call_args, ToolFilteringMiddleware)
148+
assert call_args.read_only is True
149+
150+
def test_add_retry_middleware(self):
151+
"""Test that retry middleware is added correctly."""
152+
# Arrange
153+
mock_mcp = Mock()
154+
155+
# Act
156+
add_retry_middleware(mock_mcp, retries=5)
157+
158+
# Assert
159+
mock_mcp.add_middleware.assert_called_once()
160+
# Verify that the middleware added is a RetryMiddleware
161+
call_args = mock_mcp.add_middleware.call_args[0][0]
162+
from fastmcp.server.middleware.error_handling import RetryMiddleware
163+
164+
assert isinstance(call_args, RetryMiddleware)
82165

83166
@patch('sys.argv', ['test', 'https://test.example.com'])
84167
def test_parse_args_default(self):
85168
"""Test parse_args with default arguments."""
86169
args = parse_args()
87170
assert args.endpoint == 'https://test.example.com'
171+
assert args.service is None
172+
assert args.region is None
173+
assert args.profile is None
174+
assert args.read_only is False
175+
assert args.log_level == 'INFO'
176+
assert args.retries == 0
177+
178+
@patch(
179+
'sys.argv',
180+
[
181+
'test',
182+
'https://test.example.com',
183+
'--service',
184+
'custom-service',
185+
'--region',
186+
'us-west-2',
187+
'--read-only',
188+
'--log-level',
189+
'DEBUG',
190+
'--retries',
191+
'5',
192+
],
193+
)
194+
def test_parse_args_custom(self):
195+
"""Test parse_args with custom arguments."""
196+
args = parse_args()
197+
assert args.endpoint == 'https://test.example.com'
198+
assert args.service == 'custom-service'
199+
assert args.region == 'us-west-2'
200+
assert args.read_only is True
201+
assert args.log_level == 'DEBUG'
202+
assert args.retries == 5
88203

89204
@patch('aws_mcp_proxy.server.asyncio.run')
90205
@patch('sys.argv', ['test', 'https://test.example.com'])

tests/unit/test_utils.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,40 +153,66 @@ def test_validate_service_name_invalid_url_failure(self):
153153
class TestDetermineRegion:
154154
"""Test cases for determine_aws_region function."""
155155

156-
def test_determine_region_with_region(self):
156+
@patch('os.getenv')
157+
def test_determine_region_with_region(self, mock_getenv):
157158
"""Test determination when region is provided."""
158159
endpoint = 'https://mcp.us-east-1.api.aws/mcp'
159160
region = 'custom-region'
160161

161162
result = determine_aws_region(endpoint, region)
162163

163164
assert result == region
165+
# Environment variable should not be checked when region is provided
166+
mock_getenv.assert_not_called()
164167

165-
def test_determine_region_without_region_success(self):
168+
@patch('os.getenv')
169+
def test_determine_region_without_region_success(self, mock_getenv):
166170
"""Test determination when region is not provided but can be parsed."""
167171
endpoint = 'https://mcp.us-east-1.api.aws/mcp'
168172
expected_region = 'us-east-1'
173+
mock_getenv.return_value = None
169174

170-
result = determine_aws_region(endpoint)
175+
result = determine_aws_region(endpoint, None)
171176

172177
assert result == expected_region
178+
# Environment variable should not be checked when region can be parsed from endpoint
173179

174-
def test_determine_region_with_complex_service_name(self):
180+
@patch('os.getenv')
181+
def test_determine_region_with_complex_service_name(self, mock_getenv):
175182
"""Test parsing region from endpoint with complex service name."""
176183
endpoint = 'https://eks-mcp-beta.us-west-2.api.aws/mcp'
177184
expected_region = 'us-west-2'
185+
mock_getenv.return_value = None
178186

179-
result = determine_aws_region(endpoint)
187+
result = determine_aws_region(endpoint, None)
180188

181189
assert result == expected_region
190+
# Environment variable should not be checked when region can be parsed from endpoint
182191

183-
def test_determine_region_without_region_failure(self):
192+
@patch('os.getenv')
193+
def test_determine_region_without_region_failure(self, mock_getenv):
184194
"""Test determination when region cannot be determined."""
185195
endpoint = 'https://service.example.com'
196+
mock_getenv.return_value = None
186197

187198
with pytest.raises(ValueError) as exc_info:
188-
determine_aws_region(endpoint)
199+
determine_aws_region(endpoint, None)
189200

190201
assert 'Could not determine AWS region' in str(exc_info.value)
191202
assert endpoint in str(exc_info.value)
192203
assert '--region argument' in str(exc_info.value)
204+
mock_getenv.assert_called_once_with('AWS_REGION')
205+
206+
@patch('os.getenv')
207+
def test_determine_region_from_environment(self, mock_getenv):
208+
"""Test determination from environment variable when endpoint doesn't contain region."""
209+
# Arrange
210+
endpoint = 'https://test-service.example.com'
211+
mock_getenv.return_value = 'us-west-1'
212+
213+
# Act
214+
result = determine_aws_region(endpoint, None)
215+
216+
# Assert
217+
assert result == 'us-west-1'
218+
mock_getenv.assert_called_once_with('AWS_REGION')

0 commit comments

Comments
 (0)