Skip to content

Commit 4cc43fd

Browse files
committed
MCP server fails to response due to 421 Misdirected Request / Invalid Host Header
Signed-off-by: zemin-piao <[email protected]>
1 parent 22208b3 commit 4cc43fd

File tree

5 files changed

+319
-1
lines changed

5 files changed

+319
-1
lines changed

.env.example

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ SHS_MCP_DEBUG=true # Enable debug mode (default: false)
66
SHS_MCP_ADDRESS=0.0.0.0 # Address for MCP server (default: localhost)
77
SHS_MCP_TRANSPORT=streamable-http
88

9+
# Transport Security Settings (DNS rebinding protection)
10+
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
11+
# SHS_MCP_TRANSPORT_SECURITY_ENABLE_DNS_REBINDING_PROTECTION=true
12+
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_HOSTS=["localhost:*","127.0.0.1:*","your-gateway:*"]
13+
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_ORIGINS=["http://localhost:*","http://127.0.0.1:*"]
14+
915
# Spark History Server Settings
1016
# SHS_SERVERS_*_URL - URL for a specific server
1117
# SHS_SERVERS_*_AUTH_USERNAME - Username for a specific server

config.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,37 @@ mcp:
4242
debug: true
4343
address: localhost
4444

45+
# Transport security settings for DNS rebinding protection
46+
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
47+
# This is only relevant when actual mcp package version is higher than 1.23.0
48+
#transport_security:
49+
# Enable DNS rebinding protection. Set to true for production deployments
50+
# with proper allowed_hosts/allowed_origins configuration.
51+
#enable_dns_rebinding_protection: false
52+
53+
# List of allowed Host header values. Required when enable_dns_rebinding_protection is true.
54+
# Supports wildcard ports (e.g., "localhost:*", "127.0.0.1:*", "your-gateway:*").
55+
# allowed_hosts:
56+
# - "localhost:*"
57+
# - "127.0.0.1:*"
58+
# - "your-proxy-domain:*"
59+
60+
# List of allowed Origin header values. Required when enable_dns_rebinding_protection is true.
61+
# Supports wildcard ports (e.g., "http://localhost:*", "http://your-gateway:*").
62+
# allowed_origins:
63+
# - "http://localhost:*"
64+
# - "http://127.0.0.1:*"
65+
# - "http://your-proxy-domain:*"
66+
4567

4668
# Available Environment Variables:
4769
# SHS_MCP_PORT - Port for MCP server (default: 18888)
4870
# SHS_MCP_DEBUG - Enable debug mode (default: false)
4971
# SHS_MCP_ADDRESS - Address for MCP server (default: localhost)
5072
# SHS_MCP_TRANSPORT - MCP transport mode (default: streamable-http)
73+
# SHS_MCP_TRANSPORT_SECURITY_ENABLE_DNS_REBINDING_PROTECTION - Enable DNS rebinding protection (true/false)
74+
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_HOSTS - JSON array of allowed hosts (e.g., '["localhost:*","127.0.0.1:*"]')
75+
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_ORIGINS - JSON array of allowed origins (e.g., '["http://localhost:*"]')
5176
# SHS_SERVERS_*_URL - URL for a specific server
5277
# SHS_SERVERS_*_AUTH_USERNAME - Username for a specific server
5378
# SHS_SERVERS_*_AUTH_PASSWORD - Password for a specific server

src/spark_history_mcp/config/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,30 @@ class ServerConfig(BaseSettings):
6666
include_plan_description: Optional[bool] = None
6767

6868

69+
class TransportSecurityConfig(BaseSettings):
70+
"""Transport security configuration for DNS rebinding protection.
71+
72+
See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
73+
"""
74+
75+
enable_dns_rebinding_protection: bool = Field(
76+
default=False,
77+
description="Enable DNS rebinding protection. Set to True for production "
78+
"deployments with proper allowed_hosts configuration.",
79+
)
80+
allowed_hosts: List[str] = Field(
81+
default_factory=list,
82+
description="List of allowed Host header values. Supports wildcard ports "
83+
'(e.g., "localhost:*", "127.0.0.1:*", "your-gateway:*").',
84+
)
85+
allowed_origins: List[str] = Field(
86+
default_factory=list,
87+
description="List of allowed Origin header values. Supports wildcard ports "
88+
'(e.g., "http://localhost:*", "http://your-gateway:*").',
89+
)
90+
model_config = SettingsConfigDict(extra="ignore")
91+
92+
6993
class McpConfig(BaseSettings):
7094
"""Configuration for the MCP server."""
7195

@@ -75,6 +99,10 @@ class McpConfig(BaseSettings):
7599
address: Optional[str] = "localhost"
76100
port: Optional[int | str] = "18888"
77101
debug: Optional[bool] = False
102+
transport_security: Optional[TransportSecurityConfig] = Field(
103+
default=None,
104+
description="Transport security settings for DNS rebinding protection.",
105+
)
78106
model_config = SettingsConfigDict(extra="ignore")
79107

80108

src/spark_history_mcp/core/app.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88

99
from mcp.server.fastmcp import FastMCP
1010

11+
# For handling different mcp version
12+
try:
13+
# mcp version higher than 1.23.0 we are able to import TransportSecuritySettings
14+
from mcp.server.transport_security import TransportSecuritySettings
15+
except ImportError:
16+
TransportSecuritySettings = None
17+
1118
from spark_history_mcp.api.emr_persistent_ui_client import EMRPersistentUIClient
1219
from spark_history_mcp.api.spark_client import SparkRestClient
1320
from spark_history_mcp.config.config import Config
@@ -74,6 +81,17 @@ def run(config: Config):
7481
mcp.settings.host = config.mcp.address
7582
mcp.settings.port = int(config.mcp.port)
7683
mcp.settings.debug = bool(config.mcp.debug)
84+
85+
# Configure transport security settings for DNS rebinding protection
86+
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
87+
if config.mcp.transport_security:
88+
ts_config = config.mcp.transport_security
89+
mcp.settings.transport_security = TransportSecuritySettings(
90+
enable_dns_rebinding_protection=ts_config.enable_dns_rebinding_protection,
91+
allowed_hosts=ts_config.allowed_hosts,
92+
allowed_origins=ts_config.allowed_origins,
93+
)
94+
7795
mcp.run(transport=os.getenv("SHS_MCP_TRANSPORT", config.mcp.transports[0]))
7896

7997

tests/unit/config.py

Lines changed: 242 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55

66
import yaml
77

8-
from spark_history_mcp.config.config import AuthConfig, Config, ServerConfig
8+
from spark_history_mcp.config.config import (
9+
AuthConfig,
10+
Config,
11+
ServerConfig,
12+
TransportSecurityConfig,
13+
)
914

1015

1116
class TestConfig(unittest.TestCase):
@@ -185,3 +190,239 @@ def test_model_serialization(self):
185190
# Test with explicit exclude
186191
server_dict = server.model_dump(exclude={"auth"})
187192
self.assertNotIn("auth", server_dict)
193+
194+
195+
class TestTransportSecurityConfig(unittest.TestCase):
196+
"""Test cases for TransportSecurityConfig.
197+
198+
See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
199+
"""
200+
201+
def test_transport_security_default_values(self):
202+
"""Test that transport security defaults are set correctly."""
203+
ts_config = TransportSecurityConfig()
204+
205+
# Default should be disabled for backwards compatibility
206+
self.assertFalse(ts_config.enable_dns_rebinding_protection)
207+
self.assertEqual(ts_config.allowed_hosts, [])
208+
self.assertEqual(ts_config.allowed_origins, [])
209+
210+
def test_transport_security_from_yaml(self):
211+
"""Test loading transport security from YAML config."""
212+
config_data = {
213+
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
214+
"mcp": {
215+
"transports": ["streamable-http"],
216+
"port": "18888",
217+
"transport_security": {
218+
"enable_dns_rebinding_protection": True,
219+
"allowed_hosts": ["localhost:*", "127.0.0.1:*", "my-gateway:*"],
220+
"allowed_origins": ["http://localhost:*", "http://127.0.0.1:*"],
221+
},
222+
},
223+
}
224+
225+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
226+
yaml.dump(config_data, temp_file)
227+
temp_file_path = temp_file.name
228+
229+
try:
230+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
231+
config = Config()
232+
233+
# Verify transport security config
234+
ts = config.mcp.transport_security
235+
self.assertIsNotNone(ts)
236+
self.assertTrue(ts.enable_dns_rebinding_protection)
237+
self.assertEqual(
238+
ts.allowed_hosts, ["localhost:*", "127.0.0.1:*", "my-gateway:*"]
239+
)
240+
self.assertEqual(
241+
ts.allowed_origins, ["http://localhost:*", "http://127.0.0.1:*"]
242+
)
243+
finally:
244+
os.unlink(temp_file_path)
245+
246+
def test_transport_security_disabled_in_yaml(self):
247+
"""Test explicitly disabling transport security in YAML."""
248+
config_data = {
249+
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
250+
"mcp": {
251+
"transports": ["streamable-http"],
252+
"transport_security": {
253+
"enable_dns_rebinding_protection": False,
254+
},
255+
},
256+
}
257+
258+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
259+
yaml.dump(config_data, temp_file)
260+
temp_file_path = temp_file.name
261+
262+
try:
263+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
264+
config = Config()
265+
266+
ts = config.mcp.transport_security
267+
self.assertIsNotNone(ts)
268+
self.assertFalse(ts.enable_dns_rebinding_protection)
269+
finally:
270+
os.unlink(temp_file_path)
271+
272+
def test_transport_security_default_when_not_specified(self):
273+
"""Test transport security defaults when not specified in config."""
274+
config_data = {
275+
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
276+
"mcp": {"transports": ["streamable-http"]},
277+
}
278+
279+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
280+
yaml.dump(config_data, temp_file)
281+
temp_file_path = temp_file.name
282+
283+
try:
284+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
285+
config = Config()
286+
287+
# Transport security should have default values
288+
ts = config.mcp.transport_security
289+
self.assertIsNotNone(ts)
290+
self.assertFalse(ts.enable_dns_rebinding_protection)
291+
self.assertEqual(ts.allowed_hosts, [])
292+
self.assertEqual(ts.allowed_origins, [])
293+
finally:
294+
os.unlink(temp_file_path)
295+
296+
def test_transport_security_integration_with_mcp_library(self):
297+
"""Test that transport security config integrates with MCP library."""
298+
from mcp.server.transport_security import TransportSecuritySettings
299+
300+
# Create config with transport security enabled
301+
ts_config = TransportSecurityConfig(
302+
enable_dns_rebinding_protection=True,
303+
allowed_hosts=["localhost:*", "127.0.0.1:*"],
304+
allowed_origins=["http://localhost:*"],
305+
)
306+
307+
# Convert to MCP library's TransportSecuritySettings
308+
ts_settings = TransportSecuritySettings(
309+
enable_dns_rebinding_protection=ts_config.enable_dns_rebinding_protection,
310+
allowed_hosts=ts_config.allowed_hosts,
311+
allowed_origins=ts_config.allowed_origins,
312+
)
313+
314+
# Verify the settings are correctly transferred
315+
self.assertTrue(ts_settings.enable_dns_rebinding_protection)
316+
self.assertEqual(ts_settings.allowed_hosts, ["localhost:*", "127.0.0.1:*"])
317+
self.assertEqual(ts_settings.allowed_origins, ["http://localhost:*"])
318+
319+
def test_transport_security_partial_config(self):
320+
"""Test transport security with partial configuration."""
321+
config_data = {
322+
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
323+
"mcp": {
324+
"transports": ["streamable-http"],
325+
"transport_security": {
326+
"enable_dns_rebinding_protection": True,
327+
# Only specifying allowed_hosts, not allowed_origins
328+
"allowed_hosts": ["localhost:*"],
329+
},
330+
},
331+
}
332+
333+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
334+
yaml.dump(config_data, temp_file)
335+
temp_file_path = temp_file.name
336+
337+
try:
338+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
339+
config = Config()
340+
341+
ts = config.mcp.transport_security
342+
self.assertTrue(ts.enable_dns_rebinding_protection)
343+
self.assertEqual(ts.allowed_hosts, ["localhost:*"])
344+
# allowed_origins should default to empty list
345+
self.assertEqual(ts.allowed_origins, [])
346+
finally:
347+
os.unlink(temp_file_path)
348+
349+
def test_transport_security_wildcard_patterns(self):
350+
"""Test various wildcard patterns for hosts and origins."""
351+
ts_config = TransportSecurityConfig(
352+
enable_dns_rebinding_protection=True,
353+
allowed_hosts=[
354+
"localhost:*",
355+
"127.0.0.1:*",
356+
"192.168.1.100:*",
357+
"my-gateway.example.com:*",
358+
"internal-service:8080", # Specific port
359+
],
360+
allowed_origins=[
361+
"http://localhost:*",
362+
"https://localhost:*",
363+
"http://127.0.0.1:*",
364+
"https://my-gateway.example.com:*",
365+
"http://internal-service:8080", # Specific port
366+
],
367+
)
368+
369+
# Verify all patterns are stored correctly
370+
self.assertEqual(len(ts_config.allowed_hosts), 5)
371+
self.assertEqual(len(ts_config.allowed_origins), 5)
372+
self.assertIn("localhost:*", ts_config.allowed_hosts)
373+
self.assertIn("internal-service:8080", ts_config.allowed_hosts)
374+
self.assertIn("http://localhost:*", ts_config.allowed_origins)
375+
self.assertIn("https://localhost:*", ts_config.allowed_origins)
376+
377+
378+
class TestAppTransportSecurityIntegration(unittest.TestCase):
379+
"""Test app.py integration with transport security settings."""
380+
381+
def test_app_run_configures_transport_security(self):
382+
"""Test that app.run() correctly configures transport security."""
383+
from mcp.server.transport_security import TransportSecuritySettings
384+
385+
from spark_history_mcp.core.app import mcp
386+
387+
config_data = {
388+
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
389+
"mcp": {
390+
"transports": ["streamable-http"],
391+
"port": "18888",
392+
"address": "localhost",
393+
"debug": False,
394+
"transport_security": {
395+
"enable_dns_rebinding_protection": True,
396+
"allowed_hosts": ["localhost:*", "test-gateway:*"],
397+
"allowed_origins": ["http://localhost:*"],
398+
},
399+
},
400+
}
401+
402+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
403+
yaml.dump(config_data, temp_file)
404+
temp_file_path = temp_file.name
405+
406+
try:
407+
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
408+
config = Config()
409+
410+
# Manually apply the transport security settings as run() would
411+
if config.mcp.transport_security:
412+
ts_config = config.mcp.transport_security
413+
mcp.settings.transport_security = TransportSecuritySettings(
414+
enable_dns_rebinding_protection=ts_config.enable_dns_rebinding_protection,
415+
allowed_hosts=ts_config.allowed_hosts,
416+
allowed_origins=ts_config.allowed_origins,
417+
)
418+
419+
# Verify settings were applied
420+
ts = mcp.settings.transport_security
421+
self.assertIsNotNone(ts)
422+
self.assertTrue(ts.enable_dns_rebinding_protection)
423+
self.assertEqual(ts.allowed_hosts, ["localhost:*", "test-gateway:*"])
424+
self.assertEqual(ts.allowed_origins, ["http://localhost:*"])
425+
finally:
426+
os.unlink(temp_file_path)
427+
# Reset to None to avoid affecting other tests
428+
mcp.settings.transport_security = None

0 commit comments

Comments
 (0)