diff --git a/.env.example b/.env.example index b72f5db..d80f983 100644 --- a/.env.example +++ b/.env.example @@ -21,10 +21,23 @@ REDIS_URL=redis://localhost:6379/0 REDIS_MAX_CONNECTIONS=10 # =================== -# Security +# Authentication # =================== -# Network-based security: API is only accessible from internal Docker network -# No API keys required - all external requests are rejected at network level +# AUTH_MODE: none (dev), psk (Docker Compose), jwt (portable/production) +AUTH_MODE=none + +# PSK Mode: Pre-shared key tokens (one per scope) +# Each token grants access to specific API operations +# AUTH_TOKEN_SUBMIT=your-submit-token-here # Scope: lens:submit +# AUTH_TOKEN_READ=your-read-token-here # Scope: lens:read +# AUTH_TOKEN_ADMIN=your-admin-token-here # Scope: lens:admin (includes all) + +# JWT Mode: Validate JWT tokens with external identity provider +# AUTH_JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY----- +# AUTH_JWT_JWKS_URL=https://your-idp.com/.well-known/jwks.json +# AUTH_JWT_ISSUER=https://your-idp.com +# AUTH_JWT_AUDIENCE=lens-api +# AUTH_JWT_SCOPE_CLAIM=scope # =================== # Rate Limiting @@ -71,47 +84,10 @@ HYPERLIQUID_WS_URL=wss://api.hyperliquid.xyz/ws # =================== # AI Models # =================== -# Comma-separated list of enabled models (chatgpt, gemini, claude, deepseek) -AI_MODELS=chatgpt,gemini,claude,deepseek - -# REQUIRED: Must explicitly choose AI mode - no default to prevent accidents # Set to true for production to use real AI models # Set to false for development/testing (returns deterministic stub decisions) -# Application will fail to start if not explicitly set USE_REAL_AI=false -# ChatGPT Configuration -MODEL_CHATGPT_PROVIDER=openai -MODEL_CHATGPT_API_KEY= -MODEL_CHATGPT_MODEL_ID=gpt-4o -MODEL_CHATGPT_TIMEOUT_MS=30000 -MODEL_CHATGPT_MAX_TOKENS=1000 -MODEL_CHATGPT_PROMPT_PATH=prompts/chatgpt_wrapper_v1.md - -# Gemini Configuration -MODEL_GEMINI_PROVIDER=google -MODEL_GEMINI_API_KEY= -MODEL_GEMINI_MODEL_ID=gemini-1.5-pro -MODEL_GEMINI_TIMEOUT_MS=30000 -MODEL_GEMINI_MAX_TOKENS=1000 -MODEL_GEMINI_PROMPT_PATH=prompts/gemini_wrapper_v1.md - -# Claude Configuration -MODEL_CLAUDE_PROVIDER=anthropic -MODEL_CLAUDE_API_KEY= -MODEL_CLAUDE_MODEL_ID=claude-sonnet-4-20250514 -MODEL_CLAUDE_TIMEOUT_MS=30000 -MODEL_CLAUDE_MAX_TOKENS=1000 -MODEL_CLAUDE_PROMPT_PATH=prompts/claude_wrapper_v1.md - -# DeepSeek Configuration -MODEL_DEEPSEEK_PROVIDER=deepseek -MODEL_DEEPSEEK_API_KEY= -MODEL_DEEPSEEK_MODEL_ID=deepseek-chat -MODEL_DEEPSEEK_TIMEOUT_MS=30000 -MODEL_DEEPSEEK_MAX_TOKENS=1000 -MODEL_DEEPSEEK_PROMPT_PATH=prompts/deepseek_wrapper_v1.md - # =================== # WebSocket # =================== diff --git a/README.md b/README.md index 73aaa88..c620b4c 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ SigmaPilot Lens analyzes trading signals in real-time using multiple AI models a - **Multi-Model AI Consensus** — Get perspectives from 4 different AI providers simultaneously - **Real-Time Enrichment** — Live market data from Hyperliquid with TA indicators - **Signal Validation** — Automatic rejection of stale or price-drifted signals +- **Runtime Configuration** — Manage LLM API keys and AI prompts via API without restarts - **Production Ready** — Load tested, observable, with comprehensive failure handling ## Quick Start @@ -25,15 +26,57 @@ SigmaPilot Lens analyzes trading signals in real-time using multiple AI models a ```bash # Setup cp .env.example .env -# Add your AI API keys to .env +# Edit .env: set AUTH_MODE and configure tokens # Run make build && make up && make migrate +# Configure AI models via API (requires admin token) +curl -X PUT http://localhost:8000/api/v1/llm-configs/chatgpt \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"api_key": "sk-...", "enabled": true}' + # Verify make health ``` +## Authentication + +SigmaPilot Lens supports 3 authentication modes: + +| Mode | Use Case | Configuration | +|------|----------|---------------| +| `none` | Development | No auth required | +| `psk` | Docker Compose | Pre-shared key tokens | +| `jwt` | Production | External identity provider | + +### Scopes + +| Scope | Access | +|-------|--------| +| `lens:submit` | Submit signals | +| `lens:read` | Read events, decisions, DLQ | +| `lens:admin` | Admin operations (LLM configs, prompts, DLQ retry) + all above | + +### PSK Mode Example + +```bash +# .env +AUTH_MODE=psk +AUTH_TOKEN_SUBMIT= +AUTH_TOKEN_READ= +AUTH_TOKEN_ADMIN= + +# Usage +curl -X POST http://localhost:8000/api/v1/signals \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"symbol": "BTC-PERP", ...}' +``` + +See [Configuration Guide](docs/configuration.md#authentication) for full details. + ## Documentation | Guide | Description | diff --git a/config/feature_profiles.yaml b/config/feature_profiles.yaml index 1146b30..ae77068 100644 --- a/config/feature_profiles.yaml +++ b/config/feature_profiles.yaml @@ -1,5 +1,16 @@ # Feature Profile Configuration # Defines what market data and technical indicators to compute for each profile +# +# Available indicators: +# - ema: Exponential Moving Average (periods: list of integers) +# - sma: Simple Moving Average (periods: list of integers) +# - macd: Moving Average Convergence Divergence (fast, slow, signal) +# - rsi: Relative Strength Index (period) +# - atr: Average True Range (period) +# - bollinger: Bollinger Bands (period, std_dev) - includes BBW and rating +# - adx: Average Directional Index (period) - trend strength +# - stochastic: Stochastic Oscillator (k_period, d_period) +# - volume: Volume metrics (includes current volume and SMA20) trend_follow_v1: description: "Minimal trend-following indicators" @@ -11,6 +22,9 @@ trend_follow_v1: - name: ema params: periods: [9, 21, 50] + - name: sma + params: + periods: [20] - name: macd params: fast: 12 @@ -22,6 +36,18 @@ trend_follow_v1: - name: atr params: period: 14 + - name: bollinger + params: + period: 20 + std_dev: 2.0 + - name: adx + params: + period: 14 + - name: stochastic + params: + k_period: 14 + d_period: 3 + - name: volume market_data: - mid_price - spread_bps diff --git a/docker-compose.yml b/docker-compose.yml index 6bb30e7..b20af6b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,16 @@ services: - LOG_FORMAT=json - RATE_LIMIT_PER_MIN=${RATE_LIMIT_PER_MIN:-60} - RATE_LIMIT_BURST=${RATE_LIMIT_BURST:-120} + # Authentication + - AUTH_MODE=${AUTH_MODE:-none} + - AUTH_TOKEN_SUBMIT=${AUTH_TOKEN_SUBMIT:-} + - AUTH_TOKEN_READ=${AUTH_TOKEN_READ:-} + - AUTH_TOKEN_ADMIN=${AUTH_TOKEN_ADMIN:-} + - AUTH_JWT_PUBLIC_KEY=${AUTH_JWT_PUBLIC_KEY:-} + - AUTH_JWT_JWKS_URL=${AUTH_JWT_JWKS_URL:-} + - AUTH_JWT_ISSUER=${AUTH_JWT_ISSUER:-} + - AUTH_JWT_AUDIENCE=${AUTH_JWT_AUDIENCE:-} + - AUTH_JWT_SCOPE_CLAIM=${AUTH_JWT_SCOPE_CLAIM:-scope} depends_on: redis: condition: service_healthy diff --git a/docs/api-reference.md b/docs/api-reference.md index 83ff4ec..452fce2 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -2,13 +2,70 @@ Base URL: `http://gateway:8000/api/v1` (from within Docker network) -## Security +## Authentication -**Network-Level Security**: All endpoints are protected by network isolation. Only requests from within the Docker network (`lens-network`) are accepted. +SigmaPilot Lens supports 3 authentication modes controlled by `AUTH_MODE`: + +| Mode | Description | Use Case | +|------|-------------|----------| +| `none` | No authentication required | Development | +| `psk` | Pre-shared key tokens | Docker Compose deployments | +| `jwt` | JWT validation | Production with external IdP | + +### Authorization Header + +All protected endpoints accept a Bearer token: +``` +Authorization: Bearer +``` + +### Scopes + +| Scope | Description | Required For | +|-------|-------------|--------------| +| `lens:submit` | Submit trading signals | `POST /signals` | +| `lens:read` | Read events, decisions, DLQ | `GET /events/*`, `GET /decisions/*`, `GET /dlq/*`, `WS /ws/stream` | +| `lens:admin` | Administrative operations | `POST /dlq/*/retry`, `POST /dlq/*/resolve`, `/llm-configs/*`, `/prompts/*` | + +> **Note**: The `lens:admin` scope includes all other scopes. + +### WebSocket Authentication + +WebSocket connections use the `Sec-WebSocket-Protocol` header: +``` +Sec-WebSocket-Protocol: bearer, +``` + +### Error Responses + +`401 Unauthorized` - No valid token provided: +```json +{ + "error": { + "code": "UNAUTHORIZED", + "message": "Authentication required" + } +} +``` + +`403 Forbidden` - Insufficient permissions: +```json +{ + "error": { + "code": "FORBIDDEN", + "message": "Insufficient permissions. Required scope: lens:admin" + } +} +``` + +See [Configuration Guide](configuration.md#authentication) for setup details. + +## Network Security + +In addition to authentication, all endpoints are protected by network isolation. Only requests from within the Docker network (`lens-network`) are accepted. -- No API keys required - External requests are rejected with `403 Forbidden` -- Health endpoints (`/health`, `/ready`) are accessible for Docker health checks +- Health endpoints (`/health`, `/ready`) are always accessible --- @@ -22,6 +79,8 @@ Submit a trading signal for analysis. POST /signals ``` +**Scope**: `lens:submit` + **Headers**: | Header | Required | Description | |--------|----------|-------------| @@ -99,6 +158,8 @@ Retrieve a list of submitted events. GET /events ``` +**Scope**: `lens:read` + **Query Parameters**: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| @@ -141,6 +202,8 @@ Retrieve a specific event with full details. GET /events/{event_id} ``` +**Scope**: `lens:read` + **Response** `200 OK`: ```json { @@ -204,6 +267,8 @@ Get the current processing status of an event. GET /events/{event_id}/status ``` +**Scope**: `lens:read` + **Response** `200 OK`: ```json { @@ -226,6 +291,8 @@ Query AI model decisions. GET /decisions ``` +**Scope**: `lens:read` + **Query Parameters**: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| @@ -272,6 +339,8 @@ Get a specific decision by ID. GET /decisions/{decision_id} ``` +**Scope**: `lens:read` + **Response** `200 OK`: ```json { @@ -370,6 +439,15 @@ From within the Docker network: ws://gateway:8000/api/v1/ws/stream ``` +**Scope**: `lens:read` + +**Authentication**: When `AUTH_MODE` is `psk` or `jwt`, authenticate via the `Sec-WebSocket-Protocol` header: +``` +Sec-WebSocket-Protocol: bearer, +``` + +The server echoes back `bearer` in the response protocol if authentication succeeds. + ### Messages **Subscribe**: @@ -428,6 +506,411 @@ ws://gateway:8000/api/v1/ws/stream --- +## LLM Configuration + +Manage LLM provider configurations at runtime. Allows updating API keys, enabling/disabling models, and testing connections without container restarts. + +> **Note**: All LLM configuration endpoints require `lens:admin` scope. + +### List LLM Configurations + +``` +GET /llm-configs +``` + +**Scope**: `lens:admin` + +**Response** `200 OK`: +```json +{ + "items": [ + { + "model_name": "chatgpt", + "provider": "openai", + "model_id": "gpt-4o", + "enabled": true, + "timeout_ms": 30000, + "max_tokens": 1000, + "prompt_path": null, + "api_key_masked": "****sk12", + "validation_status": "ok", + "last_validated_at": "2024-01-15T10:00:00Z" + } + ], + "total": 1 +} +``` + +### Get LLM Configuration + +``` +GET /llm-configs/{model_name} +``` + +**Scope**: `lens:admin` + +**Path Parameters**: +| Parameter | Description | +|-----------|-------------| +| `model_name` | Model name (chatgpt, gemini, claude, deepseek) | + +**Response** `200 OK`: +```json +{ + "model_name": "chatgpt", + "provider": "openai", + "model_id": "gpt-4o", + "enabled": true, + "timeout_ms": 30000, + "max_tokens": 1000, + "prompt_path": null, + "api_key_masked": "****sk12", + "validation_status": "ok", + "last_validated_at": "2024-01-15T10:00:00Z" +} +``` + +### Create or Update LLM Configuration + +``` +PUT /llm-configs/{model_name} +``` + +**Scope**: `lens:admin` + +**Request Body**: +```json +{ + "api_key": "sk-your-api-key", + "model_id": "gpt-4o", + "enabled": true, + "timeout_ms": 30000, + "max_tokens": 1000 +} +``` + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `api_key` | string | **Yes** | API key for the provider | +| `model_id` | string | No | Model identifier (uses default if not specified) | +| `enabled` | boolean | No | Whether model is enabled (default: true) | +| `timeout_ms` | int | No | Request timeout in ms (default: 30000) | +| `max_tokens` | int | No | Max response tokens (default: 1000) | + +**Note**: The `provider` is automatically determined by the model name and cannot be changed. + +**Response** `200 OK`: +```json +{ + "model_name": "chatgpt", + "provider": "openai", + "model_id": "gpt-4o", + "enabled": true, + "timeout_ms": 30000, + "max_tokens": 1000, + "prompt_path": null, + "api_key_masked": "****sk12" +} +``` + +### Partial Update LLM Configuration + +``` +PATCH /llm-configs/{model_name} +``` + +**Scope**: `lens:admin` + +Update specific fields without providing all values. + +**Request Body**: +```json +{ + "enabled": false +} +``` + +### Delete LLM Configuration + +``` +DELETE /llm-configs/{model_name} +``` + +**Scope**: `lens:admin` + +**Response** `200 OK`: +```json +{ + "status": "deleted", + "model_name": "chatgpt" +} +``` + +### Test API Key + +Test if an API key is valid by making a minimal API call. + +``` +POST /llm-configs/{model_name}/test +``` + +**Scope**: `lens:admin` + +**Response** `200 OK`: +```json +{ + "model_name": "chatgpt", + "success": true, + "message": "API key is valid", + "latency_ms": 1250 +} +``` + +**Response** `200 OK` (failed test): +```json +{ + "model_name": "chatgpt", + "success": false, + "message": "401 Unauthorized: Invalid API key", + "latency_ms": 0 +} +``` + +### Enable/Disable Model + +Quick shortcuts to enable or disable a model. + +``` +POST /llm-configs/{model_name}/enable +POST /llm-configs/{model_name}/disable +``` + +**Scope**: `lens:admin` + +--- + +## Prompt Management + +Manage AI prompts at runtime. Prompts use a **core + wrapper** pattern: +- **Core prompts**: Shared decision logic (e.g., `core_decision`) +- **Wrapper prompts**: Provider-specific formatting (e.g., `chatgpt_wrapper`, `claude_wrapper`) + +> **Note**: All prompt management endpoints require `lens:admin` scope. + +### List Prompts + +``` +GET /prompts +``` + +**Scope**: `lens:admin` + +**Query Parameters**: +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `prompt_type` | string | - | Filter by type (`core` or `wrapper`) | +| `include_inactive` | boolean | `false` | Include inactive prompts | + +**Response** `200 OK`: +```json +{ + "items": [ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "core_decision", + "version": "v1", + "prompt_type": "core", + "model_name": null, + "content": "# Trading Signal Decision Framework...", + "content_hash": "a1b2c3d4...", + "is_active": true, + "description": "Core decision prompt for trading signals", + "created_at": "2024-01-15T10:00:00Z" + } + ], + "total": 5 +} +``` + +### Get Available Prompts + +Get a summary of available prompt versions grouped by type. + +``` +GET /prompts/available +``` + +**Scope**: `lens:admin` + +**Response** `200 OK`: +```json +{ + "core_versions": ["v1", "v2"], + "wrappers": { + "chatgpt": ["v1"], + "gemini": ["v1"], + "claude": ["v1"], + "deepseek": ["v1"] + } +} +``` + +### Get Prompt + +``` +GET /prompts/{name}/{version} +``` + +**Scope**: `lens:admin` + +**Path Parameters**: +| Parameter | Description | +|-----------|-------------| +| `name` | Prompt name (e.g., `core_decision`, `chatgpt_wrapper`) | +| `version` | Version string (e.g., `v1`, `v2`) | + +**Response** `200 OK`: +```json +{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "core_decision", + "version": "v1", + "prompt_type": "core", + "model_name": null, + "content": "# Trading Signal Decision Framework...", + "content_hash": "a1b2c3d4...", + "is_active": true, + "description": "Core decision prompt for trading signals", + "created_at": "2024-01-15T10:00:00Z" +} +``` + +### Create Prompt + +``` +POST /prompts +``` + +**Scope**: `lens:admin` + +**Request Body**: +```json +{ + "name": "core_decision", + "version": "v2", + "prompt_type": "core", + "content": "# Trading Signal Decision Framework v2...", + "description": "Updated core decision prompt" +} +``` + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `name` | string | **Yes** | Prompt name | +| `version` | string | **Yes** | Version string | +| `prompt_type` | string | **Yes** | `core` or `wrapper` | +| `content` | string | **Yes** | Prompt content | +| `model_name` | string | Conditional | Required for wrapper prompts | +| `description` | string | No | Optional description | + +**Response** `201 Created`: +```json +{ + "id": "550e8400-e29b-41d4-a716-446655440001", + "name": "core_decision", + "version": "v2", + "prompt_type": "core", + "model_name": null, + "content": "# Trading Signal Decision Framework v2...", + "content_hash": "e5f6g7h8...", + "is_active": true, + "description": "Updated core decision prompt", + "created_at": "2024-01-15T12:00:00Z" +} +``` + +### Update Prompt + +``` +PUT /prompts/{name}/{version} +``` + +**Scope**: `lens:admin` + +**Request Body**: +```json +{ + "content": "# Updated prompt content...", + "description": "Updated description" +} +``` + +### Partial Update Prompt + +``` +PATCH /prompts/{name}/{version} +``` + +**Scope**: `lens:admin` + +Partially update a prompt (e.g., activate/deactivate). + +**Request Body**: +```json +{ + "is_active": false +} +``` + +### Delete Prompt + +``` +DELETE /prompts/{name}/{version} +``` + +**Scope**: `lens:admin` + +**Response** `204 No Content` + +### Render Prompt + +Render a complete prompt with enriched event data and constraints. Useful for testing prompts before deployment. + +``` +POST /prompts/render +``` + +**Scope**: `lens:admin` + +**Request Body**: +```json +{ + "model_name": "chatgpt", + "enriched_event": { + "symbol": "BTC", + "signal_direction": "long", + "entry_price": 42000.50 + }, + "constraints": { + "max_position_size_pct": 25, + "min_hold_minutes": 15 + }, + "core_version": "v1", + "wrapper_version": "v1" +} +``` + +**Response** `200 OK`: +```json +{ + "rendered_prompt": "# Trading Signal Decision...\n\n## Event Data\n{...}", + "prompt_version": "chatgpt_v1_core_v1", + "prompt_hash": "a1b2c3d4e5f6..." +} +``` + +--- + ## Dead Letter Queue (DLQ) ### List DLQ Entries @@ -438,6 +921,8 @@ Query failed processing entries. GET /dlq ``` +**Scope**: `lens:read` + **Query Parameters**: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| @@ -479,6 +964,8 @@ Get full details of a DLQ entry including payload. GET /dlq/{dlq_id} ``` +**Scope**: `lens:read` + **Response** `200 OK`: ```json { @@ -508,6 +995,8 @@ Re-enqueue a failed entry for processing. POST /dlq/{dlq_id}/retry ``` +**Scope**: `lens:admin` + **Response** `200 OK`: ```json { @@ -533,6 +1022,8 @@ Mark an entry as manually resolved. POST /dlq/{dlq_id}/resolve ``` +**Scope**: `lens:admin` + **Request Body**: ```json { @@ -556,7 +1047,8 @@ POST /dlq/{dlq_id}/resolve | Code | HTTP Status | Description | |------|-------------|-------------| | `VALIDATION_ERROR` | 400 | Invalid request body or parameters | -| `FORBIDDEN` | 403 | Access denied (external network) | +| `UNAUTHORIZED` | 401 | No valid token provided | +| `FORBIDDEN` | 403 | Insufficient permissions or external network access | | `NOT_FOUND` | 404 | Resource not found | | `RATE_LIMITED` | 429 | Rate limit exceeded | | `INTERNAL_ERROR` | 500 | Internal server error | diff --git a/docs/architecture.md b/docs/architecture.md index 48ba8cb..b9af984 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -42,23 +42,29 @@ SigmaPilot Lens follows an event-driven architecture with clear separation of co - FastAPI application with async request handling - Schema validation using Pydantic -- Network-level security (Docker network isolation) +- 3-mode authentication: `none` / `psk` / `jwt` +- Scope-based authorization (`lens:submit`, `lens:read`, `lens:admin`) - Rate limiting (60 req/min, burst 120) - Assigns `event_id`, `received_ts`, and validates `source` **Endpoints**: -- `POST /api/v1/signals` - Submit trading signal -- `GET /api/v1/events` - List events with filtering -- `GET /api/v1/events/{event_id}` - Get event details with timeline -- `GET /api/v1/events/{event_id}/status` - Get processing status -- `GET /api/v1/decisions` - List AI decisions with filtering -- `GET /api/v1/decisions/{id}` - Get decision details -- `GET /api/v1/dlq` - List dead letter queue entries -- `GET /api/v1/dlq/{id}` - Get DLQ entry details -- `POST /api/v1/dlq/{id}/retry` - Retry failed entry -- `POST /api/v1/dlq/{id}/resolve` - Mark entry as resolved -- `GET /api/v1/health` - Health check -- `GET /api/v1/ready` - Readiness check +- `POST /api/v1/signals` - Submit trading signal (scope: `lens:submit`) +- `GET /api/v1/events` - List events with filtering (scope: `lens:read`) +- `GET /api/v1/events/{event_id}` - Get event details with timeline (scope: `lens:read`) +- `GET /api/v1/events/{event_id}/status` - Get processing status (scope: `lens:read`) +- `GET /api/v1/decisions` - List AI decisions with filtering (scope: `lens:read`) +- `GET /api/v1/decisions/{id}` - Get decision details (scope: `lens:read`) +- `GET /api/v1/dlq` - List dead letter queue entries (scope: `lens:read`) +- `GET /api/v1/dlq/{id}` - Get DLQ entry details (scope: `lens:read`) +- `POST /api/v1/dlq/{id}/retry` - Retry failed entry (scope: `lens:admin`) +- `POST /api/v1/dlq/{id}/resolve` - Mark entry as resolved (scope: `lens:admin`) +- `GET /api/v1/llm-configs` - List LLM configurations (scope: `lens:admin`) +- `PUT /api/v1/llm-configs/{model}` - Configure LLM model (scope: `lens:admin`) +- `GET /api/v1/prompts` - List AI prompts (scope: `lens:admin`) +- `POST /api/v1/prompts` - Create AI prompt (scope: `lens:admin`) +- `PUT /api/v1/prompts/{name}/{version}` - Update AI prompt (scope: `lens:admin`) +- `GET /api/v1/health` - Health check (no auth) +- `GET /api/v1/ready` - Readiness check (no auth) - `GET /api/v1/metrics` - Prometheus metrics ### 2. Redis Streams Queue @@ -108,6 +114,7 @@ SigmaPilot Lens follows an event-driven architecture with clear separation of co - Strict output schema validation - Token economy controls per model - Fallback decisions for failed evaluations (IGNORE with 0 confidence) +- **Database-backed prompts**: Prompts loaded from PostgreSQL with in-memory caching (5-min TTL) - Prompt versioning and hash tracking for reproducibility **Supported Models**: @@ -126,9 +133,9 @@ SigmaPilot Lens follows an event-driven architecture with clear separation of co **Responsibility**: Real-time decision delivery -- Plain WebSocket server +- Plain WebSocket server (scope: `lens:read`) +- Authentication via `Sec-WebSocket-Protocol: bearer,` - Subscription filters: model, symbol, event_type -- Network-level security (Docker network isolation) - Broadcast to matching subscribers ### 6. PostgreSQL Storage @@ -141,6 +148,8 @@ SigmaPilot Lens follows an event-driven architecture with clear separation of co - `model_decisions` - Per-model decisions - `dlq_entries` - Failed processing records - `processing_timeline` - Status transitions +- `llm_configs` - Runtime LLM configuration (API keys, model IDs, enabled status) +- `prompts` - Versioned AI prompts (core + wrapper pattern) ## Data Flow @@ -173,10 +182,11 @@ SigmaPilot Lens follows an event-driven architecture with clear separation of co 4. AI Evaluation └─▶ Consume from lens:signals:enriched - └─▶ Load prompts for configured models + └─▶ Load prompts from database cache (core + wrapper) + └─▶ Render prompts with enriched data └─▶ Call models in parallel └─▶ Validate output schemas - └─▶ Persist decisions + └─▶ Persist decisions with prompt version/hash └─▶ Trigger publish 5. Publishing @@ -208,11 +218,66 @@ services: ## Security Model +### Authentication + +SigmaPilot Lens implements a flexible 3-mode authentication system designed for gradual migration from development to production: + +| Mode | Use Case | Description | +|------|----------|-------------| +| `none` | Development | No authentication required | +| `psk` | Docker Compose | Pre-shared key tokens | +| `jwt` | Production | JWT validation with external IdP | + +### Authorization Scopes + +Role-based access control with 3 scopes: + +| Scope | Permissions | Endpoints | +|-------|-------------|-----------| +| `lens:submit` | Submit signals | `POST /signals` | +| `lens:read` | Read data | `GET /events/*`, `/decisions/*`, `/dlq/*`, `WS /ws/stream` | +| `lens:admin` | Admin + all above | `POST /dlq/*/retry`, `/dlq/*/resolve`, `/llm-configs/*`, `/prompts/*` | + +**Scope Hierarchy**: `lens:admin` includes all other scopes. + +### PSK Mode (Docker Compose) + +```bash +AUTH_MODE=psk +AUTH_TOKEN_SUBMIT= # Grants lens:submit +AUTH_TOKEN_READ= # Grants lens:read +AUTH_TOKEN_ADMIN= # Grants lens:admin (all scopes) +``` + +Usage: `Authorization: Bearer ` + +### JWT Mode (Production) + +- Validates signature against JWKS endpoint or public key +- Checks `exp`, `iat` claims +- Reads scopes from configurable claim (default: `scope`) +- Supports RS256, ES256, HS256 algorithms + +```bash +AUTH_MODE=jwt +AUTH_JWT_JWKS_URL=https://idp.example.com/.well-known/jwks.json +AUTH_JWT_ISSUER=https://idp.example.com +AUTH_JWT_AUDIENCE=lens-api +AUTH_JWT_SCOPE_CLAIM=scope +``` + +### WebSocket Authentication + +WebSocket connections authenticate via subprotocol header: + +``` +Sec-WebSocket-Protocol: bearer, +``` + ### Network-Level Security - All services run inside isolated Docker network (`lens-network`) - No ports exposed to host machine in production -- External requests rejected at application level -- No API keys required - network isolation provides security +- Health endpoints (`/health`, `/ready`) bypass authentication ### Rate Limiting - Per-client rate limiting diff --git a/docs/configuration.md b/docs/configuration.md index 7359c67..89acadd 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -39,16 +39,121 @@ DATABASE_URL=postgresql://lens:password@localhost:5432/lens REDIS_URL=redis://:password@localhost:6379/0 ``` -### Security +### Authentication -SigmaPilot Lens uses **network-level security** instead of API keys: +SigmaPilot Lens supports 3 authentication modes, allowing gradual migration from development to production: + +| Variable | Description | Default | Required | +|----------|-------------|---------|----------| +| `AUTH_MODE` | Authentication mode: `none`, `psk`, or `jwt` | `none` | No | + +**Authentication Modes**: + +1. **`none`** (Development): No authentication required. All requests are allowed with full admin access. +2. **`psk`** (Docker Compose): Pre-shared key tokens. Simple setup for internal deployments. +3. **`jwt`** (Portable): JWT validation. Enterprise-ready, integrates with external identity providers. + +#### Scopes + +| Scope | Description | Endpoints | +|-------|-------------|-----------| +| `lens:submit` | Submit signals | `POST /signals` | +| `lens:read` | Read events, decisions, DLQ | `GET /events/*`, `GET /decisions/*`, `GET /dlq/*`, `WS /ws/stream` | +| `lens:admin` | Administrative operations (includes all scopes) | `POST /dlq/*/retry`, `POST /dlq/*/resolve`, `/llm-configs/*`, `/prompts/*` | + +#### PSK Mode Configuration + +Pre-shared key mode uses fixed tokens configured via environment variables: + +| Variable | Description | Granted Scope | +|----------|-------------|---------------| +| `AUTH_TOKEN_SUBMIT` | Token for signal submission | `lens:submit` | +| `AUTH_TOKEN_READ` | Token for read operations | `lens:read` | +| `AUTH_TOKEN_ADMIN` | Token for admin operations | `lens:admin` (all scopes) | + +**Example**: +```bash +AUTH_MODE=psk +AUTH_TOKEN_SUBMIT=submit-secret-token-abc123 +AUTH_TOKEN_READ=read-secret-token-def456 +AUTH_TOKEN_ADMIN=admin-secret-token-ghi789 +``` + +**Usage**: +```bash +# Submit a signal with submit token +curl -X POST http://gateway:8000/api/v1/signals \ + -H "Authorization: Bearer submit-secret-token-abc123" \ + -H "Content-Type: application/json" \ + -d '{"symbol": "BTC-PERP", ...}' + +# Read events with read token +curl http://gateway:8000/api/v1/events \ + -H "Authorization: Bearer read-secret-token-def456" + +# Admin operations with admin token +curl http://gateway:8000/api/v1/llm-configs \ + -H "Authorization: Bearer admin-secret-token-ghi789" +``` + +#### JWT Mode Configuration + +JWT mode validates tokens against external identity providers: + +| Variable | Description | Default | Required | +|----------|-------------|---------|----------| +| `AUTH_JWT_PUBLIC_KEY` | PEM-encoded public key for validation | - | Conditional* | +| `AUTH_JWT_JWKS_URL` | URL to JWKS endpoint | - | Conditional* | +| `AUTH_JWT_ISSUER` | Expected `iss` claim | - | No | +| `AUTH_JWT_AUDIENCE` | Expected `aud` claim | - | No | +| `AUTH_JWT_SCOPE_CLAIM` | Claim containing scopes | `scope` | No | + +*Either `AUTH_JWT_PUBLIC_KEY` or `AUTH_JWT_JWKS_URL` is required in JWT mode. + +**Example**: +```bash +AUTH_MODE=jwt +AUTH_JWT_JWKS_URL=https://your-idp.com/.well-known/jwks.json +AUTH_JWT_ISSUER=https://your-idp.com +AUTH_JWT_AUDIENCE=lens-api +AUTH_JWT_SCOPE_CLAIM=scope +``` + +**JWT Requirements**: +- Algorithm: RS256, ES256, or HS256 +- Required claims: `exp`, `iat` +- Scopes must be space-separated in the scope claim: `"scope": "lens:submit lens:read"` + +#### WebSocket Authentication + +WebSocket connections authenticate via the `Sec-WebSocket-Protocol` header: + +``` +Sec-WebSocket-Protocol: bearer, +``` + +The server echoes back `bearer` in the response protocol if authentication succeeds. + +**Example (JavaScript)**: +```javascript +const ws = new WebSocket('ws://gateway:8000/api/v1/ws/stream', ['bearer', 'your-token-here']); + +ws.onopen = () => { + // Protocol will be 'bearer' if auth succeeded + console.log('Connected with protocol:', ws.protocol); +}; +``` + +### Network Security + +In addition to authentication, SigmaPilot Lens uses network-level security: - All services are isolated within the Docker network (`lens-network`) -- No ports are exposed to the host machine +- No ports are exposed to the host machine by default - External requests are rejected at the application level -- No API keys required +- Health check endpoints (`/health`, `/ready`) are always accessible -This is configured automatically in `docker-compose.yml` - no additional configuration needed. +This is configured automatically in `docker-compose.yml`. ### Rate Limiting @@ -110,66 +215,96 @@ This is configured automatically in `docker-compose.yml` - no additional configu | Variable | Description | Default | Required | |----------|-------------|---------|----------| -| `AI_MODELS` | Comma-separated model names | `chatgpt,gemini` | No | -| `USE_REAL_AI` | Enable real AI evaluation | `false` | **Yes for production** | +| `USE_REAL_AI` | Enable real AI evaluation | - | **Yes** | -> **⚠️ IMPORTANT**: `USE_REAL_AI` defaults to `false` for safety. When `false`, the system returns deterministic stub decisions instead of calling AI APIs. **You must set `USE_REAL_AI=true` in production** to use real AI models. +> **⚠️ IMPORTANT**: `USE_REAL_AI` must be explicitly set. When `false`, the system returns deterministic stub decisions instead of calling AI APIs. **Set `USE_REAL_AI=true` in production** to use real AI models. **Evaluation Modes**: -- `USE_REAL_AI=false` (default): Stub mode - returns deterministic decisions for testing/development +- `USE_REAL_AI=false`: Stub mode - returns deterministic decisions for testing/development - `USE_REAL_AI=true`: Real mode - calls configured AI models in parallel -#### Per-Model Configuration +#### LLM Configuration Management -For each model in `AI_MODELS`, configure: +LLM configurations (API keys, model IDs, enabled status) are managed at **runtime via API endpoints**, not environment variables. This allows: -| Variable Pattern | Description | Default | Required | -|-----------------|-------------|---------|----------| -| `MODEL_{NAME}_PROVIDER` | API provider | - | **Yes** | -| `MODEL_{NAME}_API_KEY` | API key | - | **Yes** | -| `MODEL_{NAME}_MODEL_ID` | Specific model ID | varies | No | -| `MODEL_{NAME}_TIMEOUT_MS` | Request timeout | `30000` | No | -| `MODEL_{NAME}_MAX_TOKENS` | Max response tokens | `1000` | No | -| `MODEL_{NAME}_PROMPT_PATH` | Path to prompt file | `prompts/{name}_v1.md` | No | +- **Hot reload**: Change API keys without container restart +- **Dynamic enable/disable**: Turn models on/off without redeployment +- **API key testing**: Validate keys before enabling +- **Secure storage**: Keys stored in PostgreSQL, masked in API responses -**Example (ChatGPT)**: -``` -MODEL_CHATGPT_PROVIDER=openai -MODEL_CHATGPT_API_KEY=sk-your-openai-key -MODEL_CHATGPT_MODEL_ID=gpt-4o -MODEL_CHATGPT_TIMEOUT_MS=30000 -MODEL_CHATGPT_MAX_TOKENS=1000 -MODEL_CHATGPT_PROMPT_PATH=/app/prompts/chatgpt_v1.md -``` +**Supported Models**: -**Example (Gemini)**: -``` -MODEL_GEMINI_PROVIDER=google -MODEL_GEMINI_API_KEY=your-google-ai-key -MODEL_GEMINI_MODEL_ID=gemini-1.5-pro -MODEL_GEMINI_TIMEOUT_MS=30000 -MODEL_GEMINI_MAX_TOKENS=1000 -MODEL_GEMINI_PROMPT_PATH=/app/prompts/gemini_v1.md -``` +| Model Name | Provider | Default Model ID | +|------------|----------|------------------| +| `chatgpt` | OpenAI | `gpt-4o` | +| `gemini` | Google | `gemini-1.5-pro` | +| `claude` | Anthropic | `claude-sonnet-4-20250514` | +| `deepseek` | DeepSeek | `deepseek-chat` | -**Example (Claude)**: -``` -MODEL_CLAUDE_PROVIDER=anthropic -MODEL_CLAUDE_API_KEY=sk-ant-your-anthropic-key -MODEL_CLAUDE_MODEL_ID=claude-sonnet-4-20250514 -MODEL_CLAUDE_TIMEOUT_MS=30000 -MODEL_CLAUDE_MAX_TOKENS=1000 -``` +**Configuration via API**: + +```bash +# Configure ChatGPT +curl -X PUT http://gateway:8000/api/v1/llm-configs/chatgpt \ + -H "Content-Type: application/json" \ + -d '{"api_key": "sk-...", "enabled": true}' -**Example (DeepSeek)**: +# Test the API key +curl -X POST http://gateway:8000/api/v1/llm-configs/chatgpt/test + +# List all configurations +curl http://gateway:8000/api/v1/llm-configs + +# Disable a model +curl -X PATCH http://gateway:8000/api/v1/llm-configs/chatgpt \ + -H "Content-Type: application/json" \ + -d '{"enabled": false}' ``` -MODEL_DEEPSEEK_PROVIDER=deepseek -MODEL_DEEPSEEK_API_KEY=your-deepseek-key -MODEL_DEEPSEEK_MODEL_ID=deepseek-chat -MODEL_DEEPSEEK_TIMEOUT_MS=30000 -MODEL_DEEPSEEK_MAX_TOKENS=1000 + +See [API Reference](api-reference.md#llm-configuration) for full endpoint documentation. + +#### Prompt Management + +AI prompts are stored in the database and managed via API endpoints. This enables: + +- **Versioning**: Multiple prompt versions can coexist (v1, v2, etc.) +- **Hot reload**: Update prompts without container restart +- **Audit trail**: Track who created/modified prompts and when +- **Core + Wrapper pattern**: Shared core logic with model-specific wrappers + +**Prompt Types**: +- `core`: Shared decision-making logic (e.g., `core_decision`) +- `wrapper`: Model-specific formatting (e.g., `chatgpt_wrapper`, `claude_wrapper`) + +**Configuration via API**: + +```bash +# List all prompts +curl http://gateway:8000/api/v1/prompts \ + -H "Authorization: Bearer " + +# Create a new prompt version +curl -X POST http://gateway:8000/api/v1/prompts \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "name": "core_decision", + "version": "v2", + "prompt_type": "core", + "content": "# Trading Decision Framework v2..." + }' + +# Deactivate a prompt version +curl -X PATCH http://gateway:8000/api/v1/prompts/core_decision/v1 \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"is_active": false}' ``` +On first startup, if no prompts exist in the database, the service automatically seeds from the `prompts/` directory. + +See [API Reference](api-reference.md#prompt-management) for full endpoint documentation. + ### WebSocket | Variable | Description | Default | Required | @@ -306,10 +441,20 @@ RETENTION_DAYS=180 REDIS_URL=redis://localhost:6379/0 # =================== -# Security +# Authentication # =================== -# Network-based security: API is only accessible from internal Docker network -# No API keys required - all external requests are rejected at network level +# AUTH_MODE: none (dev), psk (Docker Compose), jwt (portable) +AUTH_MODE=none + +# PSK Mode tokens (uncomment and set for psk mode) +# AUTH_TOKEN_SUBMIT=your-submit-token +# AUTH_TOKEN_READ=your-read-token +# AUTH_TOKEN_ADMIN=your-admin-token + +# JWT Mode (uncomment for jwt mode) +# AUTH_JWT_JWKS_URL=https://your-idp.com/.well-known/jwks.json +# AUTH_JWT_ISSUER=https://your-idp.com +# AUTH_JWT_AUDIENCE=lens-api # =================== # Rate Limiting @@ -340,35 +485,9 @@ STALE_CTX_S=60 # =================== # AI Models # =================== -AI_MODELS=chatgpt,gemini,claude,deepseek - -# ChatGPT Configuration -MODEL_CHATGPT_PROVIDER=openai -MODEL_CHATGPT_API_KEY=sk-your-openai-api-key -MODEL_CHATGPT_MODEL_ID=gpt-4o -MODEL_CHATGPT_TIMEOUT_MS=30000 -MODEL_CHATGPT_MAX_TOKENS=1000 - -# Gemini Configuration -MODEL_GEMINI_PROVIDER=google -MODEL_GEMINI_API_KEY=your-google-ai-api-key -MODEL_GEMINI_MODEL_ID=gemini-1.5-pro -MODEL_GEMINI_TIMEOUT_MS=30000 -MODEL_GEMINI_MAX_TOKENS=1000 - -# Claude Configuration -MODEL_CLAUDE_PROVIDER=anthropic -MODEL_CLAUDE_API_KEY=sk-ant-your-anthropic-api-key -MODEL_CLAUDE_MODEL_ID=claude-sonnet-4-20250514 -MODEL_CLAUDE_TIMEOUT_MS=30000 -MODEL_CLAUDE_MAX_TOKENS=1000 - -# DeepSeek Configuration -MODEL_DEEPSEEK_PROVIDER=deepseek -MODEL_DEEPSEEK_API_KEY=your-deepseek-api-key -MODEL_DEEPSEEK_MODEL_ID=deepseek-chat -MODEL_DEEPSEEK_TIMEOUT_MS=30000 -MODEL_DEEPSEEK_MAX_TOKENS=1000 +# Set to true for production, false for stub decisions +USE_REAL_AI=false +# LLM API keys are configured via API: /api/v1/llm-configs # =================== # WebSocket @@ -427,8 +546,10 @@ On startup, the application validates: 1. All required environment variables are set 2. Database connection is valid 3. Redis connection is valid -4. AI model API keys are present for configured models +4. `USE_REAL_AI` is explicitly set (true or false) 5. Feature profile exists 6. Policy configuration is valid YAML +LLM configurations are loaded from the database on startup. If no models are configured, the system will log a warning but continue running. Configure models via the `/api/v1/llm-configs` endpoints. + Validation errors will prevent startup and log detailed error messages. diff --git a/docs/quickstart.md b/docs/quickstart.md index bdd774e..9941893 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -38,17 +38,23 @@ cp .env.example .env ### 2. Edit Configuration -Open `.env` and configure AI model API keys: +Open `.env` and configure authentication: ```bash -# Add AI model API keys (for Phase 2+) -MODEL_CHATGPT_API_KEY=sk-... -MODEL_GEMINI_API_KEY=... -MODEL_CLAUDE_API_KEY=sk-ant-... -MODEL_DEEPSEEK_API_KEY=... +# Authentication mode: none (dev), psk (Docker Compose), jwt (production) +AUTH_MODE=psk + +# Generate secure tokens (one per scope) +# python -c "import secrets; print(secrets.token_urlsafe(32))" +AUTH_TOKEN_SUBMIT= +AUTH_TOKEN_READ= +AUTH_TOKEN_ADMIN= ``` -**Note**: No API key authentication is required. SigmaPilot Lens uses network-level security - the API is only accessible from within the Docker network. +**Scopes**: +- `AUTH_TOKEN_SUBMIT` - Submit signals (`POST /signals`) +- `AUTH_TOKEN_READ` - Read events, decisions, DLQ +- `AUTH_TOKEN_ADMIN` - Admin operations (includes all scopes) ### 3. Start Services @@ -155,16 +161,17 @@ alembic current ## Testing the API -### Access from Within Docker Network +### Authenticated API Requests -The API is only accessible from within the Docker network. Use `docker-compose exec` to interact with the API: +All API requests (except `/health` and `/ready`) require authentication via Bearer token: ```bash -# Health check -docker-compose exec gateway curl http://localhost:8000/api/v1/health +# Health check (no auth required) +curl http://localhost:8000/api/v1/health -# Submit a test signal -docker-compose exec gateway curl -X POST http://localhost:8000/api/v1/signals \ +# Submit a signal (requires submit token) +curl -X POST http://localhost:8000/api/v1/signals \ + -H "Authorization: Bearer " \ -H "Content-Type: application/json" \ -d '{ "event_type": "OPEN_SIGNAL", @@ -176,6 +183,16 @@ docker-compose exec gateway curl -X POST http://localhost:8000/api/v1/signals \ "ts_utc": "2025-01-15T10:30:00Z", "source": "test-signal" }' + +# Read events (requires read token) +curl http://localhost:8000/api/v1/events \ + -H "Authorization: Bearer " + +# Configure LLM models (requires admin token) +curl -X PUT http://localhost:8000/api/v1/llm-configs/chatgpt \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"api_key": "sk-...", "enabled": true}' ``` Response: @@ -201,9 +218,9 @@ curl http://localhost:8000/api/v1/metrics ## WebSocket Testing -WebSocket connections are also restricted to the Docker network. +WebSocket connections require authentication via the `Sec-WebSocket-Protocol` header. -### From a Container Connected to lens-network +### Authenticated WebSocket Connection ```python import asyncio @@ -211,9 +228,12 @@ import websockets import json async def subscribe(): - # Connect to gateway from within Docker network - uri = "ws://gateway:8000/api/v1/ws/stream" - async with websockets.connect(uri) as ws: + uri = "ws://localhost:8000/api/v1/ws/stream" + # Authenticate via subprotocol header + async with websockets.connect( + uri, + subprotocols=["bearer", ""] + ) as ws: # Subscribe await ws.send(json.dumps({ "action": "subscribe", @@ -229,6 +249,22 @@ async def subscribe(): asyncio.run(subscribe()) ``` +### JavaScript Example + +```javascript +const ws = new WebSocket( + 'ws://localhost:8000/api/v1/ws/stream', + ['bearer', ''] +); + +ws.onopen = () => { + console.log('Connected with protocol:', ws.protocol); + ws.send(JSON.stringify({ action: 'subscribe', filters: {} })); +}; + +ws.onmessage = (event) => console.log('Decision:', JSON.parse(event.data)); +``` + ### WebSocket Messages ```bash diff --git a/migrations/versions/20251217_0002_add_llm_configs.py b/migrations/versions/20251217_0002_add_llm_configs.py new file mode 100644 index 0000000..7a2cfdf --- /dev/null +++ b/migrations/versions/20251217_0002_add_llm_configs.py @@ -0,0 +1,44 @@ +"""Add llm_configs table for runtime API key management + +Revision ID: 0002 +Revises: 0001 +Create Date: 2025-12-17 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '0002' +down_revision: Union[str, None] = '0001' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # LLM Configs table for runtime API key management + op.create_table( + 'llm_configs', + sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')), + sa.Column('model_name', sa.String(50), nullable=False, unique=True), + sa.Column('enabled', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('provider', sa.String(50), nullable=False), + sa.Column('api_key', sa.Text(), nullable=False), + sa.Column('model_id', sa.String(100), nullable=False), + sa.Column('timeout_ms', sa.Integer(), nullable=False, server_default='30000'), + sa.Column('max_tokens', sa.Integer(), nullable=False, server_default='1000'), + sa.Column('prompt_path', sa.String(200), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()')), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()')), + sa.Column('last_validated_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('validation_status', sa.String(20), nullable=True), + ) + op.create_index('idx_llm_configs_model_name', 'llm_configs', ['model_name']) + op.create_index('idx_llm_configs_enabled', 'llm_configs', ['enabled']) + + +def downgrade() -> None: + op.drop_table('llm_configs') diff --git a/migrations/versions/20251219_0003_add_prompts.py b/migrations/versions/20251219_0003_add_prompts.py new file mode 100644 index 0000000..ae4cb0b --- /dev/null +++ b/migrations/versions/20251219_0003_add_prompts.py @@ -0,0 +1,61 @@ +"""Add prompts table for database-backed prompt storage + +Revision ID: 0003 +Revises: 0002 +Create Date: 2025-12-19 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '0003' +down_revision: Union[str, None] = '0002' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Prompts table for database-backed prompt storage + op.create_table( + 'prompts', + sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')), + sa.Column('name', sa.String(100), nullable=False), + sa.Column('version', sa.String(20), nullable=False), + sa.Column('prompt_type', sa.String(20), nullable=False), + sa.Column('model_name', sa.String(50), nullable=True), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('content_hash', sa.String(64), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()')), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()')), + sa.Column('created_by', sa.String(100), nullable=True), + ) + + # Index for looking up active prompts by name and version + op.create_index('idx_prompts_name_version', 'prompts', ['name', 'version']) + + # Index for filtering by prompt type + op.create_index('idx_prompts_type', 'prompts', ['prompt_type']) + + # Index for finding wrapper prompts by model + op.create_index('idx_prompts_model_name', 'prompts', ['model_name']) + + # Index for finding active prompts + op.create_index('idx_prompts_is_active', 'prompts', ['is_active']) + + # Unique constraint: only one active prompt per name+version + op.create_unique_constraint( + 'uq_prompts_name_version_active', + 'prompts', + ['name', 'version'], + postgresql_where=sa.text('is_active = true') + ) + + +def downgrade() -> None: + op.drop_table('prompts') diff --git a/prompts/gemini_v1.md b/prompts/gemini_v1.md index f2d45f6..14d1326 100644 --- a/prompts/gemini_v1.md +++ b/prompts/gemini_v1.md @@ -43,6 +43,7 @@ Respond with a valid JSON object containing: "atr_multiple": number (optional), "trail_pct": number (optional) }, + "size_pct": 0 to 100, "reasons": ["tag1", "tag2", ...] } diff --git a/src/api/v1/decisions.py b/src/api/v1/decisions.py index 293e523..4ea9171 100644 --- a/src/api/v1/decisions.py +++ b/src/api/v1/decisions.py @@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from src.core.auth import AuthContext, require_read from src.models.database import get_db_session from src.models.orm.decision import ModelDecision as DecisionORM from src.models.orm.event import Event @@ -86,6 +87,7 @@ async def list_decisions( limit: int = Query(50, ge=1, le=100, description="Max results"), offset: int = Query(0, ge=0, description="Pagination offset"), db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_read), ): """ Query AI model decisions. @@ -162,6 +164,7 @@ async def list_decisions( async def get_decision( decision_id: str, db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_read), ): """ Get full details of a decision. diff --git a/src/api/v1/dlq.py b/src/api/v1/dlq.py index 287789d..f42143b 100644 --- a/src/api/v1/dlq.py +++ b/src/api/v1/dlq.py @@ -45,6 +45,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession +from src.core.auth import AuthContext, require_admin, require_read from src.models.database import get_db_session from src.models.orm.dlq import DLQEntry from src.observability.logging import get_logger @@ -174,6 +175,7 @@ async def list_dlq_entries( limit: int = Query(50, ge=1, le=100, description="Max results"), offset: int = Query(0, ge=0, description="Pagination offset"), db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_read), ): """ Query DLQ entries. @@ -246,6 +248,7 @@ async def list_dlq_entries( async def get_dlq_entry( dlq_id: str, db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_read), ): """ Get full details of a DLQ entry. @@ -290,6 +293,7 @@ async def get_dlq_entry( async def retry_dlq_entry( dlq_id: str, db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_admin), ): """ Retry processing a DLQ entry. @@ -468,6 +472,7 @@ async def resolve_dlq_entry( dlq_id: str, request: DLQResolveRequest, db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_admin), ): """ Mark a DLQ entry as resolved. diff --git a/src/api/v1/events.py b/src/api/v1/events.py index 8b3b931..7f2252b 100644 --- a/src/api/v1/events.py +++ b/src/api/v1/events.py @@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from src.core.auth import AuthContext, require_read from src.models.database import get_db_session from src.models.orm.event import Event, EnrichedEvent, ProcessingTimeline from src.models.orm.decision import ModelDecision @@ -40,6 +41,7 @@ async def list_events( limit: int = Query(50, ge=1, le=100, description="Max results"), offset: int = Query(0, ge=0, description="Pagination offset"), db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_read), ): """ List events with optional filters. @@ -107,6 +109,7 @@ async def list_events( async def get_event( event_id: str, db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_read), ): """ Get full details of an event. @@ -186,6 +189,7 @@ async def get_event( async def get_event_status( event_id: str, db: AsyncSession = Depends(get_db_session), + _auth: AuthContext = Depends(require_read), ): """ Get the current processing status of an event. diff --git a/src/api/v1/llm_configs.py b/src/api/v1/llm_configs.py new file mode 100644 index 0000000..298d45d --- /dev/null +++ b/src/api/v1/llm_configs.py @@ -0,0 +1,389 @@ +"""LLM configuration management endpoints. + +This module provides REST API endpoints for managing LLM provider configurations +at runtime. Allows updating API keys, enabling/disabling models, and testing +connections without container restarts. + +Endpoints: + GET /llm-configs - List all LLM configurations + GET /llm-configs/{model} - Get configuration for a specific model + PUT /llm-configs/{model} - Create or update a model configuration + PATCH /llm-configs/{model} - Partial update (e.g., enable/disable) + DELETE /llm-configs/{model} - Delete a model configuration + POST /llm-configs/{model}/test - Test API key validity + +Security: + - Restricted to internal Docker network only + - API keys are masked in responses (only last 4 chars shown) + +Supported models: + - chatgpt (OpenAI) + - gemini (Google) + - claude (Anthropic) + - deepseek (DeepSeek) + +Usage: + # Add/update ChatGPT configuration + PUT /llm-configs/chatgpt + { + "api_key": "sk-...", + "model_id": "gpt-4o", + "enabled": true + } + + # Test the API key + POST /llm-configs/chatgpt/test + + # Disable a model + PATCH /llm-configs/chatgpt + {"enabled": false} +""" + +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from src.core.auth import AuthContext, require_admin +from src.observability.logging import get_logger +from src.services.llm_config import LLMConfigService, get_llm_config_service +from src.services.llm_config.service import MODEL_PROVIDERS, DEFAULT_MODEL_IDS + +logger = get_logger(__name__) + +router = APIRouter() + + +# Request/Response schemas +class LLMConfigCreate(BaseModel): + """Request schema for creating/updating LLM configuration.""" + + api_key: str = Field(..., min_length=1, description="API key for the provider") + model_id: Optional[str] = Field(None, description="Model identifier (e.g., gpt-4o). Uses default if not specified.") + enabled: bool = Field(True, description="Whether the model is enabled") + timeout_ms: int = Field(30000, ge=1000, le=120000, description="Request timeout in milliseconds") + max_tokens: int = Field(1000, ge=100, le=8000, description="Maximum tokens for response") + + +class LLMConfigPatch(BaseModel): + """Request schema for partial updates.""" + + enabled: Optional[bool] = Field(None, description="Enable/disable the model") + api_key: Optional[str] = Field(None, min_length=1, description="New API key") + model_id: Optional[str] = Field(None, description="New model identifier") + timeout_ms: Optional[int] = Field(None, ge=1000, le=120000) + max_tokens: Optional[int] = Field(None, ge=100, le=8000) + + +class LLMConfigResponse(BaseModel): + """Response schema for LLM configuration (API key masked).""" + + model_name: str + provider: str + model_id: str + enabled: bool + timeout_ms: int + max_tokens: int + prompt_path: Optional[str] + api_key_masked: str = Field(..., description="Masked API key (last 4 chars only)") + validation_status: Optional[str] = None + last_validated_at: Optional[datetime] = None + + +class LLMConfigListResponse(BaseModel): + """Response schema for listing configurations.""" + + items: List[LLMConfigResponse] + total: int + + +class LLMConfigTestResponse(BaseModel): + """Response schema for API key test.""" + + model_name: str + success: bool + message: str + latency_ms: int + + +def _mask_api_key(api_key: str) -> str: + """Mask API key, showing only last 4 characters.""" + if len(api_key) <= 4: + return "****" + return f"****{api_key[-4:]}" + + +def _get_service() -> LLMConfigService: + """Get the LLM config service instance.""" + return get_llm_config_service() + + +@router.get( + "", + response_model=LLMConfigListResponse, + summary="List all LLM configurations", + description="Get all configured LLM providers with masked API keys.", +) +async def list_llm_configs(_auth: AuthContext = Depends(require_admin)): + """List all LLM configurations.""" + service = _get_service() + configs = await service.list_all() + + items = [ + LLMConfigResponse( + model_name=c.model_name, + provider=c.provider, + model_id=c.model_id, + enabled=c.enabled, + timeout_ms=c.timeout_ms, + max_tokens=c.max_tokens, + prompt_path=c.prompt_path, + api_key_masked=_mask_api_key(c.api_key), + ) + for c in configs + ] + + return LLMConfigListResponse(items=items, total=len(items)) + + +@router.get( + "/{model_name}", + response_model=LLMConfigResponse, + summary="Get LLM configuration", + description="Get configuration for a specific model.", +) +async def get_llm_config(model_name: str, _auth: AuthContext = Depends(require_admin)): + """Get configuration for a specific model.""" + service = _get_service() + config = await service.get_config(model_name) + + if not config: + raise HTTPException( + status_code=404, + detail=f"Configuration not found for model: {model_name}" + ) + + return LLMConfigResponse( + model_name=config.model_name, + provider=config.provider, + model_id=config.model_id, + enabled=config.enabled, + timeout_ms=config.timeout_ms, + max_tokens=config.max_tokens, + prompt_path=config.prompt_path, + api_key_masked=_mask_api_key(config.api_key), + ) + + +@router.put( + "/{model_name}", + response_model=LLMConfigResponse, + summary="Create or update LLM configuration", + description="Create a new LLM configuration or update an existing one. Provider is determined automatically by model name.", +) +async def create_or_update_llm_config( + model_name: str, request: LLMConfigCreate, _auth: AuthContext = Depends(require_admin) +): + """Create or update an LLM configuration.""" + # Validate model name and get provider + if model_name not in MODEL_PROVIDERS: + raise HTTPException( + status_code=400, + detail=f"Invalid model name. Must be one of: {', '.join(MODEL_PROVIDERS.keys())}" + ) + + # Use predefined provider (not user-editable) + provider = MODEL_PROVIDERS[model_name] + + # Use default model_id if not specified + model_id = request.model_id or DEFAULT_MODEL_IDS[model_name] + + service = _get_service() + config = await service.create_or_update( + model_name=model_name, + provider=provider, + api_key=request.api_key, + model_id=model_id, + enabled=request.enabled, + timeout_ms=request.timeout_ms, + max_tokens=request.max_tokens, + prompt_path=None, + ) + + logger.info(f"LLM config updated for {model_name}") + + return LLMConfigResponse( + model_name=config.model_name, + provider=config.provider, + model_id=config.model_id, + enabled=config.enabled, + timeout_ms=config.timeout_ms, + max_tokens=config.max_tokens, + prompt_path=config.prompt_path, + api_key_masked=_mask_api_key(config.api_key), + ) + + +@router.patch( + "/{model_name}", + response_model=LLMConfigResponse, + summary="Partial update LLM configuration", + description="Partially update an LLM configuration (e.g., enable/disable).", +) +async def patch_llm_config( + model_name: str, request: LLMConfigPatch, _auth: AuthContext = Depends(require_admin) +): + """Partially update an LLM configuration.""" + service = _get_service() + + # Get existing config + existing = await service.get_config(model_name) + if not existing: + raise HTTPException( + status_code=404, + detail=f"Configuration not found for model: {model_name}" + ) + + # Build update with existing values as defaults + config = await service.create_or_update( + model_name=model_name, + provider=existing.provider, + api_key=request.api_key if request.api_key else existing.api_key, + model_id=request.model_id if request.model_id else existing.model_id, + enabled=request.enabled if request.enabled is not None else existing.enabled, + timeout_ms=request.timeout_ms if request.timeout_ms else existing.timeout_ms, + max_tokens=request.max_tokens if request.max_tokens else existing.max_tokens, + prompt_path=existing.prompt_path, + ) + + logger.info(f"LLM config patched for {model_name}") + + return LLMConfigResponse( + model_name=config.model_name, + provider=config.provider, + model_id=config.model_id, + enabled=config.enabled, + timeout_ms=config.timeout_ms, + max_tokens=config.max_tokens, + prompt_path=config.prompt_path, + api_key_masked=_mask_api_key(config.api_key), + ) + + +@router.delete( + "/{model_name}", + summary="Delete LLM configuration", + description="Delete an LLM configuration. The model will fall back to environment variables if configured.", +) +async def delete_llm_config(model_name: str, _auth: AuthContext = Depends(require_admin)): + """Delete an LLM configuration.""" + service = _get_service() + deleted = await service.delete(model_name) + + if not deleted: + raise HTTPException( + status_code=404, + detail=f"Configuration not found for model: {model_name}" + ) + + logger.info(f"LLM config deleted for {model_name}") + + return {"status": "deleted", "model_name": model_name} + + +@router.post( + "/{model_name}/test", + response_model=LLMConfigTestResponse, + summary="Test API key", + description="Test if the API key is valid by making a minimal API call.", +) +async def test_llm_config(model_name: str, _auth: AuthContext = Depends(require_admin)): + """Test if the API key is valid.""" + service = _get_service() + + # Check config exists + config = await service.get_config(model_name) + if not config: + raise HTTPException( + status_code=404, + detail=f"Configuration not found for model: {model_name}" + ) + + result = await service.test_api_key(model_name) + + return LLMConfigTestResponse( + model_name=model_name, + success=result["success"], + message=result["message"], + latency_ms=result["latency_ms"], + ) + + +@router.post( + "/{model_name}/enable", + response_model=LLMConfigResponse, + summary="Enable a model", + description="Enable a previously disabled model.", +) +async def enable_llm_config(model_name: str, _auth: AuthContext = Depends(require_admin)): + """Enable a model.""" + service = _get_service() + + success = await service.set_enabled(model_name, True) + if not success: + raise HTTPException( + status_code=404, + detail=f"Configuration not found for model: {model_name}" + ) + + config = await service.get_config(model_name) + return LLMConfigResponse( + model_name=config.model_name, + provider=config.provider, + model_id=config.model_id, + enabled=config.enabled, + timeout_ms=config.timeout_ms, + max_tokens=config.max_tokens, + prompt_path=config.prompt_path, + api_key_masked=_mask_api_key(config.api_key), + ) + + +@router.post( + "/{model_name}/disable", + response_model=LLMConfigResponse, + summary="Disable a model", + description="Disable a model without deleting its configuration.", +) +async def disable_llm_config(model_name: str, _auth: AuthContext = Depends(require_admin)): + """Disable a model.""" + service = _get_service() + + success = await service.set_enabled(model_name, False) + if not success: + raise HTTPException( + status_code=404, + detail=f"Configuration not found for model: {model_name}" + ) + + # Get config (will return None because disabled, so fetch directly) + async with get_llm_config_service()._cache_lock: + pass # Just to trigger refresh + await service._refresh_cache() + + # Return the disabled config + for c in await service.list_all(): + if c.model_name == model_name: + return LLMConfigResponse( + model_name=c.model_name, + provider=c.provider, + model_id=c.model_id, + enabled=c.enabled, + timeout_ms=c.timeout_ms, + max_tokens=c.max_tokens, + prompt_path=c.prompt_path, + api_key_masked=_mask_api_key(c.api_key), + ) + + raise HTTPException(status_code=404, detail="Configuration not found after update") diff --git a/src/api/v1/prompts.py b/src/api/v1/prompts.py new file mode 100644 index 0000000..776294c --- /dev/null +++ b/src/api/v1/prompts.py @@ -0,0 +1,421 @@ +"""Prompt management endpoints. + +This module provides REST API endpoints for managing AI prompts +at runtime. Allows creating, updating, and versioning prompts +without container restarts. + +Endpoints: + GET /prompts - List all prompts + GET /prompts/available - Get available prompt versions + GET /prompts/{name}/{version} - Get a specific prompt + POST /prompts - Create a new prompt + PUT /prompts/{name}/{version} - Update a prompt + PATCH /prompts/{name}/{version} - Partial update (e.g., activate/deactivate) + DELETE /prompts/{name}/{version} - Delete a prompt + POST /prompts/render - Render a prompt with data + +Security: + - All endpoints require lens:admin scope + +Prompt Types: + - core: Shared decision logic (core_decision) + - wrapper: Provider-specific formatting (chatgpt_wrapper, gemini_wrapper, etc.) + +Usage: + # Create a new core prompt version + POST /prompts + { + "name": "core_decision", + "version": "v2", + "prompt_type": "core", + "content": "# Trading Signal Decision Framework..." + } + + # Update a wrapper prompt + PUT /prompts/chatgpt_wrapper/v1 + { + "content": "..." + } + + # Deactivate a prompt version + PATCH /prompts/core_decision/v1 + {"is_active": false} +""" + +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from src.core.auth import AuthContext, require_admin +from src.observability.logging import get_logger +from src.services.prompt import PromptService, get_prompt_service + +logger = get_logger(__name__) + +router = APIRouter() + + +# Request/Response schemas +class PromptCreate(BaseModel): + """Request schema for creating a prompt.""" + + name: str = Field(..., min_length=1, max_length=100, description="Prompt name (e.g., core_decision, chatgpt_wrapper)") + version: str = Field(..., min_length=1, max_length=20, description="Version string (e.g., v1, v2)") + prompt_type: str = Field(..., description="Type: 'core' or 'wrapper'") + content: str = Field(..., min_length=1, description="Prompt content") + model_name: Optional[str] = Field(None, max_length=50, description="For wrapper prompts, which model this is for") + description: Optional[str] = Field(None, description="Optional description") + + +class PromptUpdate(BaseModel): + """Request schema for updating a prompt.""" + + content: str = Field(..., min_length=1, description="New prompt content") + description: Optional[str] = Field(None, description="Optional description") + + +class PromptPatch(BaseModel): + """Request schema for partial updates.""" + + content: Optional[str] = Field(None, min_length=1, description="New prompt content") + description: Optional[str] = Field(None, description="New description") + is_active: Optional[bool] = Field(None, description="Activate/deactivate") + + +class PromptResponse(BaseModel): + """Response schema for a prompt.""" + + id: str + name: str + version: str + prompt_type: str + model_name: Optional[str] + content: str + content_hash: str + is_active: bool + description: Optional[str] + created_at: datetime + + +class PromptListResponse(BaseModel): + """Response schema for listing prompts.""" + + items: List[PromptResponse] + total: int + + +class PromptAvailableResponse(BaseModel): + """Response schema for available prompts.""" + + core_versions: List[str] + wrappers: dict + + +class PromptRenderRequest(BaseModel): + """Request schema for rendering a prompt.""" + + model_name: str = Field(..., description="Model name (chatgpt, gemini, claude, deepseek)") + enriched_event: dict = Field(..., description="Enriched event data") + constraints: dict = Field(..., description="Trading constraints") + core_version: str = Field("v1", description="Core prompt version") + wrapper_version: str = Field("v1", description="Wrapper prompt version") + + +class PromptRenderResponse(BaseModel): + """Response schema for rendered prompt.""" + + rendered_prompt: str + prompt_version: str + prompt_hash: str + + +def _get_service() -> PromptService: + """Get the prompt service instance.""" + return get_prompt_service() + + +@router.get( + "", + response_model=PromptListResponse, + summary="List all prompts", + description="Get all prompts with optional filtering by type.", +) +async def list_prompts( + prompt_type: Optional[str] = None, + include_inactive: bool = False, + _auth: AuthContext = Depends(require_admin), +): + """List all prompts.""" + service = _get_service() + + if prompt_type and prompt_type not in ("core", "wrapper"): + raise HTTPException( + status_code=400, + detail={"error": {"code": "INVALID_TYPE", "message": "prompt_type must be 'core' or 'wrapper'"}}, + ) + + prompts = await service.list_all( + prompt_type=prompt_type, + include_inactive=include_inactive, + ) + + items = [ + PromptResponse( + id=p.id, + name=p.name, + version=p.version, + prompt_type=p.prompt_type, + model_name=p.model_name, + content=p.content, + content_hash=p.content_hash, + is_active=p.is_active, + description=p.description, + created_at=p.created_at, + ) + for p in prompts + ] + + return PromptListResponse(items=items, total=len(items)) + + +@router.get( + "/available", + response_model=PromptAvailableResponse, + summary="Get available prompt versions", + description="Get a summary of available prompt versions grouped by type.", +) +async def get_available_prompts( + _auth: AuthContext = Depends(require_admin), +): + """Get available prompt versions.""" + service = _get_service() + available = await service.get_available_prompts() + return PromptAvailableResponse( + core_versions=available["core_versions"], + wrappers=available["wrappers"], + ) + + +@router.get( + "/{name}/{version}", + response_model=PromptResponse, + summary="Get a specific prompt", + description="Get a prompt by name and version.", +) +async def get_prompt( + name: str, + version: str, + _auth: AuthContext = Depends(require_admin), +): + """Get a specific prompt.""" + service = _get_service() + prompt = await service.get_prompt(name, version) + + if not prompt: + raise HTTPException( + status_code=404, + detail={"error": {"code": "NOT_FOUND", "message": f"Prompt {name} version {version} not found"}}, + ) + + return PromptResponse( + id=prompt.id, + name=prompt.name, + version=prompt.version, + prompt_type=prompt.prompt_type, + model_name=prompt.model_name, + content=prompt.content, + content_hash=prompt.content_hash, + is_active=prompt.is_active, + description=prompt.description, + created_at=prompt.created_at, + ) + + +@router.post( + "", + response_model=PromptResponse, + status_code=201, + summary="Create a new prompt", + description="Create a new prompt. For wrapper prompts, model_name is required.", +) +async def create_prompt( + prompt: PromptCreate, + _auth: AuthContext = Depends(require_admin), +): + """Create a new prompt.""" + service = _get_service() + + try: + created = await service.create( + name=prompt.name, + version=prompt.version, + prompt_type=prompt.prompt_type, + content=prompt.content, + model_name=prompt.model_name, + description=prompt.description, + created_by=_auth.subject, + ) + except ValueError as e: + raise HTTPException( + status_code=400, + detail={"error": {"code": "VALIDATION_ERROR", "message": str(e)}}, + ) + + return PromptResponse( + id=created.id, + name=created.name, + version=created.version, + prompt_type=created.prompt_type, + model_name=created.model_name, + content=created.content, + content_hash=created.content_hash, + is_active=created.is_active, + description=created.description, + created_at=created.created_at, + ) + + +@router.put( + "/{name}/{version}", + response_model=PromptResponse, + summary="Update a prompt", + description="Update a prompt's content and description.", +) +async def update_prompt( + name: str, + version: str, + prompt: PromptUpdate, + _auth: AuthContext = Depends(require_admin), +): + """Update a prompt.""" + service = _get_service() + + updated = await service.update( + name=name, + version=version, + content=prompt.content, + description=prompt.description, + ) + + if not updated: + raise HTTPException( + status_code=404, + detail={"error": {"code": "NOT_FOUND", "message": f"Prompt {name} version {version} not found"}}, + ) + + return PromptResponse( + id=updated.id, + name=updated.name, + version=updated.version, + prompt_type=updated.prompt_type, + model_name=updated.model_name, + content=updated.content, + content_hash=updated.content_hash, + is_active=updated.is_active, + description=updated.description, + created_at=updated.created_at, + ) + + +@router.patch( + "/{name}/{version}", + response_model=PromptResponse, + summary="Partial update a prompt", + description="Partially update a prompt (e.g., activate/deactivate).", +) +async def patch_prompt( + name: str, + version: str, + prompt: PromptPatch, + _auth: AuthContext = Depends(require_admin), +): + """Partially update a prompt.""" + service = _get_service() + + updated = await service.update( + name=name, + version=version, + content=prompt.content, + description=prompt.description, + is_active=prompt.is_active, + ) + + if not updated: + raise HTTPException( + status_code=404, + detail={"error": {"code": "NOT_FOUND", "message": f"Prompt {name} version {version} not found"}}, + ) + + return PromptResponse( + id=updated.id, + name=updated.name, + version=updated.version, + prompt_type=updated.prompt_type, + model_name=updated.model_name, + content=updated.content, + content_hash=updated.content_hash, + is_active=updated.is_active, + description=updated.description, + created_at=updated.created_at, + ) + + +@router.delete( + "/{name}/{version}", + status_code=204, + summary="Delete a prompt", + description="Delete a prompt by name and version.", +) +async def delete_prompt( + name: str, + version: str, + _auth: AuthContext = Depends(require_admin), +): + """Delete a prompt.""" + service = _get_service() + + deleted = await service.delete(name, version) + + if not deleted: + raise HTTPException( + status_code=404, + detail={"error": {"code": "NOT_FOUND", "message": f"Prompt {name} version {version} not found"}}, + ) + + return None + + +@router.post( + "/render", + response_model=PromptRenderResponse, + summary="Render a prompt with data", + description="Render a complete prompt with enriched event data and constraints.", +) +async def render_prompt( + request: PromptRenderRequest, + _auth: AuthContext = Depends(require_admin), +): + """Render a prompt with data.""" + service = _get_service() + + try: + rendered, version, hash = await service.render_prompt( + model_name=request.model_name, + enriched_event=request.enriched_event, + constraints=request.constraints, + core_version=request.core_version, + wrapper_version=request.wrapper_version, + ) + except ValueError as e: + raise HTTPException( + status_code=400, + detail={"error": {"code": "RENDER_ERROR", "message": str(e)}}, + ) + + return PromptRenderResponse( + rendered_prompt=rendered, + prompt_version=version, + prompt_hash=hash, + ) diff --git a/src/api/v1/router.py b/src/api/v1/router.py index f18f1e4..390608b 100644 --- a/src/api/v1/router.py +++ b/src/api/v1/router.py @@ -2,7 +2,7 @@ from fastapi import APIRouter -from src.api.v1 import decisions, dlq, events, health, signals, ws +from src.api.v1 import decisions, dlq, events, health, llm_configs, prompts, signals, ws api_router = APIRouter() @@ -14,4 +14,6 @@ api_router.include_router(events.router, prefix="/events", tags=["Events"]) api_router.include_router(decisions.router, prefix="/decisions", tags=["Decisions"]) api_router.include_router(dlq.router, prefix="/dlq", tags=["DLQ"]) +api_router.include_router(llm_configs.router, prefix="/llm-configs", tags=["LLM Configs"]) +api_router.include_router(prompts.router, prefix="/prompts", tags=["Prompts"]) api_router.include_router(ws.router, prefix="/ws", tags=["WebSocket"]) diff --git a/src/api/v1/signals.py b/src/api/v1/signals.py index b1f9729..a4bf99e 100644 --- a/src/api/v1/signals.py +++ b/src/api/v1/signals.py @@ -36,6 +36,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from src.core.auth import AuthContext, require_submit from src.core.config import settings from src.core.rate_limit import get_rate_limiter from src.models.database import get_db_session @@ -98,7 +99,8 @@ async def submit_signal( response: Response, request: Request, idempotency_key: Annotated[Optional[str], Header(alias="X-Idempotency-Key")] = None, - _: None = Depends(check_rate_limit), + _rate_limit: None = Depends(check_rate_limit), + _auth: AuthContext = Depends(require_submit), ): """ Submit a trading signal for analysis. diff --git a/src/api/v1/ws.py b/src/api/v1/ws.py index 2f95346..90db487 100644 --- a/src/api/v1/ws.py +++ b/src/api/v1/ws.py @@ -1,7 +1,8 @@ """WebSocket endpoint for real-time decision streaming.""" -from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect +from src.core.auth import AuthContext, get_websocket_auth_context, require_read, Scope from src.core.config import settings from src.observability.logging import get_logger from src.services.publisher.ws_server import handle_websocket, ws_manager @@ -16,7 +17,16 @@ async def websocket_endpoint(websocket: WebSocket): """ WebSocket endpoint for streaming AI decisions in real-time. - Access is restricted to internal Docker network only (network-level security). + ## Authentication + + When AUTH_MODE is 'psk' or 'jwt', authenticate via the Sec-WebSocket-Protocol header: + + ``` + Sec-WebSocket-Protocol: bearer, + ``` + + The server will echo back "bearer" in the response protocol if auth succeeds. + Requires `lens:read` scope. ## Connection @@ -77,14 +87,26 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.close(code=1000, reason="WebSocket disabled") return + # Check authentication for WebSocket + auth = await get_websocket_auth_context(websocket) + if settings.AUTH_MODE != "none": + if not auth.authenticated: + await websocket.close(code=4001, reason="Authentication required") + return + if not auth.has_scope(Scope.READ): + await websocket.close(code=4003, reason="Insufficient permissions") + return + # Check max connections if len(ws_manager.subscriptions) >= settings.WS_MAX_CONNECTIONS: await websocket.close(code=4029, reason="Too many connections") return # Handle the WebSocket connection + # If auth was via bearer protocol, we need to accept with that subprotocol + subprotocol = "bearer" if auth.token_type else None try: - await handle_websocket(websocket) + await handle_websocket(websocket, subprotocol=subprotocol) except WebSocketDisconnect: pass except Exception as e: diff --git a/src/core/auth.py b/src/core/auth.py new file mode 100644 index 0000000..ae85e69 --- /dev/null +++ b/src/core/auth.py @@ -0,0 +1,376 @@ +"""Authentication and authorization module. + +This module provides 3-mode authentication: + - none: No auth (development mode) + - psk: Pre-shared key tokens (Docker Compose deployments) + - jwt: JWT validation (portable/production deployments) + +Scopes: + - lens:submit: POST /signals + - lens:read: GET events, decisions, DLQ + - lens:admin: LLM configs, DLQ retry/resolve (includes all scopes) + +Usage: + # In route handlers, use the require_scope dependency: + @router.post("/signals") + async def submit_signal( + auth: AuthContext = Depends(require_scope("lens:submit")) + ): + ... + + # For WebSocket, extract token from Sec-WebSocket-Protocol header: + # Sec-WebSocket-Protocol: bearer, +""" + +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Set + +from fastapi import Depends, Header, HTTPException, Request, WebSocket, status + +from src.core.config import settings +from src.observability.logging import get_logger + +logger = get_logger(__name__) + + +class Scope(str, Enum): + """Authorization scopes.""" + + SUBMIT = "lens:submit" + READ = "lens:read" + ADMIN = "lens:admin" + + +# Scope hierarchy: admin includes all other scopes +SCOPE_HIERARCHY = { + Scope.ADMIN: {Scope.SUBMIT, Scope.READ, Scope.ADMIN}, + Scope.SUBMIT: {Scope.SUBMIT}, + Scope.READ: {Scope.READ}, +} + + +@dataclass +class AuthContext: + """Authentication context for a request.""" + + authenticated: bool + scopes: Set[Scope] + token_type: Optional[str] = None # "psk" or "jwt" + subject: Optional[str] = None # For JWT: sub claim + + def has_scope(self, scope: Scope) -> bool: + """Check if context has the given scope (including via hierarchy).""" + for granted_scope in self.scopes: + if scope in SCOPE_HIERARCHY.get(granted_scope, {granted_scope}): + return True + return False + + +def _extract_bearer_token(authorization: Optional[str]) -> Optional[str]: + """Extract token from Authorization header.""" + if not authorization: + return None + parts = authorization.split() + if len(parts) == 2 and parts[0].lower() == "bearer": + return parts[1] + return None + + +def _validate_psk_token(token: str) -> Optional[AuthContext]: + """Validate a PSK token and return the auth context. + + Args: + token: The token to validate + + Returns: + AuthContext if valid, None if not + """ + # Check admin token first (grants all scopes) + if settings.AUTH_TOKEN_ADMIN and token == settings.AUTH_TOKEN_ADMIN: + return AuthContext( + authenticated=True, + scopes={Scope.ADMIN}, + token_type="psk", + subject="admin", + ) + + # Check submit token + if settings.AUTH_TOKEN_SUBMIT and token == settings.AUTH_TOKEN_SUBMIT: + return AuthContext( + authenticated=True, + scopes={Scope.SUBMIT}, + token_type="psk", + subject="submit", + ) + + # Check read token + if settings.AUTH_TOKEN_READ and token == settings.AUTH_TOKEN_READ: + return AuthContext( + authenticated=True, + scopes={Scope.READ}, + token_type="psk", + subject="read", + ) + + return None + + +def _validate_jwt_token(token: str) -> Optional[AuthContext]: + """Validate a JWT token and return the auth context. + + Args: + token: The JWT to validate + + Returns: + AuthContext if valid, None if not + + Raises: + HTTPException: If JWT validation fails with specific error + """ + try: + import jwt + from jwt import PyJWKClient + except ImportError: + logger.error("PyJWT not installed. Install with: pip install PyJWT[crypto]") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": {"code": "CONFIG_ERROR", "message": "JWT support not configured"}}, + ) + + # Build verification options + verify_options = { + "verify_signature": True, + "verify_exp": True, + "verify_iat": True, + "require": ["exp", "iat"], + } + + # Get the public key + public_key = None + + if settings.AUTH_JWT_JWKS_URL: + # Use JWKS endpoint + try: + jwks_client = PyJWKClient(settings.AUTH_JWT_JWKS_URL) + # Get the signing key from the token header + signing_key = jwks_client.get_signing_key_from_jwt(token) + public_key = signing_key.key + except Exception as e: + logger.warning(f"Failed to get signing key from JWKS: {e}") + return None + elif settings.AUTH_JWT_PUBLIC_KEY: + # Use provided public key + public_key = settings.AUTH_JWT_PUBLIC_KEY + else: + logger.error("JWT mode enabled but no public key or JWKS URL configured") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": {"code": "CONFIG_ERROR", "message": "JWT validation not configured"}}, + ) + + # Decode and validate + try: + # Build decode options + decode_options = {"verify_signature": True} + + # Add issuer validation if configured + issuer = settings.AUTH_JWT_ISSUER if settings.AUTH_JWT_ISSUER else None + + # Add audience validation if configured + audience = settings.AUTH_JWT_AUDIENCE if settings.AUTH_JWT_AUDIENCE else None + + payload = jwt.decode( + token, + public_key, + algorithms=["RS256", "ES256", "HS256"], + issuer=issuer, + audience=audience, + options=decode_options, + ) + + # Extract scopes from configured claim + scope_claim = settings.AUTH_JWT_SCOPE_CLAIM + raw_scopes = payload.get(scope_claim, "") + + # Scopes can be space-separated string or list + if isinstance(raw_scopes, str): + scope_strings = raw_scopes.split() + elif isinstance(raw_scopes, list): + scope_strings = raw_scopes + else: + scope_strings = [] + + # Convert to Scope enum values + scopes = set() + for s in scope_strings: + try: + scopes.add(Scope(s)) + except ValueError: + # Ignore unknown scopes + pass + + if not scopes: + logger.warning(f"JWT has no valid scopes: {scope_strings}") + return None + + return AuthContext( + authenticated=True, + scopes=scopes, + token_type="jwt", + subject=payload.get("sub"), + ) + + except jwt.ExpiredSignatureError: + logger.warning("JWT has expired") + return None + except jwt.InvalidIssuerError: + logger.warning("JWT has invalid issuer") + return None + except jwt.InvalidAudienceError: + logger.warning("JWT has invalid audience") + return None + except jwt.InvalidTokenError as e: + logger.warning(f"Invalid JWT: {e}") + return None + + +def get_auth_context( + authorization: Optional[str] = Header(None, alias="Authorization"), +) -> AuthContext: + """FastAPI dependency to extract and validate auth context. + + Args: + authorization: Authorization header value + + Returns: + AuthContext for the request + """ + # Mode: none - allow everything + if settings.AUTH_MODE == "none": + return AuthContext( + authenticated=True, + scopes={Scope.ADMIN}, # Grant all scopes in dev mode + token_type=None, + ) + + # Extract token from header + token = _extract_bearer_token(authorization) + + if not token: + return AuthContext(authenticated=False, scopes=set()) + + # Mode: psk - validate against configured tokens + if settings.AUTH_MODE == "psk": + context = _validate_psk_token(token) + if context: + return context + return AuthContext(authenticated=False, scopes=set()) + + # Mode: jwt - validate JWT signature and claims + if settings.AUTH_MODE == "jwt": + context = _validate_jwt_token(token) + if context: + return context + return AuthContext(authenticated=False, scopes=set()) + + # Unknown mode (should be caught by validator, but defensive) + logger.error(f"Unknown auth mode: {settings.AUTH_MODE}") + return AuthContext(authenticated=False, scopes=set()) + + +def require_scope(scope: Scope): + """Create a dependency that requires a specific scope. + + Args: + scope: The required scope + + Returns: + FastAPI dependency function + """ + + def dependency(auth: AuthContext = Depends(get_auth_context)) -> AuthContext: + # In mode none, everything is allowed + if settings.AUTH_MODE == "none": + return auth + + # Check authentication + if not auth.authenticated: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={"error": {"code": "UNAUTHORIZED", "message": "Authentication required"}}, + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Check scope + if not auth.has_scope(scope): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": { + "code": "FORBIDDEN", + "message": f"Insufficient permissions. Required scope: {scope.value}", + } + }, + ) + + return auth + + return dependency + + +async def get_websocket_auth_context(websocket: WebSocket) -> AuthContext: + """Extract and validate auth context from WebSocket connection. + + WebSocket auth uses the Sec-WebSocket-Protocol header: + Sec-WebSocket-Protocol: bearer, + + The server should echo back "bearer" in the response protocol. + + Args: + websocket: The WebSocket connection + + Returns: + AuthContext for the connection + """ + # Mode: none - allow everything + if settings.AUTH_MODE == "none": + return AuthContext( + authenticated=True, + scopes={Scope.ADMIN}, + token_type=None, + ) + + # Extract token from Sec-WebSocket-Protocol header + # Format: bearer, + protocols = websocket.headers.get("sec-websocket-protocol", "") + token = None + + if protocols: + parts = [p.strip() for p in protocols.split(",")] + if len(parts) >= 2 and parts[0].lower() == "bearer": + token = parts[1] + + if not token: + return AuthContext(authenticated=False, scopes=set()) + + # Validate based on mode + if settings.AUTH_MODE == "psk": + context = _validate_psk_token(token) + if context: + return context + return AuthContext(authenticated=False, scopes=set()) + + if settings.AUTH_MODE == "jwt": + context = _validate_jwt_token(token) + if context: + return context + return AuthContext(authenticated=False, scopes=set()) + + return AuthContext(authenticated=False, scopes=set()) + + +# Convenience dependencies for common scope requirements +require_submit = require_scope(Scope.SUBMIT) +require_read = require_scope(Scope.READ) +require_admin = require_scope(Scope.ADMIN) diff --git a/src/core/config.py b/src/core/config.py index 1fd3e7c..94928d3 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -69,15 +69,50 @@ class Settings(BaseSettings): HYPERLIQUID_WS_URL: str = "wss://api.hyperliquid.xyz/ws" # AI Models - AI_MODELS: str = "chatgpt,gemini" # IMPORTANT: Must explicitly set to true (production) or false (stub mode for testing) # Defaults to None - validated at runtime when accessed via use_real_ai property # This allows scripts/migrations to import config without crashing + # LLM configurations are managed via /api/v1/llm-configs endpoints USE_REAL_AI: Optional[bool] = Field( default=None, description="Must be explicitly set: true for real AI, false for stub decisions" ) + # Authentication + # AUTH_MODE: none (dev), psk (pre-shared keys), jwt (portable) + AUTH_MODE: str = Field( + default="none", + description="Authentication mode: none (dev), psk (pre-shared keys), jwt (portable)" + ) + + # PSK Mode tokens (each grants its associated scope) + AUTH_TOKEN_SUBMIT: Optional[str] = Field( + default=None, description="PSK token for lens:submit scope" + ) + AUTH_TOKEN_READ: Optional[str] = Field( + default=None, description="PSK token for lens:read scope" + ) + AUTH_TOKEN_ADMIN: Optional[str] = Field( + default=None, description="PSK token for lens:admin scope (includes all)" + ) + + # JWT Mode configuration + AUTH_JWT_PUBLIC_KEY: Optional[str] = Field( + default=None, description="PEM-encoded public key for JWT validation" + ) + AUTH_JWT_JWKS_URL: Optional[str] = Field( + default=None, description="URL to JWKS endpoint for JWT validation" + ) + AUTH_JWT_ISSUER: Optional[str] = Field( + default=None, description="Expected JWT issuer claim" + ) + AUTH_JWT_AUDIENCE: Optional[str] = Field( + default=None, description="Expected JWT audience claim" + ) + AUTH_JWT_SCOPE_CLAIM: str = Field( + default="scope", description="JWT claim containing scopes" + ) + # WebSocket WS_ENABLED: bool = True WS_PING_INTERVAL_S: int = 30 @@ -95,11 +130,6 @@ def timeframes_list(self) -> List[str]: """Parse TIMEFRAMES into a list.""" return [tf.strip() for tf in self.TIMEFRAMES.split(",")] - @property - def ai_models_list(self) -> List[str]: - """Parse AI_MODELS into a list.""" - return [m.strip() for m in self.AI_MODELS.split(",")] - @property def use_real_ai(self) -> bool: """Get USE_REAL_AI with deferred validation. @@ -134,38 +164,15 @@ def validate_feature_profile(cls, v: str) -> str: raise ValueError(f"FEATURE_PROFILE must be one of {valid_profiles}") return v - -class ModelConfig(BaseSettings): - """Per-model configuration loaded dynamically.""" - - model_config = SettingsConfigDict( - env_file=".env", - env_file_encoding="utf-8", - case_sensitive=True, - extra="ignore", - ) - - provider: str - api_key: str - model_id: str - timeout_ms: int = 30000 - max_tokens: int = 1000 - prompt_path: Optional[str] = None - + @field_validator("AUTH_MODE") @classmethod - def for_model(cls, model_name: str) -> "ModelConfig": - """Load configuration for a specific model.""" - import os - - prefix = f"MODEL_{model_name.upper()}_" - return cls( - provider=os.getenv(f"{prefix}PROVIDER", ""), - api_key=os.getenv(f"{prefix}API_KEY", ""), - model_id=os.getenv(f"{prefix}MODEL_ID", ""), - timeout_ms=int(os.getenv(f"{prefix}TIMEOUT_MS", "30000")), - max_tokens=int(os.getenv(f"{prefix}MAX_TOKENS", "1000")), - prompt_path=os.getenv(f"{prefix}PROMPT_PATH"), - ) + def validate_auth_mode(cls, v: str) -> str: + """Validate authentication mode.""" + valid_modes = {"none", "psk", "jwt"} + v = v.lower() + if v not in valid_modes: + raise ValueError(f"AUTH_MODE must be one of {valid_modes}") + return v @lru_cache diff --git a/src/main.py b/src/main.py index 84f044c..bf9d2d9 100644 --- a/src/main.py +++ b/src/main.py @@ -12,6 +12,8 @@ from src.models.database import close_db, init_db from src.observability.logging import get_logger, setup_logging from src.observability.metrics import metrics +from src.services.llm_config import get_llm_config_service +from src.services.prompt import get_prompt_service from src.services.queue import close_redis_client, get_redis_client, init_redis_client, reset_queue_producer logger = get_logger(__name__) @@ -36,15 +38,27 @@ async def lifespan(app: FastAPI): await init_rate_limiter(get_redis_client()) logger.info(f"Rate limiter initialized (enabled={settings.RATE_LIMIT_ENABLED}, limit={settings.RATE_LIMIT_PER_MIN}/min)") + # Initialize LLM config service (loads from database) + llm_config_service = get_llm_config_service() + await llm_config_service.initialize() + enabled_models = await llm_config_service.get_enabled_models() + logger.info(f"LLM config service initialized, enabled models: {enabled_models}") + + # Initialize prompt service (loads from database, seeds from files if empty) + prompt_service = get_prompt_service() + await prompt_service.initialize() + available_prompts = await prompt_service.get_available_prompts() + logger.info(f"Prompt service initialized, available: {available_prompts}") + # Set application info metrics metrics.set_app_info( version="0.1.0", feature_profile=settings.FEATURE_PROFILE, - ai_models=settings.AI_MODELS, + ai_models=",".join(enabled_models) if enabled_models else "none", ) logger.info(f"Feature profile: {settings.FEATURE_PROFILE}") - logger.info(f"AI models: {settings.ai_models_list}") + logger.info(f"AI models (enabled): {enabled_models}") logger.info(f"WebSocket enabled: {settings.WS_ENABLED}") yield diff --git a/src/models/orm/__init__.py b/src/models/orm/__init__.py index b5e4f06..87dc58e 100644 --- a/src/models/orm/__init__.py +++ b/src/models/orm/__init__.py @@ -3,6 +3,8 @@ from src.models.orm.event import Event, EnrichedEvent, ProcessingTimeline from src.models.orm.decision import ModelDecision from src.models.orm.dlq import DLQEntry +from src.models.orm.llm_config import LLMConfig +from src.models.orm.prompt import Prompt __all__ = [ "Event", @@ -10,4 +12,6 @@ "ProcessingTimeline", "ModelDecision", "DLQEntry", + "LLMConfig", + "Prompt", ] diff --git a/src/models/orm/llm_config.py b/src/models/orm/llm_config.py new file mode 100644 index 0000000..bf76b11 --- /dev/null +++ b/src/models/orm/llm_config.py @@ -0,0 +1,69 @@ +"""LLM configuration ORM model for runtime API key management.""" + +import uuid +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy import Boolean, DateTime, Integer, String, Text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from src.models.database import Base + + +class LLMConfig(Base): + """LLM provider configuration stored in database for runtime management. + + Allows API keys and model settings to be updated without container restart. + API keys are stored encrypted (application-level encryption recommended). + """ + + __tablename__ = "llm_configs" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + + # Model identification (chatgpt, gemini, claude, deepseek) + model_name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) + + # Enable/disable without deleting config + enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + + # Provider info + provider: Mapped[str] = mapped_column(String(50), nullable=False) + # e.g., "openai", "google", "anthropic", "deepseek" + + # API key (should be encrypted at rest in production) + api_key: Mapped[str] = mapped_column(Text, nullable=False) + + # Model settings + model_id: Mapped[str] = mapped_column(String(100), nullable=False) + # e.g., "gpt-4o", "gemini-1.5-pro", "claude-sonnet-4-20250514" + + timeout_ms: Mapped[int] = mapped_column(Integer, nullable=False, default=30000) + max_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=1000) + + # Optional custom prompt path (relative to prompts/ directory) + prompt_path: Mapped[Optional[str]] = mapped_column(String(200), nullable=True) + + # Audit timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + + # Last validation status + last_validated_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), nullable=True + ) + validation_status: Mapped[Optional[str]] = mapped_column( + String(20), nullable=True + ) # ok, invalid_key, rate_limited, error + + def __repr__(self) -> str: + return f"" diff --git a/src/models/orm/prompt.py b/src/models/orm/prompt.py new file mode 100644 index 0000000..5db3a3b --- /dev/null +++ b/src/models/orm/prompt.py @@ -0,0 +1,162 @@ +"""Prompt ORM model for database-backed prompt storage. + +This module defines the SQLAlchemy ORM model for storing AI prompts in the database. +Prompts are used by the evaluation worker to generate decisions from trading signals. + +The prompt system uses a "core + wrapper" pattern: +- Core prompts contain the shared decision-making logic and output schema +- Wrapper prompts provide model-specific formatting (e.g., JSON enforcement for Claude) + +Example usage: + # Core prompt: shared by all models + core = Prompt(name="core_decision", version="v1", prompt_type="core", ...) + + # Wrapper prompt: specific to ChatGPT + wrapper = Prompt(name="chatgpt_wrapper", version="v1", prompt_type="wrapper", + model_name="chatgpt", ...) +""" + +import uuid +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy import Boolean, DateTime, String, Text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from src.models.database import Base + + +class Prompt(Base): + """Database-backed prompt storage for AI model evaluation. + + This model stores versioned prompts that can be updated at runtime without + requiring a deployment. Supports the core + wrapper pattern where: + + - Core prompts (prompt_type='core'): Contain shared decision logic, output + schema, and evaluation guidelines. Used by all models. + + - Wrapper prompts (prompt_type='wrapper'): Model-specific formatting that + wraps around the core prompt. Each AI provider may need different + instructions for JSON output, system prompts, etc. + + Key Features: + - Versioning: Multiple versions can coexist (v1, v2, etc.) + - Activation: Toggle is_active to switch between versions without deletion + - Audit trail: Track creation/modification times and authors + - Content hashing: SHA-256 hash for detecting changes + + Database Constraints: + - Unique constraint on (name, version) for active prompts + - Indexes on name, version, prompt_type, model_name, and is_active + + Attributes: + id: UUID primary key + name: Prompt identifier (e.g., "core_decision", "chatgpt_wrapper") + version: Version string (e.g., "v1", "v2", "v1.1") + prompt_type: Either "core" or "wrapper" + model_name: For wrappers, the target model (e.g., "chatgpt", "gemini") + content: The actual prompt text with placeholders + description: Human-readable description for admin UI + is_active: Whether this prompt version is currently in use + content_hash: SHA-256 hash for change detection + created_at: When this prompt was created + updated_at: When this prompt was last modified + created_by: Username/identifier of the creator (from auth context) + """ + + __tablename__ = "prompts" + + # ========================================================================== + # Primary Key + # ========================================================================== + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + comment="Unique identifier for this prompt record" + ) + + # ========================================================================== + # Prompt Identification + # ========================================================================== + name: Mapped[str] = mapped_column( + String(100), + nullable=False, + comment="Prompt name, e.g., 'core_decision', 'chatgpt_wrapper'" + ) + + version: Mapped[str] = mapped_column( + String(20), + nullable=False, + comment="Version string, e.g., 'v1', 'v2', 'v1.1'" + ) + + prompt_type: Mapped[str] = mapped_column( + String(20), + nullable=False, + comment="Either 'core' for shared logic or 'wrapper' for model-specific" + ) + + model_name: Mapped[Optional[str]] = mapped_column( + String(50), + nullable=True, + comment="For wrapper prompts: target model (chatgpt, gemini, claude, deepseek)" + ) + + # ========================================================================== + # Content + # ========================================================================== + content: Mapped[str] = mapped_column( + Text, + nullable=False, + comment="The prompt text with placeholders like {enriched_event}" + ) + + description: Mapped[Optional[str]] = mapped_column( + Text, + nullable=True, + comment="Human-readable description for admin UI" + ) + + # ========================================================================== + # Status and Versioning + # ========================================================================== + is_active: Mapped[bool] = mapped_column( + Boolean, + nullable=False, + default=True, + comment="Whether this version is active; only one active per name+version" + ) + + content_hash: Mapped[str] = mapped_column( + String(64), + nullable=False, + comment="SHA-256 hash of content for change detection" + ) + + # ========================================================================== + # Audit Fields + # ========================================================================== + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + comment="When this prompt was created" + ) + + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + comment="When this prompt was last modified" + ) + + created_by: Mapped[Optional[str]] = mapped_column( + String(100), + nullable=True, + comment="Username/subject from auth context who created this prompt" + ) + + def __repr__(self) -> str: + """Return a string representation for debugging.""" + return f"" diff --git a/src/services/enrichment/ta_calculator.py b/src/services/enrichment/ta_calculator.py index 50f9e40..8e56221 100644 --- a/src/services/enrichment/ta_calculator.py +++ b/src/services/enrichment/ta_calculator.py @@ -1,7 +1,18 @@ -"""Technical analysis calculator for trading indicators.""" - -from dataclasses import dataclass -from typing import Dict, List, Optional +"""Technical analysis calculator for trading indicators. + +Computes technical indicators from OHLCV data: +- EMA (Exponential Moving Average) +- SMA (Simple Moving Average) +- MACD (Moving Average Convergence Divergence) +- RSI (Relative Strength Index) +- ATR (Average True Range) +- Bollinger Bands (upper, lower, width, rating) +- ADX (Average Directional Index) +- Stochastic (K, D) +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple import numpy as np @@ -25,14 +36,41 @@ class MACDResult: histogram: float +@dataclass +class BollingerBandsResult: + """Bollinger Bands indicator result.""" + + upper: float # Upper band (SMA + 2*std) + middle: float # Middle band (SMA20) + lower: float # Lower band (SMA - 2*std) + bbw: float # Band width: (upper - lower) / middle + rating: int # Position rating: -3 to +3 + signal: str # BUY, SELL, or NEUTRAL + + +@dataclass +class StochasticResult: + """Stochastic oscillator result.""" + + k: float # %K (fast stochastic) + d: float # %D (slow stochastic, SMA of %K) + signal: str # OVERBOUGHT, OVERSOLD, or NEUTRAL + + @dataclass class TAResult: """Complete technical analysis result for a timeframe.""" ema: Dict[str, float] # e.g., {"ema_9": 42000.0, "ema_21": 41800.0} + sma: Dict[str, float] # e.g., {"sma_20": 41900.0} macd: MACDResult rsi: float atr: float + bollinger: Optional[BollingerBandsResult] = None + adx: Optional[float] = None + stochastic: Optional[StochasticResult] = None + volume: Optional[float] = None + volume_sma: Optional[float] = None class TACalculator: @@ -41,11 +79,54 @@ class TACalculator: Computes indicators from OHLCV data: - EMA (Exponential Moving Average) + - SMA (Simple Moving Average) - MACD (Moving Average Convergence Divergence) - RSI (Relative Strength Index) - ATR (Average True Range) + - Bollinger Bands (upper, lower, width, rating) + - ADX (Average Directional Index) + - Stochastic (K, D) """ + @staticmethod + def calculate_sma(data: np.ndarray, period: int) -> float: + """ + Calculate Simple Moving Average. + + Args: + data: Array of values (oldest to newest) + period: SMA period + + Returns: + Current SMA value + """ + if len(data) < period: + return float(np.mean(data)) + return float(np.mean(data[-period:])) + + @staticmethod + def calculate_sma_series(data: np.ndarray, period: int) -> np.ndarray: + """ + Calculate SMA series for all values. + + Args: + data: Array of values + period: SMA period + + Returns: + Array of SMA values + """ + if len(data) < period: + return np.full(len(data), np.nan) + + sma = np.zeros(len(data)) + sma[:period - 1] = np.nan + + for i in range(period - 1, len(data)): + sma[i] = np.mean(data[i - period + 1:i + 1]) + + return sma + @staticmethod def calculate_ema(closes: np.ndarray, period: int) -> float: """ @@ -227,14 +308,271 @@ def calculate_atr( return round(float(atr), 4) + @staticmethod + def calculate_bollinger_bands( + closes: np.ndarray, + period: int = 20, + std_dev: float = 2.0, + ) -> BollingerBandsResult: + """ + Calculate Bollinger Bands. + + Args: + closes: Array of closing prices + period: SMA period (default 20) + std_dev: Standard deviation multiplier (default 2.0) + + Returns: + BollingerBandsResult with upper, middle, lower, bbw, rating, signal + """ + if len(closes) < period: + current_price = closes[-1] if len(closes) > 0 else 0 + return BollingerBandsResult( + upper=current_price, + middle=current_price, + lower=current_price, + bbw=0.0, + rating=0, + signal="NEUTRAL", + ) + + # Calculate middle band (SMA) + middle = float(np.mean(closes[-period:])) + + # Calculate standard deviation + std = float(np.std(closes[-period:])) + + # Calculate upper and lower bands + upper = middle + (std_dev * std) + lower = middle - (std_dev * std) + + # Calculate BBW (Band Width) + bbw = (upper - lower) / middle if middle != 0 else 0 + + # Calculate rating based on price position + current_price = closes[-1] + rating, signal = TACalculator._compute_bb_rating( + current_price, upper, middle, lower + ) + + return BollingerBandsResult( + upper=round(upper, 4), + middle=round(middle, 4), + lower=round(lower, 4), + bbw=round(bbw, 4), + rating=rating, + signal=signal, + ) + + @staticmethod + def _compute_bb_rating( + close: float, bb_upper: float, bb_middle: float, bb_lower: float + ) -> Tuple[int, str]: + """ + Compute Bollinger Band rating and signal. + + Rating scale: + +3: Strong Buy (price above upper band) + +2: Buy (price in upper 50% of bands) + +1: Weak Buy (price above middle line) + 0: Neutral (price at middle line) + -1: Weak Sell (price below middle line) + -2: Sell (price in lower 50% of bands) + -3: Strong Sell (price below lower band) + + Args: + close: Current closing price + bb_upper: Upper Bollinger Band + bb_middle: Middle Bollinger Band (SMA) + bb_lower: Lower Bollinger Band + + Returns: + Tuple of (rating, signal) + """ + rating = 0 + if close > bb_upper: + rating = 3 + elif close > bb_middle + ((bb_upper - bb_middle) / 2): + rating = 2 + elif close > bb_middle: + rating = 1 + elif close < bb_lower: + rating = -3 + elif close < bb_middle - ((bb_middle - bb_lower) / 2): + rating = -2 + elif close < bb_middle: + rating = -1 + + signal = "NEUTRAL" + if rating >= 2: + signal = "BUY" + elif rating <= -2: + signal = "SELL" + + return rating, signal + + @staticmethod + def calculate_adx( + highs: np.ndarray, + lows: np.ndarray, + closes: np.ndarray, + period: int = 14, + ) -> float: + """ + Calculate Average Directional Index (ADX). + + ADX measures trend strength (0-100): + 0-25: Weak or no trend + 25-50: Strong trend + 50-75: Very strong trend + 75-100: Extremely strong trend + + Args: + highs: Array of high prices + lows: Array of low prices + closes: Array of closing prices + period: ADX period (default 14) + + Returns: + Current ADX value (0-100) + """ + if len(closes) < period * 2: + return 25.0 # Default to weak trend if not enough data + + # Calculate True Range + prev_closes = np.roll(closes, 1) + prev_closes[0] = closes[0] + tr = np.maximum( + highs - lows, + np.maximum( + np.abs(highs - prev_closes), + np.abs(lows - prev_closes) + ) + ) + + # Calculate +DM and -DM + prev_highs = np.roll(highs, 1) + prev_lows = np.roll(lows, 1) + prev_highs[0] = highs[0] + prev_lows[0] = lows[0] + + plus_dm = np.where( + (highs - prev_highs) > (prev_lows - lows), + np.maximum(highs - prev_highs, 0), + 0 + ) + minus_dm = np.where( + (prev_lows - lows) > (highs - prev_highs), + np.maximum(prev_lows - lows, 0), + 0 + ) + + # Smooth using Wilder's method + def wilder_smooth(data: np.ndarray, period: int) -> np.ndarray: + result = np.zeros(len(data)) + result[:period] = np.nan + result[period - 1] = np.sum(data[:period]) + for i in range(period, len(data)): + result[i] = result[i - 1] - (result[i - 1] / period) + data[i] + return result + + atr_smooth = wilder_smooth(tr, period) + plus_dm_smooth = wilder_smooth(plus_dm, period) + minus_dm_smooth = wilder_smooth(minus_dm, period) + + # Calculate +DI and -DI + plus_di = 100 * (plus_dm_smooth / atr_smooth) + minus_di = 100 * (minus_dm_smooth / atr_smooth) + + # Calculate DX + dx = 100 * np.abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10) + + # Calculate ADX (smoothed DX) + adx = wilder_smooth(dx[period - 1:], period) + + return round(float(adx[-1]), 2) if not np.isnan(adx[-1]) else 25.0 + + @staticmethod + def calculate_stochastic( + highs: np.ndarray, + lows: np.ndarray, + closes: np.ndarray, + k_period: int = 14, + d_period: int = 3, + ) -> StochasticResult: + """ + Calculate Stochastic Oscillator. + + %K = (Current Close - Lowest Low) / (Highest High - Lowest Low) * 100 + %D = SMA of %K + + Interpretation: + > 80: Overbought + < 20: Oversold + + Args: + highs: Array of high prices + lows: Array of low prices + closes: Array of closing prices + k_period: %K period (default 14) + d_period: %D smoothing period (default 3) + + Returns: + StochasticResult with k, d, signal + """ + if len(closes) < k_period: + return StochasticResult(k=50.0, d=50.0, signal="NEUTRAL") + + # Calculate %K series + k_values = np.zeros(len(closes)) + k_values[:k_period - 1] = np.nan + + for i in range(k_period - 1, len(closes)): + highest_high = np.max(highs[i - k_period + 1:i + 1]) + lowest_low = np.min(lows[i - k_period + 1:i + 1]) + range_val = highest_high - lowest_low + + if range_val > 0: + k_values[i] = ((closes[i] - lowest_low) / range_val) * 100 + else: + k_values[i] = 50.0 + + # Calculate %D (SMA of %K) + valid_k = k_values[~np.isnan(k_values)] + if len(valid_k) >= d_period: + d_value = float(np.mean(valid_k[-d_period:])) + else: + d_value = float(valid_k[-1]) if len(valid_k) > 0 else 50.0 + + k_value = float(k_values[-1]) + + # Determine signal + if k_value > 80: + signal = "OVERBOUGHT" + elif k_value < 20: + signal = "OVERSOLD" + else: + signal = "NEUTRAL" + + return StochasticResult( + k=round(k_value, 2), + d=round(d_value, 2), + signal=signal, + ) + @classmethod def calculate_all( cls, candles: List[OHLCV], ema_periods: List[int] = None, + sma_periods: List[int] = None, macd_params: Dict = None, rsi_period: int = 14, atr_period: int = 14, + bollinger_params: Dict = None, + adx_period: int = 14, + stochastic_params: Dict = None, + include_volume: bool = True, ) -> Optional[TAResult]: """ Calculate all technical indicators from candle data. @@ -242,9 +580,14 @@ def calculate_all( Args: candles: List of OHLCV candles (oldest to newest) ema_periods: List of EMA periods to calculate (default [9, 21, 50]) + sma_periods: List of SMA periods to calculate (default [20]) macd_params: MACD parameters dict (fast, slow, signal) rsi_period: RSI period atr_period: ATR period + bollinger_params: Bollinger Bands parameters dict (period, std_dev) + adx_period: ADX period + stochastic_params: Stochastic parameters dict (k_period, d_period) + include_volume: Whether to include volume metrics Returns: TAResult with all indicators, or None if not enough data @@ -255,13 +598,23 @@ def calculate_all( if ema_periods is None: ema_periods = [9, 21, 50] + if sma_periods is None: + sma_periods = [20] + if macd_params is None: macd_params = {"fast": 12, "slow": 26, "signal": 9} + if bollinger_params is None: + bollinger_params = {"period": 20, "std_dev": 2.0} + + if stochastic_params is None: + stochastic_params = {"k_period": 14, "d_period": 3} + # Convert candles to numpy arrays closes = np.array([c.close for c in candles]) highs = np.array([c.high for c in candles]) lows = np.array([c.low for c in candles]) + volumes = np.array([c.volume for c in candles]) # Calculate EMAs ema_results = {} @@ -269,6 +622,12 @@ def calculate_all( ema_value = cls.calculate_ema(closes, period) ema_results[f"ema_{period}"] = round(ema_value, 4) + # Calculate SMAs + sma_results = {} + for period in sma_periods: + sma_value = cls.calculate_sma(closes, period) + sma_results[f"sma_{period}"] = round(sma_value, 4) + # Calculate MACD macd_result = cls.calculate_macd( closes, @@ -283,9 +642,41 @@ def calculate_all( # Calculate ATR atr = cls.calculate_atr(highs, lows, closes, atr_period) + # Calculate Bollinger Bands + bollinger = cls.calculate_bollinger_bands( + closes, + period=bollinger_params.get("period", 20), + std_dev=bollinger_params.get("std_dev", 2.0), + ) + + # Calculate ADX + adx = cls.calculate_adx(highs, lows, closes, adx_period) + + # Calculate Stochastic + stochastic = cls.calculate_stochastic( + highs, + lows, + closes, + k_period=stochastic_params.get("k_period", 14), + d_period=stochastic_params.get("d_period", 3), + ) + + # Calculate Volume metrics + current_volume = None + volume_sma = None + if include_volume and len(volumes) > 0: + current_volume = float(volumes[-1]) + volume_sma = cls.calculate_sma(volumes, 20) if len(volumes) >= 20 else float(np.mean(volumes)) + return TAResult( ema=ema_results, + sma=sma_results, macd=macd_result, rsi=rsi, atr=atr, + bollinger=bollinger, + adx=adx, + stochastic=stochastic, + volume=current_volume, + volume_sma=round(volume_sma, 2) if volume_sma else None, ) diff --git a/src/services/evaluation/models/factory.py b/src/services/evaluation/models/factory.py index b34f399..2d68de8 100644 --- a/src/services/evaluation/models/factory.py +++ b/src/services/evaluation/models/factory.py @@ -6,16 +6,17 @@ 3. Fallback decision generation for error cases Usage: - # Create adapter from environment config - adapter = create_adapter("chatgpt") + # Create adapter from LLMConfigData (from database) + from src.services.llm_config import LLMConfigData + config_data = LLMConfigData(...) + adapter = create_adapter("chatgpt", config_data) # Validate model output is_valid, errors = validate_decision_output(response.parsed_response) """ -from typing import Dict, List, Optional, Tuple, Type, Any +from typing import Dict, List, Optional, Tuple, Type, Any, TYPE_CHECKING -from src.core.config import ModelConfig as EnvModelConfig from src.observability.logging import get_logger from src.services.evaluation.models.base import ( BaseModelAdapter, @@ -24,6 +25,9 @@ ModelStatus, ) +if TYPE_CHECKING: + from src.services.llm_config import LLMConfigData + logger = get_logger(__name__) # Valid decision values @@ -71,39 +75,38 @@ def get_adapter_class(provider: str) -> Type[BaseModelAdapter]: return adapters[provider_lower] -def create_adapter(model_name: str) -> BaseModelAdapter: - """Create a model adapter from environment configuration. - - Loads configuration from environment variables using the pattern: - MODEL_{MODEL_NAME}_PROVIDER, MODEL_{MODEL_NAME}_API_KEY, etc. +def create_adapter( + model_name: str, + config_data: Optional["LLMConfigData"] = None, +) -> BaseModelAdapter: + """Create a model adapter from configuration. Args: model_name: Name of the model (e.g., 'chatgpt', 'gemini', 'claude') + config_data: LLMConfigData from database. Required for configured adapters. Returns: Configured model adapter instance Raises: - ValueError: If model configuration is invalid + ValueError: If config_data is not provided """ - # Load config from environment - env_config = EnvModelConfig.for_model(model_name) - - if not env_config.provider: - raise ValueError(f"MODEL_{model_name.upper()}_PROVIDER not configured") + if not config_data: + raise ValueError( + f"No configuration provided for model '{model_name}'. " + f"Configure it via /api/v1/llm-configs/{model_name}" + ) - # Create internal config + # Use provided config data (from database) config = ModelConfig( model_name=model_name, - provider=env_config.provider, - api_key=env_config.api_key, - model_id=env_config.model_id, - timeout_ms=env_config.timeout_ms, - max_tokens=env_config.max_tokens, + provider=config_data.provider, + api_key=config_data.api_key, + model_id=config_data.model_id, + timeout_ms=config_data.timeout_ms, + max_tokens=config_data.max_tokens, ) - - # Get adapter class and instantiate - adapter_class = get_adapter_class(env_config.provider) + adapter_class = get_adapter_class(config_data.provider) return adapter_class(config) diff --git a/src/services/llm_config/__init__.py b/src/services/llm_config/__init__.py new file mode 100644 index 0000000..491914f --- /dev/null +++ b/src/services/llm_config/__init__.py @@ -0,0 +1,13 @@ +"""LLM configuration service for runtime API key management.""" + +from src.services.llm_config.service import ( + LLMConfigService, + get_llm_config_service, + LLMConfigData, +) + +__all__ = [ + "LLMConfigService", + "get_llm_config_service", + "LLMConfigData", +] diff --git a/src/services/llm_config/service.py b/src/services/llm_config/service.py new file mode 100644 index 0000000..3a8fc50 --- /dev/null +++ b/src/services/llm_config/service.py @@ -0,0 +1,399 @@ +"""LLM configuration service with caching for runtime API key management. + +This service provides: +- Database-backed LLM configuration storage +- In-memory caching with configurable TTL +- CRUD operations for LLM configs +- API key validation testing +- Predefined model-to-provider mappings +""" + +import asyncio +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Dict, List, Optional + +from sqlalchemy import select + +from src.models.database import get_db_context +from src.models.orm.llm_config import LLMConfig +from src.observability.logging import get_logger + +logger = get_logger(__name__) + + +# Predefined model-to-provider mappings (not user-editable) +MODEL_PROVIDERS = { + "chatgpt": "openai", + "gemini": "google", + "claude": "anthropic", + "deepseek": "deepseek", +} + +# Default model IDs for each model +DEFAULT_MODEL_IDS = { + "chatgpt": "gpt-4o", + "gemini": "gemini-1.5-pro", + "claude": "claude-sonnet-4-20250514", + "deepseek": "deepseek-chat", +} + + +@dataclass +class LLMConfigData: + """Immutable LLM configuration data for use by adapters.""" + + model_name: str + enabled: bool + provider: str + api_key: str + model_id: str + timeout_ms: int + max_tokens: int + prompt_path: Optional[str] + + @classmethod + def from_orm(cls, config: LLMConfig) -> "LLMConfigData": + """Create from ORM model.""" + return cls( + model_name=config.model_name, + enabled=config.enabled, + provider=config.provider, + api_key=config.api_key, + model_id=config.model_id, + timeout_ms=config.timeout_ms, + max_tokens=config.max_tokens, + prompt_path=config.prompt_path, + ) + + +class LLMConfigService: + """Service for managing LLM configurations with caching. + + Caches configurations in memory to avoid database lookups on every + evaluation. Cache is refreshed on: + - TTL expiration (default 5 minutes) + - Explicit invalidation after updates + - Service restart + """ + + # Cache TTL in seconds + CACHE_TTL_SECONDS = 300 # 5 minutes + + def __init__(self): + self._cache: Dict[str, LLMConfigData] = {} + self._cache_timestamp: float = 0 + self._cache_lock = asyncio.Lock() + self._initialized = False + + async def initialize(self) -> None: + """Initialize service and load configs from database. + + Should be called on application startup. + """ + await self._refresh_cache() + self._initialized = True + logger.info(f"LLM config service initialized with {len(self._cache)} configs") + + async def _refresh_cache(self) -> None: + """Refresh the in-memory cache from database.""" + async with self._cache_lock: + try: + async with get_db_context() as db: + result = await db.execute(select(LLMConfig)) + configs = result.scalars().all() + + new_cache = {} + for config in configs: + new_cache[config.model_name] = LLMConfigData.from_orm(config) + + self._cache = new_cache + self._cache_timestamp = time.time() + logger.debug(f"LLM config cache refreshed with {len(new_cache)} entries") + + except Exception as e: + logger.error(f"Failed to refresh LLM config cache: {e}") + # Keep existing cache on error + + def _is_cache_valid(self) -> bool: + """Check if cache is still valid based on TTL.""" + return (time.time() - self._cache_timestamp) < self.CACHE_TTL_SECONDS + + async def get_config(self, model_name: str) -> Optional[LLMConfigData]: + """Get configuration for a specific model. + + Args: + model_name: Model identifier (chatgpt, gemini, claude, deepseek) + + Returns: + LLMConfigData if found, enabled, and has API key; None otherwise + """ + # Refresh cache if expired + if not self._is_cache_valid(): + await self._refresh_cache() + + # Check database cache + if model_name in self._cache: + config = self._cache[model_name] + if config.enabled and config.api_key: + return config + elif not config.enabled: + logger.debug(f"Model {model_name} is disabled") + return None + + return None + + async def get_enabled_models(self) -> List[str]: + """Get list of enabled model names. + + Returns: + List of model names that are enabled and have valid API keys + """ + if not self._is_cache_valid(): + await self._refresh_cache() + + return [ + name for name, config in self._cache.items() + if config.enabled and config.api_key + ] + + async def list_all(self) -> List[LLMConfigData]: + """List all configurations (enabled and disabled). + + Returns: + List of all LLM configurations + """ + if not self._is_cache_valid(): + await self._refresh_cache() + return list(self._cache.values()) + + async def create_or_update( + self, + model_name: str, + provider: str, + api_key: str, + model_id: str, + enabled: bool = True, + timeout_ms: int = 30000, + max_tokens: int = 1000, + prompt_path: Optional[str] = None, + ) -> LLMConfigData: + """Create or update an LLM configuration. + + Args: + model_name: Unique model identifier + provider: Provider name (openai, google, anthropic, deepseek) + api_key: API key for the provider + model_id: Model identifier (e.g., gpt-4o) + enabled: Whether the model is enabled + timeout_ms: Request timeout in milliseconds + max_tokens: Maximum tokens for response + prompt_path: Optional custom prompt path + + Returns: + Updated configuration data + """ + async with get_db_context() as db: + # Check if exists + result = await db.execute( + select(LLMConfig).where(LLMConfig.model_name == model_name) + ) + config = result.scalar_one_or_none() + + now = datetime.now(timezone.utc) + + if config: + # Update existing + config.provider = provider + config.api_key = api_key + config.model_id = model_id + config.enabled = enabled + config.timeout_ms = timeout_ms + config.max_tokens = max_tokens + config.prompt_path = prompt_path + config.updated_at = now + logger.info(f"Updated LLM config for {model_name}") + else: + # Create new + config = LLMConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + model_id=model_id, + enabled=enabled, + timeout_ms=timeout_ms, + max_tokens=max_tokens, + prompt_path=prompt_path, + created_at=now, + updated_at=now, + ) + db.add(config) + logger.info(f"Created LLM config for {model_name}") + + await db.commit() + await db.refresh(config) + + # Invalidate cache + await self._refresh_cache() + + return LLMConfigData.from_orm(config) + + async def set_enabled(self, model_name: str, enabled: bool) -> bool: + """Enable or disable a model. + + Args: + model_name: Model to enable/disable + enabled: New enabled state + + Returns: + True if model was found and updated, False otherwise + """ + async with get_db_context() as db: + result = await db.execute( + select(LLMConfig).where(LLMConfig.model_name == model_name) + ) + config = result.scalar_one_or_none() + + if not config: + return False + + config.enabled = enabled + config.updated_at = datetime.now(timezone.utc) + await db.commit() + + # Invalidate cache + await self._refresh_cache() + + logger.info(f"Model {model_name} {'enabled' if enabled else 'disabled'}") + return True + + async def delete(self, model_name: str) -> bool: + """Delete an LLM configuration. + + Args: + model_name: Model to delete + + Returns: + True if deleted, False if not found + """ + async with get_db_context() as db: + result = await db.execute( + select(LLMConfig).where(LLMConfig.model_name == model_name) + ) + config = result.scalar_one_or_none() + + if not config: + return False + + await db.delete(config) + await db.commit() + + # Invalidate cache + await self._refresh_cache() + + logger.info(f"Deleted LLM config for {model_name}") + return True + + async def update_validation_status( + self, + model_name: str, + status: str, + ) -> None: + """Update the validation status for a model. + + Args: + model_name: Model to update + status: Validation status (ok, invalid_key, rate_limited, error) + """ + async with get_db_context() as db: + result = await db.execute( + select(LLMConfig).where(LLMConfig.model_name == model_name) + ) + config = result.scalar_one_or_none() + + if config: + config.validation_status = status + config.last_validated_at = datetime.now(timezone.utc) + await db.commit() + + async def test_api_key(self, model_name: str) -> Dict: + """Test if an API key is valid by making a minimal API call. + + Args: + model_name: Model to test + + Returns: + Dict with 'success', 'message', and 'latency_ms' keys + """ + config = await self.get_config(model_name) + if not config: + return { + "success": False, + "message": f"No configuration found for {model_name}", + "latency_ms": 0, + } + + try: + # Import here to avoid circular imports + from src.services.evaluation.models import create_adapter + + start_time = time.time() + adapter = create_adapter(model_name, config) + + # Make a minimal test call + if hasattr(adapter, 'test_connection'): + result = await adapter.test_connection() + else: + # Fallback: try a minimal prompt + result = await adapter.evaluate("Say 'ok' and nothing else.") + result = result.is_success + + latency_ms = int((time.time() - start_time) * 1000) + + if result: + await self.update_validation_status(model_name, "ok") + return { + "success": True, + "message": "API key is valid", + "latency_ms": latency_ms, + } + else: + await self.update_validation_status(model_name, "error") + return { + "success": False, + "message": "API call failed", + "latency_ms": latency_ms, + } + + except Exception as e: + error_msg = str(e).lower() + if "invalid" in error_msg or "unauthorized" in error_msg or "401" in error_msg: + status = "invalid_key" + elif "rate" in error_msg or "429" in error_msg: + status = "rate_limited" + else: + status = "error" + + await self.update_validation_status(model_name, status) + return { + "success": False, + "message": str(e), + "latency_ms": 0, + } + + def invalidate_cache(self) -> None: + """Force cache invalidation (synchronous).""" + self._cache_timestamp = 0 + + +# Singleton instance +_llm_config_service: Optional[LLMConfigService] = None + + +def get_llm_config_service() -> LLMConfigService: + """Get the singleton LLM config service instance.""" + global _llm_config_service + if _llm_config_service is None: + _llm_config_service = LLMConfigService() + return _llm_config_service diff --git a/src/services/prompt/__init__.py b/src/services/prompt/__init__.py new file mode 100644 index 0000000..1d990bd --- /dev/null +++ b/src/services/prompt/__init__.py @@ -0,0 +1,13 @@ +"""Prompt service for database-backed prompt management.""" + +from src.services.prompt.service import ( + PromptService, + PromptData, + get_prompt_service, +) + +__all__ = [ + "PromptService", + "PromptData", + "get_prompt_service", +] diff --git a/src/services/prompt/service.py b/src/services/prompt/service.py new file mode 100644 index 0000000..85867b1 --- /dev/null +++ b/src/services/prompt/service.py @@ -0,0 +1,571 @@ +"""Prompt service with caching for database-backed prompt management. + +This service provides the core business logic for managing AI prompts: + +Features: + - Database-backed prompt storage with versioning + - In-memory caching with configurable TTL (default 5 minutes) + - CRUD operations for prompts via async methods + - Support for core + wrapper prompt pattern + - Automatic seeding from file-based prompts on first run + - Thread-safe cache operations with asyncio locks + +Architecture: + The service uses a singleton pattern (get_prompt_service()) to ensure + consistent caching across the application. The cache is refreshed on: + - TTL expiration (every 5 minutes by default) + - After any create/update/delete operation + - On service initialization + +Usage: + from src.services.prompt import get_prompt_service + + # Get the singleton service + service = get_prompt_service() + + # Initialize on startup (loads cache, seeds if empty) + await service.initialize() + + # Render a prompt for model evaluation + prompt, version, hash = await service.render_prompt( + model_name="chatgpt", + enriched_event={"signal": "BUY", ...}, + constraints={"max_position": 1000}, + ) + + # CRUD operations + await service.create(name="core_decision", version="v2", ...) + await service.update(name="core_decision", version="v2", content=...) + await service.delete(name="core_decision", version="v2") +""" + +import asyncio +import hashlib +import json +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from sqlalchemy import select, and_ + +from src.models.database import get_db_context +from src.models.orm.prompt import Prompt +from src.observability.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class PromptData: + """Immutable data transfer object for prompt information. + + This dataclass is used to transfer prompt data between layers without + exposing the ORM model directly. It's immutable to prevent accidental + modifications and ensure thread-safety. + + Attributes: + id: UUID as string + name: Prompt identifier (e.g., "core_decision", "chatgpt_wrapper") + version: Version string (e.g., "v1", "v2") + prompt_type: Either "core" or "wrapper" + model_name: For wrapper prompts, the target model name + content: The actual prompt text + content_hash: SHA-256 hash for change detection + is_active: Whether this prompt is currently active + description: Human-readable description + created_at: Creation timestamp + """ + + id: str + name: str + version: str + prompt_type: str # "core" or "wrapper" + model_name: Optional[str] # For wrapper prompts only + content: str + content_hash: str + is_active: bool + description: Optional[str] + created_at: datetime + + @classmethod + def from_orm(cls, prompt: Prompt) -> "PromptData": + """Create a PromptData instance from an ORM Prompt model. + + Args: + prompt: SQLAlchemy ORM Prompt instance + + Returns: + PromptData with copied values (immutable snapshot) + """ + return cls( + id=str(prompt.id), + name=prompt.name, + version=prompt.version, + prompt_type=prompt.prompt_type, + model_name=prompt.model_name, + content=prompt.content, + content_hash=prompt.content_hash, + is_active=prompt.is_active, + description=prompt.description, + created_at=prompt.created_at, + ) + + +class PromptService: + """Service for managing prompts with caching. + + Caches prompts in memory to avoid database lookups on every + evaluation. Cache is refreshed on: + - TTL expiration (default 5 minutes) + - Explicit invalidation after updates + - Service restart + """ + + CACHE_TTL_SECONDS = 300 # 5 minutes + + def __init__(self, prompts_dir: str = "prompts"): + self._cache: Dict[str, PromptData] = {} # key: f"{name}:{version}" + self._cache_timestamp: float = 0 + self._cache_lock = asyncio.Lock() + self._initialized = False + self._prompts_dir = Path(prompts_dir) + + async def initialize(self) -> None: + """Initialize service and load prompts from database. + + If database is empty, seeds from file-based prompts. + """ + await self._refresh_cache() + + # If no prompts in database, seed from files + if not self._cache: + await self._seed_from_files() + await self._refresh_cache() + + self._initialized = True + logger.info(f"Prompt service initialized with {len(self._cache)} prompts") + + async def _seed_from_files(self) -> None: + """Seed database with prompts from file system.""" + if not self._prompts_dir.exists(): + logger.warning(f"Prompts directory not found: {self._prompts_dir}") + return + + seeded_count = 0 + for filepath in self._prompts_dir.glob("*.md"): + name = filepath.stem + content = filepath.read_text(encoding="utf-8") + + # Determine prompt type from filename + if name.startswith("core_decision_"): + prompt_type = "core" + prompt_name = "core_decision" + version = name.replace("core_decision_", "") + model_name = None + elif "_wrapper_" in name: + prompt_type = "wrapper" + parts = name.rsplit("_wrapper_", 1) + model_name = parts[0] + prompt_name = f"{model_name}_wrapper" + version = parts[1] if len(parts) > 1 else "v1" + else: + # Skip non-standard files + continue + + try: + await self.create( + name=prompt_name, + version=version, + prompt_type=prompt_type, + content=content, + model_name=model_name, + description=f"Seeded from {filepath.name}", + created_by="system", + ) + seeded_count += 1 + logger.info(f"Seeded prompt: {prompt_name} {version}") + except Exception as e: + logger.error(f"Failed to seed prompt {name}: {e}") + + logger.info(f"Seeded {seeded_count} prompts from files") + + async def _refresh_cache(self) -> None: + """Refresh the in-memory cache from database.""" + async with self._cache_lock: + try: + async with get_db_context() as db: + # Only load active prompts + result = await db.execute( + select(Prompt).where(Prompt.is_active == True) + ) + prompts = result.scalars().all() + + new_cache = {} + for prompt in prompts: + key = f"{prompt.name}:{prompt.version}" + new_cache[key] = PromptData.from_orm(prompt) + + self._cache = new_cache + self._cache_timestamp = time.time() + logger.debug(f"Prompt cache refreshed with {len(new_cache)} entries") + + except Exception as e: + logger.error(f"Failed to refresh prompt cache: {e}") + + def _is_cache_valid(self) -> bool: + """Check if cache is still valid based on TTL.""" + return (time.time() - self._cache_timestamp) < self.CACHE_TTL_SECONDS + + @staticmethod + def _compute_hash(content: str) -> str: + """Compute SHA-256 hash of content.""" + return hashlib.sha256(content.encode()).hexdigest() + + async def get_prompt( + self, + name: str, + version: str = "v1", + ) -> Optional[PromptData]: + """Get a specific prompt by name and version. + + Args: + name: Prompt name (e.g., 'core_decision', 'chatgpt_wrapper') + version: Prompt version (e.g., 'v1') + + Returns: + PromptData if found and active, None otherwise + """ + if not self._is_cache_valid(): + await self._refresh_cache() + + key = f"{name}:{version}" + return self._cache.get(key) + + async def get_core_prompt(self, version: str = "v1") -> Optional[PromptData]: + """Get the core decision prompt. + + Args: + version: Core prompt version + + Returns: + PromptData for core prompt, None if not found + """ + return await self.get_prompt("core_decision", version) + + async def get_wrapper_prompt( + self, + model_name: str, + version: str = "v1", + ) -> Optional[PromptData]: + """Get a wrapper prompt for a specific model. + + Args: + model_name: Model name (chatgpt, gemini, claude, deepseek) + version: Wrapper version + + Returns: + PromptData for wrapper prompt, None if not found + """ + return await self.get_prompt(f"{model_name}_wrapper", version) + + async def render_prompt( + self, + model_name: str, + enriched_event: dict, + constraints: dict, + core_version: str = "v1", + wrapper_version: str = "v1", + ) -> Tuple[str, str, str]: + """Render a complete prompt for model evaluation. + + Args: + model_name: Model name (e.g., 'chatgpt', 'gemini') + enriched_event: The enriched signal event data + constraints: Trading constraints + core_version: Core prompt version + wrapper_version: Wrapper prompt version + + Returns: + Tuple of (rendered_prompt, prompt_version, prompt_hash) + + Raises: + ValueError: If required prompts not found + """ + # Get prompts + core = await self.get_core_prompt(core_version) + wrapper = await self.get_wrapper_prompt(model_name, wrapper_version) + + if not core: + raise ValueError(f"Core prompt version {core_version} not found") + if not wrapper: + raise ValueError(f"Wrapper prompt for {model_name} version {wrapper_version} not found") + + # Render core prompt with data + rendered_core = core.content.replace( + "{enriched_event}", + json.dumps(enriched_event, indent=2, default=str) + ).replace( + "{constraints}", + json.dumps(constraints, indent=2) + ) + + # Combine wrapper with rendered core + rendered_prompt = wrapper.content.replace("{core_prompt}", rendered_core) + + # Generate version and hash + prompt_version = f"{model_name}_{wrapper_version}_core_{core_version}" + prompt_hash = self._compute_hash(wrapper.content + core.content) + + return rendered_prompt, prompt_version, prompt_hash + + async def list_all( + self, + prompt_type: Optional[str] = None, + include_inactive: bool = False, + ) -> List[PromptData]: + """List all prompts with optional filtering. + + Args: + prompt_type: Filter by type ('core' or 'wrapper') + include_inactive: Include inactive prompts + + Returns: + List of prompts matching criteria + """ + async with get_db_context() as db: + query = select(Prompt) + + if prompt_type: + query = query.where(Prompt.prompt_type == prompt_type) + + if not include_inactive: + query = query.where(Prompt.is_active == True) + + query = query.order_by(Prompt.name, Prompt.version) + + result = await db.execute(query) + prompts = result.scalars().all() + + return [PromptData.from_orm(p) for p in prompts] + + async def create( + self, + name: str, + version: str, + prompt_type: str, + content: str, + model_name: Optional[str] = None, + description: Optional[str] = None, + created_by: Optional[str] = None, + ) -> PromptData: + """Create a new prompt. + + Args: + name: Prompt name + version: Version string + prompt_type: 'core' or 'wrapper' + content: Prompt content + model_name: For wrapper prompts, which model this is for + description: Optional description + created_by: Who created this prompt + + Returns: + Created prompt data + + Raises: + ValueError: If prompt_type is invalid or duplicate exists + """ + if prompt_type not in ("core", "wrapper"): + raise ValueError(f"Invalid prompt_type: {prompt_type}") + + if prompt_type == "wrapper" and not model_name: + raise ValueError("model_name is required for wrapper prompts") + + content_hash = self._compute_hash(content) + now = datetime.now(timezone.utc) + + async with get_db_context() as db: + # Check for existing active prompt with same name/version + result = await db.execute( + select(Prompt).where( + and_( + Prompt.name == name, + Prompt.version == version, + Prompt.is_active == True, + ) + ) + ) + existing = result.scalar_one_or_none() + + if existing: + raise ValueError(f"Active prompt {name} version {version} already exists") + + prompt = Prompt( + name=name, + version=version, + prompt_type=prompt_type, + model_name=model_name, + content=content, + content_hash=content_hash, + description=description, + is_active=True, + created_by=created_by, + created_at=now, + updated_at=now, + ) + db.add(prompt) + await db.commit() + await db.refresh(prompt) + + logger.info(f"Created prompt: {name} {version}") + + # Invalidate cache + await self._refresh_cache() + + return PromptData.from_orm(prompt) + + async def update( + self, + name: str, + version: str, + content: Optional[str] = None, + description: Optional[str] = None, + is_active: Optional[bool] = None, + ) -> Optional[PromptData]: + """Update an existing prompt. + + Args: + name: Prompt name + version: Version string + content: New content (optional) + description: New description (optional) + is_active: New active status (optional) + + Returns: + Updated prompt data, or None if not found + """ + async with get_db_context() as db: + result = await db.execute( + select(Prompt).where( + and_(Prompt.name == name, Prompt.version == version) + ) + ) + prompt = result.scalar_one_or_none() + + if not prompt: + return None + + now = datetime.now(timezone.utc) + + if content is not None: + prompt.content = content + prompt.content_hash = self._compute_hash(content) + + if description is not None: + prompt.description = description + + if is_active is not None: + prompt.is_active = is_active + + prompt.updated_at = now + + await db.commit() + await db.refresh(prompt) + + logger.info(f"Updated prompt: {name} {version}") + + # Invalidate cache + await self._refresh_cache() + + return PromptData.from_orm(prompt) + + async def set_active( + self, + name: str, + version: str, + is_active: bool, + ) -> bool: + """Set active status for a prompt. + + Args: + name: Prompt name + version: Version string + is_active: New active status + + Returns: + True if updated, False if not found + """ + result = await self.update(name, version, is_active=is_active) + return result is not None + + async def delete(self, name: str, version: str) -> bool: + """Delete a prompt. + + Args: + name: Prompt name + version: Version string + + Returns: + True if deleted, False if not found + """ + async with get_db_context() as db: + result = await db.execute( + select(Prompt).where( + and_(Prompt.name == name, Prompt.version == version) + ) + ) + prompt = result.scalar_one_or_none() + + if not prompt: + return False + + await db.delete(prompt) + await db.commit() + + logger.info(f"Deleted prompt: {name} {version}") + + # Invalidate cache + await self._refresh_cache() + + return True + + async def get_available_prompts(self) -> dict: + """List available prompts grouped by type. + + Returns: + Dict with core_versions and wrappers keys + """ + prompts = await self.list_all() + + result = { + "core_versions": [], + "wrappers": {}, + } + + for p in prompts: + if p.prompt_type == "core": + result["core_versions"].append(p.version) + elif p.prompt_type == "wrapper" and p.model_name: + if p.model_name not in result["wrappers"]: + result["wrappers"][p.model_name] = [] + result["wrappers"][p.model_name].append(p.version) + + return result + + def invalidate_cache(self) -> None: + """Force cache invalidation (synchronous).""" + self._cache_timestamp = 0 + + +# Singleton instance +_prompt_service: Optional[PromptService] = None + + +def get_prompt_service() -> PromptService: + """Get the singleton prompt service instance.""" + global _prompt_service + if _prompt_service is None: + _prompt_service = PromptService() + return _prompt_service diff --git a/src/services/publisher/ws_server.py b/src/services/publisher/ws_server.py index a86662e..ea611fb 100644 --- a/src/services/publisher/ws_server.py +++ b/src/services/publisher/ws_server.py @@ -4,7 +4,7 @@ import json from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Dict, List +from typing import Dict, List, Optional from uuid import uuid4 from fastapi import WebSocket, WebSocketDisconnect @@ -47,17 +47,18 @@ def __init__(self): self.subscriptions: Dict[str, Subscription] = {} self._lock = asyncio.Lock() - async def connect(self, websocket: WebSocket) -> str: + async def connect(self, websocket: WebSocket, subprotocol: Optional[str] = None) -> str: """ Accept a new WebSocket connection. Args: websocket: FastAPI WebSocket instance + subprotocol: Optional subprotocol to echo back (e.g., "bearer" for auth) Returns: Subscription ID """ - await websocket.accept() + await websocket.accept(subprotocol=subprotocol) sub_id = str(uuid4()) subscription = Subscription( @@ -221,14 +222,15 @@ def get_stats(self) -> dict: ws_manager = WebSocketManager() -async def handle_websocket(websocket: WebSocket) -> None: +async def handle_websocket(websocket: WebSocket, subprotocol: Optional[str] = None) -> None: """ Handle a WebSocket connection lifecycle. Args: websocket: FastAPI WebSocket instance + subprotocol: Optional subprotocol to echo back (e.g., "bearer" for auth) """ - sub_id = await ws_manager.connect(websocket) + sub_id = await ws_manager.connect(websocket, subprotocol=subprotocol) try: while True: diff --git a/src/workers/evaluation_worker.py b/src/workers/evaluation_worker.py index 708eae0..d3dd4f6 100644 --- a/src/workers/evaluation_worker.py +++ b/src/workers/evaluation_worker.py @@ -9,11 +9,11 @@ - Per-model error isolation (one failure doesn't block others) - Token usage tracking - Prompt version tracking + - Runtime LLM configuration via database (API key management without restart) Configuration: - AI_MODELS: Comma-separated list of models to use (e.g., "chatgpt,gemini,claude") - USE_REAL_AI: Set to true for real AI evaluation, false for stub mode (default: false) - MODEL_{NAME}_*: Per-model configuration (API_KEY, MODEL_ID, TIMEOUT_MS, etc.) + USE_REAL_AI: Set to true for real AI evaluation, false for stub mode + LLM configs: Managed via /api/v1/llm-configs endpoints or env vars (legacy) Usage: worker = EvaluationWorker(redis_client, "worker-1") @@ -41,7 +41,8 @@ normalize_decision_output, validate_decision_output, ) -from src.services.evaluation.prompt_loader import get_prompt_for_model +from src.services.llm_config import get_llm_config_service +from src.services.prompt import get_prompt_service from src.services.publisher.publisher import publisher from src.services.queue import QueueConsumer, RedisClient @@ -81,18 +82,26 @@ def __init__(self, redis_client: RedisClient, consumer_name: str): ) self._adapters: Dict[str, BaseModelAdapter] = {} - def _get_adapter(self, model_name: str) -> Optional[BaseModelAdapter]: + async def _get_adapter(self, model_name: str) -> Optional[BaseModelAdapter]: """Get or create adapter for a model (lazy initialization). + Fetches configuration from LLM config service (database-backed with caching) + and creates adapter. Falls back to environment variables if not in database. + Args: model_name: Name of the model (e.g., 'chatgpt') Returns: - Model adapter or None if creation fails + Model adapter or None if creation fails or model not configured """ if model_name not in self._adapters: try: - adapter = create_adapter(model_name) + # Get config from LLM config service (database-backed) + llm_service = get_llm_config_service() + config = await llm_service.get_config(model_name) + + # Create adapter with config (or fall back to env vars if None) + adapter = create_adapter(model_name, config) if adapter.is_configured: self._adapters[model_name] = adapter logger.info(f"Initialized adapter for {model_name}") @@ -121,8 +130,14 @@ async def process_message(self, event_id: str, payload: Dict[str, Any]) -> bool: log_stage(logger, "EVALUATION", event_id, status="started") try: - # Get models to evaluate - models = settings.ai_models_list + # Get enabled models from LLM config service (database-backed) + llm_service = get_llm_config_service() + models = await llm_service.get_enabled_models() + + if not models: + # Fall back to settings if no models configured in database + models = settings.ai_models_list + logger.debug("No models in database, using settings.ai_models_list") if settings.use_real_ai: # Parallel evaluation with real AI models @@ -258,15 +273,16 @@ async def _evaluate_with_model( start_time = time.time() try: - # Get adapter - adapter = self._get_adapter(model_name) + # Get adapter (async - fetches config from database) + adapter = await self._get_adapter(model_name) if adapter is None: logger.warning(f"No adapter available for {model_name}, skipping") return None - # Render prompt + # Render prompt using database-backed prompt service constraints = payload.get("constraints", {}) - prompt, prompt_version, prompt_hash = get_prompt_for_model( + prompt_service = get_prompt_service() + prompt, prompt_version, prompt_hash = await prompt_service.render_prompt( model_name=model_name, enriched_event=payload, constraints=constraints, diff --git a/tests/conftest.py b/tests/conftest.py index 115d162..7be9750 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ os.environ["REDIS_URL"] = "redis://localhost:6379" os.environ["LOG_LEVEL"] = "WARNING" os.environ["USE_REAL_AI"] = "false" # Use stub decisions in tests +os.environ["AUTH_MODE"] = "none" # Disable auth in tests @pytest.fixture(scope="session") diff --git a/tests/test_api/test_auth.py b/tests/test_api/test_auth.py new file mode 100644 index 0000000..b0ba47a --- /dev/null +++ b/tests/test_api/test_auth.py @@ -0,0 +1,833 @@ +"""API integration tests for authentication. + +These tests verify auth behavior at the API endpoint level, +complementing the unit tests in test_core/test_auth.py. +""" + +import pytest +from unittest.mock import patch, AsyncMock +from fastapi.testclient import TestClient + + +@pytest.fixture +def mock_db_session(): + """Mock database session for API tests.""" + mock = AsyncMock() + mock.commit = AsyncMock() + mock.rollback = AsyncMock() + return mock + + +@pytest.mark.unit +class TestAuthModeNone: + """Tests for AUTH_MODE=none (development mode).""" + + def test_signals_endpoint_allows_unauthenticated(self, mock_db_session): + """Test that signals endpoint allows unauthenticated requests in mode=none.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "none" + + # Import after patching + from src.core.auth import get_auth_context, Scope + + ctx = get_auth_context(authorization=None) + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_admin_endpoints_allow_unauthenticated(self, mock_db_session): + """Test that admin endpoints allow unauthenticated in mode=none.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "none" + + from src.core.auth import get_auth_context, Scope + + ctx = get_auth_context(authorization=None) + assert ctx.has_scope(Scope.ADMIN) is True + + +@pytest.mark.unit +class TestAuthModePSK: + """Tests for AUTH_MODE=psk (pre-shared key mode).""" + + def test_signals_requires_submit_token(self): + """Test that signals endpoint requires submit token in PSK mode.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + from fastapi import HTTPException + + # No token - should fail + ctx = get_auth_context(authorization=None) + assert ctx.authenticated is False + + dependency = require_scope(Scope.SUBMIT) + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + assert exc_info.value.status_code == 401 + + def test_signals_accepts_submit_token(self): + """Test that signals endpoint accepts valid submit token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + + ctx = get_auth_context(authorization="Bearer submit-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + dependency = require_scope(Scope.SUBMIT) + result = dependency(ctx) + assert result.authenticated is True + + def test_signals_rejects_read_token(self): + """Test that signals endpoint rejects read token (wrong scope).""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + from fastapi import HTTPException + + ctx = get_auth_context(authorization="Bearer read-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.READ) is True + assert ctx.has_scope(Scope.SUBMIT) is False + + dependency = require_scope(Scope.SUBMIT) + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + assert exc_info.value.status_code == 403 + + def test_admin_token_grants_all_scopes(self): + """Test that admin token grants access to all scopes.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + + ctx = get_auth_context(authorization="Bearer admin-secret") + assert ctx.authenticated is True + + # Admin should satisfy all scope requirements + for scope in [Scope.SUBMIT, Scope.READ, Scope.ADMIN]: + dependency = require_scope(scope) + result = dependency(ctx) + assert result.authenticated is True + + def test_events_requires_read_token(self): + """Test that events endpoint requires read token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + + # Submit token should not work for read endpoints + ctx = get_auth_context(authorization="Bearer submit-secret") + assert ctx.has_scope(Scope.READ) is False + + # Read token should work + ctx = get_auth_context(authorization="Bearer read-secret") + assert ctx.has_scope(Scope.READ) is True + + def test_dlq_retry_requires_admin_token(self): + """Test that DLQ retry requires admin token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + from fastapi import HTTPException + + # Read token should not work for admin endpoints + ctx = get_auth_context(authorization="Bearer read-secret") + dependency = require_scope(Scope.ADMIN) + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + assert exc_info.value.status_code == 403 + + # Admin token should work + ctx = get_auth_context(authorization="Bearer admin-secret") + result = dependency(ctx) + assert result.authenticated is True + + def test_llm_configs_requires_admin_token(self): + """Test that LLM configs endpoints require admin token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + from fastapi import HTTPException + + # Submit token should not work + ctx = get_auth_context(authorization="Bearer submit-secret") + dependency = require_scope(Scope.ADMIN) + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + assert exc_info.value.status_code == 403 + + +@pytest.mark.unit +class TestWebSocketAuth: + """Tests for WebSocket authentication.""" + + @pytest.mark.asyncio + async def test_websocket_auth_mode_none(self): + """Test WebSocket allows all connections in mode=none.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "none" + + from src.core.auth import get_websocket_auth_context, Scope + + mock_ws = MagicMock() + mock_ws.headers = {} + + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is True + assert ctx.has_scope(Scope.READ) is True + + @pytest.mark.asyncio + async def test_websocket_auth_psk_valid_token(self): + """Test WebSocket accepts valid PSK token.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_websocket_auth_context, Scope + + mock_ws = MagicMock() + mock_ws.headers = {"sec-websocket-protocol": "bearer,read-secret"} + + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is True + assert ctx.has_scope(Scope.READ) is True + + @pytest.mark.asyncio + async def test_websocket_auth_psk_missing_token(self): + """Test WebSocket rejects missing token in PSK mode.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_websocket_auth_context + + mock_ws = MagicMock() + mock_ws.headers = {} + + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is False + + @pytest.mark.asyncio + async def test_websocket_auth_psk_invalid_token(self): + """Test WebSocket rejects invalid PSK token.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_websocket_auth_context + + mock_ws = MagicMock() + mock_ws.headers = {"sec-websocket-protocol": "bearer,wrong-token"} + + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is False + + +@pytest.mark.unit +class TestAuthErrorResponses: + """Tests for auth error response formats.""" + + def test_401_response_format(self): + """Test 401 error response has correct format.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + from fastapi import HTTPException + + ctx = get_auth_context(authorization=None) + dependency = require_scope(Scope.SUBMIT) + + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail["error"]["code"] == "UNAUTHORIZED" + assert "WWW-Authenticate" in exc_info.value.headers + + def test_403_response_format(self): + """Test 403 error response has correct format.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, require_scope, Scope + from fastapi import HTTPException + + # Authenticated but wrong scope + ctx = get_auth_context(authorization="Bearer submit-secret") + dependency = require_scope(Scope.ADMIN) + + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail["error"]["code"] == "FORBIDDEN" + assert "lens:admin" in exc_info.value.detail["error"]["message"] + + +@pytest.mark.unit +class TestScopeEndpointMapping: + """Tests verifying correct scope requirements on endpoints.""" + + def test_submit_scope_endpoints(self): + """Verify endpoints that require lens:submit scope.""" + from src.core.auth import Scope + + # These should require SUBMIT + submit_endpoints = [ + ("POST", "/api/v1/signals"), + ] + + # Document the mapping (actual enforcement tested above) + for method, path in submit_endpoints: + assert Scope.SUBMIT.value == "lens:submit" + + def test_read_scope_endpoints(self): + """Verify endpoints that require lens:read scope.""" + from src.core.auth import Scope + + # These should require READ + read_endpoints = [ + ("GET", "/api/v1/events"), + ("GET", "/api/v1/events/{event_id}"), + ("GET", "/api/v1/events/{event_id}/status"), + ("GET", "/api/v1/decisions"), + ("GET", "/api/v1/decisions/{decision_id}"), + ("GET", "/api/v1/dlq"), + ("GET", "/api/v1/dlq/{dlq_id}"), + ("WS", "/api/v1/ws/stream"), + ] + + # Document the mapping + for method, path in read_endpoints: + assert Scope.READ.value == "lens:read" + + def test_admin_scope_endpoints(self): + """Verify endpoints that require lens:admin scope.""" + from src.core.auth import Scope + + # These should require ADMIN + admin_endpoints = [ + ("POST", "/api/v1/dlq/{dlq_id}/retry"), + ("POST", "/api/v1/dlq/{dlq_id}/resolve"), + ("GET", "/api/v1/llm-configs"), + ("GET", "/api/v1/llm-configs/{model_name}"), + ("PUT", "/api/v1/llm-configs/{model_name}"), + ("PATCH", "/api/v1/llm-configs/{model_name}"), + ("DELETE", "/api/v1/llm-configs/{model_name}"), + ("POST", "/api/v1/llm-configs/{model_name}/test"), + ("POST", "/api/v1/llm-configs/{model_name}/enable"), + ("POST", "/api/v1/llm-configs/{model_name}/disable"), + ] + + # Document the mapping + for method, path in admin_endpoints: + assert Scope.ADMIN.value == "lens:admin" + + +@pytest.mark.unit +class TestAuthorizationHeaderEdgeCases: + """Tests for edge cases in Authorization header parsing.""" + + def test_bearer_case_insensitive(self): + """Test that Bearer prefix is case-insensitive.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + # Lowercase 'bearer' + ctx = get_auth_context(authorization="bearer submit-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + # Uppercase 'BEARER' + ctx = get_auth_context(authorization="BEARER submit-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + # Mixed case 'BeArEr' + ctx = get_auth_context(authorization="BeArEr submit-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_missing_bearer_prefix(self): + """Test that token without Bearer prefix is rejected.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context + + # Just the token without Bearer prefix + ctx = get_auth_context(authorization="submit-secret") + assert ctx.authenticated is False + + def test_wrong_auth_scheme(self): + """Test that non-Bearer auth schemes are rejected.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context + + # Basic auth scheme + ctx = get_auth_context(authorization="Basic dXNlcjpwYXNz") + assert ctx.authenticated is False + + # Digest auth scheme + ctx = get_auth_context(authorization="Digest username=test") + assert ctx.authenticated is False + + def test_empty_authorization_header(self): + """Test that empty Authorization header is handled.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context + + ctx = get_auth_context(authorization="") + assert ctx.authenticated is False + + def test_bearer_only_no_token(self): + """Test that 'Bearer' without token is rejected.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context + + # Just 'Bearer' with nothing else + ctx = get_auth_context(authorization="Bearer") + assert ctx.authenticated is False + + # 'Bearer ' with trailing space but no token + ctx = get_auth_context(authorization="Bearer ") + assert ctx.authenticated is False + + def test_extra_spaces_in_header(self): + """Test handling of extra spaces in Authorization header.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + # Multiple spaces between Bearer and token + # split() without args splits on any whitespace and removes empties + # So "Bearer submit-secret".split() returns ['Bearer', 'submit-secret'] + ctx = get_auth_context(authorization="Bearer submit-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_bearer_with_multiple_parts(self): + """Test that tokens with spaces are handled correctly.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context + + # Token with space (e.g., "Bearer token with space") + ctx = get_auth_context(authorization="Bearer token with space") + # Should extract only "token" and fail since it doesn't match + assert ctx.authenticated is False + + +@pytest.mark.unit +class TestPSKPartialConfiguration: + """Tests for partially configured PSK tokens.""" + + def test_only_submit_token_configured(self): + """Test PSK mode with only submit token configured.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + # Submit token works + ctx = get_auth_context(authorization="Bearer submit-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + # Random token fails (no read/admin tokens to match) + ctx = get_auth_context(authorization="Bearer random-token") + assert ctx.authenticated is False + + def test_only_read_token_configured(self): + """Test PSK mode with only read token configured.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + # Read token works + ctx = get_auth_context(authorization="Bearer read-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.READ) is True + + # Cannot get submit scope + assert ctx.has_scope(Scope.SUBMIT) is False + + def test_only_admin_token_configured(self): + """Test PSK mode with only admin token configured.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, Scope + + # Admin token works and grants all scopes + ctx = get_auth_context(authorization="Bearer admin-secret") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.ADMIN) is True + assert ctx.has_scope(Scope.READ) is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_empty_string_tokens_not_matched(self): + """Test that empty string tokens are not matched.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "" + mock_settings.AUTH_TOKEN_READ = "" + mock_settings.AUTH_TOKEN_ADMIN = "" + + from src.core.auth import get_auth_context + + # Empty token should not authenticate + ctx = get_auth_context(authorization="Bearer ") + assert ctx.authenticated is False + + # Empty string exact match should also fail + ctx = get_auth_context(authorization="Bearer") + assert ctx.authenticated is False + + def test_no_tokens_configured(self): + """Test PSK mode with no tokens configured at all.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context + + # Any token should fail + ctx = get_auth_context(authorization="Bearer any-token") + assert ctx.authenticated is False + + +@pytest.mark.unit +class TestWebSocketAuthEdgeCases: + """Tests for WebSocket authentication edge cases.""" + + @pytest.mark.asyncio + async def test_websocket_malformed_protocol_header(self): + """Test WebSocket with malformed sec-websocket-protocol header.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_websocket_auth_context + + # Missing comma separator + mock_ws = MagicMock() + mock_ws.headers = {"sec-websocket-protocol": "bearer read-secret"} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is False + + # Just 'bearer' without token + mock_ws.headers = {"sec-websocket-protocol": "bearer"} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is False + + # Empty protocol header + mock_ws.headers = {"sec-websocket-protocol": ""} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is False + + @pytest.mark.asyncio + async def test_websocket_bearer_case_sensitivity(self): + """Test that WebSocket bearer prefix is case-insensitive.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_websocket_auth_context, Scope + + # Lowercase 'bearer' + mock_ws = MagicMock() + mock_ws.headers = {"sec-websocket-protocol": "bearer,read-secret"} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is True + + # Uppercase 'BEARER' + mock_ws.headers = {"sec-websocket-protocol": "BEARER,read-secret"} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is True + + @pytest.mark.asyncio + async def test_websocket_extra_protocols(self): + """Test WebSocket with additional protocol values.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_websocket_auth_context, Scope + + # Extra protocol after token (should still work, takes parts[1]) + mock_ws = MagicMock() + mock_ws.headers = {"sec-websocket-protocol": "bearer,read-secret,extra"} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is True + assert ctx.has_scope(Scope.READ) is True + + @pytest.mark.asyncio + async def test_websocket_whitespace_in_protocol(self): + """Test WebSocket handles whitespace in protocol header.""" + from unittest.mock import MagicMock + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_websocket_auth_context, Scope + + # Spaces around comma (should be stripped) + mock_ws = MagicMock() + mock_ws.headers = {"sec-websocket-protocol": "bearer, read-secret"} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is True + assert ctx.has_scope(Scope.READ) is True + + # Spaces on both sides + mock_ws.headers = {"sec-websocket-protocol": " bearer , read-secret "} + ctx = await get_websocket_auth_context(mock_ws) + assert ctx.authenticated is True + + +@pytest.mark.unit +class TestAuthContextHierarchy: + """Tests for scope hierarchy and AuthContext.has_scope().""" + + def test_admin_includes_submit_scope(self): + """Test that admin scope includes submit via hierarchy.""" + from src.core.auth import AuthContext, Scope + + ctx = AuthContext( + authenticated=True, + scopes={Scope.ADMIN}, + token_type="psk", + subject="admin" + ) + + assert ctx.has_scope(Scope.ADMIN) is True + assert ctx.has_scope(Scope.READ) is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_read_does_not_include_submit(self): + """Test that read scope does not include submit.""" + from src.core.auth import AuthContext, Scope + + ctx = AuthContext( + authenticated=True, + scopes={Scope.READ}, + token_type="psk", + subject="read" + ) + + assert ctx.has_scope(Scope.READ) is True + assert ctx.has_scope(Scope.SUBMIT) is False + assert ctx.has_scope(Scope.ADMIN) is False + + def test_submit_does_not_include_read(self): + """Test that submit scope does not include read.""" + from src.core.auth import AuthContext, Scope + + ctx = AuthContext( + authenticated=True, + scopes={Scope.SUBMIT}, + token_type="psk", + subject="submit" + ) + + assert ctx.has_scope(Scope.SUBMIT) is True + assert ctx.has_scope(Scope.READ) is False + assert ctx.has_scope(Scope.ADMIN) is False + + def test_empty_scopes_has_nothing(self): + """Test that empty scopes grants nothing.""" + from src.core.auth import AuthContext, Scope + + ctx = AuthContext( + authenticated=False, + scopes=set(), + ) + + assert ctx.has_scope(Scope.SUBMIT) is False + assert ctx.has_scope(Scope.READ) is False + assert ctx.has_scope(Scope.ADMIN) is False + + def test_multiple_scopes(self): + """Test context with multiple explicit scopes.""" + from src.core.auth import AuthContext, Scope + + ctx = AuthContext( + authenticated=True, + scopes={Scope.SUBMIT, Scope.READ}, + token_type="jwt", + subject="multi-scope-user" + ) + + assert ctx.has_scope(Scope.SUBMIT) is True + assert ctx.has_scope(Scope.READ) is True + assert ctx.has_scope(Scope.ADMIN) is False + + +@pytest.mark.unit +class TestInvalidTokenFormats: + """Tests for various invalid token formats.""" + + def test_token_with_special_characters(self): + """Test tokens containing special characters.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "token-with-special!@#$%" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + ctx = get_auth_context(authorization="Bearer token-with-special!@#$%") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_very_long_token(self): + """Test handling of very long tokens.""" + with patch("src.core.auth.settings") as mock_settings: + long_token = "a" * 1000 + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = long_token + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + ctx = get_auth_context(authorization=f"Bearer {long_token}") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_unicode_in_token(self): + """Test handling of unicode characters in tokens.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "token-üñîçödé-🔐" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + ctx = get_auth_context(authorization="Bearer token-üñîçödé-🔐") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True + + def test_base64_like_token(self): + """Test tokens that look like base64 (common format).""" + with patch("src.core.auth.settings") as mock_settings: + # Realistic base64-encoded token + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "1IT0W8e-lVbahdtLQ7vGVc_doPYXKGBfLUosM8V57Ac" + mock_settings.AUTH_TOKEN_READ = None + mock_settings.AUTH_TOKEN_ADMIN = None + + from src.core.auth import get_auth_context, Scope + + ctx = get_auth_context(authorization="Bearer 1IT0W8e-lVbahdtLQ7vGVc_doPYXKGBfLUosM8V57Ac") + assert ctx.authenticated is True + assert ctx.has_scope(Scope.SUBMIT) is True \ No newline at end of file diff --git a/tests/test_api/test_prompts.py b/tests/test_api/test_prompts.py new file mode 100644 index 0000000..d2da4bb --- /dev/null +++ b/tests/test_api/test_prompts.py @@ -0,0 +1,497 @@ +"""Tests for prompt management.""" + +import pytest +from unittest.mock import patch, AsyncMock, MagicMock +from datetime import datetime, timezone +from httpx import AsyncClient + + +@pytest.fixture +def sample_prompt_data(): + """Sample prompt data for tests.""" + from src.services.prompt.service import PromptData + + return PromptData( + id="550e8400-e29b-41d4-a716-446655440000", + name="core_decision", + version="v1", + prompt_type="core", + model_name=None, + content="# Core Decision Prompt\n\n{enriched_event}\n{constraints}", + content_hash="abc123def456", + is_active=True, + description="Core decision prompt v1", + created_at=datetime.now(timezone.utc), + ) + + +@pytest.fixture +def sample_wrapper_prompt_data(): + """Sample wrapper prompt data for tests.""" + from src.services.prompt.service import PromptData + + return PromptData( + id="660e8400-e29b-41d4-a716-446655440001", + name="chatgpt_wrapper", + version="v1", + prompt_type="wrapper", + model_name="chatgpt", + content="# ChatGPT Wrapper\n\n{core_prompt}", + content_hash="def456abc123", + is_active=True, + description="ChatGPT wrapper prompt v1", + created_at=datetime.now(timezone.utc), + ) + + +@pytest.mark.unit +class TestPromptService: + """Tests for PromptService.""" + + @pytest.mark.asyncio + async def test_get_prompt_from_cache(self, sample_prompt_data): + """Test getting a prompt from cache.""" + from src.services.prompt.service import PromptService + + service = PromptService() + # Manually populate cache + service._cache = { + "core_decision:v1": sample_prompt_data, + } + service._cache_timestamp = 9999999999 # Far future + + result = await service.get_prompt("core_decision", "v1") + + assert result is not None + assert result.name == "core_decision" + assert result.version == "v1" + + @pytest.mark.asyncio + async def test_get_prompt_not_found(self): + """Test getting a non-existent prompt.""" + from src.services.prompt.service import PromptService + + service = PromptService() + service._cache = {} + service._cache_timestamp = 9999999999 + + result = await service.get_prompt("nonexistent", "v1") + + assert result is None + + @pytest.mark.asyncio + async def test_get_core_prompt(self, sample_prompt_data): + """Test getting the core prompt.""" + from src.services.prompt.service import PromptService + + service = PromptService() + service._cache = { + "core_decision:v1": sample_prompt_data, + } + service._cache_timestamp = 9999999999 + + result = await service.get_core_prompt("v1") + + assert result is not None + assert result.prompt_type == "core" + + @pytest.mark.asyncio + async def test_get_wrapper_prompt(self, sample_wrapper_prompt_data): + """Test getting a wrapper prompt.""" + from src.services.prompt.service import PromptService + + service = PromptService() + service._cache = { + "chatgpt_wrapper:v1": sample_wrapper_prompt_data, + } + service._cache_timestamp = 9999999999 + + result = await service.get_wrapper_prompt("chatgpt", "v1") + + assert result is not None + assert result.prompt_type == "wrapper" + assert result.model_name == "chatgpt" + + @pytest.mark.asyncio + async def test_render_prompt(self, sample_prompt_data, sample_wrapper_prompt_data): + """Test rendering a complete prompt.""" + from src.services.prompt.service import PromptService + + service = PromptService() + service._cache = { + "core_decision:v1": sample_prompt_data, + "chatgpt_wrapper:v1": sample_wrapper_prompt_data, + } + service._cache_timestamp = 9999999999 + + rendered, version, hash = await service.render_prompt( + model_name="chatgpt", + enriched_event={"signal": "BUY", "price": 100.0}, + constraints={"max_position": 1000}, + ) + + assert "chatgpt_v1_core_v1" in version + assert len(hash) == 64 # SHA-256 hash + assert "signal" in rendered + assert "BUY" in rendered + + @pytest.mark.asyncio + async def test_render_prompt_missing_core(self, sample_wrapper_prompt_data): + """Test render fails when core prompt is missing.""" + from src.services.prompt.service import PromptService + + service = PromptService() + service._cache = { + "chatgpt_wrapper:v1": sample_wrapper_prompt_data, + } + service._cache_timestamp = 9999999999 + + with pytest.raises(ValueError) as exc_info: + await service.render_prompt( + model_name="chatgpt", + enriched_event={}, + constraints={}, + ) + + assert "Core prompt" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_render_prompt_missing_wrapper(self, sample_prompt_data): + """Test render fails when wrapper prompt is missing.""" + from src.services.prompt.service import PromptService + + service = PromptService() + service._cache = { + "core_decision:v1": sample_prompt_data, + } + service._cache_timestamp = 9999999999 + + with pytest.raises(ValueError) as exc_info: + await service.render_prompt( + model_name="unknown", + enriched_event={}, + constraints={}, + ) + + assert "Wrapper prompt" in str(exc_info.value) + + def test_compute_hash(self): + """Test content hash computation.""" + from src.services.prompt.service import PromptService + + hash1 = PromptService._compute_hash("test content") + hash2 = PromptService._compute_hash("test content") + hash3 = PromptService._compute_hash("different content") + + assert hash1 == hash2 + assert hash1 != hash3 + assert len(hash1) == 64 + + def test_cache_validity(self): + """Test cache TTL validation.""" + import time + from src.services.prompt.service import PromptService + + service = PromptService() + + # Cache is invalid when timestamp is 0 + service._cache_timestamp = 0 + assert service._is_cache_valid() is False + + # Cache is valid when timestamp is recent + service._cache_timestamp = time.time() + assert service._is_cache_valid() is True + + # Cache is invalid when timestamp is old + service._cache_timestamp = time.time() - 600 # 10 min ago + assert service._is_cache_valid() is False + + def test_invalidate_cache(self): + """Test cache invalidation.""" + import time + from src.services.prompt.service import PromptService + + service = PromptService() + service._cache_timestamp = time.time() + + service.invalidate_cache() + + assert service._cache_timestamp == 0 + assert service._is_cache_valid() is False + + +@pytest.mark.unit +class TestPromptData: + """Tests for PromptData dataclass.""" + + def test_from_orm(self): + """Test creating PromptData from ORM model.""" + from src.services.prompt.service import PromptData + from src.models.orm.prompt import Prompt + import uuid + + now = datetime.now(timezone.utc) + prompt_id = uuid.uuid4() + + # Create a mock ORM object + orm_prompt = MagicMock(spec=Prompt) + orm_prompt.id = prompt_id + orm_prompt.name = "core_decision" + orm_prompt.version = "v1" + orm_prompt.prompt_type = "core" + orm_prompt.model_name = None + orm_prompt.content = "# Test" + orm_prompt.content_hash = "abc123" + orm_prompt.is_active = True + orm_prompt.description = "Test prompt" + orm_prompt.created_at = now + + data = PromptData.from_orm(orm_prompt) + + assert data.id == str(prompt_id) + assert data.name == "core_decision" + assert data.version == "v1" + assert data.prompt_type == "core" + assert data.content == "# Test" + assert data.is_active is True + + +@pytest.mark.unit +class TestPromptAuthRequirements: + """Tests for authentication requirements on prompt endpoints.""" + + def test_prompt_endpoints_require_admin_scope(self): + """Test that prompt endpoints require admin scope.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + + from src.core.auth import get_auth_context, require_scope, Scope + from fastapi import HTTPException + + # Read token should not work + ctx = get_auth_context(authorization="Bearer read-secret") + dependency = require_scope(Scope.ADMIN) + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + assert exc_info.value.status_code == 403 + + # Submit token should not work + ctx = get_auth_context(authorization="Bearer submit-secret") + with pytest.raises(HTTPException) as exc_info: + dependency(ctx) + assert exc_info.value.status_code == 403 + + # Admin token should work + ctx = get_auth_context(authorization="Bearer admin-secret") + result = dependency(ctx) + assert result.authenticated is True + + def test_prompt_endpoints_allow_all_in_mode_none(self): + """Test that prompt endpoints allow all requests in AUTH_MODE=none.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "none" + + from src.core.auth import get_auth_context, Scope + + ctx = get_auth_context(authorization=None) + assert ctx.has_scope(Scope.ADMIN) is True + + +@pytest.mark.unit +class TestPromptAPIIntegration: + """Integration tests for prompt API endpoints using async client.""" + + @pytest.mark.asyncio + async def test_list_prompts_endpoint(self, client: AsyncClient): + """Test GET /prompts endpoint.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_all = AsyncMock(return_value=[]) + mock_get_service.return_value = mock_service + + response = await client.get("/api/v1/prompts") + + assert response.status_code == 200 + data = response.json() + assert "items" in data + assert "total" in data + + @pytest.mark.asyncio + async def test_available_prompts_endpoint(self, client: AsyncClient): + """Test GET /prompts/available endpoint.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_available_prompts = AsyncMock(return_value={ + "core_versions": ["v1"], + "wrappers": {"chatgpt": ["v1"]}, + }) + mock_get_service.return_value = mock_service + + response = await client.get("/api/v1/prompts/available") + + assert response.status_code == 200 + data = response.json() + assert "core_versions" in data + assert "wrappers" in data + + @pytest.mark.asyncio + async def test_get_prompt_not_found(self, client: AsyncClient): + """Test GET /prompts/{name}/{version} returns 404 when not found.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_prompt = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = await client.get("/api/v1/prompts/nonexistent/v1") + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_create_prompt_endpoint(self, client: AsyncClient, sample_prompt_data): + """Test POST /prompts endpoint.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create = AsyncMock(return_value=sample_prompt_data) + mock_get_service.return_value = mock_service + + response = await client.post( + "/api/v1/prompts", + json={ + "name": "core_decision", + "version": "v1", + "prompt_type": "core", + "content": "# Core Prompt", + }, + ) + + assert response.status_code == 201 + + @pytest.mark.asyncio + async def test_delete_prompt_endpoint(self, client: AsyncClient): + """Test DELETE /prompts/{name}/{version} endpoint.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.delete = AsyncMock(return_value=True) + mock_get_service.return_value = mock_service + + response = await client.delete("/api/v1/prompts/core_decision/v1") + + assert response.status_code == 204 + + @pytest.mark.asyncio + async def test_render_prompt_endpoint(self, client: AsyncClient): + """Test POST /prompts/render endpoint.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.render_prompt = AsyncMock(return_value=( + "Rendered content", + "chatgpt_v1_core_v1", + "abc123def456", + )) + mock_get_service.return_value = mock_service + + response = await client.post( + "/api/v1/prompts/render", + json={ + "model_name": "chatgpt", + "enriched_event": {"signal": "BUY"}, + "constraints": {"max_position": 1000}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "rendered_prompt" in data + assert "prompt_version" in data + assert "prompt_hash" in data + + @pytest.mark.asyncio + async def test_update_prompt_endpoint(self, client: AsyncClient, sample_prompt_data): + """Test PUT /prompts/{name}/{version} endpoint.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.update = AsyncMock(return_value=sample_prompt_data) + mock_get_service.return_value = mock_service + + response = await client.put( + "/api/v1/prompts/core_decision/v1", + json={ + "content": "# Updated Content", + "description": "Updated description", + }, + ) + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_patch_prompt_endpoint(self, client: AsyncClient, sample_prompt_data): + """Test PATCH /prompts/{name}/{version} endpoint.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + sample_prompt_data.is_active = False + mock_service.update = AsyncMock(return_value=sample_prompt_data) + mock_get_service.return_value = mock_service + + response = await client.patch( + "/api/v1/prompts/core_decision/v1", + json={"is_active": False}, + ) + + assert response.status_code == 200 + assert response.json()["is_active"] is False + + @pytest.mark.asyncio + async def test_create_prompt_validation_error(self, client: AsyncClient): + """Test POST /prompts with validation error.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.create = AsyncMock( + side_effect=ValueError("model_name is required for wrapper prompts") + ) + mock_get_service.return_value = mock_service + + response = await client.post( + "/api/v1/prompts", + json={ + "name": "test_wrapper", + "version": "v1", + "prompt_type": "wrapper", + "content": "# Wrapper", + }, + ) + + assert response.status_code == 400 + assert "VALIDATION_ERROR" in response.json()["detail"]["error"]["code"] + + @pytest.mark.asyncio + async def test_render_prompt_error(self, client: AsyncClient): + """Test POST /prompts/render with missing prompts.""" + with patch("src.api.v1.prompts.get_prompt_service") as mock_get_service: + mock_service = MagicMock() + mock_service.render_prompt = AsyncMock( + side_effect=ValueError("Wrapper prompt for unknown version v1 not found") + ) + mock_get_service.return_value = mock_service + + response = await client.post( + "/api/v1/prompts/render", + json={ + "model_name": "unknown", + "enriched_event": {}, + "constraints": {}, + }, + ) + + assert response.status_code == 400 + assert "RENDER_ERROR" in response.json()["detail"]["error"]["code"] + + @pytest.mark.asyncio + async def test_list_prompts_filter_invalid_type(self, client: AsyncClient): + """Test GET /prompts with invalid prompt_type filter.""" + response = await client.get("/api/v1/prompts?prompt_type=invalid") + + assert response.status_code == 400 + assert "INVALID_TYPE" in response.json()["detail"]["error"]["code"] diff --git a/tests/test_core/__init__.py b/tests/test_core/__init__.py new file mode 100644 index 0000000..37c601d --- /dev/null +++ b/tests/test_core/__init__.py @@ -0,0 +1 @@ +"""Tests for core module.""" diff --git a/tests/test_core/test_auth.py b/tests/test_core/test_auth.py new file mode 100644 index 0000000..3288c3a --- /dev/null +++ b/tests/test_core/test_auth.py @@ -0,0 +1,356 @@ +"""Tests for authentication and authorization module. + +This module tests the 3-mode auth system: +- none: No auth (development mode) +- psk: Pre-shared key tokens +- jwt: JWT validation + +Scope hierarchy tests: +- lens:admin includes all scopes +- lens:submit only allows signal submission +- lens:read only allows data reading +""" + +import pytest +from unittest.mock import patch, MagicMock +from dataclasses import asdict + +from src.core.auth import ( + Scope, + AuthContext, + SCOPE_HIERARCHY, + _extract_bearer_token, + _validate_psk_token, + get_auth_context, + require_scope, +) + + +@pytest.mark.unit +class TestScope: + """Tests for Scope enum.""" + + def test_scope_values(self): + """Test scope enum values match expected strings.""" + assert Scope.SUBMIT.value == "lens:submit" + assert Scope.READ.value == "lens:read" + assert Scope.ADMIN.value == "lens:admin" + + def test_scope_hierarchy_admin_includes_all(self): + """Test that admin scope includes all other scopes.""" + admin_scopes = SCOPE_HIERARCHY[Scope.ADMIN] + assert Scope.SUBMIT in admin_scopes + assert Scope.READ in admin_scopes + assert Scope.ADMIN in admin_scopes + + def test_scope_hierarchy_submit_only(self): + """Test that submit scope only includes itself.""" + submit_scopes = SCOPE_HIERARCHY[Scope.SUBMIT] + assert Scope.SUBMIT in submit_scopes + assert Scope.READ not in submit_scopes + assert Scope.ADMIN not in submit_scopes + + def test_scope_hierarchy_read_only(self): + """Test that read scope only includes itself.""" + read_scopes = SCOPE_HIERARCHY[Scope.READ] + assert Scope.READ in read_scopes + assert Scope.SUBMIT not in read_scopes + assert Scope.ADMIN not in read_scopes + + +@pytest.mark.unit +class TestAuthContext: + """Tests for AuthContext dataclass.""" + + def test_create_authenticated_context(self): + """Test creating an authenticated context.""" + ctx = AuthContext( + authenticated=True, + scopes={Scope.SUBMIT}, + token_type="psk", + subject="test-user", + ) + + assert ctx.authenticated is True + assert Scope.SUBMIT in ctx.scopes + assert ctx.token_type == "psk" + assert ctx.subject == "test-user" + + def test_create_unauthenticated_context(self): + """Test creating an unauthenticated context.""" + ctx = AuthContext(authenticated=False, scopes=set()) + + assert ctx.authenticated is False + assert len(ctx.scopes) == 0 + + def test_has_scope_direct(self): + """Test has_scope for directly granted scope.""" + ctx = AuthContext(authenticated=True, scopes={Scope.SUBMIT}) + + assert ctx.has_scope(Scope.SUBMIT) is True + assert ctx.has_scope(Scope.READ) is False + assert ctx.has_scope(Scope.ADMIN) is False + + def test_has_scope_via_admin(self): + """Test that admin scope grants all permissions.""" + ctx = AuthContext(authenticated=True, scopes={Scope.ADMIN}) + + assert ctx.has_scope(Scope.SUBMIT) is True + assert ctx.has_scope(Scope.READ) is True + assert ctx.has_scope(Scope.ADMIN) is True + + +@pytest.mark.unit +class TestExtractBearerToken: + """Tests for bearer token extraction.""" + + def test_extract_valid_bearer_token(self): + """Test extracting a valid bearer token.""" + token = _extract_bearer_token("Bearer abc123") + assert token == "abc123" + + def test_extract_bearer_case_insensitive(self): + """Test that Bearer is case insensitive.""" + token = _extract_bearer_token("bearer abc123") + assert token == "abc123" + + token = _extract_bearer_token("BEARER abc123") + assert token == "abc123" + + def test_extract_none_when_missing(self): + """Test that None is returned when header is missing.""" + token = _extract_bearer_token(None) + assert token is None + + def test_extract_none_for_invalid_format(self): + """Test that None is returned for invalid format.""" + # No space + token = _extract_bearer_token("Bearerabc123") + assert token is None + + # Wrong scheme + token = _extract_bearer_token("Basic abc123") + assert token is None + + # No token + token = _extract_bearer_token("Bearer") + assert token is None + + +@pytest.mark.unit +class TestPSKValidation: + """Tests for PSK token validation.""" + + def test_validate_admin_token(self): + """Test validation of admin token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + + ctx = _validate_psk_token("admin-secret") + + assert ctx is not None + assert ctx.authenticated is True + assert Scope.ADMIN in ctx.scopes + assert ctx.token_type == "psk" + assert ctx.subject == "admin" + + def test_validate_submit_token(self): + """Test validation of submit token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + + ctx = _validate_psk_token("submit-secret") + + assert ctx is not None + assert ctx.authenticated is True + assert Scope.SUBMIT in ctx.scopes + assert ctx.token_type == "psk" + assert ctx.subject == "submit" + + def test_validate_read_token(self): + """Test validation of read token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + + ctx = _validate_psk_token("read-secret") + + assert ctx is not None + assert ctx.authenticated is True + assert Scope.READ in ctx.scopes + assert ctx.token_type == "psk" + assert ctx.subject == "read" + + def test_validate_invalid_token(self): + """Test that invalid token returns None.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_TOKEN_ADMIN = "admin-secret" + mock_settings.AUTH_TOKEN_SUBMIT = "submit-secret" + mock_settings.AUTH_TOKEN_READ = "read-secret" + + ctx = _validate_psk_token("wrong-token") + assert ctx is None + + def test_validate_when_tokens_not_configured(self): + """Test validation when tokens are not configured.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_TOKEN_ADMIN = None + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = None + + ctx = _validate_psk_token("any-token") + assert ctx is None + + +@pytest.mark.unit +class TestGetAuthContext: + """Tests for get_auth_context dependency.""" + + def test_mode_none_grants_all(self): + """Test that mode=none grants all scopes.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "none" + + ctx = get_auth_context(authorization=None) + + assert ctx.authenticated is True + assert Scope.ADMIN in ctx.scopes + + def test_mode_psk_with_valid_token(self): + """Test PSK mode with valid token.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_ADMIN = "admin-token" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = None + + ctx = get_auth_context(authorization="Bearer admin-token") + + assert ctx.authenticated is True + assert Scope.ADMIN in ctx.scopes + + def test_mode_psk_without_token(self): + """Test PSK mode without token returns unauthenticated.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + + ctx = get_auth_context(authorization=None) + + assert ctx.authenticated is False + assert len(ctx.scopes) == 0 + + def test_mode_psk_with_invalid_token(self): + """Test PSK mode with invalid token returns unauthenticated.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + mock_settings.AUTH_TOKEN_ADMIN = "admin-token" + mock_settings.AUTH_TOKEN_SUBMIT = None + mock_settings.AUTH_TOKEN_READ = None + + ctx = get_auth_context(authorization="Bearer wrong-token") + + assert ctx.authenticated is False + assert len(ctx.scopes) == 0 + + +@pytest.mark.unit +class TestRequireScope: + """Tests for require_scope dependency factory.""" + + def test_mode_none_allows_all(self): + """Test that mode=none allows all scopes.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "none" + + dependency = require_scope(Scope.ADMIN) + # Create auth context for mode=none + auth_ctx = AuthContext(authenticated=True, scopes={Scope.ADMIN}) + + result = dependency(auth_ctx) + assert result.authenticated is True + + def test_requires_authentication(self): + """Test that unauthenticated requests are rejected.""" + from fastapi import HTTPException + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + + dependency = require_scope(Scope.SUBMIT) + auth_ctx = AuthContext(authenticated=False, scopes=set()) + + with pytest.raises(HTTPException) as exc_info: + dependency(auth_ctx) + + assert exc_info.value.status_code == 401 + + def test_requires_correct_scope(self): + """Test that incorrect scope is rejected.""" + from fastapi import HTTPException + + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + + dependency = require_scope(Scope.ADMIN) + # Only has submit scope, not admin + auth_ctx = AuthContext(authenticated=True, scopes={Scope.SUBMIT}) + + with pytest.raises(HTTPException) as exc_info: + dependency(auth_ctx) + + assert exc_info.value.status_code == 403 + + def test_admin_scope_grants_all(self): + """Test that admin scope satisfies any scope requirement.""" + with patch("src.core.auth.settings") as mock_settings: + mock_settings.AUTH_MODE = "psk" + + # Admin should satisfy submit requirement + dependency = require_scope(Scope.SUBMIT) + auth_ctx = AuthContext(authenticated=True, scopes={Scope.ADMIN}) + + result = dependency(auth_ctx) + assert result.authenticated is True + + # Admin should satisfy read requirement + dependency = require_scope(Scope.READ) + result = dependency(auth_ctx) + assert result.authenticated is True + + +@pytest.mark.unit +class TestScopeMapping: + """Tests for endpoint scope mapping (documented, not enforced here).""" + + def test_scope_assignments(self): + """Document expected scope assignments for endpoints.""" + # This test documents the expected scope assignments + # Actual enforcement is tested in integration tests + scope_map = { + "POST /signals": Scope.SUBMIT, + "GET /events": Scope.READ, + "GET /events/{id}": Scope.READ, + "GET /events/{id}/status": Scope.READ, + "GET /decisions": Scope.READ, + "GET /decisions/{id}": Scope.READ, + "GET /dlq": Scope.READ, + "GET /dlq/{id}": Scope.READ, + "POST /dlq/{id}/retry": Scope.ADMIN, + "POST /dlq/{id}/resolve": Scope.ADMIN, + "GET /llm-configs": Scope.ADMIN, + "GET /llm-configs/{model}": Scope.ADMIN, + "PUT /llm-configs/{model}": Scope.ADMIN, + "PATCH /llm-configs/{model}": Scope.ADMIN, + "DELETE /llm-configs/{model}": Scope.ADMIN, + "POST /llm-configs/{model}/test": Scope.ADMIN, + "WS /ws/stream": Scope.READ, + } + + # Verify all expected scopes are valid + for endpoint, scope in scope_map.items(): + assert isinstance(scope, Scope), f"Invalid scope for {endpoint}" diff --git a/tests/test_services/test_llm_config.py b/tests/test_services/test_llm_config.py new file mode 100644 index 0000000..fcfcfe8 --- /dev/null +++ b/tests/test_services/test_llm_config.py @@ -0,0 +1,395 @@ +"""Tests for LLM configuration service. + +This module tests the LLM configuration service which provides: +- Database-backed LLM configuration storage +- In-memory caching with configurable TTL +- CRUD operations for LLM configs +- API key masking for security +- Fallback to environment variables + +Key test scenarios: +- LLMConfigData dataclass creation +- Config service get/list operations +- Cache invalidation and TTL behavior +- Environment variable fallback +""" + +import pytest +from unittest.mock import patch, MagicMock, AsyncMock +from dataclasses import asdict + + +@pytest.mark.unit +class TestLLMConfigData: + """Tests for LLMConfigData dataclass.""" + + def test_create_from_values(self): + """Test creating LLMConfigData with direct values.""" + from src.services.llm_config.service import LLMConfigData + + config = LLMConfigData( + model_name="chatgpt", + enabled=True, + provider="openai", + api_key="sk-test-key", + model_id="gpt-4o", + timeout_ms=30000, + max_tokens=1000, + prompt_path="prompts/test.md", + ) + + assert config.model_name == "chatgpt" + assert config.enabled is True + assert config.provider == "openai" + assert config.api_key == "sk-test-key" + assert config.model_id == "gpt-4o" + assert config.timeout_ms == 30000 + assert config.max_tokens == 1000 + assert config.prompt_path == "prompts/test.md" + + def test_create_from_orm(self): + """Test creating LLMConfigData from ORM model.""" + from src.services.llm_config.service import LLMConfigData + + # Mock ORM object + mock_orm = MagicMock() + mock_orm.model_name = "gemini" + mock_orm.enabled = True + mock_orm.provider = "google" + mock_orm.api_key = "goog-test-key" + mock_orm.model_id = "gemini-1.5-pro" + mock_orm.timeout_ms = 25000 + mock_orm.max_tokens = 2000 + mock_orm.prompt_path = None + + config = LLMConfigData.from_orm(mock_orm) + + assert config.model_name == "gemini" + assert config.provider == "google" + assert config.api_key == "goog-test-key" + assert config.model_id == "gemini-1.5-pro" + assert config.prompt_path is None + + def test_predefined_model_providers(self): + """Test that model providers are predefined and not editable.""" + from src.services.llm_config.service import MODEL_PROVIDERS, DEFAULT_MODEL_IDS + + # Verify predefined mappings + assert MODEL_PROVIDERS["chatgpt"] == "openai" + assert MODEL_PROVIDERS["gemini"] == "google" + assert MODEL_PROVIDERS["claude"] == "anthropic" + assert MODEL_PROVIDERS["deepseek"] == "deepseek" + + # Verify default model IDs + assert DEFAULT_MODEL_IDS["chatgpt"] == "gpt-4o" + assert DEFAULT_MODEL_IDS["gemini"] == "gemini-1.5-pro" + assert DEFAULT_MODEL_IDS["claude"] == "claude-sonnet-4-20250514" + assert DEFAULT_MODEL_IDS["deepseek"] == "deepseek-chat" + + def test_supported_models_count(self): + """Test that all supported models have providers and defaults.""" + from src.services.llm_config.service import MODEL_PROVIDERS, DEFAULT_MODEL_IDS + + assert len(MODEL_PROVIDERS) == 4 + assert len(DEFAULT_MODEL_IDS) == 4 + assert set(MODEL_PROVIDERS.keys()) == set(DEFAULT_MODEL_IDS.keys()) + + +@pytest.mark.unit +class TestLLMConfigService: + """Tests for LLMConfigService.""" + + def test_cache_ttl_check(self): + """Test cache validity based on TTL.""" + import time + from src.services.llm_config.service import LLMConfigService + + service = LLMConfigService() + + # Cache is invalid when timestamp is 0 + assert service._is_cache_valid() is False + + # Set timestamp to now + service._cache_timestamp = time.time() + assert service._is_cache_valid() is True + + # Set timestamp to past TTL + service._cache_timestamp = time.time() - service.CACHE_TTL_SECONDS - 1 + assert service._is_cache_valid() is False + + def test_invalidate_cache(self): + """Test manual cache invalidation.""" + import time + from src.services.llm_config.service import LLMConfigService + + service = LLMConfigService() + service._cache_timestamp = time.time() + + # Cache should be valid + assert service._is_cache_valid() is True + + # Invalidate + service.invalidate_cache() + + # Cache should now be invalid + assert service._is_cache_valid() is False + assert service._cache_timestamp == 0 + + +@pytest.mark.unit +class TestLLMConfigServiceAsync: + """Async tests for LLMConfigService.""" + + @pytest.mark.asyncio + async def test_get_config_from_cache(self): + """Test getting config from in-memory cache.""" + import time + from src.services.llm_config.service import LLMConfigService, LLMConfigData + + service = LLMConfigService() + + # Populate cache directly + service._cache = { + "chatgpt": LLMConfigData( + model_name="chatgpt", + enabled=True, + provider="openai", + api_key="sk-cached-key", + model_id="gpt-4o", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + ) + } + service._cache_timestamp = time.time() + + config = await service.get_config("chatgpt") + + assert config is not None + assert config.model_name == "chatgpt" + assert config.api_key == "sk-cached-key" + + @pytest.mark.asyncio + async def test_get_config_disabled_returns_none(self): + """Test that disabled models return None.""" + import time + from src.services.llm_config.service import LLMConfigService, LLMConfigData + + service = LLMConfigService() + + # Populate cache with disabled model + service._cache = { + "disabled_model": LLMConfigData( + model_name="disabled_model", + enabled=False, + provider="openai", + api_key="sk-key", + model_id="gpt-4o", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + ) + } + service._cache_timestamp = time.time() + + config = await service.get_config("disabled_model") + + assert config is None + + @pytest.mark.asyncio + async def test_get_enabled_models(self): + """Test getting list of enabled models.""" + import time + from src.services.llm_config.service import LLMConfigService, LLMConfigData + + service = LLMConfigService() + + # Populate cache with mix of enabled/disabled + service._cache = { + "chatgpt": LLMConfigData( + model_name="chatgpt", + enabled=True, + provider="openai", + api_key="sk-key1", + model_id="gpt-4o", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + ), + "gemini": LLMConfigData( + model_name="gemini", + enabled=True, + provider="google", + api_key="goog-key", + model_id="gemini-1.5-pro", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + ), + "disabled": LLMConfigData( + model_name="disabled", + enabled=False, + provider="openai", + api_key="sk-disabled", + model_id="gpt-4o", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + ), + } + service._cache_timestamp = time.time() + + # Patch at module level where settings is imported + with patch.object(service, '_is_cache_valid', return_value=True): + # Mock the settings import inside the method + import src.services.llm_config.service as llm_service_module + original_settings = llm_service_module.settings if hasattr(llm_service_module, 'settings') else None + + # The method imports settings inside, so we need to patch at import location + with patch("src.core.config.settings") as mock_settings: + mock_settings.ai_models_list = [] + enabled = await service.get_enabled_models() + + assert "chatgpt" in enabled + assert "gemini" in enabled + assert "disabled" not in enabled + + @pytest.mark.asyncio + async def test_list_all_returns_all_configs(self): + """Test listing all configs including disabled.""" + import time + from src.services.llm_config.service import LLMConfigService, LLMConfigData + + service = LLMConfigService() + + # Populate cache + service._cache = { + "model1": LLMConfigData( + model_name="model1", + enabled=True, + provider="openai", + api_key="key1", + model_id="gpt-4o", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + ), + "model2": LLMConfigData( + model_name="model2", + enabled=False, + provider="google", + api_key="key2", + model_id="gemini", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + ), + } + service._cache_timestamp = time.time() + + all_configs = await service.list_all() + + assert len(all_configs) == 2 + model_names = [c.model_name for c in all_configs] + assert "model1" in model_names + assert "model2" in model_names + + +@pytest.mark.unit +class TestLLMConfigAPISchemas: + """Tests for LLM config API schemas.""" + + def test_llm_config_response_schema(self): + """Test LLMConfigResponse schema structure.""" + from src.api.v1.llm_configs import LLMConfigResponse + + response = LLMConfigResponse( + model_name="chatgpt", + enabled=True, + provider="openai", + api_key_masked="****-test", + model_id="gpt-4o", + timeout_ms=30000, + max_tokens=1000, + prompt_path=None, + validation_status="ok", + last_validated_at=None, + ) + + assert response.model_name == "chatgpt" + assert response.api_key_masked == "****-test" + assert response.validation_status == "ok" + + def test_llm_config_create_schema(self): + """Test LLMConfigCreate schema.""" + from src.api.v1.llm_configs import LLMConfigCreate + + # Provider is now automatically determined by model name + request = LLMConfigCreate( + api_key="sk-real-key", + model_id="gpt-4o", + enabled=True, + ) + + assert request.api_key == "sk-real-key" + assert request.model_id == "gpt-4o" + assert request.enabled is True + assert request.timeout_ms == 30000 # Default + assert request.max_tokens == 1000 # Default + + def test_llm_config_patch_allows_partial(self): + """Test that patch request allows partial updates.""" + from src.api.v1.llm_configs import LLMConfigPatch + + # Only updating enabled flag + request = LLMConfigPatch(enabled=False) + + assert request.enabled is False + assert request.api_key is None + assert request.model_id is None + + +@pytest.mark.unit +class TestAPIKeyMasking: + """Tests for API key masking in responses.""" + + def test_mask_api_key_function(self): + """Test API key masking shows only last 4 chars.""" + from src.api.v1.llm_configs import _mask_api_key + + # Normal key + masked = _mask_api_key("sk-abcdefghijklmnop") + assert masked == "****mnop" + assert "sk-" not in masked + + # Short key (4 chars or less returns just ****) + masked = _mask_api_key("abc") + assert masked == "****" + + # Empty key + masked = _mask_api_key("") + assert masked == "****" + + def test_mask_preserves_last_four_chars(self): + """Test that masking preserves exactly last 4 characters.""" + from src.api.v1.llm_configs import _mask_api_key + + key = "test-key-with-suffix1234" + masked = _mask_api_key(key) + + assert masked.endswith("1234") + assert masked == "****1234" + + +@pytest.mark.unit +class TestLLMConfigServiceSingleton: + """Tests for singleton pattern.""" + + def test_get_llm_config_service_returns_singleton(self): + """Test that get_llm_config_service returns same instance.""" + from src.services.llm_config import get_llm_config_service + + service1 = get_llm_config_service() + service2 = get_llm_config_service() + + assert service1 is service2