Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ SHS_MCP_DEBUG=true # Enable debug mode (default: false)
SHS_MCP_ADDRESS=0.0.0.0 # Address for MCP server (default: localhost)
SHS_MCP_TRANSPORT=streamable-http

# Transport Security Settings (DNS rebinding protection)
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
# SHS_MCP_TRANSPORT_SECURITY_ENABLE_DNS_REBINDING_PROTECTION=true
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_HOSTS=["localhost:*","127.0.0.1:*","your-gateway:*"]
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_ORIGINS=["http://localhost:*","http://127.0.0.1:*"]

# Spark History Server Settings
# SHS_SERVERS_*_URL - URL for a specific server
# SHS_SERVERS_*_AUTH_USERNAME - Username for a specific server
Expand Down
25 changes: 25 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,37 @@ mcp:
debug: true
address: localhost

# Transport security settings for DNS rebinding protection
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
# This is only relevant when actual mcp package version is higher than 1.23.0
#transport_security:
# Enable DNS rebinding protection. Set to true for production deployments
# with proper allowed_hosts/allowed_origins configuration.
#enable_dns_rebinding_protection: false

# List of allowed Host header values. Required when enable_dns_rebinding_protection is true.
# Supports wildcard ports (e.g., "localhost:*", "127.0.0.1:*", "your-gateway:*").
# allowed_hosts:
# - "localhost:*"
# - "127.0.0.1:*"
# - "your-proxy-domain:*"

# List of allowed Origin header values. Required when enable_dns_rebinding_protection is true.
# Supports wildcard ports (e.g., "http://localhost:*", "http://your-gateway:*").
# allowed_origins:
# - "http://localhost:*"
# - "http://127.0.0.1:*"
# - "http://your-proxy-domain:*"


# Available Environment Variables:
# SHS_MCP_PORT - Port for MCP server (default: 18888)
# SHS_MCP_DEBUG - Enable debug mode (default: false)
# SHS_MCP_ADDRESS - Address for MCP server (default: localhost)
# SHS_MCP_TRANSPORT - MCP transport mode (default: streamable-http)
# SHS_MCP_TRANSPORT_SECURITY_ENABLE_DNS_REBINDING_PROTECTION - Enable DNS rebinding protection (true/false)
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_HOSTS - JSON array of allowed hosts (e.g., '["localhost:*","127.0.0.1:*"]')
# SHS_MCP_TRANSPORT_SECURITY_ALLOWED_ORIGINS - JSON array of allowed origins (e.g., '["http://localhost:*"]')
# SHS_SERVERS_*_URL - URL for a specific server
# SHS_SERVERS_*_AUTH_USERNAME - Username for a specific server
# SHS_SERVERS_*_AUTH_PASSWORD - Password for a specific server
Expand Down
28 changes: 28 additions & 0 deletions src/spark_history_mcp/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@ class ServerConfig(BaseSettings):
include_plan_description: Optional[bool] = None


class TransportSecurityConfig(BaseSettings):
"""Transport security configuration for DNS rebinding protection.

See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
"""

enable_dns_rebinding_protection: bool = Field(
default=False,
description="Enable DNS rebinding protection. Set to True for production "
"deployments with proper allowed_hosts configuration.",
)
allowed_hosts: List[str] = Field(
default_factory=list,
description="List of allowed Host header values. Supports wildcard ports "
'(e.g., "localhost:*", "127.0.0.1:*", "your-gateway:*").',
)
allowed_origins: List[str] = Field(
default_factory=list,
description="List of allowed Origin header values. Supports wildcard ports "
'(e.g., "http://localhost:*", "http://your-gateway:*").',
)
model_config = SettingsConfigDict(extra="ignore")


class McpConfig(BaseSettings):
"""Configuration for the MCP server."""

Expand All @@ -75,6 +99,10 @@ class McpConfig(BaseSettings):
address: Optional[str] = "localhost"
port: Optional[int | str] = "18888"
debug: Optional[bool] = False
transport_security: Optional[TransportSecurityConfig] = Field(
default=None,
description="Transport security settings for DNS rebinding protection.",
)
model_config = SettingsConfigDict(extra="ignore")


Expand Down
18 changes: 18 additions & 0 deletions src/spark_history_mcp/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@

from mcp.server.fastmcp import FastMCP

# For handling different mcp version
try:
# mcp version higher than 1.23.0 we are able to import TransportSecuritySettings
from mcp.server.transport_security import TransportSecuritySettings
except ImportError:
TransportSecuritySettings = None

Comment on lines +12 to +17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will open an issue to remove this in the future.

from spark_history_mcp.api.emr_persistent_ui_client import EMRPersistentUIClient
from spark_history_mcp.api.spark_client import SparkRestClient
from spark_history_mcp.config.config import Config
Expand Down Expand Up @@ -74,6 +81,17 @@ def run(config: Config):
mcp.settings.host = config.mcp.address
mcp.settings.port = int(config.mcp.port)
mcp.settings.debug = bool(config.mcp.debug)

# Configure transport security settings for DNS rebinding protection
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
if config.mcp.transport_security:
ts_config = config.mcp.transport_security
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=ts_config.enable_dns_rebinding_protection,
allowed_hosts=ts_config.allowed_hosts,
allowed_origins=ts_config.allowed_origins,
)

mcp.run(transport=os.getenv("SHS_MCP_TRANSPORT", config.mcp.transports[0]))


Expand Down
243 changes: 242 additions & 1 deletion tests/unit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import yaml

from spark_history_mcp.config.config import AuthConfig, Config, ServerConfig
from spark_history_mcp.config.config import (
AuthConfig,
Config,
ServerConfig,
TransportSecurityConfig,
)


class TestConfig(unittest.TestCase):
Expand Down Expand Up @@ -185,3 +190,239 @@ def test_model_serialization(self):
# Test with explicit exclude
server_dict = server.model_dump(exclude={"auth"})
self.assertNotIn("auth", server_dict)


class TestTransportSecurityConfig(unittest.TestCase):
"""Test cases for TransportSecurityConfig.

See: https://github.com/modelcontextprotocol/python-sdk/issues/1798
"""

def test_transport_security_default_values(self):
"""Test that transport security defaults are set correctly."""
ts_config = TransportSecurityConfig()

# Default should be disabled for backwards compatibility
self.assertFalse(ts_config.enable_dns_rebinding_protection)
self.assertEqual(ts_config.allowed_hosts, [])
self.assertEqual(ts_config.allowed_origins, [])

def test_transport_security_from_yaml(self):
"""Test loading transport security from YAML config."""
config_data = {
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
"mcp": {
"transports": ["streamable-http"],
"port": "18888",
"transport_security": {
"enable_dns_rebinding_protection": True,
"allowed_hosts": ["localhost:*", "127.0.0.1:*", "my-gateway:*"],
"allowed_origins": ["http://localhost:*", "http://127.0.0.1:*"],
},
},
}

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
yaml.dump(config_data, temp_file)
temp_file_path = temp_file.name

try:
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
config = Config()

# Verify transport security config
ts = config.mcp.transport_security
self.assertIsNotNone(ts)
self.assertTrue(ts.enable_dns_rebinding_protection)
self.assertEqual(
ts.allowed_hosts, ["localhost:*", "127.0.0.1:*", "my-gateway:*"]
)
self.assertEqual(
ts.allowed_origins, ["http://localhost:*", "http://127.0.0.1:*"]
)
finally:
os.unlink(temp_file_path)

def test_transport_security_disabled_in_yaml(self):
"""Test explicitly disabling transport security in YAML."""
config_data = {
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
"mcp": {
"transports": ["streamable-http"],
"transport_security": {
"enable_dns_rebinding_protection": False,
},
},
}

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
yaml.dump(config_data, temp_file)
temp_file_path = temp_file.name

try:
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
config = Config()

ts = config.mcp.transport_security
self.assertIsNotNone(ts)
self.assertFalse(ts.enable_dns_rebinding_protection)
finally:
os.unlink(temp_file_path)

def test_transport_security_default_when_not_specified(self):
"""Test transport security defaults when not specified in config."""
config_data = {
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
"mcp": {"transports": ["streamable-http"]},
}

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
yaml.dump(config_data, temp_file)
temp_file_path = temp_file.name

try:
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
config = Config()

# Transport security should have default values
ts = config.mcp.transport_security
self.assertIsNotNone(ts)
self.assertFalse(ts.enable_dns_rebinding_protection)
self.assertEqual(ts.allowed_hosts, [])
self.assertEqual(ts.allowed_origins, [])
finally:
os.unlink(temp_file_path)

def test_transport_security_integration_with_mcp_library(self):
"""Test that transport security config integrates with MCP library."""
from mcp.server.transport_security import TransportSecuritySettings

# Create config with transport security enabled
ts_config = TransportSecurityConfig(
enable_dns_rebinding_protection=True,
allowed_hosts=["localhost:*", "127.0.0.1:*"],
allowed_origins=["http://localhost:*"],
)

# Convert to MCP library's TransportSecuritySettings
ts_settings = TransportSecuritySettings(
enable_dns_rebinding_protection=ts_config.enable_dns_rebinding_protection,
allowed_hosts=ts_config.allowed_hosts,
allowed_origins=ts_config.allowed_origins,
)

# Verify the settings are correctly transferred
self.assertTrue(ts_settings.enable_dns_rebinding_protection)
self.assertEqual(ts_settings.allowed_hosts, ["localhost:*", "127.0.0.1:*"])
self.assertEqual(ts_settings.allowed_origins, ["http://localhost:*"])

def test_transport_security_partial_config(self):
"""Test transport security with partial configuration."""
config_data = {
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
"mcp": {
"transports": ["streamable-http"],
"transport_security": {
"enable_dns_rebinding_protection": True,
# Only specifying allowed_hosts, not allowed_origins
"allowed_hosts": ["localhost:*"],
},
},
}

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
yaml.dump(config_data, temp_file)
temp_file_path = temp_file.name

try:
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
config = Config()

ts = config.mcp.transport_security
self.assertTrue(ts.enable_dns_rebinding_protection)
self.assertEqual(ts.allowed_hosts, ["localhost:*"])
# allowed_origins should default to empty list
self.assertEqual(ts.allowed_origins, [])
finally:
os.unlink(temp_file_path)

def test_transport_security_wildcard_patterns(self):
"""Test various wildcard patterns for hosts and origins."""
ts_config = TransportSecurityConfig(
enable_dns_rebinding_protection=True,
allowed_hosts=[
"localhost:*",
"127.0.0.1:*",
"192.168.1.100:*",
"my-gateway.example.com:*",
"internal-service:8080", # Specific port
],
allowed_origins=[
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1:*",
"https://my-gateway.example.com:*",
"http://internal-service:8080", # Specific port
],
)

# Verify all patterns are stored correctly
self.assertEqual(len(ts_config.allowed_hosts), 5)
self.assertEqual(len(ts_config.allowed_origins), 5)
self.assertIn("localhost:*", ts_config.allowed_hosts)
self.assertIn("internal-service:8080", ts_config.allowed_hosts)
self.assertIn("http://localhost:*", ts_config.allowed_origins)
self.assertIn("https://localhost:*", ts_config.allowed_origins)


class TestAppTransportSecurityIntegration(unittest.TestCase):
"""Test app.py integration with transport security settings."""

def test_app_run_configures_transport_security(self):
"""Test that app.run() correctly configures transport security."""
from mcp.server.transport_security import TransportSecuritySettings

from spark_history_mcp.core.app import mcp

config_data = {
"servers": {"local": {"url": "http://localhost:18080", "default": True}},
"mcp": {
"transports": ["streamable-http"],
"port": "18888",
"address": "localhost",
"debug": False,
"transport_security": {
"enable_dns_rebinding_protection": True,
"allowed_hosts": ["localhost:*", "test-gateway:*"],
"allowed_origins": ["http://localhost:*"],
},
},
}

with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
yaml.dump(config_data, temp_file)
temp_file_path = temp_file.name

try:
with patch.dict(os.environ, {"SHS_MCP_CONFIG": temp_file_path}):
config = Config()

# Manually apply the transport security settings as run() would
if config.mcp.transport_security:
ts_config = config.mcp.transport_security
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=ts_config.enable_dns_rebinding_protection,
allowed_hosts=ts_config.allowed_hosts,
allowed_origins=ts_config.allowed_origins,
)

# Verify settings were applied
ts = mcp.settings.transport_security
self.assertIsNotNone(ts)
self.assertTrue(ts.enable_dns_rebinding_protection)
self.assertEqual(ts.allowed_hosts, ["localhost:*", "test-gateway:*"])
self.assertEqual(ts.allowed_origins, ["http://localhost:*"])
finally:
os.unlink(temp_file_path)
# Reset to None to avoid affecting other tests
mcp.settings.transport_security = None