diff --git a/examples/python/api/secondary-market-research-api.py b/examples/python/api/secondary-market-research-api.py index c2be4f6d3..3ed3f31e1 100644 --- a/examples/python/api/secondary-market-research-api.py +++ b/examples/python/api/secondary-market-research-api.py @@ -43,14 +43,34 @@ redoc_url="/redoc" ) -# Add CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# Add CORS middleware with secure configuration +cors_origins = os.getenv("API_CORS_ORIGINS", "").split(",") +cors_origins = [origin.strip() for origin in cors_origins if origin.strip() and origin.strip() != "*"] + +# Default secure origins if none specified +if not cors_origins: + # Secure defaults for different environments + if os.getenv("ENVIRONMENT") == "production": + # In production, require explicit configuration + cors_origins = [] + else: + # Development defaults - restrict to local origins + cors_origins = [ + "http://localhost:3000", # Development frontend + "http://localhost:8000", # Local development + "http://127.0.0.1:3000", # Local development + "http://127.0.0.1:8000", # Local development + ] + +# Only add CORS middleware if origins are specified +if cors_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Authorization", "Content-Type", "Origin", "Accept"], + ) # Create directories for storing reports REPORTS_DIR = Path("generated_reports") diff --git a/src/praisonai-agents/praisonaiagents/app/config.py b/src/praisonai-agents/praisonaiagents/app/config.py index 3c1dcc84e..eefd477ce 100644 --- a/src/praisonai-agents/praisonaiagents/app/config.py +++ b/src/praisonai-agents/praisonaiagents/app/config.py @@ -21,7 +21,7 @@ class AgentAppConfig: host: Host address to bind to (default: "0.0.0.0") port: Port number to listen on (default: 8000) reload: Enable auto-reload for development (default: False) - cors_origins: List of allowed CORS origins (default: ["*"]) + cors_origins: List of allowed CORS origins (default: []) api_prefix: API route prefix (default: "/api") docs_url: URL for API documentation (default: "/docs") openapi_url: URL for OpenAPI schema (default: "/openapi.json") @@ -44,7 +44,7 @@ class AgentAppConfig: host: str = "0.0.0.0" port: int = 8000 reload: bool = False - cors_origins: List[str] = field(default_factory=lambda: ["*"]) + cors_origins: List[str] = field(default_factory=lambda: []) api_prefix: str = "/api" docs_url: str = "/docs" openapi_url: str = "/openapi.json" diff --git a/src/praisonai-agents/praisonaiagents/gateway/config.py b/src/praisonai-agents/praisonaiagents/gateway/config.py index b5d128192..f8d737522 100644 --- a/src/praisonai-agents/praisonaiagents/gateway/config.py +++ b/src/praisonai-agents/praisonaiagents/gateway/config.py @@ -57,7 +57,7 @@ class GatewayConfig: host: str = "127.0.0.1" port: int = 8765 - cors_origins: List[str] = field(default_factory=lambda: ["*"]) + cors_origins: List[str] = field(default_factory=lambda: []) auth_token: Optional[str] = None max_connections: int = 1000 max_sessions_per_agent: int = 0 # 0 = unlimited @@ -204,7 +204,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "MultiChannelGatewayConfig": gateway_config = GatewayConfig( host=gw_data.get("host", "127.0.0.1"), port=gw_data.get("port", 8765), - cors_origins=gw_data.get("cors_origins", ["*"]), + cors_origins=gw_data.get("cors_origins", []), auth_token=gw_data.get("auth_token"), max_connections=gw_data.get("max_connections", 1000), ) diff --git a/src/praisonai-agents/praisonaiagents/server/server.py b/src/praisonai-agents/praisonaiagents/server/server.py index 3a8987687..6eb1b6cb3 100644 --- a/src/praisonai-agents/praisonaiagents/server/server.py +++ b/src/praisonai-agents/praisonaiagents/server/server.py @@ -27,7 +27,7 @@ class ServerConfig: host: str = DEFAULT_HOST port: int = DEFAULT_PORT - cors_origins: List[str] = field(default_factory=lambda: ["*"]) + cors_origins: List[str] = field(default_factory=lambda: []) auth_token: Optional[str] = None max_connections: int = 100 @@ -200,8 +200,8 @@ async def info(request): app = CORSMiddleware( app, allow_origins=self.config.cors_origins, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["Authorization", "Content-Type", "Origin", "Accept"], ) return app diff --git a/src/praisonai-agents/tests/unit/server/test_server.py b/src/praisonai-agents/tests/unit/server/test_server.py index 725d4052a..c93607010 100644 --- a/src/praisonai-agents/tests/unit/server/test_server.py +++ b/src/praisonai-agents/tests/unit/server/test_server.py @@ -22,7 +22,7 @@ def test_config_defaults(self): assert config.host == "127.0.0.1" assert config.port == 8765 - assert config.cors_origins == ["*"] + assert config.cors_origins == [] assert config.auth_token is None def test_config_custom(self): diff --git a/src/praisonai-agents/tests/unit/test_gateway_config.py b/src/praisonai-agents/tests/unit/test_gateway_config.py index c9754bcdd..e591e3c72 100644 --- a/src/praisonai-agents/tests/unit/test_gateway_config.py +++ b/src/praisonai-agents/tests/unit/test_gateway_config.py @@ -53,7 +53,7 @@ def test_gateway_config_defaults(self): config = GatewayConfig() assert config.host == "127.0.0.1" assert config.port == 8765 - assert config.cors_origins == ["*"] + assert config.cors_origins == [] assert config.auth_token is None assert config.max_connections == 1000 assert config.max_sessions_per_agent == 0 diff --git a/src/praisonai/praisonai/browser/server.py b/src/praisonai/praisonai/browser/server.py index ab84aa33e..4f8517b2d 100644 --- a/src/praisonai/praisonai/browser/server.py +++ b/src/praisonai/praisonai/browser/server.py @@ -7,6 +7,7 @@ import logging import signal import sys +import os from typing import Dict, Optional, Set from dataclasses import dataclass @@ -82,13 +83,37 @@ def _get_app(self): version="1.0.0", ) - # Enable CORS for extension + # Configure CORS origins based on environment + cors_origins = os.getenv("BROWSER_CORS_ORIGINS", "").split(",") + cors_origins = [origin.strip() for origin in cors_origins if origin.strip() and origin.strip() != "*"] + + # Default secure origins if none specified + if not cors_origins: + # Environment-specific defaults for security + if os.getenv("ENVIRONMENT") == "production": + # In production, require explicit configuration + cors_origins = [] + else: + # Development defaults - restrict to local origins + cors_origins = [ + "http://localhost:3000", # Development frontend + "http://localhost:8000", # Local development + "http://127.0.0.1:3000", # Local development + "http://127.0.0.1:8000", # Local development + ] + + # Enable CORS for extension with secure origins. + # allow_origin_regex enables Chrome extension support since extension IDs + # (chrome-extension://<32-char-id>) cannot be listed as exact strings + # and the glob pattern chrome-extension://* is NOT supported by CORSMiddleware. + # Set BROWSER_CORS_ORIGINS to restrict to a specific extension ID. app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=cors_origins, + allow_origin_regex=r"chrome-extension://[a-z0-9]{32}", allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Authorization", "Content-Type", "Origin", "Accept"], ) @app.get("/health") @@ -113,20 +138,32 @@ async def _handle_connection(self, websocket): import time import uuid import os + import re + + # Use same CORS origins configuration for WebSocket validation + cors_origins = os.getenv("BROWSER_CORS_ORIGINS", "").split(",") + cors_origins = [origin.strip() for origin in cors_origins if origin.strip() and origin.strip() != "*"] + + if not cors_origins: + if os.getenv("ENVIRONMENT") == "production": + cors_origins = [] + else: + cors_origins = [ + "http://localhost:3000", "http://localhost:8000", + "http://127.0.0.1:3000", "http://127.0.0.1:8000" + ] origin = websocket.headers.get("origin") - allowed_origins = os.environ.get("ALLOWED_ORIGINS", "").split(",") if origin: import urllib.parse parsed_origin = urllib.parse.urlparse(origin) is_allowed = False - if parsed_origin.scheme in ("http", "https") and parsed_origin.hostname in ("localhost", "127.0.0.1"): - is_allowed = True - elif parsed_origin.scheme == "chrome-extension": + # Check exact origin matches + if origin in cors_origins: is_allowed = True - - if any(origin == allowed.strip() for allowed in allowed_origins if allowed.strip()): + # Check chrome extension regex pattern (same as CORS middleware) + elif parsed_origin.scheme == "chrome-extension" and re.match(r"chrome-extension://[a-z0-9]{32}", origin): is_allowed = True if not is_allowed: diff --git a/src/praisonai/praisonai/jobs/server.py b/src/praisonai/praisonai/jobs/server.py index 32e3dd43a..e30d5221e 100644 --- a/src/praisonai/praisonai/jobs/server.py +++ b/src/praisonai/praisonai/jobs/server.py @@ -86,15 +86,36 @@ def create_app( lifespan=lifespan ) - # Add CORS middleware - origins = cors_origins or ["*"] - app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) + # Add CORS middleware with secure defaults + if cors_origins is None: + # Default secure origins based on environment + default_origins = os.getenv("JOBS_CORS_ORIGINS", "").split(",") + default_origins = [origin.strip() for origin in default_origins if origin.strip() and origin.strip() != "*"] + + if not default_origins: + # Secure defaults for different environments + if os.getenv("ENVIRONMENT") == "production": + origins = [] # No origins allowed in production without explicit config + else: + origins = [ + "http://localhost:3000", # Development frontend + "http://localhost:8000", # Local development + "http://127.0.0.1:3000", # Local development + "http://127.0.0.1:8000", # Local development + ] + else: + origins = default_origins + else: + origins = cors_origins + + if origins: # Only add CORS middleware if origins are specified + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Authorization", "Content-Type", "Origin", "Accept", "Idempotency-Key"], + ) # Add jobs router jobs_router = create_router(get_store(), get_executor()) diff --git a/src/praisonai/praisonai/mcp_server/transports/http_stream.py b/src/praisonai/praisonai/mcp_server/transports/http_stream.py index d0096390a..9211a401e 100644 --- a/src/praisonai/praisonai/mcp_server/transports/http_stream.py +++ b/src/praisonai/praisonai/mcp_server/transports/http_stream.py @@ -71,7 +71,23 @@ def __init__( self.host = host self.port = port self.endpoint = endpoint - self.cors_origins = cors_origins or ["*"] + # Environment-aware CORS origins for security + if cors_origins is None: + import os + if os.getenv("ENVIRONMENT") == "production": + # In production, require explicit configuration + self.cors_origins = [] + else: + # Development defaults - restrict to local origins + self.cors_origins = [ + "http://localhost:3000", + "http://127.0.0.1:3000", + "http://localhost:8000", + "http://127.0.0.1:8000" + ] + else: + # Validate provided origins to reject wildcards + self.cors_origins = [origin for origin in cors_origins if origin != "*"] self.api_key = api_key self.session_ttl = session_ttl self.allow_client_termination = allow_client_termination @@ -352,7 +368,7 @@ async def root(request: Request) -> Response: CORSMiddleware, allow_origins=self.cors_origins, allow_methods=["GET", "POST", "DELETE", "OPTIONS"], - allow_headers=["*"], + allow_headers=["Authorization", "Content-Type", "Origin", "Accept", "Mcp-Session-Id", "Last-Event-Id"], ), ] diff --git a/src/praisonai/praisonai/recipe/serve.py b/src/praisonai/praisonai/recipe/serve.py index 03ac5371a..dcda7bb65 100644 --- a/src/praisonai/praisonai/recipe/serve.py +++ b/src/praisonai/praisonai/recipe/serve.py @@ -31,7 +31,7 @@ - my-recipe - another-recipe preload: true -cors_origins: "*" +cors_origins: "http://localhost:3000,http://localhost:8000" rate_limit: 100 # requests per minute (0 = disabled) max_request_size: 10485760 # 10MB default enable_metrics: false # Enable /metrics endpoint